wireframe 0.3.0

Simplify building servers and clients for custom binary protocols.
Documentation
//! Wireframe client runtime implementation.

use std::{fmt, sync::atomic::AtomicU64, time::Instant};

use bytes::Bytes;
use futures::SinkExt;
use tokio::{
    io::{AsyncRead, AsyncWrite},
    net::TcpStream,
};
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use tracing::Instrument;

use super::{
    ClientCodecConfig,
    ClientError,
    WireframeClientBuilder,
    hooks::{ClientConnectionTeardownHandler, ClientErrorHandler, RequestHooks},
    tracing_config::TracingConfig,
    tracing_helpers::{call_span, close_span, emit_timing_event, send_span},
};
use crate::{
    message::{DecodeWith, EncodeWith},
    rewind_stream::RewindStream,
    serializer::{BincodeSerializer, Serializer},
};

/// Trait alias for stream types that can be used with the client runtime.
pub trait ClientStream: AsyncRead + AsyncWrite + Unpin {}
impl<T> ClientStream for T where T: AsyncRead + AsyncWrite + Unpin {}

/// Client runtime for wireframe connections.
///
/// The client supports connection lifecycle hooks that mirror the server's
/// hooks, enabling consistent instrumentation across both ends of a wireframe
/// connection.
///
/// # Correlation Identifiers
///
/// When using the envelope-aware APIs ([`send_envelope`](Self::send_envelope),
/// [`receive_envelope`](Self::receive_envelope), and
/// [`call_correlated`](Self::call_correlated)), the client automatically
/// generates unique correlation identifiers for each request. The response
/// is validated to ensure its correlation ID matches the request.
///
/// # Examples
///
/// ```no_run
/// use std::net::SocketAddr;
///
/// use wireframe::client::WireframeClient;
///
/// # #[tokio::main]
/// # async fn main() -> Result<(), wireframe::client::ClientError> {
/// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address");
/// let _client = WireframeClient::builder().connect(addr).await?;
/// # Ok(())
/// # }
/// ```
pub struct WireframeClient<S = BincodeSerializer, T = TcpStream, C = ()>
where
    T: ClientStream,
{
    pub(crate) framed: Framed<T, LengthDelimitedCodec>,
    pub(crate) serializer: S,
    pub(crate) codec_config: ClientCodecConfig,
    pub(crate) connection_state: Option<C>,
    pub(crate) on_disconnect: Option<ClientConnectionTeardownHandler<C>>,
    pub(crate) on_error: Option<ClientErrorHandler>,
    /// Hooks invoked on every outgoing and incoming frame.
    pub(crate) request_hooks: RequestHooks,
    /// Tracing configuration for span levels and per-command timing.
    pub(crate) tracing_config: TracingConfig,
    /// Counter for generating unique correlation identifiers.
    pub(crate) correlation_counter: AtomicU64,
}

impl<S, T, C> fmt::Debug for WireframeClient<S, T, C>
where
    T: ClientStream,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("WireframeClient")
            .field("codec_config", &self.codec_config)
            .finish_non_exhaustive()
    }
}

impl WireframeClient<BincodeSerializer, TcpStream, ()> {
    /// Start building a new client with the default serializer and codec.
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use std::net::SocketAddr;
    ///
    /// use wireframe::client::WireframeClient;
    ///
    /// # #[tokio::main]
    /// # async fn main() -> Result<(), wireframe::client::ClientError> {
    /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address");
    /// let _client = WireframeClient::builder().connect(addr).await?;
    /// # Ok(())
    /// # }
    /// ```
    #[must_use]
    pub fn builder() -> WireframeClientBuilder<BincodeSerializer, (), ()> {
        WireframeClientBuilder::new()
    }
}

