zmux 0.1.0

Rust implementation of the ZMux v1 stream multiplexing protocol
Documentation
use crate::error::{Error, ErrorDirection, ErrorOperation, ErrorScope, ErrorSource, Result};
use std::io;

pub const MAX_VARINT62: u64 = (1u64 << 62) - 1;
pub const MAX_VARINT_LEN: usize = 8;

#[inline]
pub fn varint_len(v: u64) -> Result<usize> {
    if v <= MAX_VARINT62 {
        Ok(canonical_varint_len(v))
    } else {
        Err(Error::protocol("varint62 value out of range"))
    }
}

#[inline]
pub fn append_varint(dst: &mut Vec<u8>, v: u64) -> Result<()> {
    let n = varint_len(v)?;
    dst.try_reserve(n)
        .map_err(|_| Error::local("zmux: varint allocation failed"))?;
    match n {
        1 => dst.push(v as u8),
        2 => dst.extend_from_slice(&[((v >> 8) as u8 & 0x3f) | 0x40, v as u8]),
        4 => dst.extend_from_slice(&[
            ((v >> 24) as u8 & 0x3f) | 0x80,
            (v >> 16) as u8,
            (v >> 8) as u8,
            v as u8,
        ]),
        8 => dst.extend_from_slice(&[
            ((v >> 56) as u8 & 0x3f) | 0xc0,
            (v >> 48) as u8,
            (v >> 40) as u8,
            (v >> 32) as u8,
            (v >> 24) as u8,
            (v >> 16) as u8,
            (v >> 8) as u8,
            v as u8,
        ]),
        _ => unreachable!(),
    }
    Ok(())
}

#[inline]
pub fn encode_varint_to_slice(dst: &mut [u8], v: u64) -> Result<usize> {
    let n = varint_len(v)?;
    if dst.len() < n {
        return Err(Error::frame_size("varint destination too small"));
    }
    match n {
        1 => dst[0] = v as u8,
        2 => {
            dst[0] = ((v >> 8) as u8 & 0x3f) | 0x40;
            dst[1] = v as u8;
        }
        4 => {
            dst[0] = ((v >> 24) as u8 & 0x3f) | 0x80;
            dst[1] = (v >> 16) as u8;
            dst[2] = (v >> 8) as u8;
            dst[3] = v as u8;
        }
        8 => {
            dst[0] = ((v >> 56) as u8 & 0x3f) | 0xc0;
            dst[1] = (v >> 48) as u8;
            dst[2] = (v >> 40) as u8;
            dst[3] = (v >> 32) as u8;
            dst[4] = (v >> 24) as u8;
            dst[5] = (v >> 16) as u8;
            dst[6] = (v >> 8) as u8;
            dst[7] = v as u8;
        }
        _ => unreachable!(),
    }
    Ok(n)
}

#[inline]
pub fn parse_varint(src: &[u8]) -> Result<(u64, usize)> {
    if src.is_empty() {
        return Err(varint_wire_error("truncated varint62"));
    }
    let first = src[0];
    let n = encoded_len_from_first(first);
    if src.len() < n {
        return Err(varint_wire_error("truncated varint62"));
    }
    validate_decoded_varint(decode_varint_value(src, n), n)
}

#[inline]
pub(crate) fn decode_varint_with_len(src: &[u8], n: usize) -> Result<u64> {
    if !matches!(n, 1 | 2 | 4 | 8) || src.len() < n {
        return Err(varint_wire_error("truncated varint62"));
    }
    validate_decoded_varint(decode_varint_value(src, n), n).map(|(value, _)| value)
}

#[inline]
pub fn read_varint<R: io::Read>(reader: &mut R) -> Result<(u64, usize)> {
    let mut buf = [0u8; MAX_VARINT_LEN];
    buf[0] = read_first_varint_byte(reader)?;
    let n = encoded_len_from_first(buf[0]);
    if n > 1 {
        read_varint_tail_bytes(reader, &mut buf[1..n])?;
    }
    validate_decoded_varint(decode_varint_value(&buf, n), n)
}

#[inline]
pub fn encode_varint(v: u64) -> Result<Vec<u8>> {
    let mut buf = [0u8; MAX_VARINT_LEN];
    let n = encode_varint_to_slice(&mut buf, v)?;
    let mut out = Vec::new();
    out.try_reserve_exact(n)
        .map_err(|_| Error::local("zmux: varint allocation failed"))?;
    out.extend_from_slice(&buf[..n]);
    Ok(out)
}

