1use std::collections::{HashMap, VecDeque};
7use std::time::{Duration, Instant};
8
9use crate::nat_traversal_api::PeerId;
10
11#[derive(Debug)]
13struct PendingEntry {
14 data: Vec<u8>,
15 created_at: Instant,
16}
17
18#[derive(Debug, Default)]
20struct PeerPendingData {
21 entries: VecDeque<PendingEntry>,
22 total_bytes: usize,
23}
24
25#[derive(Debug, Clone, Default)]
27pub struct PendingBufferStats {
28 pub total_peers: usize,
30 pub total_messages: usize,
32 pub total_bytes: usize,
34 pub dropped_messages: u64,
36 pub expired_messages: u64,
38}
39
40#[derive(Debug)]
42pub struct BoundedPendingBuffer {
43 data: HashMap<PeerId, PeerPendingData>,
44 max_bytes_per_peer: usize,
45 max_messages_per_peer: usize,
46 ttl: Duration,
47 dropped_messages: u64,
48 expired_messages: u64,
49}
50
51impl BoundedPendingBuffer {
52 pub fn new(max_bytes_per_peer: usize, max_messages_per_peer: usize, ttl: Duration) -> Self {
54 Self {
55 data: HashMap::new(),
56 max_bytes_per_peer,
57 max_messages_per_peer,
58 ttl,
59 dropped_messages: 0,
60 expired_messages: 0,
61 }
62 }
63
64 pub fn push(&mut self, peer_id: &PeerId, data: Vec<u8>) -> Result<(), PendingBufferError> {
66 let data_len = data.len();
67
68 if data_len > self.max_bytes_per_peer {
70 return Err(PendingBufferError::MessageTooLarge {
71 size: data_len,
72 max: self.max_bytes_per_peer,
73 });
74 }
75
76 let peer_data = self.data.entry(*peer_id).or_default();
77
78 while peer_data.total_bytes + data_len > self.max_bytes_per_peer
80 || peer_data.entries.len() >= self.max_messages_per_peer
81 {
82 if let Some(dropped) = peer_data.entries.pop_front() {
83 peer_data.total_bytes = peer_data.total_bytes.saturating_sub(dropped.data.len());
84 self.dropped_messages += 1;
85 } else {
86 break;
87 }
88 }
89
90 peer_data.entries.push_back(PendingEntry {
92 data,
93 created_at: Instant::now(),
94 });
95 peer_data.total_bytes += data_len;
96
97 Ok(())
98 }
99
100 pub fn pop(&mut self, peer_id: &PeerId) -> Option<Vec<u8>> {
102 let peer_data = self.data.get_mut(peer_id)?;
103 let entry = peer_data.entries.pop_front()?;
104 peer_data.total_bytes = peer_data.total_bytes.saturating_sub(entry.data.len());
105
106 if peer_data.entries.is_empty() {
108 self.data.remove(peer_id);
109 }
110
111 Some(entry.data)
112 }
113
114 pub fn pop_any(&mut self) -> Option<(PeerId, Vec<u8>)> {
116 let peer_id = *self.data.keys().next()?;
118 let data = self.pop(&peer_id)?;
119 Some((peer_id, data))
120 }
121
122 pub fn peek_oldest(&self, peer_id: &PeerId) -> Option<&[u8]> {
124 self.data
125 .get(peer_id)?
126 .entries
127 .front()
128 .map(|e| e.data.as_slice())
129 }
130
131 pub fn message_count(&self, peer_id: &PeerId) -> usize {
133 self.data.get(peer_id).map(|d| d.entries.len()).unwrap_or(0)
134 }
135
136 pub fn total_bytes(&self, peer_id: &PeerId) -> usize {
138 self.data.get(peer_id).map(|d| d.total_bytes).unwrap_or(0)
139 }
140
141 pub fn clear_peer(&mut self, peer_id: &PeerId) {
143 self.data.remove(peer_id);
144 }
145
146 pub fn is_empty(&self) -> bool {
148 self.data.is_empty()
149 }
150
151 pub fn cleanup_expired(&mut self) {
153 let now = Instant::now();
154 let ttl = self.ttl;
155
156 self.data.retain(|_, peer_data| {
157 let before_len = peer_data.entries.len();
158
159 peer_data.entries.retain(|entry| {
160 let is_valid = now.duration_since(entry.created_at) < ttl;
161 if !is_valid {
162 peer_data.total_bytes = peer_data.total_bytes.saturating_sub(entry.data.len());
163 }
164 is_valid
165 });
166
167 let expired_count = before_len - peer_data.entries.len();
168 self.expired_messages += expired_count as u64;
169
170 !peer_data.entries.is_empty()
171 });
172 }
173
174 pub fn stats(&self) -> PendingBufferStats {
176 PendingBufferStats {
177 total_peers: self.data.len(),
178 total_messages: self.data.values().map(|d| d.entries.len()).sum(),
179 total_bytes: self.data.values().map(|d| d.total_bytes).sum(),
180 dropped_messages: self.dropped_messages,
181 expired_messages: self.expired_messages,
182 }
183 }
184
185 pub fn iter_peers(&self) -> impl Iterator<Item = &PeerId> {
187 self.data.keys()
188 }
189}
190
191impl Default for BoundedPendingBuffer {
192 fn default() -> Self {
193 Self::new(
194 1024 * 1024, 100, Duration::from_secs(30),
197 )
198 }
199}
200
201#[derive(Debug, Clone)]
203pub enum PendingBufferError {
204 MessageTooLarge {
206 size: usize,
208 max: usize,
210 },
211}
212
213impl std::fmt::Display for PendingBufferError {
214 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215 match self {
216 Self::MessageTooLarge { size, max } => {
217 write!(
218 f,
219 "Message too large: {} bytes exceeds max {} bytes",
220 size, max
221 )
222 }
223 }
224 }
225}
226
227impl std::error::Error for PendingBufferError {}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 const MAX_PENDING_BYTES_PER_PEER: usize = 1024 * 1024; const MAX_PENDING_MESSAGES_PER_PEER: usize = 100;
236 const PENDING_DATA_TTL: Duration = Duration::from_secs(30);
237
238 fn random_peer_id() -> PeerId {
239 use std::time::SystemTime;
240 let seed = SystemTime::now()
241 .duration_since(SystemTime::UNIX_EPOCH)
242 .map(|d| d.as_nanos())
243 .unwrap_or(0);
244 let mut bytes = [0u8; 32];
245 for (i, b) in bytes.iter_mut().enumerate() {
246 *b = ((seed >> (i % 16)) & 0xFF) as u8;
247 }
248 PeerId(bytes)
249 }
250
251 #[test]
252 fn test_pending_buffer_enforces_byte_limit() {
253 let mut buffer = BoundedPendingBuffer::new(
254 MAX_PENDING_BYTES_PER_PEER,
255 MAX_PENDING_MESSAGES_PER_PEER,
256 PENDING_DATA_TTL,
257 );
258
259 let peer_id = random_peer_id();
260
261 let large_data = vec![0u8; MAX_PENDING_BYTES_PER_PEER / 2];
263 assert!(buffer.push(&peer_id, large_data.clone()).is_ok());
264 assert!(buffer.push(&peer_id, large_data.clone()).is_ok());
265
266 let result = buffer.push(&peer_id, vec![0u8; 100]);
268 assert!(result.is_ok());
269
270 assert!(buffer.total_bytes(&peer_id) <= MAX_PENDING_BYTES_PER_PEER);
272 }
273
274 #[test]
275 fn test_pending_buffer_enforces_message_limit() {
276 let mut buffer = BoundedPendingBuffer::new(
277 MAX_PENDING_BYTES_PER_PEER,
278 10, PENDING_DATA_TTL,
280 );
281
282 let peer_id = random_peer_id();
283
284 for i in 0..10 {
286 assert!(buffer.push(&peer_id, vec![i as u8]).is_ok());
287 }
288
289 buffer
291 .push(&peer_id, vec![10u8])
292 .expect("push should succeed");
293 assert_eq!(buffer.message_count(&peer_id), 10);
294
295 let first = buffer.peek_oldest(&peer_id).expect("should have data");
297 assert_eq!(first[0], 1u8); }
299
300 #[tokio::test]
301 async fn test_pending_buffer_expires_old_entries() {
302 let mut buffer = BoundedPendingBuffer::new(
303 MAX_PENDING_BYTES_PER_PEER,
304 MAX_PENDING_MESSAGES_PER_PEER,
305 Duration::from_millis(50), );
307
308 let peer_id = random_peer_id();
309 buffer
310 .push(&peer_id, vec![1, 2, 3])
311 .expect("push should succeed");
312
313 assert_eq!(buffer.message_count(&peer_id), 1);
315
316 tokio::time::sleep(Duration::from_millis(100)).await;
318
319 buffer.cleanup_expired();
321 assert_eq!(buffer.message_count(&peer_id), 0);
322 }
323
324 #[test]
325 fn test_pending_buffer_pop_returns_oldest_first() {
326 let mut buffer = BoundedPendingBuffer::new(
327 MAX_PENDING_BYTES_PER_PEER,
328 MAX_PENDING_MESSAGES_PER_PEER,
329 PENDING_DATA_TTL,
330 );
331
332 let peer_id = random_peer_id();
333 buffer.push(&peer_id, vec![1]).expect("push should succeed");
334 buffer.push(&peer_id, vec![2]).expect("push should succeed");
335 buffer.push(&peer_id, vec![3]).expect("push should succeed");
336
337 assert_eq!(buffer.pop(&peer_id), Some(vec![1]));
338 assert_eq!(buffer.pop(&peer_id), Some(vec![2]));
339 assert_eq!(buffer.pop(&peer_id), Some(vec![3]));
340 assert_eq!(buffer.pop(&peer_id), None);
341 }
342
343 #[test]
344 fn test_pending_buffer_clear_peer() {
345 let mut buffer = BoundedPendingBuffer::new(
346 MAX_PENDING_BYTES_PER_PEER,
347 MAX_PENDING_MESSAGES_PER_PEER,
348 PENDING_DATA_TTL,
349 );
350
351 let peer_id = random_peer_id();
352 buffer
353 .push(&peer_id, vec![1, 2, 3])
354 .expect("push should succeed");
355 buffer
356 .push(&peer_id, vec![4, 5, 6])
357 .expect("push should succeed");
358
359 buffer.clear_peer(&peer_id);
360 assert_eq!(buffer.message_count(&peer_id), 0);
361 assert_eq!(buffer.total_bytes(&peer_id), 0);
362 }
363
364 #[test]
365 fn test_pending_buffer_stats() {
366 let mut buffer = BoundedPendingBuffer::new(
367 MAX_PENDING_BYTES_PER_PEER,
368 MAX_PENDING_MESSAGES_PER_PEER,
369 PENDING_DATA_TTL,
370 );
371
372 let peer1 = PeerId([1u8; 32]);
373 let peer2 = PeerId([2u8; 32]);
374
375 buffer
376 .push(&peer1, vec![1, 2, 3])
377 .expect("push should succeed");
378 buffer
379 .push(&peer2, vec![4, 5])
380 .expect("push should succeed");
381
382 let stats = buffer.stats();
383 assert_eq!(stats.total_peers, 2);
384 assert_eq!(stats.total_messages, 2);
385 assert_eq!(stats.total_bytes, 5);
386 }
387
388 #[test]
389 fn test_pending_buffer_pop_any() {
390 let mut buffer = BoundedPendingBuffer::new(
391 MAX_PENDING_BYTES_PER_PEER,
392 MAX_PENDING_MESSAGES_PER_PEER,
393 PENDING_DATA_TTL,
394 );
395
396 let peer1 = PeerId([1u8; 32]);
397 buffer
398 .push(&peer1, vec![1, 2, 3])
399 .expect("push should succeed");
400
401 let result = buffer.pop_any();
402 assert!(result.is_some());
403 let (peer_id, data) = result.unwrap();
404 assert_eq!(peer_id, peer1);
405 assert_eq!(data, vec![1, 2, 3]);
406
407 assert!(buffer.is_empty());
409 assert!(buffer.pop_any().is_none());
410 }
411
412 #[test]
413 fn test_pending_buffer_rejects_too_large_message() {
414 let mut buffer = BoundedPendingBuffer::new(
415 1000, MAX_PENDING_MESSAGES_PER_PEER,
417 PENDING_DATA_TTL,
418 );
419
420 let peer_id = random_peer_id();
421
422 let result = buffer.push(&peer_id, vec![0u8; 2000]);
424 assert!(matches!(
425 result,
426 Err(PendingBufferError::MessageTooLarge { .. })
427 ));
428 }
429
430 #[test]
431 fn test_pending_buffer_dropped_count() {
432 let mut buffer = BoundedPendingBuffer::new(
433 MAX_PENDING_BYTES_PER_PEER,
434 5, PENDING_DATA_TTL,
436 );
437
438 let peer_id = random_peer_id();
439
440 for i in 0..5 {
442 buffer.push(&peer_id, vec![i]).expect("push should succeed");
443 }
444
445 for i in 5..8 {
447 buffer.push(&peer_id, vec![i]).expect("push should succeed");
448 }
449
450 let stats = buffer.stats();
451 assert_eq!(stats.dropped_messages, 3);
452 assert_eq!(stats.total_messages, 5);
453 }
454
455 #[test]
456 fn test_pending_buffer_default() {
457 let buffer = BoundedPendingBuffer::default();
458 assert!(buffer.is_empty());
459 let stats = buffer.stats();
460 assert_eq!(stats.total_peers, 0);
461 assert_eq!(stats.total_messages, 0);
462 }
463}