1use serde::Serialize;
4use serde::de::DeserializeOwned;
5
6use crate::error::ProtocolError;
7
8const FRAME_HEADER_LEN: usize = 4;
10
11pub fn encode_frame(value: &impl Serialize) -> Result<Vec<u8>, ProtocolError> {
28 let payload: Vec<u8> = serde_json::to_vec(value)?;
29
30 let payload_len_u32: u32 = u32::try_from(payload.len()).map_err(|e| {
31 let _ = e;
32 ProtocolError::FrameTooLarge {
33 size: payload.len(),
34 }
35 })?;
36
37 let frame_capacity: usize = FRAME_HEADER_LEN
38 .checked_add(payload.len())
39 .ok_or(ProtocolError::CapacityOverflow)?;
40
41 let mut frame: Vec<u8> = Vec::with_capacity(frame_capacity);
42 frame.extend_from_slice(&payload_len_u32.to_be_bytes());
43 frame.extend_from_slice(&payload);
44
45 Ok(frame)
46}
47
48pub fn decode_frame<T>(frame: &[u8]) -> Result<T, ProtocolError>
69where
70 T: DeserializeOwned,
71{
72 match try_decode_frame(frame)? {
73 Some((value, _consumed)) => Ok(value),
74 None => Err(ProtocolError::TruncatedFrame),
75 }
76}
77
78pub fn try_decode_frame<T>(buffer: &[u8]) -> Result<Option<(T, usize)>, ProtocolError>
100where
101 T: DeserializeOwned,
102{
103 let header: &[u8] = match buffer.get(..FRAME_HEADER_LEN) {
104 Some(header) => header,
105 None => return Ok(None),
106 };
107
108 let header_array: [u8; FRAME_HEADER_LEN] = match <[u8; FRAME_HEADER_LEN]>::try_from(header) {
109 Ok(array) => array,
110 Err(header_error) => {
111 let _ = header_error;
112 return Ok(None);
113 }
114 };
115
116 let payload_len_u32: u32 = u32::from_be_bytes(header_array);
117
118 let payload_len: usize = match usize::try_from(payload_len_u32) {
119 Ok(len) => len,
120 Err(conversion_error) => {
121 let _ = conversion_error;
122 return Err(ProtocolError::FrameLengthOutOfRange {
123 length: payload_len_u32,
124 });
125 }
126 };
127
128 let frame_len: usize = match FRAME_HEADER_LEN.checked_add(payload_len) {
129 Some(len) => len,
130 None => return Err(ProtocolError::CapacityOverflow),
131 };
132
133 let payload: &[u8] = match buffer.get(FRAME_HEADER_LEN..frame_len) {
134 Some(pld) => pld,
135 None => return Ok(None),
136 };
137
138 let value: T = serde_json::from_slice(payload)?;
139
140 Ok(Some((value, frame_len)))
141}
142
143#[cfg(test)]
144#[allow(dead_code, unused)]
145mod tests {
146 use hoy_test::assert_err;
147 use serde::Serialize;
148 use serde::de::DeserializeOwned;
149
150 use crate::codec::{decode_frame, encode_frame, try_decode_frame};
151 use crate::error::ProtocolError;
152 use crate::packet::{ClientPacket, ServerPacket};
153
154 fn build_frame(payload: &[u8]) -> Vec<u8> {
155 let payload_len_u32 =
156 u32::try_from(payload.len()).expect("test payload length capacity overflow");
157
158 let frame_capacity: usize = 4_usize
159 .checked_add(payload.len())
160 .expect("test frame capacity overflow");
161
162 let mut frame: Vec<u8> = Vec::with_capacity(frame_capacity);
163 frame.extend_from_slice(&payload_len_u32.to_be_bytes());
164 frame.extend_from_slice(payload);
165 frame
166 }
167
168 fn encode_frame_ok(value: &impl Serialize) -> Vec<u8> {
169 encode_frame(&value).expect("Frame encoding failed unexpectedly.")
170 }
171
172 fn encode_frame_err(value: &impl Serialize, error: &str) -> ProtocolError {
173 encode_frame(&value).expect_err(&format!("Expected error: ${error}."))
174 }
175
176 fn decode_frame_ok<T>(frame: &[u8]) -> T
177 where
178 T: DeserializeOwned,
179 {
180 decode_frame(frame).expect("Frame deserialization failed unexpectedly.")
181 }
182
183 fn decode_frame_err<T>(frame: &[u8], error: &str) -> ProtocolError
184 where
185 T: DeserializeOwned + std::fmt::Debug,
186 {
187 decode_frame::<T>(frame).expect_err(&format!("Expected error: {error}."))
188 }
189
190 fn try_decode_frame_ok<T>(buffer: &[u8]) -> (T, usize)
191 where
192 T: DeserializeOwned,
193 {
194 try_decode_frame::<T>(buffer)
195 .expect("Unexpected failure while trying to deserialize frame from buffer.")
196 .expect("Frame deserialization should not return None.")
197 }
198
199 fn try_decode_frame_none<T>(buffer: &[u8]) -> Option<(T, usize)>
200 where
201 T: DeserializeOwned + std::fmt::Debug + PartialEq,
202 {
203 let result = try_decode_frame::<T>(buffer)
204 .expect("Unexpected failure while trying to deserialize frame from buffer.");
205 assert_eq!(result, None);
206 result
207 }
208
209 fn try_decode_frame_err<T>(buffer: &[u8], error: &str) -> ProtocolError
210 where
211 T: DeserializeOwned + std::fmt::Debug,
212 {
213 try_decode_frame::<T>(buffer).expect_err(&format!("Expected error: {error}."))
214 }
215
216 #[test]
217 fn encode_and_decode_client_packet_roundtrip() {
218 let packet = ClientPacket::Hello {
219 username: String::from("bruce_lee"),
220 };
221 let frame = encode_frame_ok(&packet);
222 let decoded: ClientPacket = decode_frame_ok(&frame);
223
224 assert_eq!(decoded, packet);
225 }
226
227 #[test]
228 fn encode_and_decode_server_packet_roundtrip() {
229 let packet: ServerPacket = ServerPacket::ChatMessage {
230 from: String::from("bruce_lee"),
231 room: String::from("#general"),
232 text: String::from("Kung foo..."),
233 };
234 let frame = encode_frame_ok(&packet);
235 let decoded: ServerPacket = decode_frame_ok(&frame);
236
237 assert_eq!(decoded, packet);
238 }
239
240 #[test]
241 fn decode_frame_rejects_truncated_header() {
242 let frame: Vec<u8> = vec![0, 0, 0];
243 let error = decode_frame_err::<ClientPacket>(&frame, "Truncated header");
244
245 assert_err!(error, ProtocolError::TruncatedFrame);
246 }
247
248 #[test]
249 fn decode_frame_rejects_truncated_payload() {
250 let declared_payload_len: u32 = 10;
251 let mut frame: Vec<u8> = Vec::new();
252 frame.extend_from_slice(&declared_payload_len.to_be_bytes());
253 frame.extend_from_slice(b"abc");
254 let error = decode_frame_err::<ClientPacket>(&frame, "Truncated payload");
255
256 assert_err!(error, ProtocolError::TruncatedFrame);
257 }
258
259 #[test]
260 fn decode_frame_rejects_invalid_json_payload() {
261 let frame: Vec<u8> = build_frame(b"this is not valid json");
262 let error = decode_frame_err::<ClientPacket>(&frame, "Serde error");
263
264 assert_err!(error, ProtocolError::Serde(_));
265 }
266
267 #[test]
268 fn decode_frame_rejects_json_of_wrong_packet_shape() {
269 let payload: &[u8] = br#"{"NotARealPacket":{"foo":"bar"}}"#;
270 let frame: Vec<u8> = build_frame(payload);
271 let error = decode_frame_err::<ClientPacket>(&frame, "Serde error");
272
273 assert_err!(error, ProtocolError::Serde(_));
274 }
275
276 #[test]
277 fn decode_frame_ignores_trailing_bytes_after_payload() {
278 let packet: ClientPacket = ClientPacket::Ping;
279 let mut frame = encode_frame_ok(&packet);
280 frame.extend_from_slice(b"trailing bytes that belong to a future frame");
281 let decoded: ClientPacket = decode_frame_ok(&frame);
282
283 assert_eq!(decoded, packet);
284 }
285
286 #[test]
287 fn decode_frame_accepts_empty_string_fields() {
288 let packet: ClientPacket = ClientPacket::Hello {
289 username: String::new(),
290 };
291 let frame = encode_frame_ok(&packet);
292 let decoded: ClientPacket = decode_frame_ok(&frame);
293
294 assert_eq!(decoded, packet);
295 }
296
297 #[test]
298 fn decode_frame_handles_utf8_content() {
299 let packet: ServerPacket = ServerPacket::SystemMessage {
300 text: String::from("Ahoj ^^ Привет こんにちは"),
301 };
302 let frame = encode_frame_ok(&packet);
303 let decoded: ServerPacket = decode_frame_ok(&frame);
304
305 assert_eq!(decoded, packet);
306 }
307
308 #[test]
309 fn try_decode_frame_returns_none_for_incomplete_header() {
310 let buffer: Vec<u8> = vec![0, 0, 0];
311
312 let result = try_decode_frame_none::<ClientPacket>(&buffer);
313
314 assert_eq!(result, None);
315 }
316
317 #[test]
318 fn try_decode_frame_returns_none_for_incomplete_payload() {
319 let declared_payload_len: u32 = 10;
320 let mut buffer: Vec<u8> = Vec::new();
321 buffer.extend_from_slice(&declared_payload_len.to_be_bytes());
322 buffer.extend_from_slice(b"abc");
323
324 let result = try_decode_frame_none::<ClientPacket>(&buffer);
325
326 assert_eq!(result, None);
327 }
328
329 #[test]
330 fn try_decode_frame_decodes_complete_frame() {
331 let packet = ClientPacket::Ping;
332 let frame: Vec<u8> = encode_frame_ok(&packet);
333
334 let (decoded, consumed) = try_decode_frame_ok::<ClientPacket>(&frame);
335 assert_eq!(decoded, packet);
336 assert_eq!(consumed, frame.len());
337 }
338
339 #[test]
340 fn try_decode_frame_reports_consumed_len_with_trailing_bytes() {
341 let packet = ClientPacket::Ping;
342 let mut buffer: Vec<u8> = encode_frame_ok(&packet);
343 let frame_len: usize = buffer.len();
344 buffer.extend_from_slice(b"trailing bytes");
345
346 let (decoded, consumed) = try_decode_frame_ok::<ClientPacket>(&buffer);
347 assert_eq!(decoded, packet);
348 assert_eq!(consumed, frame_len);
349 }
350
351 #[test]
352 fn try_decode_frame_rejects_invalid_complete_payload() {
353 let buffer: Vec<u8> = build_frame(b"this is not a valid json");
354
355 let err = try_decode_frame_err::<ClientPacket>(&buffer, "Serde error");
356
357 assert_err!(err, ProtocolError::Serde(_));
358 }
359
360 #[test]
361 fn try_decode_frame_only_decodes_1_frame() {
362 let packet1 = ClientPacket::Ping;
363 let packet2 = ClientPacket::Hello {
364 username: String::from("bruce_lee"),
365 };
366
367 let frame1 = encode_frame_ok(&packet1);
368 let frame2 = encode_frame_ok(&packet2);
369
370 let len1 = frame1.len();
371
372 let mut buffer: Vec<u8> = Vec::new();
373 buffer.extend_from_slice(&frame1);
374 buffer.extend_from_slice(&frame2);
375
376 let (decoded, consumed) = try_decode_frame_ok::<ClientPacket>(&buffer);
377
378 assert_eq!(decoded, packet1);
379 assert_eq!(consumed, len1);
380 }
381}