1use bytes::{Buf, BufMut, Bytes, BytesMut};
6use thiserror::Error;
7
8use crate::frames::Frame;
9
10pub const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
12
13pub const LENGTH_PREFIX_SIZE: usize = 4;
15
16#[derive(Debug, Error)]
18pub enum ProtocolError {
19 #[error("Frame size {0} exceeds maximum {MAX_FRAME_SIZE}")]
21 FrameTooLarge(usize),
22
23 #[error("Incomplete frame: need {0} more bytes")]
25 Incomplete(usize),
26
27 #[error("Encoding error: {0}")]
29 Encode(#[from] rmp_serde::encode::Error),
30
31 #[error("Decoding error: {0}")]
33 Decode(#[from] rmp_serde::decode::Error),
34
35 #[error("Invalid frame: {0}")]
37 Invalid(String),
38}
39
40pub fn encode(frame: &Frame) -> Result<Bytes, ProtocolError> {
50 let payload = rmp_serde::to_vec_named(frame)?;
51
52 if payload.len() > MAX_FRAME_SIZE {
53 return Err(ProtocolError::FrameTooLarge(payload.len()));
54 }
55
56 let mut buf = BytesMut::with_capacity(LENGTH_PREFIX_SIZE + payload.len());
57 buf.put_u32(payload.len() as u32);
58 buf.extend_from_slice(&payload);
59
60 Ok(buf.freeze())
61}
62
63pub fn encode_into(frame: &Frame, buf: &mut BytesMut) -> Result<(), ProtocolError> {
69 let payload = rmp_serde::to_vec_named(frame)?;
70
71 if payload.len() > MAX_FRAME_SIZE {
72 return Err(ProtocolError::FrameTooLarge(payload.len()));
73 }
74
75 buf.reserve(LENGTH_PREFIX_SIZE + payload.len());
76 buf.put_u32(payload.len() as u32);
77 buf.extend_from_slice(&payload);
78
79 Ok(())
80}
81
82pub fn decode(data: &[u8]) -> Result<Frame, ProtocolError> {
88 if data.len() < LENGTH_PREFIX_SIZE {
89 return Err(ProtocolError::Incomplete(LENGTH_PREFIX_SIZE - data.len()));
90 }
91
92 let length = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
93
94 if length > MAX_FRAME_SIZE {
95 return Err(ProtocolError::FrameTooLarge(length));
96 }
97
98 let total_size = LENGTH_PREFIX_SIZE + length;
99 if data.len() < total_size {
100 return Err(ProtocolError::Incomplete(total_size - data.len()));
101 }
102
103 let frame = rmp_serde::from_slice(&data[LENGTH_PREFIX_SIZE..total_size])?;
104 Ok(frame)
105}
106
107pub fn decode_from(buf: &mut BytesMut) -> Result<Option<Frame>, ProtocolError> {
116 if buf.len() < LENGTH_PREFIX_SIZE {
117 return Ok(None);
118 }
119
120 let length = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
121
122 if length > MAX_FRAME_SIZE {
123 return Err(ProtocolError::FrameTooLarge(length));
124 }
125
126 let total_size = LENGTH_PREFIX_SIZE + length;
127 if buf.len() < total_size {
128 return Ok(None);
129 }
130
131 buf.advance(LENGTH_PREFIX_SIZE);
132 let payload = buf.split_to(length);
133 let frame = rmp_serde::from_slice(&payload)?;
134
135 Ok(Some(frame))
136}
137
138#[derive(Debug, Default)]
140pub struct FrameCodec {
141 }
143
144impl FrameCodec {
145 #[must_use]
147 pub fn new() -> Self {
148 Self::default()
149 }
150
151 pub fn encode(&self, frame: &Frame) -> Result<Bytes, ProtocolError> {
157 encode(frame)
158 }
159
160 pub fn decode(&self, data: &[u8]) -> Result<Frame, ProtocolError> {
166 decode(data)
167 }
168
169 pub fn decode_from(&self, buf: &mut BytesMut) -> Result<Option<Frame>, ProtocolError> {
175 decode_from(buf)
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn test_encode_decode_roundtrip() {
185 let frames = vec![
186 Frame::subscribe(1, "test-channel"),
187 Frame::publish("chat:room", b"Hello, world!".to_vec()),
188 Frame::ack(42),
189 Frame::error(1, 1001, "Invalid frame"),
190 Frame::ping(),
191 Frame::connect(1, Some("token123".to_string())),
192 Frame::connected("conn-123", 1, 30000),
193 ];
194
195 for frame in frames {
196 let encoded = encode(&frame).unwrap();
197 let decoded = decode(&encoded).unwrap();
198 assert_eq!(frame, decoded);
199 }
200 }
201
202 #[test]
203 fn test_decode_incomplete() {
204 let frame = Frame::subscribe(1, "test");
205 let encoded = encode(&frame).unwrap();
206
207 let partial = &encoded[..5];
209 match decode(partial) {
210 Err(ProtocolError::Incomplete(_)) => {}
211 other => panic!("Expected Incomplete error, got {:?}", other),
212 }
213 }
214
215 #[test]
216 fn test_frame_too_large() {
217 let large_payload = vec![0u8; MAX_FRAME_SIZE + 1];
219 let frame = Frame::publish("test", large_payload);
220
221 match encode(&frame) {
222 Err(ProtocolError::FrameTooLarge(_)) => {}
223 other => panic!("Expected FrameTooLarge error, got {:?}", other),
224 }
225 }
226
227 #[test]
228 fn test_streaming_decode() {
229 let frame1 = Frame::subscribe(1, "test1");
230 let frame2 = Frame::subscribe(2, "test2");
231
232 let mut buf = BytesMut::new();
233 encode_into(&frame1, &mut buf).unwrap();
234 encode_into(&frame2, &mut buf).unwrap();
235
236 let decoded1 = decode_from(&mut buf).unwrap().unwrap();
237 let decoded2 = decode_from(&mut buf).unwrap().unwrap();
238
239 assert_eq!(frame1, decoded1);
240 assert_eq!(frame2, decoded2);
241 assert!(buf.is_empty());
242 }
243}