use axum::{
extract::Request,
response::Response,
body::Body,
};
use axum::http::StatusCode;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tower::{Layer, Service};
use alun_cache::Cache;
#[derive(Clone)]
pub struct IdempotencyLayer {
cache: Arc<alun_cache::SharedCache>,
ttl: Duration,
}
impl IdempotencyLayer {
pub fn new(cache: Arc<alun_cache::SharedCache>, ttl: Duration) -> Self {
Self { cache, ttl }
}
}
impl<S> Layer<S> for IdempotencyLayer {
type Service = IdempotencyService<S>;
fn layer(&self, inner: S) -> Self::Service {
IdempotencyService {
inner,
cache: self.cache.clone(),
ttl: self.ttl,
}
}
}
#[derive(Clone)]
pub struct IdempotencyService<S> {
inner: S,
cache: Arc<alun_cache::SharedCache>,
ttl: Duration,
}
#[derive(serde::Serialize, serde::Deserialize)]
struct CachedResponse {
status: u16,
content_type: String,
body: String,
}
impl<S> Service<Request<Body>> for IdempotencyService<S>
where
S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let idem_key = req.headers()
.get("x-idempotency-key")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let has_idem_key = idem_key.is_some();
let cache = self.cache.clone();
let ttl_secs = self.ttl.as_secs();
let mut inner = self.inner.clone();
Box::pin(async move {
if let Some(ref k) = idem_key {
let cache_key = format!("idem:{}", k);
if let Ok(Some(cached)) = cache.get::<CachedResponse>(&cache_key).await {
return Ok(Response::builder()
.status(StatusCode::from_u16(cached.status).unwrap_or(StatusCode::OK))
.header("Content-Type", &cached.content_type)
.body(Body::from(cached.body))
.expect("response body build failed"));
}
}
let resp = inner.call(req).await?;
if has_idem_key {
let k = idem_key.expect("has_idem_key checked above");
let status = resp.status().as_u16();
let content_type = resp.headers()
.get("Content-Type")
.and_then(|v| v.to_str().ok())
.unwrap_or("application/json; charset=utf-8")
.to_string();
let body_bytes = axum::body::to_bytes(resp.into_body(), 1024 * 1024).await;
let body_str = match body_bytes {
Ok(bytes) => String::from_utf8_lossy(&bytes).to_string(),
Err(_) => String::new(),
};
let cache_key = format!("idem:{}", k);
let _ = cache.set_ex(
&cache_key,
&CachedResponse {
status,
content_type: content_type.clone(),
body: body_str.clone(),
},
ttl_secs,
).await;
return Ok(Response::builder()
.status(StatusCode::from_u16(status).unwrap_or(StatusCode::OK))
.header("Content-Type", &content_type)
.body(Body::from(body_str))
.expect("response body build failed"));
}
Ok(resp)
})
}
}