memberlist_plumtree/
scheduler.rs

1//! Background scheduler for IHave announcements.
2//!
3//! Batches IHave messages and sends them periodically to lazy peers,
4//! using a lock-free queue for efficient producer/consumer pattern.
5
6use crossbeam_queue::SegQueue;
7use smallvec::SmallVec;
8use std::{
9    sync::{
10        atomic::{AtomicBool, AtomicUsize, Ordering},
11        Arc,
12    },
13    time::Duration,
14};
15
16use crate::message::MessageId;
17
18/// Pending IHave entry in the queue.
19#[derive(Debug, Clone)]
20pub struct PendingIHave {
21    /// Message ID to announce.
22    pub message_id: MessageId,
23    /// Round number for this announcement.
24    pub round: u32,
25}
26
27/// Lock-free queue for pending IHave announcements.
28///
29/// Uses crossbeam's SegQueue for efficient concurrent access
30/// without lock contention in the hot path.
31#[derive(Debug)]
32pub struct IHaveQueue {
33    /// Lock-free queue of pending announcements.
34    queue: SegQueue<PendingIHave>,
35    /// Approximate length (may be slightly stale).
36    len: AtomicUsize,
37    /// Maximum queue size before dropping.
38    max_size: usize,
39    /// Flag indicating if the queue is accepting new items.
40    accepting: AtomicBool,
41    /// Threshold for early flush notification.
42    flush_threshold: AtomicUsize,
43}
44
45impl IHaveQueue {
46    /// Create a new IHave queue with the specified maximum size.
47    pub fn new(max_size: usize) -> Self {
48        Self {
49            queue: SegQueue::new(),
50            len: AtomicUsize::new(0),
51            max_size,
52            accepting: AtomicBool::new(true),
53            flush_threshold: AtomicUsize::new(16), // Default batch size
54        }
55    }
56
57    /// Create a new IHave queue with a specific flush threshold.
58    pub fn with_flush_threshold(max_size: usize, flush_threshold: usize) -> Self {
59        Self {
60            queue: SegQueue::new(),
61            len: AtomicUsize::new(0),
62            max_size,
63            accepting: AtomicBool::new(true),
64            flush_threshold: AtomicUsize::new(flush_threshold),
65        }
66    }
67
68    /// Set the flush threshold (batch size).
69    pub fn set_flush_threshold(&self, threshold: usize) {
70        self.flush_threshold.store(threshold, Ordering::Relaxed);
71    }
72
73    /// Push a new IHave announcement to the queue.
74    ///
75    /// Returns `true` if the item was queued, `false` if the queue is full
76    /// or not accepting new items.
77    pub fn push(&self, message_id: MessageId, round: u32) -> bool {
78        // Check if accepting
79        if !self.accepting.load(Ordering::Acquire) {
80            return false;
81        }
82
83        // Check approximate size
84        let current_len = self.len.load(Ordering::Relaxed);
85        if current_len >= self.max_size {
86            return false;
87        }
88
89        // Push to queue
90        self.queue.push(PendingIHave { message_id, round });
91        self.len.fetch_add(1, Ordering::Relaxed);
92        true
93    }
94
95    /// Check if the queue has reached the flush threshold.
96    ///
97    /// This can be used to trigger early batch flush in high-throughput scenarios.
98    pub fn should_flush(&self) -> bool {
99        let current_len = self.len.load(Ordering::Relaxed);
100        let threshold = self.flush_threshold.load(Ordering::Relaxed);
101        current_len >= threshold
102    }
103
104    /// Get the current flush threshold.
105    pub fn flush_threshold(&self) -> usize {
106        self.flush_threshold.load(Ordering::Relaxed)
107    }
108
109    /// Pop a batch of IHave announcements from the queue.
110    ///
111    /// Returns up to `max_batch` items.
112    pub fn pop_batch(&self, max_batch: usize) -> SmallVec<[PendingIHave; 16]> {
113        let mut batch = SmallVec::new();
114
115        for _ in 0..max_batch {
116            if let Some(item) = self.queue.pop() {
117                self.len.fetch_sub(1, Ordering::Relaxed);
118                batch.push(item);
119            } else {
120                break;
121            }
122        }
123
124        batch
125    }
126
127    /// Get the approximate length of the queue.
128    pub fn len(&self) -> usize {
129        self.len.load(Ordering::Relaxed)
130    }
131
132    /// Check if the queue is empty.
133    pub fn is_empty(&self) -> bool {
134        self.len() == 0
135    }
136
137    /// Stop accepting new items.
138    pub fn stop(&self) {
139        self.accepting.store(false, Ordering::Release);
140    }
141
142    /// Resume accepting new items.
143    pub fn resume(&self) {
144        self.accepting.store(true, Ordering::Release);
145    }
146
147    /// Clear all pending items from the queue.
148    pub fn clear(&self) {
149        while self.queue.pop().is_some() {
150            self.len.fetch_sub(1, Ordering::Relaxed);
151        }
152    }
153}
154
155impl Default for IHaveQueue {
156    fn default() -> Self {
157        Self::new(10000)
158    }
159}
160
161/// Scheduler for managing IHave announcement timing.
162///
163/// Tracks pending Graft timeouts and manages the IHave send interval.
164#[derive(Debug)]
165pub struct IHaveScheduler {
166    /// Queue of pending IHave announcements.
167    queue: Arc<IHaveQueue>,
168    /// Interval between IHave batches.
169    interval: Duration,
170    /// Maximum batch size.
171    batch_size: usize,
172    /// Shutdown flag.
173    shutdown: AtomicBool,
174}
175
176impl IHaveScheduler {
177    /// Create a new IHave scheduler.
178    pub fn new(interval: Duration, batch_size: usize, max_queue_size: usize) -> Self {
179        Self {
180            queue: Arc::new(IHaveQueue::with_flush_threshold(max_queue_size, batch_size)),
181            interval,
182            batch_size,
183            shutdown: AtomicBool::new(false),
184        }
185    }
186
187    /// Get a reference to the IHave queue.
188    pub fn queue(&self) -> &Arc<IHaveQueue> {
189        &self.queue
190    }
191
192    /// Get the configured interval.
193    pub fn interval(&self) -> Duration {
194        self.interval
195    }
196
197    /// Get the configured batch size.
198    pub fn batch_size(&self) -> usize {
199        self.batch_size
200    }
201
202    /// Check if shutdown has been requested.
203    pub fn is_shutdown(&self) -> bool {
204        self.shutdown.load(Ordering::Acquire)
205    }
206
207    /// Request shutdown.
208    pub fn shutdown(&self) {
209        self.shutdown.store(true, Ordering::Release);
210        self.queue.stop();
211    }
212
213    /// Pop a batch for sending.
214    pub fn pop_batch(&self) -> SmallVec<[PendingIHave; 16]> {
215        self.queue.pop_batch(self.batch_size)
216    }
217}
218
219/// Tracks pending Graft requests with timeouts and exponential backoff.
220///
221/// When an IHave is received, we wait for the actual message.
222/// If not received within timeout, we send a Graft with exponential backoff.
223///
224/// # Performance
225///
226/// Uses a dual-index structure for efficient operations:
227/// - `HashMap<MessageId, GraftEntry>` for O(1) lookup by message ID (cancel on receive)
228/// - `BTreeMap<Instant, HashSet<MessageId>>` for O(K) expired entry retrieval (K = expired count)
229///
230/// This avoids the O(N) full scan on every timer tick, which is critical under high load
231/// when thousands of messages may be pending Graft.
232///
233/// # Type Parameters
234///
235/// - `I`: Peer identifier type (must be Clone + Send + Sync)
236#[derive(Debug)]
237pub struct GraftTimer<I> {
238    /// Inner state protected by mutex.
239    inner: parking_lot::Mutex<GraftTimerInner<I>>,
240    /// Base timeout for expecting message after IHave.
241    base_timeout: Duration,
242    /// Maximum timeout after backoff.
243    max_timeout: Duration,
244    /// Maximum retry attempts before giving up.
245    max_retries: u32,
246}
247
248/// Inner state of GraftTimer (held under lock).
249#[derive(Debug)]
250struct GraftTimerInner<I> {
251    /// Pending message IDs waiting for content (fast lookup by ID).
252    entries: std::collections::HashMap<MessageId, GraftEntry<I>>,
253    /// Time-sorted index for efficient expired entry retrieval.
254    /// Maps timeout instant -> set of message IDs expiring at that time.
255    timeouts: std::collections::BTreeMap<std::time::Instant, std::collections::HashSet<MessageId>>,
256}
257
258impl<I> Default for GraftTimerInner<I> {
259    fn default() -> Self {
260        Self {
261            entries: std::collections::HashMap::new(),
262            timeouts: std::collections::BTreeMap::new(),
263        }
264    }
265}
266
267#[derive(Debug, Clone)]
268struct GraftEntry<I> {
269    /// When this entry was created.
270    /// Reserved for future use in timeout diagnostics.
271    #[allow(dead_code)]
272    created: std::time::Instant,
273    /// When the next retry should occur.
274    next_retry: std::time::Instant,
275    /// Node that sent the IHave.
276    from: I,
277    /// Alternative peers to try.
278    alternative_peers: Vec<I>,
279    /// Round from the IHave.
280    round: u32,
281    /// Number of retry attempts made.
282    retry_count: u32,
283}
284
285/// Result of checking for expired Graft timers.
286#[derive(Debug, Clone)]
287pub struct ExpiredGraft<I> {
288    /// Message ID that needs to be requested.
289    pub message_id: MessageId,
290    /// Peer to request from.
291    pub peer: I,
292    /// Round number.
293    pub round: u32,
294    /// Which retry attempt this is (0 = first attempt).
295    pub retry_count: u32,
296}
297
298/// Information about a failed Graft after max retries exhausted.
299#[derive(Debug, Clone)]
300pub struct FailedGraft<I> {
301    /// Message ID that could not be retrieved.
302    pub message_id: MessageId,
303    /// Original peer that sent the IHave (potential zombie).
304    pub original_peer: I,
305    /// Total retry attempts made.
306    pub total_retries: u32,
307}
308
309impl<I: Clone + Send + Sync + 'static> GraftTimer<I> {
310    /// Create a new Graft timer with default backoff settings.
311    pub fn new(timeout: Duration) -> Self {
312        Self {
313            inner: parking_lot::Mutex::new(GraftTimerInner::default()),
314            base_timeout: timeout,
315            max_timeout: timeout * 8, // Max 8x base timeout
316            max_retries: 5,
317        }
318    }
319
320    /// Create a new Graft timer with custom backoff settings.
321    pub fn with_backoff(base_timeout: Duration, max_timeout: Duration, max_retries: u32) -> Self {
322        Self {
323            inner: parking_lot::Mutex::new(GraftTimerInner::default()),
324            base_timeout,
325            max_timeout,
326            max_retries,
327        }
328    }
329
330    /// Add a message ID to the timeout index at a given instant.
331    fn add_to_timeout_index(
332        inner: &mut GraftTimerInner<I>,
333        timeout: std::time::Instant,
334        message_id: MessageId,
335    ) {
336        inner
337            .timeouts
338            .entry(timeout)
339            .or_default()
340            .insert(message_id);
341    }
342
343    /// Remove a message ID from the timeout index at a given instant.
344    fn remove_from_timeout_index(
345        inner: &mut GraftTimerInner<I>,
346        timeout: std::time::Instant,
347        message_id: &MessageId,
348    ) {
349        if let Some(ids) = inner.timeouts.get_mut(&timeout) {
350            ids.remove(message_id);
351            // Clean up empty sets to avoid memory leak
352            if ids.is_empty() {
353                inner.timeouts.remove(&timeout);
354            }
355        }
356    }
357
358    /// Record that we're expecting a message (received IHave).
359    pub fn expect_message(&self, message_id: MessageId, from: I, round: u32) {
360        let now = std::time::Instant::now();
361        let next_retry = now + self.base_timeout;
362        let mut inner = self.inner.lock();
363
364        // Only insert if not already present
365        if inner.entries.contains_key(&message_id) {
366            return;
367        }
368
369        inner.entries.insert(
370            message_id,
371            GraftEntry {
372                created: now,
373                next_retry,
374                from,
375                alternative_peers: Vec::new(),
376                round,
377                retry_count: 0,
378            },
379        );
380        Self::add_to_timeout_index(&mut inner, next_retry, message_id);
381    }
382
383    /// Record that we're expecting a message with alternative peers to try.
384    pub fn expect_message_with_alternatives(
385        &self,
386        message_id: MessageId,
387        from: I,
388        alternatives: Vec<I>,
389        round: u32,
390    ) {
391        let now = std::time::Instant::now();
392        let next_retry = now + self.base_timeout;
393        let mut inner = self.inner.lock();
394
395        // Only insert if not already present
396        if inner.entries.contains_key(&message_id) {
397            return;
398        }
399
400        inner.entries.insert(
401            message_id,
402            GraftEntry {
403                created: now,
404                next_retry,
405                from,
406                alternative_peers: alternatives,
407                round,
408                retry_count: 0,
409            },
410        );
411        Self::add_to_timeout_index(&mut inner, next_retry, message_id);
412    }
413
414    /// Mark that a message was received (cancel Graft timer).
415    ///
416    /// This records a successful graft and the latency from when the entry
417    /// was created to when the message was received.
418    ///
419    /// Returns `true` if a graft was actually pending and sent (retry_count > 0),
420    /// which indicates a successful graft that the adaptive batcher should track.
421    pub fn message_received(&self, message_id: &MessageId) -> bool {
422        let mut inner = self.inner.lock();
423        if let Some(entry) = inner.entries.remove(message_id) {
424            // Also remove from timeout index
425            Self::remove_from_timeout_index(&mut inner, entry.next_retry, message_id);
426
427            // Only count as success if we actually sent a graft (retry_count > 0)
428            let was_graft_sent = entry.retry_count > 0;
429
430            // Record metrics if the feature is enabled
431            #[cfg(feature = "metrics")]
432            if was_graft_sent {
433                crate::metrics::record_graft_success();
434                let latency = entry.created.elapsed().as_secs_f64();
435                crate::metrics::record_graft_latency(latency);
436            }
437
438            return was_graft_sent;
439        }
440        false
441    }
442
443    /// Calculate backoff duration for a given retry count.
444    fn calculate_backoff(&self, retry_count: u32) -> Duration {
445        // Exponential backoff: base * 2^retry_count, capped at max
446        let multiplier = 1u32.checked_shl(retry_count).unwrap_or(u32::MAX);
447        let backoff = self.base_timeout.saturating_mul(multiplier);
448        std::cmp::min(backoff, self.max_timeout)
449    }
450
451    /// Get expired entries that need Graft requests.
452    ///
453    /// Returns list of expired grafts with peer to try and retry info.
454    /// Entries that exceed max_retries are removed.
455    pub fn get_expired(&self) -> Vec<ExpiredGraft<I>> {
456        let (expired, _) = self.get_expired_with_failures();
457        expired
458    }
459
460    /// Get expired entries and failed entries (max retries exceeded).
461    ///
462    /// This operation is O(K) where K is the number of expired entries,
463    /// NOT O(N) where N is total pending entries.
464    ///
465    /// Returns:
466    /// - `Vec<ExpiredGraft<I>>`: Entries that need retry
467    /// - `Vec<FailedGraft<I>>`: Entries that exceeded max retries (zombie peer detection)
468    pub fn get_expired_with_failures(&self) -> (Vec<ExpiredGraft<I>>, Vec<FailedGraft<I>>) {
469        let now = std::time::Instant::now();
470        let mut inner = self.inner.lock();
471
472        let mut expired = Vec::new();
473        let mut failed = Vec::new();
474
475        // Collect expired timeout keys - only iterate over times <= now
476        // This is O(K) where K is the number of expired time buckets
477        let expired_times: Vec<std::time::Instant> =
478            inner.timeouts.range(..=now).map(|(t, _)| *t).collect();
479
480        // Collect updates to apply after processing (to avoid borrow issues)
481        let mut to_reschedule: Vec<(MessageId, std::time::Instant)> = Vec::new();
482        let mut to_remove: Vec<MessageId> = Vec::new();
483
484        // Process each expired time bucket
485        for timeout in expired_times {
486            // Remove the entire bucket from the BTreeMap
487            let Some(message_ids) = inner.timeouts.remove(&timeout) else {
488                continue;
489            };
490
491            for message_id in message_ids {
492                // Get the entry - it might have been removed by message_received()
493                let Some(entry) = inner.entries.get_mut(&message_id) else {
494                    continue;
495                };
496
497                // Verify this entry is actually expired (defensive check)
498                if now < entry.next_retry {
499                    // Need to re-add to timeout index at correct time
500                    to_reschedule.push((message_id, entry.next_retry));
501                    continue;
502                }
503
504                // Determine which peer to try
505                let peer = if entry.retry_count == 0 {
506                    // First attempt: use original sender
507                    entry.from.clone()
508                } else {
509                    // Subsequent attempts: try alternatives in round-robin
510                    let alt_idx =
511                        (entry.retry_count - 1) as usize % entry.alternative_peers.len().max(1);
512                    if alt_idx < entry.alternative_peers.len() {
513                        entry.alternative_peers[alt_idx].clone()
514                    } else {
515                        entry.from.clone()
516                    }
517                };
518
519                expired.push(ExpiredGraft {
520                    message_id,
521                    peer,
522                    round: entry.round,
523                    retry_count: entry.retry_count,
524                });
525
526                entry.retry_count += 1;
527
528                if entry.retry_count >= self.max_retries {
529                    // Max retries exceeded, record failure for zombie detection
530                    failed.push(FailedGraft {
531                        message_id,
532                        original_peer: entry.from.clone(),
533                        total_retries: entry.retry_count,
534                    });
535
536                    // Record failure metric
537                    #[cfg(feature = "metrics")]
538                    crate::metrics::record_graft_failed();
539
540                    // Mark for removal
541                    to_remove.push(message_id);
542                } else {
543                    // Schedule next retry with backoff
544                    let backoff = self.calculate_backoff(entry.retry_count);
545                    let new_timeout = now + backoff;
546                    entry.next_retry = new_timeout;
547                    // Mark for rescheduling
548                    to_reschedule.push((message_id, new_timeout));
549                }
550            }
551        }
552
553        // Apply deferred updates
554        for id in to_remove {
555            inner.entries.remove(&id);
556        }
557        for (id, timeout) in to_reschedule {
558            Self::add_to_timeout_index(&mut inner, timeout, id);
559        }
560
561        (expired, failed)
562    }
563
564    /// Clear all pending entries.
565    pub fn clear(&self) {
566        let mut inner = self.inner.lock();
567        inner.entries.clear();
568        inner.timeouts.clear();
569    }
570
571    /// Get the number of pending entries.
572    pub fn pending_count(&self) -> usize {
573        self.inner.lock().entries.len()
574    }
575
576    /// Get the base timeout.
577    pub fn base_timeout(&self) -> Duration {
578        self.base_timeout
579    }
580
581    /// Get the max retries.
582    pub fn max_retries(&self) -> u32 {
583        self.max_retries
584    }
585}
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590
591    #[test]
592    fn test_ihave_queue_push_pop() {
593        let queue = IHaveQueue::new(100);
594
595        let id = MessageId::new();
596        assert!(queue.push(id, 0));
597
598        let batch = queue.pop_batch(10);
599        assert_eq!(batch.len(), 1);
600        assert_eq!(batch[0].message_id, id);
601    }
602
603    #[test]
604    fn test_ihave_queue_capacity() {
605        let queue = IHaveQueue::new(3);
606
607        for i in 0..5 {
608            let pushed = queue.push(MessageId::new(), i);
609            if i < 3 {
610                assert!(pushed);
611            } else {
612                assert!(!pushed);
613            }
614        }
615
616        assert_eq!(queue.len(), 3);
617    }
618
619    #[test]
620    fn test_ihave_queue_batch() {
621        let queue = IHaveQueue::new(100);
622
623        for i in 0..10 {
624            queue.push(MessageId::new(), i);
625        }
626
627        let batch = queue.pop_batch(5);
628        assert_eq!(batch.len(), 5);
629        assert_eq!(queue.len(), 5);
630    }
631
632    #[test]
633    fn test_ihave_queue_stop() {
634        let queue = IHaveQueue::new(100);
635
636        assert!(queue.push(MessageId::new(), 0));
637        queue.stop();
638        assert!(!queue.push(MessageId::new(), 0));
639        queue.resume();
640        assert!(queue.push(MessageId::new(), 0));
641    }
642
643    #[test]
644    fn test_graft_timer() {
645        let timer: GraftTimer<u64> = GraftTimer::new(Duration::from_millis(50));
646
647        let id = MessageId::new();
648        timer.expect_message(id, 42u64, 0);
649
650        // Not expired yet
651        let expired = timer.get_expired();
652        assert!(expired.is_empty());
653
654        // Wait for timeout
655        std::thread::sleep(Duration::from_millis(100));
656
657        let expired = timer.get_expired();
658        assert_eq!(expired.len(), 1);
659        assert_eq!(expired[0].message_id, id);
660        assert_eq!(expired[0].peer, 42u64);
661        assert_eq!(expired[0].retry_count, 0);
662    }
663
664    #[test]
665    fn test_graft_timer_message_received() {
666        let timer: GraftTimer<u64> = GraftTimer::new(Duration::from_millis(50));
667
668        let id = MessageId::new();
669        timer.expect_message(id, 42u64, 0);
670        timer.message_received(&id);
671
672        std::thread::sleep(Duration::from_millis(100));
673
674        let expired = timer.get_expired();
675        assert!(expired.is_empty());
676    }
677
678    #[test]
679    fn test_graft_timer_backoff() {
680        let timer: GraftTimer<u64> =
681            GraftTimer::with_backoff(Duration::from_millis(20), Duration::from_millis(160), 3);
682
683        let id = MessageId::new();
684        timer.expect_message(id, 1u64, 0);
685
686        // First expiry after base timeout
687        std::thread::sleep(Duration::from_millis(30));
688        let expired = timer.get_expired();
689        assert_eq!(expired.len(), 1);
690        assert_eq!(expired[0].retry_count, 0);
691
692        // Second expiry should be after 2x base timeout (40ms)
693        std::thread::sleep(Duration::from_millis(30));
694        let expired = timer.get_expired();
695        assert!(expired.is_empty()); // Not yet
696
697        std::thread::sleep(Duration::from_millis(20));
698        let expired = timer.get_expired();
699        assert_eq!(expired.len(), 1);
700        assert_eq!(expired[0].retry_count, 1);
701
702        // Third expiry after 4x base timeout (80ms)
703        std::thread::sleep(Duration::from_millis(90));
704        let expired = timer.get_expired();
705        assert_eq!(expired.len(), 1);
706        assert_eq!(expired[0].retry_count, 2);
707
708        // After max retries, entry should be removed
709        assert_eq!(timer.pending_count(), 0);
710    }
711
712    #[test]
713    fn test_graft_timer_alternatives() {
714        let timer: GraftTimer<u64> =
715            GraftTimer::with_backoff(Duration::from_millis(20), Duration::from_millis(200), 4);
716
717        let id = MessageId::new();
718        let primary = 1u64;
719        let alt1 = 2u64;
720        let alt2 = 3u64;
721        timer.expect_message_with_alternatives(id, primary, vec![alt1, alt2], 0);
722
723        // First try: primary peer
724        std::thread::sleep(Duration::from_millis(30));
725        let expired = timer.get_expired();
726        assert_eq!(expired[0].peer, primary);
727
728        // Second try: first alternative
729        std::thread::sleep(Duration::from_millis(50));
730        let expired = timer.get_expired();
731        assert_eq!(expired[0].peer, alt1);
732
733        // Third try: second alternative
734        std::thread::sleep(Duration::from_millis(90));
735        let expired = timer.get_expired();
736        assert_eq!(expired[0].peer, alt2);
737
738        // Fourth try: back to first alternative (round-robin)
739        std::thread::sleep(Duration::from_millis(170));
740        let expired = timer.get_expired();
741        assert_eq!(expired[0].peer, alt1);
742    }
743
744    #[test]
745    fn test_scheduler() {
746        let scheduler = IHaveScheduler::new(Duration::from_millis(100), 16, 1000);
747
748        scheduler.queue().push(MessageId::new(), 0);
749        scheduler.queue().push(MessageId::new(), 1);
750
751        let batch = scheduler.pop_batch();
752        assert_eq!(batch.len(), 2);
753    }
754}