fastmcp_transport/
codec.rs1use fastmcp_protocol::{JsonRpcMessage, JsonRpcRequest, JsonRpcResponse};
6
7#[derive(Debug)]
9pub struct Codec {
10 buffer: Vec<u8>,
12 read_pos: usize,
14 max_message_size: usize,
16}
17
18impl Default for Codec {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24const COMPACT_THRESHOLD: usize = 4096;
26
27impl Codec {
28 #[must_use]
30 pub fn new() -> Self {
31 Self {
32 buffer: Vec::new(),
33 read_pos: 0,
34 max_message_size: 10 * 1024 * 1024, }
36 }
37
38 #[must_use]
40 pub fn max_message_size(&self) -> usize {
41 self.max_message_size
42 }
43
44 pub fn set_max_message_size(&mut self, size: usize) {
46 self.max_message_size = size;
47 let unread = self.buffer.len() - self.read_pos;
48 if unread > size {
49 self.buffer.clear();
50 self.read_pos = 0;
51 }
52 }
53
54 pub fn encode_request(&self, request: &JsonRpcRequest) -> Result<Vec<u8>, CodecError> {
60 let mut bytes = serde_json::to_vec(request)?;
61 bytes.push(b'\n');
62 Ok(bytes)
63 }
64
65 pub fn encode_response(&self, response: &JsonRpcResponse) -> Result<Vec<u8>, CodecError> {
71 let mut bytes = serde_json::to_vec(response)?;
72 bytes.push(b'\n');
73 Ok(bytes)
74 }
75
76 pub fn decode(&mut self, data: &[u8]) -> Result<Vec<JsonRpcMessage>, CodecError> {
84 let unread_len = self.buffer.len() - self.read_pos;
86 let projected_size = unread_len.saturating_add(data.len());
87
88 if projected_size > self.max_message_size {
90 self.buffer.clear();
91 self.read_pos = 0;
92 return Err(CodecError::MessageTooLarge(projected_size));
93 }
94
95 if self.read_pos >= COMPACT_THRESHOLD {
97 self.buffer.drain(..self.read_pos);
98 self.read_pos = 0;
99 }
100
101 self.buffer.extend_from_slice(data);
102
103 let mut messages = Vec::new();
104 let mut start = self.read_pos;
105
106 #[allow(clippy::mut_range_bound)]
107 for i in start..self.buffer.len() {
108 if self.buffer[i] == b'\n' {
109 let line_len = i - start;
110 if line_len > self.max_message_size {
111 self.buffer.clear();
112 self.read_pos = 0;
113 return Err(CodecError::MessageTooLarge(line_len));
114 }
115 let line = &self.buffer[start..i];
116 if !line.is_empty() {
117 let msg: JsonRpcMessage = serde_json::from_slice(line)?;
118 messages.push(msg);
119 }
120 start = i + 1;
121 }
122 }
123
124 self.read_pos = start;
126
127 let remaining = self.buffer.len() - self.read_pos;
129 if remaining > self.max_message_size {
130 self.buffer.clear();
131 self.read_pos = 0;
132 return Err(CodecError::MessageTooLarge(remaining));
133 }
134
135 Ok(messages)
136 }
137
138 pub fn clear(&mut self) {
140 self.buffer.clear();
141 self.read_pos = 0;
142 }
143}
144
145#[derive(Debug)]
147pub enum CodecError {
148 Json(serde_json::Error),
150 MessageTooLarge(usize),
152}
153
154impl std::fmt::Display for CodecError {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 match self {
157 CodecError::Json(e) => write!(f, "JSON error: {e}"),
158 CodecError::MessageTooLarge(size) => write!(f, "Message too large: {size} bytes"),
159 }
160 }
161}
162
163impl std::error::Error for CodecError {
164 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
165 match self {
166 CodecError::Json(e) => Some(e),
167 CodecError::MessageTooLarge(_) => None,
168 }
169 }
170}
171
172impl From<serde_json::Error> for CodecError {
173 fn from(err: serde_json::Error) -> Self {
174 CodecError::Json(err)
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use fastmcp_protocol::RequestId;
182 use std::error::Error;
183
184 #[test]
185 fn test_encode_decode_roundtrip() {
186 let codec = Codec::new();
187 let request = JsonRpcRequest::new("test/method", None, 1i64);
188
189 let encoded = codec.encode_request(&request).unwrap();
190 assert!(encoded.ends_with(b"\n"));
191
192 let mut codec2 = Codec::new();
193 let messages = codec2.decode(&encoded).unwrap();
194 assert_eq!(messages.len(), 1);
195 }
196
197 #[test]
198 fn test_encode_response() {
199 let codec = Codec::new();
200 let response =
201 JsonRpcResponse::success(RequestId::Number(1), serde_json::json!({"result": "ok"}));
202
203 let encoded = codec.encode_response(&response).unwrap();
204 assert!(encoded.ends_with(b"\n"));
205
206 let mut codec2 = Codec::new();
207 let messages = codec2.decode(&encoded).unwrap();
208 assert_eq!(messages.len(), 1);
209
210 assert!(
211 matches!(&messages[0], JsonRpcMessage::Response(_)),
212 "Expected response"
213 );
214 if let JsonRpcMessage::Response(resp) = &messages[0] {
215 assert_eq!(resp.id, Some(RequestId::Number(1)));
216 }
217 }
218
219 #[test]
220 fn test_decode_multiple_messages() {
221 let input = b"{\"jsonrpc\":\"2.0\",\"method\":\"test1\",\"id\":1}\n{\"jsonrpc\":\"2.0\",\"method\":\"test2\",\"id\":2}\n";
222
223 let mut codec = Codec::new();
224 let messages = codec.decode(input).unwrap();
225
226 assert_eq!(messages.len(), 2);
227
228 assert!(
229 matches!(&messages[0], JsonRpcMessage::Request(_)),
230 "Expected request"
231 );
232 if let JsonRpcMessage::Request(req) = &messages[0] {
233 assert_eq!(req.method, "test1");
234 }
235
236 assert!(
237 matches!(&messages[1], JsonRpcMessage::Request(_)),
238 "Expected request"
239 );
240 if let JsonRpcMessage::Request(req) = &messages[1] {
241 assert_eq!(req.method, "test2");
242 }
243 }
244
245 #[test]
246 fn test_decode_allows_multiple_messages_in_separate_chunks() {
247 let req1 = JsonRpcRequest::new("test1", None, 1i64);
249 let req2 = JsonRpcRequest::new("test2", None, 2i64);
250 let mut line1 = serde_json::to_vec(&req1).unwrap();
251 let mut line2 = serde_json::to_vec(&req2).unwrap();
252 line1.push(b'\n');
253 line2.push(b'\n');
254
255 let mut codec = Codec::new();
256 codec.set_max_message_size(line1.len());
258
259 let messages1 = codec.decode(&line1).unwrap();
261 assert_eq!(messages1.len(), 1);
262
263 let messages2 = codec.decode(&line2).unwrap();
265 assert_eq!(messages2.len(), 1);
266 }
267
268 #[test]
269 fn test_decode_rejects_oversized_incomplete_line() {
270 let req = JsonRpcRequest::new("oversized", None, 1i64);
271 let line = serde_json::to_vec(&req).unwrap();
272
273 let mut codec = Codec::new();
274 codec.max_message_size = line.len().saturating_sub(1);
275
276 let result = codec.decode(&line);
277 assert!(matches!(result, Err(CodecError::MessageTooLarge(_))));
278 }
279
280 #[test]
281 fn test_decode_partial_message() {
282 let mut codec = Codec::new();
283
284 let partial = b"{\"jsonrpc\":\"2.0\",\"method\":\"test\"";
286 let messages = codec.decode(partial).unwrap();
287 assert_eq!(messages.len(), 0); let rest = b",\"id\":1}\n";
291 let messages = codec.decode(rest).unwrap();
292 assert_eq!(messages.len(), 1);
293
294 assert!(
295 matches!(&messages[0], JsonRpcMessage::Request(_)),
296 "Expected request"
297 );
298 if let JsonRpcMessage::Request(req) = &messages[0] {
299 assert_eq!(req.method, "test");
300 }
301 }
302
303 #[test]
304 fn test_decode_invalid_json() {
305 let mut codec = Codec::new();
306 let invalid = b"not valid json\n";
307
308 let result = codec.decode(invalid);
309 assert!(result.is_err());
310
311 let err = result.unwrap_err();
312 assert!(matches!(err, CodecError::Json(_)));
313 }
314
315 #[test]
316 fn test_decode_empty_line() {
317 let mut codec = Codec::new();
318 let input = b"\n{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"id\":1}\n";
319
320 let messages = codec.decode(input).unwrap();
321 assert_eq!(messages.len(), 1); }
323
324 #[test]
325 fn test_clear_buffer() {
326 let mut codec = Codec::new();
327
328 let partial = b"{\"jsonrpc\":\"2.0\"";
330 codec.decode(partial).unwrap();
331
332 codec.clear();
334
335 let complete = b"{\"jsonrpc\":\"2.0\",\"method\":\"fresh\",\"id\":1}\n";
337 let messages = codec.decode(complete).unwrap();
338
339 assert_eq!(messages.len(), 1);
340 assert!(
341 matches!(&messages[0], JsonRpcMessage::Request(_)),
342 "Expected request"
343 );
344 if let JsonRpcMessage::Request(req) = &messages[0] {
345 assert_eq!(req.method, "fresh");
346 }
347 }
348
349 #[test]
350 fn test_codec_error_display() {
351 let json_err = CodecError::Json(serde_json::from_str::<()>("invalid").unwrap_err());
352 let size_err = CodecError::MessageTooLarge(1000);
353
354 assert!(json_err.to_string().contains("JSON error"));
355 assert!(size_err.to_string().contains("1000"));
356 }
357
358 #[test]
359 fn test_codec_error_source() {
360 let json_err = CodecError::Json(serde_json::from_str::<()>("invalid").unwrap_err());
361 let size_err = CodecError::MessageTooLarge(1000);
362
363 assert!(json_err.source().is_some());
364 assert!(size_err.source().is_none());
365 }
366}