1use crate::{Message, RpcError};
2use bytes::{BufMut, Bytes, BytesMut};
3use serde_json;
4use std::io::{self, Read, Write};
5
6#[derive(Debug, Clone, Copy, PartialEq, Default)]
7pub enum FrameFormat {
8 LengthPrefixed,
9 #[default]
10 NewlineDelimited,
11}
12
13pub fn encode_message(msg: &Message) -> Result<Bytes, RpcError> {
14 let json = serde_json::to_vec(msg)?;
15 Ok(Bytes::from(json))
16}
17
18pub fn decode_message(data: &[u8]) -> Result<Message, RpcError> {
19 let msg = serde_json::from_slice(data)?;
20 Ok(msg)
21}
22
23pub fn encode_frame(msg: &Message, format: FrameFormat) -> Result<Bytes, RpcError> {
24 let json = serde_json::to_vec(msg)?;
25
26 match format {
27 FrameFormat::LengthPrefixed => {
28 let len = json.len() as u32;
29 let mut buf = BytesMut::with_capacity(4 + json.len());
30 buf.put_u32(len);
31 buf.put_slice(&json);
32 Ok(buf.freeze())
33 }
34 FrameFormat::NewlineDelimited => {
35 let mut buf = BytesMut::with_capacity(json.len() + 1);
36 buf.put_slice(&json);
37 buf.put_u8(b'\n');
38 Ok(buf.freeze())
39 }
40 }
41}
42
43pub fn decode_frame(data: &[u8], format: FrameFormat) -> Result<(Message, usize), RpcError> {
44 match format {
45 FrameFormat::LengthPrefixed => {
46 if data.len() < 4 {
47 return Err(RpcError::bad_request(
48 "Incomplete frame: missing length prefix",
49 ));
50 }
51
52 let len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
53 let total_len = 4 + len;
54
55 if data.len() < total_len {
56 return Err(RpcError::bad_request("Incomplete frame: insufficient data"));
57 }
58
59 let msg = decode_message(&data[4..total_len])?;
60 Ok((msg, total_len))
61 }
62 FrameFormat::NewlineDelimited => {
63 let newline_pos = data
64 .iter()
65 .position(|&b| b == b'\n')
66 .ok_or_else(|| RpcError::bad_request("No newline found in frame"))?;
67
68 let msg = decode_message(&data[..newline_pos])?;
69 Ok((msg, newline_pos + 1))
70 }
71 }
72}
73
74pub struct FrameReader<R> {
75 reader: R,
76 buffer: BytesMut,
77 format: FrameFormat,
78}
79
80impl<R: Read> FrameReader<R> {
81 pub fn new(reader: R, format: FrameFormat) -> Self {
82 FrameReader {
83 reader,
84 buffer: BytesMut::with_capacity(4096),
85 format,
86 }
87 }
88
89 pub fn read_frame(&mut self) -> Result<Option<Message>, RpcError> {
90 loop {
91 if let Ok((msg, consumed)) = decode_frame(&self.buffer, self.format) {
92 self.buffer.advance(consumed);
93 return Ok(Some(msg));
94 }
95
96 let mut temp_buf = [0u8; 4096];
97 match self.reader.read(&mut temp_buf) {
98 Ok(0) => return Ok(None),
99 Ok(n) => self.buffer.put_slice(&temp_buf[..n]),
100 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
101 Err(e) => return Err(e.into()),
102 }
103 }
104 }
105}
106
107pub struct FrameWriter<W> {
108 writer: W,
109 format: FrameFormat,
110}
111
112impl<W: Write> FrameWriter<W> {
113 pub fn new(writer: W, format: FrameFormat) -> Self {
114 FrameWriter { writer, format }
115 }
116
117 pub fn write_frame(&mut self, msg: &Message) -> Result<(), RpcError> {
118 let data = encode_frame(msg, self.format)?;
119 self.writer.write_all(&data)?;
120 self.writer.flush()?;
121 Ok(())
122 }
123}
124
125trait BytesMutExt {
126 fn advance(&mut self, cnt: usize);
127}
128
129impl BytesMutExt for BytesMut {
130 fn advance(&mut self, cnt: usize) {
131 if cnt >= self.len() {
132 self.clear();
133 } else {
134 let remaining = self.split_off(cnt);
135 *self = remaining;
136 }
137 }
138}
139
140#[cfg(feature = "simd")]
141pub mod simd {
142 use super::*;
143 use simd_json;
144
145 pub fn encode_message_simd(msg: &Message) -> Result<Bytes, RpcError> {
146 let json = simd_json::to_vec(msg)
147 .map_err(|e| RpcError::bad_request(format!("SIMD JSON encode error: {}", e)))?;
148 Ok(Bytes::from(json))
149 }
150
151 pub fn decode_message_simd(data: &mut [u8]) -> Result<Message, RpcError> {
152 let msg = simd_json::from_slice(data)
153 .map_err(|e| RpcError::bad_request(format!("SIMD JSON decode error: {}", e)))?;
154 Ok(msg)
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use crate::ids::{CallId, CapId};
162 use crate::msg::{Outcome, Target};
163 use serde_json::json;
164
165 #[test]
166 fn test_encode_decode_message() {
167 let msg = Message::call(
168 CallId::new(1),
169 Target::cap(CapId::new(42)),
170 "test".to_string(),
171 vec![json!("hello"), json!(123)],
172 );
173
174 let encoded = encode_message(&msg).unwrap();
175 let decoded = decode_message(&encoded).unwrap();
176 assert_eq!(msg, decoded);
177 }
178
179 #[test]
180 fn test_encode_decode_frame_newline() {
181 let msg = Message::cap_ref(CapId::new(99));
182
183 let frame = encode_frame(&msg, FrameFormat::NewlineDelimited).unwrap();
184 assert!(frame[frame.len() - 1] == b'\n');
185
186 let (decoded, consumed) = decode_frame(&frame, FrameFormat::NewlineDelimited).unwrap();
187 assert_eq!(msg, decoded);
188 assert_eq!(consumed, frame.len());
189 }
190
191 #[test]
192 fn test_encode_decode_frame_length_prefixed() {
193 let msg = Message::dispose(vec![CapId::new(1), CapId::new(2), CapId::new(3)]);
194
195 let frame = encode_frame(&msg, FrameFormat::LengthPrefixed).unwrap();
196 assert!(frame.len() > 4);
197
198 let (decoded, consumed) = decode_frame(&frame, FrameFormat::LengthPrefixed).unwrap();
199 assert_eq!(msg, decoded);
200 assert_eq!(consumed, frame.len());
201 }
202
203 #[test]
204 fn test_frame_reader_writer() {
205 use std::io::Cursor;
206
207 let messages = vec![
208 Message::call(
209 CallId::new(1),
210 Target::cap(CapId::new(10)),
211 "method".to_string(),
212 vec![json!("test")],
213 ),
214 Message::result(
215 CallId::new(1),
216 Outcome::Success {
217 value: json!({"result": true}),
218 },
219 ),
220 ];
221
222 let mut buffer = Vec::new();
223 {
224 let mut writer = FrameWriter::new(&mut buffer, FrameFormat::NewlineDelimited);
225 for msg in &messages {
226 writer.write_frame(msg).unwrap();
227 }
228 }
229
230 let cursor = Cursor::new(buffer);
231 let mut reader = FrameReader::new(cursor, FrameFormat::NewlineDelimited);
232
233 for expected_msg in messages {
234 let msg = reader.read_frame().unwrap().expect("Expected message");
235 assert_eq!(msg, expected_msg);
236 }
237
238 assert_eq!(reader.read_frame().unwrap(), None);
239 }
240
241 #[test]
242 fn test_incomplete_frame() {
243 let msg = Message::cap_ref(CapId::new(42));
244 let frame = encode_frame(&msg, FrameFormat::LengthPrefixed).unwrap();
245
246 let result = decode_frame(&frame[..2], FrameFormat::LengthPrefixed);
247 assert!(result.is_err());
248
249 let result = decode_frame(&frame[..frame.len() - 1], FrameFormat::LengthPrefixed);
250 assert!(result.is_err());
251 }
252
253 #[test]
254 fn test_multiple_frames_in_buffer() {
255 let msg1 = Message::cap_ref(CapId::new(1));
256 let msg2 = Message::cap_ref(CapId::new(2));
257
258 let frame1 = encode_frame(&msg1, FrameFormat::NewlineDelimited).unwrap();
259 let frame2 = encode_frame(&msg2, FrameFormat::NewlineDelimited).unwrap();
260
261 let mut combined = BytesMut::new();
262 combined.put_slice(&frame1);
263 combined.put_slice(&frame2);
264
265 let (decoded1, consumed1) = decode_frame(&combined, FrameFormat::NewlineDelimited).unwrap();
266 assert_eq!(decoded1, msg1);
267
268 let (decoded2, _consumed2) =
269 decode_frame(&combined[consumed1..], FrameFormat::NewlineDelimited).unwrap();
270 assert_eq!(decoded2, msg2);
271 }
272}