oxigdal_websocket/protocol/
framing.rs1use crate::error::{Error, Result};
4use bytes::{BufMut, Bytes, BytesMut};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8#[repr(u8)]
9pub enum FrameType {
10 Data = 0,
12 Control = 1,
14 Heartbeat = 2,
16 FragmentStart = 3,
18 FragmentContinuation = 4,
20 FragmentEnd = 5,
22}
23
24impl TryFrom<u8> for FrameType {
25 type Error = Error;
26
27 fn try_from(value: u8) -> Result<Self> {
28 match value {
29 0 => Ok(FrameType::Data),
30 1 => Ok(FrameType::Control),
31 2 => Ok(FrameType::Heartbeat),
32 3 => Ok(FrameType::FragmentStart),
33 4 => Ok(FrameType::FragmentContinuation),
34 5 => Ok(FrameType::FragmentEnd),
35 _ => Err(Error::Protocol(format!("Invalid frame type: {}", value))),
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
48pub struct FrameHeader {
49 pub frame_type: FrameType,
51 pub version: u8,
53 pub compressed: bool,
55 pub fragmented: bool,
57 pub payload_length: u32,
59}
60
61impl FrameHeader {
62 pub const SIZE: usize = 8;
64
65 pub fn new(frame_type: FrameType, version: u8, compressed: bool, payload_length: u32) -> Self {
67 Self {
68 frame_type,
69 version,
70 compressed,
71 fragmented: false,
72 payload_length,
73 }
74 }
75
76 pub fn encode(&self) -> [u8; Self::SIZE] {
78 let mut buf = [0u8; Self::SIZE];
79
80 buf[0] = ((self.frame_type as u8) << 4) | (self.version & 0x0F);
82
83 let mut flags = 0u8;
85 if self.compressed {
86 flags |= 0x80; }
88 if self.fragmented {
89 flags |= 0x40; }
91 buf[1] = flags;
92
93 buf[2..6].copy_from_slice(&self.payload_length.to_be_bytes());
95
96 buf
98 }
99
100 pub fn decode(data: &[u8]) -> Result<Self> {
102 if data.len() < Self::SIZE {
103 return Err(Error::Protocol(format!(
104 "Insufficient data for frame header: expected {}, got {}",
105 Self::SIZE,
106 data.len()
107 )));
108 }
109
110 let frame_type = FrameType::try_from(data[0] >> 4)?;
112 let version = data[0] & 0x0F;
113
114 let compressed = (data[1] & 0x80) != 0;
116 let fragmented = (data[1] & 0x40) != 0;
117
118 let payload_length = u32::from_be_bytes([data[2], data[3], data[4], data[5]]);
120
121 Ok(Self {
122 frame_type,
123 version,
124 compressed,
125 fragmented,
126 payload_length,
127 })
128 }
129
130 pub fn total_size(&self) -> usize {
132 Self::SIZE + self.payload_length as usize
133 }
134}
135
136#[derive(Debug, Clone)]
138pub struct Frame {
139 pub header: FrameHeader,
141 pub payload: Bytes,
143}
144
145impl Frame {
146 pub fn new(frame_type: FrameType, version: u8, compressed: bool, payload: Bytes) -> Self {
148 let header = FrameHeader::new(frame_type, version, compressed, payload.len() as u32);
149 Self { header, payload }
150 }
151
152 pub fn data(version: u8, compressed: bool, payload: Bytes) -> Self {
154 Self::new(FrameType::Data, version, compressed, payload)
155 }
156
157 pub fn control(version: u8, payload: Bytes) -> Self {
159 Self::new(FrameType::Control, version, false, payload)
160 }
161
162 pub fn heartbeat(version: u8) -> Self {
164 Self::new(FrameType::Heartbeat, version, false, Bytes::new())
165 }
166
167 pub fn size(&self) -> usize {
169 self.header.total_size()
170 }
171}
172
173pub struct FrameCodec {
175 max_payload_size: u32,
176}
177
178impl FrameCodec {
179 pub fn new() -> Self {
181 Self {
182 max_payload_size: 16 * 1024 * 1024, }
184 }
185
186 pub fn with_max_payload_size(max_payload_size: u32) -> Self {
188 Self { max_payload_size }
189 }
190
191 pub fn encode(&self, frame: &Frame) -> Result<Bytes> {
193 if frame.header.payload_length > self.max_payload_size {
194 return Err(Error::Protocol(format!(
195 "Payload size {} exceeds maximum {}",
196 frame.header.payload_length, self.max_payload_size
197 )));
198 }
199
200 let mut buf = BytesMut::with_capacity(frame.size());
201
202 buf.put_slice(&frame.header.encode());
204
205 buf.put_slice(&frame.payload);
207
208 Ok(buf.freeze())
209 }
210
211 pub fn decode(&self, data: &[u8]) -> Result<Frame> {
213 let header = FrameHeader::decode(data)?;
215
216 if header.payload_length > self.max_payload_size {
218 return Err(Error::Protocol(format!(
219 "Payload size {} exceeds maximum {}",
220 header.payload_length, self.max_payload_size
221 )));
222 }
223
224 let total_size = header.total_size();
226 if data.len() < total_size {
227 return Err(Error::Protocol(format!(
228 "Insufficient data for frame: expected {}, got {}",
229 total_size,
230 data.len()
231 )));
232 }
233
234 let payload = Bytes::copy_from_slice(&data[FrameHeader::SIZE..total_size]);
236
237 Ok(Frame { header, payload })
238 }
239
240 pub fn decode_all(&self, data: &[u8]) -> Result<Vec<Frame>> {
242 let mut frames = Vec::new();
243 let mut offset = 0;
244
245 while offset < data.len() {
246 if data.len() - offset < FrameHeader::SIZE {
247 break;
248 }
249
250 let header = FrameHeader::decode(&data[offset..])?;
251 let total_size = header.total_size();
252
253 if data.len() - offset < total_size {
254 break;
255 }
256
257 let frame = self.decode(&data[offset..])?;
258 frames.push(frame);
259
260 offset += total_size;
261 }
262
263 Ok(frames)
264 }
265}
266
267impl Default for FrameCodec {
268 fn default() -> Self {
269 Self::new()
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_frame_header_encode_decode() -> Result<()> {
279 let header = FrameHeader::new(FrameType::Data, 1, true, 1024);
280 let encoded = header.encode();
281 let decoded = FrameHeader::decode(&encoded)?;
282
283 assert_eq!(header.frame_type as u8, decoded.frame_type as u8);
284 assert_eq!(header.version, decoded.version);
285 assert_eq!(header.compressed, decoded.compressed);
286 assert_eq!(header.payload_length, decoded.payload_length);
287 Ok(())
288 }
289
290 #[test]
291 fn test_frame_encode_decode() -> Result<()> {
292 let codec = FrameCodec::new();
293 let payload = Bytes::from(vec![1, 2, 3, 4, 5]);
294 let frame = Frame::data(1, false, payload.clone());
295
296 let encoded = codec.encode(&frame)?;
297 let decoded = codec.decode(&encoded)?;
298
299 assert_eq!(
300 frame.header.frame_type as u8,
301 decoded.header.frame_type as u8
302 );
303 assert_eq!(frame.payload, decoded.payload);
304 Ok(())
305 }
306
307 #[test]
308 fn test_frame_codec_decode_all() -> Result<()> {
309 let codec = FrameCodec::new();
310
311 let frame1 = Frame::data(1, false, Bytes::from(vec![1, 2, 3]));
313 let frame2 = Frame::data(1, false, Bytes::from(vec![4, 5, 6]));
314
315 let mut buf = BytesMut::new();
317 buf.put_slice(&codec.encode(&frame1)?);
318 buf.put_slice(&codec.encode(&frame2)?);
319
320 let frames = codec.decode_all(&buf)?;
322
323 assert_eq!(frames.len(), 2);
324 assert_eq!(frames[0].payload, Bytes::from(vec![1, 2, 3]));
325 assert_eq!(frames[1].payload, Bytes::from(vec![4, 5, 6]));
326 Ok(())
327 }
328
329 #[test]
330 fn test_frame_heartbeat() -> Result<()> {
331 let codec = FrameCodec::new();
332 let frame = Frame::heartbeat(1);
333
334 let encoded = codec.encode(&frame)?;
335 let decoded = codec.decode(&encoded)?;
336
337 assert_eq!(decoded.header.frame_type as u8, FrameType::Heartbeat as u8);
338 assert!(decoded.payload.is_empty());
339 Ok(())
340 }
341
342 #[test]
343 fn test_frame_max_size() {
344 let codec = FrameCodec::with_max_payload_size(100);
345 let payload = Bytes::from(vec![0; 200]);
346 let frame = Frame::data(1, false, payload);
347
348 let result = codec.encode(&frame);
349 assert!(result.is_err());
350 }
351}