Skip to main content

fips_core/mmp/
receiver.rs

1//! MMP receiver state machine.
2//!
3//! Tracks what this node has received from a specific peer and produces
4//! ReceiverReport messages on demand. One `ReceiverState` per active peer.
5
6use std::time::{Duration, Instant};
7
8use crate::mmp::algorithms::{JitterEstimator, OwdTrendDetector};
9use crate::mmp::report::ReceiverReport;
10use crate::mmp::{
11    COLD_START_SAMPLES, DEFAULT_COLD_START_INTERVAL_MS, DEFAULT_OWD_WINDOW_SIZE,
12    MAX_REPORT_INTERVAL_MS, MIN_REPORT_INTERVAL_MS,
13};
14
15/// Grace period after rekey before resuming jitter calculation.
16///
17/// During rekey cutover, frames from the old session may still arrive via the
18/// drain window (DRAIN_WINDOW_SECS = 10s). These carry large sender timestamps
19/// from the old session, producing enormous transit deltas that spike the EWMA
20/// jitter estimator. We suppress jitter updates for drain window + 5s margin.
21const REKEY_JITTER_GRACE_SECS: u64 = 15;
22
23// ============================================================================
24// Gap Tracker (burst loss detection)
25// ============================================================================
26
27/// Tracks counter gaps to detect loss bursts.
28///
29/// Each gap in the counter sequence is a burst of lost frames.
30/// Maintains per-interval statistics that are reset when a report is built.
31struct GapTracker {
32    /// Next expected counter value.
33    expected_next: Option<u64>,
34    /// Whether we are currently in a burst (gap).
35    in_burst: bool,
36    /// Length of the current burst.
37    current_burst_len: u16,
38
39    // --- Per-interval stats (reset on report) ---
40    /// Number of distinct burst events this interval.
41    burst_count: u32,
42    /// Longest burst in this interval.
43    max_burst_len: u16,
44    /// Sum of all burst lengths (for mean computation).
45    total_burst_len: u64,
46}
47
48impl GapTracker {
49    fn new() -> Self {
50        Self {
51            expected_next: None,
52            in_burst: false,
53            current_burst_len: 0,
54            burst_count: 0,
55            max_burst_len: 0,
56            total_burst_len: 0,
57        }
58    }
59
60    /// Process a received counter value. Returns the number of lost frames
61    /// detected (0 if in order or first frame).
62    fn observe(&mut self, counter: u64) -> u64 {
63        let Some(expected) = self.expected_next else {
64            // First frame: initialize
65            self.expected_next = Some(counter + 1);
66            return 0;
67        };
68
69        let lost = if counter > expected {
70            // Gap detected
71            let gap = counter - expected;
72            if self.in_burst {
73                // Extend current burst
74                self.current_burst_len = self.current_burst_len.saturating_add(gap as u16);
75            } else {
76                // New burst
77                self.in_burst = true;
78                self.current_burst_len = gap as u16;
79                self.burst_count += 1;
80            }
81            gap
82        } else {
83            // In-order or duplicate (counter <= expected)
84            if self.in_burst {
85                // End current burst
86                self.finish_burst();
87            }
88            0
89        };
90
91        // Update expected (always advance to counter+1 or keep expected if
92        // this was a late/reordered frame)
93        if counter >= expected {
94            self.expected_next = Some(counter + 1);
95        }
96
97        lost
98    }
99
100    /// Finish the current burst and record its stats.
101    fn finish_burst(&mut self) {
102        if self.in_burst {
103            self.max_burst_len = self.max_burst_len.max(self.current_burst_len);
104            self.total_burst_len += self.current_burst_len as u64;
105            self.in_burst = false;
106            self.current_burst_len = 0;
107        }
108    }
109
110    /// Get interval stats and reset for next interval.
111    fn take_interval_stats(&mut self) -> (u32, u16, u16) {
112        // Finish any in-progress burst
113        self.finish_burst();
114
115        let count = self.burst_count;
116        let max_len = self.max_burst_len;
117        let mean_len = if count > 0 {
118            // u8.8 fixed-point: (total / count) * 256
119            let mean_f = (self.total_burst_len as f64) / (count as f64);
120            (mean_f * 256.0) as u16
121        } else {
122            0
123        };
124
125        // Reset interval
126        self.burst_count = 0;
127        self.max_burst_len = 0;
128        self.total_burst_len = 0;
129
130        (count, max_len, mean_len)
131    }
132}
133
134// ============================================================================
135// ReceiverState
136// ============================================================================
137
138/// Per-peer receiver-side MMP state.
139///
140/// Accumulates per-frame observations and produces `ReceiverReport` snapshots.
141pub struct ReceiverState {
142    // --- Cumulative (lifetime) ---
143    cumulative_packets_recv: u64,
144    cumulative_bytes_recv: u64,
145    cumulative_reorder_count: u64,
146
147    /// Highest counter value ever received.
148    highest_counter: u64,
149
150    // --- Current interval ---
151    interval_packets_recv: u32,
152    interval_bytes_recv: u32,
153
154    // --- Jitter ---
155    jitter: JitterEstimator,
156
157    // --- OWD trend ---
158    owd_trend: OwdTrendDetector,
159    /// Monotonic sequence counter for OWD samples.
160    owd_seq: u32,
161
162    // --- Loss tracking ---
163    gap_tracker: GapTracker,
164
165    // --- ECN ---
166    ecn_ce_count: u32,
167
168    // --- Timestamp echo ---
169    /// Sender timestamp from the most recent frame (for echo).
170    last_sender_timestamp: u32,
171    /// Local time when the most recent frame was received (for dwell computation).
172    last_recv_time: Option<Instant>,
173
174    // --- Rekey grace ---
175    /// When set, jitter updates are suppressed until this instant passes.
176    /// Prevents drain-window frames from spiking the jitter estimator.
177    rekey_jitter_grace_until: Option<Instant>,
178
179    // --- Report timing ---
180    last_report_time: Option<Instant>,
181    report_interval: Duration,
182    /// Whether any frames have been received since the last report.
183    interval_has_data: bool,
184
185    // --- Cold-start tracking ---
186    /// Number of SRTT-based interval updates received.
187    srtt_sample_count: u32,
188}
189
190impl ReceiverState {
191    pub fn new(owd_window_size: usize) -> Self {
192        Self::new_with_cold_start(owd_window_size, DEFAULT_COLD_START_INTERVAL_MS)
193    }
194
195    /// Create with a custom cold-start interval (ms).
196    ///
197    /// Used by session-layer MMP which needs a longer initial interval
198    /// since reports consume bandwidth on every transit link.
199    pub fn new_with_cold_start(owd_window_size: usize, cold_start_ms: u64) -> Self {
200        Self {
201            cumulative_packets_recv: 0,
202            cumulative_bytes_recv: 0,
203            cumulative_reorder_count: 0,
204            highest_counter: 0,
205            interval_packets_recv: 0,
206            interval_bytes_recv: 0,
207            jitter: JitterEstimator::new(),
208            owd_trend: OwdTrendDetector::new(owd_window_size),
209            owd_seq: 0,
210            gap_tracker: GapTracker::new(),
211            ecn_ce_count: 0,
212            last_sender_timestamp: 0,
213            last_recv_time: None,
214            rekey_jitter_grace_until: None,
215            last_report_time: None,
216            report_interval: Duration::from_millis(cold_start_ms),
217            interval_has_data: false,
218            srtt_sample_count: 0,
219        }
220    }
221
222    /// Reset counter-dependent state for rekey cutover.
223    ///
224    /// After cutover, the new session starts with counter 0 and reset
225    /// timestamps. Without resetting, the old `highest_counter` and
226    /// `GapTracker.expected_next` cause false reorder/loss detection.
227    pub fn reset_for_rekey(&mut self, now: Instant) {
228        self.highest_counter = 0;
229        self.cumulative_reorder_count = 0;
230        self.gap_tracker = GapTracker::new();
231        self.interval_packets_recv = 0;
232        self.interval_bytes_recv = 0;
233        self.jitter = JitterEstimator::new();
234        self.owd_trend.clear();
235        self.owd_seq = 0;
236        self.last_sender_timestamp = 0;
237        self.last_recv_time = None;
238        self.rekey_jitter_grace_until = Some(now + Duration::from_secs(REKEY_JITTER_GRACE_SECS));
239        self.ecn_ce_count = 0;
240        self.interval_has_data = false;
241        // Keep cumulative_packets_recv, cumulative_bytes_recv (lifetime stats)
242        // Keep last_report_time, report_interval (report scheduling)
243    }
244
245    /// Record a received frame from this peer.
246    ///
247    /// Called on the RX path after AEAD decryption, before message dispatch.
248    ///
249    /// - `counter`: AEAD counter from outer header
250    /// - `sender_timestamp_ms`: session-relative timestamp from inner header (ms)
251    /// - `bytes`: wire payload size
252    /// - `ce_flag`: CE bit from flags byte
253    /// - `now`: current local time
254    pub fn record_recv(
255        &mut self,
256        counter: u64,
257        sender_timestamp_ms: u32,
258        bytes: usize,
259        ce_flag: bool,
260        now: Instant,
261    ) {
262        self.interval_has_data = true;
263        self.cumulative_packets_recv += 1;
264        self.cumulative_bytes_recv += bytes as u64;
265        self.interval_packets_recv = self.interval_packets_recv.saturating_add(1);
266        self.interval_bytes_recv = self.interval_bytes_recv.saturating_add(bytes as u32);
267
268        // Reordering detection: counter < highest means out-of-order
269        if counter < self.highest_counter {
270            self.cumulative_reorder_count += 1;
271        } else {
272            self.highest_counter = counter;
273        }
274
275        // Loss/burst detection
276        let _lost = self.gap_tracker.observe(counter);
277
278        // ECN
279        if ce_flag {
280            self.ecn_ce_count = self.ecn_ce_count.saturating_add(1);
281        }
282
283        // Jitter: compute transit time delta
284        // Transit = recv_local - sender_timestamp (in µs for precision)
285        // We use a monotonic local reference derived from Instant offsets.
286        let sender_us = (sender_timestamp_ms as i64) * 1000;
287        // We can't get absolute µs from Instant, but we can compute the delta
288        // between consecutive transits using relative Instant differences.
289        // Skip during post-rekey grace period to avoid drain-window spikes.
290        let in_grace = self
291            .rekey_jitter_grace_until
292            .is_some_and(|deadline| now < deadline);
293        if !in_grace {
294            self.rekey_jitter_grace_until = None; // clear expired grace
295            if let Some(prev_recv) = self.last_recv_time {
296                let recv_delta_us = now.duration_since(prev_recv).as_micros() as i64;
297                let send_delta_us = sender_us - (self.last_sender_timestamp as i64 * 1000);
298                let transit_delta = (recv_delta_us - send_delta_us) as i32;
299                self.jitter.update(transit_delta);
300            }
301        }
302
303        // OWD trend: use sender timestamp as a proxy for send time
304        // and Instant delta from a fixed reference as receive time.
305        // Since we only need the *trend* (slope), absolute offsets cancel out.
306        if let Some(first_recv) = self.last_recv_time.or(Some(now)) {
307            let recv_offset_us = now.duration_since(first_recv).as_micros() as i64;
308            let owd_us = recv_offset_us - sender_us;
309            self.owd_seq = self.owd_seq.wrapping_add(1);
310            self.owd_trend.push(self.owd_seq, owd_us);
311        }
312
313        // Timestamp echo state
314        self.last_sender_timestamp = sender_timestamp_ms;
315        self.last_recv_time = Some(now);
316    }
317
318    /// Build a ReceiverReport from current state and reset the interval.
319    ///
320    /// Returns `None` if no frames have been received since the last report.
321    pub fn build_report(&mut self, now: Instant) -> Option<ReceiverReport> {
322        if !self.interval_has_data {
323            return None;
324        }
325
326        // Dwell time: ms between last frame reception and report generation
327        let dwell_time = self
328            .last_recv_time
329            .map(|t| now.duration_since(t).as_millis() as u16)
330            .unwrap_or(0);
331
332        let (burst_count, max_burst, mean_burst) = self.gap_tracker.take_interval_stats();
333
334        let report = ReceiverReport {
335            highest_counter: self.highest_counter,
336            cumulative_packets_recv: self.cumulative_packets_recv,
337            cumulative_bytes_recv: self.cumulative_bytes_recv,
338            timestamp_echo: self.last_sender_timestamp,
339            dwell_time,
340            max_burst_loss: max_burst,
341            mean_burst_loss: mean_burst,
342            jitter: self.jitter.jitter_us(),
343            ecn_ce_count: self.ecn_ce_count,
344            owd_trend: self.owd_trend.trend_us_per_sec(),
345            burst_loss_count: burst_count,
346            cumulative_reorder_count: self.cumulative_reorder_count as u32,
347            interval_packets_recv: self.interval_packets_recv,
348            interval_bytes_recv: self.interval_bytes_recv,
349        };
350
351        // Reset interval
352        self.interval_packets_recv = 0;
353        self.interval_bytes_recv = 0;
354        self.interval_has_data = false;
355        self.last_report_time = Some(now);
356
357        Some(report)
358    }
359
360    /// Check if it's time to send a report.
361    pub fn should_send_report(&self, now: Instant) -> bool {
362        if !self.interval_has_data {
363            return false;
364        }
365        match self.last_report_time {
366            None => true,
367            Some(last) => now.duration_since(last) >= self.report_interval,
368        }
369    }
370
371    /// Update the report interval based on SRTT (link-layer defaults).
372    ///
373    /// Receiver reports at 1× SRTT clamped to [floor, MAX]. During cold-start
374    /// (first `COLD_START_SAMPLES` updates), the floor is the cold-start
375    /// interval (200ms) for fast SRTT convergence. After that, it rises to
376    /// `MIN_REPORT_INTERVAL_MS` (1000ms) for steady-state efficiency.
377    pub fn update_report_interval_from_srtt(&mut self, srtt_us: i64) {
378        self.srtt_sample_count = self.srtt_sample_count.saturating_add(1);
379        let floor = if self.srtt_sample_count <= COLD_START_SAMPLES {
380            DEFAULT_COLD_START_INTERVAL_MS
381        } else {
382            MIN_REPORT_INTERVAL_MS
383        };
384        self.update_report_interval_with_bounds(srtt_us, floor, MAX_REPORT_INTERVAL_MS);
385    }
386
387    /// Update the report interval based on SRTT with custom bounds.
388    ///
389    /// Used by session-layer MMP which needs higher clamp values since
390    /// each report consumes bandwidth on every transit link.
391    pub fn update_report_interval_with_bounds(&mut self, srtt_us: i64, min_ms: u64, max_ms: u64) {
392        if srtt_us <= 0 {
393            return;
394        }
395        let interval_ms = ((srtt_us as u64) / 1000).clamp(min_ms, max_ms);
396        self.report_interval = Duration::from_millis(interval_ms);
397    }
398
399    // --- Accessors ---
400
401    pub fn cumulative_packets_recv(&self) -> u64 {
402        self.cumulative_packets_recv
403    }
404
405    pub fn cumulative_bytes_recv(&self) -> u64 {
406        self.cumulative_bytes_recv
407    }
408
409    pub fn highest_counter(&self) -> u64 {
410        self.highest_counter
411    }
412
413    pub fn jitter_us(&self) -> u32 {
414        self.jitter.jitter_us()
415    }
416
417    pub fn report_interval(&self) -> Duration {
418        self.report_interval
419    }
420
421    pub fn last_recv_time(&self) -> Option<Instant> {
422        self.last_recv_time
423    }
424
425    pub fn ecn_ce_count(&self) -> u32 {
426        self.ecn_ce_count
427    }
428}
429
430impl Default for ReceiverState {
431    fn default() -> Self {
432        Self::new(DEFAULT_OWD_WINDOW_SIZE)
433    }
434}
435
436// ============================================================================
437// Tests
438// ============================================================================
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    #[test]
445    fn test_new_receiver_state() {
446        let r = ReceiverState::new(32);
447        assert_eq!(r.cumulative_packets_recv(), 0);
448        assert_eq!(r.cumulative_bytes_recv(), 0);
449        assert_eq!(r.highest_counter(), 0);
450    }
451
452    #[test]
453    fn test_record_recv_basic() {
454        let mut r = ReceiverState::new(32);
455        let now = Instant::now();
456        r.record_recv(1, 100, 500, false, now);
457        r.record_recv(2, 200, 600, false, now + Duration::from_millis(100));
458
459        assert_eq!(r.cumulative_packets_recv(), 2);
460        assert_eq!(r.cumulative_bytes_recv(), 1100);
461        assert_eq!(r.highest_counter(), 2);
462    }
463
464    #[test]
465    fn test_reorder_detection() {
466        let mut r = ReceiverState::new(32);
467        let now = Instant::now();
468        r.record_recv(5, 500, 100, false, now);
469        r.record_recv(3, 300, 100, false, now + Duration::from_millis(10));
470
471        assert_eq!(r.cumulative_reorder_count, 1);
472        assert_eq!(r.highest_counter(), 5); // not changed by out-of-order
473    }
474
475    #[test]
476    fn test_ecn_counting() {
477        let mut r = ReceiverState::new(32);
478        let now = Instant::now();
479        r.record_recv(1, 100, 100, true, now);
480        r.record_recv(2, 200, 100, false, now);
481        r.record_recv(3, 300, 100, true, now);
482
483        assert_eq!(r.ecn_ce_count, 2);
484    }
485
486    #[test]
487    fn test_build_report_empty() {
488        let mut r = ReceiverState::new(32);
489        assert!(r.build_report(Instant::now()).is_none());
490    }
491
492    #[test]
493    fn test_build_report() {
494        let mut r = ReceiverState::new(32);
495        let t0 = Instant::now();
496        r.record_recv(1, 100, 500, false, t0);
497        r.record_recv(2, 200, 600, false, t0 + Duration::from_millis(100));
498
499        let report = r.build_report(t0 + Duration::from_millis(150)).unwrap();
500        assert_eq!(report.highest_counter, 2);
501        assert_eq!(report.cumulative_packets_recv, 2);
502        assert_eq!(report.cumulative_bytes_recv, 1100);
503        assert_eq!(report.timestamp_echo, 200); // last sender timestamp
504        assert_eq!(report.interval_packets_recv, 2);
505        assert_eq!(report.interval_bytes_recv, 1100);
506    }
507
508    #[test]
509    fn test_build_report_resets_interval() {
510        let mut r = ReceiverState::new(32);
511        let t0 = Instant::now();
512        r.record_recv(1, 100, 500, false, t0);
513        let _ = r.build_report(t0);
514
515        // No new data
516        assert!(r.build_report(t0).is_none());
517
518        // New data
519        r.record_recv(2, 200, 300, false, t0 + Duration::from_millis(100));
520        let report = r.build_report(t0 + Duration::from_millis(150)).unwrap();
521        assert_eq!(report.interval_packets_recv, 1);
522        assert_eq!(report.interval_bytes_recv, 300);
523        // Cumulative continues
524        assert_eq!(report.cumulative_packets_recv, 2);
525    }
526
527    #[test]
528    fn test_gap_tracker_no_loss() {
529        let mut g = GapTracker::new();
530        g.observe(1);
531        g.observe(2);
532        g.observe(3);
533        let (count, max, mean) = g.take_interval_stats();
534        assert_eq!(count, 0);
535        assert_eq!(max, 0);
536        assert_eq!(mean, 0);
537    }
538
539    #[test]
540    fn test_gap_tracker_single_burst() {
541        let mut g = GapTracker::new();
542        g.observe(1);
543        // frames 2, 3 lost
544        g.observe(4);
545        g.observe(5);
546        let (count, max, _mean) = g.take_interval_stats();
547        assert_eq!(count, 1);
548        assert_eq!(max, 2);
549    }
550
551    #[test]
552    fn test_gap_tracker_multiple_bursts() {
553        let mut g = GapTracker::new();
554        g.observe(1);
555        g.observe(4); // burst of 2 (frames 2,3 lost)
556        g.observe(5);
557        g.observe(8); // burst of 2 (frames 6,7 lost)
558        g.observe(9);
559        let (count, max, mean) = g.take_interval_stats();
560        assert_eq!(count, 2);
561        assert_eq!(max, 2);
562        // mean = 2.0 in u8.8 = 512
563        assert_eq!(mean, 512);
564    }
565
566    #[test]
567    fn test_should_send_report_timing() {
568        let mut r = ReceiverState::new(32);
569        let t0 = Instant::now();
570
571        assert!(!r.should_send_report(t0)); // no data
572
573        r.record_recv(1, 100, 500, false, t0);
574        assert!(r.should_send_report(t0)); // first time, has data
575
576        let _ = r.build_report(t0);
577        r.record_recv(2, 200, 500, false, t0);
578        assert!(!r.should_send_report(t0)); // just reported
579
580        let t1 = t0 + r.report_interval() + Duration::from_millis(1);
581        assert!(r.should_send_report(t1));
582    }
583
584    #[test]
585    fn test_update_report_interval_cold_start() {
586        let mut r = ReceiverState::new(32);
587        // During cold-start, floor is 200ms (DEFAULT_COLD_START_INTERVAL_MS)
588        // 50ms SRTT → 50ms receiver interval (1× SRTT), clamped to cold-start floor 200ms
589        r.update_report_interval_from_srtt(50_000);
590        assert_eq!(r.report_interval(), Duration::from_millis(200));
591
592        // 500ms SRTT → 500ms (above cold-start floor)
593        r.update_report_interval_from_srtt(500_000);
594        assert_eq!(r.report_interval(), Duration::from_millis(500));
595    }
596
597    #[test]
598    fn test_update_report_interval_after_cold_start() {
599        let mut r = ReceiverState::new(32);
600        // Burn through cold-start samples
601        for _ in 0..COLD_START_SAMPLES {
602            r.update_report_interval_from_srtt(500_000);
603        }
604
605        // 6th sample: steady state, floor is MIN_REPORT_INTERVAL_MS (1000ms)
606        // 50ms SRTT → 50ms receiver interval (1× SRTT), clamped to 1000ms
607        r.update_report_interval_from_srtt(50_000);
608        assert_eq!(
609            r.report_interval(),
610            Duration::from_millis(MIN_REPORT_INTERVAL_MS)
611        );
612
613        // 3s SRTT → 3000ms, within [1000, 5000]
614        r.update_report_interval_from_srtt(3_000_000);
615        assert_eq!(r.report_interval(), Duration::from_millis(3000));
616    }
617
618    #[test]
619    fn test_rekey_jitter_grace_suppresses_spikes() {
620        let mut r = ReceiverState::new(32);
621        let t0 = Instant::now();
622
623        // Establish baseline with two frames so jitter starts updating
624        r.record_recv(1, 1000, 100, false, t0);
625        r.record_recv(2, 2000, 100, false, t0 + Duration::from_secs(1));
626        assert_eq!(r.jitter_us(), 0); // perfect 1s spacing → 0 jitter
627
628        // Simulate rekey: reset, then send a frame with a large old-session
629        // timestamp followed by a new-session timestamp near zero.
630        // Without grace, this would produce a huge jitter spike.
631        r.reset_for_rekey(t0 + Duration::from_secs(2));
632
633        // Frame arrives during grace period with old-session timestamp
634        r.record_recv(0, 120_000, 100, false, t0 + Duration::from_secs(3));
635        // Next frame with new-session timestamp near zero
636        r.record_recv(1, 100, 100, false, t0 + Duration::from_secs(4));
637        // Jitter should still be zero — updates suppressed during grace
638        assert_eq!(r.jitter_us(), 0);
639
640        // After grace expires, jitter updates resume
641        let after_grace =
642            t0 + Duration::from_secs(2) + Duration::from_secs(REKEY_JITTER_GRACE_SECS + 1);
643        r.record_recv(2, 200, 100, false, after_grace);
644        r.record_recv(3, 300, 100, false, after_grace + Duration::from_millis(100));
645        // Now jitter should be updating (non-zero or zero depending on timing)
646        // The key assertion is that it's not a multi-second spike
647        assert!(r.jitter_us() < 1_000_000); // less than 1 second
648    }
649}