use crate::error::Result;
use crate::mcp::tools::ToolResult;
use crate::mcp::{McpClient, McpClientConfig, McpClientTrait};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
const MAX_IDLE_SECS: u64 = 300;
const MAX_CONNECTIONS_PER_SERVER: usize = 4;
struct PooledConnection {
client: McpClient,
last_used: Instant,
call_count: u64,
}
impl PooledConnection {
fn new(client: McpClient) -> Self {
Self {
client,
last_used: Instant::now(),
call_count: 0,
}
}
fn is_idle_too_long(&self) -> bool {
self.last_used.elapsed() > Duration::from_secs(MAX_IDLE_SECS)
}
fn mark_used(&mut self) {
self.last_used = Instant::now();
self.call_count += 1;
}
}
#[derive(Debug, Clone, Default)]
pub struct PoolStats {
pub active_connections: usize,
pub total_calls: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub idle_closes: u64,
}
pub struct ConnectionPool {
connections: Arc<RwLock<HashMap<String, Vec<PooledConnection>>>>,
stats: Arc<PoolStatsInner>,
}
struct PoolStatsInner {
total_calls: AtomicU64,
cache_hits: AtomicU64,
cache_misses: AtomicU64,
idle_closes: AtomicU64,
}
impl ConnectionPool {
pub fn new() -> Self {
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(PoolStatsInner {
total_calls: AtomicU64::new(0),
cache_hits: AtomicU64::new(0),
cache_misses: AtomicU64::new(0),
idle_closes: AtomicU64::new(0),
}),
}
}
pub async fn call_tool(
&self,
config: &McpClientConfig,
tool_name: &str,
args: serde_json::Value,
) -> Result<ToolResult> {
self.stats.total_calls.fetch_add(1, Ordering::Relaxed);
let mut connections = self.connections.write().await;
let server_name = &config.name;
self.cleanup_idle(&mut connections);
if let Some(pool) = connections.get_mut(server_name) {
if let Some(conn) = pool.iter_mut().find(|c| !c.is_idle_too_long()) {
conn.mark_used();
self.stats.cache_hits.fetch_add(1, Ordering::Relaxed);
debug!(server = %server_name, "Reusing pooled connection");
let result = conn.client.call_tool(tool_name, args.clone()).await;
if let Err(ref e) = result {
warn!(error = %e, "Pooled connection failed, will retry with new connection");
pool.retain(|c| !c.is_idle_too_long());
drop(connections);
return self.create_and_call(config, tool_name, args).await;
}
return result;
}
}
drop(connections);
self.create_and_call(config, tool_name, args).await
}
async fn create_and_call(
&self,
config: &McpClientConfig,
tool_name: &str,
args: serde_json::Value,
) -> Result<ToolResult> {
self.stats.cache_misses.fetch_add(1, Ordering::Relaxed);
let mut client = McpClient::new(config.clone());
client.connect().await?;
info!(server = %config.name, "Created new pooled connection");
let result = client.call_tool(tool_name, args).await?;
let mut connections = self.connections.write().await;
let pool = connections.entry(config.name.clone()).or_default();
if pool.len() < MAX_CONNECTIONS_PER_SERVER {
let mut conn = PooledConnection::new(client);
conn.mark_used();
pool.push(conn);
debug!(
server = %config.name,
pool_size = pool.len(),
"Added connection to pool"
);
} else {
let _ = client.disconnect().await;
}
Ok(result)
}
fn cleanup_idle(&self, connections: &mut HashMap<String, Vec<PooledConnection>>) {
for (server, pool) in connections.iter_mut() {
let before = pool.len();
pool.retain(|c| !c.is_idle_too_long());
let removed = before - pool.len();
if removed > 0 {
self.stats
.idle_closes
.fetch_add(removed as u64, Ordering::Relaxed);
debug!(server = %server, removed = removed, "Cleaned up idle connections");
}
}
connections.retain(|_, pool| !pool.is_empty());
}
pub fn clear(&self) {
let rt = tokio::runtime::Handle::try_current();
if let Ok(handle) = rt {
let connections = self.connections.clone();
handle.spawn(async move {
let mut conns = connections.write().await;
for (_, pool) in conns.drain() {
for mut conn in pool {
let _ = conn.client.disconnect().await;
}
}
});
}
}
pub fn stats(&self) -> PoolStats {
let connections = self.connections.try_read();
let active = connections
.map(|c| c.values().map(|p| p.len()).sum())
.unwrap_or(0);
PoolStats {
active_connections: active,
total_calls: self.stats.total_calls.load(Ordering::Relaxed),
cache_hits: self.stats.cache_hits.load(Ordering::Relaxed),
cache_misses: self.stats.cache_misses.load(Ordering::Relaxed),
idle_closes: self.stats.idle_closes.load(Ordering::Relaxed),
}
}
pub async fn server_connections(&self, server_name: &str) -> usize {
let connections = self.connections.read().await;
connections.get(server_name).map(|p| p.len()).unwrap_or(0)
}
}
impl Default for ConnectionPool {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_stats_default() {
let stats = PoolStats::default();
assert_eq!(stats.active_connections, 0);
assert_eq!(stats.total_calls, 0);
assert_eq!(stats.cache_hits, 0);
assert_eq!(stats.cache_misses, 0);
}
#[test]
fn test_pool_creation() {
let pool = ConnectionPool::new();
let stats = pool.stats();
assert_eq!(stats.active_connections, 0);
}
#[tokio::test]
async fn test_server_connections_empty() {
let pool = ConnectionPool::new();
let count = pool.server_connections("nonexistent").await;
assert_eq!(count, 0);
}
#[test]
fn test_pooled_connection_idle() {
use std::thread::sleep;
let config = McpClientConfig {
name: "test".to_string(),
command: "echo".to_string(),
args: vec![],
env: HashMap::new(),
timeout_secs: 30,
auto_reconnect: false,
max_retries: 1,
};
let client = McpClient::new(config);
let conn = PooledConnection::new(client);
assert!(!conn.is_idle_too_long());
}
#[test]
fn test_max_connections() {
assert_eq!(MAX_CONNECTIONS_PER_SERVER, 4);
assert_eq!(MAX_IDLE_SECS, 300);
}
}