Skip to main content

nika_mcp/
pool.rs

1//! MCP Client Pool
2//!
3//! Centralized lifecycle manager for MCP client connections.
4//! Handles lazy initialization, deduplication, and graceful shutdown.
5//!
6//! ## Why not a traditional connection pool (bb8/deadpool)?
7//!
8//! MCP is 1-connection-per-server, not N-connections-per-server.
9//! We need per-server lazy init + coordinated shutdown, not pool sizing.
10//!
11//! ## Thread Safety
12//!
13//! `McpClientPool` is `Clone + Send + Sync`. Cloning is cheap (Arc inner).
14//! Multiple components (TaskExecutor, App, ChatAgent) share the same pool.
15//!
16//! ## Initialization Pattern
17//!
18//! Uses `DashMap<String, Arc<OnceCell<Arc<McpClient>>>>`:
19//!
20//! - **DashMap**: Concurrent access to different servers (shard-level locking)
21//! - **OnceCell**: Per-server async init serialization
22//!   - Only one task spawns the server process; others wait
23//!   - If init fails, cell stays uninitialized for retry on next call
24//! - **Arc wrapping**: Releases DashMap shard lock before awaiting
25
26use std::sync::atomic::{AtomicBool, Ordering};
27use std::sync::Arc;
28use std::time::Duration;
29
30use dashmap::DashMap;
31use rustc_hash::FxHashMap;
32use tokio::sync::OnceCell;
33
34use crate::error::McpError;
35use crate::types::McpConfig;
36use crate::validation::ValidationConfig;
37use crate::{McpClient, McpConfigInline};
38use nika_event::{EventKind, EventLog};
39
40/// Centralized MCP client lifecycle manager.
41///
42/// Provides lazy connection establishment, per-server deduplication,
43/// and coordinated shutdown of all MCP server processes.
44///
45/// # Clone semantics
46///
47/// Cloning is cheap (Arc inner). All clones share the same pool state.
48/// This enables sharing across TaskExecutor, TUI App, and ChatAgent.
49///
50/// # Example
51///
52/// ```rust,ignore
53/// let pool = McpClientPool::with_configs(event_log, mcp_configs);
54///
55/// // Lazy connect (first call spawns server, subsequent return cached)
56/// let client = pool.get_or_connect("neo4j").await?;
57/// client.call_tool("novanet_search", params).await?;
58///
59/// // Graceful shutdown (disconnects all servers)
60/// pool.shutdown_all().await;
61/// ```
62#[derive(Clone)]
63pub struct McpClientPool {
64    inner: Arc<PoolInner>,
65}
66
67struct PoolInner {
68    /// Per-server lazy-initialized clients.
69    ///
70    /// ## Why `Arc<OnceCell<Arc<McpClient>>>`?
71    ///
72    /// 1. **DashMap shard lock**: `entry().or_insert_with()` holds shard lock.
73    ///    We `.clone()` the `Arc<OnceCell>` to release it before awaiting.
74    /// 2. **OnceCell serialization**: Only one task runs the init closure;
75    ///    concurrent callers wait on the same future.
76    /// 3. **Retry on failure**: If init fails, cell stays uninitialized
77    ///    and next call retries (documented tokio::OnceCell behavior).
78    clients: DashMap<String, Arc<OnceCell<Arc<McpClient>>>>,
79
80    /// Server configurations (workflow-level or global).
81    ///
82    /// parking_lot::RwLock because configs are rarely mutated but read on every get_or_connect().
83    configs: parking_lot::RwLock<FxHashMap<String, McpConfigInline>>,
84
85    /// Event log for McpConnected/McpError observability events.
86    event_log: EventLog,
87
88    /// Shutdown flag. Once true, get_or_connect() returns Err immediately.
89    is_shutdown: AtomicBool,
90}
91
92impl McpClientPool {
93    /// Create an empty pool (no server configurations loaded yet).
94    pub fn new(event_log: EventLog) -> Self {
95        Self {
96            inner: Arc::new(PoolInner {
97                clients: DashMap::new(),
98                configs: parking_lot::RwLock::new(FxHashMap::default()),
99                event_log,
100                is_shutdown: AtomicBool::new(false),
101            }),
102        }
103    }
104
105    /// Create a pool with pre-loaded server configurations.
106    pub fn with_configs(event_log: EventLog, configs: FxHashMap<String, McpConfigInline>) -> Self {
107        Self {
108            inner: Arc::new(PoolInner {
109                clients: DashMap::new(),
110                configs: parking_lot::RwLock::new(configs),
111                event_log,
112                is_shutdown: AtomicBool::new(false),
113            }),
114        }
115    }
116
117    // ═══════════════════════════════════════════════════════════════════════
118    // CONFIG MANAGEMENT
119    // ═══════════════════════════════════════════════════════════════════════
120
121    /// Replace all server configurations.
122    ///
123    /// Does NOT disconnect existing clients. Call `shutdown_all()` first
124    /// if you need to force reconnection with new configs.
125    pub fn set_configs(&self, configs: FxHashMap<String, McpConfigInline>) {
126        *self.inner.configs.write() = configs;
127    }
128
129    /// Get a read reference to the current configs.
130    pub fn configs(&self) -> parking_lot::RwLockReadGuard<'_, FxHashMap<String, McpConfigInline>> {
131        self.inner.configs.read()
132    }
133
134    /// Check if a configuration exists for the given server name.
135    pub fn has_config(&self, name: &str) -> bool {
136        self.inner.configs.read().contains_key(name)
137    }
138
139    /// Return the number of configured MCP servers.
140    pub fn config_count(&self) -> usize {
141        self.inner.configs.read().len()
142    }
143
144    /// Get the event log.
145    pub fn event_log(&self) -> &EventLog {
146        &self.inner.event_log
147    }
148
149    // ═══════════════════════════════════════════════════════════════════════
150    // CLIENT ACCESS (THE MAIN API)
151    // ═══════════════════════════════════════════════════════════════════════
152
153    /// Get an existing client or establish a new connection.
154    ///
155    /// This is the primary API. It:
156    /// 1. Returns a cached client if already connected
157    /// 2. Spawns the server process and connects if this is the first call
158    /// 3. Serializes concurrent init attempts per server (OnceCell)
159    /// 4. Retries automatically if a previous init attempt failed
160    ///
161    /// # Errors
162    ///
163    /// - `McpError::McpNotConfigured` if no config exists for this server
164    /// - `McpError::McpStartError` if the server process fails to spawn
165    /// - `McpError::McpStartError` if the pool is shut down
166    pub async fn get_or_connect(&self, name: &str) -> Result<Arc<McpClient>, McpError> {
167        // Fast path: reject if pool is shutting down
168        if self.inner.is_shutdown.load(Ordering::SeqCst) {
169            return Err(McpError::McpStartError {
170                name: name.to_string(),
171                reason: "MCP client pool is shut down".to_string(),
172            });
173        }
174
175        // Single allocation reused for DashMap entry key and init closure.
176        let name_owned = name.to_string();
177
178        // Get or create the OnceCell for this server.
179        // SAFETY: entry() holds a shard lock. The .clone() immediately releases it.
180        // NEVER access self.inner.clients from within the OnceCell init closure.
181        let cell = self
182            .inner
183            .clients
184            .entry(name_owned.clone())
185            .or_insert_with(|| Arc::new(OnceCell::new()))
186            .clone();
187
188        // Capture Arc<PoolInner> for the async init closure.
189        // We clone the Arc (cheap) to avoid borrowing self across await.
190        let pool_inner = Arc::clone(&self.inner);
191
192        // OnceCell::get_or_try_init ensures:
193        // - Only one task runs the init closure
194        // - Concurrent callers wait for the result
195        // - If init fails, the cell stays empty for retry
196        let client = cell
197            .get_or_try_init(|| async {
198                // Double-check shutdown inside init closure to close TOCTOU race:
199                // A task could pass the outer check, then shutdown_all() runs,
200                // then this closure starts — we must reject here too.
201                if pool_inner.is_shutdown.load(Ordering::SeqCst) {
202                    return Err(McpError::McpStartError {
203                        name: name_owned.clone(),
204                        reason: "MCP client pool is shut down".to_string(),
205                    });
206                }
207                Self::connect_server(&pool_inner.configs, &pool_inner.event_log, &name_owned).await
208            })
209            .await?;
210
211        Ok(Arc::clone(client))
212    }
213
214    /// Internal: spawn and connect to an MCP server.
215    async fn connect_server(
216        configs: &parking_lot::RwLock<FxHashMap<String, McpConfigInline>>,
217        event_log: &EventLog,
218        name: &str,
219    ) -> Result<Arc<McpClient>, McpError> {
220        // Read config (hold read lock only briefly)
221        let config = {
222            let guard = configs.read();
223            guard.get(name).cloned()
224        };
225
226        let config = config.ok_or_else(|| McpError::McpNotConfigured {
227            name: name.to_string(),
228        })?;
229
230        // Build McpConfig from inline config
231        let mut mcp_config = McpConfig::new(name, &config.command);
232        for arg in &config.args {
233            mcp_config = mcp_config.with_arg(arg);
234        }
235        for (key, value) in &config.env {
236            mcp_config = mcp_config.with_env(key, value);
237        }
238        if let Some(cwd) = &config.cwd {
239            mcp_config = mcp_config.with_cwd(cwd);
240        }
241
242        // Expand environment variables ($VAR, ${VAR}, ~) in command/args/env/cwd
243        let mcp_config = mcp_config
244            .expand_env_vars()
245            .map_err(|e| McpError::McpStartError {
246                name: name.to_string(),
247                reason: format!("Environment variable expansion failed: {}", e),
248            })?;
249
250        // Create with validation enabled and connect
251        let client = McpClient::new(mcp_config)
252            .map_err(|e| McpError::McpStartError {
253                name: name.to_string(),
254                reason: e.to_string(),
255            })?
256            .with_validation(ValidationConfig::default());
257
258        match client.connect().await {
259            Ok(()) => {
260                // Cache tools for synchronous get_tool_definitions() access
261                if let Err(e) = client.list_tools().await {
262                    tracing::warn!(mcp_server = %name, error = %e, "Failed to cache tools");
263                }
264
265                tracing::info!(mcp_server = %name, "Connected to MCP server");
266                event_log.emit(EventKind::McpConnected {
267                    server_name: name.to_string(),
268                });
269
270                Ok(Arc::new(client))
271            }
272            Err(e) => {
273                let error_msg = e.to_string();
274                event_log.emit(EventKind::McpError {
275                    server_name: name.to_string(),
276                    error: error_msg.clone(),
277                });
278
279                Err(McpError::McpStartError {
280                    name: name.to_string(),
281                    reason: error_msg,
282                })
283            }
284        }
285    }
286
287    // ═══════════════════════════════════════════════════════════════════════
288    // INSPECTION
289    // ═══════════════════════════════════════════════════════════════════════
290
291    /// Check if a server has an active (initialized) connection.
292    pub fn is_connected(&self, name: &str) -> bool {
293        self.inner
294            .clients
295            .get(name)
296            .and_then(|cell| cell.get().map(|_| true))
297            .unwrap_or(false)
298    }
299
300    /// Count of servers with active connections.
301    pub fn connected_count(&self) -> usize {
302        self.inner
303            .clients
304            .iter()
305            .filter(|entry| entry.value().get().is_some())
306            .count()
307    }
308
309    /// Check if the pool has been shut down.
310    pub fn is_shutdown(&self) -> bool {
311        self.inner.is_shutdown.load(Ordering::SeqCst)
312    }
313
314    // ═══════════════════════════════════════════════════════════════════════
315    // LIFECYCLE
316    // ═══════════════════════════════════════════════════════════════════════
317
318    /// Disconnect a specific server and remove it from the pool.
319    ///
320    /// Always removes the entry to prevent dangling OnceCell references,
321    /// even if disconnect fails. The next call to `get_or_connect()` will re-initialize.
322    pub async fn disconnect(&self, name: &str) -> Result<(), McpError> {
323        // Attempt disconnect, capturing any error
324        let disconnect_err = if let Some(cell) = self.inner.clients.get(name) {
325            if let Some(client) = cell.get() {
326                client.disconnect().await.err()
327            } else {
328                None
329            }
330        } else {
331            None
332        };
333
334        // Always remove to prevent dangling entries with spent OnceCell
335        self.inner.clients.remove(name);
336
337        if let Some(e) = disconnect_err {
338            return Err(e);
339        }
340        Ok(())
341    }
342
343    /// Gracefully shut down all MCP server connections.
344    ///
345    /// After this call:
346    /// - All server processes are terminated
347    /// - The pool is marked as shut down
348    /// - `get_or_connect()` will return Err for all subsequent calls
349    ///
350    /// This method is idempotent.
351    pub async fn shutdown_all(&self) {
352        // 1. Set shutdown flag to reject new connections
353        self.inner.is_shutdown.store(true, Ordering::SeqCst);
354
355        // 2. Drain all clients from the map
356        let entries: Vec<(String, Arc<OnceCell<Arc<McpClient>>>)> = self
357            .inner
358            .clients
359            .iter()
360            .map(|entry| (entry.key().clone(), Arc::clone(entry.value())))
361            .collect();
362
363        self.inner.clients.clear();
364
365        // 3. Disconnect each initialized client with timeout
366        for (name, cell) in entries {
367            if let Some(client) = cell.get() {
368                let disconnect_result =
369                    tokio::time::timeout(Duration::from_secs(5), client.disconnect()).await;
370
371                match disconnect_result {
372                    Ok(Ok(())) => {
373                        tracing::debug!(server = %name, "MCP server disconnected");
374                    }
375                    Ok(Err(e)) => {
376                        tracing::warn!(server = %name, error = %e, "Error disconnecting MCP server");
377                    }
378                    Err(_) => {
379                        tracing::warn!(server = %name, "MCP server disconnect timed out (5s)");
380                    }
381                }
382            }
383        }
384    }
385
386    // ═══════════════════════════════════════════════════════════════════════
387    // TESTING
388    // ═══════════════════════════════════════════════════════════════════════
389
390    /// Inject a pre-built client for testing.
391    ///
392    /// The client is inserted as already-initialized, bypassing connect_server().
393    /// Only intended for test code (production callers use connect_server).
394    pub fn inject_mock(&self, name: &str, client: Arc<McpClient>) {
395        let cell = Arc::new(OnceCell::new());
396        // OnceCell is freshly created so set() always succeeds
397        let _ = cell.set(client);
398        self.inner.clients.insert(name.to_string(), cell);
399    }
400}
401
402impl std::fmt::Debug for McpClientPool {
403    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
404        f.debug_struct("McpClientPool")
405            .field("connected", &self.connected_count())
406            .field("configured", &self.inner.configs.read().len())
407            .field("is_shutdown", &self.is_shutdown())
408            .finish()
409    }
410}
411
412// Compile-time assertion: McpClientPool must be Send + Sync + Clone.
413// If a future change introduces a !Send or !Sync field, this fails at definition site.
414const _: () = {
415    fn _assert_send_sync_clone<T: Send + Sync + Clone>() {}
416    fn _check() {
417        _assert_send_sync_clone::<McpClientPool>();
418    }
419};
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use nika_event::EventLog;
425
426    #[test]
427    fn test_pool_new_is_empty() {
428        let pool = McpClientPool::new(EventLog::new());
429        assert_eq!(pool.connected_count(), 0);
430        assert!(!pool.is_shutdown());
431    }
432
433    #[test]
434    fn test_pool_with_configs() {
435        let mut configs = FxHashMap::default();
436        configs.insert(
437            "test".to_string(),
438            McpConfigInline {
439                command: "echo".to_string(),
440                args: vec![],
441                env: FxHashMap::default(),
442                cwd: None,
443            },
444        );
445
446        let pool = McpClientPool::with_configs(EventLog::new(), configs);
447        assert!(pool.has_config("test"));
448        assert!(!pool.has_config("missing"));
449    }
450
451    #[test]
452    fn test_pool_clone_shares_state() {
453        let pool1 = McpClientPool::new(EventLog::new());
454        let pool2 = pool1.clone();
455
456        let mock = Arc::new(McpClient::mock("test"));
457        pool1.inject_mock("test", mock);
458
459        // pool2 should see the same client
460        assert!(pool2.is_connected("test"));
461    }
462
463    #[test]
464    fn test_pool_is_connected_false_when_empty() {
465        let pool = McpClientPool::new(EventLog::new());
466        assert!(!pool.is_connected("neo4j"));
467    }
468
469    #[test]
470    fn test_pool_inject_mock() {
471        let pool = McpClientPool::new(EventLog::new());
472        let mock = Arc::new(McpClient::mock("novanet"));
473        pool.inject_mock("novanet", mock);
474
475        assert!(pool.is_connected("novanet"));
476        assert_eq!(pool.connected_count(), 1);
477    }
478
479    #[tokio::test]
480    async fn test_pool_get_or_connect_with_mock() {
481        let pool = McpClientPool::new(EventLog::new());
482        let mock = Arc::new(McpClient::mock("novanet"));
483        pool.inject_mock("novanet", mock);
484
485        let client = pool.get_or_connect("novanet").await.unwrap();
486        assert!(client.is_connected());
487        assert_eq!(client.name(), "novanet");
488    }
489
490    #[tokio::test]
491    async fn test_pool_get_or_connect_not_configured() {
492        let pool = McpClientPool::new(EventLog::new());
493        let result = pool.get_or_connect("missing").await;
494        assert!(result.is_err());
495        assert!(
496            result.unwrap_err().to_string().contains("not configured"),
497            "Expected McpNotConfigured error"
498        );
499    }
500
501    #[tokio::test]
502    async fn test_pool_shutdown_rejects_new_connections() {
503        let pool = McpClientPool::new(EventLog::new());
504        pool.shutdown_all().await;
505
506        assert!(pool.is_shutdown());
507        let result = pool.get_or_connect("test").await;
508        assert!(result.is_err());
509        assert!(result.unwrap_err().to_string().contains("shut down"));
510    }
511
512    #[tokio::test]
513    async fn test_pool_disconnect_single_server() {
514        let pool = McpClientPool::new(EventLog::new());
515        let mock = Arc::new(McpClient::mock("test"));
516        pool.inject_mock("test", mock);
517
518        assert!(pool.is_connected("test"));
519        pool.disconnect("test").await.unwrap();
520        assert!(!pool.is_connected("test"));
521    }
522
523    #[tokio::test]
524    async fn test_pool_shutdown_clears_all() {
525        let pool = McpClientPool::new(EventLog::new());
526        pool.inject_mock("a", Arc::new(McpClient::mock("a")));
527        pool.inject_mock("b", Arc::new(McpClient::mock("b")));
528        assert_eq!(pool.connected_count(), 2);
529
530        pool.shutdown_all().await;
531        assert_eq!(pool.connected_count(), 0);
532        assert!(pool.is_shutdown());
533    }
534
535    #[test]
536    fn test_pool_set_configs() {
537        let pool = McpClientPool::new(EventLog::new());
538        assert!(!pool.has_config("neo4j"));
539
540        let mut configs = FxHashMap::default();
541        configs.insert(
542            "neo4j".to_string(),
543            McpConfigInline {
544                command: "npx".to_string(),
545                args: vec![],
546                env: FxHashMap::default(),
547                cwd: None,
548            },
549        );
550        pool.set_configs(configs);
551        assert!(pool.has_config("neo4j"));
552    }
553
554    #[test]
555    fn test_pool_debug_format() {
556        let pool = McpClientPool::new(EventLog::new());
557        let debug = format!("{:?}", pool);
558        assert!(debug.contains("McpClientPool"));
559        assert!(debug.contains("connected: 0"));
560    }
561}