commonware-stream 2026.4.0

Exchange messages over arbitrary transport.
Documentation
use crate::encrypted::Error;
use commonware_codec::{
    varint::{Decoder, UInt, MAX_U32_VARINT_SIZE},
    Encode, EncodeSize, Write,
};
use commonware_runtime::{Buf, IoBuf, IoBufMut, IoBufs, Sink, Stream};

/// Validates the frame size and assembles the frame via the caller's closure.
///
/// The `assemble` closure receives the varint prefix and must combine it with
/// the payload. This allows callers to choose between:
/// - Chunked: prepend the prefix as a separate buffer
/// - Contiguous: write the prefix directly into a pre-allocated buffer
///
/// Returns an error if the message is too large.
pub(crate) fn build_frame<T>(
    payload_len: usize,
    max_message_size: u32,
    assemble: impl FnOnce(UInt<u32>) -> Result<T, Error>,
) -> Result<T, Error> {
    if payload_len > max_message_size as usize {
        return Err(Error::SendTooLarge(payload_len));
    }
    let prefix = UInt(payload_len as u32);
    assemble(prefix)
}

/// Returns the total size of a length-prefixed frame.
pub(crate) fn framed_len(payload_len: usize, max_message_size: u32) -> Result<usize, Error> {
    build_frame(payload_len, max_message_size, |prefix| {
        Ok(prefix.encode_size() + payload_len)
    })
}

/// Appends one length-prefixed frame to a contiguous output buffer.
///
/// The callback receives the offset of the frame payload, which is useful when
/// callers need to operate on the payload bytes after copying them.
pub(crate) fn append_frame(
    frame: &mut IoBufMut,
    payload_len: usize,
    max_message_size: u32,
    append_payload: impl FnOnce(&mut IoBufMut, usize) -> Result<(), Error>,
) -> Result<usize, Error> {
    build_frame(payload_len, max_message_size, |prefix| {
        let start = frame.len();
        prefix.write(frame);
        let payload_offset = frame.len();
        append_payload(frame, payload_offset)?;
        assert_eq!(frame.len() - payload_offset, payload_len);
        Ok(frame.len() - start)
    })
}

/// Sends data to the sink with a varint length prefix.
///
/// The varint length prefix is prepended to the buffer(s), which results in a
/// chunked `IoBufs`.
///
/// Returns an error if the message is too large or the sink is closed.
pub async fn send_frame<S: Sink>(
    sink: &mut S,
    bufs: impl Into<IoBufs> + Send,
    max_message_size: u32,
) -> Result<(), Error> {
    let mut bufs = bufs.into();

    let frame = build_frame(bufs.len(), max_message_size, |prefix| {
        bufs.prepend(IoBuf::from(prefix.encode()));
        Ok(bufs)
    })?;
    sink.send(frame).await.map_err(Error::SendFailed)
}

/// Receives data from the stream with a varint length prefix.
/// Returns an error if the message is too large, the varint is invalid, or the
/// stream is closed.
pub async fn recv_frame<T: Stream>(stream: &mut T, max_message_size: u32) -> Result<IoBufs, Error> {
    let (len, skip) = recv_length(stream).await?;
    if len > max_message_size as usize {
        return Err(Error::RecvTooLarge(len));
    }

    stream
        .recv(skip + len)
        .await
        .map(|mut bufs| {
            bufs.advance(skip);
            bufs
        })
        .map_err(Error::RecvFailed)
}