pub(crate) fn read_exact_checked<R: io::Read>(
    reader: &mut R,
    mut dst: &mut [u8],
) -> io::Result<()> {
    while !dst.is_empty() {
        match reader.read(dst) {
            Ok(0) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
            Ok(n) if n <= dst.len() => {
                let remaining = dst;
                dst = &mut remaining[n..];
            }
            Ok(_) => {
                return Err(io::Error::new(
                    io::ErrorKind::InvalidData,
                    "reader reported invalid progress",
                ));
            }
            Err(err) if err.kind() == io::ErrorKind::Interrupted => {}
            Err(err) => return Err(err),
        }
    }
    Ok(())
}

#[inline]
fn encoded_len_from_first(first: u8) -> usize {
    1usize << (first >> 6)
}

#[inline]
fn canonical_varint_len(v: u64) -> usize {
    match v {
        0..=63 => 1,
        64..=16_383 => 2,
        16_384..=1_073_741_823 => 4,
        _ => 8,
    }
}

#[inline]
fn decode_varint_value(src: &[u8], n: usize) -> u64 {
    debug_assert!(src.len() >= n);
    match n {
        1 => (src[0] & 0x3f) as u64,
        2 => (((src[0] & 0x3f) as u64) << 8) | src[1] as u64,
        4 => {
            (((src[0] & 0x3f) as u64) << 24)
                | ((src[1] as u64) << 16)
                | ((src[2] as u64) << 8)
                | src[3] as u64
        }
        8 => {
            (((src[0] & 0x3f) as u64) << 56)
                | ((src[1] as u64) << 48)
                | ((src[2] as u64) << 40)
                | ((src[3] as u64) << 32)
                | ((src[4] as u64) << 24)
                | ((src[5] as u64) << 16)
                | ((src[6] as u64) << 8)
                | src[7] as u64
        }
        _ => unreachable!("varint62 prefix produces only 1, 2, 4, or 8 byte lengths"),
    }
}

#[inline]
fn validate_decoded_varint(value: u64, n: usize) -> Result<(u64, usize)> {
    if value > MAX_VARINT62 {
        return Err(varint_wire_error("varint62 value out of range"));
    }
    if canonical_varint_len(value) != n {
        return Err(varint_wire_error("non-canonical varint62"));
    }
    Ok((value, n))
}

#[inline]
fn read_first_varint_byte<R: io::Read>(reader: &mut R) -> Result<u8> {
    let mut byte = [0u8; 1];
    read_exact_checked(reader, &mut byte).map_err(|err| {
        if err.kind() == io::ErrorKind::UnexpectedEof {
            varint_wire_error("truncated varint62")
        } else {
            Error::from(err)
        }
    })?;
    Ok(byte[0])
}

#[inline]
fn read_varint_tail_bytes<R: io::Read>(reader: &mut R, dst: &mut [u8]) -> Result<()> {
    read_exact_checked(reader, dst).map_err(|err| {
        if err.kind() == io::ErrorKind::UnexpectedEof {
            varint_wire_error("truncated varint62")
        } else {
            Error::from(err)
        }
    })
}

fn varint_wire_error(message: &'static str) -> Error {
    Error::protocol(message)
        .with_scope(ErrorScope::Session)
        .with_operation(ErrorOperation::Read)
        .with_source(ErrorSource::Remote)
        .with_direction(ErrorDirection::Read)
}

#[cfg(test)]
mod tests {
    use super::{
        append_varint, encode_varint, encode_varint_to_slice, parse_varint, read_varint,
        varint_len, MAX_VARINT62,
    };
    use crate::error::{Error, ErrorCode, ErrorDirection, ErrorOperation, ErrorScope, ErrorSource};
    use std::io;

