connection_pool/
connection_pool.rs

1use std::collections::VecDeque;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::{Mutex, Semaphore};
6use tokio::task::JoinHandle;
7use tokio::time::{interval, timeout};
8
9pub const DEFAULT_MAX_SIZE: usize = 10;
10pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(5 * 60); // 5 minutes
11pub const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); // 10 seconds
12pub const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); // 30 seconds
13
14/// Configuration for background cleanup task
15#[derive(Clone)]
16pub struct CleanupConfig {
17    pub interval: Duration,
18    pub enabled: bool,
19}
20
21impl Default for CleanupConfig {
22    fn default() -> Self {
23        Self {
24            interval: DEFAULT_CLEANUP_INTERVAL,
25            enabled: true,
26        }
27    }
28}
29
30/// Background cleanup task controller
31struct CleanupTaskController {
32    handle: Option<JoinHandle<()>>,
33    shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
34}
35
36impl CleanupTaskController {
37    fn new() -> Self {
38        Self {
39            handle: None,
40            shutdown_tx: None,
41        }
42    }
43
44    fn start<T, M>(
45        &mut self,
46        connections: Arc<Mutex<VecDeque<InnerConnection<T>>>>,
47        max_idle_time: Duration,
48        cleanup_interval: Duration,
49        manager: Arc<M>,
50        max_size: usize,
51    ) where
52        T: Send + 'static,
53        M: ConnectionManager<Connection = T> + Send + Sync + 'static,
54    {
55        if self.handle.is_some() {
56            log::warn!("Cleanup task is already running");
57            return;
58        }
59        let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
60        self.shutdown_tx = Some(shutdown_tx);
61        let handle = tokio::spawn(async move {
62            let mut interval_timer = interval(cleanup_interval);
63            log::info!("Background cleanup task started with interval: {cleanup_interval:?}");
64
65            loop {
66                tokio::select! {
67                    _ = interval_timer.tick() => {}
68                    _ = &mut shutdown_rx => {
69                        log::info!("Received shutdown signal, exiting cleanup loop");
70                        break;
71                    }
72                };
73
74                let mut connections = connections.lock().await;
75                let initial_count = connections.len();
76                let now = Instant::now();
77
78                // check both idle time and is_valid
79                let mut valid_connections = VecDeque::new();
80                for mut conn in connections.drain(..) {
81                    let not_expired = now.duration_since(conn.created_at) < max_idle_time;
82                    let is_valid = if not_expired {
83                        manager.is_valid(&mut conn.connection).await
84                    } else {
85                        false
86                    };
87                    if not_expired && is_valid {
88                        valid_connections.push_back(conn);
89                    }
90                }
91                let removed_count = initial_count - valid_connections.len();
92                *connections = valid_connections;
93
94                if removed_count > 0 {
95                    log::debug!("Background cleanup removed {removed_count} expired/invalid connections");
96                }
97
98                log::debug!("Current pool (remaining {}/{max_size}) after cleanup", connections.len());
99            }
100        });
101
102        self.handle = Some(handle);
103    }
104
105    async fn stop(&mut self) {
106        if let Some(shutdown_tx) = self.shutdown_tx.take() {
107            let _ = shutdown_tx.send(());
108        }
109        if let Some(handle) = self.handle.take() {
110            if let Err(e) = handle.await {
111                log::error!("Error while stopping background cleanup task: {e}");
112            }
113            log::info!("Background cleanup task stopped in async stop");
114        }
115    }
116
117    fn stop_sync(&mut self) {
118        if let Some(shutdown_tx) = self.shutdown_tx.take() {
119            let _ = shutdown_tx.send(());
120        }
121        if let Some(handle) = self.handle.take() {
122            std::thread::sleep(std::time::Duration::from_millis(100));
123            handle.abort();
124            log::info!("Background cleanup task stopped in stop_sync");
125        }
126    }
127}
128
129impl Drop for CleanupTaskController {
130    fn drop(&mut self) {
131        self.stop_sync();
132    }
133}
134
135/// Manager responsible for creating new [`ConnectionManager::Connection`]s or checking existing ones.
136pub trait ConnectionManager: Sync + Send + Clone {
137    /// Type of [`ConnectionManager::Connection`]s that this [`ConnectionManager`] creates and recycles.
138    type Connection: Send;
139
140    /// Error that this [`ConnectionManager`] can return when creating and/or recycling [`ConnectionManager::Connection`]s.
141    type Error: std::error::Error + Send + Sync + 'static;
142
143    /// Future that resolves to a new [`ConnectionManager::Connection`] when created.
144    type CreateFut: Future<Output = Result<Self::Connection, Self::Error>> + Send;
145
146    /// Future that resolves to true if the connection is valid, false otherwise.
147    type ValidFut<'a>: Future<Output = bool> + Send
148    where
149        Self: 'a;
150
151    /// Create a new connection.
152    fn create_connection(&self) -> Self::CreateFut;
153
154    /// Check if a connection is valid.
155    fn is_valid<'a>(&'a self, connection: &'a mut Self::Connection) -> Self::ValidFut<'a>;
156}
157
158/// Connection pool
159pub struct ConnectionPool<M>
160where
161    M: ConnectionManager + Send + Sync + Clone + 'static,
162{
163    connections: Arc<Mutex<VecDeque<InnerConnection<M::Connection>>>>,
164    semaphore: Arc<Semaphore>,
165    max_size: usize,
166    manager: M,
167    max_idle_time: Duration,
168    connection_timeout: Duration,
169    cleanup_controller: Arc<Mutex<CleanupTaskController>>,
170    outstanding_count: Arc<std::sync::atomic::AtomicUsize>,
171}
172
173/// Pooled inner connection, used within the connection pool
174struct InnerConnection<T> {
175    pub connection: T,
176    pub created_at: Instant,
177}
178
179/// Pooled managed stream, provided for the outer world usage
180pub struct ManagedConnection<M>
181where
182    M: ConnectionManager + Send + Sync + 'static,
183{
184    connection: Option<M::Connection>,
185    pool: Arc<ConnectionPool<M>>,
186    _permit: tokio::sync::OwnedSemaphorePermit,
187}
188
189impl<M> ManagedConnection<M>
190where
191    M: ConnectionManager,
192{
193    /// Consume the managed connection and return the inner connection
194    pub fn into_inner(mut self) -> M::Connection {
195        self.connection.take().unwrap()
196    }
197
198    fn new(connection: M::Connection, pool: Arc<ConnectionPool<M>>, permit: tokio::sync::OwnedSemaphorePermit) -> Self {
199        pool.outstanding_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
200        ManagedConnection {
201            connection: Some(connection),
202            pool,
203            _permit: permit,
204        }
205    }
206}
207
208impl<M> ConnectionPool<M>
209where
210    M: ConnectionManager + Send + Sync + Clone + 'static,
211{
212    /// Create a new connection pool
213    pub fn new(
214        max_size: Option<usize>,
215        max_idle_time: Option<Duration>,
216        connection_timeout: Option<Duration>,
217        cleanup_config: Option<CleanupConfig>,
218        manager: M,
219    ) -> Arc<Self> {
220        let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
221        let max_idle_time = max_idle_time.unwrap_or(DEFAULT_IDLE_TIMEOUT);
222        let connection_timeout = connection_timeout.unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
223        let cleanup_config = cleanup_config.unwrap_or_default();
224
225        log::info!(
226            "Creating connection pool with max_size: {max_size}, idle_timeout: {max_idle_time:?}, connection_timeout: {connection_timeout:?}, cleanup_enabled: {}",
227            cleanup_config.enabled
228        );
229
230        let connections = Arc::new(Mutex::new(VecDeque::new()));
231        let cleanup_controller = Arc::new(Mutex::new(CleanupTaskController::new()));
232
233        let pool = Arc::new(ConnectionPool {
234            connections: connections.clone(),
235            semaphore: Arc::new(Semaphore::new(max_size)),
236            max_size,
237            manager,
238            max_idle_time,
239            connection_timeout,
240            cleanup_controller: cleanup_controller.clone(),
241            outstanding_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
242        });
243
244        // Start background cleanup task if enabled
245        if cleanup_config.enabled {
246            let manager = Arc::new(pool.manager.clone());
247            tokio::spawn(async move {
248                let mut controller = cleanup_controller.lock().await;
249                controller.start(connections, max_idle_time, cleanup_config.interval, manager, max_size);
250            });
251        }
252
253        pool
254    }
255
256    /// Get a connection from the pool
257    pub async fn get_connection(self: Arc<Self>) -> Result<ManagedConnection<M>, PoolError<M::Error>> {
258        log::debug!("Attempting to get connection from pool");
259
260        // Use semaphore to limit concurrent connections
261        let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| PoolError::PoolClosed)?;
262
263        // Try to get an existing connection from the pool
264        {
265            let mut connections = self.connections.lock().await;
266            loop {
267                let Some(mut pooled_conn) = connections.pop_front() else {
268                    // No available connection, break the loop
269                    break;
270                };
271                log::trace!("Found existing connection in pool, validating...");
272                let age = Instant::now().duration_since(pooled_conn.created_at);
273                let valid = self.manager.is_valid(&mut pooled_conn.connection).await;
274                if age >= self.max_idle_time {
275                    log::debug!("Connection expired (age: {age:?}), discarding");
276                } else if !valid {
277                    log::warn!("Connection validation failed, discarding invalid connection");
278                } else {
279                    let size = connections.len();
280                    log::debug!("Reusing existing connection from pool (remaining: {size}/{})", self.max_size);
281                    return Ok(ManagedConnection::new(pooled_conn.connection, self.clone(), permit));
282                }
283            }
284        }
285
286        log::trace!("No valid connection available, creating new connection...");
287        // Create new connection
288        match timeout(self.connection_timeout, self.manager.create_connection()).await {
289            Ok(Ok(connection)) => {
290                log::info!("Successfully created new connection");
291                Ok(ManagedConnection::new(connection, self.clone(), permit))
292            }
293            Ok(Err(e)) => {
294                log::error!("Failed to create new connection");
295                Err(PoolError::Creation(e))
296            }
297            Err(_) => {
298                log::warn!("Connection creation timed out after {:?}", self.connection_timeout);
299                Err(PoolError::Timeout)
300            }
301        }
302    }
303
304    pub fn outstanding_count(&self) -> usize {
305        self.outstanding_count.load(std::sync::atomic::Ordering::SeqCst)
306    }
307
308    pub async fn pool_size(&self) -> usize {
309        self.connections.lock().await.len()
310    }
311
312    pub fn max_size(&self) -> usize {
313        self.max_size
314    }
315
316    /// Stop the background cleanup task
317    pub async fn stop_cleanup_task(&self) {
318        let mut controller = self.cleanup_controller.lock().await;
319        controller.stop().await;
320    }
321
322    /// Restart the background cleanup task with new configuration
323    pub async fn restart_cleanup_task(&self, cleanup_config: CleanupConfig) {
324        self.stop_cleanup_task().await;
325
326        if cleanup_config.enabled {
327            let manager = Arc::new(self.manager.clone());
328            let mut controller = self.cleanup_controller.lock().await;
329            let m = self.max_size;
330            controller.start(self.connections.clone(), self.max_idle_time, cleanup_config.interval, manager, m);
331        }
332    }
333}
334
335impl<M> ConnectionPool<M>
336where
337    M: ConnectionManager,
338{
339    async fn recycle(&self, mut connection: M::Connection) {
340        if !self.manager.is_valid(&mut connection).await {
341            log::debug!("Invalid connection, dropping");
342            return;
343        }
344        let mut connections = self.connections.lock().await;
345        if connections.len() < self.max_size {
346            connections.push_back(InnerConnection {
347                connection,
348                created_at: Instant::now(),
349            });
350            log::debug!("Connection recycled to pool (pool size: {}/{})", connections.len(), self.max_size);
351        } else {
352            log::debug!("Pool is full, dropping connection (pool max size: {})", self.max_size);
353        }
354        // If the pool is full, the connection will be dropped (automatically closed)
355    }
356}
357
358impl<M> Drop for ManagedConnection<M>
359where
360    M: ConnectionManager + Send + Sync + Clone + 'static,
361{
362    fn drop(&mut self) {
363        self.pool.outstanding_count.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
364        if let Some(connection) = self.connection.take() {
365            let pool = self.pool.clone();
366            _ = tokio::spawn(async move {
367                log::trace!("Recycling connection to pool on drop");
368                pool.recycle(connection).await;
369            });
370        }
371    }
372}
373
374// Generic implementations for AsRef and AsMut
375impl<M> AsRef<M::Connection> for ManagedConnection<M>
376where
377    M: ConnectionManager + Send + Sync + Clone + 'static,
378{
379    fn as_ref(&self) -> &M::Connection {
380        self.connection.as_ref().unwrap()
381    }
382}
383
384impl<M> AsMut<M::Connection> for ManagedConnection<M>
385where
386    M: ConnectionManager + Send + Sync + Clone + 'static,
387{
388    fn as_mut(&mut self) -> &mut M::Connection {
389        self.connection.as_mut().unwrap()
390    }
391}
392
393// Implement Deref and DerefMut for PooledStream
394impl<M> std::ops::Deref for ManagedConnection<M>
395where
396    M: ConnectionManager + Send + Sync + Clone + 'static,
397{
398    type Target = M::Connection;
399
400    fn deref(&self) -> &Self::Target {
401        self.connection.as_ref().unwrap()
402    }
403}
404
405impl<M> std::ops::DerefMut for ManagedConnection<M>
406where
407    M: ConnectionManager + Send + Sync + Clone + 'static,
408{
409    fn deref_mut(&mut self) -> &mut Self::Target {
410        self.connection.as_mut().unwrap()
411    }
412}
413
414/// Connection pool errors
415#[derive(Debug)]
416pub enum PoolError<E> {
417    PoolClosed,
418    Timeout,
419    Creation(E),
420}
421
422impl<E: std::fmt::Display> std::fmt::Display for PoolError<E> {
423    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424        match self {
425            PoolError::PoolClosed => write!(f, "Connection pool is closed"),
426            PoolError::Timeout => write!(f, "Connection creation timeout"),
427            PoolError::Creation(e) => write!(f, "Connection creation failed: {e}"),
428        }
429    }
430}
431
432impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
433    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
434        match self {
435            PoolError::Creation(e) => Some(e),
436            _ => None,
437        }
438    }
439}