Skip to main content

fastmcp_transport/
codec.rs

1//! Message codec for framing JSON-RPC messages.
2//!
3//! MCP uses newline-delimited JSON (NDJSON) for message framing.
4
5use fastmcp_protocol::{JsonRpcMessage, JsonRpcRequest, JsonRpcResponse};
6
7/// Codec for encoding/decoding JSON-RPC messages.
8#[derive(Debug)]
9pub struct Codec {
10    /// Buffer for incomplete messages.
11    buffer: Vec<u8>,
12    /// Read position in buffer (data before this has been consumed).
13    read_pos: usize,
14    /// Maximum allowed message size in bytes.
15    max_message_size: usize,
16}
17
18impl Default for Codec {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24/// Threshold for compacting buffer (when read_pos exceeds this fraction of capacity).
25const COMPACT_THRESHOLD: usize = 4096;
26
27impl Codec {
28    /// Creates a new codec with default settings (10MB limit).
29    #[must_use]
30    pub fn new() -> Self {
31        Self {
32            buffer: Vec::new(),
33            read_pos: 0,
34            max_message_size: 10 * 1024 * 1024, // 10MB
35        }
36    }
37
38    /// Returns the maximum allowed message size in bytes.
39    #[must_use]
40    pub fn max_message_size(&self) -> usize {
41        self.max_message_size
42    }
43
44    /// Sets the maximum allowed message size in bytes.
45    pub fn set_max_message_size(&mut self, size: usize) {
46        self.max_message_size = size;
47        let unread = self.buffer.len() - self.read_pos;
48        if unread > size {
49            self.buffer.clear();
50            self.read_pos = 0;
51        }
52    }
53
54    /// Encodes a request to bytes.
55    ///
56    /// # Errors
57    ///
58    /// Returns an error if serialization fails.
59    pub fn encode_request(&self, request: &JsonRpcRequest) -> Result<Vec<u8>, CodecError> {
60        let mut bytes = serde_json::to_vec(request)?;
61        bytes.push(b'\n');
62        Ok(bytes)
63    }
64
65    /// Encodes a response to bytes.
66    ///
67    /// # Errors
68    ///
69    /// Returns an error if serialization fails.
70    pub fn encode_response(&self, response: &JsonRpcResponse) -> Result<Vec<u8>, CodecError> {
71        let mut bytes = serde_json::to_vec(response)?;
72        bytes.push(b'\n');
73        Ok(bytes)
74    }
75
76    /// Decodes bytes into a message, returning any complete messages.
77    ///
78    /// Incomplete data is buffered for the next call.
79    ///
80    /// # Errors
81    ///
82    /// Returns an error if a complete line fails to parse or if the buffer exceeds the limit.
83    pub fn decode(&mut self, data: &[u8]) -> Result<Vec<JsonRpcMessage>, CodecError> {
84        // Calculate unread data size
85        let unread_len = self.buffer.len() - self.read_pos;
86        let projected_size = unread_len.saturating_add(data.len());
87
88        // Check projected size BEFORE extending to prevent temporary memory exhaustion
89        if projected_size > self.max_message_size {
90            self.buffer.clear();
91            self.read_pos = 0;
92            return Err(CodecError::MessageTooLarge(projected_size));
93        }
94
95        // Compact buffer if read_pos is large (to prevent unbounded growth)
96        if self.read_pos >= COMPACT_THRESHOLD {
97            self.buffer.drain(..self.read_pos);
98            self.read_pos = 0;
99        }
100
101        self.buffer.extend_from_slice(data);
102
103        let mut messages = Vec::new();
104        let mut start = self.read_pos;
105
106        #[allow(clippy::mut_range_bound)]
107        for i in start..self.buffer.len() {
108            if self.buffer[i] == b'\n' {
109                let line_len = i - start;
110                if line_len > self.max_message_size {
111                    self.buffer.clear();
112                    self.read_pos = 0;
113                    return Err(CodecError::MessageTooLarge(line_len));
114                }
115                let line = &self.buffer[start..i];
116                if !line.is_empty() {
117                    let msg: JsonRpcMessage = serde_json::from_slice(line)?;
118                    messages.push(msg);
119                }
120                start = i + 1;
121            }
122        }
123
124        // Update read position instead of draining for each decode call
125        self.read_pos = start;
126
127        // Check remaining unread data
128        let remaining = self.buffer.len() - self.read_pos;
129        if remaining > self.max_message_size {
130            self.buffer.clear();
131            self.read_pos = 0;
132            return Err(CodecError::MessageTooLarge(remaining));
133        }
134
135        Ok(messages)
136    }
137
138    /// Clears the internal buffer.
139    pub fn clear(&mut self) {
140        self.buffer.clear();
141        self.read_pos = 0;
142    }
143}
144
145/// Codec error types.
146#[derive(Debug)]
147pub enum CodecError {
148    /// JSON parsing error.
149    Json(serde_json::Error),
150    /// Message too large.
151    MessageTooLarge(usize),
152}
153
154impl std::fmt::Display for CodecError {
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        match self {
157            CodecError::Json(e) => write!(f, "JSON error: {e}"),
158            CodecError::MessageTooLarge(size) => write!(f, "Message too large: {size} bytes"),
159        }
160    }
161}
162
163impl std::error::Error for CodecError {
164    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
165        match self {
166            CodecError::Json(e) => Some(e),
167            CodecError::MessageTooLarge(_) => None,
168        }
169    }
170}
171
172impl From<serde_json::Error> for CodecError {
173    fn from(err: serde_json::Error) -> Self {
174        CodecError::Json(err)
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use fastmcp_protocol::RequestId;
182    use std::error::Error;
183
184    #[test]
185    fn test_encode_decode_roundtrip() {
186        let codec = Codec::new();
187        let request = JsonRpcRequest::new("test/method", None, 1i64);
188
189        let encoded = codec.encode_request(&request).unwrap();
190        assert!(encoded.ends_with(b"\n"));
191
192        let mut codec2 = Codec::new();
193        let messages = codec2.decode(&encoded).unwrap();
194        assert_eq!(messages.len(), 1);
195    }
196
197    #[test]
198    fn test_encode_response() {
199        let codec = Codec::new();
200        let response =
201            JsonRpcResponse::success(RequestId::Number(1), serde_json::json!({"result": "ok"}));
202
203        let encoded = codec.encode_response(&response).unwrap();
204        assert!(encoded.ends_with(b"\n"));
205
206        let mut codec2 = Codec::new();
207        let messages = codec2.decode(&encoded).unwrap();
208        assert_eq!(messages.len(), 1);
209
210        assert!(
211            matches!(&messages[0], JsonRpcMessage::Response(_)),
212            "Expected response"
213        );
214        if let JsonRpcMessage::Response(resp) = &messages[0] {
215            assert_eq!(resp.id, Some(RequestId::Number(1)));
216        }
217    }
218
219    #[test]
220    fn test_decode_multiple_messages() {
221        let input = b"{\"jsonrpc\":\"2.0\",\"method\":\"test1\",\"id\":1}\n{\"jsonrpc\":\"2.0\",\"method\":\"test2\",\"id\":2}\n";
222
223        let mut codec = Codec::new();
224        let messages = codec.decode(input).unwrap();
225
226        assert_eq!(messages.len(), 2);
227
228        assert!(
229            matches!(&messages[0], JsonRpcMessage::Request(_)),
230            "Expected request"
231        );
232        if let JsonRpcMessage::Request(req) = &messages[0] {
233            assert_eq!(req.method, "test1");
234        }
235
236        assert!(
237            matches!(&messages[1], JsonRpcMessage::Request(_)),
238            "Expected request"
239        );
240        if let JsonRpcMessage::Request(req) = &messages[1] {
241            assert_eq!(req.method, "test2");
242        }
243    }
244
245    #[test]
246    fn test_decode_allows_multiple_messages_in_separate_chunks() {
247        // Since the codec now pre-checks chunk size, send messages in separate chunks
248        let req1 = JsonRpcRequest::new("test1", None, 1i64);
249        let req2 = JsonRpcRequest::new("test2", None, 2i64);
250        let mut line1 = serde_json::to_vec(&req1).unwrap();
251        let mut line2 = serde_json::to_vec(&req2).unwrap();
252        line1.push(b'\n');
253        line2.push(b'\n');
254
255        let mut codec = Codec::new();
256        // Set limit to accommodate one message at a time
257        codec.set_max_message_size(line1.len());
258
259        // Decode first message
260        let messages1 = codec.decode(&line1).unwrap();
261        assert_eq!(messages1.len(), 1);
262
263        // Decode second message
264        let messages2 = codec.decode(&line2).unwrap();
265        assert_eq!(messages2.len(), 1);
266    }
267
268    #[test]
269    fn test_decode_rejects_oversized_incomplete_line() {
270        let req = JsonRpcRequest::new("oversized", None, 1i64);
271        let line = serde_json::to_vec(&req).unwrap();
272
273        let mut codec = Codec::new();
274        codec.max_message_size = line.len().saturating_sub(1);
275
276        let result = codec.decode(&line);
277        assert!(matches!(result, Err(CodecError::MessageTooLarge(_))));
278    }
279
280    #[test]
281    fn test_decode_partial_message() {
282        let mut codec = Codec::new();
283
284        // Feed partial data without newline
285        let partial = b"{\"jsonrpc\":\"2.0\",\"method\":\"test\"";
286        let messages = codec.decode(partial).unwrap();
287        assert_eq!(messages.len(), 0); // No complete messages yet
288
289        // Feed the rest including newline
290        let rest = b",\"id\":1}\n";
291        let messages = codec.decode(rest).unwrap();
292        assert_eq!(messages.len(), 1);
293
294        assert!(
295            matches!(&messages[0], JsonRpcMessage::Request(_)),
296            "Expected request"
297        );
298        if let JsonRpcMessage::Request(req) = &messages[0] {
299            assert_eq!(req.method, "test");
300        }
301    }
302
303    #[test]
304    fn test_decode_invalid_json() {
305        let mut codec = Codec::new();
306        let invalid = b"not valid json\n";
307
308        let result = codec.decode(invalid);
309        assert!(result.is_err());
310
311        let err = result.unwrap_err();
312        assert!(matches!(err, CodecError::Json(_)));
313    }
314
315    #[test]
316    fn test_decode_empty_line() {
317        let mut codec = Codec::new();
318        let input = b"\n{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"id\":1}\n";
319
320        let messages = codec.decode(input).unwrap();
321        assert_eq!(messages.len(), 1); // Empty line skipped
322    }
323
324    #[test]
325    fn test_clear_buffer() {
326        let mut codec = Codec::new();
327
328        // Feed partial data
329        let partial = b"{\"jsonrpc\":\"2.0\"";
330        codec.decode(partial).unwrap();
331
332        // Clear and verify buffer is empty
333        codec.clear();
334
335        // Feed a complete message - should parse without old partial data
336        let complete = b"{\"jsonrpc\":\"2.0\",\"method\":\"fresh\",\"id\":1}\n";
337        let messages = codec.decode(complete).unwrap();
338
339        assert_eq!(messages.len(), 1);
340        assert!(
341            matches!(&messages[0], JsonRpcMessage::Request(_)),
342            "Expected request"
343        );
344        if let JsonRpcMessage::Request(req) = &messages[0] {
345            assert_eq!(req.method, "fresh");
346        }
347    }
348
349    #[test]
350    fn test_codec_error_display() {
351        let json_err = CodecError::Json(serde_json::from_str::<()>("invalid").unwrap_err());
352        let size_err = CodecError::MessageTooLarge(1000);
353
354        assert!(json_err.to_string().contains("JSON error"));
355        assert!(size_err.to_string().contains("1000"));
356    }
357
358    #[test]
359    fn test_codec_error_source() {
360        let json_err = CodecError::Json(serde_json::from_str::<()>("invalid").unwrap_err());
361        let size_err = CodecError::MessageTooLarge(1000);
362
363        assert!(json_err.source().is_some());
364        assert!(size_err.source().is_none());
365    }
366
367    #[test]
368    fn test_default_max_message_size() {
369        let codec = Codec::new();
370        assert_eq!(codec.max_message_size(), 10 * 1024 * 1024);
371    }
372
373    #[test]
374    fn test_set_max_message_size() {
375        let mut codec = Codec::new();
376        codec.set_max_message_size(1024);
377        assert_eq!(codec.max_message_size(), 1024);
378    }
379
380    #[test]
381    fn test_set_max_message_size_clears_oversized_buffer() {
382        let mut codec = Codec::new();
383        // Feed some data into buffer without a newline
384        let partial = b"{\"jsonrpc\":\"2.0\",\"method\":\"test\"";
385        codec.decode(partial).unwrap();
386
387        // Shrink max size to less than buffered data
388        codec.set_max_message_size(5);
389
390        // Buffer should have been cleared; new data should work
391        let small = b"{}\n";
392        // This will fail because {} is not valid JsonRpc, but the point is
393        // the buffer was cleared and doesn't have old data
394        let result = codec.decode(small);
395        // Either error (invalid json) or success - just verify no panic
396        let _ = result;
397    }
398
399    #[test]
400    fn test_codec_default_trait() {
401        let codec = Codec::default();
402        assert_eq!(codec.max_message_size(), 10 * 1024 * 1024);
403    }
404
405    #[test]
406    fn test_decode_oversized_projected_data() {
407        let mut codec = Codec::new();
408        codec.set_max_message_size(50);
409
410        // Feed data that exceeds max when projected
411        let big = vec![b'x'; 100];
412        let result = codec.decode(&big);
413        assert!(matches!(result, Err(CodecError::MessageTooLarge(_))));
414    }
415
416    #[test]
417    fn test_buffer_compaction_after_threshold() {
418        let mut codec = Codec::new();
419
420        // Feed many small complete messages to advance read_pos past COMPACT_THRESHOLD
421        let msg = b"{\"jsonrpc\":\"2.0\",\"method\":\"m\",\"id\":1}\n";
422        let many_messages: Vec<u8> = msg.repeat(200); // ~7400 bytes > 4096 threshold
423
424        let messages = codec.decode(&many_messages).unwrap();
425        assert_eq!(messages.len(), 200);
426
427        // After decoding, read_pos should have been compacted
428        // Verify codec still works correctly
429        let next_msg = b"{\"jsonrpc\":\"2.0\",\"method\":\"after_compact\",\"id\":2}\n";
430        let messages = codec.decode(next_msg).unwrap();
431        assert_eq!(messages.len(), 1);
432        if let JsonRpcMessage::Request(req) = &messages[0] {
433            assert_eq!(req.method, "after_compact");
434        }
435    }
436
437    #[test]
438    fn test_decode_utf8_message() {
439        let mut codec = Codec::new();
440        let json = "{\"jsonrpc\":\"2.0\",\"method\":\"test/日本語\",\"id\":1}\n";
441        let messages = codec.decode(json.as_bytes()).unwrap();
442        assert_eq!(messages.len(), 1);
443        if let JsonRpcMessage::Request(req) = &messages[0] {
444            assert_eq!(req.method, "test/日本語");
445        }
446    }
447
448    #[test]
449    fn test_decode_consecutive_newlines() {
450        let mut codec = Codec::new();
451        let input = b"\n\n{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"id\":1}\n\n\n";
452        let messages = codec.decode(input).unwrap();
453        assert_eq!(messages.len(), 1);
454    }
455
456    #[test]
457    fn test_clear_resets_state() {
458        let mut codec = Codec::new();
459
460        // Feed partial data
461        codec.decode(b"{\"jsonrpc\":\"2.0\"").unwrap();
462        codec.clear();
463
464        // Verify internal state is reset by sending a fresh complete message
465        let complete = b"{\"jsonrpc\":\"2.0\",\"method\":\"post_clear\",\"id\":1}\n";
466        let messages = codec.decode(complete).unwrap();
467        assert_eq!(messages.len(), 1);
468        if let JsonRpcMessage::Request(req) = &messages[0] {
469            assert_eq!(req.method, "post_clear");
470        }
471    }
472
473    #[test]
474    fn test_codec_error_from_serde() {
475        let serde_err = serde_json::from_str::<()>("bad").unwrap_err();
476        let codec_err: CodecError = serde_err.into();
477        assert!(matches!(codec_err, CodecError::Json(_)));
478    }
479
480    #[test]
481    fn test_encode_request_contains_newline() {
482        let codec = Codec::new();
483        let request = JsonRpcRequest::new("m", None, 1i64);
484        let encoded = codec.encode_request(&request).unwrap();
485        assert_eq!(encoded.last(), Some(&b'\n'));
486        // Everything before \n should be valid JSON
487        let json_part = &encoded[..encoded.len() - 1];
488        let _: JsonRpcRequest = serde_json::from_slice(json_part).expect("valid JSON");
489    }
490
491    #[test]
492    fn test_encode_response_contains_newline() {
493        let codec = Codec::new();
494        let response =
495            JsonRpcResponse::success(RequestId::Number(1), serde_json::json!({"ok": true}));
496        let encoded = codec.encode_response(&response).unwrap();
497        assert_eq!(encoded.last(), Some(&b'\n'));
498    }
499
500    #[test]
501    fn test_decode_notification_without_id() {
502        let mut codec = Codec::new();
503        let input = b"{\"jsonrpc\":\"2.0\",\"method\":\"notifications/test\"}\n";
504        let messages = codec.decode(input).unwrap();
505        assert_eq!(messages.len(), 1);
506        if let JsonRpcMessage::Request(req) = &messages[0] {
507            assert!(req.id.is_none());
508            assert_eq!(req.method, "notifications/test");
509        }
510    }
511}