    #[test]
    fn varint_len_and_round_trip_cover_encoding_boundaries() {
        let cases = [
            (0, 1, &[0x00][..]),
            (63, 1, &[0x3f][..]),
            (64, 2, &[0x40, 0x40][..]),
            (16_383, 2, &[0x7f, 0xff][..]),
            (16_384, 4, &[0x80, 0x00, 0x40, 0x00][..]),
            (1_073_741_823, 4, &[0xbf, 0xff, 0xff, 0xff][..]),
            (
                1_073_741_824,
                8,
                &[0xc0, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00][..],
            ),
            (
                MAX_VARINT62,
                8,
                &[0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff][..],
            ),
        ];

        for (value, expected_len, expected_bytes) in cases {
            assert_eq!(varint_len(value).unwrap(), expected_len);
            let encoded = encode_varint(value).unwrap();
            assert_eq!(encoded, expected_bytes);
            assert_eq!(parse_varint(&encoded).unwrap(), (value, expected_len));
            let mut reader = io::Cursor::new(&encoded);
            assert_eq!(read_varint(&mut reader).unwrap(), (value, expected_len));

            let mut appended = Vec::new();
            append_varint(&mut appended, value).unwrap();
            assert_eq!(appended, expected_bytes);
        }
    }

    #[test]
    fn out_of_range_and_too_small_destination_do_not_write_partial_data() {
        let mut dst = vec![0xaa];
        let err = append_varint(&mut dst, MAX_VARINT62 + 1).unwrap_err();
        assert_eq!(err.code(), Some(ErrorCode::Protocol));
        assert!(err.is_protocol_message("varint62 value out of range"));
        assert_eq!(dst, [0xaa]);

        let wrapped: io::Error = err.clone().into();
        let preserved = wrapped
            .get_ref()
            .and_then(|cause| cause.downcast_ref::<Error>())
            .expect("structured varint error should be preserved inside io::Error");
        assert_eq!(preserved.code(), Some(ErrorCode::Protocol));
        assert!(preserved.is_protocol_message("varint62 value out of range"));

        let mut short = [0u8; 1];
        let err = encode_varint_to_slice(&mut short, 64).unwrap_err();
        assert_eq!(err.code(), Some(ErrorCode::FrameSize));
        assert!(err.to_string().contains("varint destination too small"));
    }

    #[test]
    fn parse_varint_rejects_truncated_and_non_canonical_encodings() {
        for raw in [&[][..], &[0x40][..], &[0x80, 0x00, 0x00][..]] {
            let err = parse_varint(raw).unwrap_err();
            assert_eq!(err.code(), Some(ErrorCode::Protocol));
            assert!(err.is_protocol_message("truncated varint62"));
            assert_varint_wire_error(&err);
        }

        for raw in [&[0x40, 0x01][..], &[0x80, 0x00, 0x00, 0x01][..]] {
            let err = parse_varint(raw).unwrap_err();
            assert_eq!(err.code(), Some(ErrorCode::Protocol));
            assert!(err.is_protocol_message("non-canonical varint62"));
            assert_varint_wire_error(&err);
        }
    }

    #[test]
    fn read_varint_maps_eof_to_structured_truncated_protocol_error() {
        let mut empty = io::Cursor::new(&[][..]);
        let err = read_varint(&mut empty).unwrap_err();
        assert_eq!(err.code(), Some(ErrorCode::Protocol));
        assert!(err.is_protocol_message("truncated varint62"));
        assert_varint_wire_error(&err);
        assert_eq!(err.source_io_error_kind(), None);

        for raw in [&[0x40][..], &[0x80, 0x00, 0x00][..]] {
            let mut cursor = io::Cursor::new(raw);
            let err = read_varint(&mut cursor).unwrap_err();
            assert_eq!(err.code(), Some(ErrorCode::Protocol));
            assert!(err.is_protocol_message("truncated varint62"));
            assert_varint_wire_error(&err);
            assert_eq!(err.source_io_error_kind(), None);
        }
    }

    #[test]
    fn read_varint_preserves_non_eof_io_errors() {
        struct FailingReader;

        impl io::Read for FailingReader {
            fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
                Err(io::Error::from(io::ErrorKind::TimedOut))
            }
        }

        let mut reader = FailingReader;
        let err = read_varint(&mut reader).unwrap_err();
        assert_eq!(err.source_io_error_kind(), Some(io::ErrorKind::TimedOut));
    }

    fn assert_varint_wire_error(err: &Error) {
        assert_eq!(err.scope(), ErrorScope::Session);
        assert_eq!(err.operation(), ErrorOperation::Read);
        assert_eq!(err.source(), ErrorSource::Remote);
        assert_eq!(err.direction(), ErrorDirection::Read);
    }
}