Skip to main content

matrixcode_core/matrixrpc/transport/
codec.rs

1//! Frame Codec for JSON-RPC messages
2//!
3//! Implements Content-Length based framing (LSP/MCP style):
4//! ```
5//! Content-Length: <length>\r\n
6//! \r\n
7//! <json-payload>
8//! ```
9//!
10//! This is a widely adopted format used by:
11//! - Language Server Protocol (LSP)
12//! - Model Context Protocol (MCP)
13//! - Debug Adapter Protocol (DAP)
14
15use std::io;
16
17use crate::matrixrpc::protocol::JsonRpcMessage;
18
19/// Frame codec for encoding/decoding JSON-RPC messages
20///
21/// Uses Content-Length header framing similar to LSP/MCP.
22#[derive(Debug, Default)]
23pub struct FrameCodec {
24    /// Maximum allowed message size
25    max_message_size: usize,
26}
27
28impl FrameCodec {
29    /// Create a new frame codec with default settings
30    pub fn new() -> Self {
31        Self {
32            max_message_size: 16 * 1024 * 1024, // 16MB default
33        }
34    }
35
36    /// Create a frame codec with a custom max message size
37    pub fn with_max_size(max_message_size: usize) -> Self {
38        Self { max_message_size }
39    }
40
41    /// Encode a message into a framed byte vector
42    ///
43    /// Format: `Content-Length: <length>\r\n\r\n<json>`
44    pub fn encode(&self, message: &JsonRpcMessage) -> io::Result<Vec<u8>> {
45        let json = message.to_json().map_err(|e| {
46            io::Error::new(
47                io::ErrorKind::InvalidData,
48                format!("JSON encode error: {}", e),
49            )
50        })?;
51
52        let json_bytes = json.into_bytes();
53        if json_bytes.len() > self.max_message_size {
54            return Err(io::Error::new(
55                io::ErrorKind::InvalidData,
56                format!(
57                    "Message size {} exceeds maximum {}",
58                    json_bytes.len(),
59                    self.max_message_size
60                ),
61            ));
62        }
63
64        let header = format!("Content-Length: {}\r\n\r\n", json_bytes.len());
65        let mut frame = header.into_bytes();
66        frame.extend(json_bytes);
67
68        Ok(frame)
69    }
70
71    /// Encode a message and write to a writer
72    pub fn encode_to_writer<W: std::io::Write>(
73        &self,
74        writer: &mut W,
75        message: &JsonRpcMessage,
76    ) -> io::Result<()> {
77        let frame = self.encode(message)?;
78        writer.write_all(&frame)?;
79        writer.flush()?;
80        Ok(())
81    }
82
83    /// Decode a message from a complete buffer
84    ///
85    /// Returns the remaining unconsumed bytes and the parsed message.
86    /// Returns Ok((remaining, None)) if more data is needed.
87    pub fn decode_from_buffer<'a>(
88        &self,
89        buffer: &'a [u8],
90    ) -> io::Result<(&'a [u8], Option<JsonRpcMessage>)> {
91        // Find end of headers
92        let header_end = match find_header_end(buffer) {
93            Some(pos) => pos,
94            None => return Ok((buffer, None)), // Need more data
95        };
96
97        // Parse headers
98        let header_str = std::str::from_utf8(&buffer[..header_end]).map_err(|e| {
99            io::Error::new(
100                io::ErrorKind::InvalidData,
101                format!("Invalid UTF-8 in headers: {}", e),
102            )
103        })?;
104
105        let content_length = parse_content_length(header_str)?;
106
107        // Check if we have enough data
108        let body_start = header_end + 4; // Skip \r\n\r\n
109        if buffer.len() < body_start + content_length {
110            return Ok((buffer, None)); // Need more data
111        }
112
113        // Extract and parse the body
114        let body = &buffer[body_start..body_start + content_length];
115        let json_str = std::str::from_utf8(body).map_err(|e| {
116            io::Error::new(
117                io::ErrorKind::InvalidData,
118                format!("Invalid UTF-8 in body: {}", e),
119            )
120        })?;
121
122        let message = JsonRpcMessage::from_json(json_str).map_err(|e| {
123            io::Error::new(
124                io::ErrorKind::InvalidData,
125                format!("JSON decode error: {}", e),
126            )
127        })?;
128
129        // Return remaining buffer
130        let remaining = &buffer[body_start + content_length..];
131        Ok((remaining, Some(message)))
132    }
133
134    /// Get the maximum message size
135    pub fn max_message_size(&self) -> usize {
136        self.max_message_size
137    }
138}
139
140/// Find the end of HTTP-like headers (double CRLF)
141fn find_header_end(buffer: &[u8]) -> Option<usize> {
142    let pattern = b"\r\n\r\n";
143    if buffer.len() < 4 {
144        return None;
145    }
146
147    for i in 0..=buffer.len() - 4 {
148        if &buffer[i..i + 4] == pattern {
149            return Some(i);
150        }
151    }
152    None
153}
154
155/// Parse Content-Length from header string
156fn parse_content_length(headers: &str) -> io::Result<usize> {
157    for line in headers.lines() {
158        let line = line.trim();
159        if let Some((key, value)) = line.split_once(':') {
160            if key.trim().eq_ignore_ascii_case("Content-Length") {
161                let length: usize = value.trim().parse().map_err(|e| {
162                    io::Error::new(
163                        io::ErrorKind::InvalidData,
164                        format!("Invalid Content-Length: {}", e),
165                    )
166                })?;
167                return Ok(length);
168            }
169        }
170    }
171    Err(io::Error::new(
172        io::ErrorKind::InvalidData,
173        "Missing Content-Length header",
174    ))
175}
176
177#[allow(dead_code)]
178/// Encode a single message to bytes (convenience function)
179pub fn encode_message(message: &JsonRpcMessage) -> io::Result<Vec<u8>> {
180#[allow(dead_code)]
181    FrameCodec::new().encode(message)
182}
183
184#[allow(dead_code)]
185/// Decode a single message from bytes (convenience function)
186///
187/// Expects the buffer to contain exactly one framed message (with headers)
188#[allow(dead_code)]
189pub fn decode_message_from_buffer(buffer: &[u8]) -> io::Result<(Vec<u8>, JsonRpcMessage)> {
190    let codec = FrameCodec::new();
191    let (remaining, message) = codec.decode_from_buffer(buffer)?;
192    match message {
193        Some(msg) => Ok((remaining.to_vec(), msg)),
194        None => Err(io::Error::new(
195            io::ErrorKind::UnexpectedEof,
196            "Incomplete message",
197        )),
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use serde_json::json;
205
206    #[test]
207    fn test_encode_message() {
208        let request = JsonRpcMessage::Request(
209            crate::matrixrpc::protocol::JsonRpcRequest::new("test_method")
210                .params(json!({"key": "value"})),
211        );
212
213        let codec = FrameCodec::new();
214        let frame = codec.encode(&request).unwrap();
215
216        let frame_str = String::from_utf8_lossy(&frame);
217        assert!(frame_str.starts_with("Content-Length:"));
218        assert!(frame_str.contains("\r\n\r\n"));
219        assert!(frame_str.contains("\"method\":\"test_method\""));
220    }
221
222    #[test]
223    fn test_decode_message() {
224        let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
225        let frame = format!("Content-Length: {}\r\n\r\n{}", json.len(), json);
226
227        let codec = FrameCodec::new();
228        let (remaining, message) = codec.decode_from_buffer(frame.as_bytes()).unwrap();
229
230        assert!(message.is_some());
231        let msg = message.unwrap();
232        assert!(msg.is_request());
233        assert!(remaining.is_empty());
234    }
235
236    #[test]
237    fn test_encode_decode_roundtrip() {
238        let request = JsonRpcMessage::Request(
239            crate::matrixrpc::protocol::JsonRpcRequest::with_id("test_method", 42)
240                .params(json!({"arg": "value"})),
241        );
242
243        let codec = FrameCodec::new();
244        let frame = codec.encode(&request).unwrap();
245
246        let (_, decoded) = codec.decode_from_buffer(&frame).unwrap();
247        let decoded = decoded.unwrap();
248
249        assert_eq!(
250            decoded.as_request().unwrap().method,
251            request.as_request().unwrap().method
252        );
253    }
254
255    #[test]
256    fn test_max_message_size() {
257        let codec = FrameCodec::with_max_size(10);
258        let request =
259            JsonRpcMessage::Request(crate::matrixrpc::protocol::JsonRpcRequest::new("test"));
260
261        let result = codec.encode(&request);
262        assert!(result.is_err());
263        assert!(matches!(
264            result.unwrap_err().kind(),
265            io::ErrorKind::InvalidData
266        ));
267    }
268
269    #[test]
270    fn test_incomplete_message() {
271        let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
272        let partial_frame = format!("Content-Length: {}\r\n\r\n", json.len()); // Missing body
273
274        let codec = FrameCodec::new();
275        let result = codec.decode_from_buffer(partial_frame.as_bytes()).unwrap();
276
277        assert!(result.1.is_none());
278    }
279
280    #[test]
281    fn test_multiple_messages_in_buffer() {
282        let json1 = r#"{"jsonrpc":"2.0","method":"test1","id":1}"#;
283        let json2 = r#"{"jsonrpc":"2.0","method":"test2","id":2}"#;
284
285        let codec = FrameCodec::new();
286        let mut buffer = Vec::new();
287        buffer.extend(
288            codec
289                .encode(&JsonRpcMessage::Request(
290                    crate::matrixrpc::protocol::JsonRpcRequest::from_json(json1).unwrap(),
291                ))
292                .unwrap(),
293        );
294        buffer.extend(
295            codec
296                .encode(&JsonRpcMessage::Request(
297                    crate::matrixrpc::protocol::JsonRpcRequest::from_json(json2).unwrap(),
298                ))
299                .unwrap(),
300        );
301
302        // Decode first message
303        let (remaining1, msg1) = codec.decode_from_buffer(&buffer).unwrap();
304        let msg1 = msg1.unwrap();
305        assert_eq!(msg1.as_request().unwrap().method, "test1");
306
307        // Decode second message
308        let (_, msg2) = codec.decode_from_buffer(remaining1).unwrap();
309        let msg2 = msg2.unwrap();
310        assert_eq!(msg2.as_request().unwrap().method, "test2");
311    }
312
313    #[test]
314    fn test_convenience_functions() {
315        let request = JsonRpcMessage::Request(crate::matrixrpc::protocol::JsonRpcRequest::new(
316            "test_method",
317        ));
318
319        let encoded = encode_message(&request).unwrap();
320        let (_, decoded) = decode_message_from_buffer(&encoded).unwrap();
321
322        assert!(decoded.is_request());
323    }
324
325    // ==================== Additional Edge Case Tests ====================
326
327    #[test]
328    fn test_decode_missing_content_length() {
329        let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
330        let frame = format!("Content-Type: application/json\r\n\r\n{}", json);
331
332        let codec = FrameCodec::new();
333        let result = codec.decode_from_buffer(frame.as_bytes());
334        assert!(result.is_err());
335        let err = result.unwrap_err();
336        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
337        assert!(err.to_string().contains("Missing Content-Length"));
338    }
339
340    #[test]
341    fn test_decode_malformed_content_length() {
342        let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
343        let frame = format!("Content-Length: abc\r\n\r\n{}", json);
344
345        let codec = FrameCodec::new();
346        let result = codec.decode_from_buffer(frame.as_bytes());
347        assert!(result.is_err());
348        let err = result.unwrap_err();
349        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
350        assert!(err.to_string().contains("Invalid Content-Length"));
351    }
352
353    #[test]
354    fn test_decode_negative_content_length() {
355        let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
356        let frame = format!("Content-Length: -1\r\n\r\n{}", json);
357
358        let codec = FrameCodec::new();
359        let result = codec.decode_from_buffer(frame.as_bytes());
360        assert!(result.is_err());
361    }
362
363    #[test]
364    fn test_decode_case_insensitive_header() {
365        let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
366        // Test with different case variations
367        for header in [
368            "content-length",
369            "CONTENT-LENGTH",
370            "Content-length",
371            "CONTENT-length",
372        ] {
373            let frame = format!("{}: {}\r\n\r\n{}", header, json.len(), json);
374            let codec = FrameCodec::new();
375            let (_, message) = codec.decode_from_buffer(frame.as_bytes()).unwrap();
376            assert!(
377                message.is_some(),
378                "Failed to parse with header: {}",
379                header
380            );
381        }
382    }
383
384    #[test]
385    fn test_decode_with_extra_headers() {
386        let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
387        let frame = format!(
388            "Content-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
389            json.len(),
390            json
391        );
392
393        let codec = FrameCodec::new();
394        let (_, message) = codec.decode_from_buffer(frame.as_bytes()).unwrap();
395        assert!(message.is_some());
396    }
397
398    #[test]
399    fn test_decode_zero_content_length() {
400        let frame = "Content-Length: 0\r\n\r\n";
401
402        let codec = FrameCodec::new();
403        // Zero-length body should fail JSON parse
404        let result = codec.decode_from_buffer(frame.as_bytes());
405        assert!(result.is_err());
406    }
407
408    #[test]
409    fn test_decode_invalid_utf8_in_header() {
410        // Invalid UTF8 in the header part (before \r\n\r\n)
411        // The header parsing will fail when trying to convert to UTF8
412        let invalid_bytes = b"Content-Length: \xFF\xFE\r\n\r\n{}";
413
414        let codec = FrameCodec::new();
415        let result = codec.decode_from_buffer(invalid_bytes);
416        assert!(result.is_err());
417        let err = result.unwrap_err();
418        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
419        assert!(err.to_string().contains("Invalid UTF-8"));
420    }
421
422    #[test]
423    fn test_decode_invalid_json_body() {
424        let invalid_json = r#"{"jsonrpc":"2.0","method":}"#; // Malformed JSON
425        let frame = format!("Content-Length: {}\r\n\r\n{}", invalid_json.len(), invalid_json);
426
427        let codec = FrameCodec::new();
428        let result = codec.decode_from_buffer(frame.as_bytes());
429        assert!(result.is_err());
430        let err = result.unwrap_err();
431        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
432        assert!(err.to_string().contains("JSON decode error"));
433    }
434
435    #[test]
436    fn test_decode_empty_buffer() {
437        let codec = FrameCodec::new();
438        let (remaining, message) = codec.decode_from_buffer(b"").unwrap();
439        assert!(message.is_none());
440        assert!(remaining.is_empty());
441    }
442
443    #[test]
444    fn test_decode_partial_header() {
445        let partial = b"Content-Length: 10";
446
447        let codec = FrameCodec::new();
448        let (remaining, message) = codec.decode_from_buffer(partial).unwrap();
449        assert!(message.is_none());
450        assert_eq!(remaining, partial);
451    }
452
453    #[test]
454    fn test_decode_partial_body() {
455        let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
456        // Header says 100 bytes but only provides partial body
457        let partial = format!("Content-Length: 100\r\n\r\n{}", json);
458
459        let codec = FrameCodec::new();
460        let (remaining, message) = codec
461            .decode_from_buffer(partial.as_bytes())
462            .unwrap();
463        assert!(message.is_none());
464        assert!(!remaining.is_empty());
465    }
466
467    #[test]
468    fn test_encode_response_message() {
469        let response = JsonRpcMessage::Response(
470            crate::matrixrpc::protocol::JsonRpcResponse::success(1, json!({"result": "ok"})),
471        );
472
473        let codec = FrameCodec::new();
474        let frame = codec.encode(&response).unwrap();
475        let frame_str = String::from_utf8_lossy(&frame);
476
477        assert!(frame_str.contains("\"result\":"));
478        assert!(frame_str.contains("\"ok\""));
479    }
480
481    #[test]
482    fn test_encode_error_response() {
483        let error = JsonRpcMessage::Response(
484            crate::matrixrpc::protocol::JsonRpcResponse::error(
485                1,
486                crate::matrixrpc::protocol::JsonRpcError::method_not_found("unknown"),
487            ),
488        );
489
490        let codec = FrameCodec::new();
491        let frame = codec.encode(&error).unwrap();
492        let frame_str = String::from_utf8_lossy(&frame);
493
494        assert!(frame_str.contains("\"error\""));
495        assert!(frame_str.contains("Method 'unknown' not found"));
496    }
497
498    #[test]
499    fn test_encode_batch_message() {
500        let batch = JsonRpcMessage::Batch(vec![
501            JsonRpcMessage::Request(
502                crate::matrixrpc::protocol::JsonRpcRequest::new("method1"),
503            ),
504            JsonRpcMessage::Request(
505                crate::matrixrpc::protocol::JsonRpcRequest::new("method2"),
506            ),
507        ]);
508
509        let codec = FrameCodec::new();
510        let frame = codec.encode(&batch).unwrap();
511        let frame_str = String::from_utf8_lossy(&frame);
512
513        assert!(frame_str.starts_with('[') || frame_str.contains("["));
514        assert!(frame_str.contains("method1"));
515        assert!(frame_str.contains("method2"));
516    }
517
518    #[test]
519    fn test_encode_notification() {
520        let notification = JsonRpcMessage::Request(
521            crate::matrixrpc::protocol::JsonRpcRequest::notification("notify_event")
522                .params(json!({"event": "test"})),
523        );
524
525        let codec = FrameCodec::new();
526        let frame = codec.encode(&notification).unwrap();
527        let frame_str = String::from_utf8_lossy(&frame);
528
529        // Notification should have no id
530        let body_start = frame_str.find("\r\n\r\n").unwrap() + 4;
531        let body = &frame_str[body_start..];
532        let parsed: serde_json::Value = serde_json::from_str(body).unwrap();
533        assert!(parsed.get("id").is_none());
534        assert_eq!(parsed["method"], "notify_event");
535    }
536
537    #[test]
538    fn test_decode_with_trailing_data() {
539        let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
540        let frame = format!("Content-Length: {}\r\n\r\n{}extra_data", json.len(), json);
541
542        let codec = FrameCodec::new();
543        let (remaining, message) = codec.decode_from_buffer(frame.as_bytes()).unwrap();
544
545        assert!(message.is_some());
546        assert_eq!(remaining, b"extra_data");
547    }
548
549    #[test]
550    fn test_decode_message_from_buffer_incomplete() {
551        let partial = b"Content-Length: 100\r\n\r\n{}";
552        let result = decode_message_from_buffer(partial);
553        assert!(result.is_err());
554        let err = result.unwrap_err();
555        assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
556    }
557
558    #[test]
559    fn test_content_length_whitespace() {
560        let json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
561        // Content-Length with extra whitespace
562        let frame = format!("Content-Length:  {}  \r\n\r\n{}", json.len(), json);
563
564        let codec = FrameCodec::new();
565        let (_, message) = codec.decode_from_buffer(frame.as_bytes()).unwrap();
566        assert!(message.is_some());
567    }
568
569    #[test]
570    fn test_large_message_within_limit() {
571        // Create a large message within the default 16MB limit
572        let large_params = "x".repeat(1024 * 1024); // 1MB of data
573        let request = JsonRpcMessage::Request(
574            crate::matrixrpc::protocol::JsonRpcRequest::new("test").params(json!({"data": large_params})),
575        );
576
577        let codec = FrameCodec::new();
578        let result = codec.encode(&request);
579        assert!(result.is_ok());
580    }
581}