1use std::io::{self, BufRead, Write};
6
7use crate::protocol::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
8
9#[derive(Debug)]
11pub enum IncomingMessage {
12 Request(JsonRpcRequest),
13 Notification(JsonRpcNotification),
14}
15
16pub struct StdioTransport {
18 reader: Box<dyn BufRead + Send>,
19 writer: Box<dyn Write + Send>,
20}
21
22impl StdioTransport {
23 pub fn stdio() -> Self {
25 Self {
26 reader: Box::new(io::BufReader::new(io::stdin())),
27 writer: Box::new(io::stdout()),
28 }
29 }
30
31 #[cfg(test)]
33 pub fn new(reader: Box<dyn BufRead + Send>, writer: Box<dyn Write + Send>) -> Self {
34 Self { reader, writer }
35 }
36
37 pub fn read_message(&mut self) -> io::Result<Option<IncomingMessage>> {
39 let mut line = String::new();
40
41 match self.reader.read_line(&mut line) {
42 Ok(0) => Ok(None), Ok(_) => {
44 let line = line.trim();
45 if line.is_empty() {
46 return Ok(None);
47 }
48
49 tracing::debug!("Received: {}", line);
50
51 if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(line) {
53 return Ok(Some(IncomingMessage::Request(request)));
54 }
55
56 if let Ok(notification) = serde_json::from_str::<JsonRpcNotification>(line) {
58 return Ok(Some(IncomingMessage::Notification(notification)));
59 }
60
61 tracing::warn!("Failed to parse message: {}", line);
62 Err(io::Error::new(
63 io::ErrorKind::InvalidData,
64 format!("Invalid JSON-RPC message: {}", line),
65 ))
66 }
67 Err(e) => Err(e),
68 }
69 }
70
71 pub fn write_response(&mut self, response: &JsonRpcResponse) -> io::Result<()> {
73 let json = serde_json::to_string(response).map_err(|e| {
74 io::Error::new(
75 io::ErrorKind::InvalidData,
76 format!("Serialization error: {}", e),
77 )
78 })?;
79
80 tracing::debug!("Sending: {}", json);
81
82 writeln!(self.writer, "{}", json)?;
83 self.writer.flush()
84 }
85
86 pub fn write_notification(&mut self, notification: &JsonRpcNotification) -> io::Result<()> {
88 let json = serde_json::to_string(notification).map_err(|e| {
89 io::Error::new(
90 io::ErrorKind::InvalidData,
91 format!("Serialization error: {}", e),
92 )
93 })?;
94
95 tracing::debug!("Sending notification: {}", json);
96
97 writeln!(self.writer, "{}", json)?;
98 self.writer.flush()
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use crate::protocol::RequestId;
106 use std::io::Cursor;
107
108 #[test]
109 fn test_read_request() {
110 let input = r#"{"jsonrpc":"2.0","id":1,"method":"test","params":{}}"#;
111 let reader = Box::new(Cursor::new(format!("{}\n", input)));
112 let writer = Box::new(Vec::new());
113
114 let mut transport = StdioTransport::new(reader, writer);
115 let msg = transport.read_message().unwrap();
116
117 match msg {
118 Some(IncomingMessage::Request(req)) => {
119 assert_eq!(req.method, "test");
120 assert_eq!(req.id, RequestId::Number(1));
121 }
122 _ => panic!("Expected request"),
123 }
124 }
125
126 #[test]
127 fn test_read_notification() {
128 let input = r#"{"jsonrpc":"2.0","method":"initialized"}"#;
129 let reader = Box::new(Cursor::new(format!("{}\n", input)));
130 let writer = Box::new(Vec::new());
131
132 let mut transport = StdioTransport::new(reader, writer);
133 let msg = transport.read_message().unwrap();
134
135 match msg {
136 Some(IncomingMessage::Notification(notif)) => {
137 assert_eq!(notif.method, "initialized");
138 }
139 _ => panic!("Expected notification"),
140 }
141 }
142
143 #[test]
144 fn test_write_response() {
145 use std::sync::{Arc, Mutex};
146
147 let buffer = Arc::new(Mutex::new(Vec::new()));
149 let buffer_clone = buffer.clone();
150
151 struct SharedWriter(Arc<Mutex<Vec<u8>>>);
152 impl std::io::Write for SharedWriter {
153 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
154 self.0.lock().unwrap().extend_from_slice(buf);
155 Ok(buf.len())
156 }
157 fn flush(&mut self) -> std::io::Result<()> {
158 Ok(())
159 }
160 }
161
162 let reader = Box::new(Cursor::new(Vec::new()));
163 let writer = Box::new(SharedWriter(buffer_clone));
164
165 let mut transport = StdioTransport::new(reader, writer);
166
167 let response =
168 JsonRpcResponse::success(RequestId::Number(1), serde_json::json!({"test": true}));
169
170 transport.write_response(&response).unwrap();
171
172 let output = String::from_utf8(buffer.lock().unwrap().clone()).unwrap();
173 assert!(output.contains("\"jsonrpc\":\"2.0\""));
174 assert!(output.contains("\"id\":1"));
175 }
176
177 #[test]
178 fn test_read_eof() {
179 let reader = Box::new(Cursor::new(Vec::new()));
180 let writer = Box::new(Vec::new());
181
182 let mut transport = StdioTransport::new(reader, writer);
183 let msg = transport.read_message().unwrap();
184
185 assert!(msg.is_none());
186 }
187
188 #[test]
189 fn test_read_empty_line() {
190 let reader = Box::new(Cursor::new("\n".to_string()));
191 let writer = Box::new(Vec::new());
192
193 let mut transport = StdioTransport::new(reader, writer);
194 let msg = transport.read_message().unwrap();
195
196 assert!(msg.is_none());
197 }
198
199 #[test]
200 fn test_read_invalid_json() {
201 let reader = Box::new(Cursor::new("not valid json\n".to_string()));
202 let writer = Box::new(Vec::new());
203
204 let mut transport = StdioTransport::new(reader, writer);
205 let result = transport.read_message();
206
207 assert!(result.is_err());
208 let err = result.unwrap_err();
209 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
210 }
211
212 #[test]
213 fn test_write_notification() {
214 use std::sync::{Arc, Mutex};
215
216 let buffer = Arc::new(Mutex::new(Vec::new()));
217 let buffer_clone = buffer.clone();
218
219 struct SharedWriter(Arc<Mutex<Vec<u8>>>);
220 impl std::io::Write for SharedWriter {
221 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
222 self.0.lock().unwrap().extend_from_slice(buf);
223 Ok(buf.len())
224 }
225 fn flush(&mut self) -> std::io::Result<()> {
226 Ok(())
227 }
228 }
229
230 let reader = Box::new(Cursor::new(Vec::new()));
231 let writer = Box::new(SharedWriter(buffer_clone));
232
233 let mut transport = StdioTransport::new(reader, writer);
234
235 let notification = JsonRpcNotification {
236 jsonrpc: "2.0".to_string(),
237 method: "test/notification".to_string(),
238 params: Some(serde_json::json!({"key": "value"})),
239 };
240
241 transport.write_notification(¬ification).unwrap();
242
243 let output = String::from_utf8(buffer.lock().unwrap().clone()).unwrap();
244 assert!(output.contains("\"jsonrpc\":\"2.0\""));
245 assert!(output.contains("\"method\":\"test/notification\""));
246 assert!(output.ends_with('\n'));
247 }
248
249 #[test]
250 fn test_write_notification_without_params() {
251 use std::sync::{Arc, Mutex};
252
253 let buffer = Arc::new(Mutex::new(Vec::new()));
254 let buffer_clone = buffer.clone();
255
256 struct SharedWriter(Arc<Mutex<Vec<u8>>>);
257 impl std::io::Write for SharedWriter {
258 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
259 self.0.lock().unwrap().extend_from_slice(buf);
260 Ok(buf.len())
261 }
262 fn flush(&mut self) -> std::io::Result<()> {
263 Ok(())
264 }
265 }
266
267 let reader = Box::new(Cursor::new(Vec::new()));
268 let writer = Box::new(SharedWriter(buffer_clone));
269
270 let mut transport = StdioTransport::new(reader, writer);
271
272 let notification = JsonRpcNotification {
273 jsonrpc: "2.0".to_string(),
274 method: "initialized".to_string(),
275 params: None,
276 };
277
278 transport.write_notification(¬ification).unwrap();
279
280 let output = String::from_utf8(buffer.lock().unwrap().clone()).unwrap();
281 assert!(output.contains("\"method\":\"initialized\""));
282 }
283
284 #[test]
285 fn test_read_request_with_string_id() {
286 let input = r#"{"jsonrpc":"2.0","id":"abc","method":"ping"}"#;
287 let reader = Box::new(Cursor::new(format!("{}\n", input)));
288 let writer = Box::new(Vec::new());
289
290 let mut transport = StdioTransport::new(reader, writer);
291 let msg = transport.read_message().unwrap();
292
293 match msg {
294 Some(IncomingMessage::Request(req)) => {
295 assert_eq!(req.method, "ping");
296 assert_eq!(req.id, RequestId::String("abc".to_string()));
297 }
298 _ => panic!("Expected request"),
299 }
300 }
301}