use std::{
sync::Arc,
time::{Duration, Instant},
};
use dashmap::DashMap;
use serde_json::Value;
use xxhash_rust::xxh3::xxh3_64;
#[derive(Debug)]
pub enum IdempotencyCheck {
New,
Replay(StoredResponse),
Conflict,
}
#[derive(Debug, Clone)]
#[cfg_attr(
feature = "redis-idempotency",
derive(serde::Serialize, serde::Deserialize)
)]
pub struct StoredResponse {
pub status: u16,
pub headers: Vec<(String, String)>,
pub body: Option<Value>,
}
struct Entry {
response: StoredResponse,
body_hash: u64,
created_at: Instant,
}
pub trait IdempotencyStore: Send + Sync {
fn check(
&self,
key: &str,
body_hash: u64,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = IdempotencyCheck> + Send + '_>>;
fn store(
&self,
key: String,
body_hash: u64,
response: StoredResponse,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + '_>>;
}
pub struct InMemoryIdempotencyStore {
entries: DashMap<String, Entry>,
ttl: Duration,
max_entries: usize,
}
impl InMemoryIdempotencyStore {
#[must_use]
pub fn new(ttl: Duration, max_entries: usize) -> Self {
Self {
entries: DashMap::new(),
ttl,
max_entries,
}
}
fn evict_expired(&self) {
let expired_keys: Vec<String> = self
.entries
.iter()
.filter(|e| e.created_at.elapsed() > self.ttl)
.take(100)
.map(|e| e.key().clone())
.collect();
for key in expired_keys {
self.entries.remove(&key);
}
}
fn find_oldest_key(&self) -> Option<String> {
self.entries.iter().min_by_key(|e| e.created_at).map(|e| e.key().clone())
}
}
impl IdempotencyStore for InMemoryIdempotencyStore {
fn check(
&self,
key: &str,
body_hash: u64,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = IdempotencyCheck> + Send + '_>> {
let result = if let Some(entry) = self.entries.get(key) {
if entry.created_at.elapsed() > self.ttl {
drop(entry);
self.entries.remove(key);
IdempotencyCheck::New
} else if entry.body_hash == body_hash {
IdempotencyCheck::Replay(entry.response.clone())
} else {
IdempotencyCheck::Conflict
}
} else {
IdempotencyCheck::New
};
Box::pin(std::future::ready(result))
}
fn store(
&self,
key: String,
body_hash: u64,
response: StoredResponse,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + '_>> {
self.evict_expired();
if self.entries.len() >= self.max_entries {
if let Some(oldest_key) = self.find_oldest_key() {
self.entries.remove(&oldest_key);
}
}
self.entries.insert(
key,
Entry {
response,
body_hash,
created_at: Instant::now(),
},
);
Box::pin(std::future::ready(()))
}
}
#[cfg(feature = "redis-idempotency")]
#[path = "redis_store.rs"]
mod redis_store;
#[cfg(feature = "redis-idempotency")]
pub use redis_store::RedisIdempotencyStore;
#[must_use]
pub fn hash_body(body: &Value) -> u64 {
let bytes = serde_json::to_vec(body).unwrap_or_default();
xxh3_64(&bytes)
}
#[must_use]
pub fn create_store(ttl_seconds: u64) -> Arc<dyn IdempotencyStore> {
Arc::new(InMemoryIdempotencyStore::new(Duration::from_secs(ttl_seconds), 10_000))
}
#[cfg(feature = "redis-idempotency")]
#[must_use]
pub fn create_store_with_redis(
ttl_seconds: u64,
redis_pool: Option<redis::aio::ConnectionManager>,
) -> Arc<dyn IdempotencyStore> {
if let Some(pool) = redis_pool {
Arc::new(RedisIdempotencyStore::new(pool, Duration::from_secs(ttl_seconds)))
} else {
create_store(ttl_seconds)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)] mod tests {
use serde_json::json;
use super::*;
fn make_store(ttl_secs: u64) -> InMemoryIdempotencyStore {
InMemoryIdempotencyStore::new(Duration::from_secs(ttl_secs), 100)
}
fn make_response() -> StoredResponse {
StoredResponse {
status: 201,
headers: vec![("x-request-id".to_string(), "abc".to_string())],
body: Some(json!({"id": 1, "name": "Alice"})),
}
}
#[tokio::test]
async fn new_key_returns_new() {
let store = make_store(3600);
let body_hash = hash_body(&json!({"name": "Alice"}));
assert!(matches!(store.check("key1", body_hash).await, IdempotencyCheck::New));
}
#[tokio::test]
async fn stored_key_replays_response() {
let store = make_store(3600);
let body = json!({"name": "Alice"});
let body_hash = hash_body(&body);
let response = make_response();
store.store("key1".to_string(), body_hash, response).await;
match store.check("key1", body_hash).await {
IdempotencyCheck::Replay(stored) => {
assert_eq!(stored.status, 201);
assert_eq!(stored.body.as_ref().unwrap()["name"], "Alice");
},
other => panic!("Expected Replay, got {other:?}"),
}
}
#[tokio::test]
async fn same_key_different_body_returns_conflict() {
let store = make_store(3600);
let body1 = json!({"name": "Alice"});
let body2 = json!({"name": "Bob"});
let hash1 = hash_body(&body1);
let hash2 = hash_body(&body2);
store.store("key1".to_string(), hash1, make_response()).await;
assert!(matches!(store.check("key1", hash2).await, IdempotencyCheck::Conflict));
}
#[tokio::test]
async fn expired_key_treated_as_new() {
let store = InMemoryIdempotencyStore::new(Duration::from_millis(1), 100);
let body = json!({"name": "Alice"});
let body_hash = hash_body(&body);
store.store("key1".to_string(), body_hash, make_response()).await;
tokio::time::sleep(Duration::from_millis(5)).await;
assert!(matches!(store.check("key1", body_hash).await, IdempotencyCheck::New));
}
#[tokio::test]
async fn max_entries_evicts_oldest() {
let store = InMemoryIdempotencyStore::new(Duration::from_secs(3600), 3);
let hash = hash_body(&json!({}));
store.store("key1".to_string(), hash, make_response()).await;
tokio::time::sleep(Duration::from_millis(1)).await;
store.store("key2".to_string(), hash, make_response()).await;
tokio::time::sleep(Duration::from_millis(1)).await;
store.store("key3".to_string(), hash, make_response()).await;
tokio::time::sleep(Duration::from_millis(1)).await;
store.store("key4".to_string(), hash, make_response()).await;
assert!(matches!(store.check("key1", hash).await, IdempotencyCheck::New));
assert!(matches!(store.check("key2", hash).await, IdempotencyCheck::Replay(_)));
}
#[test]
fn body_hash_deterministic() {
let body = json!({"name": "Alice", "age": 30});
let hash1 = hash_body(&body);
let hash2 = hash_body(&body);
assert_eq!(hash1, hash2);
}
#[test]
fn body_hash_different_for_different_bodies() {
let hash1 = hash_body(&json!({"name": "Alice"}));
let hash2 = hash_body(&json!({"name": "Bob"}));
assert_ne!(hash1, hash2);
}
#[tokio::test]
async fn create_store_returns_arc() {
let store = create_store(3600);
let body_hash = hash_body(&json!({}));
assert!(matches!(store.check("key1", body_hash).await, IdempotencyCheck::New));
}
}