Skip to main content

devboy_mcp/
transport.rs

1//! Transport layer for MCP JSON-RPC communication.
2//!
3//! MCP uses newline-delimited JSON over stdin/stdout.
4
5use std::io::{self, BufRead, Write};
6
7use crate::protocol::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
8
9/// Message that can be received from the client.
10#[derive(Debug)]
11pub enum IncomingMessage {
12    Request(JsonRpcRequest),
13    Notification(JsonRpcNotification),
14}
15
16/// Transport for reading/writing JSON-RPC messages.
17pub struct StdioTransport {
18    reader: Box<dyn BufRead + Send>,
19    writer: Box<dyn Write + Send>,
20}
21
22impl StdioTransport {
23    /// Create a transport using stdin/stdout.
24    pub fn stdio() -> Self {
25        Self {
26            reader: Box::new(io::BufReader::new(io::stdin())),
27            writer: Box::new(io::stdout()),
28        }
29    }
30
31    /// Create a transport with custom reader/writer (for testing).
32    #[cfg(test)]
33    pub fn new(reader: Box<dyn BufRead + Send>, writer: Box<dyn Write + Send>) -> Self {
34        Self { reader, writer }
35    }
36
37    /// Read a single JSON-RPC message from the transport.
38    pub fn read_message(&mut self) -> io::Result<Option<IncomingMessage>> {
39        let mut line = String::new();
40
41        match self.reader.read_line(&mut line) {
42            Ok(0) => Ok(None), // EOF
43            Ok(_) => {
44                let line = line.trim();
45                if line.is_empty() {
46                    return Ok(None);
47                }
48
49                tracing::debug!("Received: {}", line);
50
51                // Try to parse as request first (has id field)
52                if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(line) {
53                    return Ok(Some(IncomingMessage::Request(request)));
54                }
55
56                // Try as notification (no id field)
57                if let Ok(notification) = serde_json::from_str::<JsonRpcNotification>(line) {
58                    return Ok(Some(IncomingMessage::Notification(notification)));
59                }
60
61                tracing::warn!("Failed to parse message: {}", line);
62                Err(io::Error::new(
63                    io::ErrorKind::InvalidData,
64                    format!("Invalid JSON-RPC message: {}", line),
65                ))
66            }
67            Err(e) => Err(e),
68        }
69    }
70
71    /// Write a JSON-RPC response to the transport.
72    pub fn write_response(&mut self, response: &JsonRpcResponse) -> io::Result<()> {
73        let json = serde_json::to_string(response).map_err(|e| {
74            io::Error::new(
75                io::ErrorKind::InvalidData,
76                format!("Serialization error: {}", e),
77            )
78        })?;
79
80        tracing::debug!("Sending: {}", json);
81
82        writeln!(self.writer, "{}", json)?;
83        self.writer.flush()
84    }
85
86    /// Write a JSON-RPC notification to the transport.
87    pub fn write_notification(&mut self, notification: &JsonRpcNotification) -> io::Result<()> {
88        let json = serde_json::to_string(notification).map_err(|e| {
89            io::Error::new(
90                io::ErrorKind::InvalidData,
91                format!("Serialization error: {}", e),
92            )
93        })?;
94
95        tracing::debug!("Sending notification: {}", json);
96
97        writeln!(self.writer, "{}", json)?;
98        self.writer.flush()
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use crate::protocol::RequestId;
106    use std::io::Cursor;
107
108    #[test]
109    fn test_read_request() {
110        let input = r#"{"jsonrpc":"2.0","id":1,"method":"test","params":{}}"#;
111        let reader = Box::new(Cursor::new(format!("{}\n", input)));
112        let writer = Box::new(Vec::new());
113
114        let mut transport = StdioTransport::new(reader, writer);
115        let msg = transport.read_message().unwrap();
116
117        match msg {
118            Some(IncomingMessage::Request(req)) => {
119                assert_eq!(req.method, "test");
120                assert_eq!(req.id, RequestId::Number(1));
121            }
122            _ => panic!("Expected request"),
123        }
124    }
125
126    #[test]
127    fn test_read_notification() {
128        let input = r#"{"jsonrpc":"2.0","method":"initialized"}"#;
129        let reader = Box::new(Cursor::new(format!("{}\n", input)));
130        let writer = Box::new(Vec::new());
131
132        let mut transport = StdioTransport::new(reader, writer);
133        let msg = transport.read_message().unwrap();
134
135        match msg {
136            Some(IncomingMessage::Notification(notif)) => {
137                assert_eq!(notif.method, "initialized");
138            }
139            _ => panic!("Expected notification"),
140        }
141    }
142
143    #[test]
144    fn test_write_response() {
145        use std::sync::{Arc, Mutex};
146
147        // Use Arc<Mutex<Vec>> to capture output
148        let buffer = Arc::new(Mutex::new(Vec::new()));
149        let buffer_clone = buffer.clone();
150
151        struct SharedWriter(Arc<Mutex<Vec<u8>>>);
152        impl std::io::Write for SharedWriter {
153            fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
154                self.0.lock().unwrap().extend_from_slice(buf);
155                Ok(buf.len())
156            }
157            fn flush(&mut self) -> std::io::Result<()> {
158                Ok(())
159            }
160        }
161
162        let reader = Box::new(Cursor::new(Vec::new()));
163        let writer = Box::new(SharedWriter(buffer_clone));
164
165        let mut transport = StdioTransport::new(reader, writer);
166
167        let response =
168            JsonRpcResponse::success(RequestId::Number(1), serde_json::json!({"test": true}));
169
170        transport.write_response(&response).unwrap();
171
172        let output = String::from_utf8(buffer.lock().unwrap().clone()).unwrap();
173        assert!(output.contains("\"jsonrpc\":\"2.0\""));
174        assert!(output.contains("\"id\":1"));
175    }
176
177    #[test]
178    fn test_read_eof() {
179        let reader = Box::new(Cursor::new(Vec::new()));
180        let writer = Box::new(Vec::new());
181
182        let mut transport = StdioTransport::new(reader, writer);
183        let msg = transport.read_message().unwrap();
184
185        assert!(msg.is_none());
186    }
187
188    #[test]
189    fn test_read_empty_line() {
190        let reader = Box::new(Cursor::new("\n".to_string()));
191        let writer = Box::new(Vec::new());
192
193        let mut transport = StdioTransport::new(reader, writer);
194        let msg = transport.read_message().unwrap();
195
196        assert!(msg.is_none());
197    }
198
199    #[test]
200    fn test_read_invalid_json() {
201        let reader = Box::new(Cursor::new("not valid json\n".to_string()));
202        let writer = Box::new(Vec::new());
203
204        let mut transport = StdioTransport::new(reader, writer);
205        let result = transport.read_message();
206
207        assert!(result.is_err());
208        let err = result.unwrap_err();
209        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
210    }
211
212    #[test]
213    fn test_write_notification() {
214        use std::sync::{Arc, Mutex};
215
216        let buffer = Arc::new(Mutex::new(Vec::new()));
217        let buffer_clone = buffer.clone();
218
219        struct SharedWriter(Arc<Mutex<Vec<u8>>>);
220        impl std::io::Write for SharedWriter {
221            fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
222                self.0.lock().unwrap().extend_from_slice(buf);
223                Ok(buf.len())
224            }
225            fn flush(&mut self) -> std::io::Result<()> {
226                Ok(())
227            }
228        }
229
230        let reader = Box::new(Cursor::new(Vec::new()));
231        let writer = Box::new(SharedWriter(buffer_clone));
232
233        let mut transport = StdioTransport::new(reader, writer);
234
235        let notification = JsonRpcNotification {
236            jsonrpc: "2.0".to_string(),
237            method: "test/notification".to_string(),
238            params: Some(serde_json::json!({"key": "value"})),
239        };
240
241        transport.write_notification(&notification).unwrap();
242
243        let output = String::from_utf8(buffer.lock().unwrap().clone()).unwrap();
244        assert!(output.contains("\"jsonrpc\":\"2.0\""));
245        assert!(output.contains("\"method\":\"test/notification\""));
246        assert!(output.ends_with('\n'));
247    }
248
249    #[test]
250    fn test_write_notification_without_params() {
251        use std::sync::{Arc, Mutex};
252
253        let buffer = Arc::new(Mutex::new(Vec::new()));
254        let buffer_clone = buffer.clone();
255
256        struct SharedWriter(Arc<Mutex<Vec<u8>>>);
257        impl std::io::Write for SharedWriter {
258            fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
259                self.0.lock().unwrap().extend_from_slice(buf);
260                Ok(buf.len())
261            }
262            fn flush(&mut self) -> std::io::Result<()> {
263                Ok(())
264            }
265        }
266
267        let reader = Box::new(Cursor::new(Vec::new()));
268        let writer = Box::new(SharedWriter(buffer_clone));
269
270        let mut transport = StdioTransport::new(reader, writer);
271
272        let notification = JsonRpcNotification {
273            jsonrpc: "2.0".to_string(),
274            method: "initialized".to_string(),
275            params: None,
276        };
277
278        transport.write_notification(&notification).unwrap();
279
280        let output = String::from_utf8(buffer.lock().unwrap().clone()).unwrap();
281        assert!(output.contains("\"method\":\"initialized\""));
282    }
283
284    #[test]
285    fn test_read_request_with_string_id() {
286        let input = r#"{"jsonrpc":"2.0","id":"abc","method":"ping"}"#;
287        let reader = Box::new(Cursor::new(format!("{}\n", input)));
288        let writer = Box::new(Vec::new());
289
290        let mut transport = StdioTransport::new(reader, writer);
291        let msg = transport.read_message().unwrap();
292
293        match msg {
294            Some(IncomingMessage::Request(req)) => {
295                assert_eq!(req.method, "ping");
296                assert_eq!(req.id, RequestId::String("abc".to_string()));
297            }
298            _ => panic!("Expected request"),
299        }
300    }
301}