Skip to main content

apfsds_crypto/
replay.rs

1//! Replay protection cache
2
3use dashmap::DashMap;
4use std::time::{Duration, Instant};
5
6/// Thread-safe replay cache for nonce/UUID deduplication
7pub struct ReplayCache {
8    /// Map of nonce -> expiration time
9    seen: DashMap<[u8; 32], Instant>,
10    /// TTL for entries
11    ttl: Duration,
12}
13
14impl ReplayCache {
15    /// Create a new replay cache with the given TTL
16    pub fn new(ttl: Duration) -> Self {
17        Self {
18            seen: DashMap::new(),
19            ttl,
20        }
21    }
22
23    /// Check if a nonce has been seen and insert it if not
24    /// Returns true if the nonce is new (not a replay)
25    pub fn check_and_insert(&self, nonce: &[u8; 32]) -> bool {
26        let now = Instant::now();
27        let expiry = now + self.ttl;
28
29        // Check if already exists and not expired
30        if let Some(existing) = self.seen.get(nonce) {
31            if *existing > now {
32                return false; // Replay detected
33            }
34        }
35
36        // Insert or update
37        self.seen.insert(*nonce, expiry);
38        true
39    }
40
41    /// Check if a nonce has been seen (without inserting)
42    pub fn contains(&self, nonce: &[u8; 32]) -> bool {
43        if let Some(expiry) = self.seen.get(nonce) {
44            *expiry > Instant::now()
45        } else {
46            false
47        }
48    }
49
50    /// Remove expired entries
51    pub fn cleanup(&self) {
52        let now = Instant::now();
53        self.seen.retain(|_, expiry| *expiry > now);
54    }
55
56    /// Get the number of entries
57    pub fn len(&self) -> usize {
58        self.seen.len()
59    }
60
61    /// Check if empty
62    pub fn is_empty(&self) -> bool {
63        self.seen.is_empty()
64    }
65
66    /// Clear all entries
67    pub fn clear(&self) {
68        self.seen.clear();
69    }
70}
71
72/// UUID-based replay cache (16-byte keys)
73pub struct UuidReplayCache {
74    seen: DashMap<[u8; 16], Instant>,
75    ttl: Duration,
76}
77
78impl UuidReplayCache {
79    pub fn new(ttl: Duration) -> Self {
80        Self {
81            seen: DashMap::new(),
82            ttl,
83        }
84    }
85
86    pub fn check_and_insert(&self, uuid: &[u8; 16]) -> bool {
87        let now = Instant::now();
88        let expiry = now + self.ttl;
89
90        if let Some(existing) = self.seen.get(uuid) {
91            if *existing > now {
92                return false;
93            }
94        }
95
96        self.seen.insert(*uuid, expiry);
97        true
98    }
99
100    pub fn cleanup(&self) {
101        let now = Instant::now();
102        self.seen.retain(|_, expiry| *expiry > now);
103    }
104
105    pub fn len(&self) -> usize {
106        self.seen.len()
107    }
108
109    pub fn is_empty(&self) -> bool {
110        self.seen.is_empty()
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_replay_detection() {
120        let cache = ReplayCache::new(Duration::from_secs(60));
121        let nonce = [42u8; 32];
122
123        // First time should succeed
124        assert!(cache.check_and_insert(&nonce));
125
126        // Second time should fail (replay)
127        assert!(!cache.check_and_insert(&nonce));
128    }
129
130    #[test]
131    fn test_different_nonces() {
132        let cache = ReplayCache::new(Duration::from_secs(60));
133        let nonce1 = [1u8; 32];
134        let nonce2 = [2u8; 32];
135
136        assert!(cache.check_and_insert(&nonce1));
137        assert!(cache.check_and_insert(&nonce2));
138    }
139
140    #[test]
141    fn test_cleanup() {
142        let cache = ReplayCache::new(Duration::from_millis(10));
143        let nonce = [42u8; 32];
144
145        cache.check_and_insert(&nonce);
146        assert_eq!(cache.len(), 1);
147
148        // Wait for expiration
149        std::thread::sleep(Duration::from_millis(20));
150
151        cache.cleanup();
152        assert_eq!(cache.len(), 0);
153    }
154
155    #[test]
156    fn test_uuid_cache() {
157        let cache = UuidReplayCache::new(Duration::from_secs(60));
158        let uuid = [42u8; 16];
159
160        assert!(cache.check_and_insert(&uuid));
161        assert!(!cache.check_and_insert(&uuid));
162    }
163}