use futures::{SinkExt, StreamExt};
use tokio::net::TcpListener;
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use wireframe::{
WireframeError,
client::ClientError,
correlation::CorrelatableFrame,
serializer::{BincodeSerializer, Serializer},
};
use super::{
ClientStreamingWorld,
TestResult,
TypedStreamingItem,
server::{
build_interleaved_priority_frames,
build_rate_limited_priority_frames,
send_data_and_terminator,
send_data_frames,
send_mismatch_frame,
},
types::{CorrelationId, MessageId, Payload, StreamTestEnvelope},
};
pub enum StreamingServerMode {
Normal { data_count: usize },
ControlInterleaved,
Mismatch,
Disconnect { data_count: usize },
InterleavedPriorities,
SharedRateLimit,
}
const SHARED_RATE_LIMIT_CONTENTION_MARKER: u8 = 99;
const SHARED_RATE_LIMIT_NO_CONTENTION_MARKER: u8 = 98;
async fn send_stream_frame<T>(
framed_transport: &mut Framed<T, LengthDelimitedCodec>,
frame: StreamTestEnvelope,
) -> bool
where
T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
let Ok(encoded_frame) = frame.serialize_to_bytes() else {
return false;
};
framed_transport.send(encoded_frame).await.is_ok()
}
async fn send_stream_frames<T>(
framed_transport: &mut Framed<T, LengthDelimitedCodec>,
frames: Vec<StreamTestEnvelope>,
) -> bool
where
T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
for frame in frames {
if !send_stream_frame(framed_transport, frame).await {
return false;
}
}
true
}
async fn send_shared_rate_limit_frames<T>(
framed_transport: &mut Framed<T, LengthDelimitedCodec>,
cid: CorrelationId,
) -> bool
where
T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
let Ok((generated_frames, was_blocked)) = build_rate_limited_priority_frames(cid).await else {
return false;
};
let marker_value = if was_blocked {
SHARED_RATE_LIMIT_CONTENTION_MARKER
} else {
SHARED_RATE_LIMIT_NO_CONTENTION_MARKER
};
let marker_frame =
StreamTestEnvelope::data(MessageId::new(250), cid, Payload::new(vec![marker_value]));
if !send_stream_frame(framed_transport, marker_frame).await {
return false;
}
send_stream_frames(framed_transport, generated_frames).await
}
async fn run_streaming_mode<T>(
framed_transport: &mut Framed<T, LengthDelimitedCodec>,
mode: StreamingServerMode,
cid: CorrelationId,
) -> bool
where
T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
match mode {
StreamingServerMode::Normal { data_count } => {
send_data_and_terminator(framed_transport, cid, data_count).await;
true
}
StreamingServerMode::ControlInterleaved => {
let frames = vec![
StreamTestEnvelope::data(MessageId::new(1), cid, Payload::new(vec![1])),
StreamTestEnvelope::data(MessageId::new(200), cid, Payload::new(vec![200])),
StreamTestEnvelope::data(MessageId::new(2), cid, Payload::new(vec![2])),
StreamTestEnvelope::data(MessageId::new(201), cid, Payload::new(vec![201])),
StreamTestEnvelope::terminator(cid),
];
send_stream_frames(framed_transport, frames).await
}
StreamingServerMode::Mismatch => {
send_mismatch_frame(framed_transport, cid).await;
true
}
StreamingServerMode::Disconnect { data_count } => {
send_data_frames(framed_transport, cid, data_count).await;
true
}
StreamingServerMode::InterleavedPriorities => {
let Ok(generated_frames) = build_interleaved_priority_frames(cid).await else {
return false;
};
send_stream_frames(framed_transport, generated_frames).await
}
StreamingServerMode::SharedRateLimit => {
send_shared_rate_limit_frames(framed_transport, cid).await
}
}
}
impl ClientStreamingWorld {
pub async fn start_normal_server(&mut self, data_count: usize) -> TestResult {
self.start_server(StreamingServerMode::Normal { data_count })
.await
}
pub async fn start_mismatch_server(&mut self) -> TestResult {
self.start_server(StreamingServerMode::Mismatch).await
}
pub async fn start_control_interleaved_server(&mut self) -> TestResult {
self.start_server(StreamingServerMode::ControlInterleaved)
.await
}
pub async fn start_disconnect_server(&mut self, data_count: usize) -> TestResult {
self.start_server(StreamingServerMode::Disconnect { data_count })
.await
}
pub async fn start_interleaved_priority_server(&mut self) -> TestResult {
self.start_server(StreamingServerMode::InterleavedPriorities)
.await
}
pub async fn start_shared_rate_limit_server(&mut self) -> TestResult {
self.start_server(StreamingServerMode::SharedRateLimit)
.await
}
async fn start_server(&mut self, mode: StreamingServerMode) -> TestResult {
self.abort_server();
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let handle = tokio::spawn(async move {
let Ok((stream, _)) = listener.accept().await else {
return;
};
let mut framed = Framed::new(stream, LengthDelimitedCodec::new());
let Some(Ok(req_bytes)) = framed.next().await else {
return;
};
let Ok((req, _)): Result<(StreamTestEnvelope, usize), _> =
BincodeSerializer.deserialize(&req_bytes)
else {
return;
};
let cid = CorrelationId::new(req.correlation_id().unwrap_or(1));
let _ = run_streaming_mode(&mut framed, mode, cid).await;
});
self.addr = Some(addr);
self.server = Some(handle);
Ok(())
}
pub fn verify_frame_count(&self, expected: usize) -> TestResult {
let actual = self.received_frames.len();
if actual != expected {
return Err(format!("expected {expected} frames, got {actual}").into());
}
Ok(())
}
pub fn verify_frame_order(&self) -> TestResult {
for (i, frame) in self.received_frames.iter().enumerate() {
let payload_byte =
u8::try_from(i + 1).map_err(|e| format!("frame index {i} overflows u8: {e}"))?;
let expected = Payload::new(vec![payload_byte]);
if frame.payload != expected {
return Err(format!(
"frame {i}: expected payload {expected:?}, got {:?}",
frame.payload
)
.into());
}
}
Ok(())
}
pub fn verify_typed_item_order(&self, expected: &[u8]) -> TestResult {
let actual: Vec<u8> = self
.typed_items
.iter()
.map(TypedStreamingItem::value)
.collect();
if actual != expected {
return Err(format!("expected typed items {expected:?}, got {actual:?}").into());
}
Ok(())
}
pub fn verify_clean_termination(&self) -> TestResult {
if !self.stream_terminated_cleanly {
return Err("stream did not terminate cleanly".into());
}
Ok(())
}
pub fn verify_interleaved_priority_order(&self) -> TestResult {
let expected = vec![
Payload::new(vec![1]),
Payload::new(vec![2]),
Payload::new(vec![3]),
Payload::new(vec![4]),
Payload::new(vec![10]),
Payload::new(vec![11]),
];
let actual: Vec<Payload> = self
.received_frames
.iter()
.map(|frame| frame.payload.clone())
.collect();
if actual != expected {
return Err(format!(
"expected interleaved priority payloads {expected:?}, got {actual:?}",
)
.into());
}
Ok(())
}
pub fn verify_shared_rate_limit_symmetry(&mut self) -> TestResult {
let marker = self
.received_frames
.first()
.ok_or("missing rate-limit marker frame")?
.payload
.clone()
.into_inner();
let was_blocked = marker == vec![SHARED_RATE_LIMIT_CONTENTION_MARKER];
self.shared_rate_limit_blocked = Some(was_blocked);
if !was_blocked {
return Err("expected shared limiter contention marker".into());
}
let remaining: Vec<Vec<u8>> = self
.received_frames
.iter()
.skip(1)
.map(|frame| frame.payload.clone().into_inner())
.collect();
if remaining != vec![vec![1], vec![2]] {
return Err(format!(
"unexpected payload order under shared rate limiting: {remaining:?}",
)
.into());
}
Ok(())
}
pub fn verify_correlation_mismatch_error(&self) -> TestResult {
match &self.last_error {
Some(ClientError::StreamCorrelationMismatch { .. }) => Ok(()),
Some(err) => Err(format!("expected StreamCorrelationMismatch, got {err:?}").into()),
None => Err("expected StreamCorrelationMismatch, but no error".into()),
}
}
pub fn verify_disconnect_error(&self) -> TestResult {
match &self.last_error {
Some(ClientError::Wireframe(WireframeError::Io(_))) => Ok(()),
Some(err) => Err(format!("expected transport error, got {err:?}").into()),
None => Err("expected transport error, but no error".into()),
}
}
pub fn abort_server(&mut self) {
if let Some(handle) = self.server.take() {
handle.abort();
}
}
}