network_protocol/utils/
replay_cache.rs1use std::collections::{HashMap, VecDeque};
11use std::hash::Hash;
12use std::time::{Duration, Instant};
13use tracing::{debug, instrument, warn};
14
15#[derive(Debug, Clone)]
17struct CacheEntry {
18 added_at: Instant,
20 timestamp: u64,
22 #[allow(dead_code)]
24 nonce: [u8; 16],
25}
26
27#[derive(Debug, Clone, PartialEq, Eq, Hash)]
29pub struct CacheKey {
30 peer_id: String,
32 nonce: [u8; 16],
34}
35
36#[derive(Debug)]
41pub struct ReplayCache {
42 entries: HashMap<CacheKey, CacheEntry>,
44 insertion_order: VecDeque<CacheKey>,
46 ttl: Duration,
48 max_entries: usize,
50}
51
52impl ReplayCache {
53 pub fn new() -> Self {
58 Self {
59 entries: HashMap::new(),
60 insertion_order: VecDeque::new(),
61 ttl: Duration::from_secs(300),
62 max_entries: 10_000,
63 }
64 }
65
66 pub fn with_settings(ttl: Duration, max_entries: usize) -> Self {
68 Self {
69 entries: HashMap::new(),
70 insertion_order: VecDeque::new(),
71 ttl,
72 max_entries,
73 }
74 }
75
76 #[instrument(skip(self, peer_id, nonce))]
81 pub fn is_replay(&mut self, peer_id: &str, nonce: &[u8; 16], timestamp: u64) -> bool {
82 let key = CacheKey {
83 peer_id: peer_id.to_string(),
84 nonce: *nonce,
85 };
86
87 self.cleanup_expired();
89
90 if let Some(entry) = self.entries.get(&key) {
92 if entry.timestamp == timestamp {
94 warn!(
95 peer_id,
96 ?nonce,
97 timestamp,
98 "Replay attack detected - identical nonce and timestamp"
99 );
100 return true;
101 }
102 debug!(
105 peer_id,
106 ?nonce,
107 "Nonce seen before with different timestamp - allowing"
108 );
109 }
110
111 let entry = CacheEntry {
113 added_at: Instant::now(),
114 timestamp,
115 nonce: *nonce,
116 };
117
118 if self.entries.len() >= self.max_entries {
120 let to_remove = self.entries.len() - self.max_entries + 1;
121 self.remove_oldest_entries(to_remove);
122 }
123
124 self.entries.insert(key.clone(), entry);
125 self.insertion_order.push_back(key);
126 debug!(peer_id, ?nonce, timestamp, "New nonce/timestamp cached");
127
128 false
129 }
130
131 fn cleanup_expired(&mut self) {
133 let now = Instant::now();
134 let initial_count = self.entries.len();
135
136 self.entries
137 .retain(|_, entry| now.duration_since(entry.added_at) < self.ttl);
138
139 while let Some(key) = self.insertion_order.front() {
141 if !self.entries.contains_key(key) {
142 self.insertion_order.pop_front();
143 } else {
144 break;
145 }
146 }
147
148 let removed = initial_count - self.entries.len();
149 if removed > 0 {
150 debug!("Cleaned up {} expired replay cache entries", removed);
151 }
152 }
153
154 #[inline]
157 fn remove_oldest_entries(&mut self, count: usize) {
158 if count == 0 {
159 return;
160 }
161
162 for _ in 0..count {
163 if let Some(key) = self.insertion_order.pop_front() {
164 self.entries.remove(&key);
165 }
166 }
167
168 debug!(
169 "Removed {} oldest replay cache entries due to size limit",
170 count
171 );
172 }
173
174 pub fn stats(&self) -> CacheStats {
176 CacheStats {
177 entries: self.entries.len(),
178 max_entries: self.max_entries,
179 ttl_seconds: self.ttl.as_secs(),
180 }
181 }
182
183 pub fn clear(&mut self) {
185 self.entries.clear();
186 self.insertion_order.clear();
187 debug!("Replay cache cleared");
188 }
189}
190
191impl Default for ReplayCache {
192 fn default() -> Self {
193 Self::new()
194 }
195}
196
197#[derive(Debug, Clone)]
199pub struct CacheStats {
200 pub entries: usize,
202 pub max_entries: usize,
204 pub ttl_seconds: u64,
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use std::thread;
212
213 #[test]
214 fn test_replay_detection() {
215 let mut cache = ReplayCache::with_settings(Duration::from_secs(60), 100);
216
217 let peer_id = "test_peer";
218 let nonce = [1u8; 16];
219 let timestamp = 1234567890;
220
221 assert!(!cache.is_replay(peer_id, &nonce, timestamp));
223
224 assert!(cache.is_replay(peer_id, &nonce, timestamp));
226 }
227
228 #[test]
229 fn test_different_nonce_allowed() {
230 let mut cache = ReplayCache::with_settings(Duration::from_secs(60), 100);
231
232 let peer_id = "test_peer";
233 let nonce1 = [1u8; 16];
234 let nonce2 = [2u8; 16];
235 let timestamp = 1234567890;
236
237 assert!(!cache.is_replay(peer_id, &nonce1, timestamp));
239 assert!(!cache.is_replay(peer_id, &nonce2, timestamp));
240 }
241
242 #[test]
243 fn test_same_nonce_different_timestamp_allowed() {
244 let mut cache = ReplayCache::with_settings(Duration::from_secs(60), 100);
245
246 let peer_id = "test_peer";
247 let nonce = [1u8; 16];
248 let timestamp1 = 1234567890;
249 let timestamp2 = 1234567891;
250
251 assert!(!cache.is_replay(peer_id, &nonce, timestamp1));
253 assert!(!cache.is_replay(peer_id, &nonce, timestamp2));
254 }
255
256 #[test]
257 fn test_expiration() {
258 let mut cache = ReplayCache::with_settings(Duration::from_millis(10), 100);
259
260 let peer_id = "test_peer";
261 let nonce = [1u8; 16];
262 let timestamp = 1234567890;
263
264 assert!(!cache.is_replay(peer_id, &nonce, timestamp));
266
267 thread::sleep(Duration::from_millis(20));
269
270 assert!(!cache.is_replay(peer_id, &nonce, timestamp));
272 }
273
274 #[test]
275 fn test_max_entries_limit() {
276 let mut cache = ReplayCache::with_settings(Duration::from_secs(60), 5);
277
278 for i in 0..10 {
280 let nonce = [i as u8; 16];
281 assert!(!cache.is_replay("peer", &nonce, 1000 + i as u64));
282 }
283
284 assert!(cache.entries.len() <= 5);
286 }
287}