use std::collections::HashMap;
use ahash::RandomState;
use crate::event::{EndReason, FlowEvent, FlowSide};
use crate::extractor::FlowExtractor;
use crate::reassembler::{Reassembler, ReassemblerFactory};
use crate::tracker::{FlowEvents, FlowTracker, FlowTrackerConfig};
use crate::view::PacketView;
pub struct FlowDriver<E, F, S = ()>
where
E: FlowExtractor,
F: ReassemblerFactory<E::Key>,
S: Send + 'static,
{
tracker: FlowTracker<E, S>,
factory: F,
reassemblers: HashMap<(E::Key, FlowSide), F::Reassembler, RandomState>,
}
impl<E, F, S> FlowDriver<E, F, S>
where
E: FlowExtractor,
F: ReassemblerFactory<E::Key>,
S: Default + Send + 'static,
{
pub fn new(extractor: E, factory: F) -> Self {
Self::with_config(extractor, factory, FlowTrackerConfig::default())
}
pub fn with_config(extractor: E, factory: F, config: FlowTrackerConfig) -> Self {
Self {
tracker: FlowTracker::with_config(extractor, config),
factory,
reassemblers: HashMap::with_hasher(RandomState::new()),
}
}
}
impl<E, F, S> FlowDriver<E, F, S>
where
E: FlowExtractor,
F: ReassemblerFactory<E::Key>,
S: Send + 'static,
{
pub fn track(&mut self, view: PacketView<'_>) -> FlowEvents<E::Key> {
let factory = &mut self.factory;
let reassemblers = &mut self.reassemblers;
let events = self
.tracker
.track_with_payload(view, |key, side, seq, payload| {
let r = reassemblers
.entry((key.clone(), side))
.or_insert_with(|| factory.new_reassembler(key, side));
r.segment(seq, payload);
});
for ev in &events {
if let FlowEvent::Ended { key, reason, .. } = ev {
for side in [FlowSide::Initiator, FlowSide::Responder] {
if let Some(mut r) = reassemblers.remove(&(key.clone(), side)) {
match reason {
EndReason::Fin | EndReason::IdleTimeout => r.fin(),
EndReason::Rst | EndReason::Evicted => r.rst(),
}
}
}
}
}
events
}
pub fn sweep(&mut self, now: crate::Timestamp) -> Vec<FlowEvent<E::Key>> {
let events = self.tracker.sweep(now);
for ev in &events {
if let FlowEvent::Ended { key, reason, .. } = ev {
for side in [FlowSide::Initiator, FlowSide::Responder] {
if let Some(mut r) = self.reassemblers.remove(&(key.clone(), side)) {
match reason {
EndReason::Fin | EndReason::IdleTimeout => r.fin(),
EndReason::Rst | EndReason::Evicted => r.rst(),
}
}
}
}
}
events
}
pub fn tracker(&self) -> &FlowTracker<E, S> {
&self.tracker
}
pub fn tracker_mut(&mut self) -> &mut FlowTracker<E, S> {
&mut self.tracker
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::extract::FiveTuple;
use crate::extract::parse::test_frames::*;
use crate::reassembler::{BufferedReassembler, BufferedReassemblerFactory};
use crate::{FlowEvent, Timestamp};
fn view(frame: &[u8], sec: u32) -> PacketView<'_> {
PacketView::new(frame, Timestamp::new(sec, 0))
}
#[test]
fn buffered_reassembly_in_order() {
let mut d = FlowDriver::<_, _>::new(FiveTuple::bidirectional(), BufferedReassemblerFactory);
let syn = ipv4_tcp(
[0; 6],
[0; 6],
[10, 0, 0, 1],
[10, 0, 0, 2],
1234,
80,
1000,
0,
0x02,
b"",
);
let synack = ipv4_tcp(
[0; 6],
[0; 6],
[10, 0, 0, 2],
[10, 0, 0, 1],
80,
1234,
5000,
1001,
0x12,
b"",
);
let ack = ipv4_tcp(
[0; 6],
[0; 6],
[10, 0, 0, 1],
[10, 0, 0, 2],
1234,
80,
1001,
5001,
0x10,
b"",
);
let req = ipv4_tcp(
[0; 6],
[0; 6],
[10, 0, 0, 1],
[10, 0, 0, 2],
1234,
80,
1001,
5001,
0x18,
b"GET / HTTP/1.1\r\n\r\n",
);
let resp = ipv4_tcp(
[0; 6],
[0; 6],
[10, 0, 0, 2],
[10, 0, 0, 1],
80,
1234,
5001,
1019,
0x18,
b"HTTP/1.1 200 OK\r\n\r\nbody",
);
d.track(view(&syn, 0));
d.track(view(&synack, 0));
d.track(view(&ack, 0));
d.track(view(&req, 0));
d.track(view(&resp, 0));
let fin = ipv4_tcp(
[0; 6],
[0; 6],
[10, 0, 0, 1],
[10, 0, 0, 2],
1234,
80,
1019,
5024,
0x11,
b"",
);
let fin_resp = ipv4_tcp(
[0; 6],
[0; 6],
[10, 0, 0, 2],
[10, 0, 0, 1],
80,
1234,
5024,
1020,
0x11,
b"",
);
let last_ack = ipv4_tcp(
[0; 6],
[0; 6],
[10, 0, 0, 1],
[10, 0, 0, 2],
1234,
80,
1020,
5025,
0x10,
b"",
);
let mut all_events = Vec::new();
all_events.extend(d.track(view(&fin, 0)));
all_events.extend(d.track(view(&fin_resp, 0)));
all_events.extend(d.track(view(&last_ack, 0)));
let ended_count = all_events
.iter()
.filter(|e| matches!(e, FlowEvent::Ended { .. }))
.count();
assert_eq!(ended_count, 1);
}
#[test]
fn no_dispatch_on_empty_payload() {
struct CountingFactory(std::cell::RefCell<Vec<FlowSide>>);
impl ReassemblerFactory<crate::extract::FiveTupleKey> for CountingFactory {
type Reassembler = BufferedReassembler;
fn new_reassembler(
&mut self,
_key: &crate::extract::FiveTupleKey,
side: FlowSide,
) -> BufferedReassembler {
self.0.borrow_mut().push(side);
BufferedReassembler::new()
}
}
unsafe impl Send for CountingFactory {}
unsafe impl Sync for CountingFactory {}
let factory = CountingFactory(std::cell::RefCell::new(Vec::new()));
let mut d = FlowDriver::<_, _>::new(FiveTuple::bidirectional(), factory);
let syn = ipv4_tcp(
[0; 6],
[0; 6],
[10, 0, 0, 1],
[10, 0, 0, 2],
1234,
80,
0,
0,
0x02,
b"",
);
d.track(view(&syn, 0));
assert!(d.factory.0.borrow().is_empty());
}
}