use std::collections::HashMap;
use std::fmt::Display;
use std::mem;
use std::net::IpAddr;
use kinesin_rdt::common::ring_buffer::RingBuf;
use tracing::debug;
use tracing::warn;
use crate::connection::Connection;
use crate::connection::ConnectionState;
use crate::connection::Direction;
use crate::serialized::PacketExtra;
use crate::ConnectionHandler;
use crate::TcpMeta;
pub const IPPROTO_TCP: u8 = 6;
pub const IPPROTO_UDP: u8 = 17;
#[derive(Debug, Clone)]
pub struct Flow {
pub proto: u8,
pub src_addr: IpAddr,
pub src_port: u16,
pub dst_addr: IpAddr,
pub dst_port: u16,
}
impl Flow {
pub fn reverse(&mut self) {
mem::swap(&mut self.src_addr, &mut self.dst_addr);
mem::swap(&mut self.src_port, &mut self.dst_port);
}
pub fn compare_tcp_meta(&self, other: &TcpMeta) -> FlowCompare {
self.compare(&other.into())
}
pub fn compare(&self, other: &Self) -> FlowCompare {
if self.proto != other.proto {
FlowCompare::None
} else if self.src_addr == other.src_addr
&& self.dst_addr == other.dst_addr
&& self.src_port == other.src_port
&& self.dst_port == other.dst_port
{
FlowCompare::Forward
} else if self.src_addr == other.dst_addr
&& self.dst_addr == other.src_addr
&& self.src_port == other.dst_port
&& self.dst_port == other.src_port
{
FlowCompare::Reverse
} else {
FlowCompare::None
}
}
}
impl From<&TcpMeta> for Flow {
fn from(value: &TcpMeta) -> Self {
Flow {
proto: IPPROTO_TCP,
src_addr: value.src_addr,
src_port: value.src_port,
dst_addr: value.dst_addr,
dst_port: value.dst_port,
}
}
}
impl Display for Flow {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
macro_rules! fmt_addr {
($addr:expr) => {
match $addr {
IpAddr::V4(addr) => addr.fmt(f)?,
IpAddr::V6(addr) => {
write!(f, "[")?;
addr.fmt(f)?;
write!(f, "]")?;
}
}
};
}
fmt_addr!(self.src_addr);
write!(f, ":{} -> ", self.src_port)?;
fmt_addr!(self.dst_addr);
write!(f, ":{}", self.dst_port)?;
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FlowCompare {
Forward,
Reverse,
None,
}
impl FlowCompare {
pub fn to_direction(&self) -> Option<Direction> {
match self {
FlowCompare::Forward => Some(Direction::Forward),
FlowCompare::Reverse => Some(Direction::Reverse),
FlowCompare::None => None,
}
}
}
impl PartialEq for Flow {
fn eq(&self, other: &Self) -> bool {
self.compare(other) != FlowCompare::None
}
}
impl Eq for Flow {}
impl std::hash::Hash for Flow {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
if self.src_addr <= self.dst_addr {
self.src_addr.hash(state);
self.dst_addr.hash(state);
} else {
self.dst_addr.hash(state);
self.src_addr.hash(state);
}
if self.src_port <= self.dst_port {
self.src_port.hash(state);
self.dst_port.hash(state);
} else {
self.dst_port.hash(state);
self.src_port.hash(state);
}
}
}
pub struct FlowTable<H: ConnectionHandler> {
pub map: HashMap<Flow, Connection<H>>,
pub retired: RingBuf<Connection<H>>,
pub save_retired: bool,
pub handler_init_data: H::InitialData,
}
pub enum HandlePacketResult {
Ok,
Dropped,
NotFound,
Desync,
}
impl<H: ConnectionHandler> FlowTable<H> {
pub fn new(handler_init_data: H::InitialData) -> Self {
Self {
map: HashMap::new(),
retired: RingBuf::new(),
save_retired: false,
handler_init_data,
}
}
pub fn handle_packet(
&mut self,
meta: &TcpMeta,
data: &[u8],
extra: &PacketExtra,
) -> Result<bool, H::ConstructError> {
match self.handle_packet_direct(meta, data, extra) {
HandlePacketResult::Ok => Ok(true),
HandlePacketResult::Dropped => Ok(false),
HandlePacketResult::NotFound => {
self.create_flow(meta.into(), self.handler_init_data.clone())?;
match self.handle_packet_direct(meta, data, extra) {
HandlePacketResult::Ok => Ok(true),
HandlePacketResult::Dropped => Ok(false),
_ => unreachable!("result not possible"),
}
}
HandlePacketResult::Desync => {
debug!("handle_packet: got desync, recreating flow");
let flow: Flow = meta.into();
self.retire_flow(flow.clone());
self.create_flow(flow, self.handler_init_data.clone())?;
match self.handle_packet_direct(meta, data, extra) {
HandlePacketResult::Ok => Ok(true),
HandlePacketResult::Dropped => Ok(false),
_ => unreachable!("result not possible"),
}
}
}
}
pub fn handle_packet_direct(
&mut self,
meta: &TcpMeta,
data: &[u8],
extra: &PacketExtra,
) -> HandlePacketResult {
let flow = meta.into();
let did_something;
match self.map.get_mut(&flow) {
Some(conn) => {
did_something = conn.handle_packet(meta, data, extra);
match conn.conn_state {
ConnectionState::Closed => self.retire_flow(flow),
ConnectionState::Desync => {
return HandlePacketResult::Desync;
}
_ => {}
}
if did_something {
HandlePacketResult::Ok
} else {
HandlePacketResult::Dropped
}
}
None => HandlePacketResult::NotFound,
}
}
pub fn create_flow(
&mut self,
flow: Flow,
init_data: H::InitialData,
) -> Result<Option<Connection<H>>, H::ConstructError> {
let conn = Connection::new(flow.clone(), init_data)?;
debug!("new flow: {} {flow}", conn.uuid);
Ok(self.map.insert(flow, conn))
}
pub fn retire_flow(&mut self, flow: Flow) {
let Some(mut conn) = self.map.remove(&flow) else {
warn!("retire_flow called on non-existent flow?: {flow}");
return;
};
debug!("remove flow: {} {flow}", conn.uuid);
conn.will_retire();
if self.save_retired {
self.retired.push_back(conn);
}
}
pub fn close(&mut self) {
debug!("flowtable closing");
for (flow, mut conn) in self.map.drain() {
debug!("remove flow: {} {flow}", conn.uuid);
conn.will_retire();
if self.save_retired {
self.retired.push_back(conn);
}
}
}
}
#[cfg(test)]
mod test {
use std::collections::HashMap;
use std::net::Ipv4Addr;
use super::{Flow, IPPROTO_TCP};
#[test]
fn hash_map() {
let forward = Flow {
proto: IPPROTO_TCP,
src_addr: Ipv4Addr::new(10, 3, 160, 24).into(),
src_port: 35619,
dst_addr: Ipv4Addr::new(1, 1, 1, 1).into(),
dst_port: 53,
};
let reverse = Flow {
proto: IPPROTO_TCP,
src_addr: forward.dst_addr,
src_port: forward.dst_port,
dst_addr: forward.src_addr,
dst_port: forward.src_port,
};
let unrelated = Flow {
proto: IPPROTO_TCP,
src_addr: Ipv4Addr::new(10, 3, 160, 24).into(),
src_port: 35619,
dst_addr: Ipv4Addr::new(8, 8, 8, 8).into(),
dst_port: 53,
};
assert_eq!(forward, reverse);
assert_ne!(forward, unrelated);
let mut map: HashMap<Flow, String> = HashMap::new();
map.insert(forward.clone(), "test 1".into());
assert_eq!(map.get(&forward), Some(&"test 1".into()));
assert_eq!(map.get(&reverse), Some(&"test 1".into()));
assert_eq!(map.get(&unrelated), None);
assert_eq!(
map.insert(reverse.clone(), "test 2".into()),
Some("test 1".into())
);
assert_eq!(map.insert(unrelated.clone(), "test 3".into()), None);
assert_eq!(map.get(&forward), Some(&"test 2".into()));
assert_eq!(map.get(&unrelated), Some(&"test 3".into()));
}
}