use axum::response::Response;
use futures::future::BoxFuture;
use crate::web::context::RequestContext;
use crate::web::error::{Conflict, HttpError};
const MAX_STORED_BODY: usize = 256 * 1024;
pub enum IdempotencyDecision {
Fresh,
Replay { status: u16, body: bytes::Bytes },
InFlight,
Unavailable,
}
pub trait IdempotencyStore: Send + Sync + 'static {
fn claim<'a>(&'a self, key: &'a str, ttl_secs: u64) -> BoxFuture<'a, IdempotencyDecision>;
fn complete<'a>(
&'a self,
key: &'a str,
status: u16,
body: &'a [u8],
ttl_secs: u64,
) -> BoxFuture<'a, ()>;
fn release<'a>(&'a self, key: &'a str) -> BoxFuture<'a, ()>;
}
#[doc(hidden)]
pub async fn run_idempotent<Fut>(
ctx: &RequestContext,
ttl_secs: u64,
route: &'static str,
handler: Fut,
) -> Response
where
Fut: std::future::Future<Output = Response>,
{
let (Some(client_key), Some(store)) = (
ctx.header("idempotency-key")
.map(str::trim)
.filter(|k| !k.is_empty()),
ctx.try_inject::<Box<dyn IdempotencyStore>>(),
) else {
return handler.await;
};
let tenant = ctx.tenant().map(|t| t.id.as_str()).unwrap_or("-");
let sub = ctx
.claims()
.and_then(|c| c.get("sub"))
.and_then(|v| v.as_str())
.unwrap_or("-");
let key = format!("idem:{tenant}:{sub}:{route}:{client_key}");
match store.claim(&key, ttl_secs).await {
IdempotencyDecision::Replay { status, body } => {
metrics::counter!("idempotency_replays_total").increment(1);
Response::builder()
.status(status)
.header("content-type", "application/json")
.header("idempotency-replayed", "true")
.body(axum::body::Body::from(body))
.unwrap_or_else(|_| Response::new(axum::body::Body::empty()))
}
IdempotencyDecision::InFlight => {
metrics::counter!("idempotency_conflicts_total").increment(1);
crate::http::IntoResponse::into_response(crate::web::error::HttpException::from(
Conflict::new("a request with this Idempotency-Key is already in flight"),
))
}
IdempotencyDecision::Unavailable => {
metrics::counter!("idempotency_store_errors_total").increment(1);
handler.await
}
IdempotencyDecision::Fresh => {
let resp = handler.await;
let status = resp.status().as_u16();
if status >= 500 {
store.release(&key).await;
return resp;
}
let (parts, body) = resp.into_parts();
match axum::body::to_bytes(body, MAX_STORED_BODY).await {
Ok(bytes) => {
store.complete(&key, status, &bytes, ttl_secs).await;
Response::from_parts(parts, axum::body::Body::from(bytes))
}
Err(_) => {
store.release(&key).await;
Response::from_parts(parts, axum::body::Body::empty())
}
}
}
}
}
const _: fn() = || {
fn assert_http_error<T: HttpError>() {}
assert_http_error::<Conflict>();
};