Skip to main content

zccache_protocol/
lib.rs

1//! IPC protocol types and serialization for zccache.
2//!
3//! Defines the message types exchanged between CLI/wrapper and daemon,
4//! and provides serialization/deserialization using bincode.
5
6pub mod messages;
7
8pub use messages::*;
9
10/// Protocol version number. Bump this when the wire format changes:
11/// new/removed/reordered enum variants or struct field changes.
12/// Patch releases that don't change the protocol keep the same version.
13pub const PROTOCOL_VERSION: u32 = 5;
14
15use bytes::{Buf, BufMut, BytesMut};
16
17/// Serialize a message to a length-prefixed byte buffer with protocol version.
18///
19/// Format: `[4-byte LE length][4-byte LE protocol version][bincode payload]`
20///
21/// The length field covers the protocol version + payload bytes.
22///
23/// # Errors
24///
25/// Returns an error if serialization fails.
26pub fn encode_message<T: serde::Serialize>(msg: &T) -> Result<BytesMut, ProtocolError> {
27    let payload =
28        bincode::serialize(msg).map_err(|e| ProtocolError::Serialization(e.to_string()))?;
29    let frame_len: u32 = (4 + payload.len())
30        .try_into()
31        .map_err(|_| ProtocolError::MessageTooLarge(payload.len()))?;
32
33    let mut buf = BytesMut::with_capacity(4 + 4 + payload.len());
34    buf.put_u32_le(frame_len);
35    buf.put_u32_le(PROTOCOL_VERSION);
36    buf.extend_from_slice(&payload);
37    Ok(buf)
38}
39
40/// Try to decode a message from a byte buffer.
41///
42/// Returns `None` if the buffer does not contain a complete message.
43/// Advances the buffer past the consumed message on success.
44///
45/// # Errors
46///
47/// Returns `VersionMismatch` if the sender's protocol version differs.
48/// Returns a deserialization error if the payload is malformed.
49pub fn decode_message<T: serde::de::DeserializeOwned>(
50    buf: &mut BytesMut,
51) -> Result<Option<T>, ProtocolError> {
52    if buf.len() < 4 {
53        return Ok(None);
54    }
55
56    let len = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
57
58    if len > MAX_MESSAGE_SIZE {
59        return Err(ProtocolError::MessageTooLarge(len));
60    }
61
62    if buf.len() < 4 + len {
63        return Ok(None);
64    }
65
66    if len < 4 {
67        return Err(ProtocolError::Deserialization(
68            "frame too small for protocol version".into(),
69        ));
70    }
71
72    buf.advance(4);
73    let frame = buf.split_to(len);
74
75    let remote_ver = u32::from_le_bytes([frame[0], frame[1], frame[2], frame[3]]);
76    if remote_ver != PROTOCOL_VERSION {
77        return Err(ProtocolError::VersionMismatch {
78            expected: PROTOCOL_VERSION,
79            received: remote_ver,
80        });
81    }
82
83    let msg = bincode::deserialize(&frame[4..])
84        .map_err(|e| ProtocolError::Deserialization(e.to_string()))?;
85    Ok(Some(msg))
86}
87
88/// Maximum message size (16 MB).
89const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
90
91/// Protocol-level errors.
92#[derive(Debug, thiserror::Error)]
93pub enum ProtocolError {
94    #[error("serialization error: {0}")]
95    Serialization(String),
96
97    #[error("deserialization error: {0}")]
98    Deserialization(String),
99
100    #[error("message too large: {0} bytes")]
101    MessageTooLarge(usize),
102
103    #[error(
104        "protocol version mismatch: expected v{expected}, received v{received}. \
105         Run `zccache stop` first."
106    )]
107    VersionMismatch { expected: u32, received: u32 },
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    #[test]
115    fn encode_decode_roundtrip() {
116        let msg = messages::Request::Ping;
117        let encoded = encode_message(&msg).unwrap();
118        let mut buf = BytesMut::from(&encoded[..]);
119        let decoded: Option<messages::Request> = decode_message(&mut buf).unwrap();
120        assert_eq!(decoded, Some(messages::Request::Ping));
121        assert!(buf.is_empty());
122    }
123
124    #[test]
125    fn frame_includes_protocol_version() {
126        let encoded = encode_message(&messages::Request::Ping).unwrap();
127        // Bytes 4..8 should be PROTOCOL_VERSION in LE
128        let ver = u32::from_le_bytes([encoded[4], encoded[5], encoded[6], encoded[7]]);
129        assert_eq!(ver, PROTOCOL_VERSION);
130    }
131
132    #[test]
133    fn version_mismatch_returns_error() {
134        let mut encoded = encode_message(&messages::Request::Ping).unwrap();
135        // Overwrite protocol version with a different value
136        let bad_ver: u32 = PROTOCOL_VERSION + 1;
137        encoded[4..8].copy_from_slice(&bad_ver.to_le_bytes());
138
139        let mut buf = BytesMut::from(&encoded[..]);
140        let result: Result<Option<messages::Request>, _> = decode_message(&mut buf);
141        assert!(matches!(result, Err(ProtocolError::VersionMismatch { .. })));
142    }
143
144    #[test]
145    fn old_frame_without_protocol_version_fails() {
146        // Simulate an old-format frame: [len][payload] with no protocol version.
147        // Build a raw old-style frame (4-byte len + bincode payload, no proto ver).
148        let payload = bincode::serialize(&messages::Request::Ping).unwrap();
149        let len = payload.len() as u32;
150        let mut buf = BytesMut::with_capacity(4 + payload.len());
151        buf.put_u32_le(len);
152        buf.extend_from_slice(&payload);
153
154        let result: Result<Option<messages::Request>, _> = decode_message(&mut buf);
155        // Either VersionMismatch (garbage proto ver) or Deserialization error —
156        // either way, it must not succeed.
157        assert!(
158            result.is_err(),
159            "old-format frame must not decode successfully"
160        );
161    }
162
163    #[test]
164    fn incomplete_frame_returns_none() {
165        let encoded = encode_message(&messages::Request::Ping).unwrap();
166        // Provide only part of the frame
167        let mut buf = BytesMut::from(&encoded[..encoded.len() - 1]);
168        let result: Option<messages::Request> = decode_message(&mut buf).unwrap();
169        assert!(result.is_none());
170    }
171}