Skip to main content

fips_core/noise/
replay.rs

1use super::REPLAY_WINDOW_SIZE;
2use std::fmt;
3
4/// Sliding window for replay protection.
5///
6/// Tracks which packet counters have been received within a window of
7/// REPLAY_WINDOW_SIZE. Packets with counters below the window or already
8/// seen within the window are rejected.
9///
10/// Based on WireGuard's anti-replay mechanism (RFC 6479 style).
11#[derive(Clone)]
12pub struct ReplayWindow {
13    /// Highest counter value seen.
14    highest: u64,
15    /// Bitmap tracking which counters in the window have been seen.
16    /// Bit i corresponds to counter (highest - i).
17    bitmap: [u64; REPLAY_WINDOW_SIZE / 64],
18}
19
20impl ReplayWindow {
21    /// Create a new replay window.
22    pub fn new() -> Self {
23        Self {
24            highest: 0,
25            bitmap: [0; REPLAY_WINDOW_SIZE / 64],
26        }
27    }
28
29    /// Check if a counter is valid (not replayed, not too old).
30    ///
31    /// Returns true if the counter is acceptable, false if it should be rejected.
32    /// Does NOT update the window - call `accept` after successful decryption.
33    pub fn check(&self, counter: u64) -> bool {
34        if counter > self.highest {
35            // New highest - always acceptable
36            return true;
37        }
38
39        // Counter is <= highest, check if it's within the window
40        let diff = self.highest - counter;
41        if diff as usize >= REPLAY_WINDOW_SIZE {
42            // Too old (outside window)
43            return false;
44        }
45
46        // Check bitmap - bit is set if counter was already seen
47        let word_idx = (diff as usize) / 64;
48        let bit_idx = (diff as usize) % 64;
49        (self.bitmap[word_idx] & (1u64 << bit_idx)) == 0
50    }
51
52    /// Accept a counter into the window.
53    ///
54    /// Call this only after successful decryption to prevent
55    /// DoS attacks that exhaust the window.
56    pub fn accept(&mut self, counter: u64) {
57        if counter > self.highest {
58            // Shift the window
59            let shift = counter - self.highest;
60            if shift as usize >= REPLAY_WINDOW_SIZE {
61                // Complete reset
62                self.bitmap = [0; REPLAY_WINDOW_SIZE / 64];
63            } else {
64                // Shift bitmap
65                self.shift_bitmap(shift as usize);
66            }
67            self.highest = counter;
68            // Mark counter 0 (which is now the highest) as seen
69            self.bitmap[0] |= 1;
70        } else {
71            // Mark the counter as seen
72            let diff = self.highest - counter;
73            let word_idx = (diff as usize) / 64;
74            let bit_idx = (diff as usize) % 64;
75            self.bitmap[word_idx] |= 1u64 << bit_idx;
76        }
77    }
78
79    /// Shift the bitmap by the given number of positions.
80    ///
81    /// This moves old counters to higher bit positions to make room for the
82    /// new highest counter at position 0.
83    fn shift_bitmap(&mut self, shift: usize) {
84        if shift >= REPLAY_WINDOW_SIZE {
85            self.bitmap = [0; REPLAY_WINDOW_SIZE / 64];
86            return;
87        }
88
89        let word_shift = shift / 64;
90        let bit_shift = shift % 64;
91
92        // Shift entire words first (from high to low to avoid overwriting)
93        if word_shift > 0 {
94            for i in (word_shift..self.bitmap.len()).rev() {
95                self.bitmap[i] = self.bitmap[i - word_shift];
96            }
97            for i in 0..word_shift {
98                self.bitmap[i] = 0;
99            }
100        }
101
102        // Shift bits within words (from low to high so carry propagates correctly)
103        if bit_shift > 0 {
104            let mut carry = 0u64;
105            for i in 0..self.bitmap.len() {
106                let new_carry = self.bitmap[i] >> (64 - bit_shift);
107                self.bitmap[i] = (self.bitmap[i] << bit_shift) | carry;
108                carry = new_carry;
109            }
110        }
111    }
112
113    /// Get the highest counter seen.
114    pub fn highest(&self) -> u64 {
115        self.highest
116    }
117
118    /// Reset the window (use when rekeying).
119    pub fn reset(&mut self) {
120        self.highest = 0;
121        self.bitmap = [0; REPLAY_WINDOW_SIZE / 64];
122    }
123}
124
125impl Default for ReplayWindow {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131impl fmt::Debug for ReplayWindow {
132    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133        f.debug_struct("ReplayWindow")
134            .field("highest", &self.highest)
135            .field("window_size", &REPLAY_WINDOW_SIZE)
136            .finish()
137    }
138}