nomad_protocol/transport/
connection.rs1use std::net::SocketAddr;
6use std::time::Instant;
7
8use super::frame::SessionId;
9use super::migration::MigrationState;
10use super::pacing::{FramePacer, RetransmitController};
11use super::timing::{RttEstimator, TimestampTracker};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ConnectionPhase {
16 Handshaking,
18 Established,
20 Closing,
22 Closed,
24 Failed,
26}
27
28#[derive(Debug, Clone)]
33pub struct NonceWindow {
34 highest: u64,
36 window: [u64; 32], }
40
41impl Default for NonceWindow {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47impl NonceWindow {
48 pub const WINDOW_SIZE: usize = 2048;
50
51 pub fn new() -> Self {
53 Self {
54 highest: 0,
55 window: [0; 32],
56 }
57 }
58
59 pub fn check_and_mark(&mut self, nonce: u64) -> bool {
64 if self.highest == 0 && nonce > 0 {
66 self.highest = nonce;
67 return true;
68 }
69
70 if nonce > self.highest {
71 let shift = (nonce - self.highest) as usize;
73 self.shift_window(shift);
74 self.highest = nonce;
75 true
76 } else if nonce == self.highest {
77 false
79 } else {
80 let offset = (self.highest - nonce) as usize;
82 if offset > Self::WINDOW_SIZE {
83 return false;
85 }
86
87 let offset = offset - 1; let word_idx = offset / 64;
89 let bit_idx = offset % 64;
90 let mask = 1u64 << bit_idx;
91
92 if self.window[word_idx] & mask != 0 {
93 false
95 } else {
96 self.window[word_idx] |= mask;
98 true
99 }
100 }
101 }
102
103 fn shift_window(&mut self, shift: usize) {
105 if shift >= Self::WINDOW_SIZE {
106 self.window = [0; 32];
108 return;
109 }
110
111 let word_shift = shift / 64;
112 let bit_shift = shift % 64;
113
114 if word_shift > 0 {
115 for i in (word_shift..32).rev() {
117 self.window[i] = self.window[i - word_shift];
118 }
119 for i in 0..word_shift {
120 self.window[i] = 0;
121 }
122 }
123
124 if bit_shift > 0 {
125 let mut carry = 0u64;
127 for i in (0..32).rev() {
128 let new_carry = self.window[i] << (64 - bit_shift);
129 self.window[i] = (self.window[i] >> bit_shift) | carry;
130 carry = new_carry;
131 }
132 }
133
134 if shift > 0 {
136 let offset = shift - 1;
137 if offset < Self::WINDOW_SIZE {
138 let word_idx = offset / 64;
139 let bit_idx = offset % 64;
140 self.window[word_idx] |= 1u64 << bit_idx;
141 }
142 }
143 }
144}
145
146#[derive(Debug)]
148pub struct ConnectionState {
149 pub session_id: SessionId,
151 pub phase: ConnectionPhase,
153 pub remote_endpoint: SocketAddr,
155 pub last_received: Instant,
157 pub epoch: u32,
159
160 pub send_nonce: u64,
162 pub recv_nonce_window: NonceWindow,
164
165 pub rtt: RttEstimator,
167 pub timestamps: TimestampTracker,
169 pub pacer: FramePacer,
171 pub retransmit: RetransmitController,
173 pub migration: MigrationState,
175
176 pub local_state_version: u64,
178 pub remote_state_version: u64,
180 pub acked_state_version: u64,
182}
183
184impl ConnectionState {
185 pub fn new(session_id: SessionId, remote_endpoint: SocketAddr) -> Self {
187 let now = Instant::now();
188 Self {
189 session_id,
190 phase: ConnectionPhase::Established,
191 remote_endpoint,
192 last_received: now,
193 epoch: 0,
194
195 send_nonce: 0,
196 recv_nonce_window: NonceWindow::new(),
197
198 rtt: RttEstimator::new(),
199 timestamps: TimestampTracker::new(),
200 pacer: FramePacer::new(),
201 retransmit: RetransmitController::new(super::timing::constants::INITIAL_RTO),
202 migration: MigrationState::new(remote_endpoint),
203
204 local_state_version: 0,
205 remote_state_version: 0,
206 acked_state_version: 0,
207 }
208 }
209
210 pub fn handshaking(remote_endpoint: SocketAddr) -> Self {
212 let now = Instant::now();
213 Self {
214 session_id: SessionId::zero(),
215 phase: ConnectionPhase::Handshaking,
216 remote_endpoint,
217 last_received: now,
218 epoch: 0,
219
220 send_nonce: 0,
221 recv_nonce_window: NonceWindow::new(),
222
223 rtt: RttEstimator::new(),
224 timestamps: TimestampTracker::new(),
225 pacer: FramePacer::new(),
226 retransmit: RetransmitController::new(super::timing::constants::INITIAL_RTO),
227 migration: MigrationState::new(remote_endpoint),
228
229 local_state_version: 0,
230 remote_state_version: 0,
231 acked_state_version: 0,
232 }
233 }
234
235 pub fn next_send_nonce(&mut self) -> u64 {
237 let nonce = self.send_nonce;
238 self.send_nonce = self.send_nonce.saturating_add(1);
239 nonce
240 }
241
242 pub fn check_recv_nonce(&mut self, nonce: u64) -> bool {
244 self.recv_nonce_window.check_and_mark(nonce)
245 }
246
247 pub fn on_authenticated_frame(&mut self, from: SocketAddr) {
249 self.last_received = Instant::now();
250
251 if from != self.remote_endpoint && self.migration.validate_address(from) {
253 self.remote_endpoint = from;
254 }
255 }
256
257 pub fn is_alive(&self) -> bool {
259 !self.pacer.is_connection_dead(self.last_received) && !self.retransmit.is_failed()
260 }
261
262 pub fn is_failed(&self) -> bool {
264 self.phase == ConnectionPhase::Failed
265 || self.pacer.is_connection_dead(self.last_received)
266 || self.retransmit.is_failed()
267 }
268
269 pub fn has_unacked_data(&self) -> bool {
271 self.local_state_version > self.acked_state_version
272 }
273
274 pub fn on_ack(&mut self, acked_version: u64) {
276 if acked_version > self.acked_state_version {
277 self.acked_state_version = acked_version;
278 self.retransmit.on_ack();
279 }
280 }
281
282 pub fn close(&mut self) {
284 self.phase = ConnectionPhase::Closing;
285 }
286
287 pub fn mark_closed(&mut self) {
289 self.phase = ConnectionPhase::Closed;
290 }
291
292 pub fn mark_failed(&mut self) {
294 self.phase = ConnectionPhase::Failed;
295 }
296
297 pub fn complete_handshake(&mut self, session_id: SessionId) {
299 self.session_id = session_id;
300 self.phase = ConnectionPhase::Established;
301 self.timestamps = TimestampTracker::new(); }
303
304 pub fn on_rekey(&mut self) {
306 self.epoch = self.epoch.saturating_add(1);
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313 use std::net::{IpAddr, Ipv4Addr};
314
315 fn test_addr(port: u16) -> SocketAddr {
316 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port)
317 }
318
319 #[test]
320 fn test_nonce_window_new() {
321 let mut window = NonceWindow::new();
322
323 assert!(window.check_and_mark(1));
325
326 assert!(!window.check_and_mark(1));
328
329 assert!(window.check_and_mark(2));
331 }
332
333 #[test]
334 fn test_nonce_window_gap() {
335 let mut window = NonceWindow::new();
336
337 assert!(window.check_and_mark(1));
339
340 assert!(window.check_and_mark(100));
342
343 assert!(window.check_and_mark(50));
345 assert!(window.check_and_mark(75));
346
347 assert!(!window.check_and_mark(50));
349 assert!(!window.check_and_mark(100));
350 }
351
352 #[test]
353 fn test_nonce_window_too_old() {
354 let mut window = NonceWindow::new();
355
356 assert!(window.check_and_mark(3000));
358
359 assert!(!window.check_and_mark(1));
361 assert!(!window.check_and_mark(500)); }
363
364 #[test]
365 fn test_connection_state_nonces() {
366 let mut conn = ConnectionState::new(SessionId::zero(), test_addr(8080));
367
368 assert_eq!(conn.next_send_nonce(), 0);
370 assert_eq!(conn.next_send_nonce(), 1);
371 assert_eq!(conn.next_send_nonce(), 2);
372
373 assert_eq!(conn.send_nonce, 3);
375 }
376
377 #[test]
378 fn test_connection_state_lifecycle() {
379 let addr = test_addr(8080);
380 let mut conn = ConnectionState::handshaking(addr);
381
382 assert_eq!(conn.phase, ConnectionPhase::Handshaking);
383
384 let session_id = SessionId::from_bytes([1, 2, 3, 4, 5, 6]);
386 conn.complete_handshake(session_id);
387 assert_eq!(conn.phase, ConnectionPhase::Established);
388 assert_eq!(conn.session_id, session_id);
389
390 conn.close();
392 assert_eq!(conn.phase, ConnectionPhase::Closing);
393
394 conn.mark_closed();
395 assert_eq!(conn.phase, ConnectionPhase::Closed);
396 }
397
398 #[test]
399 fn test_connection_state_ack() {
400 let mut conn = ConnectionState::new(SessionId::zero(), test_addr(8080));
401
402 conn.local_state_version = 10;
403 assert!(conn.has_unacked_data());
404
405 conn.on_ack(5);
406 assert_eq!(conn.acked_state_version, 5);
407 assert!(conn.has_unacked_data());
408
409 conn.on_ack(10);
410 assert_eq!(conn.acked_state_version, 10);
411 assert!(!conn.has_unacked_data());
412 }
413
414 #[test]
415 fn test_connection_alive_check() {
416 let conn = ConnectionState::new(SessionId::zero(), test_addr(8080));
417 assert!(conn.is_alive());
418 assert!(!conn.is_failed());
419 }
420}