pub mod messages;
pub mod wire_frame;
pub mod wire_prost;
pub use messages::*;
pub const BINCODE_PROTOCOL_VERSION: u32 = 15;
pub const PROST_PROTOCOL_VERSION: u32 = 16;
pub const PROTOCOL_VERSION: u32 = BINCODE_PROTOCOL_VERSION;
use bytes::{Buf, BufMut, BytesMut};
use prost::Message as ProstMessage;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DecodedWireMessage<Bincode, Prost> {
BincodeV15(Bincode),
ProstV16(Prost),
FrameV1 {
message: Prost,
request_id: u64,
},
}
impl<Bincode, Prost> DecodedWireMessage<Bincode, Prost> {
#[must_use]
pub const fn wire_format(&self) -> wire_prost::WireFormat {
match self {
Self::BincodeV15(_) => wire_prost::WireFormat::BincodeV15,
Self::ProstV16(_) => wire_prost::WireFormat::ProstV16,
Self::FrameV1 { .. } => wire_prost::WireFormat::FrameV1,
}
}
}
pub fn encode_message<T: serde::Serialize>(msg: &T) -> Result<BytesMut, ProtocolError> {
encode_bincode_message(msg)
}
pub fn encode_bincode_message<T: serde::Serialize>(msg: &T) -> Result<BytesMut, ProtocolError> {
let payload =
bincode::serialize(msg).map_err(|e| ProtocolError::Serialization(e.to_string()))?;
let frame_len: u32 = (4 + payload.len())
.try_into()
.map_err(|_| ProtocolError::MessageTooLarge(payload.len()))?;
let mut buf = BytesMut::with_capacity(4 + 4 + payload.len());
buf.put_u32_le(frame_len);
buf.put_u32_le(BINCODE_PROTOCOL_VERSION);
buf.extend_from_slice(&payload);
Ok(buf)
}
pub fn decode_message<T: serde::de::DeserializeOwned>(
buf: &mut BytesMut,
) -> Result<Option<T>, ProtocolError> {
decode_bincode_message(buf)
}
pub fn decode_bincode_message<T: serde::de::DeserializeOwned>(
buf: &mut BytesMut,
) -> Result<Option<T>, ProtocolError> {
let Some((remote_ver, payload)) = take_complete_frame(buf)? else {
return Ok(None);
};
if remote_ver != BINCODE_PROTOCOL_VERSION {
return Err(ProtocolError::VersionMismatch {
expected: BINCODE_PROTOCOL_VERSION,
received: remote_ver,
});
}
let msg = bincode::deserialize(&payload[..])
.map_err(|e| ProtocolError::Deserialization(e.to_string()))?;
Ok(Some(msg))
}
pub fn decode_wire_message<Bincode, Prost>(
buf: &mut BytesMut,
) -> Result<Option<DecodedWireMessage<Bincode, Prost>>, ProtocolError>
where
Bincode: serde::de::DeserializeOwned,
Prost: ProstMessage + Default,
{
match wire_frame::buffer_starts_running_process_frame(buf) {
None => return Ok(None),
Some(true) => {
return wire_frame::decode_frame_v1_message(buf).map(|decoded| {
decoded.map(|frame| DecodedWireMessage::FrameV1 {
message: frame.message,
request_id: frame.request_id,
})
});
}
Some(false) => {}
}
let Some(version) = peek_frame_protocol_version(buf)? else {
return Ok(None);
};
match wire_prost::wire_format_for_protocol_version(version) {
Some(wire_prost::WireFormat::BincodeV15) => {
decode_bincode_message(buf).map(|msg| msg.map(DecodedWireMessage::BincodeV15))
}
Some(wire_prost::WireFormat::ProstV16) => {
wire_prost::decode_prost_message(buf).map(|msg| msg.map(DecodedWireMessage::ProstV16))
}
Some(wire_prost::WireFormat::FrameV1) | None => Err(ProtocolError::VersionMismatch {
expected: PROST_PROTOCOL_VERSION,
received: version,
}),
}
}
pub fn peek_frame_protocol_version(buf: &BytesMut) -> Result<Option<u32>, ProtocolError> {
if buf.len() < 4 {
return Ok(None);
}
let len = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
if len > MAX_MESSAGE_SIZE {
return Err(ProtocolError::MessageTooLarge(len));
}
if len < 4 {
return Err(ProtocolError::Deserialization(
"frame too small for protocol version".into(),
));
}
if buf.len() < 4 + len {
return Ok(None);
}
Ok(Some(u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]])))
}
fn take_complete_frame(buf: &mut BytesMut) -> Result<Option<(u32, BytesMut)>, ProtocolError> {
if buf.len() < 4 {
return Ok(None);
}
let len = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
if len > MAX_MESSAGE_SIZE {
return Err(ProtocolError::MessageTooLarge(len));
}
if len < 4 {
return Err(ProtocolError::Deserialization(
"frame too small for protocol version".into(),
));
}
if buf.len() < 4 + len {
return Ok(None);
}
buf.advance(4);
let mut frame = buf.split_to(len);
let remote_ver = frame.get_u32_le();
Ok(Some((remote_ver, frame)))
}
const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
#[derive(Debug, thiserror::Error)]
pub enum ProtocolError {
#[error("serialization error: {0}")]
Serialization(String),
#[error("deserialization error: {0}")]
Deserialization(String),
#[error("message too large: {0} bytes")]
MessageTooLarge(usize),
#[error(
"protocol version mismatch: expected v{expected}, received v{received}. \
Run `zccache stop` first."
)]
VersionMismatch { expected: u32, received: u32 },
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_decode_roundtrip() {
let msg = messages::Request::Ping;
let encoded = encode_message(&msg).unwrap();
let mut buf = BytesMut::from(&encoded[..]);
let decoded: Option<messages::Request> = decode_message(&mut buf).unwrap();
assert_eq!(decoded, Some(messages::Request::Ping));
assert!(buf.is_empty());
}
#[test]
fn frame_includes_protocol_version() {
let encoded = encode_message(&messages::Request::Ping).unwrap();
let ver = u32::from_le_bytes([encoded[4], encoded[5], encoded[6], encoded[7]]);
assert_eq!(ver, PROTOCOL_VERSION);
}
#[test]
fn version_mismatch_returns_error() {
let mut encoded = encode_message(&messages::Request::Ping).unwrap();
let bad_ver: u32 = PROTOCOL_VERSION + 1;
encoded[4..8].copy_from_slice(&bad_ver.to_le_bytes());
let mut buf = BytesMut::from(&encoded[..]);
let result: Result<Option<messages::Request>, _> = decode_message(&mut buf);
assert!(matches!(result, Err(ProtocolError::VersionMismatch { .. })));
}
#[test]
fn old_frame_without_protocol_version_fails() {
let payload = bincode::serialize(&messages::Request::Ping).unwrap();
let len = payload.len() as u32;
let mut buf = BytesMut::with_capacity(4 + payload.len());
buf.put_u32_le(len);
buf.extend_from_slice(&payload);
let result: Result<Option<messages::Request>, _> = decode_message(&mut buf);
assert!(
result.is_err(),
"old-format frame must not decode successfully"
);
}
#[test]
fn incomplete_frame_returns_none() {
let encoded = encode_message(&messages::Request::Ping).unwrap();
let mut buf = BytesMut::from(&encoded[..encoded.len() - 1]);
let result: Option<messages::Request> = decode_message(&mut buf).unwrap();
assert!(result.is_none());
}
}