Skip to main content

hoy_protocol/
codec.rs

1//! Message stream codec module.
2
3use serde::Serialize;
4use serde::de::DeserializeOwned;
5
6use crate::error::ProtocolError;
7
8/// Protocol frame header length
9const FRAME_HEADER_LEN: usize = 4;
10
11/**
12 * Encode a serializable value into a length-prefixed protocol frame.
13 *
14 * The wire format is:
15 * - 4-byte big-endian payload length (`u32`)
16 * - JSON payload bytes
17 *
18 * # Returns
19 * `Ok(Vec<u8>)` on succesfull frame encoding.
20 *
21 * # Errors
22 * Returns `ProtocolError` if:
23 * - serialization fails,
24 * - the serialized payload is too large to fit into a `u32` length prefix,
25 * - the required output buffer capacity would overflow `usize`.
26 */
27pub fn encode_frame(value: &impl Serialize) -> Result<Vec<u8>, ProtocolError> {
28    let payload: Vec<u8> = serde_json::to_vec(value)?;
29
30    let payload_len_u32: u32 = u32::try_from(payload.len()).map_err(|e| {
31        let _ = e;
32        ProtocolError::FrameTooLarge {
33            size: payload.len(),
34        }
35    })?;
36
37    let frame_capacity: usize = FRAME_HEADER_LEN
38        .checked_add(payload.len())
39        .ok_or(ProtocolError::CapacityOverflow)?;
40
41    let mut frame: Vec<u8> = Vec::with_capacity(frame_capacity);
42    frame.extend_from_slice(&payload_len_u32.to_be_bytes());
43    frame.extend_from_slice(&payload);
44
45    Ok(frame)
46}
47
48/**
49 * Decode a length-prefixed protocol frame into a value.
50 *
51 * The input must contain a complete frame:
52 * - 4-byte big-endian payload length (`u32`)
53 * - exactly that many payload bytes
54 *
55 * Extra trailing bytes after the declared payload are ignored by this helper.
56 * A streaming decoder can later handle multi-frame buffers more precisely.
57 *
58 * # Returns
59 * `Ok(impl: DeserializeOwned)` decoded frame on success.
60 *
61 * # Errors
62 * Returns an error if:
63 * - the header is missing or malformed,
64 * - the payload is truncated,
65 * - the decoded frame length cannot be represented as `usize`,
66 * - or JSON deserialization fails.
67 */
68pub fn decode_frame<T>(frame: &[u8]) -> Result<T, ProtocolError>
69where
70    T: DeserializeOwned,
71{
72    match try_decode_frame(frame)? {
73        Some((value, _consumed)) => Ok(value),
74        None => Err(ProtocolError::TruncatedFrame),
75    }
76}
77
78/**
79 * Attempt to decode a single length-prefixed frame from the provided buffer.
80 *
81 * The wire format is:
82 * - 4-byte big-endian payload length (`u32`)
83 * - payload bytes encoded as JSON
84 *
85 * # Returns
86 * - `Ok(None)` if the buffer does not yet contain a full frame,
87 * - deserialized frame with with the number of consumed bytes if buffer contains a full frame.
88 *
89 * Trailing bytes after the decoded frame are not treated as an error.
90 * The caller is expected to keep them in the input buffer and pass them again when
91 * decoding subsequent frames.
92 *
93 * # Errors
94 * Returns `Err(ProtocolError)` an error if:
95 * - the decoded payload length cannot be represented as `usize`,
96 * - frame length arithmetic overflows,
97 * - payload contains invalid JSON for the requested type.
98 */
99pub fn try_decode_frame<T>(buffer: &[u8]) -> Result<Option<(T, usize)>, ProtocolError>
100where
101    T: DeserializeOwned,
102{
103    let header: &[u8] = match buffer.get(..FRAME_HEADER_LEN) {
104        Some(header) => header,
105        None => return Ok(None),
106    };
107
108    let header_array: [u8; FRAME_HEADER_LEN] = match <[u8; FRAME_HEADER_LEN]>::try_from(header) {
109        Ok(array) => array,
110        Err(header_error) => {
111            let _ = header_error;
112            return Ok(None);
113        }
114    };
115
116    let payload_len_u32: u32 = u32::from_be_bytes(header_array);
117
118    let payload_len: usize = match usize::try_from(payload_len_u32) {
119        Ok(len) => len,
120        Err(conversion_error) => {
121            let _ = conversion_error;
122            return Err(ProtocolError::FrameLengthOutOfRange {
123                length: payload_len_u32,
124            });
125        }
126    };
127
128    let frame_len: usize = match FRAME_HEADER_LEN.checked_add(payload_len) {
129        Some(len) => len,
130        None => return Err(ProtocolError::CapacityOverflow),
131    };
132
133    let payload: &[u8] = match buffer.get(FRAME_HEADER_LEN..frame_len) {
134        Some(pld) => pld,
135        None => return Ok(None),
136    };
137
138    let value: T = serde_json::from_slice(payload)?;
139
140    Ok(Some((value, frame_len)))
141}
142
143#[cfg(test)]
144#[allow(dead_code, unused)]
145mod tests {
146    use hoy_test::assert_err;
147    use serde::Serialize;
148    use serde::de::DeserializeOwned;
149
150    use crate::codec::{decode_frame, encode_frame, try_decode_frame};
151    use crate::error::ProtocolError;
152    use crate::packet::{ClientPacket, ServerPacket};
153
154    fn build_frame(payload: &[u8]) -> Vec<u8> {
155        let payload_len_u32 =
156            u32::try_from(payload.len()).expect("test payload length capacity overflow");
157
158        let frame_capacity: usize = 4_usize
159            .checked_add(payload.len())
160            .expect("test frame capacity overflow");
161
162        let mut frame: Vec<u8> = Vec::with_capacity(frame_capacity);
163        frame.extend_from_slice(&payload_len_u32.to_be_bytes());
164        frame.extend_from_slice(payload);
165        frame
166    }
167
168    fn encode_frame_ok(value: &impl Serialize) -> Vec<u8> {
169        encode_frame(&value).expect("Frame encoding failed unexpectedly.")
170    }
171
172    fn encode_frame_err(value: &impl Serialize, error: &str) -> ProtocolError {
173        encode_frame(&value).expect_err(&format!("Expected error: ${error}."))
174    }
175
176    fn decode_frame_ok<T>(frame: &[u8]) -> T
177    where
178        T: DeserializeOwned,
179    {
180        decode_frame(frame).expect("Frame deserialization failed unexpectedly.")
181    }
182
183    fn decode_frame_err<T>(frame: &[u8], error: &str) -> ProtocolError
184    where
185        T: DeserializeOwned + std::fmt::Debug,
186    {
187        decode_frame::<T>(frame).expect_err(&format!("Expected error: {error}."))
188    }
189
190    fn try_decode_frame_ok<T>(buffer: &[u8]) -> (T, usize)
191    where
192        T: DeserializeOwned,
193    {
194        try_decode_frame::<T>(buffer)
195            .expect("Unexpected failure while trying to deserialize frame from buffer.")
196            .expect("Frame deserialization should not return None.")
197    }
198
199    fn try_decode_frame_none<T>(buffer: &[u8]) -> Option<(T, usize)>
200    where
201        T: DeserializeOwned + std::fmt::Debug + PartialEq,
202    {
203        let result = try_decode_frame::<T>(buffer)
204            .expect("Unexpected failure while trying to deserialize frame from buffer.");
205        assert_eq!(result, None);
206        result
207    }
208
209    fn try_decode_frame_err<T>(buffer: &[u8], error: &str) -> ProtocolError
210    where
211        T: DeserializeOwned + std::fmt::Debug,
212    {
213        try_decode_frame::<T>(buffer).expect_err(&format!("Expected error: {error}."))
214    }
215
216    #[test]
217    fn encode_and_decode_client_packet_roundtrip() {
218        let packet = ClientPacket::Hello {
219            username: String::from("bruce_lee"),
220        };
221        let frame = encode_frame_ok(&packet);
222        let decoded: ClientPacket = decode_frame_ok(&frame);
223
224        assert_eq!(decoded, packet);
225    }
226
227    #[test]
228    fn encode_and_decode_server_packet_roundtrip() {
229        let packet: ServerPacket = ServerPacket::ChatMessage {
230            from: String::from("bruce_lee"),
231            room: String::from("#general"),
232            text: String::from("Kung foo..."),
233        };
234        let frame = encode_frame_ok(&packet);
235        let decoded: ServerPacket = decode_frame_ok(&frame);
236
237        assert_eq!(decoded, packet);
238    }
239
240    #[test]
241    fn decode_frame_rejects_truncated_header() {
242        let frame: Vec<u8> = vec![0, 0, 0];
243        let error = decode_frame_err::<ClientPacket>(&frame, "Truncated header");
244
245        assert_err!(error, ProtocolError::TruncatedFrame);
246    }
247
248    #[test]
249    fn decode_frame_rejects_truncated_payload() {
250        let declared_payload_len: u32 = 10;
251        let mut frame: Vec<u8> = Vec::new();
252        frame.extend_from_slice(&declared_payload_len.to_be_bytes());
253        frame.extend_from_slice(b"abc");
254        let error = decode_frame_err::<ClientPacket>(&frame, "Truncated payload");
255
256        assert_err!(error, ProtocolError::TruncatedFrame);
257    }
258
259    #[test]
260    fn decode_frame_rejects_invalid_json_payload() {
261        let frame: Vec<u8> = build_frame(b"this is not valid json");
262        let error = decode_frame_err::<ClientPacket>(&frame, "Serde error");
263
264        assert_err!(error, ProtocolError::Serde(_));
265    }
266
267    #[test]
268    fn decode_frame_rejects_json_of_wrong_packet_shape() {
269        let payload: &[u8] = br#"{"NotARealPacket":{"foo":"bar"}}"#;
270        let frame: Vec<u8> = build_frame(payload);
271        let error = decode_frame_err::<ClientPacket>(&frame, "Serde error");
272
273        assert_err!(error, ProtocolError::Serde(_));
274    }
275
276    #[test]
277    fn decode_frame_ignores_trailing_bytes_after_payload() {
278        let packet: ClientPacket = ClientPacket::Ping;
279        let mut frame = encode_frame_ok(&packet);
280        frame.extend_from_slice(b"trailing bytes that belong to a future frame");
281        let decoded: ClientPacket = decode_frame_ok(&frame);
282
283        assert_eq!(decoded, packet);
284    }
285
286    #[test]
287    fn decode_frame_accepts_empty_string_fields() {
288        let packet: ClientPacket = ClientPacket::Hello {
289            username: String::new(),
290        };
291        let frame = encode_frame_ok(&packet);
292        let decoded: ClientPacket = decode_frame_ok(&frame);
293
294        assert_eq!(decoded, packet);
295    }
296
297    #[test]
298    fn decode_frame_handles_utf8_content() {
299        let packet: ServerPacket = ServerPacket::SystemMessage {
300            text: String::from("Ahoj ^^ Привет こんにちは"),
301        };
302        let frame = encode_frame_ok(&packet);
303        let decoded: ServerPacket = decode_frame_ok(&frame);
304
305        assert_eq!(decoded, packet);
306    }
307
308    #[test]
309    fn try_decode_frame_returns_none_for_incomplete_header() {
310        let buffer: Vec<u8> = vec![0, 0, 0];
311
312        let result = try_decode_frame_none::<ClientPacket>(&buffer);
313
314        assert_eq!(result, None);
315    }
316
317    #[test]
318    fn try_decode_frame_returns_none_for_incomplete_payload() {
319        let declared_payload_len: u32 = 10;
320        let mut buffer: Vec<u8> = Vec::new();
321        buffer.extend_from_slice(&declared_payload_len.to_be_bytes());
322        buffer.extend_from_slice(b"abc");
323
324        let result = try_decode_frame_none::<ClientPacket>(&buffer);
325
326        assert_eq!(result, None);
327    }
328
329    #[test]
330    fn try_decode_frame_decodes_complete_frame() {
331        let packet = ClientPacket::Ping;
332        let frame: Vec<u8> = encode_frame_ok(&packet);
333
334        let (decoded, consumed) = try_decode_frame_ok::<ClientPacket>(&frame);
335        assert_eq!(decoded, packet);
336        assert_eq!(consumed, frame.len());
337    }
338
339    #[test]
340    fn try_decode_frame_reports_consumed_len_with_trailing_bytes() {
341        let packet = ClientPacket::Ping;
342        let mut buffer: Vec<u8> = encode_frame_ok(&packet);
343        let frame_len: usize = buffer.len();
344        buffer.extend_from_slice(b"trailing bytes");
345
346        let (decoded, consumed) = try_decode_frame_ok::<ClientPacket>(&buffer);
347        assert_eq!(decoded, packet);
348        assert_eq!(consumed, frame_len);
349    }
350
351    #[test]
352    fn try_decode_frame_rejects_invalid_complete_payload() {
353        let buffer: Vec<u8> = build_frame(b"this is not a valid json");
354
355        let err = try_decode_frame_err::<ClientPacket>(&buffer, "Serde error");
356
357        assert_err!(err, ProtocolError::Serde(_));
358    }
359
360    #[test]
361    fn try_decode_frame_only_decodes_1_frame() {
362        let packet1 = ClientPacket::Ping;
363        let packet2 = ClientPacket::Hello {
364            username: String::from("bruce_lee"),
365        };
366
367        let frame1 = encode_frame_ok(&packet1);
368        let frame2 = encode_frame_ok(&packet2);
369
370        let len1 = frame1.len();
371
372        let mut buffer: Vec<u8> = Vec::new();
373        buffer.extend_from_slice(&frame1);
374        buffer.extend_from_slice(&frame2);
375
376        let (decoded, consumed) = try_decode_frame_ok::<ClientPacket>(&buffer);
377
378        assert_eq!(decoded, packet1);
379        assert_eq!(consumed, len1);
380    }
381}