pulseengine_mcp_client/
transport.rs1use crate::error::{ClientError, ClientResult};
6use async_trait::async_trait;
7use pulseengine_mcp_protocol::{NumberOrString, Request, Response};
8use std::sync::Arc;
9use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
10use tokio::sync::Mutex;
11use tracing::{debug, trace};
12
13#[async_trait]
18pub trait ClientTransport: Send + Sync {
19 async fn send(&self, request: &Request) -> ClientResult<()>;
21
22 async fn recv(&self) -> ClientResult<JsonRpcMessage>;
26
27 async fn close(&self) -> ClientResult<()>;
29}
30
31#[derive(Debug, Clone)]
33pub enum JsonRpcMessage {
34 Response(Response),
36 Request(Request),
38 Notification {
40 method: String,
42 params: serde_json::Value,
44 },
45}
46
47impl JsonRpcMessage {
48 pub fn parse(json: &str) -> ClientResult<Self> {
50 let value: serde_json::Value = serde_json::from_str(json)?;
51
52 if value.get("result").is_some() || value.get("error").is_some() {
54 let response: Response = serde_json::from_value(value)?;
55 return Ok(Self::Response(response));
56 }
57
58 if let Some(method) = value.get("method").and_then(|m| m.as_str()) {
60 if value.get("id").is_some() && !value.get("id").unwrap().is_null() {
62 let request: Request = serde_json::from_value(value)?;
63 return Ok(Self::Request(request));
64 } else {
65 let params = value
66 .get("params")
67 .cloned()
68 .unwrap_or(serde_json::Value::Null);
69 return Ok(Self::Notification {
70 method: method.to_string(),
71 params,
72 });
73 }
74 }
75
76 Err(ClientError::protocol(
77 "Invalid JSON-RPC message: no method, result, or error",
78 ))
79 }
80}
81
82pub struct StdioClientTransport<R, W>
87where
88 R: tokio::io::AsyncRead + Unpin + Send,
89 W: tokio::io::AsyncWrite + Unpin + Send,
90{
91 reader: Arc<Mutex<BufReader<R>>>,
92 writer: Arc<Mutex<W>>,
93}
94
95impl<R, W> StdioClientTransport<R, W>
96where
97 R: tokio::io::AsyncRead + Unpin + Send,
98 W: tokio::io::AsyncWrite + Unpin + Send,
99{
100 pub fn new(reader: R, writer: W) -> Self {
106 Self {
107 reader: Arc::new(Mutex::new(BufReader::new(reader))),
108 writer: Arc::new(Mutex::new(writer)),
109 }
110 }
111}
112
113#[async_trait]
114impl<R, W> ClientTransport for StdioClientTransport<R, W>
115where
116 R: tokio::io::AsyncRead + Unpin + Send + 'static,
117 W: tokio::io::AsyncWrite + Unpin + Send + 'static,
118{
119 async fn send(&self, request: &Request) -> ClientResult<()> {
120 let json = serde_json::to_string(request)?;
121
122 if json.contains('\n') || json.contains('\r') {
124 return Err(ClientError::protocol(
125 "Request contains embedded newlines, which is not allowed by MCP spec",
126 ));
127 }
128
129 trace!("Sending request: {}", json);
130
131 let mut writer = self.writer.lock().await;
132 writer
133 .write_all(json.as_bytes())
134 .await
135 .map_err(|e| ClientError::transport(format!("Failed to write: {e}")))?;
136 writer
137 .write_all(b"\n")
138 .await
139 .map_err(|e| ClientError::transport(format!("Failed to write newline: {e}")))?;
140 writer
141 .flush()
142 .await
143 .map_err(|e| ClientError::transport(format!("Failed to flush: {e}")))?;
144
145 debug!(
146 "Sent request: method={}, id={:?}",
147 request.method, request.id
148 );
149 Ok(())
150 }
151
152 async fn recv(&self) -> ClientResult<JsonRpcMessage> {
153 let mut reader = self.reader.lock().await;
154 let mut line = String::new();
155
156 loop {
157 line.clear();
158 let bytes_read = reader
159 .read_line(&mut line)
160 .await
161 .map_err(|e| ClientError::transport(format!("Failed to read: {e}")))?;
162
163 if bytes_read == 0 {
164 return Err(ClientError::transport("EOF: server closed connection"));
165 }
166
167 let trimmed = line.trim();
168 if trimmed.is_empty() {
169 continue; }
171
172 trace!("Received message: {}", trimmed);
173 return JsonRpcMessage::parse(trimmed);
174 }
175 }
176
177 async fn close(&self) -> ClientResult<()> {
178 let mut writer = self.writer.lock().await;
180 writer
181 .flush()
182 .await
183 .map_err(|e| ClientError::transport(format!("Failed to flush on close: {e}")))?;
184 Ok(())
185 }
186}
187
188pub fn next_request_id() -> NumberOrString {
190 use std::sync::atomic::{AtomicU64, Ordering};
191 static COUNTER: AtomicU64 = AtomicU64::new(1);
192 NumberOrString::Number(COUNTER.fetch_add(1, Ordering::Relaxed) as i64)
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 #[test]
199 fn test_parse_response() {
200 let json = r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#;
201 let msg = JsonRpcMessage::parse(json).unwrap();
202 assert!(matches!(msg, JsonRpcMessage::Response(_)));
203 }
204
205 #[test]
206 fn test_parse_error_response() {
207 let json = r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"Invalid"}}"#;
208 let msg = JsonRpcMessage::parse(json).unwrap();
209 assert!(matches!(msg, JsonRpcMessage::Response(_)));
210 }
211
212 #[test]
213 fn test_parse_request() {
214 let json =
215 r#"{"jsonrpc":"2.0","method":"sampling/createMessage","params":{},"id":"req-1"}"#;
216 let msg = JsonRpcMessage::parse(json).unwrap();
217 assert!(matches!(msg, JsonRpcMessage::Request(_)));
218 }
219
220 #[test]
221 fn test_parse_notification() {
222 let json =
223 r#"{"jsonrpc":"2.0","method":"notifications/progress","params":{"progress":50}}"#;
224 let msg = JsonRpcMessage::parse(json).unwrap();
225 assert!(matches!(msg, JsonRpcMessage::Notification { .. }));
226 }
227
228 #[test]
229 fn test_next_request_id() {
230 let id1 = next_request_id();
231 let id2 = next_request_id();
232
233 if let (NumberOrString::Number(n1), NumberOrString::Number(n2)) = (id1, id2) {
235 assert_eq!(n2, n1 + 1);
236 } else {
237 panic!("Expected numeric IDs");
238 }
239 }
240}