use crate::error::AwsError;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
pub const DEFAULT_TTL: Duration = Duration::from_secs(24 * 60 * 60);
#[derive(Debug)]
pub enum Lookup<V> {
Miss,
Hit(V),
Mismatch,
}
#[derive(Debug, Clone)]
struct Entry<V> {
request_hash: u64,
value: V,
inserted_at: Instant,
}
#[derive(Debug)]
pub struct IdempotencyCache<V: Clone> {
inner: Mutex<HashMap<String, Entry<V>>>,
ttl: Duration,
}
impl<V: Clone> IdempotencyCache<V> {
pub fn new() -> Self {
Self::with_ttl(DEFAULT_TTL)
}
pub fn with_ttl(ttl: Duration) -> Self {
Self {
inner: Mutex::new(HashMap::new()),
ttl,
}
}
pub fn lookup(&self, token: &str, request_hash: u64) -> Lookup<V> {
let mut g = self.inner.lock().unwrap();
if let Some(entry) = g.get(token) {
if entry.inserted_at.elapsed() > self.ttl {
g.remove(token);
return Lookup::Miss;
}
return if entry.request_hash == request_hash {
Lookup::Hit(entry.value.clone())
} else {
Lookup::Mismatch
};
}
Lookup::Miss
}
pub fn insert(&self, token: impl Into<String>, request_hash: u64, value: V) {
let token = token.into();
let mut g = self.inner.lock().unwrap();
g.insert(
token,
Entry {
request_hash,
value,
inserted_at: Instant::now(),
},
);
}
pub fn lookup_or_insert<F>(
&self,
token: &str,
request_hash: u64,
compute: F,
) -> Result<V, AwsError>
where
F: FnOnce() -> Result<V, AwsError>,
{
match self.lookup(token, request_hash) {
Lookup::Hit(v) => Ok(v),
Lookup::Mismatch => Err(AwsError::bad_request(
"IdempotencyParameterMismatchException",
"Request parameters do not match those used in a prior call with the same ClientToken.",
)),
Lookup::Miss => {
let value = compute()?;
self.insert(token, request_hash, value.clone());
Ok(value)
}
}
}
pub fn sweep(&self) {
let ttl = self.ttl;
let mut g = self.inner.lock().unwrap();
g.retain(|_, e| e.inserted_at.elapsed() <= ttl);
}
pub fn len(&self) -> usize {
self.inner.lock().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<V: Clone> Default for IdempotencyCache<V> {
fn default() -> Self {
Self::new()
}
}
type ScopeMap<V> = HashMap<(String, String), Arc<IdempotencyCache<V>>>;
#[derive(Debug)]
pub struct AccountRegionIdempotencyCache<V: Clone> {
inner: Mutex<ScopeMap<V>>,
ttl: Duration,
}
impl<V: Clone> AccountRegionIdempotencyCache<V> {
pub fn new() -> Self {
Self::with_ttl(DEFAULT_TTL)
}
pub fn with_ttl(ttl: Duration) -> Self {
Self {
inner: Mutex::new(HashMap::new()),
ttl,
}
}
pub fn scope(&self, account_id: &str, region: &str) -> Arc<IdempotencyCache<V>> {
let key = (account_id.to_string(), region.to_string());
let mut g = self.inner.lock().unwrap();
g.entry(key)
.or_insert_with(|| Arc::new(IdempotencyCache::with_ttl(self.ttl)))
.clone()
}
pub fn sweep(&self) {
let scopes: Vec<Arc<IdempotencyCache<V>>> =
self.inner.lock().unwrap().values().cloned().collect();
for s in scopes {
s.sweep();
}
}
pub fn total_len(&self) -> usize {
let scopes: Vec<Arc<IdempotencyCache<V>>> =
self.inner.lock().unwrap().values().cloned().collect();
scopes.iter().map(|s| s.len()).sum()
}
}
impl<V: Clone> Default for AccountRegionIdempotencyCache<V> {
fn default() -> Self {
Self::new()
}
}
pub fn validate_token(token: &str) -> Result<(), AwsError> {
if token.is_empty() || token.len() > 64 {
return Err(AwsError::validation(
"ClientToken must be 1-64 characters long.",
));
}
if !token.bytes().all(|b| (0x21..=0x7e).contains(&b)) {
return Err(AwsError::validation(
"ClientToken must contain only printable ASCII characters.",
));
}
Ok(())
}
pub fn hash_request<H: std::hash::Hash>(value: &H) -> u64 {
use std::hash::Hasher;
let mut h = std::collections::hash_map::DefaultHasher::new();
value.hash(&mut h);
h.finish()
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread::sleep;
#[test]
fn miss_on_unseen_token() {
let cache: IdempotencyCache<String> = IdempotencyCache::new();
assert!(matches!(cache.lookup("tok-1", 0), Lookup::Miss));
}
#[test]
fn hit_on_same_token_and_params() {
let cache: IdempotencyCache<String> = IdempotencyCache::new();
cache.insert("tok-1", 42, "first-result".to_string());
let result = cache.lookup("tok-1", 42);
assert!(matches!(result, Lookup::Hit(ref v) if v == "first-result"));
}
#[test]
fn mismatch_on_same_token_different_params() {
let cache: IdempotencyCache<String> = IdempotencyCache::new();
cache.insert("tok-1", 42, "first".to_string());
assert!(matches!(cache.lookup("tok-1", 999), Lookup::Mismatch));
}
#[test]
fn expired_entry_treated_as_miss() {
let cache: IdempotencyCache<String> = IdempotencyCache::with_ttl(Duration::from_millis(5));
cache.insert("tok-1", 1, "x".into());
sleep(Duration::from_millis(20));
assert!(matches!(cache.lookup("tok-1", 1), Lookup::Miss));
}
#[test]
fn sweep_drops_expired_entries() {
let cache: IdempotencyCache<String> = IdempotencyCache::with_ttl(Duration::from_millis(5));
cache.insert("a", 1, "x".into());
cache.insert("b", 2, "y".into());
sleep(Duration::from_millis(20));
cache.sweep();
assert!(cache.is_empty());
}
#[test]
fn validate_token_accepts_printable_ascii() {
validate_token("abc-123_XYZ").unwrap();
validate_token("!~").unwrap();
}
#[test]
fn validate_token_rejects_empty() {
assert!(validate_token("").is_err());
}
#[test]
fn validate_token_rejects_over_64_chars() {
let long: String = "a".repeat(65);
assert!(validate_token(&long).is_err());
}
#[test]
fn validate_token_rejects_control_chars() {
assert!(validate_token("with\tspace").is_err());
assert!(validate_token("with space").is_err());
assert!(validate_token("with\ncontrol").is_err());
}
#[test]
fn lookup_or_insert_runs_compute_on_miss() {
let cache: IdempotencyCache<String> = IdempotencyCache::new();
let result = cache
.lookup_or_insert("tok", 7, || Ok("computed".to_string()))
.unwrap();
assert_eq!(result, "computed");
let result = cache
.lookup_or_insert("tok", 7, || panic!("compute must not run on hit"))
.unwrap();
assert_eq!(result, "computed");
}
#[test]
fn lookup_or_insert_returns_mismatch_exception() {
let cache: IdempotencyCache<String> = IdempotencyCache::new();
cache.insert("tok", 1, "first".into());
let err = cache
.lookup_or_insert("tok", 2, || Ok("second".to_string()))
.unwrap_err();
assert_eq!(err.code, "IdempotencyParameterMismatchException");
}
#[test]
fn lookup_or_insert_does_not_cache_compute_errors() {
let cache: IdempotencyCache<String> = IdempotencyCache::new();
let err = cache
.lookup_or_insert("tok", 1, || Err(AwsError::validation("boom")))
.unwrap_err();
assert_eq!(err.code, "ValidationException");
let result = cache
.lookup_or_insert("tok", 1, || Ok("ok".to_string()))
.unwrap();
assert_eq!(result, "ok");
}
#[test]
fn account_region_cache_isolates_scopes() {
let cache: AccountRegionIdempotencyCache<String> = AccountRegionIdempotencyCache::new();
let a = cache.scope("111111111111", "us-east-1");
let b = cache.scope("222222222222", "us-east-1");
a.insert("tok", 1, "alice".to_string());
assert!(matches!(b.lookup("tok", 1), Lookup::Miss));
assert!(matches!(a.lookup("tok", 1), Lookup::Hit(ref v) if v == "alice"));
}
#[test]
fn account_region_cache_isolates_regions() {
let cache: AccountRegionIdempotencyCache<String> = AccountRegionIdempotencyCache::new();
let east = cache.scope("111111111111", "us-east-1");
let west = cache.scope("111111111111", "us-west-2");
east.insert("tok", 7, "east-only".to_string());
assert!(matches!(west.lookup("tok", 7), Lookup::Miss));
}
#[test]
fn account_region_cache_returns_same_handle_per_scope() {
let cache: AccountRegionIdempotencyCache<String> = AccountRegionIdempotencyCache::new();
let first = cache.scope("111111111111", "us-east-1");
first.insert("tok", 1, "v".into());
let second = cache.scope("111111111111", "us-east-1");
assert!(matches!(second.lookup("tok", 1), Lookup::Hit(ref v) if v == "v"));
}
#[test]
fn account_region_cache_sweep_clears_every_scope() {
let cache: AccountRegionIdempotencyCache<String> =
AccountRegionIdempotencyCache::with_ttl(Duration::from_millis(5));
cache
.scope("a", "us-east-1")
.insert("t1", 1, "x".to_string());
cache
.scope("b", "us-west-2")
.insert("t2", 2, "y".to_string());
assert_eq!(cache.total_len(), 2);
sleep(Duration::from_millis(20));
cache.sweep();
assert_eq!(cache.total_len(), 0);
}
#[test]
fn hash_request_stable_across_calls() {
let a = ("CreateUser", "alice", 42u32);
assert_eq!(hash_request(&a), hash_request(&a));
let b = ("CreateUser", "bob", 42u32);
assert_ne!(hash_request(&a), hash_request(&b));
}
}