use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use ahash::RandomState;
use bytes::Bytes;
use flowscope::tracker::FlowEvents;
use flowscope::{
EndReason, FlowEvent, FlowExtractor, FlowSide, FlowTracker, FlowTrackerConfig, Timestamp,
};
use futures_core::Stream;
use crate::async_adapters::async_reassembler::{AsyncReassembler, AsyncReassemblerFactory};
use crate::async_adapters::tokio_adapter::AsyncCapture;
use crate::error::Error;
use crate::traits::PacketSource;
pub struct NoReassembler;
pub struct AsyncReassemblerSlot<K, F>
where
K: Eq + std::hash::Hash + Clone + Send + 'static,
F: AsyncReassemblerFactory<K>,
{
factory: F,
instances: HashMap<(K, FlowSide), F::Reassembler, RandomState>,
pending_payloads: VecDeque<(K, FlowSide, u32, Bytes)>,
pending_future: Option<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
}
pub struct FlowStream<S, E, U = (), R = NoReassembler>
where
S: PacketSource + std::os::unix::io::AsRawFd,
E: FlowExtractor,
U: Send + 'static,
{
cap: AsyncCapture<S>,
tracker: FlowTracker<E, U>,
pending: VecDeque<FlowEvent<E::Key>>,
sweep: tokio::time::Interval,
reassembler: R,
}
impl<S, E> FlowStream<S, E, (), NoReassembler>
where
S: PacketSource + std::os::unix::io::AsRawFd,
E: FlowExtractor,
{
pub(crate) fn new(cap: AsyncCapture<S>, extractor: E) -> Self {
let tracker = FlowTracker::new(extractor);
let sweep_interval = tracker.config().sweep_interval;
Self {
cap,
tracker,
pending: VecDeque::new(),
sweep: tokio::time::interval(sweep_interval),
reassembler: NoReassembler,
}
}
pub fn with_state<U, F>(self, init: F) -> FlowStream<S, E, U, NoReassembler>
where
U: Send + 'static,
F: FnMut(&E::Key) -> U + Send + 'static,
{
let config = self.tracker.config().clone();
let extractor = self.tracker.into_extractor();
FlowStream {
cap: self.cap,
tracker: FlowTracker::with_config_and_state(extractor, config, init),
pending: VecDeque::new(),
sweep: self.sweep,
reassembler: NoReassembler,
}
}
}
impl<S, E, U> FlowStream<S, E, U, NoReassembler>
where
S: PacketSource + std::os::unix::io::AsRawFd,
E: FlowExtractor,
U: Send + 'static,
{
pub fn with_async_reassembler<F>(
self,
factory: F,
) -> FlowStream<S, E, U, AsyncReassemblerSlot<E::Key, F>>
where
F: AsyncReassemblerFactory<E::Key>,
{
FlowStream {
cap: self.cap,
tracker: self.tracker,
pending: self.pending,
sweep: self.sweep,
reassembler: AsyncReassemblerSlot {
factory,
instances: HashMap::with_hasher(RandomState::new()),
pending_payloads: VecDeque::new(),
pending_future: None,
},
}
}
}
impl<S, E> FlowStream<S, E, (), NoReassembler>
where
S: PacketSource + std::os::unix::io::AsRawFd,
E: FlowExtractor,
E::Key: Eq + std::hash::Hash + Clone + Send + 'static,
{
pub fn session_stream<F>(
self,
factory: F,
) -> crate::async_adapters::session_stream::SessionStream<S, E, F>
where
F: flowscope::SessionParserFactory<E::Key>,
{
let config = self.tracker.config().clone();
let extractor = self.tracker.into_extractor();
crate::async_adapters::session_stream::SessionStream::new_with_config(
self.cap, extractor, factory, config,
)
}
pub fn datagram_stream<F>(
self,
factory: F,
) -> crate::async_adapters::datagram_stream::DatagramStream<S, E, F>
where
F: flowscope::DatagramParserFactory<E::Key>,
{
let config = self.tracker.config().clone();
let extractor = self.tracker.into_extractor();
crate::async_adapters::datagram_stream::DatagramStream::new_with_config(
self.cap, extractor, factory, config,
)
}
}
impl<S, E, U, R> FlowStream<S, E, U, R>
where
S: PacketSource + std::os::unix::io::AsRawFd,
E: FlowExtractor,
U: Send + 'static,
{
pub fn with_config(mut self, config: FlowTrackerConfig) -> Self {
let new_interval = config.sweep_interval;
self.tracker.set_config(config);
self.sweep = tokio::time::interval(new_interval);
self
}
pub fn tracker(&self) -> &FlowTracker<E, U> {
&self.tracker
}
pub fn tracker_mut(&mut self) -> &mut FlowTracker<E, U> {
&mut self.tracker
}
}
impl<S, E, U> Stream for FlowStream<S, E, U, NoReassembler>
where
S: PacketSource + std::os::unix::io::AsRawFd + Unpin,
E: FlowExtractor + Unpin,
E::Key: Clone + Unpin,
U: Send + 'static + Unpin,
{
type Item = Result<FlowEvent<E::Key>, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
if let Some(evt) = this.pending.pop_front() {
return Poll::Ready(Some(Ok(evt)));
}
if this.sweep.poll_tick(cx).is_ready() {
let now = current_timestamp();
for ev in this.tracker.sweep(now) {
this.pending.push_back(ev);
}
if let Some(evt) = this.pending.pop_front() {
return Poll::Ready(Some(Ok(evt)));
}
}
let mut guard = match this.cap.poll_read_ready_mut(cx) {
Poll::Ready(Ok(g)) => g,
Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(Error::Io(e)))),
Poll::Pending => return Poll::Pending,
};
let got_batch = {
let inner = guard.get_inner_mut();
if let Some(batch) = inner.next_batch() {
for pkt in &batch {
let view = pkt.view();
let evts: FlowEvents<E::Key> = this.tracker.track(view);
for ev in evts {
this.pending.push_back(ev);
}
}
drop(batch);
true
} else {
false
}
};
if !got_batch {
guard.clear_ready();
}
}
}
}
impl<S, E, U, F> Stream for FlowStream<S, E, U, AsyncReassemblerSlot<E::Key, F>>
where
S: PacketSource + std::os::unix::io::AsRawFd + Unpin,
E: FlowExtractor + Unpin,
E::Key: Clone + Unpin,
U: Send + 'static + Unpin,
F: AsyncReassemblerFactory<E::Key> + Unpin,
F::Reassembler: Unpin,
{
type Item = Result<FlowEvent<E::Key>, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
if let Some(fut) = this.reassembler.pending_future.as_mut() {
match fut.as_mut().poll(cx) {
Poll::Ready(()) => {
this.reassembler.pending_future = None;
}
Poll::Pending => return Poll::Pending,
}
}
if let Some((key, side, seq, payload)) = this.reassembler.pending_payloads.pop_front() {
let r = this
.reassembler
.instances
.entry((key.clone(), side))
.or_insert_with(|| this.reassembler.factory.new_reassembler(&key, side));
let fut = r.segment(seq, payload);
this.reassembler.pending_future = Some(fut);
continue;
}
if let Some(evt) = this.pending.pop_front() {
if let FlowEvent::Ended { key, reason, .. } = &evt {
let reason_copy = *reason;
let key_copy = key.clone();
let mut found_fut = None;
for side in [FlowSide::Initiator, FlowSide::Responder] {
if let Some(mut r) =
this.reassembler.instances.remove(&(key_copy.clone(), side))
{
let fut = match reason_copy {
EndReason::Fin | EndReason::IdleTimeout => r.fin(),
EndReason::Rst | EndReason::Evicted | EndReason::BufferOverflow => {
r.rst()
}
};
drop(r);
found_fut = Some(fut);
break;
}
}
if let Some(fut) = found_fut {
this.pending.push_front(evt);
this.reassembler.pending_future = Some(fut);
continue;
}
}
return Poll::Ready(Some(Ok(evt)));
}
if this.sweep.poll_tick(cx).is_ready() {
let now = current_timestamp();
for ev in this.tracker.sweep(now) {
this.pending.push_back(ev);
}
if !this.pending.is_empty() {
continue;
}
}
let mut guard = match this.cap.poll_read_ready_mut(cx) {
Poll::Ready(Ok(g)) => g,
Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(Error::Io(e)))),
Poll::Pending => return Poll::Pending,
};
let got_batch = {
let inner = guard.get_inner_mut();
if let Some(batch) = inner.next_batch() {
for pkt in &batch {
let view = pkt.view();
let payloads = &mut this.reassembler.pending_payloads;
let evts: FlowEvents<E::Key> =
this.tracker
.track_with_payload(view, |key, side, seq, payload| {
payloads.push_back((
key.clone(),
side,
seq,
Bytes::copy_from_slice(payload),
));
});
for ev in evts {
this.pending.push_back(ev);
}
}
drop(batch);
true
} else {
false
}
};
if !got_batch {
guard.clear_ready();
}
}
}
}
fn current_timestamp() -> Timestamp {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or(Duration::ZERO);
Timestamp::new(now.as_secs() as u32, now.subsec_nanos())
}
impl<S> AsyncCapture<S>
where
S: PacketSource + std::os::unix::io::AsRawFd,
{
pub fn flow_stream<E>(self, extractor: E) -> FlowStream<S, E, (), NoReassembler>
where
E: FlowExtractor,
{
FlowStream::new(self, extractor)
}
}