Skip to main content

oxios_kernel/mcp/
client.rs

1//! MCP client — manages a single MCP server process lifecycle.
2//!
3//! `McpClient` spawns a child process and communicates with it over stdin/stdout
4//! using JSON-RPC 2.0 messages (one JSON object per line).
5//!
6//! I/O handles are stored persistently (not consumed via `take()`) so that
7//! multiple requests can be serialized through the same connection. A write
8//! lock on both stdin and stdout is held for the duration of each request-response
9//! cycle, ensuring correct ordering.
10
11use anyhow::{anyhow, Context, Result};
12use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
13use tokio::process::{Child, ChildStdin, ChildStdout, Command};
14use tokio::sync::RwLock;
15use tokio::time::{timeout, Duration};
16
17use super::protocol::*;
18
19// ---------------------------------------------------------------------------
20// McpClient — manages a single MCP server process lifecycle
21// ---------------------------------------------------------------------------
22
23/// Manages a single MCP server process with stdio JSON-RPC communication.
24///
25/// I/O handles are stored persistently so that concurrent requests can be
26/// serialized through the same connection without consuming the handles.
27///
28/// # Example
29///
30/// ```ignore
31/// let client = McpClient::new(server_config);
32/// client.initialize().await?;
33/// let tools = client.list_tools().await?;
34/// let result = client.call_tool("my_tool", serde_json::json!({"arg": "value"})).await?;
35/// client.shutdown().await?;
36/// ```
37pub struct McpClient {
38    /// Server configuration
39    server: McpServer,
40    /// Child process handle (None when not running)
41    child: RwLock<Option<Child>>,
42    /// Persistent stdin handle for writing to the server process.
43    stdin: RwLock<Option<tokio::io::BufWriter<ChildStdin>>>,
44    /// Persistent stdout handle for reading from the server process.
45    stdout: RwLock<Option<BufReader<ChildStdout>>>,
46    /// Whether the server has been initialized
47    initialized: RwLock<bool>,
48    /// Cached tool list (invalidated on refresh_tools)
49    tool_cache: RwLock<Option<Vec<McpTool>>>,
50    /// Server info received during initialize
51    server_info: RwLock<Option<ServerInfo>>,
52    /// Request timeout duration
53    request_timeout: Duration,
54}
55
56impl McpClient {
57    /// Create a new MCP client for the given server configuration.
58    ///
59    /// Does NOT spawn the process yet — call `initialize()` to start and negotiate.
60    pub fn new(server: McpServer) -> Self {
61        Self {
62            server,
63            child: RwLock::new(None),
64            stdin: RwLock::new(None),
65            stdout: RwLock::new(None),
66            initialized: RwLock::new(false),
67            tool_cache: RwLock::new(None),
68            server_info: RwLock::new(None),
69            request_timeout: Duration::from_secs(30),
70        }
71    }
72
73    /// Set the request timeout duration.
74    #[must_use]
75    pub fn with_timeout(mut self, timeout: Duration) -> Self {
76        self.request_timeout = timeout;
77        self
78    }
79
80    /// Spawn the MCP server process and establish communication.
81    pub async fn initialize(&self) -> Result<()> {
82        if *self.initialized.read().await {
83            return Ok(());
84        }
85
86        // Spawn the child process
87        let mut child = Command::new(&self.server.command)
88            .args(&self.server.args)
89            .envs(&self.server.env)
90            .stdin(std::process::Stdio::piped())
91            .stdout(std::process::Stdio::piped())
92            .stderr(std::process::Stdio::piped())
93            .spawn()
94            .with_context(|| format!("Failed to spawn MCP server '{}'", self.server.name))?;
95
96        let stdin = child
97            .stdin
98            .take()
99            .expect("stdin not captured — stdin was piped");
100        let stdout = child
101            .stdout
102            .take()
103            .expect("stdout not captured — stdout was piped");
104
105        // Store persistent I/O handles (separate from child process handle)
106        *self.stdin.write().await = Some(tokio::io::BufWriter::new(stdin));
107        *self.stdout.write().await = Some(BufReader::new(stdout));
108
109        // Store child handle
110        *self.child.write().await = Some(child);
111
112        // Send initialize request using persistent handles
113        let params = InitializeParams::default();
114        let request = McpRequest::new("initialize").with_params(serde_json::to_value(&params)?);
115
116        // Use do_request directly (not send_request) to avoid recursion
117        // since send_request may call restart() which calls initialize().
118        let response = self.do_request(request).await?;
119
120        // Parse initialize result
121        let result_json = response.into_result()?;
122        let init_result: InitializeResult = serde_json::from_value(result_json)?;
123
124        *self.server_info.write().await = Some(init_result.server_info.clone());
125        *self.initialized.write().await = true;
126
127        // Send initialised notification (JSON-RPC 2.0 requires this)
128        let notification = McpRequest::new("notifications/initialized");
129        self.send_notification(notification).await?;
130
131        tracing::debug!(
132            server = %self.server.name,
133            version = %init_result.server_info.version,
134            "MCP server initialized"
135        );
136
137        Ok(())
138    }
139
140    /// Check if the server has been initialized
141    pub async fn is_initialized(&self) -> bool {
142        *self.initialized.read().await
143    }
144
145    /// Get the server info received during initialize
146    pub async fn server_info(&self) -> Option<ServerInfo> {
147        self.server_info.read().await.clone()
148    }
149
150    /// Send a JSON-RPC request using persistent I/O handles.
151    ///
152    /// Acquires write locks on both stdin and stdout for the duration of
153    /// the request-response cycle, serializing concurrent access.
154    async fn do_request(&self, request: McpRequest) -> Result<McpResponse> {
155        let request_id = request.id.clone();
156
157        // Acquire stdin lock for writing
158        let mut stdin_guard = self.stdin.write().await;
159        let stdin = stdin_guard
160            .as_mut()
161            .ok_or_else(|| anyhow!("stdin not available on '{}'", self.server.name))?;
162
163        // Write the request
164        let json = request.to_jsonl()?;
165        timeout(self.request_timeout, async {
166            stdin.write_all(&json).await?;
167            stdin.flush().await?;
168            Ok::<(), tokio::io::Error>(())
169        })
170        .await
171        .map_err(|e| anyhow::anyhow!("MCP request timed out (write): {}", e))??;
172
173        // Acquire stdout lock for reading
174        let mut stdout_guard = self.stdout.write().await;
175        let stdout = stdout_guard
176            .as_mut()
177            .ok_or_else(|| anyhow!("stdout not available on '{}'", self.server.name))?;
178
179        // Read the response (single JSON line)
180        let line: std::io::Result<Option<String>> = timeout(self.request_timeout, async {
181            stdout.lines().next_line().await
182        })
183        .await
184        .map_err(|e| anyhow::anyhow!("MCP request timed out (read): {}", e))?;
185
186        let response_str: String = line
187            .context("Failed to read MCP response line from stdout")?
188            .with_context(|| format!("MCP server {} returned no response", self.server.name))?;
189
190        let parsed: McpResponse = serde_json::from_str(&response_str)
191            .with_context(|| format!("Failed to parse MCP response JSON: {}", response_str))?;
192
193        // Sanity check: ID should match
194        if parsed.id != request_id {
195            tracing::warn!(
196                server = %self.server.name,
197                expected_id = ?request_id,
198                got_id = ?parsed.id,
199                "MCP response ID mismatch"
200            );
201        }
202
203        Ok(parsed)
204    }
205
206    /// Send a JSON-RPC notification (no response expected).
207    async fn send_notification(&self, notification: McpRequest) -> Result<()> {
208        let mut stdin_guard = self.stdin.write().await;
209        let stdin = stdin_guard
210            .as_mut()
211            .ok_or_else(|| anyhow!("stdin not available on '{}'", self.server.name))?;
212
213        let json = notification.to_jsonl()?;
214        stdin.write_all(&json).await?;
215        stdin.flush().await?;
216
217        Ok(())
218    }
219
220    /// Send a JSON-RPC request via persistent I/O handles.
221    ///
222    /// If the server is not running, attempts one automatic restart before failing.
223    /// The restart itself uses the low-level `do_request` path, not `send_request`,
224    /// to avoid async recursion.
225    pub(crate) async fn send_request(&self, request: McpRequest) -> Result<McpResponse> {
226        // Verify server is running; attempt auto-restart if not
227        {
228            let child = self.child.read().await;
229            if child.is_none() {
230                tracing::warn!(
231                    server = %self.server.name,
232                    "MCP server not running, attempting auto-start"
233                );
234                drop(child);
235                // Use restart (shutdown + initialize) which doesn't call send_request
236                self.restart().await?;
237            }
238        }
239
240        match self.do_request(request).await {
241            Ok(resp) => Ok(resp),
242            Err(e) => {
243                // Auto-restart on communication errors (crashed server)
244                let err_str = e.to_string();
245                let is_comm_error = err_str.contains("not available")
246                    || err_str.contains("broken pipe")
247                    || err_str.contains("timed out")
248                    || err_str.contains("no response");
249
250                if is_comm_error {
251                    tracing::warn!(
252                        server = %self.server.name,
253                        error = %err_str,
254                        "MCP communication error, attempting auto-restart"
255                    );
256                    self.restart().await?;
257                    anyhow::bail!(
258                        "MCP server '{}' restarted after error. Please retry the request.",
259                        self.server.name
260                    );
261                } else {
262                    Err(e)
263                }
264            }
265        }
266    }
267
268    /// List all tools available from this MCP server.
269    ///
270    /// Results are cached and refreshed on [refresh_tools](Self::refresh_tools).
271    pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
272        // Return cached tools if available
273        if let Some(cached) = self.tool_cache.read().await.clone() {
274            return Ok(cached);
275        }
276
277        self.refresh_tools().await
278    }
279
280    /// Force-refresh the tool list from the server.
281    pub async fn refresh_tools(&self) -> Result<Vec<McpTool>> {
282        let request = McpRequest::new("tools/list");
283        let response = self.send_request(request).await?;
284
285        let result_json = response.into_result()?;
286        let tools_result: McpToolsResult = serde_json::from_value(result_json)?;
287
288        let tools = tools_result.tools;
289        *self.tool_cache.write().await = Some(tools.clone());
290
291        tracing::debug!(
292            server = %self.server.name,
293            count = tools.len(),
294            "Refreshed tool cache"
295        );
296
297        Ok(tools)
298    }
299
300    /// Call a tool on this MCP server.
301    ///
302    /// The server must be initialized first.
303    pub async fn call_tool(
304        &self,
305        tool_name: &str,
306        arguments: serde_json::Value,
307    ) -> Result<McpToolCallResult> {
308        let params = serde_json::json!({
309            "name": tool_name,
310            "arguments": arguments,
311        });
312
313        let request = McpRequest::new("tools/call").with_params(params);
314        let response = self.send_request(request).await?;
315
316        let result_json = response.into_result()?;
317        let call_result: McpToolCallResult = serde_json::from_value(result_json)?;
318
319        tracing::debug!(
320            server = %self.server.name,
321            tool = tool_name,
322            "Tool call completed"
323        );
324
325        Ok(call_result)
326    }
327
328    /// Call a tool and return the result content as a string.
329    ///
330    /// Returns the first text content block, or an error if no text content.
331    pub async fn call_tool_text(
332        &self,
333        tool_name: &str,
334        arguments: serde_json::Value,
335    ) -> Result<String> {
336        let result = self.call_tool(tool_name, arguments).await?;
337
338        for block in result.content {
339            if let McpContentBlock::Text { text } = block {
340                return Ok(text);
341            }
342        }
343
344        Err(anyhow!("Tool '{}' returned no text content", tool_name))
345    }
346
347    /// Gracefully shutdown the MCP server process.
348    ///
349    /// Drops persistent I/O handles first, then kills the child process.
350    pub async fn shutdown(&self) -> Result<()> {
351        // Drop persistent I/O handles first
352        *self.stdin.write().await = None;
353        *self.stdout.write().await = None;
354
355        let mut child_guard = self.child.write().await;
356
357        if let Some(mut child) = child_guard.take() {
358            tracing::debug!(server = %self.server.name, "Shutting down MCP server");
359
360            // Try graceful shutdown first
361            let _ = child.try_wait();
362
363            // Kill the process
364            child.kill().await?;
365            let _ = child.wait().await;
366        }
367
368        *self.initialized.write().await = false;
369        *self.tool_cache.write().await = None;
370
371        Ok(())
372    }
373
374    /// Restart the server (shutdown then initialize).
375    pub async fn restart(&self) -> Result<()> {
376        self.shutdown().await?;
377        self.initialize().await
378    }
379
380    /// Get the server configuration
381    pub fn server(&self) -> &McpServer {
382        &self.server
383    }
384}
385
386impl std::fmt::Debug for McpClient {
387    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388        f.debug_struct("McpClient")
389            .field("server", &self.server.name)
390            .field("initialized", &self.initialized)
391            .finish()
392    }
393}
394
395// ============================================================================
396// Tests
397// ============================================================================
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use tokio::time::Duration;
403
404    // --- McpClient construction and configuration tests ---
405
406    #[test]
407    fn test_client_construction() {
408        let server = McpServer::new("test-server", "npx");
409        let client = McpClient::new(server);
410
411        // Verify the server config is stored correctly
412        assert_eq!(client.server.name, "test-server");
413        assert_eq!(client.server.command, "npx");
414    }
415
416    #[test]
417    fn test_client_with_timeout() {
418        let server = McpServer::new("test", "echo");
419        let client = McpClient::new(server).with_timeout(Duration::from_secs(60));
420
421        // The timeout should be set to 60 seconds
422        // We verify this indirectly by checking the client was constructed
423        // with the modified configuration (via the builder pattern)
424        assert_eq!(client.server.name, "test");
425    }
426
427    #[test]
428    fn test_client_with_timeout_short() {
429        let server = McpServer::new("test", "sleep");
430        let client = McpClient::new(server).with_timeout(Duration::from_millis(50));
431
432        assert_eq!(client.server.name, "test");
433        // Timeout of 50ms is very short
434    }
435
436    #[test]
437    fn test_client_debug_format() {
438        let server = McpServer::new("debug-test", "echo");
439        let client = McpClient::new(server);
440
441        let debug_str = format!("{:?}", client);
442
443        // Debug output should contain the server name
444        assert!(debug_str.contains("debug-test"));
445        assert!(debug_str.contains("McpClient"));
446    }
447
448    #[test]
449    fn test_client_debug_different_servers() {
450        let server1 = McpServer::new("server-a", "cmd1");
451        let server2 = McpServer::new("server-b", "cmd2");
452
453        let client1 = McpClient::new(server1);
454        let client2 = McpClient::new(server2);
455
456        let debug1 = format!("{:?}", client1);
457        let debug2 = format!("{:?}", client2);
458
459        assert!(debug1.contains("server-a"));
460        assert!(debug2.contains("server-b"));
461        assert_ne!(debug1, debug2);
462    }
463
464    #[tokio::test]
465    async fn test_is_initialized_false_on_new() {
466        let server = McpServer::new("test", "echo");
467        let client = McpClient::new(server);
468
469        // New client should not be initialized
470        assert!(!client.is_initialized().await);
471    }
472
473    #[tokio::test]
474    async fn test_is_initialized_after_failed_init() {
475        let server = McpServer::new("ghost", "nonexistent-binary-xyz-123");
476        let client = McpClient::new(server);
477
478        // Failed init should leave client not initialized
479        let result = client.initialize().await;
480        assert!(result.is_err());
481        assert!(!client.is_initialized().await);
482    }
483
484    #[tokio::test]
485    async fn test_shutdown_when_not_running() {
486        let server = McpServer::new("test-shutdown", "echo");
487        let client = McpClient::new(server);
488
489        // Shutting down without ever starting should succeed gracefully
490        let result = client.shutdown().await;
491        assert!(result.is_ok());
492
493        // Client should still report as not initialized
494        assert!(!client.is_initialized().await);
495    }
496
497    #[tokio::test]
498    async fn test_shutdown_idempotent() {
499        let server = McpServer::new("test-idempotent", "echo");
500        let client = McpClient::new(server);
501
502        // First shutdown
503        let first = client.shutdown().await;
504        assert!(first.is_ok());
505
506        // Second shutdown should also succeed (idempotent)
507        let second = client.shutdown().await;
508        assert!(second.is_ok());
509    }
510
511    #[test]
512    fn test_client_server_config_passed_through() {
513        let server = McpServer::new("config-test", "npx")
514            .with_args(vec!["-y".to_string(), "@some/mcp-server".to_string()])
515            .with_env("DEBUG", "true");
516
517        let client = McpClient::new(server);
518
519        assert_eq!(client.server.name, "config-test");
520        assert_eq!(client.server.command, "npx");
521        assert_eq!(client.server.args, vec!["-y", "@some/mcp-server"]);
522        assert_eq!(client.server.env.get("DEBUG"), Some(&"true".to_string()));
523    }
524
525    #[test]
526    fn test_client_server_method() {
527        let server = McpServer::new("method-test", "python");
528        let client = McpClient::new(server);
529
530        // server() method should return a reference to the server config
531        let retrieved_server = client.server();
532        assert_eq!(retrieved_server.name, "method-test");
533    }
534
535    #[tokio::test]
536    async fn test_server_info_none_on_new_client() {
537        let server = McpServer::new("test", "echo");
538        let client = McpClient::new(server);
539
540        // Server info should be None until initialized
541        assert!(client.server_info().await.is_none());
542    }
543
544    #[tokio::test]
545    async fn test_initialize_already_initialized_skipped() {
546        let server = McpServer::new("echo", "echo");
547        let client = McpClient::new(server);
548
549        // First init fails (echo doesn't speak MCP)
550        let _ = client.initialize().await;
551
552        // Double init should be a no-op (not panic)
553        let result = client.initialize().await;
554        // Result may be error from echo (not MCP protocol) but shouldn't panic
555        assert!(result.is_err() || result.is_ok());
556    }
557
558    #[test]
559    fn test_client_default_timeout_is_30_seconds() {
560        let server = McpServer::new("test", "echo");
561        let client = McpClient::new(server);
562
563        // We can't directly access request_timeout, but we can verify
564        // the client is constructable and basic operations work
565        assert_eq!(client.server.name, "test");
566    }
567
568    #[tokio::test]
569    async fn test_shutdown_clears_initialized_flag() {
570        let server = McpServer::new("test-clear", "echo");
571        let client = McpClient::new(server);
572
573        // Ensure initialized is false
574        assert!(!client.is_initialized().await);
575
576        // Shutdown should keep it false
577        client.shutdown().await.unwrap();
578        assert!(!client.is_initialized().await);
579    }
580}