capnweb_transport/
capnweb_codec.rs

1use bytes::{Buf, BufMut, BytesMut};
2use capnweb_core::protocol::Message;
3use serde_json;
4use std::io;
5use tokio_util::codec::{Decoder, Encoder};
6
7/// Codec for Cap'n Web protocol messages
8/// Handles serialization/deserialization of protocol messages
9pub struct CapnWebCodec {
10    /// Maximum frame size to prevent DoS attacks
11    max_frame_size: usize,
12}
13
14impl CapnWebCodec {
15    /// Create a new codec with default settings
16    pub fn new() -> Self {
17        Self {
18            max_frame_size: 10 * 1024 * 1024, // 10MB default
19        }
20    }
21
22    /// Create a new codec with custom max frame size
23    pub fn with_max_frame_size(max_frame_size: usize) -> Self {
24        Self { max_frame_size }
25    }
26}
27
28impl Default for CapnWebCodec {
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34/// Frame format for Cap'n Web messages
35#[derive(Debug, Clone)]
36pub enum FrameFormat {
37    /// Length-prefixed binary frames (4-byte big-endian length)
38    LengthPrefixed,
39    /// Newline-delimited JSON frames
40    NewlineDelimited,
41}
42
43/// Decoder for length-prefixed frames
44impl Decoder for CapnWebCodec {
45    type Item = Message;
46    type Error = CodecError;
47
48    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
49        // Need at least 4 bytes for length prefix
50        if src.len() < 4 {
51            return Ok(None);
52        }
53
54        // Read the length prefix (big-endian u32)
55        let mut length_bytes = [0u8; 4];
56        length_bytes.copy_from_slice(&src[..4]);
57        let frame_len = u32::from_be_bytes(length_bytes) as usize;
58
59        // Check frame size limit
60        if frame_len > self.max_frame_size {
61            return Err(CodecError::FrameTooLarge(frame_len));
62        }
63
64        // Check if we have the complete frame
65        if src.len() < 4 + frame_len {
66            // Need more data
67            src.reserve(4 + frame_len - src.len());
68            return Ok(None);
69        }
70
71        // Extract the frame
72        src.advance(4); // Skip length prefix
73        let frame_data = src.split_to(frame_len);
74
75        // Parse the JSON message
76        let json_value: serde_json::Value = serde_json::from_slice(&frame_data)
77            .map_err(|e| CodecError::JsonError(e.to_string()))?;
78
79        // Parse into Cap'n Web message
80        let message =
81            Message::from_json(&json_value).map_err(|e| CodecError::MessageError(e.to_string()))?;
82
83        Ok(Some(message))
84    }
85}
86
87/// Encoder for length-prefixed frames
88impl Encoder<Message> for CapnWebCodec {
89    type Error = CodecError;
90
91    fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
92        // Convert message to JSON
93        let json_value = item.to_json();
94        let json_bytes =
95            serde_json::to_vec(&json_value).map_err(|e| CodecError::JsonError(e.to_string()))?;
96
97        // Check frame size
98        if json_bytes.len() > self.max_frame_size {
99            return Err(CodecError::FrameTooLarge(json_bytes.len()));
100        }
101
102        // Write length prefix (big-endian u32)
103        let length = json_bytes.len() as u32;
104        dst.reserve(4 + json_bytes.len());
105        dst.put_u32(length);
106        dst.put_slice(&json_bytes);
107
108        Ok(())
109    }
110}
111
112/// Newline-delimited JSON codec
113pub struct NewlineDelimitedCodec {
114    max_line_length: usize,
115}
116
117impl NewlineDelimitedCodec {
118    pub fn new() -> Self {
119        Self {
120            max_line_length: 1024 * 1024, // 1MB default
121        }
122    }
123}
124
125impl Default for NewlineDelimitedCodec {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131impl Decoder for NewlineDelimitedCodec {
132    type Item = Message;
133    type Error = CodecError;
134
135    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
136        // Find newline
137        let newline_pos = src.iter().position(|&b| b == b'\n');
138
139        if let Some(pos) = newline_pos {
140            // Check line length
141            if pos > self.max_line_length {
142                return Err(CodecError::LineTooLong(pos));
143            }
144
145            // Extract the line (without newline)
146            let line = src.split_to(pos);
147            src.advance(1); // Skip the newline
148
149            // Parse JSON
150            let json_value: serde_json::Value =
151                serde_json::from_slice(&line).map_err(|e| CodecError::JsonError(e.to_string()))?;
152
153            // Parse message
154            let message = Message::from_json(&json_value)
155                .map_err(|e| CodecError::MessageError(e.to_string()))?;
156
157            Ok(Some(message))
158        } else {
159            // Check if buffer is getting too large
160            if src.len() > self.max_line_length {
161                return Err(CodecError::LineTooLong(src.len()));
162            }
163
164            Ok(None)
165        }
166    }
167}
168
169impl Encoder<Message> for NewlineDelimitedCodec {
170    type Error = CodecError;
171
172    fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
173        // Convert to JSON
174        let json_value = item.to_json();
175        let json_bytes =
176            serde_json::to_vec(&json_value).map_err(|e| CodecError::JsonError(e.to_string()))?;
177
178        // Check line length
179        if json_bytes.len() > self.max_line_length {
180            return Err(CodecError::LineTooLong(json_bytes.len()));
181        }
182
183        // Write JSON and newline
184        dst.reserve(json_bytes.len() + 1);
185        dst.put_slice(&json_bytes);
186        dst.put_u8(b'\n');
187
188        Ok(())
189    }
190}
191
192/// Codec errors
193#[derive(Debug, thiserror::Error)]
194pub enum CodecError {
195    #[error("Frame too large: {0} bytes")]
196    FrameTooLarge(usize),
197
198    #[error("Line too long: {0} bytes")]
199    LineTooLong(usize),
200
201    #[error("JSON error: {0}")]
202    JsonError(String),
203
204    #[error("Message parse error: {0}")]
205    MessageError(String),
206
207    #[error("IO error: {0}")]
208    IoError(#[from] io::Error),
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use capnweb_core::{
215        protocol::{ExportId, ImportId},
216        Expression,
217    };
218
219    #[test]
220    fn test_length_prefixed_encode_decode() {
221        let mut codec = CapnWebCodec::new();
222        let mut buffer = BytesMut::new();
223
224        // Encode a message
225        let msg = Message::Push(Expression::String("test".to_string()));
226        codec.encode(msg.clone(), &mut buffer).unwrap();
227
228        // Check that we have length prefix + data
229        assert!(buffer.len() > 4);
230
231        // Decode the message
232        let decoded = codec.decode(&mut buffer).unwrap().unwrap();
233        match decoded {
234            Message::Push(expr) => {
235                assert_eq!(expr, Expression::String("test".to_string()));
236            }
237            _ => panic!("Wrong message type"),
238        }
239
240        // Buffer should be empty now
241        assert_eq!(buffer.len(), 0);
242    }
243
244    #[test]
245    fn test_newline_delimited_encode_decode() {
246        let mut codec = NewlineDelimitedCodec::new();
247        let mut buffer = BytesMut::new();
248
249        // Encode a message
250        let msg = Message::Pull(ImportId(42));
251        codec.encode(msg, &mut buffer).unwrap();
252
253        // Check that it ends with newline
254        assert_eq!(buffer[buffer.len() - 1], b'\n');
255
256        // Decode the message
257        let decoded = codec.decode(&mut buffer).unwrap().unwrap();
258        match decoded {
259            Message::Pull(id) => {
260                assert_eq!(id, ImportId(42));
261            }
262            _ => panic!("Wrong message type"),
263        }
264    }
265
266    #[test]
267    fn test_partial_frame() {
268        let mut codec = CapnWebCodec::new();
269        let mut buffer = BytesMut::new();
270
271        // Add partial length prefix
272        buffer.put_u8(0);
273        buffer.put_u8(0);
274
275        // Should return None (need more data)
276        assert!(codec.decode(&mut buffer).unwrap().is_none());
277
278        // Add rest of length prefix and partial data
279        buffer.put_u8(0);
280        buffer.put_u8(10); // Frame length = 10
281
282        // Still need more data
283        assert!(codec.decode(&mut buffer).unwrap().is_none());
284    }
285
286    #[test]
287    fn test_frame_too_large() {
288        let mut codec = CapnWebCodec::with_max_frame_size(100);
289        let mut buffer = BytesMut::new();
290
291        // Create a message that's too large
292        let large_string = "x".repeat(200);
293        let msg = Message::Push(Expression::String(large_string));
294
295        // Encoding should fail
296        assert!(codec.encode(msg, &mut buffer).is_err());
297    }
298
299    #[test]
300    fn test_multiple_messages() {
301        let mut codec = NewlineDelimitedCodec::new();
302        let mut buffer = BytesMut::new();
303
304        // Encode multiple messages
305        let msg1 = Message::Push(Expression::String("first".to_string()));
306        let msg2 = Message::Pull(ImportId(1));
307        let msg3 = Message::Resolve(
308            ExportId(-1),
309            Expression::Number(serde_json::Number::from(42)),
310        );
311
312        codec.encode(msg1, &mut buffer).unwrap();
313        codec.encode(msg2, &mut buffer).unwrap();
314        codec.encode(msg3, &mut buffer).unwrap();
315
316        // Decode all messages
317        let decoded1 = codec.decode(&mut buffer).unwrap().unwrap();
318        match decoded1 {
319            Message::Push(expr) => {
320                assert_eq!(expr, Expression::String("first".to_string()));
321            }
322            _ => panic!("Wrong message type"),
323        }
324
325        let decoded2 = codec.decode(&mut buffer).unwrap().unwrap();
326        match decoded2 {
327            Message::Pull(id) => {
328                assert_eq!(id, ImportId(1));
329            }
330            _ => panic!("Wrong message type"),
331        }
332
333        let decoded3 = codec.decode(&mut buffer).unwrap().unwrap();
334        match decoded3 {
335            Message::Resolve(id, expr) => {
336                assert_eq!(id, ExportId(-1));
337                match expr {
338                    Expression::Number(n) => assert_eq!(n.as_i64(), Some(42)),
339                    _ => panic!("Wrong expression type"),
340                }
341            }
342            _ => panic!("Wrong message type"),
343        }
344    }
345}