use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use dashmap::DashMap;
use rustc_hash::FxHashMap;
use tokio::sync::OnceCell;
use crate::ast::McpConfigInline;
use crate::error::NikaError;
use crate::event::{EventKind, EventLog};
use crate::mcp::types::McpConfig;
use crate::mcp::McpClient;
#[derive(Clone)]
pub struct McpClientPool {
inner: Arc<PoolInner>,
}
struct PoolInner {
clients: DashMap<String, Arc<OnceCell<Arc<McpClient>>>>,
configs: parking_lot::RwLock<FxHashMap<String, McpConfigInline>>,
event_log: EventLog,
is_shutdown: AtomicBool,
}
impl McpClientPool {
pub fn new(event_log: EventLog) -> Self {
Self {
inner: Arc::new(PoolInner {
clients: DashMap::new(),
configs: parking_lot::RwLock::new(FxHashMap::default()),
event_log,
is_shutdown: AtomicBool::new(false),
}),
}
}
pub fn with_configs(event_log: EventLog, configs: FxHashMap<String, McpConfigInline>) -> Self {
Self {
inner: Arc::new(PoolInner {
clients: DashMap::new(),
configs: parking_lot::RwLock::new(configs),
event_log,
is_shutdown: AtomicBool::new(false),
}),
}
}
pub fn set_configs(&self, configs: FxHashMap<String, McpConfigInline>) {
*self.inner.configs.write() = configs;
}
pub fn configs(&self) -> parking_lot::RwLockReadGuard<'_, FxHashMap<String, McpConfigInline>> {
self.inner.configs.read()
}
pub fn has_config(&self, name: &str) -> bool {
self.inner.configs.read().contains_key(name)
}
pub fn config_count(&self) -> usize {
self.inner.configs.read().len()
}
pub fn event_log(&self) -> &EventLog {
&self.inner.event_log
}
pub async fn get_or_connect(&self, name: &str) -> Result<Arc<McpClient>, NikaError> {
if self.inner.is_shutdown.load(Ordering::SeqCst) {
return Err(NikaError::McpStartError {
name: name.to_string(),
reason: "MCP client pool is shut down".to_string(),
});
}
let name_owned = name.to_string();
let cell = self
.inner
.clients
.entry(name_owned.clone())
.or_insert_with(|| Arc::new(OnceCell::new()))
.clone();
let pool_inner = Arc::clone(&self.inner);
let client = cell
.get_or_try_init(|| async {
if pool_inner.is_shutdown.load(Ordering::SeqCst) {
return Err(NikaError::McpStartError {
name: name_owned.clone(),
reason: "MCP client pool is shut down".to_string(),
});
}
Self::connect_server(&pool_inner.configs, &pool_inner.event_log, &name_owned).await
})
.await?;
Ok(Arc::clone(client))
}
async fn connect_server(
configs: &parking_lot::RwLock<FxHashMap<String, McpConfigInline>>,
event_log: &EventLog,
name: &str,
) -> Result<Arc<McpClient>, NikaError> {
let config = {
let guard = configs.read();
guard.get(name).cloned()
};
let config = config.ok_or_else(|| NikaError::McpNotConfigured {
name: name.to_string(),
})?;
let mut mcp_config = McpConfig::new(name, &config.command);
for arg in &config.args {
mcp_config = mcp_config.with_arg(arg);
}
for (key, value) in &config.env {
mcp_config = mcp_config.with_env(key, value);
}
if let Some(cwd) = &config.cwd {
mcp_config = mcp_config.with_cwd(cwd);
}
let mcp_config = mcp_config
.expand_env_vars()
.map_err(|e| NikaError::McpStartError {
name: name.to_string(),
reason: format!("Environment variable expansion failed: {}", e),
})?;
let client = McpClient::new(mcp_config).map_err(|e| NikaError::McpStartError {
name: name.to_string(),
reason: e.to_string(),
})?;
match client.connect().await {
Ok(()) => {
if let Err(e) = client.list_tools().await {
tracing::warn!(mcp_server = %name, error = %e, "Failed to cache tools");
}
tracing::info!(mcp_server = %name, "Connected to MCP server");
event_log.emit(EventKind::McpConnected {
server_name: name.to_string(),
});
Ok(Arc::new(client))
}
Err(e) => {
let error_msg = e.to_string();
event_log.emit(EventKind::McpError {
server_name: name.to_string(),
error: error_msg.clone(),
});
Err(NikaError::McpStartError {
name: name.to_string(),
reason: error_msg,
})
}
}
}
pub fn is_connected(&self, name: &str) -> bool {
self.inner
.clients
.get(name)
.and_then(|cell| cell.get().map(|_| true))
.unwrap_or(false)
}
pub fn connected_count(&self) -> usize {
self.inner
.clients
.iter()
.filter(|entry| entry.value().get().is_some())
.count()
}
pub fn is_shutdown(&self) -> bool {
self.inner.is_shutdown.load(Ordering::SeqCst)
}
pub async fn disconnect(&self, name: &str) -> Result<(), NikaError> {
if let Some(cell) = self.inner.clients.get(name) {
if let Some(client) = cell.get() {
client.disconnect().await?;
}
}
self.inner.clients.remove(name);
Ok(())
}
pub async fn shutdown_all(&self) {
self.inner.is_shutdown.store(true, Ordering::SeqCst);
let entries: Vec<(String, Arc<OnceCell<Arc<McpClient>>>)> = self
.inner
.clients
.iter()
.map(|entry| (entry.key().clone(), Arc::clone(entry.value())))
.collect();
self.inner.clients.clear();
for (name, cell) in entries {
if let Some(client) = cell.get() {
let disconnect_result =
tokio::time::timeout(Duration::from_secs(5), client.disconnect()).await;
match disconnect_result {
Ok(Ok(())) => {
tracing::debug!(server = %name, "MCP server disconnected");
}
Ok(Err(e)) => {
tracing::warn!(server = %name, error = %e, "Error disconnecting MCP server");
}
Err(_) => {
tracing::warn!(server = %name, "MCP server disconnect timed out (5s)");
}
}
}
}
}
#[cfg(test)]
pub fn inject_mock(&self, name: &str, client: Arc<McpClient>) {
let cell = Arc::new(OnceCell::new());
cell.set(client)
.expect("freshly created cell should be empty");
self.inner.clients.insert(name.to_string(), cell);
}
}
impl std::fmt::Debug for McpClientPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("McpClientPool")
.field("connected", &self.connected_count())
.field("configured", &self.inner.configs.read().len())
.field("is_shutdown", &self.is_shutdown())
.finish()
}
}
const _: () = {
fn _assert_send_sync_clone<T: Send + Sync + Clone>() {}
fn _check() {
_assert_send_sync_clone::<McpClientPool>();
}
};
#[cfg(test)]
mod tests {
use super::*;
use crate::event::EventLog;
#[test]
fn test_pool_new_is_empty() {
let pool = McpClientPool::new(EventLog::new());
assert_eq!(pool.connected_count(), 0);
assert!(!pool.is_shutdown());
}
#[test]
fn test_pool_with_configs() {
let mut configs = FxHashMap::default();
configs.insert(
"test".to_string(),
McpConfigInline {
command: "echo".to_string(),
args: vec![],
env: FxHashMap::default(),
cwd: None,
},
);
let pool = McpClientPool::with_configs(EventLog::new(), configs);
assert!(pool.has_config("test"));
assert!(!pool.has_config("missing"));
}
#[test]
fn test_pool_clone_shares_state() {
let pool1 = McpClientPool::new(EventLog::new());
let pool2 = pool1.clone();
let mock = Arc::new(McpClient::mock("test"));
pool1.inject_mock("test", mock);
assert!(pool2.is_connected("test"));
}
#[test]
fn test_pool_is_connected_false_when_empty() {
let pool = McpClientPool::new(EventLog::new());
assert!(!pool.is_connected("neo4j"));
}
#[test]
fn test_pool_inject_mock() {
let pool = McpClientPool::new(EventLog::new());
let mock = Arc::new(McpClient::mock("novanet"));
pool.inject_mock("novanet", mock);
assert!(pool.is_connected("novanet"));
assert_eq!(pool.connected_count(), 1);
}
#[tokio::test]
async fn test_pool_get_or_connect_with_mock() {
let pool = McpClientPool::new(EventLog::new());
let mock = Arc::new(McpClient::mock("novanet"));
pool.inject_mock("novanet", mock);
let client = pool.get_or_connect("novanet").await.unwrap();
assert!(client.is_connected());
assert_eq!(client.name(), "novanet");
}
#[tokio::test]
async fn test_pool_get_or_connect_not_configured() {
let pool = McpClientPool::new(EventLog::new());
let result = pool.get_or_connect("missing").await;
assert!(result.is_err());
assert!(
result.unwrap_err().to_string().contains("not configured"),
"Expected McpNotConfigured error"
);
}
#[tokio::test]
async fn test_pool_shutdown_rejects_new_connections() {
let pool = McpClientPool::new(EventLog::new());
pool.shutdown_all().await;
assert!(pool.is_shutdown());
let result = pool.get_or_connect("test").await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("shut down"));
}
#[tokio::test]
async fn test_pool_disconnect_single_server() {
let pool = McpClientPool::new(EventLog::new());
let mock = Arc::new(McpClient::mock("test"));
pool.inject_mock("test", mock);
assert!(pool.is_connected("test"));
pool.disconnect("test").await.unwrap();
assert!(!pool.is_connected("test"));
}
#[tokio::test]
async fn test_pool_shutdown_clears_all() {
let pool = McpClientPool::new(EventLog::new());
pool.inject_mock("a", Arc::new(McpClient::mock("a")));
pool.inject_mock("b", Arc::new(McpClient::mock("b")));
assert_eq!(pool.connected_count(), 2);
pool.shutdown_all().await;
assert_eq!(pool.connected_count(), 0);
assert!(pool.is_shutdown());
}
#[test]
fn test_pool_set_configs() {
let pool = McpClientPool::new(EventLog::new());
assert!(!pool.has_config("neo4j"));
let mut configs = FxHashMap::default();
configs.insert(
"neo4j".to_string(),
McpConfigInline {
command: "npx".to_string(),
args: vec![],
env: FxHashMap::default(),
cwd: None,
},
);
pool.set_configs(configs);
assert!(pool.has_config("neo4j"));
}
#[test]
fn test_pool_debug_format() {
let pool = McpClientPool::new(EventLog::new());
let debug = format!("{:?}", pool);
assert!(debug.contains("McpClientPool"));
assert!(debug.contains("connected: 0"));
}
}