1use crossbeam_queue::SegQueue;
7use smallvec::SmallVec;
8use std::{
9 sync::{
10 atomic::{AtomicBool, AtomicUsize, Ordering},
11 Arc,
12 },
13 time::Duration,
14};
15
16use crate::message::MessageId;
17
18#[derive(Debug, Clone)]
20pub struct PendingIHave {
21 pub message_id: MessageId,
23 pub round: u32,
25}
26
27#[derive(Debug)]
32pub struct IHaveQueue {
33 queue: SegQueue<PendingIHave>,
35 len: AtomicUsize,
37 max_size: usize,
39 accepting: AtomicBool,
41 flush_threshold: AtomicUsize,
43}
44
45impl IHaveQueue {
46 pub fn new(max_size: usize) -> Self {
48 Self {
49 queue: SegQueue::new(),
50 len: AtomicUsize::new(0),
51 max_size,
52 accepting: AtomicBool::new(true),
53 flush_threshold: AtomicUsize::new(16), }
55 }
56
57 pub fn with_flush_threshold(max_size: usize, flush_threshold: usize) -> Self {
59 Self {
60 queue: SegQueue::new(),
61 len: AtomicUsize::new(0),
62 max_size,
63 accepting: AtomicBool::new(true),
64 flush_threshold: AtomicUsize::new(flush_threshold),
65 }
66 }
67
68 pub fn set_flush_threshold(&self, threshold: usize) {
70 self.flush_threshold.store(threshold, Ordering::Relaxed);
71 }
72
73 pub fn push(&self, message_id: MessageId, round: u32) -> bool {
78 if !self.accepting.load(Ordering::Acquire) {
80 return false;
81 }
82
83 let current_len = self.len.load(Ordering::Relaxed);
85 if current_len >= self.max_size {
86 return false;
87 }
88
89 self.queue.push(PendingIHave { message_id, round });
91 self.len.fetch_add(1, Ordering::Relaxed);
92 true
93 }
94
95 pub fn should_flush(&self) -> bool {
99 let current_len = self.len.load(Ordering::Relaxed);
100 let threshold = self.flush_threshold.load(Ordering::Relaxed);
101 current_len >= threshold
102 }
103
104 pub fn flush_threshold(&self) -> usize {
106 self.flush_threshold.load(Ordering::Relaxed)
107 }
108
109 pub fn pop_batch(&self, max_batch: usize) -> SmallVec<[PendingIHave; 16]> {
113 let mut batch = SmallVec::new();
114
115 for _ in 0..max_batch {
116 if let Some(item) = self.queue.pop() {
117 self.len.fetch_sub(1, Ordering::Relaxed);
118 batch.push(item);
119 } else {
120 break;
121 }
122 }
123
124 batch
125 }
126
127 pub fn len(&self) -> usize {
129 self.len.load(Ordering::Relaxed)
130 }
131
132 pub fn is_empty(&self) -> bool {
134 self.len() == 0
135 }
136
137 pub fn stop(&self) {
139 self.accepting.store(false, Ordering::Release);
140 }
141
142 pub fn resume(&self) {
144 self.accepting.store(true, Ordering::Release);
145 }
146
147 pub fn clear(&self) {
149 while self.queue.pop().is_some() {
150 self.len.fetch_sub(1, Ordering::Relaxed);
151 }
152 }
153}
154
155impl Default for IHaveQueue {
156 fn default() -> Self {
157 Self::new(10000)
158 }
159}
160
161#[derive(Debug)]
165pub struct IHaveScheduler {
166 queue: Arc<IHaveQueue>,
168 interval: Duration,
170 batch_size: usize,
172 shutdown: AtomicBool,
174}
175
176impl IHaveScheduler {
177 pub fn new(interval: Duration, batch_size: usize, max_queue_size: usize) -> Self {
179 Self {
180 queue: Arc::new(IHaveQueue::with_flush_threshold(max_queue_size, batch_size)),
181 interval,
182 batch_size,
183 shutdown: AtomicBool::new(false),
184 }
185 }
186
187 pub fn queue(&self) -> &Arc<IHaveQueue> {
189 &self.queue
190 }
191
192 pub fn interval(&self) -> Duration {
194 self.interval
195 }
196
197 pub fn batch_size(&self) -> usize {
199 self.batch_size
200 }
201
202 pub fn is_shutdown(&self) -> bool {
204 self.shutdown.load(Ordering::Acquire)
205 }
206
207 pub fn shutdown(&self) {
209 self.shutdown.store(true, Ordering::Release);
210 self.queue.stop();
211 }
212
213 pub fn pop_batch(&self) -> SmallVec<[PendingIHave; 16]> {
215 self.queue.pop_batch(self.batch_size)
216 }
217}
218
219#[derive(Debug)]
237pub struct GraftTimer<I> {
238 inner: parking_lot::Mutex<GraftTimerInner<I>>,
240 base_timeout: Duration,
242 max_timeout: Duration,
244 max_retries: u32,
246}
247
248#[derive(Debug)]
250struct GraftTimerInner<I> {
251 entries: std::collections::HashMap<MessageId, GraftEntry<I>>,
253 timeouts: std::collections::BTreeMap<std::time::Instant, std::collections::HashSet<MessageId>>,
256}
257
258impl<I> Default for GraftTimerInner<I> {
259 fn default() -> Self {
260 Self {
261 entries: std::collections::HashMap::new(),
262 timeouts: std::collections::BTreeMap::new(),
263 }
264 }
265}
266
267#[derive(Debug, Clone)]
268struct GraftEntry<I> {
269 #[allow(dead_code)]
272 created: std::time::Instant,
273 next_retry: std::time::Instant,
275 from: I,
277 alternative_peers: Vec<I>,
279 round: u32,
281 retry_count: u32,
283}
284
285#[derive(Debug, Clone)]
287pub struct ExpiredGraft<I> {
288 pub message_id: MessageId,
290 pub peer: I,
292 pub round: u32,
294 pub retry_count: u32,
296}
297
298#[derive(Debug, Clone)]
300pub struct FailedGraft<I> {
301 pub message_id: MessageId,
303 pub original_peer: I,
305 pub total_retries: u32,
307}
308
309impl<I: Clone + Send + Sync + 'static> GraftTimer<I> {
310 pub fn new(timeout: Duration) -> Self {
312 Self {
313 inner: parking_lot::Mutex::new(GraftTimerInner::default()),
314 base_timeout: timeout,
315 max_timeout: timeout * 8, max_retries: 5,
317 }
318 }
319
320 pub fn with_backoff(base_timeout: Duration, max_timeout: Duration, max_retries: u32) -> Self {
322 Self {
323 inner: parking_lot::Mutex::new(GraftTimerInner::default()),
324 base_timeout,
325 max_timeout,
326 max_retries,
327 }
328 }
329
330 fn add_to_timeout_index(
332 inner: &mut GraftTimerInner<I>,
333 timeout: std::time::Instant,
334 message_id: MessageId,
335 ) {
336 inner
337 .timeouts
338 .entry(timeout)
339 .or_default()
340 .insert(message_id);
341 }
342
343 fn remove_from_timeout_index(
345 inner: &mut GraftTimerInner<I>,
346 timeout: std::time::Instant,
347 message_id: &MessageId,
348 ) {
349 if let Some(ids) = inner.timeouts.get_mut(&timeout) {
350 ids.remove(message_id);
351 if ids.is_empty() {
353 inner.timeouts.remove(&timeout);
354 }
355 }
356 }
357
358 pub fn expect_message(&self, message_id: MessageId, from: I, round: u32) {
360 let now = std::time::Instant::now();
361 let next_retry = now + self.base_timeout;
362 let mut inner = self.inner.lock();
363
364 if inner.entries.contains_key(&message_id) {
366 return;
367 }
368
369 inner.entries.insert(
370 message_id,
371 GraftEntry {
372 created: now,
373 next_retry,
374 from,
375 alternative_peers: Vec::new(),
376 round,
377 retry_count: 0,
378 },
379 );
380 Self::add_to_timeout_index(&mut inner, next_retry, message_id);
381 }
382
383 pub fn expect_message_with_alternatives(
385 &self,
386 message_id: MessageId,
387 from: I,
388 alternatives: Vec<I>,
389 round: u32,
390 ) {
391 let now = std::time::Instant::now();
392 let next_retry = now + self.base_timeout;
393 let mut inner = self.inner.lock();
394
395 if inner.entries.contains_key(&message_id) {
397 return;
398 }
399
400 inner.entries.insert(
401 message_id,
402 GraftEntry {
403 created: now,
404 next_retry,
405 from,
406 alternative_peers: alternatives,
407 round,
408 retry_count: 0,
409 },
410 );
411 Self::add_to_timeout_index(&mut inner, next_retry, message_id);
412 }
413
414 pub fn message_received(&self, message_id: &MessageId) -> bool {
422 let mut inner = self.inner.lock();
423 if let Some(entry) = inner.entries.remove(message_id) {
424 Self::remove_from_timeout_index(&mut inner, entry.next_retry, message_id);
426
427 let was_graft_sent = entry.retry_count > 0;
429
430 #[cfg(feature = "metrics")]
432 if was_graft_sent {
433 crate::metrics::record_graft_success();
434 let latency = entry.created.elapsed().as_secs_f64();
435 crate::metrics::record_graft_latency(latency);
436 }
437
438 return was_graft_sent;
439 }
440 false
441 }
442
443 fn calculate_backoff(&self, retry_count: u32) -> Duration {
445 let multiplier = 1u32.checked_shl(retry_count).unwrap_or(u32::MAX);
447 let backoff = self.base_timeout.saturating_mul(multiplier);
448 std::cmp::min(backoff, self.max_timeout)
449 }
450
451 pub fn get_expired(&self) -> Vec<ExpiredGraft<I>> {
456 let (expired, _) = self.get_expired_with_failures();
457 expired
458 }
459
460 pub fn get_expired_with_failures(&self) -> (Vec<ExpiredGraft<I>>, Vec<FailedGraft<I>>) {
469 let now = std::time::Instant::now();
470 let mut inner = self.inner.lock();
471
472 let mut expired = Vec::new();
473 let mut failed = Vec::new();
474
475 let expired_times: Vec<std::time::Instant> =
478 inner.timeouts.range(..=now).map(|(t, _)| *t).collect();
479
480 let mut to_reschedule: Vec<(MessageId, std::time::Instant)> = Vec::new();
482 let mut to_remove: Vec<MessageId> = Vec::new();
483
484 for timeout in expired_times {
486 let Some(message_ids) = inner.timeouts.remove(&timeout) else {
488 continue;
489 };
490
491 for message_id in message_ids {
492 let Some(entry) = inner.entries.get_mut(&message_id) else {
494 continue;
495 };
496
497 if now < entry.next_retry {
499 to_reschedule.push((message_id, entry.next_retry));
501 continue;
502 }
503
504 let peer = if entry.retry_count == 0 {
506 entry.from.clone()
508 } else {
509 let alt_idx =
511 (entry.retry_count - 1) as usize % entry.alternative_peers.len().max(1);
512 if alt_idx < entry.alternative_peers.len() {
513 entry.alternative_peers[alt_idx].clone()
514 } else {
515 entry.from.clone()
516 }
517 };
518
519 expired.push(ExpiredGraft {
520 message_id,
521 peer,
522 round: entry.round,
523 retry_count: entry.retry_count,
524 });
525
526 entry.retry_count += 1;
527
528 if entry.retry_count >= self.max_retries {
529 failed.push(FailedGraft {
531 message_id,
532 original_peer: entry.from.clone(),
533 total_retries: entry.retry_count,
534 });
535
536 #[cfg(feature = "metrics")]
538 crate::metrics::record_graft_failed();
539
540 to_remove.push(message_id);
542 } else {
543 let backoff = self.calculate_backoff(entry.retry_count);
545 let new_timeout = now + backoff;
546 entry.next_retry = new_timeout;
547 to_reschedule.push((message_id, new_timeout));
549 }
550 }
551 }
552
553 for id in to_remove {
555 inner.entries.remove(&id);
556 }
557 for (id, timeout) in to_reschedule {
558 Self::add_to_timeout_index(&mut inner, timeout, id);
559 }
560
561 (expired, failed)
562 }
563
564 pub fn clear(&self) {
566 let mut inner = self.inner.lock();
567 inner.entries.clear();
568 inner.timeouts.clear();
569 }
570
571 pub fn pending_count(&self) -> usize {
573 self.inner.lock().entries.len()
574 }
575
576 pub fn base_timeout(&self) -> Duration {
578 self.base_timeout
579 }
580
581 pub fn max_retries(&self) -> u32 {
583 self.max_retries
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590
591 #[test]
592 fn test_ihave_queue_push_pop() {
593 let queue = IHaveQueue::new(100);
594
595 let id = MessageId::new();
596 assert!(queue.push(id, 0));
597
598 let batch = queue.pop_batch(10);
599 assert_eq!(batch.len(), 1);
600 assert_eq!(batch[0].message_id, id);
601 }
602
603 #[test]
604 fn test_ihave_queue_capacity() {
605 let queue = IHaveQueue::new(3);
606
607 for i in 0..5 {
608 let pushed = queue.push(MessageId::new(), i);
609 if i < 3 {
610 assert!(pushed);
611 } else {
612 assert!(!pushed);
613 }
614 }
615
616 assert_eq!(queue.len(), 3);
617 }
618
619 #[test]
620 fn test_ihave_queue_batch() {
621 let queue = IHaveQueue::new(100);
622
623 for i in 0..10 {
624 queue.push(MessageId::new(), i);
625 }
626
627 let batch = queue.pop_batch(5);
628 assert_eq!(batch.len(), 5);
629 assert_eq!(queue.len(), 5);
630 }
631
632 #[test]
633 fn test_ihave_queue_stop() {
634 let queue = IHaveQueue::new(100);
635
636 assert!(queue.push(MessageId::new(), 0));
637 queue.stop();
638 assert!(!queue.push(MessageId::new(), 0));
639 queue.resume();
640 assert!(queue.push(MessageId::new(), 0));
641 }
642
643 #[test]
644 fn test_graft_timer() {
645 let timer: GraftTimer<u64> = GraftTimer::new(Duration::from_millis(50));
646
647 let id = MessageId::new();
648 timer.expect_message(id, 42u64, 0);
649
650 let expired = timer.get_expired();
652 assert!(expired.is_empty());
653
654 std::thread::sleep(Duration::from_millis(100));
656
657 let expired = timer.get_expired();
658 assert_eq!(expired.len(), 1);
659 assert_eq!(expired[0].message_id, id);
660 assert_eq!(expired[0].peer, 42u64);
661 assert_eq!(expired[0].retry_count, 0);
662 }
663
664 #[test]
665 fn test_graft_timer_message_received() {
666 let timer: GraftTimer<u64> = GraftTimer::new(Duration::from_millis(50));
667
668 let id = MessageId::new();
669 timer.expect_message(id, 42u64, 0);
670 timer.message_received(&id);
671
672 std::thread::sleep(Duration::from_millis(100));
673
674 let expired = timer.get_expired();
675 assert!(expired.is_empty());
676 }
677
678 #[test]
679 fn test_graft_timer_backoff() {
680 let timer: GraftTimer<u64> =
681 GraftTimer::with_backoff(Duration::from_millis(20), Duration::from_millis(160), 3);
682
683 let id = MessageId::new();
684 timer.expect_message(id, 1u64, 0);
685
686 std::thread::sleep(Duration::from_millis(30));
688 let expired = timer.get_expired();
689 assert_eq!(expired.len(), 1);
690 assert_eq!(expired[0].retry_count, 0);
691
692 std::thread::sleep(Duration::from_millis(30));
694 let expired = timer.get_expired();
695 assert!(expired.is_empty()); std::thread::sleep(Duration::from_millis(20));
698 let expired = timer.get_expired();
699 assert_eq!(expired.len(), 1);
700 assert_eq!(expired[0].retry_count, 1);
701
702 std::thread::sleep(Duration::from_millis(90));
704 let expired = timer.get_expired();
705 assert_eq!(expired.len(), 1);
706 assert_eq!(expired[0].retry_count, 2);
707
708 assert_eq!(timer.pending_count(), 0);
710 }
711
712 #[test]
713 fn test_graft_timer_alternatives() {
714 let timer: GraftTimer<u64> =
715 GraftTimer::with_backoff(Duration::from_millis(20), Duration::from_millis(200), 4);
716
717 let id = MessageId::new();
718 let primary = 1u64;
719 let alt1 = 2u64;
720 let alt2 = 3u64;
721 timer.expect_message_with_alternatives(id, primary, vec![alt1, alt2], 0);
722
723 std::thread::sleep(Duration::from_millis(30));
725 let expired = timer.get_expired();
726 assert_eq!(expired[0].peer, primary);
727
728 std::thread::sleep(Duration::from_millis(50));
730 let expired = timer.get_expired();
731 assert_eq!(expired[0].peer, alt1);
732
733 std::thread::sleep(Duration::from_millis(90));
735 let expired = timer.get_expired();
736 assert_eq!(expired[0].peer, alt2);
737
738 std::thread::sleep(Duration::from_millis(170));
740 let expired = timer.get_expired();
741 assert_eq!(expired[0].peer, alt1);
742 }
743
744 #[test]
745 fn test_scheduler() {
746 let scheduler = IHaveScheduler::new(Duration::from_millis(100), 16, 1000);
747
748 scheduler.queue().push(MessageId::new(), 0);
749 scheduler.queue().push(MessageId::new(), 1);
750
751 let batch = scheduler.pop_batch();
752 assert_eq!(batch.len(), 2);
753 }
754}