capnweb_core/
codec.rs

1use crate::{Message, RpcError};
2use bytes::{BufMut, Bytes, BytesMut};
3use serde_json;
4use std::io::{self, Read, Write};
5
6#[derive(Debug, Clone, Copy, PartialEq, Default)]
7pub enum FrameFormat {
8    LengthPrefixed,
9    #[default]
10    NewlineDelimited,
11}
12
13pub fn encode_message(msg: &Message) -> Result<Bytes, RpcError> {
14    let json = serde_json::to_vec(msg)?;
15    Ok(Bytes::from(json))
16}
17
18pub fn decode_message(data: &[u8]) -> Result<Message, RpcError> {
19    let msg = serde_json::from_slice(data)?;
20    Ok(msg)
21}
22
23pub fn encode_frame(msg: &Message, format: FrameFormat) -> Result<Bytes, RpcError> {
24    let json = serde_json::to_vec(msg)?;
25
26    match format {
27        FrameFormat::LengthPrefixed => {
28            let len = json.len() as u32;
29            let mut buf = BytesMut::with_capacity(4 + json.len());
30            buf.put_u32(len);
31            buf.put_slice(&json);
32            Ok(buf.freeze())
33        }
34        FrameFormat::NewlineDelimited => {
35            let mut buf = BytesMut::with_capacity(json.len() + 1);
36            buf.put_slice(&json);
37            buf.put_u8(b'\n');
38            Ok(buf.freeze())
39        }
40    }
41}
42
43pub fn decode_frame(data: &[u8], format: FrameFormat) -> Result<(Message, usize), RpcError> {
44    match format {
45        FrameFormat::LengthPrefixed => {
46            if data.len() < 4 {
47                return Err(RpcError::bad_request(
48                    "Incomplete frame: missing length prefix",
49                ));
50            }
51
52            let len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
53            let total_len = 4 + len;
54
55            if data.len() < total_len {
56                return Err(RpcError::bad_request("Incomplete frame: insufficient data"));
57            }
58
59            let msg = decode_message(&data[4..total_len])?;
60            Ok((msg, total_len))
61        }
62        FrameFormat::NewlineDelimited => {
63            let newline_pos = data
64                .iter()
65                .position(|&b| b == b'\n')
66                .ok_or_else(|| RpcError::bad_request("No newline found in frame"))?;
67
68            let msg = decode_message(&data[..newline_pos])?;
69            Ok((msg, newline_pos + 1))
70        }
71    }
72}
73
74pub struct FrameReader<R> {
75    reader: R,
76    buffer: BytesMut,
77    format: FrameFormat,
78}
79
80impl<R: Read> FrameReader<R> {
81    pub fn new(reader: R, format: FrameFormat) -> Self {
82        FrameReader {
83            reader,
84            buffer: BytesMut::with_capacity(4096),
85            format,
86        }
87    }
88
89    pub fn read_frame(&mut self) -> Result<Option<Message>, RpcError> {
90        loop {
91            if let Ok((msg, consumed)) = decode_frame(&self.buffer, self.format) {
92                self.buffer.advance(consumed);
93                return Ok(Some(msg));
94            }
95
96            let mut temp_buf = [0u8; 4096];
97            match self.reader.read(&mut temp_buf) {
98                Ok(0) => return Ok(None),
99                Ok(n) => self.buffer.put_slice(&temp_buf[..n]),
100                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
101                Err(e) => return Err(e.into()),
102            }
103        }
104    }
105}
106
107pub struct FrameWriter<W> {
108    writer: W,
109    format: FrameFormat,
110}
111
112impl<W: Write> FrameWriter<W> {
113    pub fn new(writer: W, format: FrameFormat) -> Self {
114        FrameWriter { writer, format }
115    }
116
117    pub fn write_frame(&mut self, msg: &Message) -> Result<(), RpcError> {
118        let data = encode_frame(msg, self.format)?;
119        self.writer.write_all(&data)?;
120        self.writer.flush()?;
121        Ok(())
122    }
123}
124
125trait BytesMutExt {
126    fn advance(&mut self, cnt: usize);
127}
128
129impl BytesMutExt for BytesMut {
130    fn advance(&mut self, cnt: usize) {
131        if cnt >= self.len() {
132            self.clear();
133        } else {
134            let remaining = self.split_off(cnt);
135            *self = remaining;
136        }
137    }
138}
139
140#[cfg(feature = "simd")]
141pub mod simd {
142    use super::*;
143    use simd_json;
144
145    pub fn encode_message_simd(msg: &Message) -> Result<Bytes, RpcError> {
146        let json = simd_json::to_vec(msg)
147            .map_err(|e| RpcError::bad_request(format!("SIMD JSON encode error: {}", e)))?;
148        Ok(Bytes::from(json))
149    }
150
151    pub fn decode_message_simd(data: &mut [u8]) -> Result<Message, RpcError> {
152        let msg = simd_json::from_slice(data)
153            .map_err(|e| RpcError::bad_request(format!("SIMD JSON decode error: {}", e)))?;
154        Ok(msg)
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use crate::ids::{CallId, CapId};
162    use crate::msg::{Outcome, Target};
163    use serde_json::json;
164
165    #[test]
166    fn test_encode_decode_message() {
167        let msg = Message::call(
168            CallId::new(1),
169            Target::cap(CapId::new(42)),
170            "test".to_string(),
171            vec![json!("hello"), json!(123)],
172        );
173
174        let encoded = encode_message(&msg).unwrap();
175        let decoded = decode_message(&encoded).unwrap();
176        assert_eq!(msg, decoded);
177    }
178
179    #[test]
180    fn test_encode_decode_frame_newline() {
181        let msg = Message::cap_ref(CapId::new(99));
182
183        let frame = encode_frame(&msg, FrameFormat::NewlineDelimited).unwrap();
184        assert!(frame[frame.len() - 1] == b'\n');
185
186        let (decoded, consumed) = decode_frame(&frame, FrameFormat::NewlineDelimited).unwrap();
187        assert_eq!(msg, decoded);
188        assert_eq!(consumed, frame.len());
189    }
190
191    #[test]
192    fn test_encode_decode_frame_length_prefixed() {
193        let msg = Message::dispose(vec![CapId::new(1), CapId::new(2), CapId::new(3)]);
194
195        let frame = encode_frame(&msg, FrameFormat::LengthPrefixed).unwrap();
196        assert!(frame.len() > 4);
197
198        let (decoded, consumed) = decode_frame(&frame, FrameFormat::LengthPrefixed).unwrap();
199        assert_eq!(msg, decoded);
200        assert_eq!(consumed, frame.len());
201    }
202
203    #[test]
204    fn test_frame_reader_writer() {
205        use std::io::Cursor;
206
207        let messages = vec![
208            Message::call(
209                CallId::new(1),
210                Target::cap(CapId::new(10)),
211                "method".to_string(),
212                vec![json!("test")],
213            ),
214            Message::result(
215                CallId::new(1),
216                Outcome::Success {
217                    value: json!({"result": true}),
218                },
219            ),
220        ];
221
222        let mut buffer = Vec::new();
223        {
224            let mut writer = FrameWriter::new(&mut buffer, FrameFormat::NewlineDelimited);
225            for msg in &messages {
226                writer.write_frame(msg).unwrap();
227            }
228        }
229
230        let cursor = Cursor::new(buffer);
231        let mut reader = FrameReader::new(cursor, FrameFormat::NewlineDelimited);
232
233        for expected_msg in messages {
234            let msg = reader.read_frame().unwrap().expect("Expected message");
235            assert_eq!(msg, expected_msg);
236        }
237
238        assert_eq!(reader.read_frame().unwrap(), None);
239    }
240
241    #[test]
242    fn test_incomplete_frame() {
243        let msg = Message::cap_ref(CapId::new(42));
244        let frame = encode_frame(&msg, FrameFormat::LengthPrefixed).unwrap();
245
246        let result = decode_frame(&frame[..2], FrameFormat::LengthPrefixed);
247        assert!(result.is_err());
248
249        let result = decode_frame(&frame[..frame.len() - 1], FrameFormat::LengthPrefixed);
250        assert!(result.is_err());
251    }
252
253    #[test]
254    fn test_multiple_frames_in_buffer() {
255        let msg1 = Message::cap_ref(CapId::new(1));
256        let msg2 = Message::cap_ref(CapId::new(2));
257
258        let frame1 = encode_frame(&msg1, FrameFormat::NewlineDelimited).unwrap();
259        let frame2 = encode_frame(&msg2, FrameFormat::NewlineDelimited).unwrap();
260
261        let mut combined = BytesMut::new();
262        combined.put_slice(&frame1);
263        combined.put_slice(&frame2);
264
265        let (decoded1, consumed1) = decode_frame(&combined, FrameFormat::NewlineDelimited).unwrap();
266        assert_eq!(decoded1, msg1);
267
268        let (decoded2, _consumed2) =
269            decode_frame(&combined[consumed1..], FrameFormat::NewlineDelimited).unwrap();
270        assert_eq!(decoded2, msg2);
271    }
272}