use std::{
collections::{hash_map::DefaultHasher, HashMap},
hash::{Hash, Hasher},
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use lru::LruCache;
use parking_lot::Mutex;
use rmcp::{service::RunningService, RoleClient};
use super::config::{McpProxyConfig, McpServerConfig, McpTransport};
use crate::error::McpResult;
type McpClient = RunningService<RoleClient, ()>;
type EvictionCallback = Arc<dyn Fn(&PoolKey) + Send + Sync>;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct PoolKey {
pub url: String,
pub auth_hash: u64,
pub tenant_id: Option<String>,
}
impl PoolKey {
pub fn new(url: impl Into<String>, auth_hash: u64, tenant_id: Option<String>) -> Self {
Self {
url: url.into(),
auth_hash,
tenant_id,
}
}
pub fn from_config(config: &McpServerConfig, tenant_id: Option<String>) -> Self {
let (url, auth_hash) = match &config.transport {
McpTransport::Streamable {
url,
token,
headers,
} => (url.clone(), Self::hash_auth(token.as_ref(), headers)),
McpTransport::Sse {
url,
token,
headers,
} => (url.clone(), Self::hash_auth(token.as_ref(), headers)),
McpTransport::Stdio { command, args, .. } => {
(format!("{}:{}", command, args.join(" ")), 0)
}
};
Self {
url,
auth_hash,
tenant_id,
}
}
fn hash_auth(token: Option<&String>, headers: &HashMap<String, String>) -> u64 {
if token.is_none() && headers.is_empty() {
return 0;
}
let mut hasher = DefaultHasher::new();
if let Some(t) = token {
t.hash(&mut hasher);
}
if !headers.is_empty() {
let mut sorted_headers: Vec<_> = headers.iter().collect();
sorted_headers.sort_by_key(|(k, _)| *k);
for (key, value) in sorted_headers {
key.hash(&mut hasher);
value.hash(&mut hasher);
}
}
hasher.finish()
}
#[inline]
pub fn url(&self) -> &str {
&self.url
}
}
#[derive(Clone)]
pub(crate) struct CachedConnection {
pub client: Arc<McpClient>,
}
impl CachedConnection {
pub fn new(client: Arc<McpClient>) -> Self {
Self { client }
}
}
pub struct McpConnectionPool {
connections: Arc<Mutex<LruCache<PoolKey, CachedConnection>>>,
connection_count: AtomicUsize,
max_connections: usize,
global_proxy: Option<McpProxyConfig>,
eviction_callback: Option<EvictionCallback>,
}
impl McpConnectionPool {
const DEFAULT_MAX_CONNECTIONS: usize = 200;
pub fn new() -> Self {
Self::with_full_config(Self::DEFAULT_MAX_CONNECTIONS, McpProxyConfig::from_env())
}
pub fn with_capacity(max_connections: usize) -> Self {
Self::with_full_config(max_connections, McpProxyConfig::from_env())
}
pub fn with_full_config(max_connections: usize, global_proxy: Option<McpProxyConfig>) -> Self {
let max_connections = max_connections.max(1);
let cache_cap =
std::num::NonZeroUsize::new(max_connections).unwrap_or(std::num::NonZeroUsize::MIN);
Self {
connections: Arc::new(Mutex::new(LruCache::new(cache_cap))),
connection_count: AtomicUsize::new(0),
max_connections,
global_proxy,
eviction_callback: None,
}
}
pub fn set_eviction_callback<F>(&mut self, callback: F)
where
F: Fn(&PoolKey) + Send + Sync + 'static,
{
self.eviction_callback = Some(Arc::new(callback));
}
pub async fn get_or_create<F, Fut>(
&self,
key: PoolKey,
server_config: McpServerConfig,
connect_fn: F,
) -> McpResult<Arc<McpClient>>
where
F: FnOnce(McpServerConfig, Option<McpProxyConfig>) -> Fut,
Fut: std::future::Future<Output = McpResult<McpClient>>,
{
{
let mut connections = self.connections.lock();
if let Some(cached) = connections.get(&key) {
return Ok(Arc::clone(&cached.client));
}
}
let client = connect_fn(server_config.clone(), self.global_proxy.clone()).await?;
let client_arc = Arc::new(client);
let cached = CachedConnection::new(Arc::clone(&client_arc));
{
let mut connections = self.connections.lock();
match connections.push(key, cached) {
Some((evicted_key, _)) => {
if let Some(callback) = &self.eviction_callback {
callback(&evicted_key);
}
}
None => {
self.connection_count.fetch_add(1, Ordering::Relaxed);
}
}
}
Ok(client_arc)
}
pub fn len(&self) -> usize {
self.connection_count.load(Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
self.connection_count.load(Ordering::Relaxed) == 0
}
pub fn clear(&self) {
let mut connections = self.connections.lock();
connections.clear();
self.connection_count.store(0, Ordering::Relaxed);
}
pub fn stats(&self) -> PoolStats {
PoolStats {
total_connections: self.connection_count.load(Ordering::Relaxed),
capacity: self.max_connections,
}
}
pub fn list_keys(&self) -> Vec<PoolKey> {
self.connections
.lock()
.iter()
.map(|(key, _)| key.clone())
.collect()
}
pub fn get(&self, key: &PoolKey) -> Option<Arc<McpClient>> {
self.connections
.lock()
.get(key)
.map(|cached| Arc::clone(&cached.client))
}
pub fn contains(&self, key: &PoolKey) -> bool {
self.connections.lock().contains(key)
}
pub fn get_by_url(&self, url: &str) -> Option<Arc<McpClient>> {
self.connections
.lock()
.iter()
.find(|(key, _)| key.url == url)
.map(|(_, cached)| Arc::clone(&cached.client))
}
pub fn contains_url(&self, url: &str) -> bool {
self.connections
.lock()
.iter()
.any(|(key, _)| key.url == url)
}
}
impl Default for McpConnectionPool {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct PoolStats {
pub total_connections: usize,
pub capacity: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::config::McpTransport;
fn create_test_config(url: &str) -> McpServerConfig {
McpServerConfig {
name: "test_server".to_string(),
transport: McpTransport::Streamable {
url: url.to_string(),
token: None,
headers: HashMap::new(),
},
proxy: None,
required: false,
tools: None,
builtin_type: None,
builtin_tool_name: None,
}
}
#[tokio::test]
async fn test_pool_creation() {
let pool = McpConnectionPool::new();
assert_eq!(pool.len(), 0);
assert!(pool.is_empty());
}
#[test]
fn test_pool_stats() {
let pool = McpConnectionPool::with_capacity(10);
let stats = pool.stats();
assert_eq!(stats.total_connections, 0);
assert_eq!(stats.capacity, 10);
}
#[test]
fn test_pool_clear() {
let pool = McpConnectionPool::new();
assert_eq!(pool.len(), 0);
pool.clear();
assert!(pool.is_empty());
}
#[test]
fn test_pool_key_from_config() {
let config = create_test_config("http://localhost:3000");
let key = PoolKey::from_config(&config, None);
assert_eq!(key.url, "http://localhost:3000");
assert_eq!(key.auth_hash, 0);
assert_eq!(key.tenant_id, None);
let config_with_token = McpServerConfig {
name: "test".to_string(),
transport: McpTransport::Streamable {
url: "http://localhost:3000".to_string(),
token: Some("secret-token".to_string()),
headers: HashMap::new(),
},
proxy: None,
required: false,
tools: None,
builtin_type: None,
builtin_tool_name: None,
};
let key_with_token = PoolKey::from_config(&config_with_token, None);
assert_eq!(key_with_token.url, "http://localhost:3000");
assert_ne!(key_with_token.auth_hash, 0);
let key_with_tenant = PoolKey::from_config(&config, Some("tenant-123".to_string()));
assert_eq!(key_with_tenant.tenant_id, Some("tenant-123".to_string()));
}
#[test]
fn test_pool_key_different_tokens() {
let config1 = McpServerConfig {
name: "test".to_string(),
transport: McpTransport::Streamable {
url: "http://localhost:3000".to_string(),
token: Some("token-a".to_string()),
headers: HashMap::new(),
},
proxy: None,
required: false,
tools: None,
builtin_type: None,
builtin_tool_name: None,
};
let config2 = McpServerConfig {
name: "test".to_string(),
transport: McpTransport::Streamable {
url: "http://localhost:3000".to_string(),
token: Some("token-b".to_string()),
headers: HashMap::new(),
},
proxy: None,
required: false,
tools: None,
builtin_type: None,
builtin_tool_name: None,
};
let key1 = PoolKey::from_config(&config1, None);
let key2 = PoolKey::from_config(&config2, None);
assert_eq!(key1.url, key2.url);
assert_ne!(key1.auth_hash, key2.auth_hash);
assert_ne!(key1, key2);
}
#[test]
fn test_pool_key_with_headers() {
let mut headers1 = HashMap::new();
headers1.insert("X-API-Key".to_string(), "key-1".to_string());
let mut headers2 = HashMap::new();
headers2.insert("X-API-Key".to_string(), "key-2".to_string());
let config1 = McpServerConfig {
name: "test".to_string(),
transport: McpTransport::Sse {
url: "http://localhost:3000".to_string(),
token: None,
headers: headers1,
},
proxy: None,
required: false,
tools: None,
builtin_type: None,
builtin_tool_name: None,
};
let config2 = McpServerConfig {
name: "test".to_string(),
transport: McpTransport::Sse {
url: "http://localhost:3000".to_string(),
token: None,
headers: headers2,
},
proxy: None,
required: false,
tools: None,
builtin_type: None,
builtin_tool_name: None,
};
let key1 = PoolKey::from_config(&config1, None);
let key2 = PoolKey::from_config(&config2, None);
assert_eq!(key1.url, key2.url);
assert_ne!(key1.auth_hash, key2.auth_hash);
assert_ne!(key1, key2);
}
#[test]
fn test_pool_with_global_proxy() {
use crate::core::config::McpProxyConfig;
let proxy = McpProxyConfig {
http: Some("http://proxy.example.com:8080".to_string()),
https: None,
no_proxy: Some("localhost,127.0.0.1".to_string()),
username: None,
password: None,
};
let pool = McpConnectionPool::with_full_config(100, Some(proxy.clone()));
assert!(pool.global_proxy.is_some());
let stored_proxy = pool.global_proxy.as_ref().unwrap();
assert_eq!(
stored_proxy.http.as_ref().unwrap(),
"http://proxy.example.com:8080"
);
assert_eq!(
stored_proxy.no_proxy.as_ref().unwrap(),
"localhost,127.0.0.1"
);
}
#[test]
fn test_pool_proxy_from_env() {
let pool = McpConnectionPool::new();
assert!(pool.global_proxy.is_some() || pool.global_proxy.is_none());
}
}