impl<S, T, C> WireframeClient<S, T, C>
where
    S: Serializer + Send + Sync,
    T: ClientStream,
{
    /// Send a message to the peer using the configured serializer.
    ///
    /// If an error hook is registered, it is invoked before the error is
    /// returned.
    ///
    /// # Errors
    /// Returns [`ClientError`] if serialization fails or transport I/O fails.
    /// Transport failures are surfaced through
    /// [`crate::WireframeError::Io`] within
    /// [`ClientError::Wireframe`](super::ClientError::Wireframe).
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use std::net::SocketAddr;
    ///
    /// use wireframe::client::{ClientError, WireframeClient};
    ///
    /// #[derive(bincode::Encode, bincode::BorrowDecode)]
    /// struct Ping(u8);
    ///
    /// # #[tokio::main]
    /// # async fn main() -> Result<(), ClientError> {
    /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address");
    /// let mut client = WireframeClient::builder().connect(addr).await?;
    /// client.send(&Ping(1)).await?;
    /// # Ok(())
    /// # }
    /// ```
    pub async fn send<M: EncodeWith<S>>(&mut self, message: &M) -> Result<(), ClientError> {
        let timing_start = self.tracing_config.send_timing.then(Instant::now);
        let mut bytes = match self.serializer.serialize(message) {
            Ok(bytes) => bytes,
            Err(e) => {
                let err = ClientError::Serialize(e);
                emit_timing_event(timing_start);
                self.invoke_error_hook(&err).await;
                return Err(err);
            }
        };
        self.invoke_before_send_hooks(&mut bytes);
        let span = send_span(&self.tracing_config, bytes.len());
        let send_result = async {
            let result = self.framed.send(Bytes::from(bytes)).await;
            emit_timing_event(timing_start);
            result
        }
        .instrument(span)
        .await;
        if let Err(e) = send_result {
            let err = ClientError::from(e);
            self.invoke_error_hook(&err).await;
            return Err(err);
        }
        Ok(())
    }

    /// Receive the next message from the peer.
    ///
    /// If an error hook is registered, it is invoked before the error is
    /// returned.
    ///
    /// # Errors
    /// Returns [`ClientError`] if the connection closes, decoding fails, or I/O
    /// errors occur. Transport failures are surfaced through
    /// [`crate::WireframeError::Io`], while decode failures are surfaced
    /// through [`crate::WireframeError::Protocol`].
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use std::net::SocketAddr;
    ///
    /// use wireframe::client::{ClientError, WireframeClient};
    ///
    /// #[derive(bincode::Encode, bincode::BorrowDecode, Debug, PartialEq)]
    /// struct Pong(u8);
    ///
    /// # #[tokio::main]
    /// # async fn main() -> Result<(), ClientError> {
    /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address");
    /// let mut client = WireframeClient::builder().connect(addr).await?;
    /// let _pong: Pong = client.receive().await?;
    /// # Ok(())
    /// # }
    /// ```
    pub async fn receive<M: DecodeWith<S>>(&mut self) -> Result<M, ClientError> {
        self.receive_internal().await
    }

    /// Send a message and await the next response.
    ///
    /// If an error hook is registered, it is invoked before any error is
    /// returned.
    ///
    /// # Errors
    /// Returns [`ClientError`] if the request cannot be sent or the response
    /// cannot be decoded. Transport failures are surfaced through
    /// [`crate::WireframeError::Io`], while decode failures are surfaced
    /// through [`crate::WireframeError::Protocol`].
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use std::net::SocketAddr;
    ///
    /// use wireframe::client::{ClientError, WireframeClient};
    ///
    /// #[derive(bincode::Encode, bincode::BorrowDecode)]
    /// struct Login {
    ///     username: String,
    /// }
    ///
    /// #[derive(bincode::Encode, bincode::BorrowDecode, Debug, PartialEq)]
    /// struct LoginAck {
    ///     ok: bool,
    /// }
    ///
    /// # #[tokio::main]
    /// # async fn main() -> Result<(), ClientError> {
    /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address");
    /// let mut client = WireframeClient::builder().connect(addr).await?;
    /// let login = Login {
    ///     username: "guest".to_string(),
    /// };
    /// let _ack: LoginAck = client.call(&login).await?;
    /// # Ok(())
    /// # }
    /// ```
    pub async fn call<Req: EncodeWith<S>, Resp: DecodeWith<S>>(
        &mut self,
        request: &Req,
    ) -> Result<Resp, ClientError> {
        let span = call_span(&self.tracing_config);
        let timing_start = self.tracing_config.call_timing.then(Instant::now);

        async {
            if let Err(err) = self.send(request).await {
                // Error hook already invoked by send.
                span.record("result", "err");
                emit_timing_event(timing_start);
                return Err(err);
            }
            let result = self.receive().await;
            if result.is_ok() {
                Self::traced_ok(&span, timing_start);
            } else {
                // Error hook already invoked by receive_internal.
                span.record("result", "err");
                emit_timing_event(timing_start);
            }
            result
        }
        .instrument(span.clone())
        .await
    }

    /// Inspect the configured codec settings.
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use std::net::SocketAddr;
    ///
    /// use wireframe::client::{ClientError, WireframeClient};
    ///
    /// # #[tokio::main]
    /// # async fn main() -> Result<(), ClientError> {
    /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address");
    /// let client = WireframeClient::builder().connect(addr).await?;
    /// let codec = client.codec_config();
    /// assert_eq!(codec.max_frame_length_value(), 1024);
    /// # Ok(())
    /// # }
    /// ```
    #[must_use]
    pub const fn codec_config(&self) -> &ClientCodecConfig { &self.codec_config }

    /// Access the underlying stream.
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use std::net::SocketAddr;
    ///
    /// use wireframe::client::{ClientError, WireframeClient};
    ///
    /// # #[tokio::main]
    /// # async fn main() -> Result<(), ClientError> {
    /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address");
    /// let client = WireframeClient::builder().connect(addr).await?;
    /// let _stream = client.stream();
    /// # Ok(())
    /// # }
    /// ```
    #[must_use]
    pub fn stream(&self) -> &T { self.framed.get_ref() }
}

