use std::sync::Arc;
use std::sync::atomic::{AtomicI64, Ordering};
use axum::Router;
use axum::extract::DefaultBodyLimit;
use axum::routing::{delete, get, post};
use crate::handlers::{base, catalog};
use crate::registry::RegistryMeta;
use ferro_blob_store::SharedBlobStore;
pub const MAX_BODY_BYTES: usize = 512 * 1024 * 1024;
pub struct AppState {
pub blob_store: SharedBlobStore,
pub registry: Arc<dyn RegistryMeta>,
blob_count: Arc<AtomicI64>,
max_upload_session_bytes: u64,
blob_accounting: tokio::sync::Mutex<()>,
}
impl AppState {
#[must_use]
pub fn new(blob_store: SharedBlobStore, registry: Arc<dyn RegistryMeta>) -> Arc<Self> {
Arc::new(Self {
blob_store,
registry,
blob_count: Arc::new(AtomicI64::new(0)),
max_upload_session_bytes: crate::upload::MAX_UPLOAD_SESSION_BYTES,
blob_accounting: tokio::sync::Mutex::new(()),
})
}
#[must_use]
pub fn with_max_upload_session_bytes(
blob_store: SharedBlobStore,
registry: Arc<dyn RegistryMeta>,
max_upload_session_bytes: u64,
) -> Arc<Self> {
Arc::new(Self {
blob_store,
registry,
blob_count: Arc::new(AtomicI64::new(0)),
max_upload_session_bytes,
blob_accounting: tokio::sync::Mutex::new(()),
})
}
#[must_use]
pub const fn max_upload_session_bytes(&self) -> u64 {
self.max_upload_session_bytes
}
#[must_use]
pub fn blob_count_handle(&self) -> Arc<AtomicI64> {
Arc::clone(&self.blob_count)
}
pub fn inc_blob_count(&self) {
self.blob_count.fetch_add(1, Ordering::Relaxed);
}
pub async fn store_blob_counted(
&self,
digest: &ferro_blob_store::Digest,
body: axum::body::Bytes,
) -> ferro_blob_store::Result<()> {
let _accounting = self.blob_accounting.lock().await;
let already_present = self.blob_store.contains(digest).await.unwrap_or(false);
self.blob_store.put(digest, body).await?;
if !already_present {
self.inc_blob_count();
}
Ok(())
}
pub fn dec_blob_count(&self) {
let _ = self.blob_count.fetch_update(
Ordering::Relaxed,
Ordering::Relaxed,
|n| Some(n.saturating_sub(1).max(0)),
);
}
#[must_use]
pub fn blob_count(&self) -> i64 {
self.blob_count.load(Ordering::Relaxed)
}
}
#[cfg(test)]
mod app_state_tests {
use std::sync::Arc;
use ferro_blob_store::InMemoryBlobStore;
use super::AppState;
use crate::registry::InMemoryRegistryMeta;
fn state() -> Arc<AppState> {
AppState::new(
Arc::new(InMemoryBlobStore::new()),
Arc::new(InMemoryRegistryMeta::new()),
)
}
#[test]
fn blob_count_reflects_increments_and_saturating_decrements() {
let st = state();
assert_eq!(st.blob_count(), 0, "fresh state starts at 0");
st.inc_blob_count();
st.inc_blob_count();
assert_eq!(st.blob_count(), 2, "two increments give 2");
st.dec_blob_count();
assert_eq!(st.blob_count(), 1, "one decrement gives 1");
}
#[test]
fn blob_count_saturates_at_zero_never_negative() {
let st = state();
st.dec_blob_count();
assert_eq!(st.blob_count(), 0, "decrement on empty stays at 0");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn r3_3_concurrent_same_digest_put_increments_gauge_once() {
use bytes::Bytes;
use ferro_blob_store::Digest;
let st = state();
let body = Bytes::from_static(b"the-same-blob-bytes");
let digest = Digest::sha256_of(&body);
let mut handles = Vec::new();
for _ in 0..16 {
let st = Arc::clone(&st);
let digest = digest.clone();
let body = body.clone();
handles.push(tokio::spawn(async move {
st.store_blob_counted(&digest, body).await.expect("put");
}));
}
for h in handles {
h.await.expect("join");
}
assert_eq!(
st.blob_count(),
1,
"concurrent puts of one digest must count it exactly once"
);
}
#[tokio::test]
async fn r3_3_distinct_digests_each_counted_once() {
use bytes::Bytes;
use ferro_blob_store::Digest;
let st = state();
for i in 0..5u8 {
let body = Bytes::from(vec![i; 4]);
let digest = Digest::sha256_of(&body);
st.store_blob_counted(&digest, body).await.expect("put");
let body2 = Bytes::from(vec![i; 4]);
let digest2 = Digest::sha256_of(&body2);
st.store_blob_counted(&digest2, body2).await.expect("put2");
}
assert_eq!(st.blob_count(), 5, "five distinct blobs counted once each");
}
}
pub fn router(state: Arc<AppState>) -> Router {
Router::new()
.route("/v2/", get(base::version_check))
.route("/v2", get(base::version_check))
.route("/v2/_catalog", get(catalog::list_catalog))
.route("/v2/{*rest}", get(dispatch::dispatch_get))
.route("/v2/{*rest}", axum::routing::head(dispatch::dispatch_head))
.route("/v2/{*rest}", delete(dispatch::dispatch_delete))
.route(
"/v2/{*rest}",
post(dispatch::dispatch_post)
.patch(dispatch::dispatch_patch_inner)
.put(dispatch::dispatch_put_inner),
)
.layer(DefaultBodyLimit::max(MAX_BODY_BYTES))
.with_state(state)
}
pub fn probe_routes() -> Router {
use axum::Json;
use axum::http::StatusCode;
use serde_json::json;
Router::new()
.route("/live", get(|| async { (StatusCode::OK, "OK") }))
.route(
"/healthz",
get(|| async { (StatusCode::OK, Json(json!({ "status": "ok" }))) }),
)
.route("/ready", get(|| async { (StatusCode::OK, "OK") }))
}
pub mod dispatch {
use std::sync::Arc;
use axum::body::Bytes;
use axum::extract::{Path, Query, State};
use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::{IntoResponse, Response};
use super::AppState;
use crate::error::{OciError, OciErrorCode};
use crate::handlers::{blob, blob_upload, manifest as manifest_h, referrers, tags};
fn split_rest(rest: &str) -> Option<(&str, &str)> {
let keywords = ["blobs/", "manifests/", "tags/list", "referrers/"];
for kw in keywords {
if let Some(idx) = rest.rfind(kw) {
if idx == 0 {
return None;
}
if &rest[idx - 1..idx] != "/" {
continue;
}
let name = &rest[..idx - 1];
let suffix = &rest[idx..];
return Some((name, suffix));
}
}
None
}
fn decode(rest: &str) -> Result<(String, String), OciError> {
let (name, suffix) = split_rest(rest).ok_or_else(|| {
OciError::new(OciErrorCode::NameUnknown, format!("cannot route `{rest}`"))
})?;
Ok((name.to_owned(), suffix.to_owned()))
}
pub async fn dispatch_get(
State(state): State<Arc<AppState>>,
Path(rest): Path<String>,
Query(params): Query<std::collections::BTreeMap<String, String>>,
headers: HeaderMap,
) -> Response {
let (name, suffix) = match decode(&rest) {
Ok(v) => v,
Err(e) => return e.into_response(),
};
dispatch_inner(
state,
name,
suffix,
Method::GET,
headers,
params,
Bytes::new(),
)
.await
}
pub async fn dispatch_head(
State(state): State<Arc<AppState>>,
Path(rest): Path<String>,
headers: HeaderMap,
) -> Response {
let (name, suffix) = match decode(&rest) {
Ok(v) => v,
Err(e) => return e.into_response(),
};
dispatch_inner(
state,
name,
suffix,
Method::HEAD,
headers,
std::collections::BTreeMap::default(),
Bytes::new(),
)
.await
}
pub async fn dispatch_delete(
State(state): State<Arc<AppState>>,
Path(rest): Path<String>,
headers: HeaderMap,
) -> Response {
let (name, suffix) = match decode(&rest) {
Ok(v) => v,
Err(e) => return e.into_response(),
};
dispatch_inner(
state,
name,
suffix,
Method::DELETE,
headers,
std::collections::BTreeMap::default(),
Bytes::new(),
)
.await
}
pub async fn dispatch_post(
State(state): State<Arc<AppState>>,
Path(rest): Path<String>,
Query(params): Query<std::collections::BTreeMap<String, String>>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let (name, suffix) = match decode(&rest) {
Ok(v) => v,
Err(e) => return e.into_response(),
};
dispatch_inner(state, name, suffix, Method::POST, headers, params, body).await
}
pub async fn dispatch_patch_inner(
State(state): State<Arc<AppState>>,
Path(rest): Path<String>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let (name, suffix) = match decode(&rest) {
Ok(v) => v,
Err(e) => return e.into_response(),
};
dispatch_inner(
state,
name,
suffix,
Method::PATCH,
headers,
std::collections::BTreeMap::default(),
body,
)
.await
}
pub async fn dispatch_put_inner(
State(state): State<Arc<AppState>>,
Path(rest): Path<String>,
Query(params): Query<std::collections::BTreeMap<String, String>>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let (name, suffix) = match decode(&rest) {
Ok(v) => v,
Err(e) => return e.into_response(),
};
dispatch_inner(state, name, suffix, Method::PUT, headers, params, body).await
}
#[allow(clippy::too_many_arguments)]
async fn dispatch_inner(
state: Arc<AppState>,
name: String,
suffix: String,
method: Method,
headers: HeaderMap,
params: std::collections::BTreeMap<String, String>,
body: Bytes,
) -> Response {
if suffix == "tags/list" {
return if method == Method::GET {
tags::list_tags(&state, &name, ¶ms)
.await
.into_response()
} else {
OciError::new(OciErrorCode::Unsupported, "unsupported method")
.with_status(StatusCode::METHOD_NOT_ALLOWED)
.into_response()
};
}
if let Some(rest) = suffix.strip_prefix("referrers/") {
return if method == Method::GET {
referrers::get_referrers(&state, &name, rest, ¶ms)
.await
.into_response()
} else {
OciError::new(OciErrorCode::Unsupported, "unsupported method")
.with_status(StatusCode::METHOD_NOT_ALLOWED)
.into_response()
};
}
if let Some(rest) = suffix.strip_prefix("manifests/") {
return match method {
Method::GET => manifest_h::get_manifest(&state, &name, rest, &headers)
.await
.into_response(),
Method::HEAD => manifest_h::head_manifest(&state, &name, rest)
.await
.into_response(),
Method::PUT => manifest_h::put_manifest(&state, &name, rest, &headers, body)
.await
.into_response(),
Method::DELETE => manifest_h::delete_manifest(&state, &name, rest)
.await
.into_response(),
_ => OciError::new(OciErrorCode::Unsupported, "unsupported method")
.with_status(StatusCode::METHOD_NOT_ALLOWED)
.into_response(),
};
}
if let Some(rest) = suffix.strip_prefix("blobs/uploads/") {
let uuid = rest.trim_end_matches('/');
return match method {
Method::POST => {
blob_upload::init_upload(&state, &name, &headers, ¶ms, body)
.await
.into_response()
}
Method::PATCH => blob_upload::patch_upload(&state, &name, uuid, &headers, body)
.await
.into_response(),
Method::PUT => blob_upload::finish_upload(&state, &name, uuid, ¶ms, body)
.await
.into_response(),
Method::GET => blob_upload::get_upload_status(&state, &name, uuid)
.await
.into_response(),
Method::DELETE => blob_upload::cancel_upload(&state, &name, uuid)
.await
.into_response(),
_ => OciError::new(OciErrorCode::Unsupported, "unsupported method")
.with_status(StatusCode::METHOD_NOT_ALLOWED)
.into_response(),
};
}
if let Some(rest) = suffix.strip_prefix("blobs/") {
return match method {
Method::GET => blob::get_blob(&state, &name, rest).await.into_response(),
Method::HEAD => blob::head_blob(&state, &name, rest).await.into_response(),
Method::DELETE => blob::delete_blob(&state, &name, rest).await.into_response(),
_ => OciError::new(OciErrorCode::Unsupported, "unsupported method")
.with_status(StatusCode::METHOD_NOT_ALLOWED)
.into_response(),
};
}
OciError::new(
OciErrorCode::NameUnknown,
format!("cannot route `{name}/{suffix}`"),
)
.into_response()
}
#[cfg(test)]
mod tests {
use super::split_rest;
#[test]
fn split_simple_manifest_path() {
let (name, suffix) = split_rest("alpine/manifests/latest").expect("split");
assert_eq!(name, "alpine");
assert_eq!(suffix, "manifests/latest");
}
#[test]
fn split_nested_blob_path() {
let (name, suffix) = split_rest("my-org/lib/alpine/blobs/uploads/abc").expect("split");
assert_eq!(name, "my-org/lib/alpine");
assert_eq!(suffix, "blobs/uploads/abc");
}
#[test]
fn split_tags_list() {
let (name, suffix) = split_rest("lib/alpine/tags/list").expect("split");
assert_eq!(name, "lib/alpine");
assert_eq!(suffix, "tags/list");
}
#[test]
fn split_referrers() {
let (name, suffix) = split_rest("lib/alpine/referrers/sha256:abcd").expect("split");
assert_eq!(name, "lib/alpine");
assert_eq!(suffix, "referrers/sha256:abcd");
}
#[test]
fn split_none_for_bare_name() {
assert!(split_rest("alpine").is_none());
}
}
}