1use std::sync::atomic::{AtomicBool, Ordering};
27use std::sync::Arc;
28use std::time::Duration;
29
30use dashmap::DashMap;
31use rustc_hash::FxHashMap;
32use tokio::sync::OnceCell;
33
34use crate::error::McpError;
35use crate::types::McpConfig;
36use crate::validation::ValidationConfig;
37use crate::{McpClient, McpConfigInline};
38use nika_event::{EventKind, EventLog};
39
40#[derive(Clone)]
63pub struct McpClientPool {
64 inner: Arc<PoolInner>,
65}
66
67struct PoolInner {
68 clients: DashMap<String, Arc<OnceCell<Arc<McpClient>>>>,
79
80 configs: parking_lot::RwLock<FxHashMap<String, McpConfigInline>>,
84
85 event_log: EventLog,
87
88 is_shutdown: AtomicBool,
90}
91
92impl McpClientPool {
93 pub fn new(event_log: EventLog) -> Self {
95 Self {
96 inner: Arc::new(PoolInner {
97 clients: DashMap::new(),
98 configs: parking_lot::RwLock::new(FxHashMap::default()),
99 event_log,
100 is_shutdown: AtomicBool::new(false),
101 }),
102 }
103 }
104
105 pub fn with_configs(event_log: EventLog, configs: FxHashMap<String, McpConfigInline>) -> Self {
107 Self {
108 inner: Arc::new(PoolInner {
109 clients: DashMap::new(),
110 configs: parking_lot::RwLock::new(configs),
111 event_log,
112 is_shutdown: AtomicBool::new(false),
113 }),
114 }
115 }
116
117 pub fn set_configs(&self, configs: FxHashMap<String, McpConfigInline>) {
126 *self.inner.configs.write() = configs;
127 }
128
129 pub fn configs(&self) -> parking_lot::RwLockReadGuard<'_, FxHashMap<String, McpConfigInline>> {
131 self.inner.configs.read()
132 }
133
134 pub fn has_config(&self, name: &str) -> bool {
136 self.inner.configs.read().contains_key(name)
137 }
138
139 pub fn config_count(&self) -> usize {
141 self.inner.configs.read().len()
142 }
143
144 pub fn event_log(&self) -> &EventLog {
146 &self.inner.event_log
147 }
148
149 pub async fn get_or_connect(&self, name: &str) -> Result<Arc<McpClient>, McpError> {
167 if self.inner.is_shutdown.load(Ordering::SeqCst) {
169 return Err(McpError::McpStartError {
170 name: name.to_string(),
171 reason: "MCP client pool is shut down".to_string(),
172 });
173 }
174
175 let name_owned = name.to_string();
177
178 let cell = self
182 .inner
183 .clients
184 .entry(name_owned.clone())
185 .or_insert_with(|| Arc::new(OnceCell::new()))
186 .clone();
187
188 let pool_inner = Arc::clone(&self.inner);
191
192 let client = cell
197 .get_or_try_init(|| async {
198 if pool_inner.is_shutdown.load(Ordering::SeqCst) {
202 return Err(McpError::McpStartError {
203 name: name_owned.clone(),
204 reason: "MCP client pool is shut down".to_string(),
205 });
206 }
207 Self::connect_server(&pool_inner.configs, &pool_inner.event_log, &name_owned).await
208 })
209 .await?;
210
211 Ok(Arc::clone(client))
212 }
213
214 async fn connect_server(
216 configs: &parking_lot::RwLock<FxHashMap<String, McpConfigInline>>,
217 event_log: &EventLog,
218 name: &str,
219 ) -> Result<Arc<McpClient>, McpError> {
220 let config = {
222 let guard = configs.read();
223 guard.get(name).cloned()
224 };
225
226 let config = config.ok_or_else(|| McpError::McpNotConfigured {
227 name: name.to_string(),
228 })?;
229
230 let mut mcp_config = McpConfig::new(name, &config.command);
232 for arg in &config.args {
233 mcp_config = mcp_config.with_arg(arg);
234 }
235 for (key, value) in &config.env {
236 mcp_config = mcp_config.with_env(key, value);
237 }
238 if let Some(cwd) = &config.cwd {
239 mcp_config = mcp_config.with_cwd(cwd);
240 }
241
242 let mcp_config = mcp_config
244 .expand_env_vars()
245 .map_err(|e| McpError::McpStartError {
246 name: name.to_string(),
247 reason: format!("Environment variable expansion failed: {}", e),
248 })?;
249
250 let client = McpClient::new(mcp_config)
252 .map_err(|e| McpError::McpStartError {
253 name: name.to_string(),
254 reason: e.to_string(),
255 })?
256 .with_validation(ValidationConfig::default());
257
258 match client.connect().await {
259 Ok(()) => {
260 if let Err(e) = client.list_tools().await {
262 tracing::warn!(mcp_server = %name, error = %e, "Failed to cache tools");
263 }
264
265 tracing::info!(mcp_server = %name, "Connected to MCP server");
266 event_log.emit(EventKind::McpConnected {
267 server_name: name.to_string(),
268 });
269
270 Ok(Arc::new(client))
271 }
272 Err(e) => {
273 let error_msg = e.to_string();
274 event_log.emit(EventKind::McpError {
275 server_name: name.to_string(),
276 error: error_msg.clone(),
277 });
278
279 Err(McpError::McpStartError {
280 name: name.to_string(),
281 reason: error_msg,
282 })
283 }
284 }
285 }
286
287 pub fn is_connected(&self, name: &str) -> bool {
293 self.inner
294 .clients
295 .get(name)
296 .and_then(|cell| cell.get().map(|_| true))
297 .unwrap_or(false)
298 }
299
300 pub fn connected_count(&self) -> usize {
302 self.inner
303 .clients
304 .iter()
305 .filter(|entry| entry.value().get().is_some())
306 .count()
307 }
308
309 pub fn is_shutdown(&self) -> bool {
311 self.inner.is_shutdown.load(Ordering::SeqCst)
312 }
313
314 pub async fn disconnect(&self, name: &str) -> Result<(), McpError> {
323 let disconnect_err = if let Some(cell) = self.inner.clients.get(name) {
325 if let Some(client) = cell.get() {
326 client.disconnect().await.err()
327 } else {
328 None
329 }
330 } else {
331 None
332 };
333
334 self.inner.clients.remove(name);
336
337 if let Some(e) = disconnect_err {
338 return Err(e);
339 }
340 Ok(())
341 }
342
343 pub async fn shutdown_all(&self) {
352 self.inner.is_shutdown.store(true, Ordering::SeqCst);
354
355 let entries: Vec<(String, Arc<OnceCell<Arc<McpClient>>>)> = self
357 .inner
358 .clients
359 .iter()
360 .map(|entry| (entry.key().clone(), Arc::clone(entry.value())))
361 .collect();
362
363 self.inner.clients.clear();
364
365 for (name, cell) in entries {
367 if let Some(client) = cell.get() {
368 let disconnect_result =
369 tokio::time::timeout(Duration::from_secs(5), client.disconnect()).await;
370
371 match disconnect_result {
372 Ok(Ok(())) => {
373 tracing::debug!(server = %name, "MCP server disconnected");
374 }
375 Ok(Err(e)) => {
376 tracing::warn!(server = %name, error = %e, "Error disconnecting MCP server");
377 }
378 Err(_) => {
379 tracing::warn!(server = %name, "MCP server disconnect timed out (5s)");
380 }
381 }
382 }
383 }
384 }
385
386 pub fn inject_mock(&self, name: &str, client: Arc<McpClient>) {
395 let cell = Arc::new(OnceCell::new());
396 let _ = cell.set(client);
398 self.inner.clients.insert(name.to_string(), cell);
399 }
400}
401
402impl std::fmt::Debug for McpClientPool {
403 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
404 f.debug_struct("McpClientPool")
405 .field("connected", &self.connected_count())
406 .field("configured", &self.inner.configs.read().len())
407 .field("is_shutdown", &self.is_shutdown())
408 .finish()
409 }
410}
411
412const _: () = {
415 fn _assert_send_sync_clone<T: Send + Sync + Clone>() {}
416 fn _check() {
417 _assert_send_sync_clone::<McpClientPool>();
418 }
419};
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use nika_event::EventLog;
425
426 #[test]
427 fn test_pool_new_is_empty() {
428 let pool = McpClientPool::new(EventLog::new());
429 assert_eq!(pool.connected_count(), 0);
430 assert!(!pool.is_shutdown());
431 }
432
433 #[test]
434 fn test_pool_with_configs() {
435 let mut configs = FxHashMap::default();
436 configs.insert(
437 "test".to_string(),
438 McpConfigInline {
439 command: "echo".to_string(),
440 args: vec![],
441 env: FxHashMap::default(),
442 cwd: None,
443 },
444 );
445
446 let pool = McpClientPool::with_configs(EventLog::new(), configs);
447 assert!(pool.has_config("test"));
448 assert!(!pool.has_config("missing"));
449 }
450
451 #[test]
452 fn test_pool_clone_shares_state() {
453 let pool1 = McpClientPool::new(EventLog::new());
454 let pool2 = pool1.clone();
455
456 let mock = Arc::new(McpClient::mock("test"));
457 pool1.inject_mock("test", mock);
458
459 assert!(pool2.is_connected("test"));
461 }
462
463 #[test]
464 fn test_pool_is_connected_false_when_empty() {
465 let pool = McpClientPool::new(EventLog::new());
466 assert!(!pool.is_connected("neo4j"));
467 }
468
469 #[test]
470 fn test_pool_inject_mock() {
471 let pool = McpClientPool::new(EventLog::new());
472 let mock = Arc::new(McpClient::mock("novanet"));
473 pool.inject_mock("novanet", mock);
474
475 assert!(pool.is_connected("novanet"));
476 assert_eq!(pool.connected_count(), 1);
477 }
478
479 #[tokio::test]
480 async fn test_pool_get_or_connect_with_mock() {
481 let pool = McpClientPool::new(EventLog::new());
482 let mock = Arc::new(McpClient::mock("novanet"));
483 pool.inject_mock("novanet", mock);
484
485 let client = pool.get_or_connect("novanet").await.unwrap();
486 assert!(client.is_connected());
487 assert_eq!(client.name(), "novanet");
488 }
489
490 #[tokio::test]
491 async fn test_pool_get_or_connect_not_configured() {
492 let pool = McpClientPool::new(EventLog::new());
493 let result = pool.get_or_connect("missing").await;
494 assert!(result.is_err());
495 assert!(
496 result.unwrap_err().to_string().contains("not configured"),
497 "Expected McpNotConfigured error"
498 );
499 }
500
501 #[tokio::test]
502 async fn test_pool_shutdown_rejects_new_connections() {
503 let pool = McpClientPool::new(EventLog::new());
504 pool.shutdown_all().await;
505
506 assert!(pool.is_shutdown());
507 let result = pool.get_or_connect("test").await;
508 assert!(result.is_err());
509 assert!(result.unwrap_err().to_string().contains("shut down"));
510 }
511
512 #[tokio::test]
513 async fn test_pool_disconnect_single_server() {
514 let pool = McpClientPool::new(EventLog::new());
515 let mock = Arc::new(McpClient::mock("test"));
516 pool.inject_mock("test", mock);
517
518 assert!(pool.is_connected("test"));
519 pool.disconnect("test").await.unwrap();
520 assert!(!pool.is_connected("test"));
521 }
522
523 #[tokio::test]
524 async fn test_pool_shutdown_clears_all() {
525 let pool = McpClientPool::new(EventLog::new());
526 pool.inject_mock("a", Arc::new(McpClient::mock("a")));
527 pool.inject_mock("b", Arc::new(McpClient::mock("b")));
528 assert_eq!(pool.connected_count(), 2);
529
530 pool.shutdown_all().await;
531 assert_eq!(pool.connected_count(), 0);
532 assert!(pool.is_shutdown());
533 }
534
535 #[test]
536 fn test_pool_set_configs() {
537 let pool = McpClientPool::new(EventLog::new());
538 assert!(!pool.has_config("neo4j"));
539
540 let mut configs = FxHashMap::default();
541 configs.insert(
542 "neo4j".to_string(),
543 McpConfigInline {
544 command: "npx".to_string(),
545 args: vec![],
546 env: FxHashMap::default(),
547 cwd: None,
548 },
549 );
550 pool.set_configs(configs);
551 assert!(pool.has_config("neo4j"));
552 }
553
554 #[test]
555 fn test_pool_debug_format() {
556 let pool = McpClientPool::new(EventLog::new());
557 let debug = format!("{:?}", pool);
558 assert!(debug.contains("McpClientPool"));
559 assert!(debug.contains("connected: 0"));
560 }
561}