qrpc 0.1.2

qrpc is a small QUIC + mTLS messaging library
Documentation
use prost::Message;

use crate::error::{QrpcError, QrpcResult};

/// Fixed frame prefix.
pub(crate) const START_MAGIC: [u8; 4] = *b"QRPC";
/// Fixed frame suffix.
pub(crate) const END_MAGIC: [u8; 4] = *b"CPRQ";
/// Maximum allowed frame body size.
pub(crate) const MAX_PACKET_SIZE: usize = 4 * 1024 * 1024;
/// Special target for broadcast packets.
pub(crate) const BROADCAST_TARGET: &str = "*";

/// Wire-level packet kind.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum PacketKind {
    /// Register handshake packet.
    Register,
    /// Normal user data packet.
    Data,
    /// Disconnect packet.
    Disconnect,
}

/// Protobuf body structure encoded inside framed packets.
#[derive(Clone, PartialEq, Message)]
struct ProtoPacket {
    #[prost(enumeration = "ProtoPacketKind", tag = "1")]
    kind: i32,
    #[prost(string, tag = "2")]
    source_id: String,
    #[prost(string, tag = "3")]
    target_id: String,
    #[prost(uint32, tag = "4")]
    cmd_id: u32,
    #[prost(bytes, tag = "5")]
    payload: Vec<u8>,
}

/// Protobuf packet kind mirror.
#[derive(Clone, Copy, Debug, Eq, PartialEq, prost::Enumeration)]
#[repr(i32)]
enum ProtoPacketKind {
    Register = 0,
    Data = 1,
    Disconnect = 2,
}

impl From<PacketKind> for ProtoPacketKind {
    fn from(value: PacketKind) -> Self {
        match value {
            PacketKind::Register => Self::Register,
            PacketKind::Data => Self::Data,
            PacketKind::Disconnect => Self::Disconnect,
        }
    }
}

impl TryFrom<i32> for PacketKind {
    type Error = QrpcError;

    fn try_from(value: i32) -> Result<Self, Self::Error> {
        match ProtoPacketKind::try_from(value) {
            Ok(ProtoPacketKind::Register) => Ok(Self::Register),
            Ok(ProtoPacketKind::Data) => Ok(Self::Data),
            Ok(ProtoPacketKind::Disconnect) => Ok(Self::Disconnect),
            Err(_) => Err(QrpcError::MessageDecode(format!(
                "invalid packet kind value: {value}"
            ))),
        }
    }
}

/// Runtime wire packet used across transport internals.
#[derive(Clone, Debug)]
pub(crate) struct WirePacket {
    pub(crate) kind: PacketKind,
    pub(crate) source_id: String,
    pub(crate) target_id: String,
    pub(crate) cmd_id: u32,
    pub(crate) payload: Vec<u8>,
}

impl WirePacket {
    /// Creates a register packet.
    pub(crate) fn register(source_id: impl Into<String>) -> Self {
        Self {
            kind: PacketKind::Register,
            source_id: source_id.into(),
            target_id: String::new(),
            cmd_id: 0,
            payload: Vec::new(),
        }
    }

    /// Creates a disconnect packet.
    pub(crate) fn disconnect(source_id: impl Into<String>, target_id: impl Into<String>) -> Self {
        Self {
            kind: PacketKind::Disconnect,
            source_id: source_id.into(),
            target_id: target_id.into(),
            cmd_id: 0,
            payload: Vec::new(),
        }
    }

    /// Creates a data packet.
    pub(crate) fn data(
        source_id: impl Into<String>,
        target_id: impl Into<String>,
        cmd_id: u32,
        payload: Vec<u8>,
    ) -> Self {
        Self {
            kind: PacketKind::Data,
            source_id: source_id.into(),
            target_id: target_id.into(),
            cmd_id,
            payload,
        }
    }

    /// Encodes packet into framed bytes: `START_MAGIC + len + proto + END_MAGIC`.
    pub(crate) fn encode_frame(&self) -> Vec<u8> {
        let proto = ProtoPacket {
            kind: ProtoPacketKind::from(self.kind) as i32,
            source_id: self.source_id.clone(),
            target_id: self.target_id.clone(),
            cmd_id: self.cmd_id,
            payload: self.payload.clone(),
        };

        let mut body = Vec::with_capacity(proto.encoded_len());
        proto
            .encode(&mut body)
            .expect("encoding to Vec should not fail");

        let mut out = Vec::with_capacity(START_MAGIC.len() + 4 + body.len() + END_MAGIC.len());
        out.extend_from_slice(&START_MAGIC);
        out.extend_from_slice(&(body.len() as u32).to_be_bytes());
        out.extend_from_slice(&body);
        out.extend_from_slice(&END_MAGIC);
        out
    }

    /// Decodes one framed packet.
    pub(crate) fn decode_frame(bytes: &[u8]) -> QrpcResult<Self> {
        if bytes.len() < START_MAGIC.len() + END_MAGIC.len() + 4 {
            return Err(QrpcError::FrameTooShort);
        }
        if bytes[0..4] != START_MAGIC {
            return Err(QrpcError::InvalidMagic);
        }
        if bytes[bytes.len() - 4..bytes.len()] != END_MAGIC {
            return Err(QrpcError::InvalidMagic);
        }

        let body_len = u32::from_be_bytes(
            bytes[4..8]
                .try_into()
                .map_err(|_| QrpcError::FrameTooShort)?,
        ) as usize;
        let expected_len = 4 + 4 + body_len + 4;
        if bytes.len() != expected_len {
            return Err(QrpcError::MessageDecode(format!(
                "body length mismatch, declared={body_len}, actual={}",
                bytes.len() - 12
            )));
        }

        let proto = ProtoPacket::decode(&bytes[8..8 + body_len])?;
        let kind = PacketKind::try_from(proto.kind)?;

        Ok(Self {
            kind,
            source_id: proto.source_id,
            target_id: proto.target_id,
            cmd_id: proto.cmd_id,
            payload: proto.payload,
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn packet_roundtrip_works() {
        let packet = WirePacket::data("a", "b", 12, b"hello".to_vec());
        let encoded = packet.encode_frame();
        let decoded = WirePacket::decode_frame(&encoded).expect("must decode");

        assert_eq!(decoded.kind, PacketKind::Data);
        assert_eq!(decoded.source_id, "a");
        assert_eq!(decoded.target_id, "b");
        assert_eq!(decoded.cmd_id, 12);
        assert_eq!(decoded.payload, b"hello");
    }

    #[test]
    fn invalid_magic_fails() {
        let mut encoded = WirePacket::data("a", "b", 12, vec![]).encode_frame();
        encoded[0] = b'X';
        assert!(matches!(
            WirePacket::decode_frame(&encoded),
            Err(QrpcError::InvalidMagic)
        ));
    }

    #[test]
    fn invalid_length_fails() {
        let mut encoded = WirePacket::data("a", "b", 12, vec![1, 2]).encode_frame();
        encoded[7] = 250;
        assert!(WirePacket::decode_frame(&encoded).is_err());
    }
}