use prost::Message;
use crate::error::{QrpcError, QrpcResult};
pub(crate) const START_MAGIC: [u8; 4] = *b"QRPC";
pub(crate) const END_MAGIC: [u8; 4] = *b"CPRQ";
pub(crate) const MAX_PACKET_SIZE: usize = 4 * 1024 * 1024;
pub(crate) const BROADCAST_TARGET: &str = "*";
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum PacketKind {
Register,
Data,
Disconnect,
}
#[derive(Clone, PartialEq, Message)]
struct ProtoPacket {
#[prost(enumeration = "ProtoPacketKind", tag = "1")]
kind: i32,
#[prost(string, tag = "2")]
source_id: String,
#[prost(string, tag = "3")]
target_id: String,
#[prost(uint32, tag = "4")]
cmd_id: u32,
#[prost(bytes, tag = "5")]
payload: Vec<u8>,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, prost::Enumeration)]
#[repr(i32)]
enum ProtoPacketKind {
Register = 0,
Data = 1,
Disconnect = 2,
}
impl From<PacketKind> for ProtoPacketKind {
fn from(value: PacketKind) -> Self {
match value {
PacketKind::Register => Self::Register,
PacketKind::Data => Self::Data,
PacketKind::Disconnect => Self::Disconnect,
}
}
}
impl TryFrom<i32> for PacketKind {
type Error = QrpcError;
fn try_from(value: i32) -> Result<Self, Self::Error> {
match ProtoPacketKind::try_from(value) {
Ok(ProtoPacketKind::Register) => Ok(Self::Register),
Ok(ProtoPacketKind::Data) => Ok(Self::Data),
Ok(ProtoPacketKind::Disconnect) => Ok(Self::Disconnect),
Err(_) => Err(QrpcError::MessageDecode(format!(
"invalid packet kind value: {value}"
))),
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct WirePacket {
pub(crate) kind: PacketKind,
pub(crate) source_id: String,
pub(crate) target_id: String,
pub(crate) cmd_id: u32,
pub(crate) payload: Vec<u8>,
}
impl WirePacket {
pub(crate) fn register(source_id: impl Into<String>) -> Self {
Self {
kind: PacketKind::Register,
source_id: source_id.into(),
target_id: String::new(),
cmd_id: 0,
payload: Vec::new(),
}
}
pub(crate) fn disconnect(source_id: impl Into<String>, target_id: impl Into<String>) -> Self {
Self {
kind: PacketKind::Disconnect,
source_id: source_id.into(),
target_id: target_id.into(),
cmd_id: 0,
payload: Vec::new(),
}
}
pub(crate) fn data(
source_id: impl Into<String>,
target_id: impl Into<String>,
cmd_id: u32,
payload: Vec<u8>,
) -> Self {
Self {
kind: PacketKind::Data,
source_id: source_id.into(),
target_id: target_id.into(),
cmd_id,
payload,
}
}
pub(crate) fn encode_frame(&self) -> Vec<u8> {
let proto = ProtoPacket {
kind: ProtoPacketKind::from(self.kind) as i32,
source_id: self.source_id.clone(),
target_id: self.target_id.clone(),
cmd_id: self.cmd_id,
payload: self.payload.clone(),
};
let mut body = Vec::with_capacity(proto.encoded_len());
proto
.encode(&mut body)
.expect("encoding to Vec should not fail");
let mut out = Vec::with_capacity(START_MAGIC.len() + 4 + body.len() + END_MAGIC.len());
out.extend_from_slice(&START_MAGIC);
out.extend_from_slice(&(body.len() as u32).to_be_bytes());
out.extend_from_slice(&body);
out.extend_from_slice(&END_MAGIC);
out
}
pub(crate) fn decode_frame(bytes: &[u8]) -> QrpcResult<Self> {
if bytes.len() < START_MAGIC.len() + END_MAGIC.len() + 4 {
return Err(QrpcError::FrameTooShort);
}
if bytes[0..4] != START_MAGIC {
return Err(QrpcError::InvalidMagic);
}
if bytes[bytes.len() - 4..bytes.len()] != END_MAGIC {
return Err(QrpcError::InvalidMagic);
}
let body_len = u32::from_be_bytes(
bytes[4..8]
.try_into()
.map_err(|_| QrpcError::FrameTooShort)?,
) as usize;
let expected_len = 4 + 4 + body_len + 4;
if bytes.len() != expected_len {
return Err(QrpcError::MessageDecode(format!(
"body length mismatch, declared={body_len}, actual={}",
bytes.len() - 12
)));
}
let proto = ProtoPacket::decode(&bytes[8..8 + body_len])?;
let kind = PacketKind::try_from(proto.kind)?;
Ok(Self {
kind,
source_id: proto.source_id,
target_id: proto.target_id,
cmd_id: proto.cmd_id,
payload: proto.payload,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn packet_roundtrip_works() {
let packet = WirePacket::data("a", "b", 12, b"hello".to_vec());
let encoded = packet.encode_frame();
let decoded = WirePacket::decode_frame(&encoded).expect("must decode");
assert_eq!(decoded.kind, PacketKind::Data);
assert_eq!(decoded.source_id, "a");
assert_eq!(decoded.target_id, "b");
assert_eq!(decoded.cmd_id, 12);
assert_eq!(decoded.payload, b"hello");
}
#[test]
fn invalid_magic_fails() {
let mut encoded = WirePacket::data("a", "b", 12, vec![]).encode_frame();
encoded[0] = b'X';
assert!(matches!(
WirePacket::decode_frame(&encoded),
Err(QrpcError::InvalidMagic)
));
}
#[test]
fn invalid_length_fails() {
let mut encoded = WirePacket::data("a", "b", 12, vec![1, 2]).encode_frame();
encoded[7] = 250;
assert!(WirePacket::decode_frame(&encoded).is_err());
}
}