capnweb_transport/
capnweb_codec.rs1use bytes::{Buf, BufMut, BytesMut};
2use capnweb_core::protocol::Message;
3use serde_json;
4use std::io;
5use tokio_util::codec::{Decoder, Encoder};
6
7pub struct CapnWebCodec {
10 max_frame_size: usize,
12}
13
14impl CapnWebCodec {
15 pub fn new() -> Self {
17 Self {
18 max_frame_size: 10 * 1024 * 1024, }
20 }
21
22 pub fn with_max_frame_size(max_frame_size: usize) -> Self {
24 Self { max_frame_size }
25 }
26}
27
28impl Default for CapnWebCodec {
29 fn default() -> Self {
30 Self::new()
31 }
32}
33
34#[derive(Debug, Clone)]
36pub enum FrameFormat {
37 LengthPrefixed,
39 NewlineDelimited,
41}
42
43impl Decoder for CapnWebCodec {
45 type Item = Message;
46 type Error = CodecError;
47
48 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
49 if src.len() < 4 {
51 return Ok(None);
52 }
53
54 let mut length_bytes = [0u8; 4];
56 length_bytes.copy_from_slice(&src[..4]);
57 let frame_len = u32::from_be_bytes(length_bytes) as usize;
58
59 if frame_len > self.max_frame_size {
61 return Err(CodecError::FrameTooLarge(frame_len));
62 }
63
64 if src.len() < 4 + frame_len {
66 src.reserve(4 + frame_len - src.len());
68 return Ok(None);
69 }
70
71 src.advance(4); let frame_data = src.split_to(frame_len);
74
75 let json_value: serde_json::Value = serde_json::from_slice(&frame_data)
77 .map_err(|e| CodecError::JsonError(e.to_string()))?;
78
79 let message =
81 Message::from_json(&json_value).map_err(|e| CodecError::MessageError(e.to_string()))?;
82
83 Ok(Some(message))
84 }
85}
86
87impl Encoder<Message> for CapnWebCodec {
89 type Error = CodecError;
90
91 fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
92 let json_value = item.to_json();
94 let json_bytes =
95 serde_json::to_vec(&json_value).map_err(|e| CodecError::JsonError(e.to_string()))?;
96
97 if json_bytes.len() > self.max_frame_size {
99 return Err(CodecError::FrameTooLarge(json_bytes.len()));
100 }
101
102 let length = json_bytes.len() as u32;
104 dst.reserve(4 + json_bytes.len());
105 dst.put_u32(length);
106 dst.put_slice(&json_bytes);
107
108 Ok(())
109 }
110}
111
112pub struct NewlineDelimitedCodec {
114 max_line_length: usize,
115}
116
117impl NewlineDelimitedCodec {
118 pub fn new() -> Self {
119 Self {
120 max_line_length: 1024 * 1024, }
122 }
123}
124
125impl Default for NewlineDelimitedCodec {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131impl Decoder for NewlineDelimitedCodec {
132 type Item = Message;
133 type Error = CodecError;
134
135 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
136 let newline_pos = src.iter().position(|&b| b == b'\n');
138
139 if let Some(pos) = newline_pos {
140 if pos > self.max_line_length {
142 return Err(CodecError::LineTooLong(pos));
143 }
144
145 let line = src.split_to(pos);
147 src.advance(1); let json_value: serde_json::Value =
151 serde_json::from_slice(&line).map_err(|e| CodecError::JsonError(e.to_string()))?;
152
153 let message = Message::from_json(&json_value)
155 .map_err(|e| CodecError::MessageError(e.to_string()))?;
156
157 Ok(Some(message))
158 } else {
159 if src.len() > self.max_line_length {
161 return Err(CodecError::LineTooLong(src.len()));
162 }
163
164 Ok(None)
165 }
166 }
167}
168
169impl Encoder<Message> for NewlineDelimitedCodec {
170 type Error = CodecError;
171
172 fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
173 let json_value = item.to_json();
175 let json_bytes =
176 serde_json::to_vec(&json_value).map_err(|e| CodecError::JsonError(e.to_string()))?;
177
178 if json_bytes.len() > self.max_line_length {
180 return Err(CodecError::LineTooLong(json_bytes.len()));
181 }
182
183 dst.reserve(json_bytes.len() + 1);
185 dst.put_slice(&json_bytes);
186 dst.put_u8(b'\n');
187
188 Ok(())
189 }
190}
191
192#[derive(Debug, thiserror::Error)]
194pub enum CodecError {
195 #[error("Frame too large: {0} bytes")]
196 FrameTooLarge(usize),
197
198 #[error("Line too long: {0} bytes")]
199 LineTooLong(usize),
200
201 #[error("JSON error: {0}")]
202 JsonError(String),
203
204 #[error("Message parse error: {0}")]
205 MessageError(String),
206
207 #[error("IO error: {0}")]
208 IoError(#[from] io::Error),
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use capnweb_core::{
215 protocol::{ExportId, ImportId},
216 Expression,
217 };
218
219 #[test]
220 fn test_length_prefixed_encode_decode() {
221 let mut codec = CapnWebCodec::new();
222 let mut buffer = BytesMut::new();
223
224 let msg = Message::Push(Expression::String("test".to_string()));
226 codec.encode(msg.clone(), &mut buffer).unwrap();
227
228 assert!(buffer.len() > 4);
230
231 let decoded = codec.decode(&mut buffer).unwrap().unwrap();
233 match decoded {
234 Message::Push(expr) => {
235 assert_eq!(expr, Expression::String("test".to_string()));
236 }
237 _ => panic!("Wrong message type"),
238 }
239
240 assert_eq!(buffer.len(), 0);
242 }
243
244 #[test]
245 fn test_newline_delimited_encode_decode() {
246 let mut codec = NewlineDelimitedCodec::new();
247 let mut buffer = BytesMut::new();
248
249 let msg = Message::Pull(ImportId(42));
251 codec.encode(msg, &mut buffer).unwrap();
252
253 assert_eq!(buffer[buffer.len() - 1], b'\n');
255
256 let decoded = codec.decode(&mut buffer).unwrap().unwrap();
258 match decoded {
259 Message::Pull(id) => {
260 assert_eq!(id, ImportId(42));
261 }
262 _ => panic!("Wrong message type"),
263 }
264 }
265
266 #[test]
267 fn test_partial_frame() {
268 let mut codec = CapnWebCodec::new();
269 let mut buffer = BytesMut::new();
270
271 buffer.put_u8(0);
273 buffer.put_u8(0);
274
275 assert!(codec.decode(&mut buffer).unwrap().is_none());
277
278 buffer.put_u8(0);
280 buffer.put_u8(10); assert!(codec.decode(&mut buffer).unwrap().is_none());
284 }
285
286 #[test]
287 fn test_frame_too_large() {
288 let mut codec = CapnWebCodec::with_max_frame_size(100);
289 let mut buffer = BytesMut::new();
290
291 let large_string = "x".repeat(200);
293 let msg = Message::Push(Expression::String(large_string));
294
295 assert!(codec.encode(msg, &mut buffer).is_err());
297 }
298
299 #[test]
300 fn test_multiple_messages() {
301 let mut codec = NewlineDelimitedCodec::new();
302 let mut buffer = BytesMut::new();
303
304 let msg1 = Message::Push(Expression::String("first".to_string()));
306 let msg2 = Message::Pull(ImportId(1));
307 let msg3 = Message::Resolve(
308 ExportId(-1),
309 Expression::Number(serde_json::Number::from(42)),
310 );
311
312 codec.encode(msg1, &mut buffer).unwrap();
313 codec.encode(msg2, &mut buffer).unwrap();
314 codec.encode(msg3, &mut buffer).unwrap();
315
316 let decoded1 = codec.decode(&mut buffer).unwrap().unwrap();
318 match decoded1 {
319 Message::Push(expr) => {
320 assert_eq!(expr, Expression::String("first".to_string()));
321 }
322 _ => panic!("Wrong message type"),
323 }
324
325 let decoded2 = codec.decode(&mut buffer).unwrap().unwrap();
326 match decoded2 {
327 Message::Pull(id) => {
328 assert_eq!(id, ImportId(1));
329 }
330 _ => panic!("Wrong message type"),
331 }
332
333 let decoded3 = codec.decode(&mut buffer).unwrap().unwrap();
334 match decoded3 {
335 Message::Resolve(id, expr) => {
336 assert_eq!(id, ExportId(-1));
337 match expr {
338 Expression::Number(n) => assert_eq!(n.as_i64(), Some(42)),
339 _ => panic!("Wrong expression type"),
340 }
341 }
342 _ => panic!("Wrong message type"),
343 }
344 }
345}