#![allow(missing_docs)]
use std::future::Future;
use std::pin::Pin;
use bytes::Bytes;
use flowscope::FlowSide;
use tokio::sync::mpsc;
pub trait AsyncReassembler: Send + 'static {
fn segment(
&mut self,
seq: u32,
payload: Bytes,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
fn fin(&mut self) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>> {
Box::pin(async {})
}
fn rst(&mut self) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>> {
Box::pin(async {})
}
}
pub trait AsyncReassemblerFactory<K>: Send + 'static {
type Reassembler: AsyncReassembler;
fn new_reassembler(&mut self, key: &K, side: FlowSide) -> Self::Reassembler;
}
pub fn channel_factory<K, F>(make_sender: F) -> ChannelFactory<K, F>
where
F: FnMut(&K, FlowSide) -> mpsc::Sender<Bytes> + Send + 'static,
K: Clone + Send + 'static,
{
ChannelFactory {
make_sender,
_phantom: std::marker::PhantomData,
}
}
pub struct ChannelFactory<K, F> {
make_sender: F,
_phantom: std::marker::PhantomData<fn(&K)>,
}
impl<K, F> AsyncReassemblerFactory<K> for ChannelFactory<K, F>
where
F: FnMut(&K, FlowSide) -> mpsc::Sender<Bytes> + Send + 'static,
K: Clone + Send + 'static,
{
type Reassembler = ChannelReassembler;
fn new_reassembler(&mut self, key: &K, side: FlowSide) -> ChannelReassembler {
ChannelReassembler {
tx: Some((self.make_sender)(key, side)),
}
}
}
pub struct ChannelReassembler {
tx: Option<mpsc::Sender<Bytes>>,
}
impl AsyncReassembler for ChannelReassembler {
fn segment(
&mut self,
_seq: u32,
payload: Bytes,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>> {
let tx = self.tx.clone();
Box::pin(async move {
if let Some(tx) = tx {
let _ = tx.send(payload).await;
}
})
}
fn fin(&mut self) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>> {
self.tx = None;
Box::pin(async {})
}
fn rst(&mut self) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>> {
self.tx = None;
Box::pin(async {})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test(flavor = "current_thread")]
async fn channel_factory_dispatches_per_flow_and_side() {
let counts = std::sync::Arc::new(std::sync::Mutex::new(Vec::<(String, FlowSide)>::new()));
let counts_clone = counts.clone();
let mut factory = channel_factory(move |key: &String, side: FlowSide| {
counts_clone.lock().unwrap().push((key.clone(), side));
let (tx, _rx) = mpsc::channel::<Bytes>(8);
tx
});
let _r1 = factory.new_reassembler(&"flow-A".to_string(), FlowSide::Initiator);
let _r2 = factory.new_reassembler(&"flow-A".to_string(), FlowSide::Responder);
let _r3 = factory.new_reassembler(&"flow-B".to_string(), FlowSide::Initiator);
let recorded = counts.lock().unwrap();
assert_eq!(recorded.len(), 3);
assert_eq!(recorded[0].0, "flow-A");
assert_eq!(recorded[0].1, FlowSide::Initiator);
assert_eq!(recorded[1].1, FlowSide::Responder);
assert_eq!(recorded[2].0, "flow-B");
}
#[tokio::test(flavor = "current_thread")]
async fn segment_pushes_to_channel() {
let (tx, mut rx) = mpsc::channel::<Bytes>(4);
let mut r = ChannelReassembler { tx: Some(tx) };
r.segment(0, Bytes::from_static(b"abc")).await;
r.segment(3, Bytes::from_static(b"def")).await;
assert_eq!(rx.recv().await.unwrap(), Bytes::from_static(b"abc"));
assert_eq!(rx.recv().await.unwrap(), Bytes::from_static(b"def"));
}
#[tokio::test(flavor = "current_thread")]
async fn fin_closes_channel() {
let (tx, mut rx) = mpsc::channel::<Bytes>(4);
let mut r = ChannelReassembler { tx: Some(tx) };
r.segment(0, Bytes::from_static(b"x")).await;
r.fin().await;
assert_eq!(rx.recv().await.unwrap(), Bytes::from_static(b"x"));
assert_eq!(rx.recv().await, None);
}
}