tarpc-cat 0.1.0

RPC framework built on comp-cat-rs: typed effects, no async, categorical foundations
Documentation
//! Length-delimited framing over [`Read`] / [`Write`].
//!
//! Each message is encoded as a 4-byte big-endian length prefix
//! followed by that many bytes of JSON payload.  These functions
//! are called inside [`Io::suspend`] blocks at the effectful boundary.
//!
//! [`Io::suspend`]: comp_cat_rs::effect::io::Io::suspend

use std::io::{Read, Write};

use crate::error::Error;

/// Encode `value` as length-delimited JSON and write it.
///
/// Serializes `value` to JSON bytes, writes a 4-byte big-endian
/// length prefix, then writes the payload.
///
/// # Errors
///
/// Returns [`Error::Serialize`] if JSON serialization fails,
/// or [`Error::Io`] if writing fails.
pub fn encode<T: serde::Serialize>(writer: &mut impl Write, value: &T) -> Result<(), Error> {
    let payload = serde_json::to_vec(value).map_err(Error::from_serialize)?;
    let len = u32::try_from(payload.len())
        .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "message too large"))?;
    writer.write_all(&len.to_be_bytes())?;
    writer.write_all(&payload)?;
    writer.flush()?;
    Ok(())
}

/// Decode a length-delimited JSON message.
///
/// Reads a 4-byte big-endian length prefix, then reads that many
/// bytes and deserializes them as JSON.
///
/// # Errors
///
/// Returns [`Error::ConnectionClosed`] on EOF before any bytes,
/// [`Error::Io`] on partial reads, or [`Error::Deserialize`] if
/// JSON deserialization fails.
pub fn decode<T: serde::de::DeserializeOwned>(reader: &mut impl Read) -> Result<T, Error> {
    let mut len_buf = [0u8; 4];
    match reader.read_exact(&mut len_buf) {
        Ok(()) => {}
        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
            return Err(Error::ConnectionClosed);
        }
        Err(e) => return Err(Error::from(e)),
    }
    let len = u32::from_be_bytes(len_buf) as usize;
    let mut payload = vec![0u8; len];
    reader.read_exact(&mut payload)?;
    serde_json::from_slice(&payload).map_err(Error::from_deserialize)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::protocol::{Envelope, RequestId};

    #[test]
    fn round_trip_envelope() -> Result<(), Error> {
        let envelope = Envelope::Request {
            id: RequestId::new(42),
            payload: r#"{"x":1}"#.to_owned(),
        };
        let mut buf: Vec<u8> = Vec::new();
        encode(&mut buf, &envelope)?;

        let mut cursor = std::io::Cursor::new(buf);
        let decoded: Envelope = decode(&mut cursor)?;

        match decoded {
            Envelope::Request { id, payload } => {
                assert_eq!(id.value(), 42);
                assert_eq!(payload, r#"{"x":1}"#);
                Ok(())
            }
            _ => Err(Error::Server {
                message: "wrong variant".to_owned(),
            }),
        }
    }

    #[test]
    fn decode_empty_returns_connection_closed() {
        let mut cursor = std::io::Cursor::new(Vec::<u8>::new());
        let result: Result<Envelope, Error> = decode(&mut cursor);
        assert!(matches!(result, Err(Error::ConnectionClosed)));
    }
}