1use crate::builder::ClientBuilder;
7use crate::client::Client;
8use mcpkit_core::capability::{ClientCapabilities, ClientInfo};
9use mcpkit_core::error::McpError;
10use mcpkit_transport::Transport;
11use std::collections::HashMap;
12use std::future::Future;
13use std::sync::Arc;
14use tracing::{debug, trace, warn};
15
16use tokio::sync::{Mutex, Semaphore};
18
19#[derive(Debug, Clone)]
21pub struct PoolConfig {
22 pub max_connections: usize,
24 pub acquire_timeout: std::time::Duration,
26 pub validate_on_acquire: bool,
28 pub max_idle_time: std::time::Duration,
30}
31
32impl Default for PoolConfig {
33 fn default() -> Self {
34 Self {
35 max_connections: 10,
36 acquire_timeout: std::time::Duration::from_secs(30),
37 validate_on_acquire: true,
38 max_idle_time: std::time::Duration::from_secs(300),
39 }
40 }
41}
42
43impl PoolConfig {
44 #[must_use]
46 pub fn new() -> Self {
47 Self::default()
48 }
49
50 #[must_use]
52 pub fn max_connections(mut self, max: usize) -> Self {
53 self.max_connections = max;
54 self
55 }
56
57 #[must_use]
59 pub fn acquire_timeout(mut self, timeout: std::time::Duration) -> Self {
60 self.acquire_timeout = timeout;
61 self
62 }
63
64 #[must_use]
66 pub fn validate_on_acquire(mut self, validate: bool) -> Self {
67 self.validate_on_acquire = validate;
68 self
69 }
70
71 #[must_use]
73 pub fn max_idle_time(mut self, time: std::time::Duration) -> Self {
74 self.max_idle_time = time;
75 self
76 }
77}
78
79pub struct PooledClient<T: Transport + 'static> {
83 client: Option<Client<T>>,
84 pool: Arc<ClientPoolInner<T>>,
85 key: String,
86}
87
88impl<T: Transport + 'static> PooledClient<T> {
89 pub fn client(&self) -> &Client<T> {
91 self.client.as_ref().expect("Client already dropped")
92 }
93
94 pub fn client_mut(&mut self) -> &mut Client<T> {
96 self.client.as_mut().expect("Client already dropped")
97 }
98}
99
100impl<T: Transport + 'static> std::ops::Deref for PooledClient<T> {
101 type Target = Client<T>;
102
103 fn deref(&self) -> &Self::Target {
104 self.client()
105 }
106}
107
108impl<T: Transport + 'static> std::ops::DerefMut for PooledClient<T> {
109 fn deref_mut(&mut self) -> &mut Self::Target {
110 self.client_mut()
111 }
112}
113
114impl<T: Transport + 'static> Drop for PooledClient<T> {
115 fn drop(&mut self) {
116 if let Some(client) = self.client.take() {
117 let pool = Arc::clone(&self.pool);
119 let key = self.key.clone();
120 tokio::spawn(async move {
121 pool.return_connection(key, client).await;
122 });
123 }
124 }
125}
126
127struct ClientPoolInner<T: Transport> {
129 config: PoolConfig,
131 connections: Mutex<HashMap<String, Vec<PooledEntry<T>>>>,
133 semaphores: Mutex<HashMap<String, Arc<Semaphore>>>,
135 client_info: ClientInfo,
137 client_caps: ClientCapabilities,
139}
140
141struct PooledEntry<T: Transport> {
143 client: Client<T>,
144 last_used: std::time::Instant,
145}
146
147impl<T: Transport> ClientPoolInner<T> {
148 async fn return_connection(&self, key: String, client: Client<T>) {
150 trace!(%key, "Returning connection to pool");
151
152 let entry = PooledEntry {
153 client,
154 last_used: std::time::Instant::now(),
155 };
156
157 let mut connections = self.connections.lock().await;
158 connections
159 .entry(key)
160 .or_insert_with(Vec::new)
161 .push(entry);
162 }
163
164 async fn get_semaphore(&self, key: &str) -> Arc<Semaphore> {
166 let mut semaphores = self.semaphores.lock().await;
167 semaphores
168 .entry(key.to_string())
169 .or_insert_with(|| Arc::new(Semaphore::new(self.config.max_connections)))
170 .clone()
171 }
172}
173
174pub struct ClientPool<T: Transport> {
208 inner: Arc<ClientPoolInner<T>>,
209}
210
211impl<T: Transport + 'static> ClientPool<T> {
212 pub fn builder() -> ClientPoolBuilder {
214 ClientPoolBuilder::new()
215 }
216
217 pub fn new(client_info: ClientInfo, client_caps: ClientCapabilities) -> Self {
219 Self::with_config(client_info, client_caps, PoolConfig::default())
220 }
221
222 pub fn with_config(
224 client_info: ClientInfo,
225 client_caps: ClientCapabilities,
226 config: PoolConfig,
227 ) -> Self {
228 Self {
229 inner: Arc::new(ClientPoolInner {
230 config,
231 connections: Mutex::new(HashMap::new()),
232 semaphores: Mutex::new(HashMap::new()),
233 client_info,
234 client_caps,
235 }),
236 }
237 }
238
239 pub async fn acquire<F, Fut>(
253 &self,
254 key: impl Into<String>,
255 connect: F,
256 ) -> Result<PooledClient<T>, McpError>
257 where
258 F: FnOnce() -> Fut,
259 Fut: Future<Output = Result<T, McpError>>,
260 {
261 let key = key.into();
262 debug!(%key, "Acquiring connection from pool");
263
264 let semaphore = self.inner.get_semaphore(&key).await;
266
267 let _permit = tokio::time::timeout(
269 self.inner.config.acquire_timeout,
270 semaphore.acquire_owned(),
271 )
272 .await
273 .map_err(|_| McpError::Internal {
274 message: format!("Timeout acquiring connection for {key}"),
275 source: None,
276 })?
277 .map_err(|_| McpError::Internal {
278 message: "Pool semaphore closed".to_string(),
279 source: None,
280 })?;
281
282 {
284 let mut connections = self.inner.connections.lock().await;
285 if let Some(entries) = connections.get_mut(&key) {
286 let max_idle = self.inner.config.max_idle_time;
288 entries.retain(|e| e.last_used.elapsed() < max_idle);
289
290 if let Some(entry) = entries.pop() {
292 trace!(%key, "Reusing existing connection");
293
294 if self.inner.config.validate_on_acquire {
296 if entry.client.ping().await.is_ok() {
298 return Ok(PooledClient {
299 client: Some(entry.client),
300 pool: Arc::clone(&self.inner),
301 key,
302 });
303 }
304 warn!(%key, "Cached connection failed validation");
305 } else {
306 return Ok(PooledClient {
307 client: Some(entry.client),
308 pool: Arc::clone(&self.inner),
309 key,
310 });
311 }
312 }
313 }
314 }
315
316 debug!(%key, "Creating new connection");
318 let transport = connect().await?;
319
320 let client = ClientBuilder::new()
321 .name(self.inner.client_info.name.clone())
322 .version(self.inner.client_info.version.clone())
323 .capabilities(self.inner.client_caps.clone())
324 .build(transport)
325 .await?;
326
327 Ok(PooledClient {
328 client: Some(client),
329 pool: Arc::clone(&self.inner),
330 key,
331 })
332 }
333
334 pub async fn clear(&self) {
336 let mut connections = self.inner.connections.lock().await;
337 connections.clear();
338 debug!("Cleared all pooled connections");
339 }
340
341 pub async fn clear_server(&self, key: &str) {
343 let mut connections = self.inner.connections.lock().await;
344 connections.remove(key);
345 debug!(%key, "Cleared pooled connections for server");
346 }
347
348 pub async fn stats(&self) -> PoolStats {
350 let connections = self.inner.connections.lock().await;
351 let mut total = 0;
352 let mut per_server = HashMap::new();
353
354 for (key, entries) in connections.iter() {
355 let count = entries.len();
356 total += count;
357 per_server.insert(key.clone(), count);
358 }
359
360 PoolStats {
361 total_connections: total,
362 connections_per_server: per_server,
363 max_connections: self.inner.config.max_connections,
364 }
365 }
366}
367
368impl<T: Transport + 'static> Clone for ClientPool<T> {
369 fn clone(&self) -> Self {
370 Self {
371 inner: Arc::clone(&self.inner),
372 }
373 }
374}
375
376#[derive(Debug, Clone)]
378pub struct PoolStats {
379 pub total_connections: usize,
381 pub connections_per_server: HashMap<String, usize>,
383 pub max_connections: usize,
385}
386
387pub struct ClientPoolBuilder {
389 config: PoolConfig,
390 client_info: Option<ClientInfo>,
391 client_caps: ClientCapabilities,
392}
393
394impl ClientPoolBuilder {
395 pub fn new() -> Self {
397 Self {
398 config: PoolConfig::default(),
399 client_info: None,
400 client_caps: ClientCapabilities::default(),
401 }
402 }
403
404 pub fn client_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
406 self.client_info = Some(ClientInfo {
407 name: name.into(),
408 version: version.into(),
409 });
410 self
411 }
412
413 pub fn capabilities(mut self, caps: ClientCapabilities) -> Self {
415 self.client_caps = caps;
416 self
417 }
418
419 pub fn max_connections(mut self, max: usize) -> Self {
421 self.config.max_connections = max;
422 self
423 }
424
425 pub fn acquire_timeout(mut self, timeout: std::time::Duration) -> Self {
427 self.config.acquire_timeout = timeout;
428 self
429 }
430
431 pub fn validate_on_acquire(mut self, validate: bool) -> Self {
433 self.config.validate_on_acquire = validate;
434 self
435 }
436
437 pub fn max_idle_time(mut self, time: std::time::Duration) -> Self {
439 self.config.max_idle_time = time;
440 self
441 }
442
443 pub fn build<T: Transport + 'static>(self) -> ClientPool<T> {
449 let client_info = self
450 .client_info
451 .expect("client_info must be set before building pool");
452
453 ClientPool::with_config(client_info, self.client_caps, self.config)
454 }
455}
456
457impl Default for ClientPoolBuilder {
458 fn default() -> Self {
459 Self::new()
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466
467 #[test]
468 fn test_pool_config() {
469 let config = PoolConfig::new()
470 .max_connections(5)
471 .acquire_timeout(std::time::Duration::from_secs(10))
472 .validate_on_acquire(false)
473 .max_idle_time(std::time::Duration::from_secs(60));
474
475 assert_eq!(config.max_connections, 5);
476 assert_eq!(config.acquire_timeout.as_secs(), 10);
477 assert!(!config.validate_on_acquire);
478 assert_eq!(config.max_idle_time.as_secs(), 60);
479 }
480
481 #[test]
482 fn test_pool_builder() {
483 let builder = ClientPoolBuilder::new()
484 .client_info("test-client", "1.0.0")
485 .max_connections(10)
486 .validate_on_acquire(true);
487
488 assert_eq!(builder.config.max_connections, 10);
489 assert!(builder.config.validate_on_acquire);
490 }
491}