use async_stream::try_stream;
use rstest::fixture;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use wireframe::{
app::Envelope,
connection::ConnectionActor,
correlation::CorrelatableFrame,
response::FrameStream,
};
pub use wireframe_testing::TestResult;
use crate::build_small_queues;
#[derive(Debug, Default)]
pub struct CorrelationWorld {
expected: Option<u64>,
frames: Vec<Envelope>,
}
#[rustfmt::skip]
#[fixture]
pub fn correlation_world() -> CorrelationWorld {
CorrelationWorld::default()
}
impl CorrelationWorld {
pub fn set_expected(&mut self, expected: Option<u64>) { self.expected = expected; }
#[must_use]
pub fn expected(&self) -> Option<u64> { self.expected }
pub async fn process(&mut self) -> TestResult {
let cid = self
.expected
.ok_or("streaming scenario requires a correlation id")?;
let stream: FrameStream<Envelope> = Box::pin(try_stream! {
yield Envelope::new(1, Some(cid), vec![1]);
yield Envelope::new(1, Some(cid), vec![2]);
});
let (queues, handle) = build_small_queues::<Envelope>()?;
let shutdown = CancellationToken::new();
let mut actor = ConnectionActor::new(queues, handle, Some(stream), shutdown);
actor
.run(&mut self.frames)
.await
.map_err(|e| format!("actor run failed: {e:?}"))?;
Ok(())
}
pub async fn process_multi(&mut self) -> TestResult {
let expected = self.expected;
let (tx, rx) = mpsc::channel(4);
tx.send(Envelope::new(1, None, vec![1])).await?;
tx.send(Envelope::new(1, Some(99), vec![2])).await?;
drop(tx);
let (queues, handle) = build_small_queues::<Envelope>()?;
let shutdown = CancellationToken::new();
let mut actor: ConnectionActor<Envelope, ()> =
ConnectionActor::new(queues, handle, None, shutdown);
actor
.set_multi_packet_with_correlation(Some(rx), expected)
.map_err(|e| format!("set_multi_packet_with_correlation failed: {e}"))?;
actor
.run(&mut self.frames)
.await
.map_err(|e| format!("actor run failed: {e:?}"))?;
Ok(())
}
pub fn verify(&self) -> TestResult {
let ok = match self.expected {
Some(cid) => self.frames.iter().all(|f| f.correlation_id() == Some(cid)),
None => self.frames.iter().all(|f| f.correlation_id().is_none()),
};
if ok {
return Ok(());
}
match self.expected {
Some(cid) => Err(format!("frames missing expected correlation id {cid}").into()),
None => Err("frames unexpectedly carried correlation id".into()),
}
}
}