1use crate::error::{Result, TransportError};
7use crate::shared::transport::{Transport, TransportMessage};
8use async_trait::async_trait;
9#[cfg(not(target_arch = "wasm32"))]
10use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
11#[cfg(not(target_arch = "wasm32"))]
12use tokio::sync::Mutex;
13
14const CONTENT_LENGTH_HEADER: &str = "Content-Length: ";
16
17#[derive(Debug)]
33pub struct StdioTransport {
34 stdin: Mutex<BufReader<tokio::io::Stdin>>,
35 stdout: Mutex<tokio::io::Stdout>,
36 closed: std::sync::atomic::AtomicBool,
37}
38
39impl StdioTransport {
40 pub fn new() -> Self {
51 Self {
52 stdin: Mutex::new(BufReader::new(tokio::io::stdin())),
53 stdout: Mutex::new(tokio::io::stdout()),
54 closed: std::sync::atomic::AtomicBool::new(false),
55 }
56 }
57
58 fn parse_content_length(line: &str) -> Option<usize> {
62 line.strip_prefix(CONTENT_LENGTH_HEADER)
63 .and_then(|content| content.trim().parse().ok())
64 }
65}
66
67impl Default for StdioTransport {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73#[async_trait]
74impl Transport for StdioTransport {
75 async fn send(&mut self, message: TransportMessage) -> Result<()> {
76 if self.closed.load(std::sync::atomic::Ordering::Acquire) {
77 return Err(TransportError::ConnectionClosed.into());
78 }
79
80 let json_bytes = Self::serialize_message(&message)?;
81 self.write_message(&json_bytes).await
82 }
83
84 async fn receive(&mut self) -> Result<TransportMessage> {
85 if self.closed.load(std::sync::atomic::Ordering::Acquire) {
86 return Err(TransportError::ConnectionClosed.into());
87 }
88
89 let content_length = self.read_headers().await?;
90 let buffer = self.read_message_body(content_length).await?;
91 Self::parse_message(&buffer)
92 }
93
94 async fn close(&mut self) -> Result<()> {
95 self.closed
96 .store(true, std::sync::atomic::Ordering::Release);
97
98 let mut stdout = self.stdout.lock().await;
100 stdout.flush().await.map_err(TransportError::from)?;
101 drop(stdout);
102
103 Ok(())
104 }
105
106 fn is_connected(&self) -> bool {
107 !self.closed.load(std::sync::atomic::Ordering::Acquire)
108 }
109
110 fn transport_type(&self) -> &'static str {
111 "stdio"
112 }
113}
114
115impl StdioTransport {
116 pub fn serialize_message(message: &TransportMessage) -> Result<Vec<u8>> {
118 match message {
119 TransportMessage::Request { id, request } => {
120 let jsonrpc_request = crate::shared::create_request(id.clone(), request.clone());
121 serde_json::to_vec(&jsonrpc_request).map_err(|e| {
122 TransportError::InvalidMessage(format!("Failed to serialize request: {}", e))
123 .into()
124 })
125 },
126 TransportMessage::Response(response) => serde_json::to_vec(response).map_err(|e| {
127 TransportError::InvalidMessage(format!("Failed to serialize response: {}", e))
128 .into()
129 }),
130 TransportMessage::Notification(notification) => {
131 let jsonrpc_notification = crate::shared::create_notification(notification.clone());
132 serde_json::to_vec(&jsonrpc_notification).map_err(|e| {
133 TransportError::InvalidMessage(format!(
134 "Failed to serialize notification: {}",
135 e
136 ))
137 .into()
138 })
139 },
140 }
141 }
142
143 async fn write_message(&self, json_bytes: &[u8]) -> Result<()> {
145 let mut stdout = self.stdout.lock().await;
146
147 let header = format!("{}{}\r\n\r\n", CONTENT_LENGTH_HEADER, json_bytes.len());
149 stdout
150 .write_all(header.as_bytes())
151 .await
152 .map_err(TransportError::from)?;
153
154 stdout
156 .write_all(json_bytes)
157 .await
158 .map_err(TransportError::from)?;
159
160 stdout.flush().await.map_err(TransportError::from)?;
162 drop(stdout);
163
164 Ok(())
165 }
166
167 async fn read_headers(&self) -> Result<usize> {
169 let mut stdin = self.stdin.lock().await;
170 let mut line = String::new();
171 let mut content_length = None;
172
173 loop {
175 line.clear();
176 let bytes_read = stdin
177 .read_line(&mut line)
178 .await
179 .map_err(TransportError::from)?;
180
181 if bytes_read == 0 {
182 drop(stdin);
184 self.closed
185 .store(true, std::sync::atomic::Ordering::Release);
186 return Err(TransportError::ConnectionClosed.into());
187 }
188
189 let line = line.trim();
190
191 if line.is_empty() {
192 break;
194 }
195
196 if let Some(length) = Self::parse_content_length(line) {
197 content_length = Some(length);
198 }
199 }
200 drop(stdin);
201
202 content_length.ok_or_else(|| {
203 TransportError::InvalidMessage("Missing Content-Length header".to_string()).into()
204 })
205 }
206
207 async fn read_message_body(&self, content_length: usize) -> Result<Vec<u8>> {
209 let mut stdin = self.stdin.lock().await;
210 let mut buffer = vec![0u8; content_length];
211 stdin
212 .read_exact(&mut buffer)
213 .await
214 .map_err(TransportError::from)?;
215 drop(stdin);
216 Ok(buffer)
217 }
218
219 pub fn parse_message(buffer: &[u8]) -> Result<TransportMessage> {
221 let json_value: serde_json::Value = serde_json::from_slice(buffer)
222 .map_err(|e| TransportError::InvalidMessage(format!("Invalid JSON: {}", e)))?;
223
224 if json_value.get("method").is_some() {
225 Self::parse_method_message(json_value)
226 } else if json_value.get("result").is_some() || json_value.get("error").is_some() {
227 Self::parse_response_message(json_value)
228 } else {
229 Err(TransportError::InvalidMessage("Unknown message type".to_string()).into())
230 }
231 }
232
233 fn parse_method_message(json_value: serde_json::Value) -> Result<TransportMessage> {
235 if json_value.get("id").is_some() {
236 let request: crate::types::JSONRPCRequest<serde_json::Value> =
238 serde_json::from_value(json_value).map_err(|e| {
239 TransportError::InvalidMessage(format!("Invalid request: {}", e))
240 })?;
241
242 let parsed_request = crate::shared::parse_request(request)
243 .map_err(|e| TransportError::InvalidMessage(format!("Invalid request: {}", e)))?;
244
245 Ok(TransportMessage::Request {
246 id: parsed_request.0,
247 request: parsed_request.1,
248 })
249 } else {
250 let parsed_notification =
252 crate::shared::parse_notification(json_value).map_err(|e| {
253 TransportError::InvalidMessage(format!("Invalid notification: {}", e))
254 })?;
255
256 Ok(TransportMessage::Notification(parsed_notification))
257 }
258 }
259
260 fn parse_response_message(json_value: serde_json::Value) -> Result<TransportMessage> {
262 let response: crate::types::JSONRPCResponse = serde_json::from_value(json_value)
263 .map_err(|e| TransportError::InvalidMessage(format!("Invalid response: {}", e)))?;
264
265 Ok(TransportMessage::Response(response))
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn parse_content_length_valid() {
275 assert_eq!(
276 StdioTransport::parse_content_length("Content-Length: 42"),
277 Some(42)
278 );
279 assert_eq!(
280 StdioTransport::parse_content_length("Content-Length: 0"),
281 Some(0)
282 );
283 assert_eq!(
284 StdioTransport::parse_content_length("Content-Length: 999999"),
285 Some(999_999)
286 );
287 assert_eq!(
289 StdioTransport::parse_content_length("Content-Length: 42 "),
290 Some(42)
291 );
292 }
293
294 #[test]
295 fn parse_content_length_invalid() {
296 assert_eq!(
297 StdioTransport::parse_content_length("Content-Type: application/json"),
298 None
299 );
300 assert_eq!(
301 StdioTransport::parse_content_length("Content-Length: abc"),
302 None
303 );
304 assert_eq!(StdioTransport::parse_content_length(""), None);
305 assert_eq!(
306 StdioTransport::parse_content_length("Content-Length: -42"),
307 None
308 );
309 assert_eq!(StdioTransport::parse_content_length("Content-Length"), None);
310 }
311
312 #[tokio::test]
313 async fn transport_properties() {
314 let transport = StdioTransport::new();
315 assert!(transport.is_connected());
316 assert_eq!(transport.transport_type(), "stdio");
317 }
318
319 #[tokio::test]
320 async fn test_close() {
321 let mut transport = StdioTransport::new();
322 assert!(transport.is_connected());
323
324 transport.close().await.unwrap();
325 assert!(!transport.is_connected());
326 }
327}