Skip to main content

reddb_server/mcp/
protocol.rs

1//! JSON-RPC protocol handling for MCP server.
2//!
3//! Implements the Content-Length framed JSON-RPC transport used by the
4//! Model Context Protocol over stdio.
5
6use crate::json::{Map, Value as JsonValue};
7use std::io::{BufRead, Write};
8
9/// Read a JSON-RPC payload from a buffered reader.
10///
11/// Reads headers until it finds `Content-Length: N`, then reads exactly N
12/// bytes of body. Returns `None` on EOF and `Err` on malformed input.
13pub fn read_payload<R: BufRead>(reader: &mut R) -> Result<Option<String>, String> {
14    let mut content_length: Option<usize> = None;
15    let mut header = String::new();
16
17    loop {
18        header.clear();
19        let bytes = reader
20            .read_line(&mut header)
21            .map_err(|e| format!("failed to read header: {}", e))?;
22        if bytes == 0 {
23            return Ok(None);
24        }
25
26        let trimmed = header.trim_end_matches(['\n', '\r']);
27        if trimmed.is_empty() {
28            break;
29        }
30
31        let lower = trimmed.to_ascii_lowercase();
32        if lower.starts_with("content-length:") {
33            let value = trimmed["Content-Length:".len()..].trim();
34            let length = value
35                .parse::<usize>()
36                .map_err(|_| "invalid Content-Length header".to_string())?;
37            content_length = Some(length);
38        }
39    }
40
41    let length = content_length.ok_or_else(|| "missing Content-Length header".to_string())?;
42    let mut buffer = vec![0u8; length];
43    reader
44        .read_exact(&mut buffer)
45        .map_err(|e| format!("failed to read payload: {}", e))?;
46
47    // Consume optional trailing newline between messages.
48    if let Ok(buf) = reader.fill_buf() {
49        let to_consume = if buf.starts_with(b"\r\n") {
50            Some(2)
51        } else if buf.starts_with(b"\n") {
52            Some(1)
53        } else {
54            None
55        };
56        if let Some(count) = to_consume {
57            reader.consume(count);
58        }
59    }
60
61    String::from_utf8(buffer)
62        .map(Some)
63        .map_err(|_| "payload is not UTF-8".to_string())
64}
65
66/// Write a Content-Length framed JSON-RPC message to a writer.
67pub fn write_message<W: Write>(writer: &mut W, body: &str) -> Result<(), String> {
68    write!(writer, "Content-Length: {}\r\n\r\n{}", body.len(), body)
69        .map_err(|e| format!("failed to write response: {}", e))?;
70    writer
71        .flush()
72        .map_err(|e| format!("failed to flush: {}", e))
73}
74
75/// Build a JSON-RPC 2.0 result message.
76pub fn build_result_message(id: Option<&JsonValue>, result: JsonValue) -> String {
77    let mut object = Map::new();
78    object.insert("jsonrpc".to_string(), JsonValue::String("2.0".to_string()));
79    match id {
80        Some(identifier) => {
81            object.insert("id".to_string(), identifier.clone());
82        }
83        None => {
84            object.insert("id".to_string(), JsonValue::Null);
85        }
86    }
87    object.insert("result".to_string(), result);
88    JsonValue::Object(object).to_string_compact()
89}
90
91/// Build a JSON-RPC 2.0 error message.
92pub fn build_error_message(id: Option<&JsonValue>, code: i64, message: &str) -> String {
93    let mut error = Map::new();
94    error.insert("code".to_string(), JsonValue::Number(code as f64));
95    error.insert(
96        "message".to_string(),
97        JsonValue::String(message.to_string()),
98    );
99
100    let mut object = Map::new();
101    object.insert("jsonrpc".to_string(), JsonValue::String("2.0".to_string()));
102    match id {
103        Some(identifier) => {
104            object.insert("id".to_string(), identifier.clone());
105        }
106        None => {
107            object.insert("id".to_string(), JsonValue::Null);
108        }
109    }
110    object.insert("error".to_string(), JsonValue::Object(error));
111    JsonValue::Object(object).to_string_compact()
112}
113
114/// Build a JSON-RPC 2.0 notification (no id, no response expected).
115pub fn build_notification(method: &str, params: JsonValue) -> String {
116    let mut object = Map::new();
117    object.insert("jsonrpc".to_string(), JsonValue::String("2.0".to_string()));
118    object.insert("method".to_string(), JsonValue::String(method.to_string()));
119    object.insert("params".to_string(), params);
120    JsonValue::Object(object).to_string_compact()
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use crate::json::from_str;
127
128    #[test]
129    fn test_build_result_message() {
130        let id = JsonValue::Number(1.0);
131        let result = JsonValue::Bool(true);
132        let msg = build_result_message(Some(&id), result);
133        let parsed: JsonValue = from_str(&msg).unwrap();
134        assert_eq!(parsed.get("jsonrpc").and_then(|v| v.as_str()), Some("2.0"));
135        assert_eq!(parsed.get("id").and_then(|v| v.as_f64()), Some(1.0));
136    }
137
138    #[test]
139    fn test_build_error_message() {
140        let id = JsonValue::Number(2.0);
141        let msg = build_error_message(Some(&id), -32601, "method not found");
142        let parsed: JsonValue = from_str(&msg).unwrap();
143        assert_eq!(parsed.get("jsonrpc").and_then(|v| v.as_str()), Some("2.0"));
144        let error = parsed.get("error").unwrap();
145        assert_eq!(error.get("code").and_then(|v| v.as_f64()), Some(-32601.0));
146        assert_eq!(
147            error.get("message").and_then(|v| v.as_str()),
148            Some("method not found")
149        );
150    }
151
152    #[test]
153    fn test_build_notification() {
154        let msg = build_notification("test/event", JsonValue::Null);
155        let parsed: JsonValue = from_str(&msg).unwrap();
156        assert_eq!(
157            parsed.get("method").and_then(|v| v.as_str()),
158            Some("test/event")
159        );
160        assert!(parsed.get("id").is_none());
161    }
162
163    #[test]
164    fn test_read_payload_basic() {
165        let body = r#"{"id":1}"#;
166        let msg = format!("Content-Length: {}\r\n\r\n{}", body.len(), body);
167        let mut reader = std::io::BufReader::new(msg.as_bytes());
168        let payload = read_payload(&mut reader).unwrap();
169        assert_eq!(payload, Some(body.to_string()));
170    }
171
172    #[test]
173    fn test_read_payload_eof() {
174        let input = b"";
175        let mut reader = std::io::BufReader::new(&input[..]);
176        let payload = read_payload(&mut reader).unwrap();
177        assert!(payload.is_none());
178    }
179}