impl<S, C> WireframeClient<S, RewindStream<TcpStream>, C>
where
    S: Serializer + Send + Sync,
    C: Send + 'static,
{
    /// Access the underlying [`TcpStream`].
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use std::net::SocketAddr;
    ///
    /// use wireframe::client::{ClientError, WireframeClient};
    ///
    /// # #[tokio::main]
    /// # async fn main() -> Result<(), ClientError> {
    /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address");
    /// let client = WireframeClient::builder().connect(addr).await?;
    /// let _stream = client.tcp_stream();
    /// # Ok(())
    /// # }
    /// ```
    #[must_use]
    pub fn tcp_stream(&self) -> &TcpStream { self.framed.get_ref().inner() }

    /// Access the rewind stream wrapper.
    ///
    /// This provides access to the [`RewindStream`] that wraps the TCP stream,
    /// which may contain leftover bytes from preamble exchange.
    #[must_use]
    pub fn rewind_stream(&self) -> &RewindStream<TcpStream> { self.framed.get_ref() }

    /// Gracefully close the connection, invoking teardown hooks.
    ///
    /// This method flushes any pending frames and sends EOF to the peer before
    /// invoking the teardown hook. If a teardown hook was registered via
    /// [`on_connection_teardown`](super::WireframeClientBuilder::on_connection_teardown),
    /// it is invoked with the connection state produced by the setup hook.
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use std::net::SocketAddr;
    ///
    /// use wireframe::client::{ClientError, WireframeClient};
    ///
    /// struct Session {
    ///     id: u64,
    /// }
    ///
    /// # #[tokio::main]
    /// # async fn main() -> Result<(), ClientError> {
    /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address");
    /// let client = WireframeClient::builder()
    ///     .on_connection_setup(|| async { Session { id: 42 } })
    ///     .on_connection_teardown(|session| async move {
    ///         println!("Session {} closed", session.id);
    ///     })
    ///     .connect(addr)
    ///     .await?;
    /// client.close().await;
    /// # Ok(())
    /// # }
    /// ```
    pub async fn close(mut self) {
        let span = close_span(&self.tracing_config);
        let timing_start = self.tracing_config.close_timing.then(Instant::now);

        async {
            // Flush pending frames and send EOF before teardown.
            // Ignore errors since we're closing anyway.
            let _ = self.framed.close().await;

            if let (Some(state), Some(handler)) =
                (self.connection_state.take(), &self.on_disconnect)
            {
                handler(state).await;
            }
            emit_timing_event(timing_start);
        }
        .instrument(span)
        .await;
    }
}