wireframe 0.3.0

Simplify building servers and clients for custom binary protocols.
Documentation
//! `ClientRuntimeWorld` fixture for rstest-bdd tests.
//!
//! Provides an echo server/client pair to validate client runtime framing
//! behaviour.

use std::{
    cell::{Cell, RefCell},
    net::SocketAddr,
};

use bytes::Bytes;
use futures::{SinkExt, StreamExt};
use log::warn;
use rstest::fixture;
use tokio::{net::TcpListener, task::JoinHandle};
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use wireframe::{
    WireframeError,
    app::Envelope,
    client::{ClientCodecConfig, ClientError, ClientProtocolError, WireframeClient},
    message::Message,
    rewind_stream::RewindStream,
    serializer::BincodeSerializer,
};

#[path = "../../examples/support/echo_login_contract.rs"]
mod echo_login_contract;

use echo_login_contract::{LOGIN_ROUTE_ID, LoginAck, LoginRequest};
/// `TestResult` for step definitions.
pub use wireframe_testing::TestResult;

/// Simulated TLS 1.2 `ServerHello` record, used to exercise protocol-mismatch
/// error-handling in the wireframe client.
pub const SIMULATED_TLS_RECORD: [u8; 7] = [0x16, 0x03, 0x03, 0x00, 0x02, 0x02, 0x28];

/// Test world exercising the wireframe client runtime.
#[derive(Debug)]
pub struct ClientRuntimeWorld {
    runtime: Option<tokio::runtime::Runtime>,
    runtime_error: Option<String>,
    addr: Cell<Option<SocketAddr>>,
    server: RefCell<Option<JoinHandle<()>>>,
    client:
        RefCell<Option<WireframeClient<BincodeSerializer, RewindStream<tokio::net::TcpStream>>>>,
    payload: RefCell<Option<ClientPayload>>,
    response: RefCell<Option<ClientPayload>>,
    login_ack: RefCell<Option<LoginAck>>,
    last_error: RefCell<Option<ClientError>>,
}

impl ClientRuntimeWorld {
    /// Build a new runtime-backed client world.
    pub fn new() -> Self {
        match tokio::runtime::Runtime::new() {
            Ok(runtime) => Self {
                runtime: Some(runtime),
                runtime_error: None,
                addr: Cell::new(None),
                server: RefCell::new(None),
                client: RefCell::new(None),
                payload: RefCell::new(None),
                response: RefCell::new(None),
                login_ack: RefCell::new(None),
                last_error: RefCell::new(None),
            },
            Err(err) => Self {
                runtime: None,
                runtime_error: Some(format!("failed to create runtime: {err}")),
                addr: Cell::new(None),
                server: RefCell::new(None),
                client: RefCell::new(None),
                payload: RefCell::new(None),
                response: RefCell::new(None),
                login_ack: RefCell::new(None),
                last_error: RefCell::new(None),
            },
        }
    }

    fn runtime(&self) -> TestResult<&tokio::runtime::Runtime> {
        self.runtime.as_ref().ok_or_else(|| {
            self.runtime_error
                .clone()
                .unwrap_or_else(|| "runtime unavailable".to_string())
                .into()
        })
    }

    fn block_on<F, T>(&self, future: F) -> TestResult<T>
    where
        F: std::future::Future<Output = T>,
    {
        if tokio::runtime::Handle::try_current().is_ok() {
            return Err("nested Tokio runtime detected in client runtime fixture".into());
        }
        let runtime = self.runtime()?;
        Ok(runtime.block_on(future))
    }
}

#[derive(bincode::Encode, bincode::BorrowDecode, Debug, PartialEq, Eq, Clone)]
struct ClientPayload {
    data: Vec<u8>,
}

/// Fixture for `ClientRuntimeWorld`.
// rustfmt collapses simple fixtures into one line, which triggers unused_braces.
#[rustfmt::skip]
#[fixture]
pub fn client_runtime_world() -> ClientRuntimeWorld {
    ClientRuntimeWorld::new()
}

