zerodds_websocket_bridge/
codec.rs1use alloc::vec::Vec;
7use core::fmt;
8
9use crate::frame::{Frame, Opcode};
10use crate::masking::apply_mask;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum CodecError {
15 HeaderTooShort,
17 NonMinimalLength,
20 PayloadLengthMsbSet,
22 PayloadTruncated,
24 MaskingKeyTruncated,
26 ControlFrameTooLong,
28 FragmentedControlFrame,
30}
31
32impl fmt::Display for CodecError {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 Self::HeaderTooShort => f.write_str("header too short"),
36 Self::NonMinimalLength => f.write_str("non-minimal payload length encoding"),
37 Self::PayloadLengthMsbSet => f.write_str("64-bit payload length MSB set"),
38 Self::PayloadTruncated => f.write_str("payload truncated"),
39 Self::MaskingKeyTruncated => f.write_str("masking key truncated"),
40 Self::ControlFrameTooLong => f.write_str("control frame payload > 125 bytes"),
41 Self::FragmentedControlFrame => f.write_str("control frame with FIN=0"),
42 }
43 }
44}
45
46#[cfg(feature = "std")]
47impl std::error::Error for CodecError {}
48
49pub fn encode(frame: &Frame) -> Result<Vec<u8>, CodecError> {
60 if frame.opcode.is_control() {
61 if !frame.fin {
62 return Err(CodecError::FragmentedControlFrame);
63 }
64 if frame.payload.len() > 125 {
65 return Err(CodecError::ControlFrameTooLong);
66 }
67 }
68 let mut out = Vec::with_capacity(2 + 8 + 4 + frame.payload.len());
69
70 let mut byte0 = frame.opcode.to_bits() & 0x0F;
72 if frame.fin {
73 byte0 |= 0x80;
74 }
75 if frame.rsv1 {
76 byte0 |= 0x40;
77 }
78 if frame.rsv2 {
79 byte0 |= 0x20;
80 }
81 if frame.rsv3 {
82 byte0 |= 0x10;
83 }
84 out.push(byte0);
85
86 let payload_len = frame.payload.len();
88 let masked = frame.masking_key.is_some();
89 let (len7, ext_len) = encode_payload_length(payload_len);
90 let byte1 = (if masked { 0x80 } else { 0x00 }) | (len7 & 0x7F);
91 out.push(byte1);
92 out.extend_from_slice(&ext_len);
93
94 if let Some(key) = frame.masking_key {
96 out.extend_from_slice(&key);
97 let mut masked_payload = frame.payload.clone();
99 apply_mask(&mut masked_payload, key);
100 out.extend_from_slice(&masked_payload);
101 } else {
102 out.extend_from_slice(&frame.payload);
103 }
104
105 Ok(out)
106}
107
108fn encode_payload_length(len: usize) -> (u8, Vec<u8>) {
111 if len <= 125 {
112 #[allow(clippy::cast_possible_truncation)]
113 (len as u8, Vec::new())
114 } else if len <= 0xFFFF {
115 #[allow(clippy::cast_possible_truncation)]
116 (126, (len as u16).to_be_bytes().to_vec())
117 } else {
118 let bytes = (len as u64).to_be_bytes();
120 (127, bytes.to_vec())
121 }
122}
123
124pub fn decode(bytes: &[u8]) -> Result<(Frame, usize), CodecError> {
133 if bytes.len() < 2 {
134 return Err(CodecError::HeaderTooShort);
135 }
136 let byte0 = bytes[0];
137 let fin = (byte0 & 0x80) != 0;
138 let rsv1 = (byte0 & 0x40) != 0;
139 let rsv2 = (byte0 & 0x20) != 0;
140 let rsv3 = (byte0 & 0x10) != 0;
141 let opcode = Opcode::from_bits(byte0 & 0x0F);
142
143 let byte1 = bytes[1];
144 let masked = (byte1 & 0x80) != 0;
145 let len7 = byte1 & 0x7F;
146
147 let mut cursor = 2usize;
148 let payload_len = match len7.cmp(&126) {
149 core::cmp::Ordering::Less => usize::from(len7),
150 core::cmp::Ordering::Equal => {
151 if bytes.len() < cursor + 2 {
152 return Err(CodecError::HeaderTooShort);
153 }
154 let v = u16::from_be_bytes([bytes[cursor], bytes[cursor + 1]]);
155 cursor += 2;
156 if v <= 125 {
157 return Err(CodecError::NonMinimalLength);
158 }
159 usize::from(v)
160 }
161 core::cmp::Ordering::Greater => {
162 if bytes.len() < cursor + 8 {
164 return Err(CodecError::HeaderTooShort);
165 }
166 let mut buf = [0u8; 8];
167 buf.copy_from_slice(&bytes[cursor..cursor + 8]);
168 let v = u64::from_be_bytes(buf);
169 cursor += 8;
170 if (v & 0x8000_0000_0000_0000) != 0 {
171 return Err(CodecError::PayloadLengthMsbSet);
172 }
173 if v <= 0xFFFF {
174 return Err(CodecError::NonMinimalLength);
175 }
176 usize::try_from(v).map_err(|_| CodecError::PayloadTruncated)?
177 }
178 };
179
180 if opcode.is_control() {
181 if !fin {
182 return Err(CodecError::FragmentedControlFrame);
183 }
184 if payload_len > 125 {
185 return Err(CodecError::ControlFrameTooLong);
186 }
187 }
188
189 let masking_key = if masked {
190 if bytes.len() < cursor + 4 {
191 return Err(CodecError::MaskingKeyTruncated);
192 }
193 let key = [
194 bytes[cursor],
195 bytes[cursor + 1],
196 bytes[cursor + 2],
197 bytes[cursor + 3],
198 ];
199 cursor += 4;
200 Some(key)
201 } else {
202 None
203 };
204
205 if bytes.len() < cursor + payload_len {
206 return Err(CodecError::PayloadTruncated);
207 }
208 let mut payload = bytes[cursor..cursor + payload_len].to_vec();
209 cursor += payload_len;
210
211 if let Some(key) = masking_key {
212 apply_mask(&mut payload, key);
213 }
214
215 Ok((
216 Frame {
217 fin,
218 rsv1,
219 rsv2,
220 rsv3,
221 opcode,
222 masking_key,
223 payload,
224 },
225 cursor,
226 ))
227}
228
229#[cfg(test)]
230#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
231mod tests {
232 use super::*;
233
234 #[test]
235 fn smallest_text_frame_encodes_to_2_byte_header_plus_payload() {
236 let bytes = encode(&Frame::text("hi")).expect("encode");
238 assert_eq!(bytes.len(), 4);
239 assert_eq!(bytes[0], 0x81);
241 assert_eq!(bytes[1], 0x02);
243 assert_eq!(&bytes[2..], b"hi");
244 }
245
246 #[test]
247 fn medium_payload_uses_extended_16_bit_length() {
248 let payload = alloc::vec![0xAA; 200];
250 let f = Frame::binary(payload.clone());
251 let bytes = encode(&f).expect("encode");
252 assert_eq!(bytes[0], 0x82);
253 assert_eq!(bytes[1] & 0x7F, 126);
254 assert_eq!(&bytes[2..4], &200u16.to_be_bytes());
255 assert_eq!(&bytes[4..], &payload[..]);
256 }
257
258 #[test]
259 fn large_payload_uses_extended_64_bit_length() {
260 let payload = alloc::vec![0xBB; 70_000];
262 let f = Frame::binary(payload.clone());
263 let bytes = encode(&f).expect("encode");
264 assert_eq!(bytes[1] & 0x7F, 127);
265 let mut len_buf = [0u8; 8];
266 len_buf.copy_from_slice(&bytes[2..10]);
267 assert_eq!(u64::from_be_bytes(len_buf), 70_000);
268 assert_eq!(bytes[2] & 0x80, 0);
270 }
271
272 #[test]
273 fn round_trip_unmasked_text() {
274 let f = Frame::text("hello world");
275 let bytes = encode(&f).expect("encode");
276 let (parsed, consumed) = decode(&bytes).expect("decode");
277 assert_eq!(parsed, f);
278 assert_eq!(consumed, bytes.len());
279 }
280
281 #[test]
282 fn round_trip_masked_payload_unmasked_on_decode() {
283 let f = Frame::text("masked!").with_mask([0x12, 0x34, 0x56, 0x78]);
286 let bytes = encode(&f).expect("encode");
287 assert_ne!(&bytes[6..], b"masked!");
289 let (parsed, _) = decode(&bytes).expect("decode");
290 assert_eq!(parsed.payload, b"masked!");
291 assert_eq!(parsed.masking_key, Some([0x12, 0x34, 0x56, 0x78]));
292 }
293
294 #[test]
295 fn round_trip_medium_and_large_payloads() {
296 for size in [126, 200, 65535, 65536, 100_000] {
297 let f = Frame::binary(alloc::vec![0xAB; size]);
298 let bytes = encode(&f).expect("encode");
299 let (parsed, _) = decode(&bytes).expect("decode");
300 assert_eq!(parsed.payload.len(), size);
301 }
302 }
303
304 #[test]
305 fn ping_frame_round_trip() {
306 let f = Frame::ping(alloc::vec![1, 2, 3]);
307 let bytes = encode(&f).expect("encode");
308 let (parsed, _) = decode(&bytes).expect("decode");
309 assert_eq!(parsed.opcode, Opcode::Ping);
310 assert_eq!(parsed.payload, alloc::vec![1, 2, 3]);
311 }
312
313 #[test]
314 fn close_frame_carries_status_code() {
315 let f = Frame::close(1000, "");
316 let bytes = encode(&f).expect("encode");
317 let (parsed, _) = decode(&bytes).expect("decode");
318 assert_eq!(parsed.opcode, Opcode::Close);
319 assert_eq!(&parsed.payload[..2], &1000u16.to_be_bytes());
320 }
321
322 #[test]
323 fn header_too_short_decode_fails() {
324 assert_eq!(decode(&[]), Err(CodecError::HeaderTooShort));
325 assert_eq!(decode(&[0x81]), Err(CodecError::HeaderTooShort));
326 }
327
328 #[test]
329 fn extended_16_bit_length_truncated_fails() {
330 assert_eq!(decode(&[0x81, 0x7E]), Err(CodecError::HeaderTooShort));
332 }
333
334 #[test]
335 fn extended_64_bit_length_msb_set_rejected() {
336 let bytes = [0x82u8, 0x7F, 0x80, 0, 0, 0, 0, 0, 0, 0];
338 assert_eq!(decode(&bytes), Err(CodecError::PayloadLengthMsbSet));
339 }
340
341 #[test]
342 fn non_minimal_16_bit_length_rejected() {
343 let bytes = [0x82u8, 0x7E, 0, 100, 0xAA, 0xBB];
346 assert_eq!(decode(&bytes), Err(CodecError::NonMinimalLength));
347 }
348
349 #[test]
350 fn non_minimal_64_bit_length_rejected() {
351 let mut bytes = alloc::vec![0x82u8, 0x7F];
353 bytes.extend_from_slice(&65000u64.to_be_bytes());
354 assert_eq!(decode(&bytes), Err(CodecError::NonMinimalLength));
355 }
356
357 #[test]
358 fn control_frame_with_long_payload_rejected_on_encode() {
359 let f = Frame::ping(alloc::vec![0; 200]);
361 assert_eq!(encode(&f), Err(CodecError::ControlFrameTooLong));
362 }
363
364 #[test]
365 fn fragmented_control_frame_rejected_on_encode() {
366 let mut f = Frame::ping(alloc::vec![1, 2]);
368 f.fin = false;
369 assert_eq!(encode(&f), Err(CodecError::FragmentedControlFrame));
370 }
371
372 #[test]
373 fn masked_frame_without_key_bytes_decode_fails() {
374 let bytes = [0x81u8, 0x80];
376 assert_eq!(decode(&bytes), Err(CodecError::MaskingKeyTruncated));
377 }
378
379 #[test]
380 fn payload_truncation_decode_fails() {
381 let bytes = [0x81u8, 0x0A, 0xAA, 0xBB];
383 assert_eq!(decode(&bytes), Err(CodecError::PayloadTruncated));
384 }
385
386 #[test]
387 fn rsv_bits_propagate_to_decoded_frame() {
388 let mut f = Frame::binary(alloc::vec![1]);
392 f.rsv1 = true;
393 f.rsv3 = true;
394 let bytes = encode(&f).expect("encode");
395 let (parsed, _) = decode(&bytes).expect("decode");
396 assert!(parsed.rsv1);
397 assert!(!parsed.rsv2);
398 assert!(parsed.rsv3);
399 }
400
401 #[test]
402 fn fin_zero_text_frame_round_trip() {
403 let mut f = Frame::text("part-1");
406 f.fin = false;
407 let bytes = encode(&f).expect("encode");
408 let (parsed, _) = decode(&bytes).expect("decode");
409 assert!(!parsed.fin);
410 assert_eq!(parsed.opcode, Opcode::Text);
411 }
412}