use std::collections::{BTreeMap, HashSet};
use std::io::Cursor;
use axum::Json;
use axum::extract::State;
use axum::http::{HeaderValue, StatusCode, header};
use axum::response::{IntoResponse, Response};
use mnem_core::id::Cid;
use mnem_transport::car::{CarHeader, write_block, write_header};
use mnem_transport::import::import_with_limit;
use mnem_transport::protocol::{
CAPABILITIES_HEADER, Capability, PROTOCOL_HEADER, PROTOCOL_VERSION, serialize_capabilities,
};
use serde::{Deserialize, Serialize};
use crate::auth::RequireBearer;
use crate::error::RemoteError;
use crate::metrics::AdvanceHeadLabels;
use crate::state::AppState;
const DEFAULT_REF: &str = "main";
fn protocol_headers() -> [(axum::http::HeaderName, HeaderValue); 2] {
let caps_value = serialize_capabilities(Capability::all().iter().copied());
[
(
axum::http::HeaderName::from_static(PROTOCOL_HEADER),
HeaderValue::from_str(&PROTOCOL_VERSION.to_string())
.expect("protocol version is ascii digits"),
),
(
axum::http::HeaderName::from_static(CAPABILITIES_HEADER),
HeaderValue::from_str(&caps_value).expect("capability list is ascii"),
),
]
}
#[derive(Debug, Serialize)]
pub(crate) struct RefsResponse {
pub head: Option<String>,
pub refs: BTreeMap<String, String>,
pub capabilities: Vec<String>,
}
pub(crate) async fn get_refs(State(state): State<AppState>) -> Result<Response, RemoteError> {
let head = {
let repo = state
.repo
.lock()
.map_err(|_| RemoteError::Internal("server state lock poisoned".into()))?;
let ohs = repo.op_heads_store();
let heads = ohs
.current()
.map_err(|e| RemoteError::Internal(format!("op-heads read: {e}")))?;
heads.into_iter().next()
};
let head_str = head.as_ref().map(ToString::to_string);
let mut refs: BTreeMap<String, String> = BTreeMap::new();
if let Some(h) = head_str.as_ref() {
refs.insert("HEAD".to_string(), h.clone());
}
let body = RefsResponse {
head: head_str,
refs,
capabilities: Capability::all()
.iter()
.map(|c| c.as_wire_str().to_string())
.collect(),
};
Ok((StatusCode::OK, protocol_headers(), Json(body)).into_response())
}
#[derive(Debug, Deserialize)]
pub(crate) struct FetchBlocksRequest {
pub wants: Vec<String>,
#[serde(default)]
pub have_set: Vec<u8>,
}
pub(crate) async fn post_fetch_blocks(
State(state): State<AppState>,
Json(req): Json<FetchBlocksRequest>,
) -> Result<Response, RemoteError> {
if req.wants.is_empty() {
return Err(RemoteError::BadRequest("wants: must be non-empty".into()));
}
let wants: Vec<Cid> = req
.wants
.iter()
.map(|s| Cid::parse_str(s).map_err(|e| RemoteError::BadRequest(format!("wants: {e}"))))
.collect::<Result<_, _>>()?;
let _have_set = req.have_set;
let mut buf: Vec<u8> = Vec::new();
let header = CarHeader {
version: 1,
roots: wants.clone(),
};
write_header(&mut buf, &header)
.map_err(|e| RemoteError::Internal(format!("CAR header: {e}")))?;
{
let repo = state
.repo
.lock()
.map_err(|_| RemoteError::Internal("server state lock poisoned".into()))?;
let bs = repo.blockstore();
let mut visited: HashSet<Cid> = HashSet::new();
for want in &wants {
for item in bs.iter_from_root(want) {
let (cid, data) = item.map_err(|e| match e {
mnem_core::error::StoreError::NotFound { cid } => {
RemoteError::NotFound(format!("want not in store: {cid}"))
}
other => RemoteError::Internal(format!("blockstore walk: {other}")),
})?;
if !visited.insert(cid.clone()) {
continue;
}
write_block(&mut buf, &cid, &data)
.map_err(|e| RemoteError::Internal(format!("CAR block write: {e}")))?;
}
}
}
state.metrics.remote_fetch_blocks.inc();
let mut resp = (StatusCode::OK, buf).into_response();
let h = resp.headers_mut();
h.insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/vnd.ipld.car"),
);
for (name, value) in protocol_headers() {
h.insert(name, value);
}
Ok(resp)
}
#[derive(Debug, Serialize)]
pub(crate) struct PushBlocksResponse {
pub staged: Option<String>,
pub blocks_accepted: u64,
}
pub(crate) async fn post_push_blocks(
State(state): State<AppState>,
_auth: RequireBearer,
body: axum::body::Bytes,
) -> Result<Response, RemoteError> {
let stats = {
let repo = state
.repo
.lock()
.map_err(|_| RemoteError::Internal("server state lock poisoned".into()))?;
let bs = repo.blockstore();
let mut reader = Cursor::new(body.as_ref());
import_with_limit(
&mut reader,
bs.as_ref(),
mnem_transport::import::DEFAULT_MAX_IMPORT_BYTES,
)
.map_err(remote_error_from_transport)?
};
state.metrics.remote_push_blocks.inc();
let staged = stats.roots.first().map(ToString::to_string);
let body = PushBlocksResponse {
staged,
blocks_accepted: stats.blocks,
};
Ok((StatusCode::OK, protocol_headers(), Json(body)).into_response())
}
fn remote_error_from_transport(e: mnem_transport::TransportError) -> RemoteError {
use mnem_transport::TransportError as T;
match e {
T::Car(_) | T::CidMismatch { .. } | T::MissingRoot { .. } | T::UnsupportedHash(_) => {
RemoteError::BadRequest(format!("{e}"))
}
T::SizeLimit { .. } => RemoteError::BadRequest(format!("{e}")),
T::Codec(_) => RemoteError::BadRequest(format!("{e}")),
T::Store(_) | T::Io(_) => RemoteError::Internal(format!("{e}")),
other => RemoteError::Internal(format!("{other}")),
}
}
#[derive(Debug, Deserialize)]
pub(crate) struct AdvanceHeadRequest {
pub old: String,
pub new: String,
#[serde(default = "default_ref_name")]
pub r#ref: String,
}
fn default_ref_name() -> String {
DEFAULT_REF.to_string()
}
#[derive(Debug, Serialize)]
pub(crate) struct AdvanceHeadResponse {
pub head: String,
}
pub(crate) async fn post_advance_head(
State(state): State<AppState>,
_auth: RequireBearer,
Json(req): Json<AdvanceHeadRequest>,
) -> Result<Response, RemoteError> {
if req.r#ref != DEFAULT_REF {
return Err(RemoteError::BadRequest(format!(
"ref `{}` not supported; only `{DEFAULT_REF}` in B3.1",
req.r#ref
)));
}
let old = Cid::parse_str(&req.old).map_err(|e| RemoteError::BadRequest(format!("old: {e}")))?;
let new = Cid::parse_str(&req.new).map_err(|e| RemoteError::BadRequest(format!("new: {e}")))?;
let inc_ok = |s: &AppState| {
s.metrics
.remote_advance_head
.get_or_create(&AdvanceHeadLabels {
result: "success".into(),
})
.inc();
};
let inc_mismatch = |s: &AppState| {
s.metrics
.remote_advance_head
.get_or_create(&AdvanceHeadLabels {
result: "cas_mismatch".into(),
})
.inc();
};
let repo = state
.repo
.lock()
.map_err(|_| RemoteError::Internal("server state lock poisoned".into()))?;
let ohs = repo.op_heads_store();
let current = ohs
.current()
.map_err(|e| RemoteError::Internal(format!("op-heads read: {e}")))?;
if !current.iter().any(|c| c == &old) {
inc_mismatch(&state);
let current_head = current.into_iter().next();
return Err(RemoteError::CasMismatch {
current: current_head.unwrap_or_else(|| old.clone()),
});
}
ohs.update(new.clone(), std::slice::from_ref(&old))
.map_err(|e| RemoteError::Internal(format!("op-heads update: {e}")))?;
inc_ok(&state);
let body = AdvanceHeadResponse {
head: new.to_string(),
};
Ok((StatusCode::OK, protocol_headers(), Json(body)).into_response())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::state::test_support::state_with_token;
use axum::body::Body;
use axum::http::Request;
use http_body_util::BodyExt;
use tower::ServiceExt;
fn app(state: AppState) -> axum::Router {
axum::Router::new()
.route("/remote/v1/refs", axum::routing::get(get_refs))
.route(
"/remote/v1/fetch-blocks",
axum::routing::post(post_fetch_blocks),
)
.route(
"/remote/v1/push-blocks",
axum::routing::post(post_push_blocks),
)
.route(
"/remote/v1/advance-head",
axum::routing::post(post_advance_head),
)
.with_state(state)
}
#[tokio::test]
async fn refs_shape_and_protocol_header() {
let state = state_with_token(Some("tok".into()));
let app = app(state);
let req = Request::builder()
.uri("/remote/v1/refs")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(
resp.headers()
.get(PROTOCOL_HEADER)
.unwrap()
.to_str()
.unwrap(),
"1"
);
assert!(resp.headers().get(CAPABILITIES_HEADER).is_some());
let body = resp.into_body().collect().await.unwrap().to_bytes();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(v["head"].is_null() || v["head"].is_string());
let refs = v["refs"].as_object().unwrap();
if v["head"].is_string() {
assert_eq!(
refs.get("HEAD").and_then(|s| s.as_str()),
v["head"].as_str()
);
} else {
assert!(refs.is_empty());
}
assert!(!v["capabilities"].as_array().unwrap().is_empty());
}
#[tokio::test]
async fn push_blocks_requires_bearer_missing() {
let state = state_with_token(Some("tok".into()));
let app = app(state);
let req = Request::builder()
.method("POST")
.uri("/remote/v1/push-blocks")
.body(Body::from(Vec::<u8>::new()))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 401);
assert!(resp.headers().get("www-authenticate").is_some());
}
#[tokio::test]
async fn advance_head_requires_bearer_mismatch() {
let state = state_with_token(Some("tok".into()));
let app = app(state);
let req = Request::builder()
.method("POST")
.uri("/remote/v1/advance-head")
.header("authorization", "Bearer wrong")
.header("content-type", "application/json")
.body(Body::from(r#"{"old":"x","new":"y"}"#))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 401);
}
#[tokio::test]
async fn advance_head_cas_mismatch_on_empty_heads() {
let state = state_with_token(Some("tok".into()));
let app = app(state);
let mh = mnem_core::id::Multihash::sha2_256(b"a");
let cid = mnem_core::id::Cid::new(mnem_core::id::CODEC_RAW, mh);
let body = serde_json::json!({
"old": cid.to_string(),
"new": cid.to_string(),
});
let req = Request::builder()
.method("POST")
.uri("/remote/v1/advance-head")
.header("authorization", "Bearer tok")
.header("content-type", "application/json")
.body(Body::from(body.to_string()))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 409);
}
#[tokio::test]
async fn fetch_blocks_rejects_empty_wants() {
let state = state_with_token(Some("tok".into()));
let app = app(state);
let req = Request::builder()
.method("POST")
.uri("/remote/v1/fetch-blocks")
.header("content-type", "application/json")
.body(Body::from(r#"{"wants":[]}"#))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), 400);
}
#[tokio::test]
async fn metrics_counter_increments_on_fetch_blocks_empty_wants_rejection() {
let state = state_with_token(Some("tok".into()));
let before = state.metrics.remote_fetch_blocks.get();
let app = app(state.clone());
let req = Request::builder()
.method("POST")
.uri("/remote/v1/fetch-blocks")
.header("content-type", "application/json")
.body(Body::from(r#"{"wants":[]}"#))
.unwrap();
let _ = app.oneshot(req).await.unwrap();
let after = state.metrics.remote_fetch_blocks.get();
assert_eq!(before, after, "rejected request must not bump counter");
}
}