/// Receives and decodes the varint length prefix from the stream.
/// Returns (payload_len, bytes_to_skip) where bytes_to_skip is:
/// - varint_len if decoded from peek buffer (bytes not yet consumed)
/// - 0 if decoded via recv (bytes already consumed)
async fn recv_length<T: Stream>(stream: &mut T) -> Result<(usize, usize), Error> {
    let mut decoder = Decoder::<u32>::new();

    // Fast path: decode from peek buffer without blocking
    let peeked = {
        let peeked = stream.peek(MAX_U32_VARINT_SIZE);
        for (i, byte) in peeked.iter().enumerate() {
            match decoder.feed(*byte) {
                Ok(Some(len)) => return Ok((len as usize, i + 1)),
                Ok(None) => continue,
                Err(_) => return Err(Error::InvalidVarint),
            }
        }
        peeked.len()
    };

    // Slow path: fetch bytes one at a time (skipping already-decoded peek bytes)
    let mut buf = stream.recv(peeked + 1).await.map_err(Error::RecvFailed)?;
    buf.advance(peeked);

    loop {
        match decoder.feed(buf.get_u8()) {
            Ok(Some(len)) => return Ok((len as usize, 0)),
            Ok(None) => {}
            Err(_) => return Err(Error::InvalidVarint),
        }
        buf = stream.recv(1).await.map_err(Error::RecvFailed)?;
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use commonware_runtime::{deterministic, mocks, BufMut, IoBufMut, Runner};
    use rand::Rng;

    const MAX_MESSAGE_SIZE: u32 = 1024;

    #[test]
    fn test_send_recv_at_max_message_size() {
        let (mut sink, mut stream) = mocks::Channel::init();

        let executor = deterministic::Runner::default();
        executor.start(|mut context| async move {
            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
            context.fill(&mut buf);

            let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await;
            assert!(result.is_ok());

            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
            assert_eq!(data.len(), buf.len());
            assert_eq!(data.coalesce(), buf);
        });
    }

    #[test]
    fn test_send_recv_multiple() {
        let (mut sink, mut stream) = mocks::Channel::init();

        let executor = deterministic::Runner::default();
        executor.start(|mut context| async move {
            let mut buf1 = [0u8; MAX_MESSAGE_SIZE as usize];
            let mut buf2 = [0u8; (MAX_MESSAGE_SIZE as usize) / 2];
            context.fill(&mut buf1);
            context.fill(&mut buf2);

            // Send two messages of different sizes
            let result = send_frame(&mut sink, buf1.to_vec(), MAX_MESSAGE_SIZE).await;
            assert!(result.is_ok());
            let result = send_frame(&mut sink, buf2.to_vec(), MAX_MESSAGE_SIZE).await;
            assert!(result.is_ok());

            // Read both messages in order
            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
            assert_eq!(data.len(), buf1.len());
            assert_eq!(data.coalesce(), buf1);
            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
            assert_eq!(data.len(), buf2.len());
            assert_eq!(data.coalesce(), buf2);
        });
    }

    #[test]
    fn test_send_frame() {
        let (mut sink, mut stream) = mocks::Channel::init();

        let executor = deterministic::Runner::default();
        executor.start(|mut context| async move {
            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
            context.fill(&mut buf);

            let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await;
            assert!(result.is_ok());

            // Do the reading manually without using recv_frame
            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08] (2 bytes)
            let read = stream.recv(2).await.unwrap();
            assert_eq!(read.coalesce(), &[0x80, 0x08]); // 1024 as varint
            let read = stream.recv(MAX_MESSAGE_SIZE as usize).await.unwrap();
            assert_eq!(read.coalesce(), buf);
        });
    }

    #[test]
    fn test_build_frame_closure_error() {
        let result: Result<IoBufs, _> = build_frame(10, MAX_MESSAGE_SIZE, |_prefix| {
            Err(Error::HandshakeError(
                commonware_cryptography::handshake::Error::EncryptionFailed,
            ))
        });
        assert!(matches!(&result, Err(Error::HandshakeError(_))));
    }

    #[test]
    fn test_build_frame_too_large() {
        let result: Result<IoBufs, _> = build_frame(
            MAX_MESSAGE_SIZE as usize + 1,
            MAX_MESSAGE_SIZE,
            |_prefix| unreachable!(),
        );
        assert!(
            matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize + 1)
        );
    }

    #[test]
    fn test_send_frame_too_large() {
        let (mut sink, _) = mocks::Channel::init();

        let executor = deterministic::Runner::default();
        executor.start(|mut context| async move {
            let mut buf = [0u8; MAX_MESSAGE_SIZE as usize];
            context.fill(&mut buf);

            let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE - 1).await;
            assert!(
                matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
            );
        });
    }

    #[test]
    fn test_read_frame() {
        let (mut sink, mut stream) = mocks::Channel::init();

        let executor = deterministic::Runner::default();
        executor.start(|mut context| async move {
            // Do the writing manually without using send_frame
            let mut msg = [0u8; MAX_MESSAGE_SIZE as usize];
            context.fill(&mut msg);

            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08]
            let mut buf = IoBufMut::with_capacity(2 + msg.len());
            buf.put_u8(0x80);
            buf.put_u8(0x08);
            buf.put_slice(&msg);
            sink.send(buf.freeze()).await.unwrap();

            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
            assert_eq!(data.len(), MAX_MESSAGE_SIZE as usize);
            assert_eq!(data.coalesce(), msg);
        });
    }

    #[test]
    fn test_read_frame_too_large() {
        let (mut sink, mut stream) = mocks::Channel::init();

        let executor = deterministic::Runner::default();
        executor.start(|_| async move {
            // Manually insert a frame that gives MAX_MESSAGE_SIZE as the size
            // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08]
            let mut buf = IoBufMut::with_capacity(2);
            buf.put_u8(0x80);
            buf.put_u8(0x08);
            sink.send(buf.freeze()).await.unwrap();

            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await;
            assert!(
                matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize)
            );
        });
    }

    #[test]
    fn test_recv_frame_incomplete_varint() {
        let (mut sink, mut stream) = mocks::Channel::init();

        let executor = deterministic::Runner::default();
        executor.start(|_| async move {
            // Send incomplete varint (continuation bit set but no following byte)
            let mut buf = IoBufMut::with_capacity(1);
            buf.put_u8(0x80); // Continuation bit set, expects more bytes

            sink.send(buf.freeze()).await.unwrap();
            drop(sink); // Close the sink to simulate a closed stream

            // Expect an error because varint is incomplete
            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
            assert!(matches!(&result, Err(Error::RecvFailed(_))));
        });
    }

    #[test]
    fn test_recv_frame_invalid_varint_overflow() {
        let (mut sink, mut stream) = mocks::Channel::init();

        let executor = deterministic::Runner::default();
        executor.start(|_| async move {
            // Send a varint that overflows u32 (more than 5 bytes with continuation bits)
            let mut buf = IoBufMut::with_capacity(6);
            buf.put_u8(0xFF); // 7 bits + continue
            buf.put_u8(0xFF); // 7 bits + continue
            buf.put_u8(0xFF); // 7 bits + continue
            buf.put_u8(0xFF); // 7 bits + continue
            buf.put_u8(0xFF); // 5th byte with overflow bits set + continue
            buf.put_u8(0x01); // 6th byte

            sink.send(buf.freeze()).await.unwrap();

            // Expect an error because varint overflows u32
            let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await;
            assert!(matches!(&result, Err(Error::InvalidVarint)));
        });
    }

    #[test]
    fn test_recv_frame_peek_paths() {
        let executor = deterministic::Runner::default();
        executor.start(|mut context| async move {
            // 300 encodes as [0xAC, 0x02] (2-byte varint)
            let mut payload = vec![0u8; 300];
            context.fill(&mut payload[..]);

            // Fast path: peek returns complete varint
            let (mut sink, mut stream) = mocks::Channel::init();
            send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
                .await
                .unwrap();
            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
            assert_eq!(data.coalesce(), &payload[..]);

            // Slow path: peek returns empty
            let (mut sink, mut stream) = mocks::Channel::init_with_read_buffer_size(0);
            send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
                .await
                .unwrap();
            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
            assert_eq!(data.coalesce(), &payload[..]);

            // Slow path: peek returns partial varint
            let (mut sink, mut stream) = mocks::Channel::init_with_read_buffer_size(1);
            send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE)
                .await
                .unwrap();
            let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap();
            assert_eq!(data.coalesce(), &payload[..]);
        });
    }
}