1use fastmcp_protocol::{JsonRpcMessage, JsonRpcRequest, JsonRpcResponse};
6
7#[derive(Debug)]
9pub struct Codec {
10 buffer: Vec<u8>,
12 read_pos: usize,
14 max_message_size: usize,
16}
17
18impl Default for Codec {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24const COMPACT_THRESHOLD: usize = 4096;
26
27impl Codec {
28 #[must_use]
30 pub fn new() -> Self {
31 Self {
32 buffer: Vec::new(),
33 read_pos: 0,
34 max_message_size: 10 * 1024 * 1024, }
36 }
37
38 #[must_use]
40 pub fn max_message_size(&self) -> usize {
41 self.max_message_size
42 }
43
44 pub fn set_max_message_size(&mut self, size: usize) {
46 self.max_message_size = size;
47 let unread = self.buffer.len() - self.read_pos;
48 if unread > size {
49 self.buffer.clear();
50 self.read_pos = 0;
51 }
52 }
53
54 pub fn encode_request(&self, request: &JsonRpcRequest) -> Result<Vec<u8>, CodecError> {
60 let mut bytes = serde_json::to_vec(request)?;
61 bytes.push(b'\n');
62 Ok(bytes)
63 }
64
65 pub fn encode_response(&self, response: &JsonRpcResponse) -> Result<Vec<u8>, CodecError> {
71 let mut bytes = serde_json::to_vec(response)?;
72 bytes.push(b'\n');
73 Ok(bytes)
74 }
75
76 pub fn decode(&mut self, data: &[u8]) -> Result<Vec<JsonRpcMessage>, CodecError> {
84 let unread_len = self.buffer.len() - self.read_pos;
86 let projected_size = unread_len.saturating_add(data.len());
87
88 if projected_size > self.max_message_size {
90 self.buffer.clear();
91 self.read_pos = 0;
92 return Err(CodecError::MessageTooLarge(projected_size));
93 }
94
95 if self.read_pos >= COMPACT_THRESHOLD {
97 self.buffer.drain(..self.read_pos);
98 self.read_pos = 0;
99 }
100
101 self.buffer.extend_from_slice(data);
102
103 let mut messages = Vec::new();
104 let mut start = self.read_pos;
105
106 #[allow(clippy::mut_range_bound)]
107 for i in start..self.buffer.len() {
108 if self.buffer[i] == b'\n' {
109 let line_len = i - start;
110 if line_len > self.max_message_size {
111 self.buffer.clear();
112 self.read_pos = 0;
113 return Err(CodecError::MessageTooLarge(line_len));
114 }
115 let line = &self.buffer[start..i];
116 if !line.is_empty() {
117 let msg: JsonRpcMessage = serde_json::from_slice(line)?;
118 messages.push(msg);
119 }
120 start = i + 1;
121 }
122 }
123
124 self.read_pos = start;
126
127 let remaining = self.buffer.len() - self.read_pos;
129 if remaining > self.max_message_size {
130 self.buffer.clear();
131 self.read_pos = 0;
132 return Err(CodecError::MessageTooLarge(remaining));
133 }
134
135 Ok(messages)
136 }
137
138 pub fn clear(&mut self) {
140 self.buffer.clear();
141 self.read_pos = 0;
142 }
143}
144
145#[derive(Debug)]
147pub enum CodecError {
148 Json(serde_json::Error),
150 MessageTooLarge(usize),
152}
153
154impl std::fmt::Display for CodecError {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 match self {
157 CodecError::Json(e) => write!(f, "JSON error: {e}"),
158 CodecError::MessageTooLarge(size) => write!(f, "Message too large: {size} bytes"),
159 }
160 }
161}
162
163impl std::error::Error for CodecError {
164 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
165 match self {
166 CodecError::Json(e) => Some(e),
167 CodecError::MessageTooLarge(_) => None,
168 }
169 }
170}
171
172impl From<serde_json::Error> for CodecError {
173 fn from(err: serde_json::Error) -> Self {
174 CodecError::Json(err)
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use fastmcp_protocol::RequestId;
182 use std::error::Error;
183
184 #[test]
185 fn test_encode_decode_roundtrip() {
186 let codec = Codec::new();
187 let request = JsonRpcRequest::new("test/method", None, 1i64);
188
189 let encoded = codec.encode_request(&request).unwrap();
190 assert!(encoded.ends_with(b"\n"));
191
192 let mut codec2 = Codec::new();
193 let messages = codec2.decode(&encoded).unwrap();
194 assert_eq!(messages.len(), 1);
195 }
196
197 #[test]
198 fn test_encode_response() {
199 let codec = Codec::new();
200 let response =
201 JsonRpcResponse::success(RequestId::Number(1), serde_json::json!({"result": "ok"}));
202
203 let encoded = codec.encode_response(&response).unwrap();
204 assert!(encoded.ends_with(b"\n"));
205
206 let mut codec2 = Codec::new();
207 let messages = codec2.decode(&encoded).unwrap();
208 assert_eq!(messages.len(), 1);
209
210 assert!(
211 matches!(&messages[0], JsonRpcMessage::Response(_)),
212 "Expected response"
213 );
214 if let JsonRpcMessage::Response(resp) = &messages[0] {
215 assert_eq!(resp.id, Some(RequestId::Number(1)));
216 }
217 }
218
219 #[test]
220 fn test_decode_multiple_messages() {
221 let input = b"{\"jsonrpc\":\"2.0\",\"method\":\"test1\",\"id\":1}\n{\"jsonrpc\":\"2.0\",\"method\":\"test2\",\"id\":2}\n";
222
223 let mut codec = Codec::new();
224 let messages = codec.decode(input).unwrap();
225
226 assert_eq!(messages.len(), 2);
227
228 assert!(
229 matches!(&messages[0], JsonRpcMessage::Request(_)),
230 "Expected request"
231 );
232 if let JsonRpcMessage::Request(req) = &messages[0] {
233 assert_eq!(req.method, "test1");
234 }
235
236 assert!(
237 matches!(&messages[1], JsonRpcMessage::Request(_)),
238 "Expected request"
239 );
240 if let JsonRpcMessage::Request(req) = &messages[1] {
241 assert_eq!(req.method, "test2");
242 }
243 }
244
245 #[test]
246 fn test_decode_allows_multiple_messages_in_separate_chunks() {
247 let req1 = JsonRpcRequest::new("test1", None, 1i64);
249 let req2 = JsonRpcRequest::new("test2", None, 2i64);
250 let mut line1 = serde_json::to_vec(&req1).unwrap();
251 let mut line2 = serde_json::to_vec(&req2).unwrap();
252 line1.push(b'\n');
253 line2.push(b'\n');
254
255 let mut codec = Codec::new();
256 codec.set_max_message_size(line1.len());
258
259 let messages1 = codec.decode(&line1).unwrap();
261 assert_eq!(messages1.len(), 1);
262
263 let messages2 = codec.decode(&line2).unwrap();
265 assert_eq!(messages2.len(), 1);
266 }
267
268 #[test]
269 fn test_decode_rejects_oversized_incomplete_line() {
270 let req = JsonRpcRequest::new("oversized", None, 1i64);
271 let line = serde_json::to_vec(&req).unwrap();
272
273 let mut codec = Codec::new();
274 codec.max_message_size = line.len().saturating_sub(1);
275
276 let result = codec.decode(&line);
277 assert!(matches!(result, Err(CodecError::MessageTooLarge(_))));
278 }
279
280 #[test]
281 fn test_decode_partial_message() {
282 let mut codec = Codec::new();
283
284 let partial = b"{\"jsonrpc\":\"2.0\",\"method\":\"test\"";
286 let messages = codec.decode(partial).unwrap();
287 assert_eq!(messages.len(), 0); let rest = b",\"id\":1}\n";
291 let messages = codec.decode(rest).unwrap();
292 assert_eq!(messages.len(), 1);
293
294 assert!(
295 matches!(&messages[0], JsonRpcMessage::Request(_)),
296 "Expected request"
297 );
298 if let JsonRpcMessage::Request(req) = &messages[0] {
299 assert_eq!(req.method, "test");
300 }
301 }
302
303 #[test]
304 fn test_decode_invalid_json() {
305 let mut codec = Codec::new();
306 let invalid = b"not valid json\n";
307
308 let result = codec.decode(invalid);
309 assert!(result.is_err());
310
311 let err = result.unwrap_err();
312 assert!(matches!(err, CodecError::Json(_)));
313 }
314
315 #[test]
316 fn test_decode_empty_line() {
317 let mut codec = Codec::new();
318 let input = b"\n{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"id\":1}\n";
319
320 let messages = codec.decode(input).unwrap();
321 assert_eq!(messages.len(), 1); }
323
324 #[test]
325 fn test_clear_buffer() {
326 let mut codec = Codec::new();
327
328 let partial = b"{\"jsonrpc\":\"2.0\"";
330 codec.decode(partial).unwrap();
331
332 codec.clear();
334
335 let complete = b"{\"jsonrpc\":\"2.0\",\"method\":\"fresh\",\"id\":1}\n";
337 let messages = codec.decode(complete).unwrap();
338
339 assert_eq!(messages.len(), 1);
340 assert!(
341 matches!(&messages[0], JsonRpcMessage::Request(_)),
342 "Expected request"
343 );
344 if let JsonRpcMessage::Request(req) = &messages[0] {
345 assert_eq!(req.method, "fresh");
346 }
347 }
348
349 #[test]
350 fn test_codec_error_display() {
351 let json_err = CodecError::Json(serde_json::from_str::<()>("invalid").unwrap_err());
352 let size_err = CodecError::MessageTooLarge(1000);
353
354 assert!(json_err.to_string().contains("JSON error"));
355 assert!(size_err.to_string().contains("1000"));
356 }
357
358 #[test]
359 fn test_codec_error_source() {
360 let json_err = CodecError::Json(serde_json::from_str::<()>("invalid").unwrap_err());
361 let size_err = CodecError::MessageTooLarge(1000);
362
363 assert!(json_err.source().is_some());
364 assert!(size_err.source().is_none());
365 }
366
367 #[test]
368 fn test_default_max_message_size() {
369 let codec = Codec::new();
370 assert_eq!(codec.max_message_size(), 10 * 1024 * 1024);
371 }
372
373 #[test]
374 fn test_set_max_message_size() {
375 let mut codec = Codec::new();
376 codec.set_max_message_size(1024);
377 assert_eq!(codec.max_message_size(), 1024);
378 }
379
380 #[test]
381 fn test_set_max_message_size_clears_oversized_buffer() {
382 let mut codec = Codec::new();
383 let partial = b"{\"jsonrpc\":\"2.0\",\"method\":\"test\"";
385 codec.decode(partial).unwrap();
386
387 codec.set_max_message_size(5);
389
390 let small = b"{}\n";
392 let result = codec.decode(small);
395 let _ = result;
397 }
398
399 #[test]
400 fn test_codec_default_trait() {
401 let codec = Codec::default();
402 assert_eq!(codec.max_message_size(), 10 * 1024 * 1024);
403 }
404
405 #[test]
406 fn test_decode_oversized_projected_data() {
407 let mut codec = Codec::new();
408 codec.set_max_message_size(50);
409
410 let big = vec![b'x'; 100];
412 let result = codec.decode(&big);
413 assert!(matches!(result, Err(CodecError::MessageTooLarge(_))));
414 }
415
416 #[test]
417 fn test_buffer_compaction_after_threshold() {
418 let mut codec = Codec::new();
419
420 let msg = b"{\"jsonrpc\":\"2.0\",\"method\":\"m\",\"id\":1}\n";
422 let many_messages: Vec<u8> = msg.repeat(200); let messages = codec.decode(&many_messages).unwrap();
425 assert_eq!(messages.len(), 200);
426
427 let next_msg = b"{\"jsonrpc\":\"2.0\",\"method\":\"after_compact\",\"id\":2}\n";
430 let messages = codec.decode(next_msg).unwrap();
431 assert_eq!(messages.len(), 1);
432 if let JsonRpcMessage::Request(req) = &messages[0] {
433 assert_eq!(req.method, "after_compact");
434 }
435 }
436
437 #[test]
438 fn test_decode_utf8_message() {
439 let mut codec = Codec::new();
440 let json = "{\"jsonrpc\":\"2.0\",\"method\":\"test/日本語\",\"id\":1}\n";
441 let messages = codec.decode(json.as_bytes()).unwrap();
442 assert_eq!(messages.len(), 1);
443 if let JsonRpcMessage::Request(req) = &messages[0] {
444 assert_eq!(req.method, "test/日本語");
445 }
446 }
447
448 #[test]
449 fn test_decode_consecutive_newlines() {
450 let mut codec = Codec::new();
451 let input = b"\n\n{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"id\":1}\n\n\n";
452 let messages = codec.decode(input).unwrap();
453 assert_eq!(messages.len(), 1);
454 }
455
456 #[test]
457 fn test_clear_resets_state() {
458 let mut codec = Codec::new();
459
460 codec.decode(b"{\"jsonrpc\":\"2.0\"").unwrap();
462 codec.clear();
463
464 let complete = b"{\"jsonrpc\":\"2.0\",\"method\":\"post_clear\",\"id\":1}\n";
466 let messages = codec.decode(complete).unwrap();
467 assert_eq!(messages.len(), 1);
468 if let JsonRpcMessage::Request(req) = &messages[0] {
469 assert_eq!(req.method, "post_clear");
470 }
471 }
472
473 #[test]
474 fn test_codec_error_from_serde() {
475 let serde_err = serde_json::from_str::<()>("bad").unwrap_err();
476 let codec_err: CodecError = serde_err.into();
477 assert!(matches!(codec_err, CodecError::Json(_)));
478 }
479
480 #[test]
481 fn test_encode_request_contains_newline() {
482 let codec = Codec::new();
483 let request = JsonRpcRequest::new("m", None, 1i64);
484 let encoded = codec.encode_request(&request).unwrap();
485 assert_eq!(encoded.last(), Some(&b'\n'));
486 let json_part = &encoded[..encoded.len() - 1];
488 let _: JsonRpcRequest = serde_json::from_slice(json_part).expect("valid JSON");
489 }
490
491 #[test]
492 fn test_encode_response_contains_newline() {
493 let codec = Codec::new();
494 let response =
495 JsonRpcResponse::success(RequestId::Number(1), serde_json::json!({"ok": true}));
496 let encoded = codec.encode_response(&response).unwrap();
497 assert_eq!(encoded.last(), Some(&b'\n'));
498 }
499
500 #[test]
501 fn test_decode_notification_without_id() {
502 let mut codec = Codec::new();
503 let input = b"{\"jsonrpc\":\"2.0\",\"method\":\"notifications/test\"}\n";
504 let messages = codec.decode(input).unwrap();
505 assert_eq!(messages.len(), 1);
506 if let JsonRpcMessage::Request(req) = &messages[0] {
507 assert!(req.id.is_none());
508 assert_eq!(req.method, "notifications/test");
509 }
510 }
511}