wireframe 0.3.0

Simplify building servers and clients for custom binary protocols.
Documentation
//! Streaming modes and verification helpers for `ClientStreamingWorld`.

use futures::{SinkExt, StreamExt};
use tokio::net::TcpListener;
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use wireframe::{
    WireframeError,
    client::ClientError,
    correlation::CorrelatableFrame,
    serializer::{BincodeSerializer, Serializer},
};

use super::{
    ClientStreamingWorld,
    TestResult,
    TypedStreamingItem,
    server::{
        build_interleaved_priority_frames,
        build_rate_limited_priority_frames,
        send_data_and_terminator,
        send_data_frames,
        send_mismatch_frame,
    },
    types::{CorrelationId, MessageId, Payload, StreamTestEnvelope},
};

/// Mode controlling how the streaming test server behaves.
pub enum StreamingServerMode {
    /// Send `data_count` data frames then a terminator.
    Normal { data_count: usize },
    /// Send data frames interleaved with control frames, then a terminator.
    ControlInterleaved,
    /// Send one frame with a wrong correlation ID.
    Mismatch,
    /// Send `data_count` data frames then drop the connection.
    Disconnect { data_count: usize },
    /// Emit frames generated by interleaving high- and low-priority queues.
    InterleavedPriorities,
    /// Emit frames generated while enforcing a shared cross-priority rate limit.
    SharedRateLimit,
}

const SHARED_RATE_LIMIT_CONTENTION_MARKER: u8 = 99;
const SHARED_RATE_LIMIT_NO_CONTENTION_MARKER: u8 = 98;

async fn send_stream_frame<T>(
    framed_transport: &mut Framed<T, LengthDelimitedCodec>,
    frame: StreamTestEnvelope,
) -> bool
where
    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
    let Ok(encoded_frame) = frame.serialize_to_bytes() else {
        return false;
    };
    framed_transport.send(encoded_frame).await.is_ok()
}

async fn send_stream_frames<T>(
    framed_transport: &mut Framed<T, LengthDelimitedCodec>,
    frames: Vec<StreamTestEnvelope>,
) -> bool
where
    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
    for frame in frames {
        if !send_stream_frame(framed_transport, frame).await {
            return false;
        }
    }
    true
}

async fn send_shared_rate_limit_frames<T>(
    framed_transport: &mut Framed<T, LengthDelimitedCodec>,
    cid: CorrelationId,
) -> bool
where
    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
    let Ok((generated_frames, was_blocked)) = build_rate_limited_priority_frames(cid).await else {
        return false;
    };

    let marker_value = if was_blocked {
        SHARED_RATE_LIMIT_CONTENTION_MARKER
    } else {
        SHARED_RATE_LIMIT_NO_CONTENTION_MARKER
    };
    let marker_frame =
        StreamTestEnvelope::data(MessageId::new(250), cid, Payload::new(vec![marker_value]));
    if !send_stream_frame(framed_transport, marker_frame).await {
        return false;
    }

    send_stream_frames(framed_transport, generated_frames).await
}

async fn run_streaming_mode<T>(
    framed_transport: &mut Framed<T, LengthDelimitedCodec>,
    mode: StreamingServerMode,
    cid: CorrelationId,
) -> bool
where
    T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
    match mode {
        StreamingServerMode::Normal { data_count } => {
            send_data_and_terminator(framed_transport, cid, data_count).await;
            true
        }
        StreamingServerMode::ControlInterleaved => {
            let frames = vec![
                StreamTestEnvelope::data(MessageId::new(1), cid, Payload::new(vec![1])),
                StreamTestEnvelope::data(MessageId::new(200), cid, Payload::new(vec![200])),
                StreamTestEnvelope::data(MessageId::new(2), cid, Payload::new(vec![2])),
                StreamTestEnvelope::data(MessageId::new(201), cid, Payload::new(vec![201])),
                StreamTestEnvelope::terminator(cid),
            ];
            send_stream_frames(framed_transport, frames).await
        }
        StreamingServerMode::Mismatch => {
            send_mismatch_frame(framed_transport, cid).await;
            true
        }
        StreamingServerMode::Disconnect { data_count } => {
            send_data_frames(framed_transport, cid, data_count).await;
            true
        }
        StreamingServerMode::InterleavedPriorities => {
            let Ok(generated_frames) = build_interleaved_priority_frames(cid).await else {
                return false;
            };
            send_stream_frames(framed_transport, generated_frames).await
        }
        StreamingServerMode::SharedRateLimit => {
            send_shared_rate_limit_frames(framed_transport, cid).await
        }
    }
}

impl ClientStreamingWorld {
    /// Start a streaming server that sends `data_count` frames + terminator.
    pub async fn start_normal_server(&mut self, data_count: usize) -> TestResult {
        self.start_server(StreamingServerMode::Normal { data_count })
            .await
    }

    /// Start a streaming server that returns a mismatched correlation ID.
    pub async fn start_mismatch_server(&mut self) -> TestResult {
        self.start_server(StreamingServerMode::Mismatch).await
    }

    /// Start a server that interleaves control frames with data frames.
    pub async fn start_control_interleaved_server(&mut self) -> TestResult {
        self.start_server(StreamingServerMode::ControlInterleaved)
            .await
    }

    /// Start a server that sends `data_count` frames then disconnects.
    pub async fn start_disconnect_server(&mut self, data_count: usize) -> TestResult {
        self.start_server(StreamingServerMode::Disconnect { data_count })
            .await
    }

