Skip to main content

fastmcp_transport/
codec.rs

1//! Message codec for framing JSON-RPC messages.
2//!
3//! MCP uses newline-delimited JSON (NDJSON) for message framing.
4
5use fastmcp_protocol::{JsonRpcMessage, JsonRpcRequest, JsonRpcResponse};
6
7/// Codec for encoding/decoding JSON-RPC messages.
8#[derive(Debug)]
9pub struct Codec {
10    /// Buffer for incomplete messages.
11    buffer: Vec<u8>,
12    /// Read position in buffer (data before this has been consumed).
13    read_pos: usize,
14    /// Maximum allowed message size in bytes.
15    max_message_size: usize,
16}
17
18impl Default for Codec {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24/// Threshold for compacting buffer (when read_pos exceeds this fraction of capacity).
25const COMPACT_THRESHOLD: usize = 4096;
26
27impl Codec {
28    /// Creates a new codec with default settings (10MB limit).
29    #[must_use]
30    pub fn new() -> Self {
31        Self {
32            buffer: Vec::new(),
33            read_pos: 0,
34            max_message_size: 10 * 1024 * 1024, // 10MB
35        }
36    }
37
38    /// Returns the maximum allowed message size in bytes.
39    #[must_use]
40    pub fn max_message_size(&self) -> usize {
41        self.max_message_size
42    }
43
44    /// Sets the maximum allowed message size in bytes.
45    pub fn set_max_message_size(&mut self, size: usize) {
46        self.max_message_size = size;
47        let unread = self.buffer.len() - self.read_pos;
48        if unread > size {
49            self.buffer.clear();
50            self.read_pos = 0;
51        }
52    }
53
54    /// Encodes a request to bytes.
55    ///
56    /// # Errors
57    ///
58    /// Returns an error if serialization fails.
59    pub fn encode_request(&self, request: &JsonRpcRequest) -> Result<Vec<u8>, CodecError> {
60        let mut bytes = serde_json::to_vec(request)?;
61        bytes.push(b'\n');
62        Ok(bytes)
63    }
64
65    /// Encodes a response to bytes.
66    ///
67    /// # Errors
68    ///
69    /// Returns an error if serialization fails.
70    pub fn encode_response(&self, response: &JsonRpcResponse) -> Result<Vec<u8>, CodecError> {
71        let mut bytes = serde_json::to_vec(response)?;
72        bytes.push(b'\n');
73        Ok(bytes)
74    }
75
76    /// Decodes bytes into a message, returning any complete messages.
77    ///
78    /// Incomplete data is buffered for the next call.
79    ///
80    /// # Errors
81    ///
82    /// Returns an error if a complete line fails to parse or if the buffer exceeds the limit.
83    pub fn decode(&mut self, data: &[u8]) -> Result<Vec<JsonRpcMessage>, CodecError> {
84        // Calculate unread data size
85        let unread_len = self.buffer.len() - self.read_pos;
86        let projected_size = unread_len.saturating_add(data.len());
87
88        // Check projected size BEFORE extending to prevent temporary memory exhaustion
89        if projected_size > self.max_message_size {
90            self.buffer.clear();
91            self.read_pos = 0;
92            return Err(CodecError::MessageTooLarge(projected_size));
93        }
94
95        // Compact buffer if read_pos is large (to prevent unbounded growth)
96        if self.read_pos >= COMPACT_THRESHOLD {
97            self.buffer.drain(..self.read_pos);
98            self.read_pos = 0;
99        }
100
101        self.buffer.extend_from_slice(data);
102
103        let mut messages = Vec::new();
104        let mut start = self.read_pos;
105
106        #[allow(clippy::mut_range_bound)]
107        for i in start..self.buffer.len() {
108            if self.buffer[i] == b'\n' {
109                let line_len = i - start;
110                if line_len > self.max_message_size {
111                    self.buffer.clear();
112                    self.read_pos = 0;
113                    return Err(CodecError::MessageTooLarge(line_len));
114                }
115                let line = &self.buffer[start..i];
116                if !line.is_empty() {
117                    let msg: JsonRpcMessage = serde_json::from_slice(line)?;
118                    messages.push(msg);
119                }
120                start = i + 1;
121            }
122        }
123
124        // Update read position instead of draining for each decode call
125        self.read_pos = start;
126
127        // Check remaining unread data
128        let remaining = self.buffer.len() - self.read_pos;
129        if remaining > self.max_message_size {
130            self.buffer.clear();
131            self.read_pos = 0;
132            return Err(CodecError::MessageTooLarge(remaining));
133        }
134
135        Ok(messages)
136    }
137
138    /// Clears the internal buffer.
139    pub fn clear(&mut self) {
140        self.buffer.clear();
141        self.read_pos = 0;
142    }
143}
144
145/// Codec error types.
146#[derive(Debug)]
147pub enum CodecError {
148    /// JSON parsing error.
149    Json(serde_json::Error),
150    /// Message too large.
151    MessageTooLarge(usize),
152}
153
154impl std::fmt::Display for CodecError {
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        match self {
157            CodecError::Json(e) => write!(f, "JSON error: {e}"),
158            CodecError::MessageTooLarge(size) => write!(f, "Message too large: {size} bytes"),
159        }
160    }
161}
162
163impl std::error::Error for CodecError {
164    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
165        match self {
166            CodecError::Json(e) => Some(e),
167            CodecError::MessageTooLarge(_) => None,
168        }
169    }
170}
171
172impl From<serde_json::Error> for CodecError {
173    fn from(err: serde_json::Error) -> Self {
174        CodecError::Json(err)
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use fastmcp_protocol::RequestId;
182    use std::error::Error;
183
184    #[test]
185    fn test_encode_decode_roundtrip() {
186        let codec = Codec::new();
187        let request = JsonRpcRequest::new("test/method", None, 1i64);
188
189        let encoded = codec.encode_request(&request).unwrap();
190        assert!(encoded.ends_with(b"\n"));
191
192        let mut codec2 = Codec::new();
193        let messages = codec2.decode(&encoded).unwrap();
194        assert_eq!(messages.len(), 1);
195    }
196
197    #[test]
198    fn test_encode_response() {
199        let codec = Codec::new();
200        let response =
201            JsonRpcResponse::success(RequestId::Number(1), serde_json::json!({"result": "ok"}));
202
203        let encoded = codec.encode_response(&response).unwrap();
204        assert!(encoded.ends_with(b"\n"));
205
206        let mut codec2 = Codec::new();
207        let messages = codec2.decode(&encoded).unwrap();
208        assert_eq!(messages.len(), 1);
209
210        assert!(
211            matches!(&messages[0], JsonRpcMessage::Response(_)),
212            "Expected response"
213        );
214        if let JsonRpcMessage::Response(resp) = &messages[0] {
215            assert_eq!(resp.id, Some(RequestId::Number(1)));
216        }
217    }
218
219    #[test]
220    fn test_decode_multiple_messages() {
221        let input = b"{\"jsonrpc\":\"2.0\",\"method\":\"test1\",\"id\":1}\n{\"jsonrpc\":\"2.0\",\"method\":\"test2\",\"id\":2}\n";
222
223        let mut codec = Codec::new();
224        let messages = codec.decode(input).unwrap();
225
226        assert_eq!(messages.len(), 2);
227
228        assert!(
229            matches!(&messages[0], JsonRpcMessage::Request(_)),
230            "Expected request"
231        );
232        if let JsonRpcMessage::Request(req) = &messages[0] {
233            assert_eq!(req.method, "test1");
234        }
235
236        assert!(
237            matches!(&messages[1], JsonRpcMessage::Request(_)),
238            "Expected request"
239        );
240        if let JsonRpcMessage::Request(req) = &messages[1] {
241            assert_eq!(req.method, "test2");
242        }
243    }
244
245    #[test]
246    fn test_decode_allows_multiple_messages_in_separate_chunks() {
247        // Since the codec now pre-checks chunk size, send messages in separate chunks
248        let req1 = JsonRpcRequest::new("test1", None, 1i64);
249        let req2 = JsonRpcRequest::new("test2", None, 2i64);
250        let mut line1 = serde_json::to_vec(&req1).unwrap();
251        let mut line2 = serde_json::to_vec(&req2).unwrap();
252        line1.push(b'\n');
253        line2.push(b'\n');
254
255        let mut codec = Codec::new();
256        // Set limit to accommodate one message at a time
257        codec.set_max_message_size(line1.len());
258
259        // Decode first message
260        let messages1 = codec.decode(&line1).unwrap();
261        assert_eq!(messages1.len(), 1);
262
263        // Decode second message
264        let messages2 = codec.decode(&line2).unwrap();
265        assert_eq!(messages2.len(), 1);
266    }
267
268    #[test]
269    fn test_decode_rejects_oversized_incomplete_line() {
270        let req = JsonRpcRequest::new("oversized", None, 1i64);
271        let line = serde_json::to_vec(&req).unwrap();
272
273        let mut codec = Codec::new();
274        codec.max_message_size = line.len().saturating_sub(1);
275
276        let result = codec.decode(&line);
277        assert!(matches!(result, Err(CodecError::MessageTooLarge(_))));
278    }
279
280    #[test]
281    fn test_decode_partial_message() {
282        let mut codec = Codec::new();
283
284        // Feed partial data without newline
285        let partial = b"{\"jsonrpc\":\"2.0\",\"method\":\"test\"";
286        let messages = codec.decode(partial).unwrap();
287        assert_eq!(messages.len(), 0); // No complete messages yet
288
289        // Feed the rest including newline
290        let rest = b",\"id\":1}\n";
291        let messages = codec.decode(rest).unwrap();
292        assert_eq!(messages.len(), 1);
293
294        assert!(
295            matches!(&messages[0], JsonRpcMessage::Request(_)),
296            "Expected request"
297        );
298        if let JsonRpcMessage::Request(req) = &messages[0] {
299            assert_eq!(req.method, "test");
300        }
301    }
302
303    #[test]
304    fn test_decode_invalid_json() {
305        let mut codec = Codec::new();
306        let invalid = b"not valid json\n";
307
308        let result = codec.decode(invalid);
309        assert!(result.is_err());
310
311        let err = result.unwrap_err();
312        assert!(matches!(err, CodecError::Json(_)));
313    }
314
315    #[test]
316    fn test_decode_empty_line() {
317        let mut codec = Codec::new();
318        let input = b"\n{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"id\":1}\n";
319
320        let messages = codec.decode(input).unwrap();
321        assert_eq!(messages.len(), 1); // Empty line skipped
322    }
323
324    #[test]
325    fn test_clear_buffer() {
326        let mut codec = Codec::new();
327
328        // Feed partial data
329        let partial = b"{\"jsonrpc\":\"2.0\"";
330        codec.decode(partial).unwrap();
331
332        // Clear and verify buffer is empty
333        codec.clear();
334
335        // Feed a complete message - should parse without old partial data
336        let complete = b"{\"jsonrpc\":\"2.0\",\"method\":\"fresh\",\"id\":1}\n";
337        let messages = codec.decode(complete).unwrap();
338
339        assert_eq!(messages.len(), 1);
340        assert!(
341            matches!(&messages[0], JsonRpcMessage::Request(_)),
342            "Expected request"
343        );
344        if let JsonRpcMessage::Request(req) = &messages[0] {
345            assert_eq!(req.method, "fresh");
346        }
347    }
348
349    #[test]
350    fn test_codec_error_display() {
351        let json_err = CodecError::Json(serde_json::from_str::<()>("invalid").unwrap_err());
352        let size_err = CodecError::MessageTooLarge(1000);
353
354        assert!(json_err.to_string().contains("JSON error"));
355        assert!(size_err.to_string().contains("1000"));
356    }
357
358    #[test]
359    fn test_codec_error_source() {
360        let json_err = CodecError::Json(serde_json::from_str::<()>("invalid").unwrap_err());
361        let size_err = CodecError::MessageTooLarge(1000);
362
363        assert!(json_err.source().is_some());
364        assert!(size_err.source().is_none());
365    }
366}