use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::hash::Hash;
use std::pin::Pin;
use std::sync::{Arc, Mutex, Weak};
use std::task::{Context, Poll};
use ahash::RandomState;
use bytes::Bytes;
use flowscope::{EndReason, FlowExtractor, FlowSide};
use futures_core::Stream;
use tokio::sync::mpsc;
use crate::async_adapters::async_reassembler::{AsyncReassembler, AsyncReassemblerFactory};
use crate::async_adapters::flow_stream::{AsyncReassemblerSlot, FlowStream, NoReassembler};
use crate::async_adapters::tokio_adapter::AsyncCapture;
use crate::error::Error;
use crate::traits::PacketSource;
#[derive(Debug, Clone)]
pub enum ConversationChunk {
Initiator(Bytes),
Responder(Bytes),
Closed {
reason: EndReason,
},
}
pub struct Conversation<K> {
pub key: K,
rx: mpsc::Receiver<(FlowSide, Bytes)>,
end_reason: Arc<Mutex<Option<EndReason>>>,
closed_emitted: bool,
}
impl<K> Conversation<K> {
pub async fn next_chunk(&mut self) -> Option<ConversationChunk> {
if self.closed_emitted {
return None;
}
match self.rx.recv().await {
Some((FlowSide::Initiator, bytes)) => Some(ConversationChunk::Initiator(bytes)),
Some((FlowSide::Responder, bytes)) => Some(ConversationChunk::Responder(bytes)),
None => {
self.closed_emitted = true;
let reason = self
.end_reason
.lock()
.unwrap()
.take()
.unwrap_or(EndReason::IdleTimeout);
Some(ConversationChunk::Closed { reason })
}
}
}
}
pub struct ConversationFactory<K>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
{
pending_emit: Arc<Mutex<VecDeque<Conversation<K>>>>,
in_flight: Arc<Mutex<HashMap<K, Weak<ConvShared>, RandomState>>>,
channel_capacity: usize,
}
struct ConvShared {
tx: mpsc::Sender<(FlowSide, Bytes)>,
end_reason: Arc<Mutex<Option<EndReason>>>,
}
impl<K> ConversationFactory<K>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
{
fn new(channel_capacity: usize) -> Self {
Self {
pending_emit: Arc::new(Mutex::new(VecDeque::new())),
in_flight: Arc::new(Mutex::new(HashMap::with_hasher(RandomState::new()))),
channel_capacity,
}
}
fn pending(&self) -> Arc<Mutex<VecDeque<Conversation<K>>>> {
self.pending_emit.clone()
}
}
impl<K> AsyncReassemblerFactory<K> for ConversationFactory<K>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
{
type Reassembler = ConvSideReassembler;
fn new_reassembler(&mut self, key: &K, side: FlowSide) -> ConvSideReassembler {
let mut in_flight = self.in_flight.lock().unwrap();
let shared = match in_flight.get(key).and_then(Weak::upgrade) {
Some(s) => s,
None => {
let (tx, rx) = mpsc::channel(self.channel_capacity);
let end_reason = Arc::new(Mutex::new(None));
self.pending_emit.lock().unwrap().push_back(Conversation {
key: key.clone(),
rx,
end_reason: end_reason.clone(),
closed_emitted: false,
});
let s = Arc::new(ConvShared { tx, end_reason });
in_flight.insert(key.clone(), Arc::downgrade(&s));
s
}
};
ConvSideReassembler { shared, side }
}
}
pub struct ConvSideReassembler {
shared: Arc<ConvShared>,
side: FlowSide,
}
impl AsyncReassembler for ConvSideReassembler {
fn segment(
&mut self,
_seq: u32,
payload: Bytes,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>> {
let shared = self.shared.clone();
let side = self.side;
Box::pin(async move {
let _ = shared.tx.send((side, payload)).await;
})
}
fn fin(&mut self) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>> {
let shared = self.shared.clone();
Box::pin(async move {
let mut g = shared.end_reason.lock().unwrap();
if g.is_none() {
*g = Some(EndReason::Fin);
}
})
}
fn rst(&mut self) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>> {
let shared = self.shared.clone();
Box::pin(async move {
let mut g = shared.end_reason.lock().unwrap();
*g = Some(EndReason::Rst);
})
}
}
type ConvInnerStream<S, E> = FlowStream<
S,
E,
(),
AsyncReassemblerSlot<<E as FlowExtractor>::Key, ConversationFactory<<E as FlowExtractor>::Key>>,
>;
pub struct ConversationStream<S, E>
where
S: PacketSource + std::os::unix::io::AsRawFd,
E: FlowExtractor,
E::Key: Eq + Hash + Clone + Send + Sync + 'static,
{
inner: ConvInnerStream<S, E>,
pending: Arc<Mutex<VecDeque<Conversation<E::Key>>>>,
}
impl<S, E> Stream for ConversationStream<S, E>
where
S: PacketSource + std::os::unix::io::AsRawFd + Unpin,
E: FlowExtractor + Unpin,
E::Key: Eq + Hash + Clone + Send + Sync + Unpin + 'static,
{
type Item = Result<Conversation<E::Key>, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
{
let mut p = this.pending.lock().unwrap();
if let Some(conv) = p.pop_front() {
return Poll::Ready(Some(Ok(conv)));
}
}
let inner_pinned = Pin::new(&mut this.inner);
match inner_pinned.poll_next(cx) {
Poll::Ready(Some(Ok(_evt))) => continue,
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
}
}
}
}
impl<S, E> FlowStream<S, E, (), NoReassembler>
where
S: PacketSource + std::os::unix::io::AsRawFd,
E: FlowExtractor,
E::Key: Eq + Hash + Clone + Send + Sync + 'static,
{
pub fn into_conversations(self) -> ConversationStream<S, E> {
self.into_conversations_with_capacity(64)
}
pub fn into_conversations_with_capacity(self, capacity: usize) -> ConversationStream<S, E> {
let factory = ConversationFactory::<E::Key>::new(capacity);
let pending = factory.pending();
let inner = self.with_async_reassembler(factory);
ConversationStream { inner, pending }
}
}
impl<S> AsyncCapture<S>
where
S: PacketSource + std::os::unix::io::AsRawFd,
{
pub fn flow_conversations<E>(self, extractor: E) -> ConversationStream<S, E>
where
E: FlowExtractor,
E::Key: Eq + Hash + Clone + Send + Sync + 'static,
{
self.flow_stream(extractor).into_conversations()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test(flavor = "current_thread")]
async fn factory_emits_one_conversation_per_flow() {
let mut f = ConversationFactory::<u32>::new(8);
let pending = f.pending();
let _r_a_init = f.new_reassembler(&1u32, FlowSide::Initiator);
let _r_a_resp = f.new_reassembler(&1u32, FlowSide::Responder);
let _r_b_init = f.new_reassembler(&2u32, FlowSide::Initiator);
let queued: Vec<_> = pending.lock().unwrap().drain(..).collect();
assert_eq!(queued.len(), 2, "expected 2 conversations");
assert_eq!(queued[0].key, 1);
assert_eq!(queued[1].key, 2);
}
#[tokio::test(flavor = "current_thread")]
async fn segment_dispatch_round_trips() {
let mut f = ConversationFactory::<u32>::new(8);
let pending = f.pending();
let mut r_init = f.new_reassembler(&7u32, FlowSide::Initiator);
let mut r_resp = f.new_reassembler(&7u32, FlowSide::Responder);
let mut conv = pending.lock().unwrap().pop_front().unwrap();
r_init.segment(0, Bytes::from_static(b"hello")).await;
r_resp.segment(0, Bytes::from_static(b"world")).await;
match conv.next_chunk().await.unwrap() {
ConversationChunk::Initiator(b) => assert_eq!(&*b, b"hello"),
other => panic!("expected Initiator(hello), got {other:?}"),
}
match conv.next_chunk().await.unwrap() {
ConversationChunk::Responder(b) => assert_eq!(&*b, b"world"),
other => panic!("expected Responder(world), got {other:?}"),
}
}
#[tokio::test(flavor = "current_thread")]
async fn fin_emits_closed_with_fin_reason() {
let mut f = ConversationFactory::<u32>::new(8);
let pending = f.pending();
let mut r_init = f.new_reassembler(&7u32, FlowSide::Initiator);
let mut r_resp = f.new_reassembler(&7u32, FlowSide::Responder);
let mut conv = pending.lock().unwrap().pop_front().unwrap();
r_init.segment(0, Bytes::from_static(b"x")).await;
r_init.fin().await;
drop(r_init);
r_resp.fin().await;
drop(r_resp);
let c1 = conv.next_chunk().await.unwrap();
assert!(matches!(c1, ConversationChunk::Initiator(_)));
let c2 = conv.next_chunk().await.unwrap();
assert!(
matches!(
c2,
ConversationChunk::Closed {
reason: EndReason::Fin
}
),
"expected Closed{{Fin}}, got {c2:?}"
);
assert!(conv.next_chunk().await.is_none());
}
#[tokio::test(flavor = "current_thread")]
async fn rst_emits_closed_with_rst_reason() {
let mut f = ConversationFactory::<u32>::new(8);
let pending = f.pending();
let mut r_init = f.new_reassembler(&7u32, FlowSide::Initiator);
let mut r_resp = f.new_reassembler(&7u32, FlowSide::Responder);
let mut conv = pending.lock().unwrap().pop_front().unwrap();
r_init.rst().await;
drop(r_init);
r_resp.rst().await;
drop(r_resp);
let c = conv.next_chunk().await.unwrap();
assert!(matches!(
c,
ConversationChunk::Closed {
reason: EndReason::Rst
}
));
}
#[tokio::test(flavor = "current_thread")]
async fn unidirectional_flow_works() {
let mut f = ConversationFactory::<u32>::new(8);
let pending = f.pending();
let mut r_init = f.new_reassembler(&7u32, FlowSide::Initiator);
let mut r_resp = f.new_reassembler(&7u32, FlowSide::Responder);
let mut conv = pending.lock().unwrap().pop_front().unwrap();
r_init
.segment(0, Bytes::from_static(b"only-initiator"))
.await;
r_init.fin().await;
drop(r_init);
r_resp.fin().await;
drop(r_resp);
let c1 = conv.next_chunk().await.unwrap();
assert!(matches!(c1, ConversationChunk::Initiator(_)));
let c2 = conv.next_chunk().await.unwrap();
assert!(matches!(
c2,
ConversationChunk::Closed {
reason: EndReason::Fin
}
));
}
}