1use std::collections::HashMap;
2use std::net::IpAddr;
3
4use super::Direction;
5
6#[derive(Debug, Clone, Hash, Eq, PartialEq)]
8pub struct ConnectionKey {
9 ip_a: IpAddr,
10 port_a: u16,
11 ip_b: IpAddr,
12 port_b: u16,
13}
14
15impl ConnectionKey {
16 pub fn new(src_ip: IpAddr, src_port: u16, dst_ip: IpAddr, dst_port: u16) -> Self {
19 if (src_ip, src_port) <= (dst_ip, dst_port) {
20 Self {
21 ip_a: src_ip,
22 port_a: src_port,
23 ip_b: dst_ip,
24 port_b: dst_port,
25 }
26 } else {
27 Self {
28 ip_a: dst_ip,
29 port_a: dst_port,
30 ip_b: src_ip,
31 port_b: src_port,
32 }
33 }
34 }
35
36 pub fn direction(&self, src_ip: IpAddr, src_port: u16) -> Direction {
38 if src_ip == self.ip_a && src_port == self.port_a {
39 Direction::ToServer } else {
41 Direction::ToClient
42 }
43 }
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum ConnectionState {
49 SynSent,
50 SynReceived,
51 Established,
52 FinWait1,
53 FinWait2,
54 CloseWait,
55 Closing,
56 LastAck,
57 TimeWait,
58 Closed,
59 Reset,
60 MidStream,
62}
63
64impl ConnectionState {
65 pub fn as_str(&self) -> &'static str {
67 match self {
68 ConnectionState::SynSent => "syn_sent",
69 ConnectionState::SynReceived => "syn_received",
70 ConnectionState::Established => "established",
71 ConnectionState::FinWait1 => "fin_wait_1",
72 ConnectionState::FinWait2 => "fin_wait_2",
73 ConnectionState::CloseWait => "close_wait",
74 ConnectionState::Closing => "closing",
75 ConnectionState::LastAck => "last_ack",
76 ConnectionState::TimeWait => "time_wait",
77 ConnectionState::Closed => "closed",
78 ConnectionState::Reset => "reset",
79 ConnectionState::MidStream => "mid_stream",
80 }
81 }
82}
83
84#[derive(Debug, Clone, Copy, Default)]
86pub struct TcpFlags {
87 pub syn: bool,
88 pub ack: bool,
89 pub fin: bool,
90 pub rst: bool,
91}
92
93#[derive(Debug, Clone)]
95pub struct Connection {
96 pub id: u64,
97 pub key: ConnectionKey,
98 pub state: ConnectionState,
99
100 pub client_is_a: bool,
103
104 pub client_isn: u32,
106 pub server_isn: u32,
107
108 pub start_time: i64,
110 pub last_activity: i64,
111 pub end_time: Option<i64>,
112
113 pub packets_to_server: u32,
115 pub packets_to_client: u32,
116
117 pub bytes_to_server: u64,
119 pub bytes_to_client: u64,
120
121 pub first_frame: u64,
123 pub last_frame: u64,
124}
125
126impl Connection {
127 pub fn client_ip(&self) -> IpAddr {
129 if self.client_is_a {
130 self.key.ip_a
131 } else {
132 self.key.ip_b
133 }
134 }
135
136 pub fn server_ip(&self) -> IpAddr {
138 if self.client_is_a {
139 self.key.ip_b
140 } else {
141 self.key.ip_a
142 }
143 }
144
145 pub fn client_port(&self) -> u16 {
147 if self.client_is_a {
148 self.key.port_a
149 } else {
150 self.key.port_b
151 }
152 }
153
154 pub fn server_port(&self) -> u16 {
156 if self.client_is_a {
157 self.key.port_b
158 } else {
159 self.key.port_a
160 }
161 }
162
163 pub fn direction(&self, src_ip: IpAddr, src_port: u16) -> Direction {
166 let is_from_a = src_ip == self.key.ip_a && src_port == self.key.port_a;
167
168 if self.client_is_a {
169 if is_from_a {
171 Direction::ToServer } else {
173 Direction::ToClient }
175 } else {
176 if is_from_a {
178 Direction::ToClient } else {
180 Direction::ToServer }
182 }
183 }
184}
185
186pub struct ConnectionTracker {
188 connections: HashMap<ConnectionKey, Connection>,
189 next_id: u64,
190}
191
192impl ConnectionTracker {
193 pub fn new() -> Self {
194 Self {
195 connections: HashMap::new(),
196 next_id: 1,
197 }
198 }
199
200 #[allow(clippy::too_many_arguments)]
203 pub fn get_or_create(
204 &mut self,
205 src_ip: IpAddr,
206 src_port: u16,
207 dst_ip: IpAddr,
208 dst_port: u16,
209 flags: TcpFlags,
210 seq: u32,
211 frame_number: u64,
212 timestamp: i64,
213 ) -> (&mut Connection, Direction) {
214 let key = ConnectionKey::new(src_ip, src_port, dst_ip, dst_port);
215
216 if !self.connections.contains_key(&key) {
217 let (state, client_is_a, client_isn) = if flags.syn && !flags.ack {
219 let client_is_a = src_ip == key.ip_a && src_port == key.port_a;
221 (ConnectionState::SynSent, client_is_a, seq)
222 } else {
223 let client_is_a = key.port_a > key.port_b;
225 (ConnectionState::MidStream, client_is_a, 0)
226 };
227
228 let conn = Connection {
229 id: self.next_id,
230 key: key.clone(),
231 state,
232 client_is_a,
233 client_isn,
234 server_isn: 0,
235 start_time: timestamp,
236 last_activity: timestamp,
237 end_time: None,
238 packets_to_server: 0,
239 packets_to_client: 0,
240 bytes_to_server: 0,
241 bytes_to_client: 0,
242 first_frame: frame_number,
243 last_frame: frame_number,
244 };
245
246 self.next_id += 1;
247 self.connections.insert(key.clone(), conn);
248 }
249
250 let conn = self.connections.get_mut(&key).unwrap();
251 conn.last_activity = timestamp;
252 conn.last_frame = frame_number;
253
254 let direction = conn.direction(src_ip, src_port);
256
257 (conn, direction)
258 }
259
260 pub fn update_state(conn: &mut Connection, flags: TcpFlags, direction: Direction, seq: u32) {
262 use ConnectionState::*;
263
264 match direction {
266 Direction::ToServer => conn.packets_to_server += 1,
267 Direction::ToClient => conn.packets_to_client += 1,
268 }
269
270 if flags.rst {
272 conn.state = Reset;
273 return;
274 }
275
276 conn.state = match (conn.state, flags.syn, flags.ack, flags.fin) {
278 (SynSent, true, true, false) if direction == Direction::ToClient => {
280 conn.server_isn = seq;
281 SynReceived
282 }
283 (SynReceived, false, true, false) if direction == Direction::ToServer => Established,
285
286 (Established, false, _, true) => match direction {
288 Direction::ToServer => FinWait1,
289 Direction::ToClient => CloseWait,
290 },
291
292 (FinWait1, false, true, false) => FinWait2,
294 (CloseWait, false, _, true) => LastAck,
295 (FinWait2, false, _, true) => TimeWait,
296 (LastAck, false, true, false) => Closed,
297
298 (FinWait1, false, _, true) => Closing,
300 (Closing, false, true, false) => TimeWait,
301
302 (MidStream, false, true, false) => Established,
304
305 (current, _, _, _) => current,
307 };
308 }
309
310 pub fn add_bytes(conn: &mut Connection, direction: Direction, bytes: usize) {
312 match direction {
313 Direction::ToServer => conn.bytes_to_server += bytes as u64,
314 Direction::ToClient => conn.bytes_to_client += bytes as u64,
315 }
316 }
317
318 pub fn get(&self, key: &ConnectionKey) -> Option<&Connection> {
320 self.connections.get(key)
321 }
322
323 pub fn connections(&self) -> impl Iterator<Item = &Connection> {
325 self.connections.values()
326 }
327
328 pub fn cleanup_timeout(&mut self, current_time: i64, timeout_us: i64) -> Vec<Connection> {
330 let mut removed = Vec::new();
331 self.connections.retain(|_, conn| {
332 if current_time - conn.last_activity > timeout_us {
333 removed.push(conn.clone());
334 false
335 } else {
336 true
337 }
338 });
339 removed
340 }
341}
342
343impl Default for ConnectionTracker {
344 fn default() -> Self {
345 Self::new()
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use std::net::Ipv4Addr;
353
354 fn ip(a: u8, b: u8, c: u8, d: u8) -> IpAddr {
355 IpAddr::V4(Ipv4Addr::new(a, b, c, d))
356 }
357
358 #[test]
360 fn test_connection_key_normalization() {
361 let key1 = ConnectionKey::new(ip(192, 168, 1, 1), 54321, ip(192, 168, 1, 2), 80);
362 let key2 = ConnectionKey::new(ip(192, 168, 1, 2), 80, ip(192, 168, 1, 1), 54321);
363 assert_eq!(key1, key2);
364 }
365
366 #[test]
368 fn test_three_way_handshake() {
369 let mut tracker = ConnectionTracker::new();
370
371 let syn = TcpFlags {
373 syn: true,
374 ..Default::default()
375 };
376 let (conn, dir) = tracker.get_or_create(
377 ip(192, 168, 1, 1),
378 54321,
379 ip(192, 168, 1, 2),
380 80,
381 syn,
382 1000,
383 1,
384 0,
385 );
386 assert_eq!(conn.state, ConnectionState::SynSent);
387 assert_eq!(dir, Direction::ToServer);
388
389 let syn_ack = TcpFlags {
391 syn: true,
392 ack: true,
393 ..Default::default()
394 };
395 let (conn, dir) = tracker.get_or_create(
396 ip(192, 168, 1, 2),
397 80,
398 ip(192, 168, 1, 1),
399 54321,
400 syn_ack,
401 2000,
402 2,
403 1,
404 );
405 ConnectionTracker::update_state(conn, syn_ack, dir, 2000);
406 assert_eq!(conn.state, ConnectionState::SynReceived);
407
408 let ack = TcpFlags {
410 ack: true,
411 ..Default::default()
412 };
413 let (conn, dir) = tracker.get_or_create(
414 ip(192, 168, 1, 1),
415 54321,
416 ip(192, 168, 1, 2),
417 80,
418 ack,
419 1001,
420 3,
421 2,
422 );
423 ConnectionTracker::update_state(conn, ack, dir, 1001);
424 assert_eq!(conn.state, ConnectionState::Established);
425 }
426
427 #[test]
429 fn test_fin_handshake() {
430 let mut tracker = ConnectionTracker::new();
431
432 let ack = TcpFlags {
434 ack: true,
435 ..Default::default()
436 };
437 let (conn, _) = tracker.get_or_create(
438 ip(192, 168, 1, 1),
439 54321,
440 ip(192, 168, 1, 2),
441 80,
442 ack,
443 1000,
444 1,
445 0,
446 );
447 conn.state = ConnectionState::Established;
448
449 let fin = TcpFlags {
451 fin: true,
452 ack: true,
453 ..Default::default()
454 };
455 ConnectionTracker::update_state(conn, fin, Direction::ToServer, 1000);
456 assert_eq!(conn.state, ConnectionState::FinWait1);
457 }
458
459 #[test]
461 fn test_rst_handling() {
462 let mut tracker = ConnectionTracker::new();
463 let ack = TcpFlags {
464 ack: true,
465 ..Default::default()
466 };
467 let (conn, _) = tracker.get_or_create(
468 ip(192, 168, 1, 1),
469 54321,
470 ip(192, 168, 1, 2),
471 80,
472 ack,
473 1000,
474 1,
475 0,
476 );
477 conn.state = ConnectionState::Established;
478
479 let rst = TcpFlags {
480 rst: true,
481 ..Default::default()
482 };
483 ConnectionTracker::update_state(conn, rst, Direction::ToServer, 1000);
484 assert_eq!(conn.state, ConnectionState::Reset);
485 }
486
487 #[test]
489 fn test_mid_stream() {
490 let mut tracker = ConnectionTracker::new();
491 let ack = TcpFlags {
492 ack: true,
493 ..Default::default()
494 };
495 let (conn, _) = tracker.get_or_create(
496 ip(192, 168, 1, 1),
497 54321,
498 ip(192, 168, 1, 2),
499 80,
500 ack,
501 1000,
502 1,
503 0, );
505 assert_eq!(conn.state, ConnectionState::MidStream);
506 }
507
508 #[test]
510 fn test_connection_lookup() {
511 let mut tracker = ConnectionTracker::new();
512 let syn = TcpFlags {
513 syn: true,
514 ..Default::default()
515 };
516 tracker.get_or_create(
517 ip(192, 168, 1, 1),
518 54321,
519 ip(192, 168, 1, 2),
520 80,
521 syn,
522 1000,
523 1,
524 0,
525 );
526
527 let key = ConnectionKey::new(ip(192, 168, 1, 1), 54321, ip(192, 168, 1, 2), 80);
528 assert!(tracker.get(&key).is_some());
529 }
530
531 #[test]
533 fn test_packet_counting() {
534 let mut tracker = ConnectionTracker::new();
535 let ack = TcpFlags {
536 ack: true,
537 ..Default::default()
538 };
539
540 let (conn, dir) = tracker.get_or_create(
542 ip(192, 168, 1, 1),
543 54321,
544 ip(192, 168, 1, 2),
545 80,
546 ack,
547 1000,
548 1,
549 0,
550 );
551 ConnectionTracker::update_state(conn, ack, dir, 1000);
552
553 let (conn, dir) = tracker.get_or_create(
555 ip(192, 168, 1, 2),
556 80,
557 ip(192, 168, 1, 1),
558 54321,
559 ack,
560 2000,
561 2,
562 1,
563 );
564 ConnectionTracker::update_state(conn, ack, dir, 2000);
565
566 assert_eq!(conn.packets_to_server, 1);
567 assert_eq!(conn.packets_to_client, 1);
568 }
569
570 #[test]
572 fn test_timeout_cleanup() {
573 let mut tracker = ConnectionTracker::new();
574 let syn = TcpFlags {
575 syn: true,
576 ..Default::default()
577 };
578 tracker.get_or_create(
579 ip(192, 168, 1, 1),
580 54321,
581 ip(192, 168, 1, 2),
582 80,
583 syn,
584 1000,
585 1,
586 0,
587 );
588
589 let removed = tracker.cleanup_timeout(1000000, 2000000);
591 assert!(removed.is_empty());
592
593 let removed = tracker.cleanup_timeout(5000000, 2000000);
595 assert_eq!(removed.len(), 1);
596 }
597
598 #[test]
600 fn test_simultaneous_open() {
601 let mut tracker = ConnectionTracker::new();
602
603 let syn = TcpFlags {
605 syn: true,
606 ..Default::default()
607 };
608 let (conn, _) = tracker.get_or_create(
609 ip(192, 168, 1, 1),
610 1000,
611 ip(192, 168, 1, 2),
612 1001,
613 syn,
614 100,
615 1,
616 0,
617 );
618 assert_eq!(conn.state, ConnectionState::SynSent);
619 }
620
621 #[test]
623 fn test_connection_id_uniqueness() {
624 let mut tracker = ConnectionTracker::new();
625 let syn = TcpFlags {
626 syn: true,
627 ..Default::default()
628 };
629
630 let (conn1, _) = tracker.get_or_create(
631 ip(192, 168, 1, 1),
632 54321,
633 ip(192, 168, 1, 2),
634 80,
635 syn,
636 1000,
637 1,
638 0,
639 );
640 let id1 = conn1.id;
641
642 let (conn2, _) = tracker.get_or_create(
643 ip(192, 168, 1, 3),
644 54322,
645 ip(192, 168, 1, 4),
646 443,
647 syn,
648 2000,
649 2,
650 1,
651 );
652 let id2 = conn2.id;
653
654 assert_ne!(id1, id2);
655 }
656}