impl ClientRuntimeWorld {
    /// Start an echo server with the specified maximum frame length.
    ///
    /// # Errors
    /// Returns an error if binding or spawning the server fails.
    pub fn start_server(&self, max_frame_length: usize) -> TestResult {
        let listener = self.block_on(async { TcpListener::bind("127.0.0.1:0").await })??;
        let addr = listener.local_addr()?;
        let handle = self.runtime()?.spawn(async move {
            let Ok((stream, _)) = listener.accept().await else {
                warn!("client runtime server failed to accept connection");
                return;
            };
            let codec = LengthDelimitedCodec::builder()
                .max_frame_length(max_frame_length)
                .new_codec();
            let mut framed = Framed::new(stream, codec);
            let Some(result) = framed.next().await else {
                warn!("client runtime server closed before receiving a frame");
                return;
            };
            let Ok(frame) = result else {
                warn!("client runtime server failed to decode frame");
                return;
            };
            if let Err(err) = framed.send(frame.freeze()).await {
                warn!("client runtime server failed to send response: {err:?}");
            }
        });

        self.addr.set(Some(addr));
        *self.server.borrow_mut() = Some(handle);
        Ok(())
    }

    /// Start a server that always sends malformed response bytes.
    ///
    /// # Errors
    /// Returns an error if binding or spawning the server fails.
    pub fn start_malformed_response_server(&self) -> TestResult {
        let listener = self.block_on(async { TcpListener::bind("127.0.0.1:0").await })??;
        let addr = listener.local_addr()?;
        let handle = self.runtime()?.spawn(async move {
            let Ok((stream, _)) = listener.accept().await else {
                warn!("client runtime malformed server failed to accept connection");
                return;
            };
            let mut framed = Framed::new(stream, LengthDelimitedCodec::new());
            let Some(result) = framed.next().await else {
                warn!("client runtime malformed server closed before receiving a frame");
                return;
            };
            let Ok(_frame) = result else {
                warn!("client runtime malformed server failed to decode request frame");
                return;
            };
            if let Err(err) = framed
                .send(Bytes::from_static(&[0xff, 0xff, 0xff, 0xff]))
                .await
            {
                warn!("client runtime malformed server failed to send invalid frame: {err:?}");
            }
        });

        self.addr.set(Some(addr));
        *self.server.borrow_mut() = Some(handle);
        Ok(())
    }

    /// Start a server that replies with TLS-like bytes instead of Wireframe frames.
    ///
    /// # Errors
    /// Returns an error if binding or spawning the server fails.
    pub fn start_tls_mismatch_server(&self) -> TestResult {
        let listener = self.block_on(async { TcpListener::bind("127.0.0.1:0").await })??;
        let addr = listener.local_addr()?;
        let handle = self.runtime()?.spawn(async move {
            use tokio::io::{AsyncReadExt, AsyncWriteExt};

            let Ok((mut stream, _)) = listener.accept().await else {
                warn!("client runtime TLS-mismatch server failed to accept connection");
                return;
            };
            let mut request_buf = [0_u8; 64];
            match stream.read(&mut request_buf).await {
                Ok(_) => {}
                Err(err) => {
                    warn!("client runtime TLS-mismatch server failed to read request: {err:?}");
                    return;
                }
            }
            if let Err(err) = stream.write_all(&SIMULATED_TLS_RECORD).await {
                warn!("client runtime TLS-mismatch server failed to write response: {err:?}");
                return;
            }
            if let Err(err) = stream.shutdown().await {
                warn!("client runtime TLS-mismatch server failed to shutdown: {err:?}");
            }
        });

        self.addr.set(Some(addr));
        *self.server.borrow_mut() = Some(handle);
        Ok(())
    }

    /// Connect a client using the specified maximum frame length.
    ///
    /// # Errors
    /// Returns an error if the server has not started or the client fails to connect.
    pub fn connect_client(&self, max_frame_length: usize) -> TestResult {
        let addr = self.addr.get().ok_or("server address missing")?;
        let codec_config = ClientCodecConfig::default().max_frame_length(max_frame_length);
        let client = self.block_on(async {
            WireframeClient::builder()
                .codec_config(codec_config)
                .connect(addr)
                .await
        })??;
        *self.client.borrow_mut() = Some(client);
        Ok(())
    }

    /// Send a payload of the specified size and capture the response.
    ///
    /// # Errors
    /// Returns an error if the client is missing or communication fails.
    pub fn send_payload(&self, size: usize) -> TestResult {
        let (payload, result) = self.send_payload_inner(size)?;
        let response = result?;
        *self.payload.borrow_mut() = Some(payload);
        *self.response.borrow_mut() = Some(response);
        *self.login_ack.borrow_mut() = None;
        *self.last_error.borrow_mut() = None;
        Ok(())
    }

