1use iroh::EndpointId;
26use std::collections::{HashMap, VecDeque};
27use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
28use std::sync::{Arc, RwLock};
29use std::time::{Duration, Instant};
30use thiserror::Error;
31
32#[derive(Error, Debug, Clone, PartialEq, Eq)]
34pub enum FlowControlError {
35 #[error("Rate limit exceeded for peer")]
37 RateLimitExceeded,
38 #[error("Message queue full for peer (max {max_size} messages)")]
40 QueueFull { max_size: usize },
41 #[error("Sync cooldown active, {remaining_ms}ms remaining")]
43 CooldownActive { remaining_ms: u64 },
44}
45
46#[derive(Debug, Clone)]
48pub struct FlowControlConfig {
49 pub max_messages_per_second: u32,
51 pub tokens_per_refill: u32,
53 pub refill_interval: Duration,
55 pub max_queue_size: usize,
57 pub sync_cooldown: Duration,
59 pub max_memory_per_peer: usize,
61}
62
63impl Default for FlowControlConfig {
64 fn default() -> Self {
65 Self {
66 max_messages_per_second: 100,
67 tokens_per_refill: 10,
68 refill_interval: Duration::from_millis(100), max_queue_size: 1000,
70 sync_cooldown: Duration::from_millis(100), max_memory_per_peer: 10 * 1024 * 1024, }
73 }
74}
75
76#[derive(Debug)]
84pub struct TokenBucket {
85 capacity: u32,
87 tokens: AtomicU32,
89 tokens_per_refill: u32,
91 refill_interval: Duration,
93 last_refill: RwLock<Instant>,
95}
96
97impl TokenBucket {
98 pub fn new(capacity: u32, tokens_per_refill: u32, refill_interval: Duration) -> Self {
100 Self {
101 capacity,
102 tokens: AtomicU32::new(capacity),
103 tokens_per_refill,
104 refill_interval,
105 last_refill: RwLock::new(Instant::now()),
106 }
107 }
108
109 pub fn try_acquire(&self) -> bool {
113 self.refill();
115
116 loop {
118 let current = self.tokens.load(Ordering::Acquire);
119 if current == 0 {
120 return false;
121 }
122 if self
123 .tokens
124 .compare_exchange(current, current - 1, Ordering::AcqRel, Ordering::Acquire)
125 .is_ok()
126 {
127 return true;
128 }
129 }
131 }
132
133 pub fn available_tokens(&self) -> u32 {
135 self.refill();
136 self.tokens.load(Ordering::Acquire)
137 }
138
139 fn refill(&self) {
141 let now = Instant::now();
142 let mut last = self.last_refill.write().unwrap_or_else(|e| e.into_inner());
143
144 let elapsed = now.duration_since(*last);
145 if elapsed < self.refill_interval {
146 return;
147 }
148
149 let periods = (elapsed.as_millis() / self.refill_interval.as_millis()) as u32;
151 if periods == 0 {
152 return;
153 }
154
155 let tokens_to_add = periods.saturating_mul(self.tokens_per_refill);
157 loop {
158 let current = self.tokens.load(Ordering::Acquire);
159 let new_tokens = (current + tokens_to_add).min(self.capacity);
160 if self
161 .tokens
162 .compare_exchange(current, new_tokens, Ordering::AcqRel, Ordering::Acquire)
163 .is_ok()
164 {
165 break;
166 }
167 }
168
169 *last = now;
171 }
172}
173
174#[derive(Debug)]
179pub struct BoundedQueue<T> {
180 queue: VecDeque<T>,
182 max_size: usize,
184 total_enqueued: u64,
186 total_dropped: u64,
188}
189
190impl<T> BoundedQueue<T> {
191 pub fn new(max_size: usize) -> Self {
193 Self {
194 queue: VecDeque::with_capacity(max_size.min(1000)), max_size,
196 total_enqueued: 0,
197 total_dropped: 0,
198 }
199 }
200
201 pub fn enqueue(&mut self, item: T) -> Option<T> {
205 self.total_enqueued += 1;
206
207 let dropped = if self.queue.len() >= self.max_size {
208 self.total_dropped += 1;
209 self.queue.pop_front()
210 } else {
211 None
212 };
213
214 self.queue.push_back(item);
215 dropped
216 }
217
218 pub fn try_enqueue(&mut self, item: T) -> Result<(), T> {
222 if self.queue.len() >= self.max_size {
223 return Err(item);
224 }
225 self.total_enqueued += 1;
226 self.queue.push_back(item);
227 Ok(())
228 }
229
230 pub fn dequeue(&mut self) -> Option<T> {
232 self.queue.pop_front()
233 }
234
235 pub fn peek(&self) -> Option<&T> {
237 self.queue.front()
238 }
239
240 pub fn len(&self) -> usize {
242 self.queue.len()
243 }
244
245 pub fn is_empty(&self) -> bool {
247 self.queue.is_empty()
248 }
249
250 pub fn total_enqueued(&self) -> u64 {
252 self.total_enqueued
253 }
254
255 pub fn total_dropped(&self) -> u64 {
257 self.total_dropped
258 }
259
260 pub fn clear(&mut self) {
262 self.queue.clear();
263 }
264}
265
266#[derive(Debug)]
271pub struct SyncCooldownTracker {
272 last_sync: HashMap<(EndpointId, String), Instant>,
274 cooldown: Duration,
276 blocked_count: u64,
278}
279
280impl SyncCooldownTracker {
281 pub fn new(cooldown: Duration) -> Self {
283 Self {
284 last_sync: HashMap::new(),
285 cooldown,
286 blocked_count: 0,
287 }
288 }
289
290 pub fn check_cooldown(
294 &mut self,
295 peer_id: &EndpointId,
296 doc_key: &str,
297 ) -> Result<(), FlowControlError> {
298 let key = (*peer_id, doc_key.to_string());
299 let now = Instant::now();
300
301 if let Some(last) = self.last_sync.get(&key) {
302 let elapsed = now.duration_since(*last);
303 if elapsed < self.cooldown {
304 self.blocked_count += 1;
305 let remaining = self.cooldown - elapsed;
306 return Err(FlowControlError::CooldownActive {
307 remaining_ms: remaining.as_millis() as u64,
308 });
309 }
310 }
311
312 Ok(())
313 }
314
315 pub fn record_sync(&mut self, peer_id: &EndpointId, doc_key: &str) {
317 let key = (*peer_id, doc_key.to_string());
318 self.last_sync.insert(key, Instant::now());
319 }
320
321 pub fn blocked_count(&self) -> u64 {
323 self.blocked_count
324 }
325
326 pub fn cleanup(&mut self) {
328 let now = Instant::now();
329 let threshold = self.cooldown * 10;
330 self.last_sync
331 .retain(|_, last| now.duration_since(*last) < threshold);
332 }
333}
334
335#[derive(Debug)]
337pub struct PeerResourceTracker {
338 memory_usage: AtomicU64,
340 max_memory: u64,
342 messages_sent: AtomicU64,
344 messages_received: AtomicU64,
346 messages_dropped: AtomicU64,
348}
349
350impl PeerResourceTracker {
351 pub fn new(max_memory: u64) -> Self {
353 Self {
354 memory_usage: AtomicU64::new(0),
355 max_memory,
356 messages_sent: AtomicU64::new(0),
357 messages_received: AtomicU64::new(0),
358 messages_dropped: AtomicU64::new(0),
359 }
360 }
361
362 pub fn try_allocate(&self, bytes: u64) -> bool {
364 loop {
365 let current = self.memory_usage.load(Ordering::Acquire);
366 let new_usage = current + bytes;
367 if new_usage > self.max_memory {
368 return false;
369 }
370 if self
371 .memory_usage
372 .compare_exchange(current, new_usage, Ordering::AcqRel, Ordering::Acquire)
373 .is_ok()
374 {
375 return true;
376 }
377 }
378 }
379
380 pub fn free(&self, bytes: u64) {
382 self.memory_usage.fetch_sub(bytes, Ordering::Release);
383 }
384
385 pub fn memory_usage(&self) -> u64 {
387 self.memory_usage.load(Ordering::Acquire)
388 }
389
390 pub fn record_sent(&self) {
392 self.messages_sent.fetch_add(1, Ordering::Relaxed);
393 }
394
395 pub fn record_received(&self) {
397 self.messages_received.fetch_add(1, Ordering::Relaxed);
398 }
399
400 pub fn record_dropped(&self) {
402 self.messages_dropped.fetch_add(1, Ordering::Relaxed);
403 }
404
405 pub fn messages_sent(&self) -> u64 {
407 self.messages_sent.load(Ordering::Relaxed)
408 }
409
410 pub fn messages_received(&self) -> u64 {
412 self.messages_received.load(Ordering::Relaxed)
413 }
414
415 pub fn messages_dropped(&self) -> u64 {
417 self.messages_dropped.load(Ordering::Relaxed)
418 }
419}
420
421#[derive(Debug, Clone, Default)]
423pub struct FlowControlStats {
424 pub rate_limited: u64,
426 pub queue_dropped: u64,
428 pub cooldown_blocked: u64,
430 pub total_memory_usage: u64,
432 pub active_peers: usize,
434}
435
436pub struct FlowController {
441 config: FlowControlConfig,
443 rate_limiters: Arc<RwLock<HashMap<EndpointId, TokenBucket>>>,
445 cooldowns: Arc<RwLock<SyncCooldownTracker>>,
447 resources: Arc<RwLock<HashMap<EndpointId, PeerResourceTracker>>>,
449 rate_limited_count: AtomicU64,
451}
452
453impl FlowController {
454 pub fn new() -> Self {
456 Self::with_config(FlowControlConfig::default())
457 }
458
459 pub fn with_config(config: FlowControlConfig) -> Self {
461 Self {
462 cooldowns: Arc::new(RwLock::new(SyncCooldownTracker::new(config.sync_cooldown))),
463 config,
464 rate_limiters: Arc::new(RwLock::new(HashMap::new())),
465 resources: Arc::new(RwLock::new(HashMap::new())),
466 rate_limited_count: AtomicU64::new(0),
467 }
468 }
469
470 pub fn check_sync_allowed(
479 &self,
480 peer_id: &EndpointId,
481 doc_key: &str,
482 ) -> Result<(), FlowControlError> {
483 {
485 let mut limiters = self
486 .rate_limiters
487 .write()
488 .unwrap_or_else(|e| e.into_inner());
489 let limiter = limiters.entry(*peer_id).or_insert_with(|| {
490 TokenBucket::new(
491 self.config.max_messages_per_second,
492 self.config.tokens_per_refill,
493 self.config.refill_interval,
494 )
495 });
496
497 if !limiter.try_acquire() {
498 self.rate_limited_count.fetch_add(1, Ordering::Relaxed);
499 return Err(FlowControlError::RateLimitExceeded);
500 }
501 }
502
503 {
505 let mut cooldowns = self.cooldowns.write().unwrap_or_else(|e| e.into_inner());
506 cooldowns.check_cooldown(peer_id, doc_key)?;
507 }
508
509 Ok(())
510 }
511
512 pub fn record_sync(&self, peer_id: &EndpointId, doc_key: &str) {
516 let mut cooldowns = self.cooldowns.write().unwrap_or_else(|e| e.into_inner());
517 cooldowns.record_sync(peer_id, doc_key);
518 }
519
520 pub fn get_resource_tracker(&self, peer_id: &EndpointId) -> Arc<PeerResourceTracker> {
522 let mut resources = self.resources.write().unwrap_or_else(|e| e.into_inner());
523 if !resources.contains_key(peer_id) {
524 resources.insert(
525 *peer_id,
526 PeerResourceTracker::new(self.config.max_memory_per_peer as u64),
527 );
528 }
529 let tracker = resources.get(peer_id).unwrap();
533 Arc::new(PeerResourceTracker {
534 memory_usage: AtomicU64::new(tracker.memory_usage.load(Ordering::Acquire)),
535 max_memory: tracker.max_memory,
536 messages_sent: AtomicU64::new(tracker.messages_sent.load(Ordering::Relaxed)),
537 messages_received: AtomicU64::new(tracker.messages_received.load(Ordering::Relaxed)),
538 messages_dropped: AtomicU64::new(tracker.messages_dropped.load(Ordering::Relaxed)),
539 })
540 }
541
542 pub fn stats(&self) -> FlowControlStats {
544 let cooldowns = self.cooldowns.read().unwrap_or_else(|e| e.into_inner());
545 let resources = self.resources.read().unwrap_or_else(|e| e.into_inner());
546
547 let total_memory: u64 = resources
548 .values()
549 .map(|r| r.memory_usage.load(Ordering::Relaxed))
550 .sum();
551
552 let queue_dropped: u64 = resources
553 .values()
554 .map(|r| r.messages_dropped.load(Ordering::Relaxed))
555 .sum();
556
557 FlowControlStats {
558 rate_limited: self.rate_limited_count.load(Ordering::Relaxed),
559 queue_dropped,
560 cooldown_blocked: cooldowns.blocked_count(),
561 total_memory_usage: total_memory,
562 active_peers: resources.len(),
563 }
564 }
565
566 pub fn cleanup(&self) {
568 let mut cooldowns = self.cooldowns.write().unwrap_or_else(|e| e.into_inner());
569 cooldowns.cleanup();
570 }
571
572 pub fn config(&self) -> &FlowControlConfig {
574 &self.config
575 }
576
577 pub fn available_tokens(&self, peer_id: &EndpointId) -> u32 {
579 let limiters = self.rate_limiters.read().unwrap_or_else(|e| e.into_inner());
580 limiters
581 .get(peer_id)
582 .map(|l| l.available_tokens())
583 .unwrap_or(self.config.max_messages_per_second)
584 }
585}
586
587impl Default for FlowController {
588 fn default() -> Self {
589 Self::new()
590 }
591}
592
593#[cfg(test)]
594mod tests {
595 use super::*;
596
597 fn create_test_peer_id() -> EndpointId {
598 use iroh::SecretKey;
599 let mut rng = rand::rng();
600 SecretKey::generate(&mut rng).public()
601 }
602
603 #[test]
604 fn test_token_bucket_basic() {
605 let bucket = TokenBucket::new(10, 1, Duration::from_millis(100));
606
607 assert_eq!(bucket.available_tokens(), 10);
609
610 for _ in 0..10 {
612 assert!(bucket.try_acquire());
613 }
614
615 assert!(!bucket.try_acquire());
617 assert_eq!(bucket.available_tokens(), 0);
618 }
619
620 #[test]
621 fn test_token_bucket_refill() {
622 let bucket = TokenBucket::new(10, 5, Duration::from_millis(10));
623
624 for _ in 0..10 {
626 bucket.try_acquire();
627 }
628 assert_eq!(bucket.available_tokens(), 0);
629
630 std::thread::sleep(Duration::from_millis(25));
632
633 let available = bucket.available_tokens();
635 assert!(
636 available >= 5,
637 "Expected at least 5 tokens, got {}",
638 available
639 );
640 }
641
642 #[test]
643 fn test_bounded_queue_basic() {
644 let mut queue: BoundedQueue<i32> = BoundedQueue::new(3);
645
646 assert!(queue.is_empty());
647 assert_eq!(queue.len(), 0);
648
649 queue.enqueue(1);
651 queue.enqueue(2);
652 queue.enqueue(3);
653
654 assert_eq!(queue.len(), 3);
655 assert_eq!(queue.total_enqueued(), 3);
656 assert_eq!(queue.total_dropped(), 0);
657
658 assert_eq!(queue.dequeue(), Some(1));
660 assert_eq!(queue.dequeue(), Some(2));
661 assert_eq!(queue.dequeue(), Some(3));
662 assert_eq!(queue.dequeue(), None);
663 }
664
665 #[test]
666 fn test_bounded_queue_overflow() {
667 let mut queue: BoundedQueue<i32> = BoundedQueue::new(3);
668
669 queue.enqueue(1);
670 queue.enqueue(2);
671 queue.enqueue(3);
672
673 let dropped = queue.enqueue(4);
675 assert_eq!(dropped, Some(1));
676 assert_eq!(queue.total_dropped(), 1);
677
678 assert_eq!(queue.dequeue(), Some(2));
680 assert_eq!(queue.dequeue(), Some(3));
681 assert_eq!(queue.dequeue(), Some(4));
682 }
683
684 #[test]
685 fn test_bounded_queue_try_enqueue() {
686 let mut queue: BoundedQueue<i32> = BoundedQueue::new(2);
687
688 assert!(queue.try_enqueue(1).is_ok());
689 assert!(queue.try_enqueue(2).is_ok());
690 assert!(queue.try_enqueue(3).is_err()); }
692
693 #[test]
694 fn test_sync_cooldown_tracker() {
695 let peer_id = create_test_peer_id();
696 let mut tracker = SyncCooldownTracker::new(Duration::from_millis(50));
697
698 assert!(tracker.check_cooldown(&peer_id, "doc1").is_ok());
700 tracker.record_sync(&peer_id, "doc1");
701
702 let result = tracker.check_cooldown(&peer_id, "doc1");
704 assert!(matches!(
705 result,
706 Err(FlowControlError::CooldownActive { .. })
707 ));
708
709 assert!(tracker.check_cooldown(&peer_id, "doc2").is_ok());
711
712 std::thread::sleep(Duration::from_millis(60));
714
715 assert!(tracker.check_cooldown(&peer_id, "doc1").is_ok());
717 }
718
719 #[test]
720 fn test_peer_resource_tracker() {
721 let tracker = PeerResourceTracker::new(1000);
722
723 assert_eq!(tracker.memory_usage(), 0);
725
726 assert!(tracker.try_allocate(500));
728 assert_eq!(tracker.memory_usage(), 500);
729
730 assert!(tracker.try_allocate(400));
732 assert_eq!(tracker.memory_usage(), 900);
733
734 assert!(!tracker.try_allocate(200));
736 assert_eq!(tracker.memory_usage(), 900);
737
738 tracker.free(300);
740 assert_eq!(tracker.memory_usage(), 600);
741 }
742
743 #[test]
744 fn test_flow_controller_rate_limiting() {
745 let config = FlowControlConfig {
746 max_messages_per_second: 5,
747 tokens_per_refill: 1,
748 refill_interval: Duration::from_millis(200),
749 sync_cooldown: Duration::ZERO, ..Default::default()
751 };
752 let controller = FlowController::with_config(config);
753 let peer_id = create_test_peer_id();
754
755 for i in 0..5 {
757 assert!(
758 controller.check_sync_allowed(&peer_id, "doc1").is_ok(),
759 "Sync {} should be allowed",
760 i
761 );
762 controller.record_sync(&peer_id, "doc1");
763 }
764
765 let result = controller.check_sync_allowed(&peer_id, "doc1");
767 assert!(
768 matches!(result, Err(FlowControlError::RateLimitExceeded)),
769 "Expected rate limit, got {:?}",
770 result
771 );
772 }
773
774 #[test]
775 fn test_flow_controller_stats() {
776 let controller = FlowController::new();
777 let peer_id = create_test_peer_id();
778
779 controller.check_sync_allowed(&peer_id, "doc1").ok();
781 controller.record_sync(&peer_id, "doc1");
782
783 let stats = controller.stats();
784 assert_eq!(stats.active_peers, 0); assert_eq!(stats.rate_limited, 0);
786 }
787
788 #[test]
789 fn test_flow_controller_cleanup() {
790 let config = FlowControlConfig {
791 sync_cooldown: Duration::from_millis(10),
792 ..Default::default()
793 };
794 let controller = FlowController::with_config(config);
795 let peer_id = create_test_peer_id();
796
797 controller.record_sync(&peer_id, "doc1");
799
800 std::thread::sleep(Duration::from_millis(150));
802
803 controller.cleanup();
805 }
806}