use axum::body::Bytes;
use axum::http::{HeaderMap, HeaderValue, StatusCode};
use dashmap::DashMap;
use serde::Serialize;
use std::sync::Arc;
use std::time::{Duration, Instant};
pub const IDEMPOTENCY_KEY_HEADER: &str = "idempotency-key";
pub const IDEMPOTENCY_KEY_USED_HEADER: &str = "idempotency-key-used";
pub const DEFAULT_TTL: Duration = Duration::from_secs(86400);
const MAX_KEY_LENGTH: usize = 256;
const MAX_CACHE_SIZE: usize = 100_000;
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct IdempotencyKey {
key: String,
scope: String,
}
impl IdempotencyKey {
pub fn new(key: impl Into<String>, method: &str, path: &str) -> Option<Self> {
let key = key.into();
if key.is_empty() || key.len() > MAX_KEY_LENGTH {
return None;
}
if !key
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
return None;
}
Some(Self {
key,
scope: format!("{}:{}", method, path),
})
}
pub fn from_headers(headers: &HeaderMap, method: &str, path: &str) -> Option<Self> {
headers
.get(IDEMPOTENCY_KEY_HEADER)
.and_then(|v| v.to_str().ok())
.and_then(|key| Self::new(key, method, path))
}
pub fn value(&self) -> &str {
&self.key
}
}
#[derive(Debug, Clone)]
pub struct CachedResponse {
pub status: StatusCode,
pub body: Bytes,
pub content_type: Option<String>,
pub cached_at: Instant,
}
impl CachedResponse {
pub fn new(status: StatusCode, body: Bytes, content_type: Option<String>) -> Self {
Self {
status,
body,
content_type,
cached_at: Instant::now(),
}
}
pub fn from_json<T: Serialize>(status: StatusCode, value: &T) -> Option<Self> {
serde_json::to_vec(value).ok().map(|body| Self {
status,
body: Bytes::from(body),
content_type: Some("application/json".to_string()),
cached_at: Instant::now(),
})
}
pub fn is_expired(&self, ttl: Duration) -> bool {
self.cached_at.elapsed() > ttl
}
pub fn into_axum_response(self) -> axum::response::Response {
use axum::http::header::CONTENT_TYPE;
use axum::response::IntoResponse;
let mut response = (self.status, self.body).into_response();
if let Some(content_type) = self.content_type {
if let Ok(value) = HeaderValue::from_str(&content_type) {
response.headers_mut().insert(CONTENT_TYPE, value);
}
}
response.headers_mut().insert(
IDEMPOTENCY_KEY_USED_HEADER,
HeaderValue::from_static("true"),
);
response
}
}
#[derive(Clone)]
pub struct IdempotencyCache {
cache: Arc<DashMap<IdempotencyKey, CachedResponse>>,
ttl: Duration,
}
impl IdempotencyCache {
pub fn new(ttl: Duration) -> Self {
Self {
cache: Arc::new(DashMap::new()),
ttl,
}
}
pub fn default_cache() -> Self {
Self::new(DEFAULT_TTL)
}
pub fn get(&self, key: &IdempotencyKey) -> Option<CachedResponse> {
self.cache.get(key).and_then(|entry| {
if entry.is_expired(self.ttl) {
drop(entry);
self.cache.remove(key);
None
} else {
Some(entry.clone())
}
})
}
pub fn set(&self, key: IdempotencyKey, response: CachedResponse) {
if self.cache.len() >= MAX_CACHE_SIZE {
self.evict_oldest();
}
self.cache.insert(key, response);
}
fn evict_oldest(&self) {
let evict_count = MAX_CACHE_SIZE / 10;
let mut entries: Vec<(IdempotencyKey, Instant)> = self
.cache
.iter()
.map(|entry| (entry.key().clone(), entry.value().cached_at))
.collect();
entries.sort_by_key(|(_, cached_at)| *cached_at);
for (key, _) in entries.into_iter().take(evict_count) {
self.cache.remove(&key);
}
tracing::debug!(
evicted = evict_count,
remaining = self.cache.len(),
"Evicted oldest idempotency cache entries"
);
}
pub fn remove(&self, key: &IdempotencyKey) {
self.cache.remove(key);
}
pub fn contains(&self, key: &IdempotencyKey) -> bool {
self.cache.contains_key(key)
}
pub fn cleanup(&self) {
self.cache
.retain(|_, response| !response.is_expired(self.ttl));
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn ttl(&self) -> Duration {
self.ttl
}
}
impl Default for IdempotencyCache {
fn default() -> Self {
Self::default_cache()
}
}
#[allow(dead_code)]
pub struct IdempotencyGuard<'a> {
cache: &'a IdempotencyCache,
key: IdempotencyKey,
completed: bool,
}
#[allow(dead_code)]
impl<'a> IdempotencyGuard<'a> {
fn new(cache: &'a IdempotencyCache, key: IdempotencyKey) -> Self {
Self {
cache,
key,
completed: false,
}
}
pub fn complete(mut self, response: CachedResponse) {
self.completed = true;
self.cache.set(self.key.clone(), response);
}
}
impl<'a> Drop for IdempotencyGuard<'a> {
fn drop(&mut self) {
if !self.completed {
self.cache.remove(&self.key);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_idempotency_key_validation() {
assert!(IdempotencyKey::new("abc123", "POST", "/v1/tables").is_some());
assert!(IdempotencyKey::new("uuid-with-dashes", "POST", "/v1/tables").is_some());
assert!(IdempotencyKey::new("key_with_underscores", "POST", "/v1/tables").is_some());
assert!(IdempotencyKey::new("", "POST", "/v1/tables").is_none()); assert!(IdempotencyKey::new("key with spaces", "POST", "/v1/tables").is_none()); assert!(IdempotencyKey::new("key@symbol", "POST", "/v1/tables").is_none());
let long_key = "a".repeat(MAX_KEY_LENGTH + 1);
assert!(IdempotencyKey::new(&long_key, "POST", "/v1/tables").is_none());
}
#[test]
fn test_idempotency_key_scoping() {
let key1 = IdempotencyKey::new("same-key", "POST", "/v1/tables").unwrap();
let key2 = IdempotencyKey::new("same-key", "DELETE", "/v1/tables").unwrap();
let key3 = IdempotencyKey::new("same-key", "POST", "/v1/namespaces").unwrap();
assert_ne!(key1, key2);
assert_ne!(key1, key3);
assert_ne!(key2, key3);
}
#[test]
fn test_idempotency_key_from_headers() {
let mut headers = HeaderMap::new();
headers.insert(
IDEMPOTENCY_KEY_HEADER,
HeaderValue::from_static("test-key-123"),
);
let key = IdempotencyKey::from_headers(&headers, "POST", "/v1/tables").unwrap();
assert_eq!(key.value(), "test-key-123");
}
#[test]
fn test_cached_response_expiry() {
let response = CachedResponse::new(
StatusCode::OK,
Bytes::from("test"),
Some("application/json".to_string()),
);
assert!(!response.is_expired(Duration::from_secs(60)));
assert!(response.is_expired(Duration::from_nanos(1)));
}
#[test]
fn test_idempotency_cache_basic() {
let cache = IdempotencyCache::new(Duration::from_secs(60));
let key = IdempotencyKey::new("test-key", "POST", "/v1/tables").unwrap();
assert!(cache.get(&key).is_none());
let response = CachedResponse::new(
StatusCode::CREATED,
Bytes::from(r#"{"result": "ok"}"#),
Some("application/json".to_string()),
);
cache.set(key.clone(), response);
let cached = cache.get(&key).unwrap();
assert_eq!(cached.status, StatusCode::CREATED);
}
#[test]
fn test_idempotency_cache_expiry() {
let cache = IdempotencyCache::new(Duration::from_millis(10));
let key = IdempotencyKey::new("test-key", "POST", "/v1/tables").unwrap();
let response = CachedResponse::new(StatusCode::OK, Bytes::from("test"), None);
cache.set(key.clone(), response);
assert!(cache.get(&key).is_some());
thread::sleep(Duration::from_millis(20));
assert!(cache.get(&key).is_none());
}
#[test]
fn test_idempotency_cache_cleanup() {
let cache = IdempotencyCache::new(Duration::from_millis(10));
for i in 0..5 {
let key = IdempotencyKey::new(format!("key-{}", i), "POST", "/v1/tables").unwrap();
let response = CachedResponse::new(StatusCode::OK, Bytes::from("test"), None);
cache.set(key, response);
}
assert_eq!(cache.len(), 5);
thread::sleep(Duration::from_millis(20));
cache.cleanup();
assert_eq!(cache.len(), 0);
}
#[test]
fn test_cached_response_from_json() {
#[derive(Serialize)]
struct TestResponse {
message: String,
}
let value = TestResponse {
message: "success".to_string(),
};
let response = CachedResponse::from_json(StatusCode::CREATED, &value).unwrap();
assert_eq!(response.status, StatusCode::CREATED);
assert_eq!(response.content_type, Some("application/json".to_string()));
assert!(std::str::from_utf8(&response.body)
.unwrap()
.contains("success"));
}
#[test]
fn test_idempotency_cache_bounded_size() {
let cache = IdempotencyCache::new(Duration::from_secs(3600));
let test_size = 1000;
for i in 0..test_size {
let key = IdempotencyKey::new(format!("key-{}", i), "POST", "/v1/tables").unwrap();
let response = CachedResponse::new(StatusCode::OK, Bytes::from("test"), None);
cache.set(key, response);
}
assert_eq!(cache.len(), test_size);
cache.evict_oldest();
assert!(cache.len() < test_size);
}
}