#![cfg_attr(docsrs, doc(cfg(feature = "plugins")))]
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::time::Duration;
use std::time::Instant;
use anyhow::Result;
use bytes::Bytes;
use http::HeaderName;
use http::HeaderValue;
use http::Method;
use http::StatusCode;
use http::header::CONTENT_LENGTH;
use http::header::CONTENT_TYPE;
use http::header::RETRY_AFTER;
use http_body_util::BodyExt;
use scc::HashMap as SccHashMap;
use sha1::Digest;
use sha1::Sha1;
use tako_rs_core::body::TakoBody;
use tako_rs_core::middleware::Next;
use tako_rs_core::plugins::TakoPlugin;
use tako_rs_core::responder::Responder;
use tako_rs_core::router::Router;
use tako_rs_core::types::Request;
use tako_rs_core::types::Response;
use tokio::sync::Notify;
#[cfg(not(feature = "compio"))]
use tokio::time::timeout;
#[derive(Clone, Copy)]
pub enum Scope {
KeyOnly,
MethodAndPath,
}
#[derive(Clone)]
pub struct Config {
pub header: HeaderName,
pub methods: Vec<Method>,
pub ttl_secs: u64,
pub scope: Scope,
pub coalesce_inflight: bool,
pub inflight_wait_timeout_ms: Option<u64>,
pub max_cached_body_bytes: usize,
pub max_request_body_bytes: usize,
pub verify_payload: bool,
pub cache_error_statuses: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
header: HeaderName::from_static("idempotency-key"),
methods: vec![Method::POST],
ttl_secs: 86400,
scope: Scope::MethodAndPath,
coalesce_inflight: true,
inflight_wait_timeout_ms: None,
max_cached_body_bytes: 1024 * 1024,
max_request_body_bytes: 1024 * 1024,
verify_payload: true,
cache_error_statuses: true,
}
}
}
pub struct IdempotencyBuilder(Config);
impl Default for IdempotencyBuilder {
fn default() -> Self {
Self::new()
}
}
impl IdempotencyBuilder {
pub fn new() -> Self {
Self(Config::default())
}
pub fn header(mut self, h: HeaderName) -> Self {
self.0.header = h;
self
}
pub fn methods(mut self, m: &[Method]) -> Self {
self.0.methods = m.to_vec();
self
}
pub fn ttl_secs(mut self, s: u64) -> Self {
self.0.ttl_secs = s;
self
}
pub fn scope(mut self, s: Scope) -> Self {
self.0.scope = s;
self
}
pub fn coalesce_inflight(mut self, yes: bool) -> Self {
self.0.coalesce_inflight = yes;
self
}
pub fn inflight_wait_timeout_ms(mut self, ms: Option<u64>) -> Self {
self.0.inflight_wait_timeout_ms = ms;
self
}
pub fn max_cached_body_bytes(mut self, n: usize) -> Self {
self.0.max_cached_body_bytes = n;
self
}
pub fn max_request_body_bytes(mut self, n: usize) -> Self {
self.0.max_request_body_bytes = n;
self
}
pub fn verify_payload(mut self, yes: bool) -> Self {
self.0.verify_payload = yes;
self
}
pub fn cache_error_statuses(mut self, yes: bool) -> Self {
self.0.cache_error_statuses = yes;
self
}
pub fn build(self) -> IdempotencyPlugin {
IdempotencyPlugin::new(self.0)
}
}
#[derive(Clone)]
struct CachedResponse {
status: StatusCode,
headers: Vec<(HeaderName, HeaderValue)>,
body: Bytes,
}
#[derive(Clone)]
struct Completed {
payload_sig: [u8; 20],
cached: Arc<CachedResponse>,
expires_at: Instant,
}
enum Entry {
InFlight {
payload_sig: [u8; 20],
notify: Arc<Notify>,
started: Instant,
},
Completed(Completed),
}
#[derive(Clone)]
struct Store(Arc<SccHashMap<String, Entry>>);
struct InflightGuard {
store: Store,
cache_key: String,
notify: Arc<Notify>,
armed: bool,
}
impl InflightGuard {
fn new(store: Store, cache_key: String, notify: Arc<Notify>) -> Self {
Self {
store,
cache_key,
notify,
armed: true,
}
}
fn disarm(&mut self) {
self.armed = false;
}
}
impl Drop for InflightGuard {
fn drop(&mut self) {
if self.armed {
self.store.remove(&self.cache_key);
self.notify.notify_waiters();
}
}
}
impl Store {
fn new() -> Self {
Self(Arc::new(SccHashMap::new()))
}
fn get(&self, k: &str) -> Option<Entry> {
self.0.get_sync(k).map(|e| match &*e {
Entry::InFlight {
payload_sig,
notify,
started,
} => Entry::InFlight {
payload_sig: *payload_sig,
notify: notify.clone(),
started: *started,
},
Entry::Completed(c) => Entry::Completed(c.clone()),
})
}
fn install_inflight_or_get_existing(
&self,
k: String,
payload_sig: [u8; 20],
) -> Result<Arc<Notify>, Entry> {
use scc::hash_map::Entry as MapEntry;
match self.0.entry_sync(k) {
MapEntry::Vacant(v) => {
let notify = Arc::new(Notify::new());
v.insert_entry(Entry::InFlight {
payload_sig,
notify: notify.clone(),
started: Instant::now(),
});
Ok(notify)
}
MapEntry::Occupied(o) => Err(match o.get() {
Entry::Completed(c) => Entry::Completed(c.clone()),
Entry::InFlight {
payload_sig,
notify,
started,
} => Entry::InFlight {
payload_sig: *payload_sig,
notify: notify.clone(),
started: *started,
},
}),
}
}
fn complete(&self, k: String, completed: Completed) {
self.0.upsert_sync(k, Entry::Completed(completed));
}
fn remove(&self, k: &str) {
let _ = self.0.remove_sync(k);
}
fn retain_expired(&self) {
let now = Instant::now();
self.0.retain_sync(|_, v| match v {
Entry::Completed(c) => c.expires_at > now,
Entry::InFlight { .. } => true,
});
}
}
#[derive(Clone)]
#[doc(alias = "idempotency")]
pub struct IdempotencyPlugin {
cfg: Config,
store: Store,
janitor_started: Arc<AtomicBool>,
}
impl IdempotencyPlugin {
pub fn builder() -> IdempotencyBuilder {
IdempotencyBuilder::new()
}
pub fn new(cfg: Config) -> Self {
Self {
cfg,
store: Store::new(),
janitor_started: Arc::new(AtomicBool::new(false)),
}
}
}
impl TakoPlugin for IdempotencyPlugin {
fn name(&self) -> &'static str {
"IdempotencyPlugin"
}
fn setup(&self, router: &Router) -> Result<()> {
let cfg = self.cfg.clone();
let store = self.store.clone();
router.middleware(move |req, next| {
let cfg = cfg.clone();
let store = store.clone();
async move { handle(req, next, cfg, store).await }
});
if !self.janitor_started.swap(true, Ordering::SeqCst) {
let store = self.store.clone();
let ttl = self.cfg.ttl_secs;
#[cfg(not(feature = "compio"))]
tokio::spawn(async move {
let mut tick = tokio::time::interval(Duration::from_secs(ttl.clamp(5, 3600)));
loop {
tick.tick().await;
store.retain_expired();
}
});
#[cfg(feature = "compio")]
compio::runtime::spawn(async move {
let interval = Duration::from_secs(ttl.clamp(5, 3600));
loop {
compio::time::sleep(interval).await;
store.retain_expired();
}
})
.detach();
}
Ok(())
}
}
async fn handle(req: Request, next: Next, cfg: Config, store: Store) -> impl Responder {
if !cfg.methods.iter().any(|m| m == req.method()) {
return next.run(req).await;
}
let key = match req.headers().get(&cfg.header) {
Some(v) => match v.to_str() {
Ok(s) if !s.is_empty() => s.to_string(),
Ok(_) => return next.run(req).await,
Err(_) => {
return (
http::StatusCode::BAD_REQUEST,
"Idempotency-Key must be visible ASCII",
)
.into_response();
}
},
None => return next.run(req).await,
};
let (parts, body) = req.into_parts();
let limited = http_body_util::Limited::new(body, cfg.max_request_body_bytes);
let collected = match limited.collect().await {
Ok(c) => c.to_bytes(),
Err(_) => {
return http::Response::builder()
.status(StatusCode::PAYLOAD_TOO_LARGE)
.body(TakoBody::empty())
.unwrap();
}
};
let body_bytes = collected.clone();
let mut hasher = Sha1::new();
if cfg.verify_payload {
hasher.update(parts.method.as_str().as_bytes());
hasher.update(parts.uri.path().as_bytes());
if let Some(ct) = parts.headers.get(CONTENT_TYPE) {
hasher.update(ct.as_bytes());
}
hasher.update(&body_bytes);
}
let sig: [u8; 20] = if cfg.verify_payload {
hasher.finalize().into()
} else {
[0u8; 20]
};
let new_req = http::Request::from_parts(parts, TakoBody::from(body_bytes));
let cache_key = match cfg.scope {
Scope::KeyOnly => key,
Scope::MethodAndPath => format!("{}|{}|{}", key, new_req.method(), new_req.uri().path()),
};
let notify = match store.install_inflight_or_get_existing(cache_key.clone(), sig) {
Err(Entry::Completed(c)) => {
let legacy_unverified = c.payload_sig == [0u8; 20];
if cfg.verify_payload && !legacy_unverified && c.payload_sig != sig {
return conflict();
}
return build_response_from_cache(&c.cached);
}
Err(Entry::InFlight {
payload_sig,
notify,
..
}) => {
if !cfg.coalesce_inflight {
return conflict_inflight();
}
let legacy_unverified = payload_sig == [0u8; 20];
if cfg.verify_payload && !legacy_unverified && payload_sig != sig {
return conflict();
}
if let Some(ms) = cfg.inflight_wait_timeout_ms {
#[cfg(not(feature = "compio"))]
{
let _ = timeout(Duration::from_millis(ms), notify.notified()).await;
}
#[cfg(feature = "compio")]
{
let timeout_signal = Arc::new(Notify::new());
let timer_signal = timeout_signal.clone();
let timer_task = compio::runtime::spawn(async move {
compio::time::sleep(Duration::from_millis(ms)).await;
timer_signal.notify_waiters();
});
futures_util::future::select(
std::pin::pin!(notify.notified()),
std::pin::pin!(timeout_signal.notified()),
)
.await;
drop(timer_task);
}
} else {
notify.notified().await;
}
if let Some(Entry::Completed(c2)) = store.get(&cache_key) {
if cfg.verify_payload && c2.payload_sig != sig {
return conflict();
}
return build_response_from_cache(&c2.cached);
}
return conflict_inflight();
}
Ok(notify) => notify,
};
let mut inflight_guard = InflightGuard::new(store.clone(), cache_key.clone(), notify.clone());
let mut resp = next.run(new_req).await;
let collected = match resp.body_mut().collect().await {
Ok(c) => c.to_bytes(),
Err(_) => {
return bad_gateway();
}
};
let body_bytes = if collected.len() > cfg.max_cached_body_bytes {
Bytes::new()
} else {
collected
};
let status = resp.status();
let is_error = status.is_client_error() || status.is_server_error();
let cached = Arc::new(CachedResponse {
status,
headers: filter_headers(resp.headers()),
body: body_bytes.clone(),
});
let ttl = if is_error && !cfg.cache_error_statuses {
Duration::from_secs(1)
} else {
Duration::from_secs(cfg.ttl_secs)
};
let completed = Completed {
payload_sig: sig,
cached: cached.clone(),
expires_at: Instant::now() + ttl,
};
store.complete(cache_key.clone(), completed);
notify.notify_waiters();
inflight_guard.disarm();
*resp.body_mut() = TakoBody::from(cached.body.clone());
resp.into_response()
}
fn conflict() -> Response {
conflict_response(None)
}
fn conflict_inflight() -> Response {
conflict_response(Some(3))
}
fn conflict_response(retry_after_secs: Option<u32>) -> Response {
let mut resp = http::Response::builder()
.status(StatusCode::CONFLICT)
.body(TakoBody::empty())
.unwrap();
if let Some(secs) = retry_after_secs {
resp.headers_mut().insert(
RETRY_AFTER,
HeaderValue::from_str(&secs.to_string()).unwrap_or_else(|_| HeaderValue::from_static("3")),
);
}
resp
}
fn bad_gateway() -> Response {
http::Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(TakoBody::empty())
.unwrap()
}
fn build_response_from_cache(c: &CachedResponse) -> Response {
let mut b = http::Response::builder().status(c.status);
let Some(headers) = b.headers_mut() else {
return http::Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(TakoBody::empty())
.expect("static 500 builder");
};
for (k, v) in &c.headers {
let _ = headers.insert(k, v.clone());
}
headers.remove(CONTENT_LENGTH);
b.body(TakoBody::from(c.body.clone())).unwrap_or_else(|_| {
http::Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(TakoBody::empty())
.expect("static 500 builder")
})
}
fn filter_headers(src: &http::HeaderMap) -> Vec<(HeaderName, HeaderValue)> {
const DENY: &[&str] = &[
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"upgrade",
"content-length",
"set-cookie",
];
let mut out = Vec::with_capacity(src.keys_len());
for (name, v) in src {
let name_lc = name.as_str().to_ascii_lowercase();
if DENY.contains(&name_lc.as_str()) {
continue;
}
out.push((name.clone(), v.clone()));
}
out
}