    /// Start a server that emits fairness-checked interleaved priority frames.
    pub async fn start_interleaved_priority_server(&mut self) -> TestResult {
        self.start_server(StreamingServerMode::InterleavedPriorities)
            .await
    }

    /// Start a server that emits rate-limited priority frames.
    pub async fn start_shared_rate_limit_server(&mut self) -> TestResult {
        self.start_server(StreamingServerMode::SharedRateLimit)
            .await
    }

    async fn start_server(&mut self, mode: StreamingServerMode) -> TestResult {
        self.abort_server();

        let listener = TcpListener::bind("127.0.0.1:0").await?;
        let addr = listener.local_addr()?;

        let handle = tokio::spawn(async move {
            let Ok((stream, _)) = listener.accept().await else {
                return;
            };
            let mut framed = Framed::new(stream, LengthDelimitedCodec::new());

            // Read the client's request to extract the correlation ID.
            let Some(Ok(req_bytes)) = framed.next().await else {
                return;
            };
            let Ok((req, _)): Result<(StreamTestEnvelope, usize), _> =
                BincodeSerializer.deserialize(&req_bytes)
            else {
                return;
            };
            let cid = CorrelationId::new(req.correlation_id().unwrap_or(1));

            let _ = run_streaming_mode(&mut framed, mode, cid).await;
        });

        self.addr = Some(addr);
        self.server = Some(handle);
        Ok(())
    }

    /// Verify the count of received data frames.
    pub fn verify_frame_count(&self, expected: usize) -> TestResult {
        let actual = self.received_frames.len();
        if actual != expected {
            return Err(format!("expected {expected} frames, got {actual}").into());
        }
        Ok(())
    }

    /// Verify frames arrived in order (payload == [1], [2], ...).
    pub fn verify_frame_order(&self) -> TestResult {
        for (i, frame) in self.received_frames.iter().enumerate() {
            let payload_byte =
                u8::try_from(i + 1).map_err(|e| format!("frame index {i} overflows u8: {e}"))?;
            let expected = Payload::new(vec![payload_byte]);
            if frame.payload != expected {
                return Err(format!(
                    "frame {i}: expected payload {expected:?}, got {:?}",
                    frame.payload
                )
                .into());
            }
        }
        Ok(())
    }

    /// Verify typed items arrived in order after control frames were skipped.
    pub fn verify_typed_item_order(&self, expected: &[u8]) -> TestResult {
        let actual: Vec<u8> = self
            .typed_items
            .iter()
            .map(TypedStreamingItem::value)
            .collect();
        if actual != expected {
            return Err(format!("expected typed items {expected:?}, got {actual:?}").into());
        }
        Ok(())
    }

    /// Verify the stream terminated cleanly (received `None`).
    pub fn verify_clean_termination(&self) -> TestResult {
        if !self.stream_terminated_cleanly {
            return Err("stream did not terminate cleanly".into());
        }
        Ok(())
    }

    /// Verify fairness-driven ordering for interleaved high/low push frames.
    pub fn verify_interleaved_priority_order(&self) -> TestResult {
        let expected = vec![
            Payload::new(vec![1]),
            Payload::new(vec![2]),
            Payload::new(vec![3]),
            Payload::new(vec![4]),
            Payload::new(vec![10]),
            Payload::new(vec![11]),
        ];
        let actual: Vec<Payload> = self
            .received_frames
            .iter()
            .map(|frame| frame.payload.clone())
            .collect();
        if actual != expected {
            return Err(format!(
                "expected interleaved priority payloads {expected:?}, got {actual:?}",
            )
            .into());
        }
        Ok(())
    }

    /// Verify shared limiter contention was observed and output ordering held.
    pub fn verify_shared_rate_limit_symmetry(&mut self) -> TestResult {
        let marker = self
            .received_frames
            .first()
            .ok_or("missing rate-limit marker frame")?
            .payload
            .clone()
            .into_inner();
        let was_blocked = marker == vec![SHARED_RATE_LIMIT_CONTENTION_MARKER];
        self.shared_rate_limit_blocked = Some(was_blocked);
        if !was_blocked {
            return Err("expected shared limiter contention marker".into());
        }

        let remaining: Vec<Vec<u8>> = self
            .received_frames
            .iter()
            .skip(1)
            .map(|frame| frame.payload.clone().into_inner())
            .collect();
        if remaining != vec![vec![1], vec![2]] {
            return Err(format!(
                "unexpected payload order under shared rate limiting: {remaining:?}",
            )
            .into());
        }
        Ok(())
    }

    /// Verify that a `StreamCorrelationMismatch` error was returned.
    pub fn verify_correlation_mismatch_error(&self) -> TestResult {
        match &self.last_error {
            Some(ClientError::StreamCorrelationMismatch { .. }) => Ok(()),
            Some(err) => Err(format!("expected StreamCorrelationMismatch, got {err:?}").into()),
            None => Err("expected StreamCorrelationMismatch, but no error".into()),
        }
    }

    /// Verify that a transport/disconnect error was returned.
    pub fn verify_disconnect_error(&self) -> TestResult {
        match &self.last_error {
            Some(ClientError::Wireframe(WireframeError::Io(_))) => Ok(()),
            Some(err) => Err(format!("expected transport error, got {err:?}").into()),
            None => Err("expected transport error, but no error".into()),
        }
    }

    /// Abort the server task.
    pub fn abort_server(&mut self) {
        if let Some(handle) = self.server.take() {
            handle.abort();
        }
    }
}