use std::{collections::HashMap, sync::Arc};
use rmcp::{RoleClient, model::InitializeRequestParams, service::RunningService};
use tokio::sync::RwLock;
use tracing::debug;
use crate::error::RegistryError;
use pctx_config::server::ServerConfig;
type PooledClient = Arc<RunningService<RoleClient, InitializeRequestParams>>;
#[derive(Clone, Default, Debug)]
pub struct McpConnectionPool {
connections: Arc<RwLock<HashMap<String, PooledClient>>>,
}
impl McpConnectionPool {
pub fn new() -> Self {
Self::default()
}
pub async fn get_or_connect(
&self,
cfg: &ServerConfig,
) -> Result<(PooledClient, bool), RegistryError> {
{
let connections = self.connections.read().await;
if let Some(client) = connections.get(&cfg.name) {
if !client.is_closed() {
debug!(server = %cfg.name, "Reusing cached upstream connection");
return Ok((client.clone(), true));
}
debug!(server = %cfg.name, "Cached connection is closed, reconnecting");
}
}
debug!(server = %cfg.name, "Connecting to upstream MCP server");
let new_client = Arc::new(cfg.connect().await.map_err(RegistryError::from)?);
let mut connections = self.connections.write().await;
if let Some(existing) = connections.get(&cfg.name) {
if !existing.is_closed() {
debug!(server = %cfg.name, "Lost connection race, using existing");
return Ok((existing.clone(), true));
}
}
connections.insert(cfg.name.clone(), new_client.clone());
Ok((new_client, false))
}
pub async fn cancel_all(&self) {
let mut connections = self.connections.write().await;
let count = connections.len();
for (name, client) in connections.drain() {
debug!(server = %name, "Cancelling upstream connection");
client.cancellation_token().cancel();
}
if count > 0 {
debug!(count, "Cancelled all upstream connections");
}
}
pub async fn len(&self) -> usize {
self.connections.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.connections.read().await.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn pool_starts_empty() {
let pool = McpConnectionPool::new();
assert!(pool.is_empty().await);
}
#[tokio::test]
async fn cancel_all_on_empty_pool_is_a_noop() {
let pool = McpConnectionPool::new();
pool.cancel_all().await; assert!(pool.is_empty().await);
}
}