parse_tcp/
flow_table.rs

1use std::collections::HashMap;
2use std::fmt::Display;
3use std::mem;
4use std::net::IpAddr;
5
6use kinesin_rdt::common::ring_buffer::RingBuf;
7use tracing::debug;
8use tracing::warn;
9
10use crate::connection::Connection;
11use crate::connection::ConnectionState;
12use crate::connection::Direction;
13use crate::serialized::PacketExtra;
14use crate::ConnectionHandler;
15use crate::TcpMeta;
16
17// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
18pub const IPPROTO_TCP: u8 = 6;
19pub const IPPROTO_UDP: u8 = 17;
20
21#[derive(Debug, Clone)]
22pub struct Flow {
23    pub proto: u8,
24    pub src_addr: IpAddr,
25    pub src_port: u16,
26    pub dst_addr: IpAddr,
27    pub dst_port: u16,
28}
29
30impl Flow {
31    /// reverse source/destination
32    pub fn reverse(&mut self) {
33        mem::swap(&mut self.src_addr, &mut self.dst_addr);
34        mem::swap(&mut self.src_port, &mut self.dst_port);
35    }
36
37    /// compare to TcpMeta
38    pub fn compare_tcp_meta(&self, other: &TcpMeta) -> FlowCompare {
39        self.compare(&other.into())
40    }
41
42    /// compare to other
43    pub fn compare(&self, other: &Self) -> FlowCompare {
44        if self.proto != other.proto {
45            FlowCompare::None
46        } else if self.src_addr == other.src_addr
47            && self.dst_addr == other.dst_addr
48            && self.src_port == other.src_port
49            && self.dst_port == other.dst_port
50        {
51            // exact match
52            FlowCompare::Forward
53        } else if self.src_addr == other.dst_addr
54            && self.dst_addr == other.src_addr
55            && self.src_port == other.dst_port
56            && self.dst_port == other.src_port
57        {
58            // reverse direction
59            FlowCompare::Reverse
60        } else {
61            FlowCompare::None
62        }
63    }
64}
65
66impl From<&TcpMeta> for Flow {
67    fn from(value: &TcpMeta) -> Self {
68        Flow {
69            proto: IPPROTO_TCP,
70            src_addr: value.src_addr,
71            src_port: value.src_port,
72            dst_addr: value.dst_addr,
73            dst_port: value.dst_port,
74        }
75    }
76}
77
78impl Display for Flow {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        macro_rules! fmt_addr {
81            ($addr:expr) => {
82                match $addr {
83                    IpAddr::V4(addr) => addr.fmt(f)?,
84                    IpAddr::V6(addr) => {
85                        write!(f, "[")?;
86                        addr.fmt(f)?;
87                        write!(f, "]")?;
88                    }
89                }
90            };
91        }
92        fmt_addr!(self.src_addr);
93        write!(f, ":{} -> ", self.src_port)?;
94        fmt_addr!(self.dst_addr);
95        write!(f, ":{}", self.dst_port)?;
96        Ok(())
97    }
98}
99
100/// result of FlowId::compare
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102pub enum FlowCompare {
103    /// identical to other
104    Forward,
105    /// reversed of other
106    Reverse,
107    /// no relation
108    None,
109}
110
111impl FlowCompare {
112    /// get direction from compare, or None
113    pub fn to_direction(&self) -> Option<Direction> {
114        match self {
115            FlowCompare::Forward => Some(Direction::Forward),
116            FlowCompare::Reverse => Some(Direction::Reverse),
117            FlowCompare::None => None,
118        }
119    }
120}
121
122impl PartialEq for Flow {
123    fn eq(&self, other: &Self) -> bool {
124        self.compare(other) != FlowCompare::None
125    }
126}
127
128impl Eq for Flow {}
129
130impl std::hash::Hash for Flow {
131    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
132        // order independent hashing
133        if self.src_addr <= self.dst_addr {
134            self.src_addr.hash(state);
135            self.dst_addr.hash(state);
136        } else {
137            self.dst_addr.hash(state);
138            self.src_addr.hash(state);
139        }
140        if self.src_port <= self.dst_port {
141            self.src_port.hash(state);
142            self.dst_port.hash(state);
143        } else {
144            self.dst_port.hash(state);
145            self.src_port.hash(state);
146        }
147    }
148}
149
150/// a table of TCP connections
151pub struct FlowTable<H: ConnectionHandler> {
152    /// map holding flows by tuple
153    pub map: HashMap<Flow, Connection<H>>,
154    /// retired connections (usually closed)
155    // hahahahaha watch this explode
156    pub retired: RingBuf<Connection<H>>,
157    /// whether retired connections should be saved
158    pub save_retired: bool,
159    /// initial data for ConnectionHandler
160    pub handler_init_data: H::InitialData,
161}
162
163/// result of FlowTable::handle_packet_direct
164pub enum HandlePacketResult {
165    /// packet successfully processed
166    Ok,
167    /// packet ignored, possibly because it was a duplicate
168    Dropped,
169    /// flow not found in hash table, data returned
170    NotFound,
171    /// connection fatally desynchronized, data returned
172    Desync,
173}
174
175impl<H: ConnectionHandler> FlowTable<H> {
176    /// create new instance
177    pub fn new(handler_init_data: H::InitialData) -> Self {
178        Self {
179            map: HashMap::new(),
180            retired: RingBuf::new(),
181            save_retired: false,
182            handler_init_data,
183        }
184    }
185
186    /// handle a packet, creating a flow if necessary
187    pub fn handle_packet(
188        &mut self,
189        meta: &TcpMeta,
190        data: &[u8],
191        extra: &PacketExtra,
192    ) -> Result<bool, H::ConstructError> {
193        match self.handle_packet_direct(meta, data, extra) {
194            HandlePacketResult::Ok => Ok(true),
195            HandlePacketResult::Dropped => Ok(false),
196            HandlePacketResult::NotFound => {
197                // create the flow, then process again
198                self.create_flow(meta.into(), self.handler_init_data.clone())?;
199                match self.handle_packet_direct(meta, data, extra) {
200                    HandlePacketResult::Ok => Ok(true),
201                    HandlePacketResult::Dropped => Ok(false),
202                    _ => unreachable!("result not possible"),
203                }
204            }
205            HandlePacketResult::Desync => {
206                // remove flow, then recreate and try again
207                debug!("handle_packet: got desync, recreating flow");
208                let flow: Flow = meta.into();
209                self.retire_flow(flow.clone());
210                self.create_flow(flow, self.handler_init_data.clone())?;
211                match self.handle_packet_direct(meta, data, extra) {
212                    HandlePacketResult::Ok => Ok(true),
213                    HandlePacketResult::Dropped => Ok(false),
214                    _ => unreachable!("result not possible"),
215                }
216            }
217        }
218    }
219
220    /// handle a packet, return Err if flow does not exist (and return args)
221    pub fn handle_packet_direct(
222        &mut self,
223        meta: &TcpMeta,
224        data: &[u8],
225        extra: &PacketExtra,
226    ) -> HandlePacketResult {
227        let flow = meta.into();
228        let did_something;
229        match self.map.get_mut(&flow) {
230            Some(conn) => {
231                did_something = conn.handle_packet(meta, data, extra);
232                match conn.conn_state {
233                    // remove flow if connection is no more
234                    ConnectionState::Closed => self.retire_flow(flow),
235                    ConnectionState::Desync => {
236                        return HandlePacketResult::Desync;
237                    }
238                    _ => {}
239                }
240                if did_something {
241                    HandlePacketResult::Ok
242                } else {
243                    HandlePacketResult::Dropped
244                }
245            }
246            None => HandlePacketResult::NotFound,
247        }
248    }
249
250    /// create flow
251    pub fn create_flow(
252        &mut self,
253        flow: Flow,
254        init_data: H::InitialData,
255    ) -> Result<Option<Connection<H>>, H::ConstructError> {
256        let conn = Connection::new(flow.clone(), init_data)?;
257        debug!("new flow: {} {flow}", conn.uuid);
258        Ok(self.map.insert(flow, conn))
259    }
260
261    pub fn retire_flow(&mut self, flow: Flow) {
262        let Some(mut conn) = self.map.remove(&flow) else {
263            warn!("retire_flow called on non-existent flow?: {flow}");
264            return;
265        };
266
267        debug!("remove flow: {} {flow}", conn.uuid);
268        conn.will_retire();
269        if self.save_retired {
270            self.retired.push_back(conn);
271        }
272    }
273
274    /// close flowtable and retire all flows
275    pub fn close(&mut self) {
276        debug!("flowtable closing");
277        for (flow, mut conn) in self.map.drain() {
278            debug!("remove flow: {} {flow}", conn.uuid);
279            conn.will_retire();
280            if self.save_retired {
281                self.retired.push_back(conn);
282            }
283        }
284    }
285}
286
287#[cfg(test)]
288mod test {
289    use std::collections::HashMap;
290    use std::net::Ipv4Addr;
291
292    use super::{Flow, IPPROTO_TCP};
293
294    #[test]
295    fn hash_map() {
296        let forward = Flow {
297            proto: IPPROTO_TCP,
298            src_addr: Ipv4Addr::new(10, 3, 160, 24).into(),
299            src_port: 35619,
300            dst_addr: Ipv4Addr::new(1, 1, 1, 1).into(),
301            dst_port: 53,
302        };
303        let reverse = Flow {
304            proto: IPPROTO_TCP,
305            src_addr: forward.dst_addr,
306            src_port: forward.dst_port,
307            dst_addr: forward.src_addr,
308            dst_port: forward.src_port,
309        };
310        let unrelated = Flow {
311            proto: IPPROTO_TCP,
312            src_addr: Ipv4Addr::new(10, 3, 160, 24).into(),
313            src_port: 35619,
314            dst_addr: Ipv4Addr::new(8, 8, 8, 8).into(),
315            dst_port: 53,
316        };
317        assert_eq!(forward, reverse);
318        assert_ne!(forward, unrelated);
319
320        let mut map: HashMap<Flow, String> = HashMap::new();
321        map.insert(forward.clone(), "test 1".into());
322        assert_eq!(map.get(&forward), Some(&"test 1".into()));
323        assert_eq!(map.get(&reverse), Some(&"test 1".into()));
324        assert_eq!(map.get(&unrelated), None);
325
326        assert_eq!(
327            map.insert(reverse.clone(), "test 2".into()),
328            Some("test 1".into())
329        );
330        assert_eq!(map.insert(unrelated.clone(), "test 3".into()), None);
331        assert_eq!(map.get(&forward), Some(&"test 2".into()));
332        assert_eq!(map.get(&unrelated), Some(&"test 3".into()));
333    }
334}