#![cfg(not(loom))]
use std::sync::{
Arc,
OnceLock,
atomic::{AtomicBool, Ordering},
};
use log::Level as LogLevel;
use logtest as flexi_logger;
use rstest::rstest;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use wireframe::{
app::{Envelope, Packet, PacketParts},
connection::{ConnectionActor, ConnectionChannels, FairnessConfig},
hooks::{ConnectionContext, ProtocolHooks},
push::{PushHandle, PushQueues},
};
use wireframe_testing::{LoggerHandle, TestResult, logger};
const STREAM_ID: u32 = 7;
const TERMINATOR_ID: u32 = 255;
fn terminator_frame() -> Envelope { Envelope::new(TERMINATOR_ID, None, Vec::new()) }
fn envelope_with_payload(id: u32, correlation: Option<u64>, payload: &[u8]) -> Envelope {
Envelope::new(id, correlation, payload.to_vec())
}
struct ActorHarness {
actor: ConnectionActor<Envelope, ()>,
handle: Option<PushHandle<Envelope>>,
}
impl ActorHarness {
fn new() -> TestResult<Self> {
let (queues, handle) = PushQueues::<Envelope>::builder()
.high_capacity(4)
.low_capacity(4)
.unlimited()
.build()?;
let shared_handle: Arc<OnceLock<PushHandle<Envelope>>> = Arc::new(OnceLock::new());
let duplicate_handle = Arc::new(AtomicBool::new(false));
let handle_slot = Arc::clone(&shared_handle);
let duplicate_flag = Arc::clone(&duplicate_handle);
let hooks = ProtocolHooks {
on_connection_setup: Some(Box::new(move |handle, _ctx| {
if handle_slot.set(handle).is_err() {
duplicate_flag.store(true, Ordering::Relaxed);
}
})),
stream_end: Some(Box::new(|_ctx: &mut ConnectionContext| {
Some(terminator_frame())
})),
..ProtocolHooks::default()
};
let shutdown = CancellationToken::new();
let actor = ConnectionActor::with_hooks(
ConnectionChannels::new(queues, handle),
None,
shutdown,
hooks,
);
let shared_handle =
Arc::try_unwrap(shared_handle).map_err(|_| "push handle still shared at teardown")?;
let handle = shared_handle
.into_inner()
.ok_or("connection setup hook did not run")?;
if duplicate_handle.load(Ordering::Relaxed) {
return Err("push handle already captured".into());
}
Ok(Self {
actor,
handle: Some(handle),
})
}
fn handle(&self) -> TestResult<&PushHandle<Envelope>> {
self.handle
.as_ref()
.ok_or_else(|| "push handle already released".into())
}
fn release_handle(&mut self) { self.handle.take(); }
async fn run(&mut self) -> TestResult<Vec<Envelope>> {
let mut out = Vec::new();
self.actor.run(&mut out).await.map_err(|e| {
Box::new(std::io::Error::other(format!(
"connection actor run failed: {e:?}"
))) as Box<dyn std::error::Error + Send + Sync>
})?;
Ok(out)
}
}
fn parts(frame: &Envelope) -> PacketParts { frame.clone().into_parts() }
#[tokio::test]
#[expect(
clippy::panic_in_result_fn,
reason = "asserts provide clearer diagnostics in tests"
)]
async fn client_receives_multi_packet_stream_with_terminator() -> TestResult<()> {
let mut harness = ActorHarness::new()?;
let (tx, rx) = mpsc::channel(4);
let correlation = Some(88_u64);
for chunk in [&[1_u8][..], &[2, 3][..]] {
tx.send(envelope_with_payload(STREAM_ID, None, chunk))
.await
.map_err(|e| format!("send frame: {e}"))?;
}
drop(tx);
harness
.actor
.set_multi_packet_with_correlation(Some(rx), correlation)
.map_err(|e| format!("set_multi_packet_with_correlation: {e}"))?;
harness.release_handle();
let out = harness.run().await?;
assert_eq!(out.len(), 3, "expected two frames plus terminator");
let payloads: Vec<Vec<u8>> = out
.iter()
.map(|frame| parts(frame).into_payload())
.collect();
assert_eq!(payloads.first(), Some(&vec![1]), "first payload mismatch");
assert_eq!(
payloads.get(1),
Some(&vec![2, 3]),
"second payload mismatch"
);
assert_eq!(
payloads.get(2),
Some(&Vec::<u8>::new()),
"terminator payload should be empty"
);
for frame in &out {
assert_eq!(
parts(frame).correlation_id(),
correlation,
"correlation id mismatch"
);
}
Ok(())
}
fn is_disconnect_log(record: &flexi_logger::Record) -> bool {
record.level() == LogLevel::Warn
&& record.args().contains("multi-packet stream closed")
&& record.args().contains("reason=disconnected")
}
#[rstest]
#[tokio::test]
async fn multi_packet_logs_disconnected_when_sender_dropped(
mut logger: LoggerHandle,
) -> TestResult<()> {
logger.clear();
let mut harness = ActorHarness::new()?;
let (tx, rx) = mpsc::channel(1);
let correlation = Some(41_u64);
drop(tx);
harness
.actor
.set_multi_packet_with_correlation(Some(rx), correlation)
.map_err(|e| format!("set_multi_packet_with_correlation: {e}"))?;
harness.actor.set_fairness(interleaving_fairness());
harness
.handle()?
.push_high_priority(envelope_with_payload(11, Some(5), b"hi"))
.await?;
harness.release_handle();
let out = harness.run().await?;
assert_eq!(out.len(), 2, "expected push frame followed by terminator");
let last = out.last().expect("terminator missing");
assert_eq!(
parts(last).correlation_id(),
correlation,
"terminator correlation mismatch"
);
let mut saw_disconnect = false;
while let Some(record) = logger.pop() {
if is_disconnect_log(&record) {
saw_disconnect = true;
break;
}
}
assert!(saw_disconnect, "missing disconnect log");
Ok(())
}
struct FrameSpec {
id: u32,
correlation: u64,
payload: &'static [u8],
}
const HIGH_PRIORITY_FRAMES: [FrameSpec; 2] = [
FrameSpec {
id: 2,
correlation: 1,
payload: b"A",
},
FrameSpec {
id: 4,
correlation: 3,
payload: b"C",
},
];
const LOW_PRIORITY_FRAMES: [FrameSpec; 2] = [
FrameSpec {
id: 3,
correlation: 2,
payload: b"B",
},
FrameSpec {
id: 5,
correlation: 4,
payload: b"D",
},
];
enum PushPriority {
High,
Low,
}
fn interleaving_fairness() -> FairnessConfig {
FairnessConfig {
max_high_before_low: 1,
time_slice: None,
}
}
async fn push_sequence(
handle: &PushHandle<Envelope>,
priority: PushPriority,
frames: &[FrameSpec],
) -> TestResult<()> {
for spec in frames {
let envelope = envelope_with_payload(spec.id, Some(spec.correlation), spec.payload);
match priority {
PushPriority::High => handle.push_high_priority(envelope).await?,
PushPriority::Low => handle.push_low_priority(envelope).await?,
}
}
Ok(())
}
async fn setup_stream_channel(payloads: &[&[u8]]) -> TestResult<mpsc::Receiver<Envelope>> {
let capacity = payloads.len().max(1);
let (tx, rx) = mpsc::channel(capacity);
for payload in payloads {
tx.send(envelope_with_payload(STREAM_ID, None, payload))
.await
.map_err(|e| format!("send frame to multi-packet stream: {e}"))?;
}
drop(tx);
Ok(rx)
}
async fn push_interleaved_frames(handle: &PushHandle<Envelope>) -> TestResult<()> {
push_sequence(handle, PushPriority::High, &HIGH_PRIORITY_FRAMES).await?;
push_sequence(handle, PushPriority::Low, &LOW_PRIORITY_FRAMES).await?;
Ok(())
}
fn assert_correlation_ordering(frames: &[Envelope], expected: &[Option<u64>]) {
let correlations: Vec<Option<u64>> = frames
.iter()
.map(|frame| parts(frame).correlation_id())
.collect();
assert_eq!(correlations, expected, "unexpected correlation ordering");
}
fn assert_frame_identities(frames: &[Envelope], expected: &[u32]) {
let ids: Vec<u32> = frames.iter().map(|frame| parts(frame).id()).collect();
assert_eq!(
ids, expected,
"frame sequence did not preserve request identities",
);
}
#[tokio::test]
async fn interleaved_multi_packet_and_push_frames_preserve_correlations() -> TestResult<()> {
let mut harness = ActorHarness::new()?;
let stream_correlation = Some(73_u64);
let rx = setup_stream_channel(&[&[10_u8][..], &[20][..], &[30][..]]).await?;
harness
.actor
.set_multi_packet_with_correlation(Some(rx), stream_correlation)
.map_err(|e| format!("set_multi_packet_with_correlation: {e}"))?;
harness.actor.set_fairness(interleaving_fairness());
push_interleaved_frames(harness.handle()?).await?;
harness.release_handle();
let frames = harness.run().await?;
assert_correlation_ordering(
&frames,
&[
Some(1),
Some(2),
Some(3),
Some(4),
stream_correlation,
stream_correlation,
stream_correlation,
stream_correlation,
],
);
assert_frame_identities(
&frames,
&[2, 3, 4, 5, STREAM_ID, STREAM_ID, STREAM_ID, TERMINATOR_ID],
);
Ok(())
}