nomad_protocol/sync/
ack.rs

1//! Acknowledgment tracking
2//!
3//! Tracks which versions have been acknowledged and manages retransmission.
4
5use std::time::{Duration, Instant};
6
7/// Tracks pending acknowledgments for a message
8#[derive(Debug, Clone)]
9pub struct PendingAck {
10    /// Version that needs acknowledgment
11    pub version: u64,
12    /// Time when the message was sent
13    pub sent_at: Instant,
14    /// Number of retransmissions
15    pub retransmit_count: u32,
16    /// Current retransmission timeout
17    pub rto: Duration,
18}
19
20impl PendingAck {
21    /// Create a new pending ack
22    pub fn new(version: u64, rto: Duration) -> Self {
23        Self {
24            version,
25            sent_at: Instant::now(),
26            retransmit_count: 0,
27            rto,
28        }
29    }
30
31    /// Check if retransmission is needed
32    pub fn needs_retransmit(&self) -> bool {
33        self.sent_at.elapsed() >= self.rto
34    }
35
36    /// Mark as retransmitted with updated timeout
37    pub fn retransmit(&mut self, backoff_multiplier: u32, max_rto: Duration) {
38        self.sent_at = Instant::now();
39        self.retransmit_count += 1;
40        // Exponential backoff
41        self.rto = (self.rto * backoff_multiplier).min(max_rto);
42    }
43
44    /// Time until retransmission is needed
45    pub fn time_until_retransmit(&self) -> Duration {
46        let elapsed = self.sent_at.elapsed();
47        if elapsed >= self.rto {
48            Duration::ZERO
49        } else {
50            self.rto - elapsed
51        }
52    }
53}
54
55/// Default initial retransmission timeout (1 second).
56pub const DEFAULT_INITIAL_RTO: Duration = Duration::from_millis(1000);
57
58/// Default minimum retransmission timeout (100ms).
59/// Prevents RTO from becoming too aggressive on low-latency networks.
60pub const DEFAULT_MIN_RTO: Duration = Duration::from_millis(100);
61
62/// Default maximum retransmission timeout (60 seconds).
63/// Caps RTO growth during sustained packet loss.
64pub const DEFAULT_MAX_RTO: Duration = Duration::from_secs(60);
65
66/// Default exponential backoff multiplier for RTO (2x).
67/// Applied after each retransmission timeout.
68pub const DEFAULT_BACKOFF_MULTIPLIER: u32 = 2;
69
70/// Default maximum number of retransmission attempts (10).
71/// After this many failures, the sync is considered failed.
72pub const DEFAULT_MAX_RETRANSMITS: u32 = 10;
73
74/// Acknowledgment tracker
75///
76/// Tracks pending acknowledgments and manages retransmission logic.
77#[derive(Debug)]
78pub struct AckTracker {
79    /// Currently pending acknowledgments (version -> pending ack)
80    pending: Vec<PendingAck>,
81
82    /// Highest version acknowledged by peer
83    highest_acked: u64,
84
85    /// RTO configuration
86    initial_rto: Duration,
87    min_rto: Duration,
88    max_rto: Duration,
89    backoff_multiplier: u32,
90    max_retransmits: u32,
91
92    /// Smoothed RTT and RTT variance (RFC 6298)
93    srtt: Option<Duration>,
94    rttvar: Option<Duration>,
95}
96
97impl AckTracker {
98    /// Create a new ack tracker with default settings
99    pub fn new() -> Self {
100        Self {
101            pending: Vec::new(),
102            highest_acked: 0,
103            initial_rto: DEFAULT_INITIAL_RTO,
104            min_rto: DEFAULT_MIN_RTO,
105            max_rto: DEFAULT_MAX_RTO,
106            backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
107            max_retransmits: DEFAULT_MAX_RETRANSMITS,
108            srtt: None,
109            rttvar: None,
110        }
111    }
112
113    /// Create with custom RTO settings
114    pub fn with_rto(
115        initial_rto: Duration,
116        min_rto: Duration,
117        max_rto: Duration,
118        backoff_multiplier: u32,
119        max_retransmits: u32,
120    ) -> Self {
121        Self {
122            pending: Vec::new(),
123            highest_acked: 0,
124            initial_rto,
125            min_rto,
126            max_rto,
127            backoff_multiplier,
128            max_retransmits,
129            srtt: None,
130            rttvar: None,
131        }
132    }
133
134    /// Register a sent message that needs acknowledgment
135    pub fn register_sent(&mut self, version: u64) {
136        // Don't register if already pending
137        if self.pending.iter().any(|p| p.version == version) {
138            return;
139        }
140
141        let rto = self.current_rto();
142        self.pending.push(PendingAck::new(version, rto));
143    }
144
145    /// Process an incoming acknowledgment
146    ///
147    /// Returns the RTT sample if this ack is for a pending message.
148    pub fn process_ack(&mut self, acked_version: u64) -> Option<Duration> {
149        if acked_version <= self.highest_acked {
150            return None;
151        }
152
153        self.highest_acked = acked_version;
154
155        // Find and remove all pending acks up to this version
156        let mut rtt_sample = None;
157
158        self.pending.retain(|pending| {
159            if pending.version <= acked_version {
160                // Only use as RTT sample if not retransmitted
161                if pending.retransmit_count == 0 && rtt_sample.is_none() {
162                    rtt_sample = Some(pending.sent_at.elapsed());
163                }
164                false // Remove from pending
165            } else {
166                true // Keep in pending
167            }
168        });
169
170        // Update RTT estimates if we got a sample
171        if let Some(rtt) = rtt_sample {
172            self.update_rtt(rtt);
173        }
174
175        rtt_sample
176    }
177
178    /// Update RTT estimates using RFC 6298 algorithm
179    fn update_rtt(&mut self, rtt: Duration) {
180        let rtt_secs = rtt.as_secs_f64();
181
182        match (self.srtt, self.rttvar) {
183            (None, None) => {
184                // First measurement
185                self.srtt = Some(rtt);
186                self.rttvar = Some(rtt / 2);
187            }
188            (Some(srtt), Some(rttvar)) => {
189                // Subsequent measurements
190                let srtt_secs = srtt.as_secs_f64();
191                let rttvar_secs = rttvar.as_secs_f64();
192
193                // RTTVAR = (1 - beta) * RTTVAR + beta * |SRTT - R'|
194                // where beta = 1/4
195                let new_rttvar =
196                    0.75 * rttvar_secs + 0.25 * (srtt_secs - rtt_secs).abs();
197
198                // SRTT = (1 - alpha) * SRTT + alpha * R'
199                // where alpha = 1/8
200                let new_srtt = 0.875 * srtt_secs + 0.125 * rtt_secs;
201
202                self.srtt = Some(Duration::from_secs_f64(new_srtt));
203                self.rttvar = Some(Duration::from_secs_f64(new_rttvar));
204            }
205            _ => {}
206        }
207    }
208
209    /// Get current RTO based on RTT estimates
210    pub fn current_rto(&self) -> Duration {
211        match (self.srtt, self.rttvar) {
212            (Some(srtt), Some(rttvar)) => {
213                // RTO = SRTT + max(G, K*RTTVAR) where K=4, G=clock granularity
214                // We use 1ms as clock granularity
215                let k = 4;
216                let g = Duration::from_millis(1);
217                let rto = srtt + (g.max(rttvar * k));
218                rto.clamp(self.min_rto, self.max_rto)
219            }
220            _ => self.initial_rto,
221        }
222    }
223
224    /// Get the smoothed RTT if available
225    pub fn srtt(&self) -> Option<Duration> {
226        self.srtt
227    }
228
229    /// Get the RTT variance if available
230    pub fn rttvar(&self) -> Option<Duration> {
231        self.rttvar
232    }
233
234    /// Get pending acks that need retransmission
235    pub fn needs_retransmit(&self) -> impl Iterator<Item = u64> + '_ {
236        self.pending
237            .iter()
238            .filter(|p| p.needs_retransmit() && p.retransmit_count < self.max_retransmits)
239            .map(|p| p.version)
240    }
241
242    /// Get versions that have exceeded max retransmits
243    pub fn failed_versions(&self) -> impl Iterator<Item = u64> + '_ {
244        self.pending
245            .iter()
246            .filter(|p| p.retransmit_count >= self.max_retransmits)
247            .map(|p| p.version)
248    }
249
250    /// Mark a version as retransmitted
251    pub fn mark_retransmitted(&mut self, version: u64) {
252        if let Some(pending) = self.pending.iter_mut().find(|p| p.version == version) {
253            pending.retransmit(self.backoff_multiplier, self.max_rto);
254        }
255    }
256
257    /// Check if there are pending acknowledgments
258    pub fn has_pending(&self) -> bool {
259        !self.pending.is_empty()
260    }
261
262    /// Get number of pending acknowledgments
263    pub fn pending_count(&self) -> usize {
264        self.pending.len()
265    }
266
267    /// Get highest acknowledged version
268    pub fn highest_acked(&self) -> u64 {
269        self.highest_acked
270    }
271
272    /// Get time until next retransmission is needed
273    pub fn time_until_retransmit(&self) -> Option<Duration> {
274        self.pending
275            .iter()
276            .filter(|p| p.retransmit_count < self.max_retransmits)
277            .map(|p| p.time_until_retransmit())
278            .min()
279    }
280
281    /// Cancel a pending ack (e.g., on connection close)
282    pub fn cancel(&mut self, version: u64) {
283        self.pending.retain(|p| p.version != version);
284    }
285
286    /// Cancel all pending acks
287    pub fn cancel_all(&mut self) {
288        self.pending.clear();
289    }
290
291    /// Reset tracker state
292    pub fn reset(&mut self) {
293        self.pending.clear();
294        self.highest_acked = 0;
295        self.srtt = None;
296        self.rttvar = None;
297    }
298}
299
300impl Default for AckTracker {
301    fn default() -> Self {
302        Self::new()
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use std::thread;
310
311    #[test]
312    fn test_new_tracker() {
313        let tracker = AckTracker::new();
314        assert!(!tracker.has_pending());
315        assert_eq!(tracker.highest_acked(), 0);
316        assert_eq!(tracker.current_rto(), DEFAULT_INITIAL_RTO);
317    }
318
319    #[test]
320    fn test_register_sent() {
321        let mut tracker = AckTracker::new();
322
323        tracker.register_sent(1);
324        assert!(tracker.has_pending());
325        assert_eq!(tracker.pending_count(), 1);
326
327        // Duplicate registration should not add another
328        tracker.register_sent(1);
329        assert_eq!(tracker.pending_count(), 1);
330
331        tracker.register_sent(2);
332        assert_eq!(tracker.pending_count(), 2);
333    }
334
335    #[test]
336    fn test_process_ack() {
337        let mut tracker = AckTracker::new();
338
339        tracker.register_sent(1);
340        tracker.register_sent(2);
341        tracker.register_sent(3);
342
343        // Ack version 2 should clear 1 and 2
344        tracker.process_ack(2);
345        assert_eq!(tracker.highest_acked(), 2);
346        assert_eq!(tracker.pending_count(), 1); // Only version 3 remains
347
348        // Lower ack should be ignored
349        tracker.process_ack(1);
350        assert_eq!(tracker.highest_acked(), 2);
351    }
352
353    #[test]
354    fn test_rtt_sample() {
355        let mut tracker = AckTracker::new();
356
357        tracker.register_sent(1);
358        thread::sleep(Duration::from_millis(10));
359
360        let rtt = tracker.process_ack(1);
361        assert!(rtt.is_some());
362        assert!(rtt.unwrap() >= Duration::from_millis(10));
363
364        // After first sample, we should have RTT estimates
365        assert!(tracker.srtt().is_some());
366        assert!(tracker.rttvar().is_some());
367    }
368
369    #[test]
370    fn test_retransmit() {
371        let mut tracker = AckTracker::with_rto(
372            Duration::from_millis(10),
373            Duration::from_millis(10),
374            Duration::from_secs(1),
375            2,
376            3,
377        );
378
379        tracker.register_sent(1);
380
381        // Initially should not need retransmit
382        assert_eq!(tracker.needs_retransmit().count(), 0);
383
384        // Wait for RTO
385        thread::sleep(Duration::from_millis(15));
386
387        // Now should need retransmit
388        let versions: Vec<_> = tracker.needs_retransmit().collect();
389        assert_eq!(versions, vec![1]);
390
391        // Mark as retransmitted
392        tracker.mark_retransmitted(1);
393
394        // Should not immediately need retransmit again
395        assert_eq!(tracker.needs_retransmit().count(), 0);
396    }
397
398    #[test]
399    fn test_max_retransmits() {
400        let mut tracker = AckTracker::with_rto(
401            Duration::from_millis(1),
402            Duration::from_millis(1),
403            Duration::from_millis(10),
404            1, // No backoff
405            2, // Max 2 retransmits
406        );
407
408        tracker.register_sent(1);
409        thread::sleep(Duration::from_millis(5));
410
411        // First retransmit
412        tracker.mark_retransmitted(1);
413        thread::sleep(Duration::from_millis(5));
414
415        // Second retransmit
416        tracker.mark_retransmitted(1);
417        thread::sleep(Duration::from_millis(5));
418
419        // Should now be in failed state
420        let failed: Vec<_> = tracker.failed_versions().collect();
421        assert_eq!(failed, vec![1]);
422
423        // Should not show up in needs_retransmit
424        assert_eq!(tracker.needs_retransmit().count(), 0);
425    }
426
427    #[test]
428    fn test_cancel() {
429        let mut tracker = AckTracker::new();
430
431        tracker.register_sent(1);
432        tracker.register_sent(2);
433        tracker.register_sent(3);
434
435        tracker.cancel(2);
436        assert_eq!(tracker.pending_count(), 2);
437
438        tracker.cancel_all();
439        assert!(!tracker.has_pending());
440    }
441
442    #[test]
443    fn test_reset() {
444        let mut tracker = AckTracker::new();
445
446        tracker.register_sent(1);
447        tracker.process_ack(1);
448
449        tracker.reset();
450
451        assert!(!tracker.has_pending());
452        assert_eq!(tracker.highest_acked(), 0);
453        assert!(tracker.srtt().is_none());
454    }
455
456    #[test]
457    fn test_time_until_retransmit() {
458        let mut tracker = AckTracker::with_rto(
459            Duration::from_millis(100),
460            Duration::from_millis(100),
461            Duration::from_secs(1),
462            2,
463            10,
464        );
465
466        assert!(tracker.time_until_retransmit().is_none());
467
468        tracker.register_sent(1);
469        let time = tracker.time_until_retransmit();
470        assert!(time.is_some());
471        assert!(time.unwrap() <= Duration::from_millis(100));
472    }
473}