use axum::body::Bytes;
use axum::http::{HeaderMap, HeaderValue, StatusCode};
use moka::sync::Cache;
use serde::Serialize;
use std::sync::Arc;
use std::time::{Duration, Instant};
pub const IDEMPOTENCY_KEY_HEADER: &str = "idempotency-key";
pub const IDEMPOTENCY_KEY_USED_HEADER: &str = "idempotency-key-used";
pub const DEFAULT_TTL: Duration = Duration::from_secs(86400);
const MAX_KEY_LENGTH: usize = 256;
const MAX_CACHE_SIZE: usize = 100_000;
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct IdempotencyKey {
key: String,
scope: String,
}
impl IdempotencyKey {
pub fn new(key: impl Into<String>, method: &str, path: &str) -> Option<Self> {
let key = key.into();
if key.is_empty() || key.len() > MAX_KEY_LENGTH {
return None;
}
if !key
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
return None;
}
Some(Self {
key,
scope: format!("{}:{}", method, path),
})
}
pub fn from_headers(headers: &HeaderMap, method: &str, path: &str) -> Option<Self> {
headers
.get(IDEMPOTENCY_KEY_HEADER)
.and_then(|v| v.to_str().ok())
.and_then(|key| Self::new(key, method, path))
}
pub fn value(&self) -> &str {
&self.key
}
}
#[derive(Debug, Clone)]
pub struct CachedResponse {
pub status: StatusCode,
pub body: Bytes,
pub content_type: Option<String>,
pub cached_at: Instant,
}
impl CachedResponse {
pub fn new(status: StatusCode, body: Bytes, content_type: Option<String>) -> Self {
Self {
status,
body,
content_type,
cached_at: Instant::now(),
}
}
pub fn from_json<T: Serialize>(status: StatusCode, value: &T) -> Option<Self> {
match serde_json::to_vec(value) {
Ok(body) => Some(Self {
status,
body: Bytes::from(body),
content_type: Some("application/json".to_string()),
cached_at: Instant::now(),
}),
Err(e) => {
tracing::warn!(error = %e, "Failed to serialize response for idempotency cache");
None
}
}
}
pub fn is_expired(&self, ttl: Duration) -> bool {
self.cached_at.elapsed() > ttl
}
pub fn into_axum_response(self) -> axum::response::Response {
use axum::http::header::CONTENT_TYPE;
use axum::response::IntoResponse;
let mut response = (self.status, self.body).into_response();
if let Some(content_type) = self.content_type {
if let Ok(value) = HeaderValue::from_str(&content_type) {
response.headers_mut().insert(CONTENT_TYPE, value);
}
}
response.headers_mut().insert(
IDEMPOTENCY_KEY_USED_HEADER,
HeaderValue::from_static("true"),
);
response
}
}
use crate::catalog::idempotency_store::{IdempotencyEntry, IdempotencyStore};
#[derive(Clone)]
pub struct IdempotencyCache {
cache: Cache<IdempotencyKey, CachedResponse>,
ttl: Duration,
persistent_store: Option<Arc<dyn IdempotencyStore>>,
}
impl IdempotencyCache {
pub fn new(ttl: Duration) -> Self {
Self {
cache: Cache::builder()
.max_capacity(MAX_CACHE_SIZE as u64)
.time_to_live(ttl)
.build(),
ttl,
persistent_store: None,
}
}
pub fn with_persistent_store(ttl: Duration, store: Arc<dyn IdempotencyStore>) -> Self {
Self {
cache: Cache::builder()
.max_capacity(MAX_CACHE_SIZE as u64)
.time_to_live(ttl)
.build(),
ttl,
persistent_store: Some(store),
}
}
pub fn default_cache() -> Self {
Self::new(DEFAULT_TTL)
}
pub async fn bootstrap_from_store(&self) -> crate::error::Result<usize> {
let store = match &self.persistent_store {
Some(s) => s,
None => return Ok(0),
};
let _ = store.cleanup_expired().await?;
let count = store.count().await?;
tracing::info!(
entries = count,
"Bootstrapped idempotency cache from persistent store"
);
Ok(count)
}
pub fn get(&self, key: &IdempotencyKey) -> Option<CachedResponse> {
self.cache.get(key)
}
pub async fn get_from_persistent(&self, key: &IdempotencyKey) -> Option<CachedResponse> {
let store = self.persistent_store.as_ref()?;
match store.get(&key.scope, key.value()).await {
Ok(Some(entry)) => {
let response = CachedResponse::new(
StatusCode::from_u16(entry.status_code).unwrap_or(StatusCode::OK),
Bytes::from(entry.response_body.clone()),
entry.content_type.clone(),
);
self.cache.insert(key.clone(), response.clone());
Some(response)
}
_ => None,
}
}
pub fn set(&self, key: IdempotencyKey, response: CachedResponse) {
if let Some(store) = &self.persistent_store {
let store = store.clone();
let key_value = key.value().to_string();
let scope = key.scope.clone();
let status_code = response.status.as_u16();
let response_body = response.body.to_vec();
let content_type = response.content_type.clone();
let ttl = self.ttl;
tokio::spawn(async move {
let entry = IdempotencyEntry::new(
key_value,
scope,
status_code,
response_body,
content_type,
ttl,
);
if let Err(e) = store.set(entry).await {
tracing::warn!(error = %e, "Failed to persist idempotency entry");
}
});
}
self.cache.insert(key, response);
}
pub fn remove(&self, key: &IdempotencyKey) {
self.cache.invalidate(key);
}
pub fn contains(&self, key: &IdempotencyKey) -> bool {
self.cache.contains_key(key)
}
pub fn try_begin(&self, key: IdempotencyKey) -> Result<IdempotencyGuard<'_>, CachedResponse> {
if let Some(existing) = self.get(&key) {
return Err(existing);
}
let placeholder = CachedResponse::new(
StatusCode::ACCEPTED,
Bytes::from_static(b""),
Some("application/json".to_string()),
);
self.cache.insert(key.clone(), placeholder);
Ok(IdempotencyGuard::new(self, key))
}
pub fn cleanup(&self) {
self.cache.run_pending_tasks();
}
pub fn len(&self) -> usize {
self.cache.entry_count() as usize
}
pub fn is_empty(&self) -> bool {
self.cache.entry_count() == 0
}
pub fn ttl(&self) -> Duration {
self.ttl
}
}
impl Default for IdempotencyCache {
fn default() -> Self {
Self::default_cache()
}
}
pub struct IdempotencyGuard<'a> {
cache: &'a IdempotencyCache,
key: IdempotencyKey,
completed: bool,
}
impl<'a> IdempotencyGuard<'a> {
fn new(cache: &'a IdempotencyCache, key: IdempotencyKey) -> Self {
Self {
cache,
key,
completed: false,
}
}
pub fn complete(mut self, response: CachedResponse) {
self.completed = true;
self.cache.set(self.key.clone(), response);
}
}
impl<'a> Drop for IdempotencyGuard<'a> {
fn drop(&mut self) {
if !self.completed {
self.cache.remove(&self.key);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_idempotency_key_validation() {
assert!(IdempotencyKey::new("abc123", "POST", "/v1/tables").is_some());
assert!(IdempotencyKey::new("uuid-with-dashes", "POST", "/v1/tables").is_some());
assert!(IdempotencyKey::new("key_with_underscores", "POST", "/v1/tables").is_some());
assert!(IdempotencyKey::new("", "POST", "/v1/tables").is_none()); assert!(IdempotencyKey::new("key with spaces", "POST", "/v1/tables").is_none()); assert!(IdempotencyKey::new("key@symbol", "POST", "/v1/tables").is_none());
let long_key = "a".repeat(MAX_KEY_LENGTH + 1);
assert!(IdempotencyKey::new(&long_key, "POST", "/v1/tables").is_none());
}
#[test]
fn test_idempotency_key_scoping() {
let key1 = IdempotencyKey::new("same-key", "POST", "/v1/tables").unwrap();
let key2 = IdempotencyKey::new("same-key", "DELETE", "/v1/tables").unwrap();
let key3 = IdempotencyKey::new("same-key", "POST", "/v1/namespaces").unwrap();
assert_ne!(key1, key2);
assert_ne!(key1, key3);
assert_ne!(key2, key3);
}
#[test]
fn test_idempotency_key_from_headers() {
let mut headers = HeaderMap::new();
headers.insert(
IDEMPOTENCY_KEY_HEADER,
HeaderValue::from_static("test-key-123"),
);
let key = IdempotencyKey::from_headers(&headers, "POST", "/v1/tables").unwrap();
assert_eq!(key.value(), "test-key-123");
}
#[test]
fn test_cached_response_expiry() {
let response = CachedResponse::new(
StatusCode::OK,
Bytes::from("test"),
Some("application/json".to_string()),
);
assert!(!response.is_expired(Duration::from_secs(60)));
assert!(response.is_expired(Duration::from_nanos(1)));
}
#[test]
fn test_idempotency_cache_basic() {
let cache = IdempotencyCache::new(Duration::from_secs(60));
let key = IdempotencyKey::new("test-key", "POST", "/v1/tables").unwrap();
assert!(cache.get(&key).is_none());
let response = CachedResponse::new(
StatusCode::CREATED,
Bytes::from(r#"{"result": "ok"}"#),
Some("application/json".to_string()),
);
cache.set(key.clone(), response);
let cached = cache.get(&key).unwrap();
assert_eq!(cached.status, StatusCode::CREATED);
}
#[test]
fn test_idempotency_cache_expiry() {
let cache = IdempotencyCache::new(Duration::from_millis(10));
let key = IdempotencyKey::new("test-key", "POST", "/v1/tables").unwrap();
let response = CachedResponse::new(StatusCode::OK, Bytes::from("test"), None);
cache.set(key.clone(), response);
assert!(cache.get(&key).is_some());
thread::sleep(Duration::from_millis(20));
assert!(cache.get(&key).is_none());
}
#[test]
fn test_idempotency_cache_cleanup() {
let cache = IdempotencyCache::new(Duration::from_millis(100));
for i in 0..5 {
let key = IdempotencyKey::new(format!("key-{}", i), "POST", "/v1/tables").unwrap();
let response = CachedResponse::new(StatusCode::OK, Bytes::from("test"), None);
cache.set(key, response);
}
cache.cleanup();
assert!(
cache.len() >= 4,
"Expected at least 4 entries, got {}",
cache.len()
);
thread::sleep(Duration::from_millis(150));
cache.cleanup();
assert_eq!(cache.len(), 0);
}
#[test]
fn test_cached_response_from_json() {
#[derive(Serialize)]
struct TestResponse {
message: String,
}
let value = TestResponse {
message: "success".to_string(),
};
let response = CachedResponse::from_json(StatusCode::CREATED, &value).unwrap();
assert_eq!(response.status, StatusCode::CREATED);
assert_eq!(response.content_type, Some("application/json".to_string()));
assert!(std::str::from_utf8(&response.body)
.unwrap()
.contains("success"));
}
#[test]
fn test_idempotency_cache_bounded_size() {
let cache = IdempotencyCache::new(Duration::from_secs(3600));
let test_size = 1000;
for i in 0..test_size {
let key = IdempotencyKey::new(format!("key-{}", i), "POST", "/v1/tables").unwrap();
let response = CachedResponse::new(StatusCode::OK, Bytes::from("test"), None);
cache.set(key, response);
}
cache.cleanup();
assert!(cache.len() <= test_size);
let key = IdempotencyKey::new("key-0", "POST", "/v1/tables").unwrap();
assert!(cache.get(&key).is_some());
}
#[test]
fn test_idempotency_guard_complete() {
let cache = IdempotencyCache::new(Duration::from_secs(3600));
let key = IdempotencyKey::new("guard-ok", "POST", "/v1/tables").unwrap();
let guard = cache.try_begin(key.clone()).expect("should acquire guard");
let response = CachedResponse::new(StatusCode::OK, Bytes::from("done"), None);
guard.complete(response);
let cached = cache.get(&key).expect("entry should exist after complete");
assert_eq!(cached.status, StatusCode::OK);
assert_eq!(&cached.body[..], b"done");
}
#[test]
fn test_idempotency_guard_drop_without_complete() {
let cache = IdempotencyCache::new(Duration::from_secs(3600));
let key = IdempotencyKey::new("guard-drop", "POST", "/v1/tables").unwrap();
{
let _guard = cache.try_begin(key.clone()).expect("should acquire guard");
}
assert!(
cache.get(&key).is_none(),
"entry should be removed on guard drop"
);
}
#[test]
fn test_try_begin_returns_cached_response() {
let cache = IdempotencyCache::new(Duration::from_secs(3600));
let key = IdempotencyKey::new("try-begin", "POST", "/v1/tables").unwrap();
let response = CachedResponse::new(StatusCode::CREATED, Bytes::from("existing"), None);
cache.set(key.clone(), response);
let result = cache.try_begin(key);
assert!(result.is_err(), "should return Err with cached response");
let cached = match result {
Err(resp) => resp,
Ok(_) => panic!("expected Err with cached response"),
};
assert_eq!(cached.status, StatusCode::CREATED);
}
}