    /// Send a payload that should exceed the peer's frame limit.
    ///
    /// # Errors
    /// Returns an error if the client is missing or if no failure is observed.
    pub fn send_payload_expect_error(&self, size: usize) -> TestResult {
        let (_payload, result) = self.send_payload_inner(size)?;
        match result {
            Ok(_) => return Err("expected client error for oversized payload".into()),
            Err(err) => {
                *self.last_error.borrow_mut() = Some(err);
                *self.login_ack.borrow_mut() = None;
            }
        }
        Ok(())
    }

    fn send_payload_inner(
        &self,
        size: usize,
    ) -> TestResult<(ClientPayload, Result<ClientPayload, ClientError>)> {
        let payload = ClientPayload {
            data: vec![7_u8; size],
        };
        let mut client = self
            .client
            .borrow_mut()
            .take()
            .ok_or("client not connected")?;
        let result = self.block_on(async { client.call(&payload).await })?;
        *self.client.borrow_mut() = Some(client);
        Ok((payload, result))
    }

    /// Verify that the client received the echoed payload.
    ///
    /// # Errors
    /// Returns an error if the response is missing or mismatched.
    pub fn verify_echo(&self) -> TestResult {
        let payload_ref = self.payload.borrow();
        let response_ref = self.response.borrow();
        let payload = payload_ref.as_ref().ok_or("payload missing")?;
        let response = response_ref.as_ref().ok_or("response missing")?;
        if payload != response {
            return Err("response did not match payload".into());
        }
        self.await_server()?;
        Ok(())
    }

    /// Send a login request and capture the echoed acknowledgement.
    ///
    /// # Errors
    /// Returns an error if the client is missing or deserialization fails.
    pub fn send_login_request(&self, username: String) -> TestResult {
        *self.login_ack.borrow_mut() = None;
        let login = LoginRequest { username };
        let mut client = self
            .client
            .borrow_mut()
            .take()
            .ok_or("client not connected")?;
        let request = Envelope::new(LOGIN_ROUTE_ID, None, login.to_bytes()?);
        let response: Envelope =
            self.block_on(async { client.call_correlated(request).await })??;
        let (ack, _) = LoginAck::from_bytes(response.payload_bytes())?;
        *self.client.borrow_mut() = Some(client);
        *self.login_ack.borrow_mut() = Some(ack);
        *self.last_error.borrow_mut() = None;
        Ok(())
    }

    /// Verify that login acknowledgement decoding succeeded for `expected`.
    ///
    /// # Errors
    /// Returns an error when the acknowledgement is missing or mismatched.
    pub fn verify_login_acknowledgement(&self, expected: &str) -> TestResult {
        let ack_ref = self.login_ack.borrow();
        let ack = ack_ref.as_ref().ok_or("login acknowledgement missing")?;
        if ack.username != expected {
            return Err(format!(
                "expected login acknowledgement for '{expected}', got '{}'",
                ack.username
            )
            .into());
        }
        self.await_server()?;
        Ok(())
    }

    /// Verify that the recorded error is a transport `WireframeError`.
    ///
    /// # Errors
    /// Returns an error if no failure was observed or the error variant differs.
    pub fn verify_transport_wireframe_error(&self) -> TestResult {
        let error_ref = self.last_error.borrow();
        let err = error_ref
            .as_ref()
            .ok_or("expected client error was not captured")?;
        if !matches!(err, ClientError::Wireframe(WireframeError::Io(_))) {
            return Err(format!("expected transport WireframeError::Io, got {err:?}").into());
        }
        self.await_server()?;
        Ok(())
    }

    /// Verify that the recorded error is a decode `WireframeError`.
    ///
    /// # Errors
    /// Returns an error if no failure was observed or the error variant differs.
    pub fn verify_decode_wireframe_error(&self) -> TestResult {
        let error_ref = self.last_error.borrow();
        let err = error_ref
            .as_ref()
            .ok_or("expected client error was not captured")?;
        if !matches!(
            err,
            ClientError::Wireframe(WireframeError::Protocol(ClientProtocolError::Deserialize(
                _
            )))
        ) {
            return Err(format!(
                "expected decode WireframeError::Protocol(ClientProtocolError::Deserialize(_)), \
                 got {err:?}"
            )
            .into());
        }
        self.await_server()?;
        Ok(())
    }

    fn await_server(&self) -> TestResult {
        if let Some(handle) = self.server.borrow_mut().take() {
            self.block_on(async {
                handle
                    .await
                    .map_err(|err| format!("server task failed: {err}"))
            })??;
        }
        Ok(())
    }
}