nomad_protocol/transport/
connection.rs

1//! Connection state management for NOMAD transport layer.
2//!
3//! Implements the connection state machine from 2-TRANSPORT.md.
4
5use std::net::SocketAddr;
6use std::time::Instant;
7
8use super::frame::SessionId;
9use super::migration::MigrationState;
10use super::pacing::{FramePacer, RetransmitController};
11use super::timing::{RttEstimator, TimestampTracker};
12
13/// Connection lifecycle state.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ConnectionPhase {
16    /// Handshake in progress.
17    Handshaking,
18    /// Connection established, data transfer active.
19    Established,
20    /// Connection closing gracefully.
21    Closing,
22    /// Connection closed.
23    Closed,
24    /// Connection failed (timeout, too many retransmits, etc).
25    Failed,
26}
27
28/// Anti-replay window using a bitfield.
29///
30/// Tracks received nonces to detect and reject replayed frames.
31/// Uses a sliding window of 2048+ bits as recommended by the spec.
32#[derive(Debug, Clone)]
33pub struct NonceWindow {
34    /// The highest nonce we've seen.
35    highest: u64,
36    /// Bitfield for nonces below highest (bit i = highest - 1 - i).
37    /// We track 2048 nonces below the highest.
38    window: [u64; 32], // 32 * 64 = 2048 bits
39}
40
41impl Default for NonceWindow {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl NonceWindow {
48    /// Window size in bits.
49    pub const WINDOW_SIZE: usize = 2048;
50
51    /// Create a new nonce window.
52    pub fn new() -> Self {
53        Self {
54            highest: 0,
55            window: [0; 32],
56        }
57    }
58
59    /// Check if a nonce is valid (not replayed) and mark it as seen.
60    ///
61    /// Returns `true` if the nonce is valid and should be accepted,
62    /// `false` if it's a replay or too old.
63    pub fn check_and_mark(&mut self, nonce: u64) -> bool {
64        // First nonce ever
65        if self.highest == 0 && nonce > 0 {
66            self.highest = nonce;
67            return true;
68        }
69
70        if nonce > self.highest {
71            // New highest nonce - shift window
72            let shift = (nonce - self.highest) as usize;
73            self.shift_window(shift);
74            self.highest = nonce;
75            true
76        } else if nonce == self.highest {
77            // Duplicate of the highest
78            false
79        } else {
80            // Nonce below highest - check window
81            let offset = (self.highest - nonce) as usize;
82            if offset > Self::WINDOW_SIZE {
83                // Too old, outside our window
84                return false;
85            }
86
87            let offset = offset - 1; // Convert to 0-indexed
88            let word_idx = offset / 64;
89            let bit_idx = offset % 64;
90            let mask = 1u64 << bit_idx;
91
92            if self.window[word_idx] & mask != 0 {
93                // Already seen
94                false
95            } else {
96                // Mark as seen
97                self.window[word_idx] |= mask;
98                true
99            }
100        }
101    }
102
103    /// Shift the window by the given amount.
104    fn shift_window(&mut self, shift: usize) {
105        if shift >= Self::WINDOW_SIZE {
106            // Complete reset
107            self.window = [0; 32];
108            return;
109        }
110
111        let word_shift = shift / 64;
112        let bit_shift = shift % 64;
113
114        if word_shift > 0 {
115            // Shift words
116            for i in (word_shift..32).rev() {
117                self.window[i] = self.window[i - word_shift];
118            }
119            for i in 0..word_shift {
120                self.window[i] = 0;
121            }
122        }
123
124        if bit_shift > 0 {
125            // Shift bits within words
126            let mut carry = 0u64;
127            for i in (0..32).rev() {
128                let new_carry = self.window[i] << (64 - bit_shift);
129                self.window[i] = (self.window[i] >> bit_shift) | carry;
130                carry = new_carry;
131            }
132        }
133
134        // Mark the old highest as seen (it's now at offset 'shift - 1')
135        if shift > 0 {
136            let offset = shift - 1;
137            if offset < Self::WINDOW_SIZE {
138                let word_idx = offset / 64;
139                let bit_idx = offset % 64;
140                self.window[word_idx] |= 1u64 << bit_idx;
141            }
142        }
143    }
144}
145
146/// Full connection state as specified in 2-TRANSPORT.md.
147#[derive(Debug)]
148pub struct ConnectionState {
149    /// Session identifier from handshake.
150    pub session_id: SessionId,
151    /// Current connection phase.
152    pub phase: ConnectionPhase,
153    /// Remote peer address (may change during migration).
154    pub remote_endpoint: SocketAddr,
155    /// When we last received an authenticated frame.
156    pub last_received: Instant,
157    /// Current epoch (increments on rekey).
158    pub epoch: u32,
159
160    /// Outbound nonce counter (monotonically increasing).
161    pub send_nonce: u64,
162    /// Inbound anti-replay window.
163    pub recv_nonce_window: NonceWindow,
164
165    /// RTT estimation.
166    pub rtt: RttEstimator,
167    /// Timestamp tracking for RTT measurement.
168    pub timestamps: TimestampTracker,
169    /// Frame pacing.
170    pub pacer: FramePacer,
171    /// Retransmission control.
172    pub retransmit: RetransmitController,
173    /// Migration state.
174    pub migration: MigrationState,
175
176    /// Highest state version we've sent.
177    pub local_state_version: u64,
178    /// Highest state version we've acknowledged from peer.
179    pub remote_state_version: u64,
180    /// Highest state version the peer has acknowledged from us.
181    pub acked_state_version: u64,
182}
183
184impl ConnectionState {
185    /// Create a new connection state for an established session.
186    pub fn new(session_id: SessionId, remote_endpoint: SocketAddr) -> Self {
187        let now = Instant::now();
188        Self {
189            session_id,
190            phase: ConnectionPhase::Established,
191            remote_endpoint,
192            last_received: now,
193            epoch: 0,
194
195            send_nonce: 0,
196            recv_nonce_window: NonceWindow::new(),
197
198            rtt: RttEstimator::new(),
199            timestamps: TimestampTracker::new(),
200            pacer: FramePacer::new(),
201            retransmit: RetransmitController::new(super::timing::constants::INITIAL_RTO),
202            migration: MigrationState::new(remote_endpoint),
203
204            local_state_version: 0,
205            remote_state_version: 0,
206            acked_state_version: 0,
207        }
208    }
209
210    /// Create a connection state in handshaking phase.
211    pub fn handshaking(remote_endpoint: SocketAddr) -> Self {
212        let now = Instant::now();
213        Self {
214            session_id: SessionId::zero(),
215            phase: ConnectionPhase::Handshaking,
216            remote_endpoint,
217            last_received: now,
218            epoch: 0,
219
220            send_nonce: 0,
221            recv_nonce_window: NonceWindow::new(),
222
223            rtt: RttEstimator::new(),
224            timestamps: TimestampTracker::new(),
225            pacer: FramePacer::new(),
226            retransmit: RetransmitController::new(super::timing::constants::INITIAL_RTO),
227            migration: MigrationState::new(remote_endpoint),
228
229            local_state_version: 0,
230            remote_state_version: 0,
231            acked_state_version: 0,
232        }
233    }
234
235    /// Get the next nonce for sending and increment the counter.
236    pub fn next_send_nonce(&mut self) -> u64 {
237        let nonce = self.send_nonce;
238        self.send_nonce = self.send_nonce.saturating_add(1);
239        nonce
240    }
241
242    /// Check if a received nonce is valid (not replayed).
243    pub fn check_recv_nonce(&mut self, nonce: u64) -> bool {
244        self.recv_nonce_window.check_and_mark(nonce)
245    }
246
247    /// Update state after receiving an authenticated frame.
248    pub fn on_authenticated_frame(&mut self, from: SocketAddr) {
249        self.last_received = Instant::now();
250
251        // Handle potential migration
252        if from != self.remote_endpoint && self.migration.validate_address(from) {
253            self.remote_endpoint = from;
254        }
255    }
256
257    /// Check if the connection is still alive.
258    pub fn is_alive(&self) -> bool {
259        !self.pacer.is_connection_dead(self.last_received) && !self.retransmit.is_failed()
260    }
261
262    /// Check if the connection has failed.
263    pub fn is_failed(&self) -> bool {
264        self.phase == ConnectionPhase::Failed
265            || self.pacer.is_connection_dead(self.last_received)
266            || self.retransmit.is_failed()
267    }
268
269    /// Check if there's unacknowledged data.
270    pub fn has_unacked_data(&self) -> bool {
271        self.local_state_version > self.acked_state_version
272    }
273
274    /// Update the acked state version.
275    pub fn on_ack(&mut self, acked_version: u64) {
276        if acked_version > self.acked_state_version {
277            self.acked_state_version = acked_version;
278            self.retransmit.on_ack();
279        }
280    }
281
282    /// Transition to closed state.
283    pub fn close(&mut self) {
284        self.phase = ConnectionPhase::Closing;
285    }
286
287    /// Mark as fully closed.
288    pub fn mark_closed(&mut self) {
289        self.phase = ConnectionPhase::Closed;
290    }
291
292    /// Mark as failed.
293    pub fn mark_failed(&mut self) {
294        self.phase = ConnectionPhase::Failed;
295    }
296
297    /// Complete handshake and transition to established.
298    pub fn complete_handshake(&mut self, session_id: SessionId) {
299        self.session_id = session_id;
300        self.phase = ConnectionPhase::Established;
301        self.timestamps = TimestampTracker::new(); // Reset timestamps
302    }
303
304    /// Increment epoch (on rekey).
305    pub fn on_rekey(&mut self) {
306        self.epoch = self.epoch.saturating_add(1);
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use std::net::{IpAddr, Ipv4Addr};
314
315    fn test_addr(port: u16) -> SocketAddr {
316        SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port)
317    }
318
319    #[test]
320    fn test_nonce_window_new() {
321        let mut window = NonceWindow::new();
322
323        // First nonce should be accepted
324        assert!(window.check_and_mark(1));
325
326        // Same nonce should be rejected
327        assert!(!window.check_and_mark(1));
328
329        // Next nonce should be accepted
330        assert!(window.check_and_mark(2));
331    }
332
333    #[test]
334    fn test_nonce_window_gap() {
335        let mut window = NonceWindow::new();
336
337        // Accept nonce 1
338        assert!(window.check_and_mark(1));
339
340        // Skip to nonce 100
341        assert!(window.check_and_mark(100));
342
343        // Nonces in between should still be valid (not seen)
344        assert!(window.check_and_mark(50));
345        assert!(window.check_and_mark(75));
346
347        // But duplicates should be rejected
348        assert!(!window.check_and_mark(50));
349        assert!(!window.check_and_mark(100));
350    }
351
352    #[test]
353    fn test_nonce_window_too_old() {
354        let mut window = NonceWindow::new();
355
356        // Accept high nonce
357        assert!(window.check_and_mark(3000));
358
359        // Very old nonce should be rejected (outside window)
360        assert!(!window.check_and_mark(1));
361        assert!(!window.check_and_mark(500)); // 3000 - 500 = 2500 > 2048
362    }
363
364    #[test]
365    fn test_connection_state_nonces() {
366        let mut conn = ConnectionState::new(SessionId::zero(), test_addr(8080));
367
368        // Get sequential nonces
369        assert_eq!(conn.next_send_nonce(), 0);
370        assert_eq!(conn.next_send_nonce(), 1);
371        assert_eq!(conn.next_send_nonce(), 2);
372
373        // Verify nonce counter
374        assert_eq!(conn.send_nonce, 3);
375    }
376
377    #[test]
378    fn test_connection_state_lifecycle() {
379        let addr = test_addr(8080);
380        let mut conn = ConnectionState::handshaking(addr);
381
382        assert_eq!(conn.phase, ConnectionPhase::Handshaking);
383
384        // Complete handshake
385        let session_id = SessionId::from_bytes([1, 2, 3, 4, 5, 6]);
386        conn.complete_handshake(session_id);
387        assert_eq!(conn.phase, ConnectionPhase::Established);
388        assert_eq!(conn.session_id, session_id);
389
390        // Close
391        conn.close();
392        assert_eq!(conn.phase, ConnectionPhase::Closing);
393
394        conn.mark_closed();
395        assert_eq!(conn.phase, ConnectionPhase::Closed);
396    }
397
398    #[test]
399    fn test_connection_state_ack() {
400        let mut conn = ConnectionState::new(SessionId::zero(), test_addr(8080));
401
402        conn.local_state_version = 10;
403        assert!(conn.has_unacked_data());
404
405        conn.on_ack(5);
406        assert_eq!(conn.acked_state_version, 5);
407        assert!(conn.has_unacked_data());
408
409        conn.on_ack(10);
410        assert_eq!(conn.acked_state_version, 10);
411        assert!(!conn.has_unacked_data());
412    }
413
414    #[test]
415    fn test_connection_alive_check() {
416        let conn = ConnectionState::new(SessionId::zero(), test_addr(8080));
417        assert!(conn.is_alive());
418        assert!(!conn.is_failed());
419    }
420}