mcpkit_client/
pool.rs

1//! Client connection pooling.
2//!
3//! This module provides connection pooling for MCP clients,
4//! allowing efficient reuse of connections to MCP servers.
5
6use crate::builder::ClientBuilder;
7use crate::client::Client;
8use mcpkit_core::capability::{ClientCapabilities, ClientInfo};
9use mcpkit_core::error::McpError;
10use mcpkit_transport::Transport;
11use std::collections::HashMap;
12use std::future::Future;
13use std::sync::Arc;
14use tracing::{debug, trace, warn};
15
16// Pool is tokio-specific due to spawn and timeout requirements
17use tokio::sync::{Mutex, Semaphore};
18
19/// Configuration for a client connection pool.
20#[derive(Debug, Clone)]
21pub struct PoolConfig {
22    /// Maximum number of connections per server.
23    pub max_connections: usize,
24    /// Timeout for acquiring a connection.
25    pub acquire_timeout: std::time::Duration,
26    /// Whether to validate connections before use.
27    pub validate_on_acquire: bool,
28    /// Maximum idle time before a connection is closed.
29    pub max_idle_time: std::time::Duration,
30}
31
32impl Default for PoolConfig {
33    fn default() -> Self {
34        Self {
35            max_connections: 10,
36            acquire_timeout: std::time::Duration::from_secs(30),
37            validate_on_acquire: true,
38            max_idle_time: std::time::Duration::from_secs(300),
39        }
40    }
41}
42
43impl PoolConfig {
44    /// Create a new pool configuration.
45    #[must_use]
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    /// Set the maximum number of connections.
51    #[must_use]
52    pub fn max_connections(mut self, max: usize) -> Self {
53        self.max_connections = max;
54        self
55    }
56
57    /// Set the acquire timeout.
58    #[must_use]
59    pub fn acquire_timeout(mut self, timeout: std::time::Duration) -> Self {
60        self.acquire_timeout = timeout;
61        self
62    }
63
64    /// Set whether to validate connections before use.
65    #[must_use]
66    pub fn validate_on_acquire(mut self, validate: bool) -> Self {
67        self.validate_on_acquire = validate;
68        self
69    }
70
71    /// Set the maximum idle time.
72    #[must_use]
73    pub fn max_idle_time(mut self, time: std::time::Duration) -> Self {
74        self.max_idle_time = time;
75        self
76    }
77}
78
79/// A pooled client connection.
80///
81/// When dropped, the connection is returned to the pool.
82pub struct PooledClient<T: Transport + 'static> {
83    client: Option<Client<T>>,
84    pool: Arc<ClientPoolInner<T>>,
85    key: String,
86}
87
88impl<T: Transport + 'static> PooledClient<T> {
89    /// Get a reference to the underlying client.
90    pub fn client(&self) -> &Client<T> {
91        self.client.as_ref().expect("Client already dropped")
92    }
93
94    /// Get a mutable reference to the underlying client.
95    pub fn client_mut(&mut self) -> &mut Client<T> {
96        self.client.as_mut().expect("Client already dropped")
97    }
98}
99
100impl<T: Transport + 'static> std::ops::Deref for PooledClient<T> {
101    type Target = Client<T>;
102
103    fn deref(&self) -> &Self::Target {
104        self.client()
105    }
106}
107
108impl<T: Transport + 'static> std::ops::DerefMut for PooledClient<T> {
109    fn deref_mut(&mut self) -> &mut Self::Target {
110        self.client_mut()
111    }
112}
113
114impl<T: Transport + 'static> Drop for PooledClient<T> {
115    fn drop(&mut self) {
116        if let Some(client) = self.client.take() {
117            // Return the connection to the pool
118            let pool = Arc::clone(&self.pool);
119            let key = self.key.clone();
120            tokio::spawn(async move {
121                pool.return_connection(key, client).await;
122            });
123        }
124    }
125}
126
127/// Internal pool state.
128struct ClientPoolInner<T: Transport> {
129    /// Configuration.
130    config: PoolConfig,
131    /// Available connections by server key.
132    connections: Mutex<HashMap<String, Vec<PooledEntry<T>>>>,
133    /// Semaphore for limiting concurrent connections.
134    semaphores: Mutex<HashMap<String, Arc<Semaphore>>>,
135    /// Client info to use for new connections.
136    client_info: ClientInfo,
137    /// Client capabilities.
138    client_caps: ClientCapabilities,
139}
140
141/// An entry in the pool.
142struct PooledEntry<T: Transport> {
143    client: Client<T>,
144    last_used: std::time::Instant,
145}
146
147impl<T: Transport> ClientPoolInner<T> {
148    /// Return a connection to the pool.
149    async fn return_connection(&self, key: String, client: Client<T>) {
150        trace!(%key, "Returning connection to pool");
151
152        let entry = PooledEntry {
153            client,
154            last_used: std::time::Instant::now(),
155        };
156
157        let mut connections = self.connections.lock().await;
158        connections
159            .entry(key)
160            .or_insert_with(Vec::new)
161            .push(entry);
162    }
163
164    /// Get a semaphore for rate limiting connections to a server.
165    async fn get_semaphore(&self, key: &str) -> Arc<Semaphore> {
166        let mut semaphores = self.semaphores.lock().await;
167        semaphores
168            .entry(key.to_string())
169            .or_insert_with(|| Arc::new(Semaphore::new(self.config.max_connections)))
170            .clone()
171    }
172}
173
174/// A pool of client connections.
175///
176/// The pool manages connections to multiple MCP servers, reusing
177/// existing connections when possible and creating new ones as needed.
178///
179/// # Example
180///
181/// ```no_run
182/// use mcpkit_client::{ClientPool, ClientPoolBuilder};
183/// use mcpkit_transport::SpawnedTransport;
184/// use mcpkit_core::error::McpError;
185///
186/// # async fn example() -> Result<(), McpError> {
187/// let pool = ClientPool::<SpawnedTransport>::builder()
188///     .client_info("my-client", "1.0.0")
189///     .max_connections(5)
190///     .build();
191///
192/// let client = pool.acquire("server-key", || async {
193///     // Create a new connection to a server
194///     // TransportError converts to McpError automatically
195///     Ok::<_, McpError>(
196///         SpawnedTransport::spawn("my-server", &[] as &[&str]).await?
197///     )
198/// }).await?;
199///
200/// // Use the client
201/// let tools = client.list_tools().await?;
202///
203/// // Client is returned to pool when dropped
204/// # Ok(())
205/// # }
206/// ```
207pub struct ClientPool<T: Transport> {
208    inner: Arc<ClientPoolInner<T>>,
209}
210
211impl<T: Transport + 'static> ClientPool<T> {
212    /// Create a new pool builder.
213    pub fn builder() -> ClientPoolBuilder {
214        ClientPoolBuilder::new()
215    }
216
217    /// Create a new pool with default configuration.
218    pub fn new(client_info: ClientInfo, client_caps: ClientCapabilities) -> Self {
219        Self::with_config(client_info, client_caps, PoolConfig::default())
220    }
221
222    /// Create a new pool with custom configuration.
223    pub fn with_config(
224        client_info: ClientInfo,
225        client_caps: ClientCapabilities,
226        config: PoolConfig,
227    ) -> Self {
228        Self {
229            inner: Arc::new(ClientPoolInner {
230                config,
231                connections: Mutex::new(HashMap::new()),
232                semaphores: Mutex::new(HashMap::new()),
233                client_info,
234                client_caps,
235            }),
236        }
237    }
238
239    /// Acquire a connection from the pool.
240    ///
241    /// If a cached connection is available, it is returned. Otherwise,
242    /// the `connect` function is called to create a new connection.
243    ///
244    /// # Arguments
245    ///
246    /// * `key` - A unique key identifying the server
247    /// * `connect` - A function that creates a new transport connection
248    ///
249    /// # Errors
250    ///
251    /// Returns an error if the connection cannot be acquired.
252    pub async fn acquire<F, Fut>(
253        &self,
254        key: impl Into<String>,
255        connect: F,
256    ) -> Result<PooledClient<T>, McpError>
257    where
258        F: FnOnce() -> Fut,
259        Fut: Future<Output = Result<T, McpError>>,
260    {
261        let key = key.into();
262        debug!(%key, "Acquiring connection from pool");
263
264        // Get the semaphore for rate limiting
265        let semaphore = self.inner.get_semaphore(&key).await;
266
267        // Acquire a permit (with timeout)
268        let _permit = tokio::time::timeout(
269            self.inner.config.acquire_timeout,
270            semaphore.acquire_owned(),
271        )
272        .await
273        .map_err(|_| McpError::Internal {
274            message: format!("Timeout acquiring connection for {key}"),
275            source: None,
276        })?
277        .map_err(|_| McpError::Internal {
278            message: "Pool semaphore closed".to_string(),
279            source: None,
280        })?;
281
282        // Try to get an existing connection
283        {
284            let mut connections = self.inner.connections.lock().await;
285            if let Some(entries) = connections.get_mut(&key) {
286                // Remove stale connections
287                let max_idle = self.inner.config.max_idle_time;
288                entries.retain(|e| e.last_used.elapsed() < max_idle);
289
290                // Get a connection if available
291                if let Some(entry) = entries.pop() {
292                    trace!(%key, "Reusing existing connection");
293
294                    // Optionally validate the connection
295                    if self.inner.config.validate_on_acquire {
296                        // Try to ping
297                        if entry.client.ping().await.is_ok() {
298                            return Ok(PooledClient {
299                                client: Some(entry.client),
300                                pool: Arc::clone(&self.inner),
301                                key,
302                            });
303                        }
304                        warn!(%key, "Cached connection failed validation");
305                    } else {
306                        return Ok(PooledClient {
307                            client: Some(entry.client),
308                            pool: Arc::clone(&self.inner),
309                            key,
310                        });
311                    }
312                }
313            }
314        }
315
316        // Create a new connection
317        debug!(%key, "Creating new connection");
318        let transport = connect().await?;
319
320        let client = ClientBuilder::new()
321            .name(self.inner.client_info.name.clone())
322            .version(self.inner.client_info.version.clone())
323            .capabilities(self.inner.client_caps.clone())
324            .build(transport)
325            .await?;
326
327        Ok(PooledClient {
328            client: Some(client),
329            pool: Arc::clone(&self.inner),
330            key,
331        })
332    }
333
334    /// Clear all cached connections.
335    pub async fn clear(&self) {
336        let mut connections = self.inner.connections.lock().await;
337        connections.clear();
338        debug!("Cleared all pooled connections");
339    }
340
341    /// Clear cached connections for a specific server.
342    pub async fn clear_server(&self, key: &str) {
343        let mut connections = self.inner.connections.lock().await;
344        connections.remove(key);
345        debug!(%key, "Cleared pooled connections for server");
346    }
347
348    /// Get statistics about the pool.
349    pub async fn stats(&self) -> PoolStats {
350        let connections = self.inner.connections.lock().await;
351        let mut total = 0;
352        let mut per_server = HashMap::new();
353
354        for (key, entries) in connections.iter() {
355            let count = entries.len();
356            total += count;
357            per_server.insert(key.clone(), count);
358        }
359
360        PoolStats {
361            total_connections: total,
362            connections_per_server: per_server,
363            max_connections: self.inner.config.max_connections,
364        }
365    }
366}
367
368impl<T: Transport + 'static> Clone for ClientPool<T> {
369    fn clone(&self) -> Self {
370        Self {
371            inner: Arc::clone(&self.inner),
372        }
373    }
374}
375
376/// Statistics about a connection pool.
377#[derive(Debug, Clone)]
378pub struct PoolStats {
379    /// Total number of cached connections.
380    pub total_connections: usize,
381    /// Number of connections per server.
382    pub connections_per_server: HashMap<String, usize>,
383    /// Maximum connections per server.
384    pub max_connections: usize,
385}
386
387/// Builder for creating a client pool.
388pub struct ClientPoolBuilder {
389    config: PoolConfig,
390    client_info: Option<ClientInfo>,
391    client_caps: ClientCapabilities,
392}
393
394impl ClientPoolBuilder {
395    /// Create a new pool builder.
396    pub fn new() -> Self {
397        Self {
398            config: PoolConfig::default(),
399            client_info: None,
400            client_caps: ClientCapabilities::default(),
401        }
402    }
403
404    /// Set the client info.
405    pub fn client_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
406        self.client_info = Some(ClientInfo {
407            name: name.into(),
408            version: version.into(),
409        });
410        self
411    }
412
413    /// Set the client capabilities.
414    pub fn capabilities(mut self, caps: ClientCapabilities) -> Self {
415        self.client_caps = caps;
416        self
417    }
418
419    /// Set the maximum number of connections per server.
420    pub fn max_connections(mut self, max: usize) -> Self {
421        self.config.max_connections = max;
422        self
423    }
424
425    /// Set the acquire timeout.
426    pub fn acquire_timeout(mut self, timeout: std::time::Duration) -> Self {
427        self.config.acquire_timeout = timeout;
428        self
429    }
430
431    /// Set whether to validate connections on acquire.
432    pub fn validate_on_acquire(mut self, validate: bool) -> Self {
433        self.config.validate_on_acquire = validate;
434        self
435    }
436
437    /// Set the maximum idle time.
438    pub fn max_idle_time(mut self, time: std::time::Duration) -> Self {
439        self.config.max_idle_time = time;
440        self
441    }
442
443    /// Build the pool.
444    ///
445    /// # Panics
446    ///
447    /// Panics if client_info was not set.
448    pub fn build<T: Transport + 'static>(self) -> ClientPool<T> {
449        let client_info = self
450            .client_info
451            .expect("client_info must be set before building pool");
452
453        ClientPool::with_config(client_info, self.client_caps, self.config)
454    }
455}
456
457impl Default for ClientPoolBuilder {
458    fn default() -> Self {
459        Self::new()
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    #[test]
468    fn test_pool_config() {
469        let config = PoolConfig::new()
470            .max_connections(5)
471            .acquire_timeout(std::time::Duration::from_secs(10))
472            .validate_on_acquire(false)
473            .max_idle_time(std::time::Duration::from_secs(60));
474
475        assert_eq!(config.max_connections, 5);
476        assert_eq!(config.acquire_timeout.as_secs(), 10);
477        assert!(!config.validate_on_acquire);
478        assert_eq!(config.max_idle_time.as_secs(), 60);
479    }
480
481    #[test]
482    fn test_pool_builder() {
483        let builder = ClientPoolBuilder::new()
484            .client_info("test-client", "1.0.0")
485            .max_connections(10)
486            .validate_on_acquire(true);
487
488        assert_eq!(builder.config.max_connections, 10);
489        assert!(builder.config.validate_on_acquire);
490    }
491}