1use std::collections::HashMap;
2use std::fmt::Display;
3use std::mem;
4use std::net::IpAddr;
5
6use kinesin_rdt::common::ring_buffer::RingBuf;
7use tracing::debug;
8use tracing::warn;
9
10use crate::connection::Connection;
11use crate::connection::ConnectionState;
12use crate::connection::Direction;
13use crate::serialized::PacketExtra;
14use crate::ConnectionHandler;
15use crate::TcpMeta;
16
17pub const IPPROTO_TCP: u8 = 6;
19pub const IPPROTO_UDP: u8 = 17;
20
21#[derive(Debug, Clone)]
22pub struct Flow {
23 pub proto: u8,
24 pub src_addr: IpAddr,
25 pub src_port: u16,
26 pub dst_addr: IpAddr,
27 pub dst_port: u16,
28}
29
30impl Flow {
31 pub fn reverse(&mut self) {
33 mem::swap(&mut self.src_addr, &mut self.dst_addr);
34 mem::swap(&mut self.src_port, &mut self.dst_port);
35 }
36
37 pub fn compare_tcp_meta(&self, other: &TcpMeta) -> FlowCompare {
39 self.compare(&other.into())
40 }
41
42 pub fn compare(&self, other: &Self) -> FlowCompare {
44 if self.proto != other.proto {
45 FlowCompare::None
46 } else if self.src_addr == other.src_addr
47 && self.dst_addr == other.dst_addr
48 && self.src_port == other.src_port
49 && self.dst_port == other.dst_port
50 {
51 FlowCompare::Forward
53 } else if self.src_addr == other.dst_addr
54 && self.dst_addr == other.src_addr
55 && self.src_port == other.dst_port
56 && self.dst_port == other.src_port
57 {
58 FlowCompare::Reverse
60 } else {
61 FlowCompare::None
62 }
63 }
64}
65
66impl From<&TcpMeta> for Flow {
67 fn from(value: &TcpMeta) -> Self {
68 Flow {
69 proto: IPPROTO_TCP,
70 src_addr: value.src_addr,
71 src_port: value.src_port,
72 dst_addr: value.dst_addr,
73 dst_port: value.dst_port,
74 }
75 }
76}
77
78impl Display for Flow {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 macro_rules! fmt_addr {
81 ($addr:expr) => {
82 match $addr {
83 IpAddr::V4(addr) => addr.fmt(f)?,
84 IpAddr::V6(addr) => {
85 write!(f, "[")?;
86 addr.fmt(f)?;
87 write!(f, "]")?;
88 }
89 }
90 };
91 }
92 fmt_addr!(self.src_addr);
93 write!(f, ":{} -> ", self.src_port)?;
94 fmt_addr!(self.dst_addr);
95 write!(f, ":{}", self.dst_port)?;
96 Ok(())
97 }
98}
99
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102pub enum FlowCompare {
103 Forward,
105 Reverse,
107 None,
109}
110
111impl FlowCompare {
112 pub fn to_direction(&self) -> Option<Direction> {
114 match self {
115 FlowCompare::Forward => Some(Direction::Forward),
116 FlowCompare::Reverse => Some(Direction::Reverse),
117 FlowCompare::None => None,
118 }
119 }
120}
121
122impl PartialEq for Flow {
123 fn eq(&self, other: &Self) -> bool {
124 self.compare(other) != FlowCompare::None
125 }
126}
127
128impl Eq for Flow {}
129
130impl std::hash::Hash for Flow {
131 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
132 if self.src_addr <= self.dst_addr {
134 self.src_addr.hash(state);
135 self.dst_addr.hash(state);
136 } else {
137 self.dst_addr.hash(state);
138 self.src_addr.hash(state);
139 }
140 if self.src_port <= self.dst_port {
141 self.src_port.hash(state);
142 self.dst_port.hash(state);
143 } else {
144 self.dst_port.hash(state);
145 self.src_port.hash(state);
146 }
147 }
148}
149
150pub struct FlowTable<H: ConnectionHandler> {
152 pub map: HashMap<Flow, Connection<H>>,
154 pub retired: RingBuf<Connection<H>>,
157 pub save_retired: bool,
159 pub handler_init_data: H::InitialData,
161}
162
163pub enum HandlePacketResult {
165 Ok,
167 Dropped,
169 NotFound,
171 Desync,
173}
174
175impl<H: ConnectionHandler> FlowTable<H> {
176 pub fn new(handler_init_data: H::InitialData) -> Self {
178 Self {
179 map: HashMap::new(),
180 retired: RingBuf::new(),
181 save_retired: false,
182 handler_init_data,
183 }
184 }
185
186 pub fn handle_packet(
188 &mut self,
189 meta: &TcpMeta,
190 data: &[u8],
191 extra: &PacketExtra,
192 ) -> Result<bool, H::ConstructError> {
193 match self.handle_packet_direct(meta, data, extra) {
194 HandlePacketResult::Ok => Ok(true),
195 HandlePacketResult::Dropped => Ok(false),
196 HandlePacketResult::NotFound => {
197 self.create_flow(meta.into(), self.handler_init_data.clone())?;
199 match self.handle_packet_direct(meta, data, extra) {
200 HandlePacketResult::Ok => Ok(true),
201 HandlePacketResult::Dropped => Ok(false),
202 _ => unreachable!("result not possible"),
203 }
204 }
205 HandlePacketResult::Desync => {
206 debug!("handle_packet: got desync, recreating flow");
208 let flow: Flow = meta.into();
209 self.retire_flow(flow.clone());
210 self.create_flow(flow, self.handler_init_data.clone())?;
211 match self.handle_packet_direct(meta, data, extra) {
212 HandlePacketResult::Ok => Ok(true),
213 HandlePacketResult::Dropped => Ok(false),
214 _ => unreachable!("result not possible"),
215 }
216 }
217 }
218 }
219
220 pub fn handle_packet_direct(
222 &mut self,
223 meta: &TcpMeta,
224 data: &[u8],
225 extra: &PacketExtra,
226 ) -> HandlePacketResult {
227 let flow = meta.into();
228 let did_something;
229 match self.map.get_mut(&flow) {
230 Some(conn) => {
231 did_something = conn.handle_packet(meta, data, extra);
232 match conn.conn_state {
233 ConnectionState::Closed => self.retire_flow(flow),
235 ConnectionState::Desync => {
236 return HandlePacketResult::Desync;
237 }
238 _ => {}
239 }
240 if did_something {
241 HandlePacketResult::Ok
242 } else {
243 HandlePacketResult::Dropped
244 }
245 }
246 None => HandlePacketResult::NotFound,
247 }
248 }
249
250 pub fn create_flow(
252 &mut self,
253 flow: Flow,
254 init_data: H::InitialData,
255 ) -> Result<Option<Connection<H>>, H::ConstructError> {
256 let conn = Connection::new(flow.clone(), init_data)?;
257 debug!("new flow: {} {flow}", conn.uuid);
258 Ok(self.map.insert(flow, conn))
259 }
260
261 pub fn retire_flow(&mut self, flow: Flow) {
262 let Some(mut conn) = self.map.remove(&flow) else {
263 warn!("retire_flow called on non-existent flow?: {flow}");
264 return;
265 };
266
267 debug!("remove flow: {} {flow}", conn.uuid);
268 conn.will_retire();
269 if self.save_retired {
270 self.retired.push_back(conn);
271 }
272 }
273
274 pub fn close(&mut self) {
276 debug!("flowtable closing");
277 for (flow, mut conn) in self.map.drain() {
278 debug!("remove flow: {} {flow}", conn.uuid);
279 conn.will_retire();
280 if self.save_retired {
281 self.retired.push_back(conn);
282 }
283 }
284 }
285}
286
287#[cfg(test)]
288mod test {
289 use std::collections::HashMap;
290 use std::net::Ipv4Addr;
291
292 use super::{Flow, IPPROTO_TCP};
293
294 #[test]
295 fn hash_map() {
296 let forward = Flow {
297 proto: IPPROTO_TCP,
298 src_addr: Ipv4Addr::new(10, 3, 160, 24).into(),
299 src_port: 35619,
300 dst_addr: Ipv4Addr::new(1, 1, 1, 1).into(),
301 dst_port: 53,
302 };
303 let reverse = Flow {
304 proto: IPPROTO_TCP,
305 src_addr: forward.dst_addr,
306 src_port: forward.dst_port,
307 dst_addr: forward.src_addr,
308 dst_port: forward.src_port,
309 };
310 let unrelated = Flow {
311 proto: IPPROTO_TCP,
312 src_addr: Ipv4Addr::new(10, 3, 160, 24).into(),
313 src_port: 35619,
314 dst_addr: Ipv4Addr::new(8, 8, 8, 8).into(),
315 dst_port: 53,
316 };
317 assert_eq!(forward, reverse);
318 assert_ne!(forward, unrelated);
319
320 let mut map: HashMap<Flow, String> = HashMap::new();
321 map.insert(forward.clone(), "test 1".into());
322 assert_eq!(map.get(&forward), Some(&"test 1".into()));
323 assert_eq!(map.get(&reverse), Some(&"test 1".into()));
324 assert_eq!(map.get(&unrelated), None);
325
326 assert_eq!(
327 map.insert(reverse.clone(), "test 2".into()),
328 Some("test 1".into())
329 );
330 assert_eq!(map.insert(unrelated.clone(), "test 3".into()), None);
331 assert_eq!(map.get(&forward), Some(&"test 2".into()));
332 assert_eq!(map.get(&unrelated), Some(&"test 3".into()));
333 }
334}