pcapsql_core/stream/
connection.rs

1use std::collections::HashMap;
2use std::net::IpAddr;
3
4use super::Direction;
5
6/// Normalized connection key (lower IP/port first for consistent lookup).
7#[derive(Debug, Clone, Hash, Eq, PartialEq)]
8pub struct ConnectionKey {
9    ip_a: IpAddr,
10    port_a: u16,
11    ip_b: IpAddr,
12    port_b: u16,
13}
14
15impl ConnectionKey {
16    /// Create a normalized connection key.
17    /// Ensures (ip_a, port_a) <= (ip_b, port_b) lexicographically.
18    pub fn new(src_ip: IpAddr, src_port: u16, dst_ip: IpAddr, dst_port: u16) -> Self {
19        if (src_ip, src_port) <= (dst_ip, dst_port) {
20            Self {
21                ip_a: src_ip,
22                port_a: src_port,
23                ip_b: dst_ip,
24                port_b: dst_port,
25            }
26        } else {
27            Self {
28                ip_a: dst_ip,
29                port_a: dst_port,
30                ip_b: src_ip,
31                port_b: src_port,
32            }
33        }
34    }
35
36    /// Determine direction based on who sent this packet.
37    pub fn direction(&self, src_ip: IpAddr, src_port: u16) -> Direction {
38        if src_ip == self.ip_a && src_port == self.port_a {
39            Direction::ToServer // A is client, sending to server
40        } else {
41            Direction::ToClient
42        }
43    }
44}
45
46/// TCP connection state (simplified state machine).
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum ConnectionState {
49    SynSent,
50    SynReceived,
51    Established,
52    FinWait1,
53    FinWait2,
54    CloseWait,
55    Closing,
56    LastAck,
57    TimeWait,
58    Closed,
59    Reset,
60    /// Connection started mid-capture (no SYN seen).
61    MidStream,
62}
63
64impl ConnectionState {
65    /// Return a string representation of the state.
66    pub fn as_str(&self) -> &'static str {
67        match self {
68            ConnectionState::SynSent => "syn_sent",
69            ConnectionState::SynReceived => "syn_received",
70            ConnectionState::Established => "established",
71            ConnectionState::FinWait1 => "fin_wait_1",
72            ConnectionState::FinWait2 => "fin_wait_2",
73            ConnectionState::CloseWait => "close_wait",
74            ConnectionState::Closing => "closing",
75            ConnectionState::LastAck => "last_ack",
76            ConnectionState::TimeWait => "time_wait",
77            ConnectionState::Closed => "closed",
78            ConnectionState::Reset => "reset",
79            ConnectionState::MidStream => "mid_stream",
80        }
81    }
82}
83
84/// TCP flags for state transitions.
85#[derive(Debug, Clone, Copy, Default)]
86pub struct TcpFlags {
87    pub syn: bool,
88    pub ack: bool,
89    pub fin: bool,
90    pub rst: bool,
91}
92
93/// A tracked TCP connection.
94#[derive(Debug, Clone)]
95pub struct Connection {
96    pub id: u64,
97    pub key: ConnectionKey,
98    pub state: ConnectionState,
99
100    /// Which endpoint is the client (sent SYN).
101    /// True if ip_a/port_a is client.
102    pub client_is_a: bool,
103
104    /// Initial sequence numbers.
105    pub client_isn: u32,
106    pub server_isn: u32,
107
108    /// Timing (microseconds).
109    pub start_time: i64,
110    pub last_activity: i64,
111    pub end_time: Option<i64>,
112
113    /// Packet counts.
114    pub packets_to_server: u32,
115    pub packets_to_client: u32,
116
117    /// Byte counts (payload only).
118    pub bytes_to_server: u64,
119    pub bytes_to_client: u64,
120
121    /// Frame references.
122    pub first_frame: u64,
123    pub last_frame: u64,
124}
125
126impl Connection {
127    /// Get client IP.
128    pub fn client_ip(&self) -> IpAddr {
129        if self.client_is_a {
130            self.key.ip_a
131        } else {
132            self.key.ip_b
133        }
134    }
135
136    /// Get server IP.
137    pub fn server_ip(&self) -> IpAddr {
138        if self.client_is_a {
139            self.key.ip_b
140        } else {
141            self.key.ip_a
142        }
143    }
144
145    /// Get client port.
146    pub fn client_port(&self) -> u16 {
147        if self.client_is_a {
148            self.key.port_a
149        } else {
150            self.key.port_b
151        }
152    }
153
154    /// Get server port.
155    pub fn server_port(&self) -> u16 {
156        if self.client_is_a {
157            self.key.port_b
158        } else {
159            self.key.port_a
160        }
161    }
162
163    /// Determine direction based on source IP/port.
164    /// This correctly accounts for which endpoint is the client.
165    pub fn direction(&self, src_ip: IpAddr, src_port: u16) -> Direction {
166        let is_from_a = src_ip == self.key.ip_a && src_port == self.key.port_a;
167
168        if self.client_is_a {
169            // A is client, B is server
170            if is_from_a {
171                Direction::ToServer // Client sending to server
172            } else {
173                Direction::ToClient // Server sending to client
174            }
175        } else {
176            // B is client, A is server
177            if is_from_a {
178                Direction::ToClient // Server sending to client
179            } else {
180                Direction::ToServer // Client sending to server
181            }
182        }
183    }
184}
185
186/// Tracks TCP connections.
187pub struct ConnectionTracker {
188    connections: HashMap<ConnectionKey, Connection>,
189    next_id: u64,
190}
191
192impl ConnectionTracker {
193    pub fn new() -> Self {
194        Self {
195            connections: HashMap::new(),
196            next_id: 1,
197        }
198    }
199
200    /// Get or create a connection for the given packet.
201    /// Returns (connection, direction).
202    #[allow(clippy::too_many_arguments)]
203    pub fn get_or_create(
204        &mut self,
205        src_ip: IpAddr,
206        src_port: u16,
207        dst_ip: IpAddr,
208        dst_port: u16,
209        flags: TcpFlags,
210        seq: u32,
211        frame_number: u64,
212        timestamp: i64,
213    ) -> (&mut Connection, Direction) {
214        let key = ConnectionKey::new(src_ip, src_port, dst_ip, dst_port);
215
216        if !self.connections.contains_key(&key) {
217            // Determine who is client based on SYN
218            let (state, client_is_a, client_isn) = if flags.syn && !flags.ack {
219                // This is the SYN - sender is client
220                let client_is_a = src_ip == key.ip_a && src_port == key.port_a;
221                (ConnectionState::SynSent, client_is_a, seq)
222            } else {
223                // Mid-stream connection - guess client by port (lower port = server)
224                let client_is_a = key.port_a > key.port_b;
225                (ConnectionState::MidStream, client_is_a, 0)
226            };
227
228            let conn = Connection {
229                id: self.next_id,
230                key: key.clone(),
231                state,
232                client_is_a,
233                client_isn,
234                server_isn: 0,
235                start_time: timestamp,
236                last_activity: timestamp,
237                end_time: None,
238                packets_to_server: 0,
239                packets_to_client: 0,
240                bytes_to_server: 0,
241                bytes_to_client: 0,
242                first_frame: frame_number,
243                last_frame: frame_number,
244            };
245
246            self.next_id += 1;
247            self.connections.insert(key.clone(), conn);
248        }
249
250        let conn = self.connections.get_mut(&key).unwrap();
251        conn.last_activity = timestamp;
252        conn.last_frame = frame_number;
253
254        // Compute direction using the connection's knowledge of client/server roles
255        let direction = conn.direction(src_ip, src_port);
256
257        (conn, direction)
258    }
259
260    /// Update connection state based on TCP flags.
261    pub fn update_state(conn: &mut Connection, flags: TcpFlags, direction: Direction, seq: u32) {
262        use ConnectionState::*;
263
264        // Update packet counts
265        match direction {
266            Direction::ToServer => conn.packets_to_server += 1,
267            Direction::ToClient => conn.packets_to_client += 1,
268        }
269
270        // Handle RST
271        if flags.rst {
272            conn.state = Reset;
273            return;
274        }
275
276        // State machine transitions
277        conn.state = match (conn.state, flags.syn, flags.ack, flags.fin) {
278            // SYN-ACK from server
279            (SynSent, true, true, false) if direction == Direction::ToClient => {
280                conn.server_isn = seq;
281                SynReceived
282            }
283            // ACK completing handshake
284            (SynReceived, false, true, false) if direction == Direction::ToServer => Established,
285
286            // FIN from either side
287            (Established, false, _, true) => match direction {
288                Direction::ToServer => FinWait1,
289                Direction::ToClient => CloseWait,
290            },
291
292            // ACK of FIN
293            (FinWait1, false, true, false) => FinWait2,
294            (CloseWait, false, _, true) => LastAck,
295            (FinWait2, false, _, true) => TimeWait,
296            (LastAck, false, true, false) => Closed,
297
298            // Simultaneous close
299            (FinWait1, false, _, true) => Closing,
300            (Closing, false, true, false) => TimeWait,
301
302            // Mid-stream can transition to established on data
303            (MidStream, false, true, false) => Established,
304
305            // Stay in current state
306            (current, _, _, _) => current,
307        };
308    }
309
310    /// Add payload bytes to connection stats.
311    pub fn add_bytes(conn: &mut Connection, direction: Direction, bytes: usize) {
312        match direction {
313            Direction::ToServer => conn.bytes_to_server += bytes as u64,
314            Direction::ToClient => conn.bytes_to_client += bytes as u64,
315        }
316    }
317
318    /// Get a connection by key.
319    pub fn get(&self, key: &ConnectionKey) -> Option<&Connection> {
320        self.connections.get(key)
321    }
322
323    /// Get all connections.
324    pub fn connections(&self) -> impl Iterator<Item = &Connection> {
325        self.connections.values()
326    }
327
328    /// Remove timed-out connections.
329    pub fn cleanup_timeout(&mut self, current_time: i64, timeout_us: i64) -> Vec<Connection> {
330        let mut removed = Vec::new();
331        self.connections.retain(|_, conn| {
332            if current_time - conn.last_activity > timeout_us {
333                removed.push(conn.clone());
334                false
335            } else {
336                true
337            }
338        });
339        removed
340    }
341}
342
343impl Default for ConnectionTracker {
344    fn default() -> Self {
345        Self::new()
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use std::net::Ipv4Addr;
353
354    fn ip(a: u8, b: u8, c: u8, d: u8) -> IpAddr {
355        IpAddr::V4(Ipv4Addr::new(a, b, c, d))
356    }
357
358    // Test 1: Connection key normalization
359    #[test]
360    fn test_connection_key_normalization() {
361        let key1 = ConnectionKey::new(ip(192, 168, 1, 1), 54321, ip(192, 168, 1, 2), 80);
362        let key2 = ConnectionKey::new(ip(192, 168, 1, 2), 80, ip(192, 168, 1, 1), 54321);
363        assert_eq!(key1, key2);
364    }
365
366    // Test 2: SYN -> SYN-ACK -> ACK
367    #[test]
368    fn test_three_way_handshake() {
369        let mut tracker = ConnectionTracker::new();
370
371        // SYN from client
372        let syn = TcpFlags {
373            syn: true,
374            ..Default::default()
375        };
376        let (conn, dir) = tracker.get_or_create(
377            ip(192, 168, 1, 1),
378            54321,
379            ip(192, 168, 1, 2),
380            80,
381            syn,
382            1000,
383            1,
384            0,
385        );
386        assert_eq!(conn.state, ConnectionState::SynSent);
387        assert_eq!(dir, Direction::ToServer);
388
389        // SYN-ACK from server
390        let syn_ack = TcpFlags {
391            syn: true,
392            ack: true,
393            ..Default::default()
394        };
395        let (conn, dir) = tracker.get_or_create(
396            ip(192, 168, 1, 2),
397            80,
398            ip(192, 168, 1, 1),
399            54321,
400            syn_ack,
401            2000,
402            2,
403            1,
404        );
405        ConnectionTracker::update_state(conn, syn_ack, dir, 2000);
406        assert_eq!(conn.state, ConnectionState::SynReceived);
407
408        // ACK from client
409        let ack = TcpFlags {
410            ack: true,
411            ..Default::default()
412        };
413        let (conn, dir) = tracker.get_or_create(
414            ip(192, 168, 1, 1),
415            54321,
416            ip(192, 168, 1, 2),
417            80,
418            ack,
419            1001,
420            3,
421            2,
422        );
423        ConnectionTracker::update_state(conn, ack, dir, 1001);
424        assert_eq!(conn.state, ConnectionState::Established);
425    }
426
427    // Test 3: FIN handshake
428    #[test]
429    fn test_fin_handshake() {
430        let mut tracker = ConnectionTracker::new();
431
432        // Establish connection first (simplified)
433        let ack = TcpFlags {
434            ack: true,
435            ..Default::default()
436        };
437        let (conn, _) = tracker.get_or_create(
438            ip(192, 168, 1, 1),
439            54321,
440            ip(192, 168, 1, 2),
441            80,
442            ack,
443            1000,
444            1,
445            0,
446        );
447        conn.state = ConnectionState::Established;
448
449        // FIN from client
450        let fin = TcpFlags {
451            fin: true,
452            ack: true,
453            ..Default::default()
454        };
455        ConnectionTracker::update_state(conn, fin, Direction::ToServer, 1000);
456        assert_eq!(conn.state, ConnectionState::FinWait1);
457    }
458
459    // Test 4: RST handling
460    #[test]
461    fn test_rst_handling() {
462        let mut tracker = ConnectionTracker::new();
463        let ack = TcpFlags {
464            ack: true,
465            ..Default::default()
466        };
467        let (conn, _) = tracker.get_or_create(
468            ip(192, 168, 1, 1),
469            54321,
470            ip(192, 168, 1, 2),
471            80,
472            ack,
473            1000,
474            1,
475            0,
476        );
477        conn.state = ConnectionState::Established;
478
479        let rst = TcpFlags {
480            rst: true,
481            ..Default::default()
482        };
483        ConnectionTracker::update_state(conn, rst, Direction::ToServer, 1000);
484        assert_eq!(conn.state, ConnectionState::Reset);
485    }
486
487    // Test 5: Mid-stream detection
488    #[test]
489    fn test_mid_stream() {
490        let mut tracker = ConnectionTracker::new();
491        let ack = TcpFlags {
492            ack: true,
493            ..Default::default()
494        };
495        let (conn, _) = tracker.get_or_create(
496            ip(192, 168, 1, 1),
497            54321,
498            ip(192, 168, 1, 2),
499            80,
500            ack,
501            1000,
502            1,
503            0, // No SYN, just data
504        );
505        assert_eq!(conn.state, ConnectionState::MidStream);
506    }
507
508    // Test 6: Connection lookup
509    #[test]
510    fn test_connection_lookup() {
511        let mut tracker = ConnectionTracker::new();
512        let syn = TcpFlags {
513            syn: true,
514            ..Default::default()
515        };
516        tracker.get_or_create(
517            ip(192, 168, 1, 1),
518            54321,
519            ip(192, 168, 1, 2),
520            80,
521            syn,
522            1000,
523            1,
524            0,
525        );
526
527        let key = ConnectionKey::new(ip(192, 168, 1, 1), 54321, ip(192, 168, 1, 2), 80);
528        assert!(tracker.get(&key).is_some());
529    }
530
531    // Test 7: Packet counting
532    #[test]
533    fn test_packet_counting() {
534        let mut tracker = ConnectionTracker::new();
535        let ack = TcpFlags {
536            ack: true,
537            ..Default::default()
538        };
539
540        // Packet to server
541        let (conn, dir) = tracker.get_or_create(
542            ip(192, 168, 1, 1),
543            54321,
544            ip(192, 168, 1, 2),
545            80,
546            ack,
547            1000,
548            1,
549            0,
550        );
551        ConnectionTracker::update_state(conn, ack, dir, 1000);
552
553        // Packet to client
554        let (conn, dir) = tracker.get_or_create(
555            ip(192, 168, 1, 2),
556            80,
557            ip(192, 168, 1, 1),
558            54321,
559            ack,
560            2000,
561            2,
562            1,
563        );
564        ConnectionTracker::update_state(conn, ack, dir, 2000);
565
566        assert_eq!(conn.packets_to_server, 1);
567        assert_eq!(conn.packets_to_client, 1);
568    }
569
570    // Test 8: Timeout cleanup
571    #[test]
572    fn test_timeout_cleanup() {
573        let mut tracker = ConnectionTracker::new();
574        let syn = TcpFlags {
575            syn: true,
576            ..Default::default()
577        };
578        tracker.get_or_create(
579            ip(192, 168, 1, 1),
580            54321,
581            ip(192, 168, 1, 2),
582            80,
583            syn,
584            1000,
585            1,
586            0,
587        );
588
589        // No timeout yet
590        let removed = tracker.cleanup_timeout(1000000, 2000000);
591        assert!(removed.is_empty());
592
593        // After timeout
594        let removed = tracker.cleanup_timeout(5000000, 2000000);
595        assert_eq!(removed.len(), 1);
596    }
597
598    // Test 9: Simultaneous open (both send SYN)
599    #[test]
600    fn test_simultaneous_open() {
601        let mut tracker = ConnectionTracker::new();
602
603        // First SYN
604        let syn = TcpFlags {
605            syn: true,
606            ..Default::default()
607        };
608        let (conn, _) = tracker.get_or_create(
609            ip(192, 168, 1, 1),
610            1000,
611            ip(192, 168, 1, 2),
612            1001,
613            syn,
614            100,
615            1,
616            0,
617        );
618        assert_eq!(conn.state, ConnectionState::SynSent);
619    }
620
621    // Test 10: Connection ID uniqueness
622    #[test]
623    fn test_connection_id_uniqueness() {
624        let mut tracker = ConnectionTracker::new();
625        let syn = TcpFlags {
626            syn: true,
627            ..Default::default()
628        };
629
630        let (conn1, _) = tracker.get_or_create(
631            ip(192, 168, 1, 1),
632            54321,
633            ip(192, 168, 1, 2),
634            80,
635            syn,
636            1000,
637            1,
638            0,
639        );
640        let id1 = conn1.id;
641
642        let (conn2, _) = tracker.get_or_create(
643            ip(192, 168, 1, 3),
644            54322,
645            ip(192, 168, 1, 4),
646            443,
647            syn,
648            2000,
649            2,
650            1,
651        );
652        let id2 = conn2.id;
653
654        assert_ne!(id1, id2);
655    }
656}