#![cfg_attr(docsrs, doc(cfg(feature = "plugins")))]
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::time::Duration;
use std::time::Instant;
use anyhow::Result;
use bytes::Bytes;
use http::HeaderName;
use http::HeaderValue;
use http::Method;
use http::StatusCode;
use http::header::CONTENT_LENGTH;
use http::header::CONTENT_TYPE;
use http::header::LOCATION;
use http::header::RETRY_AFTER;
use http_body_util::BodyExt;
use scc::HashMap as SccHashMap;
use sha1::Digest;
use sha1::Sha1;
use tokio::sync::Notify;
#[cfg(not(feature = "compio"))]
use tokio::time::timeout;
use crate::body::TakoBody;
use crate::middleware::Next;
use crate::plugins::TakoPlugin;
use crate::responder::Responder;
use crate::router::Router;
use crate::types::Request;
use crate::types::Response;
#[derive(Clone, Copy)]
pub enum Scope {
KeyOnly,
MethodAndPath,
}
#[derive(Clone)]
pub struct Config {
pub header: HeaderName,
pub methods: Vec<Method>,
pub ttl_secs: u64,
pub scope: Scope,
pub coalesce_inflight: bool,
pub inflight_wait_timeout_ms: Option<u64>,
pub max_cached_body_bytes: usize,
pub max_request_body_bytes: usize,
pub verify_payload: bool,
pub cache_error_statuses: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
header: HeaderName::from_static("idempotency-key"),
methods: vec![Method::POST],
ttl_secs: 30,
scope: Scope::MethodAndPath,
coalesce_inflight: true,
inflight_wait_timeout_ms: None,
max_cached_body_bytes: 1 * 1024 * 1024,
max_request_body_bytes: 1 * 1024 * 1024,
verify_payload: true,
cache_error_statuses: true,
}
}
}
pub struct IdempotencyBuilder(Config);
impl IdempotencyBuilder {
pub fn new() -> Self {
Self(Config::default())
}
pub fn header(mut self, h: HeaderName) -> Self {
self.0.header = h;
self
}
pub fn methods(mut self, m: &[Method]) -> Self {
self.0.methods = m.to_vec();
self
}
pub fn ttl_secs(mut self, s: u64) -> Self {
self.0.ttl_secs = s;
self
}
pub fn scope(mut self, s: Scope) -> Self {
self.0.scope = s;
self
}
pub fn coalesce_inflight(mut self, yes: bool) -> Self {
self.0.coalesce_inflight = yes;
self
}
pub fn inflight_wait_timeout_ms(mut self, ms: Option<u64>) -> Self {
self.0.inflight_wait_timeout_ms = ms;
self
}
pub fn max_cached_body_bytes(mut self, n: usize) -> Self {
self.0.max_cached_body_bytes = n;
self
}
pub fn max_request_body_bytes(mut self, n: usize) -> Self {
self.0.max_request_body_bytes = n;
self
}
pub fn verify_payload(mut self, yes: bool) -> Self {
self.0.verify_payload = yes;
self
}
pub fn cache_error_statuses(mut self, yes: bool) -> Self {
self.0.cache_error_statuses = yes;
self
}
pub fn build(self) -> IdempotencyPlugin {
IdempotencyPlugin::new(self.0)
}
}
#[derive(Clone)]
struct CachedResponse {
status: StatusCode,
headers: Vec<(HeaderName, HeaderValue)>,
body: Bytes,
}
#[derive(Clone)]
struct Completed {
payload_sig: [u8; 20],
cached: Arc<CachedResponse>,
expires_at: Instant,
}
enum Entry {
InFlight {
payload_sig: [u8; 20],
notify: Arc<Notify>,
started: Instant,
},
Completed(Completed),
}
#[derive(Clone)]
struct Store(Arc<SccHashMap<String, Entry>>);
impl Store {
fn new() -> Self {
Self(Arc::new(SccHashMap::new()))
}
fn get(&self, k: &str) -> Option<Entry> {
self.0.get_sync(k).map(|e| match &*e {
Entry::InFlight {
payload_sig,
notify,
started,
} => Entry::InFlight {
payload_sig: *payload_sig,
notify: notify.clone(),
started: *started,
},
Entry::Completed(c) => Entry::Completed(c.clone()),
})
}
fn insert_inflight(&self, k: String, payload_sig: [u8; 20]) -> Arc<Notify> {
let notify = Arc::new(Notify::new());
std::mem::drop(self.0.insert_sync(
k,
Entry::InFlight {
payload_sig,
notify: notify.clone(),
started: Instant::now(),
},
));
notify
}
fn complete(&self, k: String, completed: Completed) {
std::mem::drop(self.0.insert_sync(k, Entry::Completed(completed)));
}
fn remove(&self, k: &str) {
let _ = self.0.remove_sync(k);
}
fn retain_expired(&self) {
let now = Instant::now();
self.0.retain_sync(|_, v| match v {
Entry::Completed(c) => c.expires_at > now,
Entry::InFlight { .. } => true,
});
}
}
#[derive(Clone)]
#[doc(alias = "idempotency")]
pub struct IdempotencyPlugin {
cfg: Config,
store: Store,
janitor_started: Arc<AtomicBool>,
}
impl IdempotencyPlugin {
pub fn builder() -> IdempotencyBuilder {
IdempotencyBuilder::new()
}
pub fn new(cfg: Config) -> Self {
Self {
cfg,
store: Store::new(),
janitor_started: Arc::new(AtomicBool::new(false)),
}
}
}
impl TakoPlugin for IdempotencyPlugin {
fn name(&self) -> &'static str {
"IdempotencyPlugin"
}
fn setup(&self, router: &Router) -> Result<()> {
let cfg = self.cfg.clone();
let store = self.store.clone();
router.middleware(move |req, next| {
let cfg = cfg.clone();
let store = store.clone();
async move { handle(req, next, cfg, store).await }
});
if !self.janitor_started.swap(true, Ordering::SeqCst) {
let store = self.store.clone();
let ttl = self.cfg.ttl_secs;
#[cfg(not(feature = "compio"))]
tokio::spawn(async move {
let mut tick = tokio::time::interval(Duration::from_secs(ttl.max(60).min(3600)));
loop {
tick.tick().await;
store.retain_expired();
}
});
#[cfg(feature = "compio")]
compio::runtime::spawn(async move {
let interval = Duration::from_secs(ttl.max(60).min(3600));
loop {
compio::time::sleep(interval).await;
store.retain_expired();
}
})
.detach();
}
Ok(())
}
}
async fn handle(req: Request, next: Next, cfg: Config, store: Store) -> impl Responder {
if !cfg.methods.iter().any(|m| m == req.method()) {
return next.run(req).await;
}
let key = match req.headers().get(&cfg.header) {
Some(v) => v.to_str().unwrap_or("").to_string(),
None => String::new(),
};
if key.is_empty() {
return next.run(req).await;
}
let (parts, body) = req.into_parts();
let collected = match body.collect().await {
Ok(c) => c.to_bytes(),
Err(_) => Bytes::new(),
};
if collected.len() > cfg.max_request_body_bytes {
return http::Response::builder()
.status(StatusCode::PAYLOAD_TOO_LARGE)
.body(TakoBody::empty())
.unwrap();
}
let body_bytes = collected.clone();
let mut hasher = Sha1::new();
if cfg.verify_payload {
hasher.update(parts.method.as_str().as_bytes());
hasher.update(parts.uri.path().as_bytes());
if let Some(ct) = parts.headers.get(CONTENT_TYPE) {
hasher.update(ct.as_bytes());
}
hasher.update(&body_bytes);
}
let sig: [u8; 20] = if cfg.verify_payload {
hasher.finalize().into()
} else {
[0u8; 20]
};
let new_req = http::Request::from_parts(parts, TakoBody::from(Bytes::from(body_bytes)));
let cache_key = match cfg.scope {
Scope::KeyOnly => key,
Scope::MethodAndPath => format!("{}|{}|{}", key, new_req.method(), new_req.uri().path()),
};
if let Some(entry) = store.get(&cache_key) {
match entry {
Entry::Completed(c) => {
if cfg.verify_payload && c.payload_sig != sig {
return conflict();
}
return build_response_from_cache(&c.cached);
}
Entry::InFlight {
payload_sig,
notify,
..
} => {
if !cfg.coalesce_inflight {
return conflict_inflight();
}
if cfg.verify_payload && payload_sig != sig {
return conflict();
}
if let Some(_ms) = cfg.inflight_wait_timeout_ms {
#[cfg(not(feature = "compio"))]
{
let _ = timeout(Duration::from_millis(_ms), notify.notified()).await;
}
#[cfg(feature = "compio")]
{
notify.notified().await;
}
} else {
notify.notified().await;
}
if let Some(Entry::Completed(c2)) = store.get(&cache_key) {
if cfg.verify_payload && c2.payload_sig != sig {
return conflict();
}
return build_response_from_cache(&c2.cached);
}
return conflict_inflight();
}
}
}
let notify = store.insert_inflight(cache_key.clone(), sig);
let mut resp = next.run(new_req).await;
let collected = match resp.body_mut().collect().await {
Ok(c) => c.to_bytes(),
Err(_) => Bytes::new(),
};
let body_bytes = if collected.len() > cfg.max_cached_body_bytes {
Bytes::new()
} else {
collected
};
let status = resp.status();
let should_cache = status.is_success() || status.is_redirection() || cfg.cache_error_statuses;
if should_cache {
let cached = Arc::new(CachedResponse {
status,
headers: filter_headers(resp.headers()),
body: body_bytes.clone(),
});
let completed = Completed {
payload_sig: sig,
cached: cached.clone(),
expires_at: Instant::now() + Duration::from_secs(cfg.ttl_secs),
};
store.complete(cache_key.clone(), completed);
notify.notify_waiters();
*resp.body_mut() = TakoBody::from(Bytes::from(cached.body.clone()));
return resp.into_response();
} else {
store.remove(&cache_key);
notify.notify_waiters();
*resp.body_mut() = TakoBody::from(Bytes::from(body_bytes));
return resp.into_response();
}
}
fn conflict() -> Response {
http::Response::builder()
.status(StatusCode::CONFLICT)
.body(TakoBody::empty())
.unwrap()
}
fn conflict_inflight() -> Response {
let mut resp = http::Response::builder()
.status(StatusCode::CONFLICT)
.body(TakoBody::empty())
.unwrap();
resp
.headers_mut()
.insert(RETRY_AFTER, HeaderValue::from_static("3"));
resp
}
fn build_response_from_cache(c: &CachedResponse) -> Response {
let mut b = http::Response::builder().status(c.status);
let headers = b.headers_mut().unwrap();
for (k, v) in &c.headers {
let _ = headers.insert(k, v.clone());
}
headers.remove(CONTENT_LENGTH);
b.body(TakoBody::from(Bytes::from(c.body.clone()))).unwrap()
}
fn filter_headers(src: &http::HeaderMap) -> Vec<(HeaderName, HeaderValue)> {
const EXCLUDE: &[&str] = &[
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"upgrade",
];
let mut out = Vec::new();
for (name, v) in src.iter() {
let name_lc = name.as_str().to_ascii_lowercase();
if EXCLUDE.contains(&name_lc.as_str()) {
continue;
}
if name == &CONTENT_LENGTH {
continue;
}
if name == &CONTENT_TYPE || name == &LOCATION {
out.push((name.clone(), v.clone()));
continue;
}
if name_lc.starts_with("x-") {
out.push((name.clone(), v.clone()));
continue;
}
}
out
}