1use crate::builder::ClientBuilder;
7use crate::client::Client;
8use mcpkit_core::capability::{ClientCapabilities, ClientInfo};
9use mcpkit_core::error::McpError;
10use mcpkit_transport::Transport;
11use std::collections::HashMap;
12use std::future::Future;
13use std::sync::Arc;
14use tracing::{debug, trace, warn};
15
16use tokio::sync::{Mutex, Semaphore};
18
19#[derive(Debug, Clone)]
21pub struct PoolConfig {
22 pub max_connections: usize,
24 pub acquire_timeout: std::time::Duration,
26 pub validate_on_acquire: bool,
28 pub max_idle_time: std::time::Duration,
30}
31
32impl Default for PoolConfig {
33 fn default() -> Self {
34 Self {
35 max_connections: 10,
36 acquire_timeout: std::time::Duration::from_secs(30),
37 validate_on_acquire: true,
38 max_idle_time: std::time::Duration::from_secs(300),
39 }
40 }
41}
42
43impl PoolConfig {
44 #[must_use]
46 pub fn new() -> Self {
47 Self::default()
48 }
49
50 #[must_use]
52 pub const fn max_connections(mut self, max: usize) -> Self {
53 self.max_connections = max;
54 self
55 }
56
57 #[must_use]
59 pub const fn acquire_timeout(mut self, timeout: std::time::Duration) -> Self {
60 self.acquire_timeout = timeout;
61 self
62 }
63
64 #[must_use]
66 pub const fn validate_on_acquire(mut self, validate: bool) -> Self {
67 self.validate_on_acquire = validate;
68 self
69 }
70
71 #[must_use]
73 pub const fn max_idle_time(mut self, time: std::time::Duration) -> Self {
74 self.max_idle_time = time;
75 self
76 }
77}
78
79pub struct PooledClient<T: Transport + 'static> {
83 client: Option<Client<T>>,
84 pool: Arc<ClientPoolInner<T>>,
85 key: String,
86}
87
88impl<T: Transport + 'static> PooledClient<T> {
89 pub fn client(&self) -> &Client<T> {
96 self.client.as_ref().expect("Client already dropped")
97 }
98
99 pub fn client_mut(&mut self) -> &mut Client<T> {
106 self.client.as_mut().expect("Client already dropped")
107 }
108}
109
110impl<T: Transport + 'static> std::ops::Deref for PooledClient<T> {
111 type Target = Client<T>;
112
113 fn deref(&self) -> &Self::Target {
114 self.client()
115 }
116}
117
118impl<T: Transport + 'static> std::ops::DerefMut for PooledClient<T> {
119 fn deref_mut(&mut self) -> &mut Self::Target {
120 self.client_mut()
121 }
122}
123
124impl<T: Transport + 'static> Drop for PooledClient<T> {
125 fn drop(&mut self) {
126 if let Some(client) = self.client.take() {
127 let pool = Arc::clone(&self.pool);
129 let key = self.key.clone();
130 tokio::spawn(async move {
131 pool.return_connection(key, client).await;
132 });
133 }
134 }
135}
136
137struct ClientPoolInner<T: Transport> {
139 config: PoolConfig,
141 connections: Mutex<HashMap<String, Vec<PooledEntry<T>>>>,
143 semaphores: Mutex<HashMap<String, Arc<Semaphore>>>,
145 client_info: ClientInfo,
147 client_caps: ClientCapabilities,
149}
150
151struct PooledEntry<T: Transport> {
153 client: Client<T>,
154 last_used: std::time::Instant,
155}
156
157impl<T: Transport> ClientPoolInner<T> {
158 async fn return_connection(&self, key: String, client: Client<T>) {
160 trace!(%key, "Returning connection to pool");
161
162 let entry = PooledEntry {
163 client,
164 last_used: std::time::Instant::now(),
165 };
166
167 let mut connections = self.connections.lock().await;
168 connections.entry(key).or_insert_with(Vec::new).push(entry);
169 }
170
171 async fn get_semaphore(&self, key: &str) -> Arc<Semaphore> {
173 let mut semaphores = self.semaphores.lock().await;
174 semaphores
175 .entry(key.to_string())
176 .or_insert_with(|| Arc::new(Semaphore::new(self.config.max_connections)))
177 .clone()
178 }
179}
180
181pub struct ClientPool<T: Transport> {
215 inner: Arc<ClientPoolInner<T>>,
216}
217
218impl<T: Transport + 'static> ClientPool<T> {
219 #[must_use]
221 pub fn builder() -> ClientPoolBuilder {
222 ClientPoolBuilder::new()
223 }
224
225 #[must_use]
227 pub fn new(client_info: ClientInfo, client_caps: ClientCapabilities) -> Self {
228 Self::with_config(client_info, client_caps, PoolConfig::default())
229 }
230
231 #[must_use]
233 pub fn with_config(
234 client_info: ClientInfo,
235 client_caps: ClientCapabilities,
236 config: PoolConfig,
237 ) -> Self {
238 Self {
239 inner: Arc::new(ClientPoolInner {
240 config,
241 connections: Mutex::new(HashMap::new()),
242 semaphores: Mutex::new(HashMap::new()),
243 client_info,
244 client_caps,
245 }),
246 }
247 }
248
249 pub async fn acquire<F, Fut>(
263 &self,
264 key: impl Into<String>,
265 connect: F,
266 ) -> Result<PooledClient<T>, McpError>
267 where
268 F: FnOnce() -> Fut,
269 Fut: Future<Output = Result<T, McpError>>,
270 {
271 let key = key.into();
272 debug!(%key, "Acquiring connection from pool");
273
274 let semaphore = self.inner.get_semaphore(&key).await;
276
277 let _permit =
279 tokio::time::timeout(self.inner.config.acquire_timeout, semaphore.acquire_owned())
280 .await
281 .map_err(|_| McpError::Internal {
282 message: format!("Timeout acquiring connection for {key}"),
283 source: None,
284 })?
285 .map_err(|_| McpError::Internal {
286 message: "Pool semaphore closed".to_string(),
287 source: None,
288 })?;
289
290 {
292 let mut connections = self.inner.connections.lock().await;
293 if let Some(entries) = connections.get_mut(&key) {
294 let max_idle = self.inner.config.max_idle_time;
296 entries.retain(|e| e.last_used.elapsed() < max_idle);
297
298 if let Some(entry) = entries.pop() {
300 trace!(%key, "Reusing existing connection");
301
302 if self.inner.config.validate_on_acquire {
304 if entry.client.ping().await.is_ok() {
306 return Ok(PooledClient {
307 client: Some(entry.client),
308 pool: Arc::clone(&self.inner),
309 key,
310 });
311 }
312 warn!(%key, "Cached connection failed validation");
313 } else {
314 return Ok(PooledClient {
315 client: Some(entry.client),
316 pool: Arc::clone(&self.inner),
317 key,
318 });
319 }
320 }
321 }
322 }
323
324 debug!(%key, "Creating new connection");
326 let transport = connect().await?;
327
328 let client = ClientBuilder::new()
329 .name(self.inner.client_info.name.clone())
330 .version(self.inner.client_info.version.clone())
331 .capabilities(self.inner.client_caps.clone())
332 .build(transport)
333 .await?;
334
335 Ok(PooledClient {
336 client: Some(client),
337 pool: Arc::clone(&self.inner),
338 key,
339 })
340 }
341
342 pub async fn clear(&self) {
344 let mut connections = self.inner.connections.lock().await;
345 connections.clear();
346 debug!("Cleared all pooled connections");
347 }
348
349 pub async fn clear_server(&self, key: &str) {
351 let mut connections = self.inner.connections.lock().await;
352 connections.remove(key);
353 debug!(%key, "Cleared pooled connections for server");
354 }
355
356 pub async fn stats(&self) -> PoolStats {
358 let connections = self.inner.connections.lock().await;
359 let mut total = 0;
360 let mut per_server = HashMap::new();
361
362 for (key, entries) in connections.iter() {
363 let count = entries.len();
364 total += count;
365 per_server.insert(key.clone(), count);
366 }
367
368 PoolStats {
369 total_connections: total,
370 connections_per_server: per_server,
371 max_connections: self.inner.config.max_connections,
372 }
373 }
374}
375
376impl<T: Transport + 'static> Clone for ClientPool<T> {
377 fn clone(&self) -> Self {
378 Self {
379 inner: Arc::clone(&self.inner),
380 }
381 }
382}
383
384#[derive(Debug, Clone)]
386pub struct PoolStats {
387 pub total_connections: usize,
389 pub connections_per_server: HashMap<String, usize>,
391 pub max_connections: usize,
393}
394
395pub struct ClientPoolBuilder {
397 config: PoolConfig,
398 client_info: Option<ClientInfo>,
399 client_caps: ClientCapabilities,
400}
401
402impl ClientPoolBuilder {
403 #[must_use]
405 pub fn new() -> Self {
406 Self {
407 config: PoolConfig::default(),
408 client_info: None,
409 client_caps: ClientCapabilities::default(),
410 }
411 }
412
413 pub fn client_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
415 self.client_info = Some(ClientInfo {
416 name: name.into(),
417 version: version.into(),
418 });
419 self
420 }
421
422 #[must_use]
424 pub fn capabilities(mut self, caps: ClientCapabilities) -> Self {
425 self.client_caps = caps;
426 self
427 }
428
429 #[must_use]
431 pub const fn max_connections(mut self, max: usize) -> Self {
432 self.config.max_connections = max;
433 self
434 }
435
436 #[must_use]
438 pub const fn acquire_timeout(mut self, timeout: std::time::Duration) -> Self {
439 self.config.acquire_timeout = timeout;
440 self
441 }
442
443 #[must_use]
445 pub const fn validate_on_acquire(mut self, validate: bool) -> Self {
446 self.config.validate_on_acquire = validate;
447 self
448 }
449
450 #[must_use]
452 pub const fn max_idle_time(mut self, time: std::time::Duration) -> Self {
453 self.config.max_idle_time = time;
454 self
455 }
456
457 #[must_use]
463 pub fn build<T: Transport + 'static>(self) -> ClientPool<T> {
464 let client_info = self
465 .client_info
466 .expect("client_info must be set before building pool");
467
468 ClientPool::with_config(client_info, self.client_caps, self.config)
469 }
470}
471
472impl Default for ClientPoolBuilder {
473 fn default() -> Self {
474 Self::new()
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481
482 #[test]
483 fn test_pool_config() {
484 let config = PoolConfig::new()
485 .max_connections(5)
486 .acquire_timeout(std::time::Duration::from_secs(10))
487 .validate_on_acquire(false)
488 .max_idle_time(std::time::Duration::from_secs(60));
489
490 assert_eq!(config.max_connections, 5);
491 assert_eq!(config.acquire_timeout.as_secs(), 10);
492 assert!(!config.validate_on_acquire);
493 assert_eq!(config.max_idle_time.as_secs(), 60);
494 }
495
496 #[test]
497 fn test_pool_builder() {
498 let builder = ClientPoolBuilder::new()
499 .client_info("test-client", "1.0.0")
500 .max_connections(10)
501 .validate_on_acquire(true);
502
503 assert_eq!(builder.config.max_connections, 10);
504 assert!(builder.config.validate_on_acquire);
505 }
506}