use std::collections::BTreeMap;
use axum::body::Bytes;
use axum::http::{HeaderMap, HeaderValue, StatusCode, header};
use axum::response::{IntoResponse, Response};
use ferro_blob_store::Digest;
use crate::error::{OciError, OciErrorCode};
use crate::reference::validate_name;
use crate::registry::UploadAdmission;
use crate::router::AppState;
use crate::upload::ContentRange;
fn upload_too_large(cap: u64, current: u64, incoming: u64) -> Response {
OciError::new(
OciErrorCode::BlobUploadInvalid,
format!(
"upload exceeds the {cap}-byte session cap \
(buffered {current}, incoming {incoming})"
),
)
.with_status(StatusCode::PAYLOAD_TOO_LARGE)
.into_response()
}
const fn would_exceed_cap(cap: u64, current: u64, incoming: u64) -> bool {
current.saturating_add(incoming) > cap
}
fn parse_digest(s: &str) -> Result<Digest, OciError> {
s.parse::<Digest>().map_err(|e| {
OciError::new(
OciErrorCode::DigestInvalid,
format!("invalid digest `{s}`: {e}"),
)
})
}
fn upload_location_headers(name: &str, uuid: &str, new_offset: u64) -> HeaderMap {
let mut headers = HeaderMap::new();
let location = format!("/v2/{name}/blobs/uploads/{uuid}");
if let Ok(v) = HeaderValue::from_str(&location) {
headers.insert(header::LOCATION, v);
}
let range = if new_offset == 0 {
"0-0".to_owned()
} else {
format!("0-{}", new_offset - 1)
};
if let Ok(v) = HeaderValue::from_str(&range) {
headers.insert(header::RANGE, v);
}
if let Ok(v) = HeaderValue::from_str(uuid) {
headers.insert("Docker-Upload-UUID", v);
}
headers.insert("OCI-Chunk-Min-Length", HeaderValue::from_static("0"));
headers
}
pub async fn init_upload(
state: &AppState,
name: &str,
_headers: &HeaderMap,
params: &BTreeMap<String, String>,
body: Bytes,
) -> Response {
if let Err(e) = validate_name(name) {
return e.into_response();
}
if let Some(digest_str) = params.get("digest") {
let digest = match parse_digest(digest_str) {
Ok(d) => d,
Err(e) => return e.into_response(),
};
let actual = Digest::sha256_of(&body);
if actual.algo() == digest.algo() && actual.hex() != digest.hex() {
return OciError::new(
OciErrorCode::DigestInvalid,
format!("digest mismatch: declared {digest}, computed {actual}"),
)
.into_response();
}
if let Err(e) = state.store_blob_counted(&digest, body).await {
return OciError::from(e).into_response();
}
return blob_created_response(name, &digest);
}
let uuid = match state.registry.start_upload(name).await {
Ok(UploadAdmission::Started(u)) => u,
Ok(UploadAdmission::AtCapacity(cap)) => {
return OciError::new(
OciErrorCode::TooManyRequests,
format!(
"upload-session capacity reached ({cap} concurrent sessions); retry later"
),
)
.into_response();
}
Err(e) => return OciError::from(e).into_response(),
};
let headers = upload_location_headers(name, &uuid, 0);
(StatusCode::ACCEPTED, headers).into_response()
}
pub async fn patch_upload(
state: &AppState,
name: &str,
uuid: &str,
headers: &HeaderMap,
body: Bytes,
) -> Response {
if let Err(e) = validate_name(name) {
return e.into_response();
}
let existing = match state.registry.get_upload_state(name, uuid).await {
Ok(v) => v,
Err(e) => return OciError::from(e).into_response(),
};
let Some(state_snapshot) = existing else {
return OciError::new(
OciErrorCode::BlobUploadUnknown,
format!("unknown upload uuid {uuid}"),
)
.into_response();
};
let expected_offset = state_snapshot.offset();
let chunk_start = match headers.get(header::CONTENT_RANGE) {
Some(v) => {
let Ok(s) = v.to_str() else {
return OciError::new(OciErrorCode::BlobUploadInvalid, "non-ASCII Content-Range")
.into_response();
};
let range = match ContentRange::parse(s) {
Ok(r) => r,
Err(e) => {
return OciError::new(
OciErrorCode::BlobUploadInvalid,
format!("malformed Content-Range `{s}`: {e}"),
)
.into_response();
}
};
let Some(declared_len) = range.checked_length() else {
return OciError::new(
OciErrorCode::BlobUploadInvalid,
format!("Content-Range `{s}` spans more than u64::MAX bytes"),
)
.with_status(StatusCode::RANGE_NOT_SATISFIABLE)
.into_response();
};
if declared_len != body.len() as u64 {
return OciError::new(
OciErrorCode::BlobUploadInvalid,
format!(
"Content-Range length mismatch: range `{s}` spans {declared_len} bytes \
but body carries {}",
body.len()
),
)
.with_status(StatusCode::RANGE_NOT_SATISFIABLE)
.into_response();
}
range.start
}
None => expected_offset,
};
if chunk_start != expected_offset {
return OciError::new(
OciErrorCode::BlobUploadInvalid,
format!("out-of-order chunk: expected offset {expected_offset}, got {chunk_start}"),
)
.with_status(StatusCode::RANGE_NOT_SATISFIABLE)
.into_response();
}
let cap = state.max_upload_session_bytes();
if would_exceed_cap(cap, expected_offset, body.len() as u64) {
let _ = state.registry.cancel_upload(name, uuid).await;
return upload_too_large(cap, expected_offset, body.len() as u64);
}
let new_offset = match state
.registry
.append_upload(name, uuid, chunk_start, body)
.await
{
Ok(o) => o,
Err(e) => return OciError::from(e).into_response(),
};
let headers = upload_location_headers(name, uuid, new_offset);
(StatusCode::ACCEPTED, headers).into_response()
}
pub async fn finish_upload(
state: &AppState,
name: &str,
uuid: &str,
params: &BTreeMap<String, String>,
body: Bytes,
) -> Response {
if let Err(e) = validate_name(name) {
return e.into_response();
}
let Some(digest_str) = params.get("digest") else {
return OciError::new(
OciErrorCode::DigestInvalid,
"missing `digest` query parameter",
)
.into_response();
};
let declared = match parse_digest(digest_str) {
Ok(d) => d,
Err(e) => return e.into_response(),
};
let existing = match state.registry.get_upload_state(name, uuid).await {
Ok(v) => v,
Err(e) => return OciError::from(e).into_response(),
};
let Some(state_snapshot) = existing else {
return OciError::new(
OciErrorCode::BlobUploadUnknown,
format!("unknown upload uuid {uuid}"),
)
.into_response();
};
if !body.is_empty() {
let cap = state.max_upload_session_bytes();
if would_exceed_cap(cap, state_snapshot.offset(), body.len() as u64) {
let _ = state.registry.cancel_upload(name, uuid).await;
return upload_too_large(cap, state_snapshot.offset(), body.len() as u64);
}
if let Err(e) = state
.registry
.append_upload(name, uuid, state_snapshot.offset(), body)
.await
{
return OciError::from(e).into_response();
}
}
let bytes = match state.registry.take_upload_bytes(name, uuid).await {
Ok(Some(b)) => b,
Ok(None) => {
return OciError::new(
OciErrorCode::BlobUploadUnknown,
format!("upload {uuid} has no buffered bytes"),
)
.into_response();
}
Err(e) => return OciError::from(e).into_response(),
};
let actual = Digest::sha256_of(&bytes);
if declared.algo() == actual.algo() && actual.hex() != declared.hex() {
return OciError::new(
OciErrorCode::DigestInvalid,
format!("digest mismatch: declared {declared}, computed {actual}"),
)
.into_response();
}
if let Err(e) = state.store_blob_counted(&declared, bytes).await {
return OciError::from(e).into_response();
}
if let Err(e) = state.registry.complete_upload(name, uuid, &declared).await {
return OciError::from(e).into_response();
}
blob_created_response(name, &declared)
}
pub async fn get_upload_status(state: &AppState, name: &str, uuid: &str) -> Response {
if let Err(e) = validate_name(name) {
return e.into_response();
}
let existing = match state.registry.get_upload_state(name, uuid).await {
Ok(v) => v,
Err(e) => return OciError::from(e).into_response(),
};
let Some(s) = existing else {
return OciError::new(
OciErrorCode::BlobUploadUnknown,
format!("unknown upload uuid {uuid}"),
)
.into_response();
};
let headers = upload_location_headers(name, uuid, s.offset());
(StatusCode::NO_CONTENT, headers).into_response()
}
pub async fn cancel_upload(state: &AppState, name: &str, uuid: &str) -> Response {
if let Err(e) = validate_name(name) {
return e.into_response();
}
let removed = match state.registry.cancel_upload(name, uuid).await {
Ok(b) => b,
Err(e) => return OciError::from(e).into_response(),
};
if !removed {
return OciError::new(
OciErrorCode::BlobUploadUnknown,
format!("unknown upload uuid {uuid}"),
)
.into_response();
}
(StatusCode::NO_CONTENT, HeaderMap::new()).into_response()
}
fn blob_created_response(name: &str, digest: &Digest) -> Response {
let mut headers = HeaderMap::new();
let location = format!("/v2/{name}/blobs/{digest}");
if let Ok(v) = HeaderValue::from_str(&location) {
headers.insert(header::LOCATION, v);
}
if let Ok(v) = HeaderValue::from_str(&digest.to_string()) {
headers.insert("Docker-Content-Digest", v);
}
headers.insert(header::CONTENT_LENGTH, HeaderValue::from(0u64));
(StatusCode::CREATED, headers).into_response()
}
#[cfg(test)]
mod tests {
use super::would_exceed_cap;
#[test]
fn would_exceed_cap_boundary_is_inclusive_under_strict_greater() {
assert!(
!would_exceed_cap(10, 5, 5),
"exact fit (sum == cap) must be allowed"
);
assert!(!would_exceed_cap(10, 5, 4), "under cap is allowed");
assert!(would_exceed_cap(10, 5, 6), "over cap is rejected");
}
#[test]
fn would_exceed_cap_saturates_on_overflow() {
assert!(would_exceed_cap(1024, u64::MAX, u64::MAX));
}
}