pulseengine_mcp_client/
transport.rs

1//! Transport layer for MCP client
2//!
3//! Provides abstractions for bidirectional communication with MCP servers.
4
5use crate::error::{ClientError, ClientResult};
6use async_trait::async_trait;
7use pulseengine_mcp_protocol::{NumberOrString, Request, Response};
8use std::sync::Arc;
9use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
10use tokio::sync::Mutex;
11use tracing::{debug, trace};
12
13/// Trait for client-side MCP transport
14///
15/// This trait abstracts the underlying communication mechanism (stdio, WebSocket, etc.)
16/// and provides a simple interface for sending requests and receiving responses.
17#[async_trait]
18pub trait ClientTransport: Send + Sync {
19    /// Send a JSON-RPC request to the server
20    async fn send(&self, request: &Request) -> ClientResult<()>;
21
22    /// Receive the next message from the server
23    ///
24    /// This may be a response to a previous request or a server-initiated request.
25    async fn recv(&self) -> ClientResult<JsonRpcMessage>;
26
27    /// Close the transport
28    async fn close(&self) -> ClientResult<()>;
29}
30
31/// A JSON-RPC message that can be either a request or response
32#[derive(Debug, Clone)]
33pub enum JsonRpcMessage {
34    /// A response to a previous request
35    Response(Response),
36    /// A request from the server (for sampling, roots/list, etc.)
37    Request(Request),
38    /// A notification (no response expected)
39    Notification {
40        /// The notification method
41        method: String,
42        /// The notification parameters
43        params: serde_json::Value,
44    },
45}
46
47impl JsonRpcMessage {
48    /// Parse a JSON string into a JsonRpcMessage
49    pub fn parse(json: &str) -> ClientResult<Self> {
50        let value: serde_json::Value = serde_json::from_str(json)?;
51
52        // Check if it's a response (has result or error, no method)
53        if value.get("result").is_some() || value.get("error").is_some() {
54            let response: Response = serde_json::from_value(value)?;
55            return Ok(Self::Response(response));
56        }
57
58        // Check if it has a method (request or notification)
59        if let Some(method) = value.get("method").and_then(|m| m.as_str()) {
60            // If it has an id, it's a request; otherwise notification
61            if value.get("id").is_some() && !value.get("id").unwrap().is_null() {
62                let request: Request = serde_json::from_value(value)?;
63                return Ok(Self::Request(request));
64            } else {
65                let params = value
66                    .get("params")
67                    .cloned()
68                    .unwrap_or(serde_json::Value::Null);
69                return Ok(Self::Notification {
70                    method: method.to_string(),
71                    params,
72                });
73            }
74        }
75
76        Err(ClientError::protocol(
77            "Invalid JSON-RPC message: no method, result, or error",
78        ))
79    }
80}
81
82/// Standard I/O transport for MCP client
83///
84/// Communicates with an MCP server via stdin/stdout streams.
85/// Typically used with child process spawning.
86pub struct StdioClientTransport<R, W>
87where
88    R: tokio::io::AsyncRead + Unpin + Send,
89    W: tokio::io::AsyncWrite + Unpin + Send,
90{
91    reader: Arc<Mutex<BufReader<R>>>,
92    writer: Arc<Mutex<W>>,
93}
94
95impl<R, W> StdioClientTransport<R, W>
96where
97    R: tokio::io::AsyncRead + Unpin + Send,
98    W: tokio::io::AsyncWrite + Unpin + Send,
99{
100    /// Create a new stdio transport from read and write streams
101    ///
102    /// # Arguments
103    /// * `reader` - The input stream (typically child process stdout)
104    /// * `writer` - The output stream (typically child process stdin)
105    pub fn new(reader: R, writer: W) -> Self {
106        Self {
107            reader: Arc::new(Mutex::new(BufReader::new(reader))),
108            writer: Arc::new(Mutex::new(writer)),
109        }
110    }
111}
112
113#[async_trait]
114impl<R, W> ClientTransport for StdioClientTransport<R, W>
115where
116    R: tokio::io::AsyncRead + Unpin + Send + 'static,
117    W: tokio::io::AsyncWrite + Unpin + Send + 'static,
118{
119    async fn send(&self, request: &Request) -> ClientResult<()> {
120        let json = serde_json::to_string(request)?;
121
122        // Validate: no embedded newlines (MCP spec)
123        if json.contains('\n') || json.contains('\r') {
124            return Err(ClientError::protocol(
125                "Request contains embedded newlines, which is not allowed by MCP spec",
126            ));
127        }
128
129        trace!("Sending request: {}", json);
130
131        let mut writer = self.writer.lock().await;
132        writer
133            .write_all(json.as_bytes())
134            .await
135            .map_err(|e| ClientError::transport(format!("Failed to write: {e}")))?;
136        writer
137            .write_all(b"\n")
138            .await
139            .map_err(|e| ClientError::transport(format!("Failed to write newline: {e}")))?;
140        writer
141            .flush()
142            .await
143            .map_err(|e| ClientError::transport(format!("Failed to flush: {e}")))?;
144
145        debug!(
146            "Sent request: method={}, id={:?}",
147            request.method, request.id
148        );
149        Ok(())
150    }
151
152    async fn recv(&self) -> ClientResult<JsonRpcMessage> {
153        let mut reader = self.reader.lock().await;
154        let mut line = String::new();
155
156        loop {
157            line.clear();
158            let bytes_read = reader
159                .read_line(&mut line)
160                .await
161                .map_err(|e| ClientError::transport(format!("Failed to read: {e}")))?;
162
163            if bytes_read == 0 {
164                return Err(ClientError::transport("EOF: server closed connection"));
165            }
166
167            let trimmed = line.trim();
168            if trimmed.is_empty() {
169                continue; // Skip empty lines
170            }
171
172            trace!("Received message: {}", trimmed);
173            return JsonRpcMessage::parse(trimmed);
174        }
175    }
176
177    async fn close(&self) -> ClientResult<()> {
178        // For stdio, we just flush and let the streams drop
179        let mut writer = self.writer.lock().await;
180        writer
181            .flush()
182            .await
183            .map_err(|e| ClientError::transport(format!("Failed to flush on close: {e}")))?;
184        Ok(())
185    }
186}
187
188/// Create a request ID for tracking
189pub fn next_request_id() -> NumberOrString {
190    use std::sync::atomic::{AtomicU64, Ordering};
191    static COUNTER: AtomicU64 = AtomicU64::new(1);
192    NumberOrString::Number(COUNTER.fetch_add(1, Ordering::Relaxed) as i64)
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    #[test]
199    fn test_parse_response() {
200        let json = r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#;
201        let msg = JsonRpcMessage::parse(json).unwrap();
202        assert!(matches!(msg, JsonRpcMessage::Response(_)));
203    }
204
205    #[test]
206    fn test_parse_error_response() {
207        let json = r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"Invalid"}}"#;
208        let msg = JsonRpcMessage::parse(json).unwrap();
209        assert!(matches!(msg, JsonRpcMessage::Response(_)));
210    }
211
212    #[test]
213    fn test_parse_request() {
214        let json =
215            r#"{"jsonrpc":"2.0","method":"sampling/createMessage","params":{},"id":"req-1"}"#;
216        let msg = JsonRpcMessage::parse(json).unwrap();
217        assert!(matches!(msg, JsonRpcMessage::Request(_)));
218    }
219
220    #[test]
221    fn test_parse_notification() {
222        let json =
223            r#"{"jsonrpc":"2.0","method":"notifications/progress","params":{"progress":50}}"#;
224        let msg = JsonRpcMessage::parse(json).unwrap();
225        assert!(matches!(msg, JsonRpcMessage::Notification { .. }));
226    }
227
228    #[test]
229    fn test_next_request_id() {
230        let id1 = next_request_id();
231        let id2 = next_request_id();
232
233        // IDs should be sequential
234        if let (NumberOrString::Number(n1), NumberOrString::Number(n2)) = (id1, id2) {
235            assert_eq!(n2, n1 + 1);
236        } else {
237            panic!("Expected numeric IDs");
238        }
239    }
240}