mcpkit_client/
pool.rs

1//! Client connection pooling.
2//!
3//! This module provides connection pooling for MCP clients,
4//! allowing efficient reuse of connections to MCP servers.
5
6use crate::builder::ClientBuilder;
7use crate::client::Client;
8use mcpkit_core::capability::{ClientCapabilities, ClientInfo};
9use mcpkit_core::error::McpError;
10use mcpkit_transport::Transport;
11use std::collections::HashMap;
12use std::future::Future;
13use std::sync::Arc;
14use tracing::{debug, trace, warn};
15
16// Pool is tokio-specific due to spawn and timeout requirements
17use tokio::sync::{Mutex, Semaphore};
18
19/// Configuration for a client connection pool.
20#[derive(Debug, Clone)]
21pub struct PoolConfig {
22    /// Maximum number of connections per server.
23    pub max_connections: usize,
24    /// Timeout for acquiring a connection.
25    pub acquire_timeout: std::time::Duration,
26    /// Whether to validate connections before use.
27    pub validate_on_acquire: bool,
28    /// Maximum idle time before a connection is closed.
29    pub max_idle_time: std::time::Duration,
30}
31
32impl Default for PoolConfig {
33    fn default() -> Self {
34        Self {
35            max_connections: 10,
36            acquire_timeout: std::time::Duration::from_secs(30),
37            validate_on_acquire: true,
38            max_idle_time: std::time::Duration::from_secs(300),
39        }
40    }
41}
42
43impl PoolConfig {
44    /// Create a new pool configuration.
45    #[must_use]
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    /// Set the maximum number of connections.
51    #[must_use]
52    pub const fn max_connections(mut self, max: usize) -> Self {
53        self.max_connections = max;
54        self
55    }
56
57    /// Set the acquire timeout.
58    #[must_use]
59    pub const fn acquire_timeout(mut self, timeout: std::time::Duration) -> Self {
60        self.acquire_timeout = timeout;
61        self
62    }
63
64    /// Set whether to validate connections before use.
65    #[must_use]
66    pub const fn validate_on_acquire(mut self, validate: bool) -> Self {
67        self.validate_on_acquire = validate;
68        self
69    }
70
71    /// Set the maximum idle time.
72    #[must_use]
73    pub const fn max_idle_time(mut self, time: std::time::Duration) -> Self {
74        self.max_idle_time = time;
75        self
76    }
77}
78
79/// A pooled client connection.
80///
81/// When dropped, the connection is returned to the pool.
82pub struct PooledClient<T: Transport + 'static> {
83    client: Option<Client<T>>,
84    pool: Arc<ClientPoolInner<T>>,
85    key: String,
86}
87
88impl<T: Transport + 'static> PooledClient<T> {
89    /// Get a reference to the underlying client.
90    ///
91    /// # Panics
92    ///
93    /// Panics if the client was already dropped or taken. This should never
94    /// happen in normal use - the panic indicates a bug in the pool implementation.
95    pub fn client(&self) -> &Client<T> {
96        self.client.as_ref().expect("Client already dropped")
97    }
98
99    /// Get a mutable reference to the underlying client.
100    ///
101    /// # Panics
102    ///
103    /// Panics if the client was already dropped or taken. This should never
104    /// happen in normal use - the panic indicates a bug in the pool implementation.
105    pub fn client_mut(&mut self) -> &mut Client<T> {
106        self.client.as_mut().expect("Client already dropped")
107    }
108}
109
110impl<T: Transport + 'static> std::ops::Deref for PooledClient<T> {
111    type Target = Client<T>;
112
113    fn deref(&self) -> &Self::Target {
114        self.client()
115    }
116}
117
118impl<T: Transport + 'static> std::ops::DerefMut for PooledClient<T> {
119    fn deref_mut(&mut self) -> &mut Self::Target {
120        self.client_mut()
121    }
122}
123
124impl<T: Transport + 'static> Drop for PooledClient<T> {
125    fn drop(&mut self) {
126        if let Some(client) = self.client.take() {
127            // Return the connection to the pool
128            let pool = Arc::clone(&self.pool);
129            let key = self.key.clone();
130            tokio::spawn(async move {
131                pool.return_connection(key, client).await;
132            });
133        }
134    }
135}
136
137/// Internal pool state.
138struct ClientPoolInner<T: Transport> {
139    /// Configuration.
140    config: PoolConfig,
141    /// Available connections by server key.
142    connections: Mutex<HashMap<String, Vec<PooledEntry<T>>>>,
143    /// Semaphore for limiting concurrent connections.
144    semaphores: Mutex<HashMap<String, Arc<Semaphore>>>,
145    /// Client info to use for new connections.
146    client_info: ClientInfo,
147    /// Client capabilities.
148    client_caps: ClientCapabilities,
149}
150
151/// An entry in the pool.
152struct PooledEntry<T: Transport> {
153    client: Client<T>,
154    last_used: std::time::Instant,
155}
156
157impl<T: Transport> ClientPoolInner<T> {
158    /// Return a connection to the pool.
159    async fn return_connection(&self, key: String, client: Client<T>) {
160        trace!(%key, "Returning connection to pool");
161
162        let entry = PooledEntry {
163            client,
164            last_used: std::time::Instant::now(),
165        };
166
167        let mut connections = self.connections.lock().await;
168        connections.entry(key).or_insert_with(Vec::new).push(entry);
169    }
170
171    /// Get a semaphore for rate limiting connections to a server.
172    async fn get_semaphore(&self, key: &str) -> Arc<Semaphore> {
173        let mut semaphores = self.semaphores.lock().await;
174        semaphores
175            .entry(key.to_string())
176            .or_insert_with(|| Arc::new(Semaphore::new(self.config.max_connections)))
177            .clone()
178    }
179}
180
181/// A pool of client connections.
182///
183/// The pool manages connections to multiple MCP servers, reusing
184/// existing connections when possible and creating new ones as needed.
185///
186/// # Example
187///
188/// ```no_run
189/// use mcpkit_client::{ClientPool, ClientPoolBuilder};
190/// use mcpkit_transport::SpawnedTransport;
191/// use mcpkit_core::error::McpError;
192///
193/// # async fn example() -> Result<(), McpError> {
194/// let pool = ClientPool::<SpawnedTransport>::builder()
195///     .client_info("my-client", "1.0.0")
196///     .max_connections(5)
197///     .build();
198///
199/// let client = pool.acquire("server-key", || async {
200///     // Create a new connection to a server
201///     // TransportError converts to McpError automatically
202///     Ok::<_, McpError>(
203///         SpawnedTransport::spawn("my-server", &[] as &[&str]).await?
204///     )
205/// }).await?;
206///
207/// // Use the client
208/// let tools = client.list_tools().await?;
209///
210/// // Client is returned to pool when dropped
211/// # Ok(())
212/// # }
213/// ```
214pub struct ClientPool<T: Transport> {
215    inner: Arc<ClientPoolInner<T>>,
216}
217
218impl<T: Transport + 'static> ClientPool<T> {
219    /// Create a new pool builder.
220    #[must_use]
221    pub fn builder() -> ClientPoolBuilder {
222        ClientPoolBuilder::new()
223    }
224
225    /// Create a new pool with default configuration.
226    #[must_use]
227    pub fn new(client_info: ClientInfo, client_caps: ClientCapabilities) -> Self {
228        Self::with_config(client_info, client_caps, PoolConfig::default())
229    }
230
231    /// Create a new pool with custom configuration.
232    #[must_use]
233    pub fn with_config(
234        client_info: ClientInfo,
235        client_caps: ClientCapabilities,
236        config: PoolConfig,
237    ) -> Self {
238        Self {
239            inner: Arc::new(ClientPoolInner {
240                config,
241                connections: Mutex::new(HashMap::new()),
242                semaphores: Mutex::new(HashMap::new()),
243                client_info,
244                client_caps,
245            }),
246        }
247    }
248
249    /// Acquire a connection from the pool.
250    ///
251    /// If a cached connection is available, it is returned. Otherwise,
252    /// the `connect` function is called to create a new connection.
253    ///
254    /// # Arguments
255    ///
256    /// * `key` - A unique key identifying the server
257    /// * `connect` - A function that creates a new transport connection
258    ///
259    /// # Errors
260    ///
261    /// Returns an error if the connection cannot be acquired.
262    pub async fn acquire<F, Fut>(
263        &self,
264        key: impl Into<String>,
265        connect: F,
266    ) -> Result<PooledClient<T>, McpError>
267    where
268        F: FnOnce() -> Fut,
269        Fut: Future<Output = Result<T, McpError>>,
270    {
271        let key = key.into();
272        debug!(%key, "Acquiring connection from pool");
273
274        // Get the semaphore for rate limiting
275        let semaphore = self.inner.get_semaphore(&key).await;
276
277        // Acquire a permit (with timeout)
278        let _permit =
279            tokio::time::timeout(self.inner.config.acquire_timeout, semaphore.acquire_owned())
280                .await
281                .map_err(|_| McpError::Internal {
282                    message: format!("Timeout acquiring connection for {key}"),
283                    source: None,
284                })?
285                .map_err(|_| McpError::Internal {
286                    message: "Pool semaphore closed".to_string(),
287                    source: None,
288                })?;
289
290        // Try to get an existing connection
291        {
292            let mut connections = self.inner.connections.lock().await;
293            if let Some(entries) = connections.get_mut(&key) {
294                // Remove stale connections
295                let max_idle = self.inner.config.max_idle_time;
296                entries.retain(|e| e.last_used.elapsed() < max_idle);
297
298                // Get a connection if available
299                if let Some(entry) = entries.pop() {
300                    trace!(%key, "Reusing existing connection");
301
302                    // Optionally validate the connection
303                    if self.inner.config.validate_on_acquire {
304                        // Try to ping
305                        if entry.client.ping().await.is_ok() {
306                            return Ok(PooledClient {
307                                client: Some(entry.client),
308                                pool: Arc::clone(&self.inner),
309                                key,
310                            });
311                        }
312                        warn!(%key, "Cached connection failed validation");
313                    } else {
314                        return Ok(PooledClient {
315                            client: Some(entry.client),
316                            pool: Arc::clone(&self.inner),
317                            key,
318                        });
319                    }
320                }
321            }
322        }
323
324        // Create a new connection
325        debug!(%key, "Creating new connection");
326        let transport = connect().await?;
327
328        let client = ClientBuilder::new()
329            .name(self.inner.client_info.name.clone())
330            .version(self.inner.client_info.version.clone())
331            .capabilities(self.inner.client_caps.clone())
332            .build(transport)
333            .await?;
334
335        Ok(PooledClient {
336            client: Some(client),
337            pool: Arc::clone(&self.inner),
338            key,
339        })
340    }
341
342    /// Clear all cached connections.
343    pub async fn clear(&self) {
344        let mut connections = self.inner.connections.lock().await;
345        connections.clear();
346        debug!("Cleared all pooled connections");
347    }
348
349    /// Clear cached connections for a specific server.
350    pub async fn clear_server(&self, key: &str) {
351        let mut connections = self.inner.connections.lock().await;
352        connections.remove(key);
353        debug!(%key, "Cleared pooled connections for server");
354    }
355
356    /// Get statistics about the pool.
357    pub async fn stats(&self) -> PoolStats {
358        let connections = self.inner.connections.lock().await;
359        let mut total = 0;
360        let mut per_server = HashMap::new();
361
362        for (key, entries) in connections.iter() {
363            let count = entries.len();
364            total += count;
365            per_server.insert(key.clone(), count);
366        }
367
368        PoolStats {
369            total_connections: total,
370            connections_per_server: per_server,
371            max_connections: self.inner.config.max_connections,
372        }
373    }
374}
375
376impl<T: Transport + 'static> Clone for ClientPool<T> {
377    fn clone(&self) -> Self {
378        Self {
379            inner: Arc::clone(&self.inner),
380        }
381    }
382}
383
384/// Statistics about a connection pool.
385#[derive(Debug, Clone)]
386pub struct PoolStats {
387    /// Total number of cached connections.
388    pub total_connections: usize,
389    /// Number of connections per server.
390    pub connections_per_server: HashMap<String, usize>,
391    /// Maximum connections per server.
392    pub max_connections: usize,
393}
394
395/// Builder for creating a client pool.
396pub struct ClientPoolBuilder {
397    config: PoolConfig,
398    client_info: Option<ClientInfo>,
399    client_caps: ClientCapabilities,
400}
401
402impl ClientPoolBuilder {
403    /// Create a new pool builder.
404    #[must_use]
405    pub fn new() -> Self {
406        Self {
407            config: PoolConfig::default(),
408            client_info: None,
409            client_caps: ClientCapabilities::default(),
410        }
411    }
412
413    /// Set the client info.
414    pub fn client_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
415        self.client_info = Some(ClientInfo {
416            name: name.into(),
417            version: version.into(),
418        });
419        self
420    }
421
422    /// Set the client capabilities.
423    #[must_use]
424    pub fn capabilities(mut self, caps: ClientCapabilities) -> Self {
425        self.client_caps = caps;
426        self
427    }
428
429    /// Set the maximum number of connections per server.
430    #[must_use]
431    pub const fn max_connections(mut self, max: usize) -> Self {
432        self.config.max_connections = max;
433        self
434    }
435
436    /// Set the acquire timeout.
437    #[must_use]
438    pub const fn acquire_timeout(mut self, timeout: std::time::Duration) -> Self {
439        self.config.acquire_timeout = timeout;
440        self
441    }
442
443    /// Set whether to validate connections on acquire.
444    #[must_use]
445    pub const fn validate_on_acquire(mut self, validate: bool) -> Self {
446        self.config.validate_on_acquire = validate;
447        self
448    }
449
450    /// Set the maximum idle time.
451    #[must_use]
452    pub const fn max_idle_time(mut self, time: std::time::Duration) -> Self {
453        self.config.max_idle_time = time;
454        self
455    }
456
457    /// Build the pool.
458    ///
459    /// # Panics
460    ///
461    /// Panics if `client_info` was not set.
462    #[must_use]
463    pub fn build<T: Transport + 'static>(self) -> ClientPool<T> {
464        let client_info = self
465            .client_info
466            .expect("client_info must be set before building pool");
467
468        ClientPool::with_config(client_info, self.client_caps, self.config)
469    }
470}
471
472impl Default for ClientPoolBuilder {
473    fn default() -> Self {
474        Self::new()
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481
482    #[test]
483    fn test_pool_config() {
484        let config = PoolConfig::new()
485            .max_connections(5)
486            .acquire_timeout(std::time::Duration::from_secs(10))
487            .validate_on_acquire(false)
488            .max_idle_time(std::time::Duration::from_secs(60));
489
490        assert_eq!(config.max_connections, 5);
491        assert_eq!(config.acquire_timeout.as_secs(), 10);
492        assert!(!config.validate_on_acquire);
493        assert_eq!(config.max_idle_time.as_secs(), 60);
494    }
495
496    #[test]
497    fn test_pool_builder() {
498        let builder = ClientPoolBuilder::new()
499            .client_info("test-client", "1.0.0")
500            .max_connections(10)
501            .validate_on_acquire(true);
502
503        assert_eq!(builder.config.max_connections, 10);
504        assert!(builder.config.validate_on_acquire);
505    }
506}