Skip to main content

connection_pool/
connection_pool.rs

1use std::collections::VecDeque;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::{Mutex, Semaphore};
6use tokio::task::JoinHandle;
7use tokio::time::{interval, timeout};
8
9pub const DEFAULT_MAX_SIZE: usize = 10;
10pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(5 * 60); // 5 minutes
11pub const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); // 10 seconds
12pub const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); // 30 seconds
13
14/// Configuration for background cleanup task
15#[derive(Clone)]
16pub struct CleanupConfig {
17    pub interval: Duration,
18    pub enabled: bool,
19}
20
21impl Default for CleanupConfig {
22    fn default() -> Self {
23        Self {
24            interval: DEFAULT_CLEANUP_INTERVAL,
25            enabled: true,
26        }
27    }
28}
29
30/// Background cleanup task controller
31struct CleanupTaskController {
32    handle: Option<JoinHandle<()>>,
33    shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
34}
35
36impl CleanupTaskController {
37    fn new() -> Self {
38        Self {
39            handle: None,
40            shutdown_tx: None,
41        }
42    }
43
44    fn start<T, M>(
45        &mut self,
46        connections: Arc<Mutex<VecDeque<InnerConnection<T>>>>,
47        max_idle_time: Duration,
48        cleanup_interval: Duration,
49        manager: Arc<M>,
50        max_size: usize,
51    ) where
52        T: Send + 'static,
53        M: ConnectionManager<Connection = T> + Send + Sync + 'static,
54    {
55        if self.handle.is_some() {
56            log::warn!("Cleanup task is already running");
57            return;
58        }
59        let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
60        self.shutdown_tx = Some(shutdown_tx);
61        let handle = tokio::spawn(async move {
62            let mut interval_timer = interval(cleanup_interval);
63            log::info!("Background cleanup task started with interval: {cleanup_interval:?}");
64
65            loop {
66                tokio::select! {
67                    _ = interval_timer.tick() => {}
68                    _ = &mut shutdown_rx => {
69                        log::info!("Received shutdown signal, exiting cleanup loop");
70                        break;
71                    }
72                };
73
74                let mut connections = connections.lock().await;
75                let initial_count = connections.len();
76                let now = Instant::now();
77
78                // check both idle time and is_valid
79                let mut valid_connections = VecDeque::new();
80                for mut conn in connections.drain(..) {
81                    let not_expired = now.duration_since(conn.created_at) < max_idle_time;
82                    let is_valid = if not_expired {
83                        manager.is_valid(&mut conn.connection).await
84                    } else {
85                        false
86                    };
87                    if not_expired && is_valid {
88                        valid_connections.push_back(conn);
89                    }
90                }
91                let removed_count = initial_count - valid_connections.len();
92                *connections = valid_connections;
93
94                if removed_count > 0 {
95                    log::debug!("Background cleanup removed {removed_count} expired/invalid connections");
96                }
97
98                log::debug!("Current pool (remaining {}/{max_size}) after cleanup", connections.len());
99            }
100        });
101
102        self.handle = Some(handle);
103    }
104
105    async fn stop(&mut self) {
106        if let Some(shutdown_tx) = self.shutdown_tx.take() {
107            let _ = shutdown_tx.send(());
108        }
109        if let Some(handle) = self.handle.take() {
110            if let Err(e) = handle.await {
111                log::error!("Error while stopping background cleanup task: {e}");
112            }
113            log::info!("Background cleanup task stopped in async stop");
114        }
115    }
116
117    fn stop_sync(&mut self) {
118        if let Some(shutdown_tx) = self.shutdown_tx.take() {
119            let _ = shutdown_tx.send(());
120        }
121        if let Some(handle) = self.handle.take() {
122            std::thread::sleep(std::time::Duration::from_millis(100));
123            handle.abort();
124            log::info!("Background cleanup task stopped in stop_sync");
125        }
126    }
127}
128
129impl Drop for CleanupTaskController {
130    fn drop(&mut self) {
131        self.stop_sync();
132    }
133}
134
135/// Manager responsible for creating new [`ConnectionManager::Connection`]s or checking existing ones.
136pub trait ConnectionManager: Sync + Send + Clone {
137    /// Type of [`ConnectionManager::Connection`]s that this [`ConnectionManager`] creates and recycles.
138    type Connection: Send;
139
140    /// Error that this [`ConnectionManager`] can return when creating and/or recycling [`ConnectionManager::Connection`]s.
141    type Error: std::error::Error + Send + Sync + 'static;
142
143    /// Future that resolves to a new [`ConnectionManager::Connection`] when created.
144    type CreateFut: Future<Output = Result<Self::Connection, Self::Error>> + Send;
145
146    /// Future that resolves to true if the connection is valid, false otherwise.
147    type ValidFut<'a>: Future<Output = bool> + Send
148    where
149        Self: 'a;
150
151    /// Create a new connection.
152    fn create_connection(&self) -> Self::CreateFut;
153
154    /// Check if a connection is valid.
155    fn is_valid<'a>(&'a self, connection: &'a mut Self::Connection) -> Self::ValidFut<'a>;
156}
157
158/// Connection pool
159pub struct ConnectionPool<M: ConnectionManager> {
160    connections: Arc<Mutex<VecDeque<InnerConnection<M::Connection>>>>,
161    semaphore: Arc<Semaphore>,
162    max_size: usize,
163    manager: M,
164    max_idle_time: Duration,
165    connection_timeout: Duration,
166    cleanup_controller: Arc<Mutex<CleanupTaskController>>,
167    outstanding_count: Arc<std::sync::atomic::AtomicUsize>,
168}
169
170/// Pooled inner connection, used within the connection pool
171struct InnerConnection<T> {
172    pub connection: T,
173    pub created_at: Instant,
174}
175
176/// Pooled managed stream, provided for the outer world usage
177pub struct ManagedConnection<M>
178where
179    M: ConnectionManager + Send + Sync + 'static,
180{
181    connection: Option<M::Connection>,
182    pool: Arc<ConnectionPool<M>>,
183    _permit: tokio::sync::OwnedSemaphorePermit,
184}
185
186impl<M: ConnectionManager> ManagedConnection<M> {
187    /// Consume the managed connection and return the inner connection
188    pub fn into_inner(mut self) -> M::Connection {
189        self.connection.take().unwrap()
190    }
191
192    fn new(connection: M::Connection, pool: Arc<ConnectionPool<M>>, permit: tokio::sync::OwnedSemaphorePermit) -> Self {
193        pool.outstanding_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
194        ManagedConnection {
195            connection: Some(connection),
196            pool,
197            _permit: permit,
198        }
199    }
200}
201
202impl<M> ConnectionPool<M>
203where
204    M: ConnectionManager + Send + Sync + Clone + 'static,
205{
206    /// Create a new connection pool
207    pub fn new(
208        max_size: Option<usize>,
209        max_idle_time: Option<Duration>,
210        connection_timeout: Option<Duration>,
211        cleanup_config: Option<CleanupConfig>,
212        manager: M,
213    ) -> Arc<Self> {
214        let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
215        let max_idle_time = max_idle_time.unwrap_or(DEFAULT_IDLE_TIMEOUT);
216        let connection_timeout = connection_timeout.unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
217        let cleanup_config = cleanup_config.unwrap_or_default();
218
219        log::info!(
220            "Creating connection pool with max_size: {max_size}, idle_timeout: {max_idle_time:?}, connection_timeout: {connection_timeout:?}, cleanup_enabled: {}",
221            cleanup_config.enabled
222        );
223
224        let connections = Arc::new(Mutex::new(VecDeque::new()));
225        let cleanup_controller = Arc::new(Mutex::new(CleanupTaskController::new()));
226
227        let pool = Arc::new(ConnectionPool {
228            connections: connections.clone(),
229            semaphore: Arc::new(Semaphore::new(max_size)),
230            max_size,
231            manager,
232            max_idle_time,
233            connection_timeout,
234            cleanup_controller: cleanup_controller.clone(),
235            outstanding_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
236        });
237
238        // Start background cleanup task if enabled
239        if cleanup_config.enabled {
240            let manager = Arc::new(pool.manager.clone());
241            tokio::spawn(async move {
242                let mut controller = cleanup_controller.lock().await;
243                controller.start(connections, max_idle_time, cleanup_config.interval, manager, max_size);
244            });
245        }
246
247        pool
248    }
249
250    /// Get a connection from the pool
251    pub async fn get_connection(self: Arc<Self>) -> Result<ManagedConnection<M>, PoolError<M::Error>> {
252        log::debug!("Attempting to get connection from pool");
253
254        // Use semaphore to limit concurrent connections
255        let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| PoolError::PoolClosed)?;
256
257        // Try to get an existing connection from the pool
258        {
259            let mut connections = self.connections.lock().await;
260            loop {
261                let Some(mut pooled_conn) = connections.pop_front() else {
262                    // No available connection, break the loop
263                    break;
264                };
265                log::trace!("Found existing connection in pool, validating...");
266                let age = Instant::now().duration_since(pooled_conn.created_at);
267                let is_valid = if age < self.max_idle_time {
268                    let r = self.manager.is_valid(&mut pooled_conn.connection).await;
269                    if !r {
270                        log::warn!("Connection validation failed, discarding invalid connection");
271                    }
272                    r
273                } else {
274                    log::debug!("Connection expired (age: {age:?}), discarding");
275                    false
276                };
277                if is_valid {
278                    let size = connections.len();
279                    log::debug!("Reusing existing connection from pool (remaining: {size}/{})", self.max_size);
280                    return Ok(ManagedConnection::new(pooled_conn.connection, self.clone(), permit));
281                }
282            }
283        }
284
285        log::trace!("No valid connection available, creating new connection...");
286        // Create new connection
287        match timeout(self.connection_timeout, self.manager.create_connection()).await {
288            Ok(Ok(connection)) => {
289                log::info!("Successfully created new connection");
290                Ok(ManagedConnection::new(connection, self.clone(), permit))
291            }
292            Ok(Err(e)) => {
293                log::error!("Failed to create new connection");
294                Err(PoolError::Creation(e))
295            }
296            Err(_) => {
297                log::warn!("Connection creation timed out after {:?}", self.connection_timeout);
298                Err(PoolError::Timeout)
299            }
300        }
301    }
302
303    pub fn outstanding_count(&self) -> usize {
304        self.outstanding_count.load(std::sync::atomic::Ordering::SeqCst)
305    }
306
307    pub async fn pool_size(&self) -> usize {
308        self.connections.lock().await.len()
309    }
310
311    pub fn max_size(&self) -> usize {
312        self.max_size
313    }
314
315    /// Stop the background cleanup task
316    pub async fn stop_cleanup_task(&self) {
317        let mut controller = self.cleanup_controller.lock().await;
318        controller.stop().await;
319    }
320
321    /// Restart the background cleanup task with new configuration
322    pub async fn restart_cleanup_task(&self, cleanup_config: CleanupConfig) {
323        self.stop_cleanup_task().await;
324
325        if cleanup_config.enabled {
326            let manager = Arc::new(self.manager.clone());
327            let mut controller = self.cleanup_controller.lock().await;
328            let m = self.max_size;
329            controller.start(self.connections.clone(), self.max_idle_time, cleanup_config.interval, manager, m);
330        }
331    }
332}
333
334impl<M: ConnectionManager> ConnectionPool<M> {
335    async fn recycle(&self, mut connection: M::Connection) {
336        if !self.manager.is_valid(&mut connection).await {
337            log::debug!("Invalid connection, dropping");
338            return;
339        }
340        let mut connections = self.connections.lock().await;
341        if connections.len() < self.max_size {
342            connections.push_back(InnerConnection {
343                connection,
344                created_at: Instant::now(),
345            });
346            log::debug!("Connection recycled to pool (pool size: {}/{})", connections.len(), self.max_size);
347        } else {
348            log::debug!("Pool is full, dropping connection (pool max size: {})", self.max_size);
349        }
350        // If the pool is full, the connection will be dropped (automatically closed)
351    }
352}
353
354impl<M: ConnectionManager> Drop for ManagedConnection<M> {
355    fn drop(&mut self) {
356        self.pool.outstanding_count.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
357        if let Some(connection) = self.connection.take() {
358            let pool = self.pool.clone();
359            _ = tokio::spawn(async move {
360                log::trace!("Recycling connection to pool on drop");
361                pool.recycle(connection).await;
362            });
363        }
364    }
365}
366
367// Generic implementations for AsRef and AsMut
368impl<M: ConnectionManager> AsRef<M::Connection> for ManagedConnection<M> {
369    fn as_ref(&self) -> &M::Connection {
370        self.connection.as_ref().unwrap()
371    }
372}
373
374impl<M: ConnectionManager> AsMut<M::Connection> for ManagedConnection<M> {
375    fn as_mut(&mut self) -> &mut M::Connection {
376        self.connection.as_mut().unwrap()
377    }
378}
379
380// Implement Deref and DerefMut for PooledStream
381impl<M: ConnectionManager> std::ops::Deref for ManagedConnection<M> {
382    type Target = M::Connection;
383
384    fn deref(&self) -> &Self::Target {
385        self.connection.as_ref().unwrap()
386    }
387}
388
389impl<M: ConnectionManager> std::ops::DerefMut for ManagedConnection<M> {
390    fn deref_mut(&mut self) -> &mut Self::Target {
391        self.connection.as_mut().unwrap()
392    }
393}
394
395/// Connection pool errors
396#[derive(Debug)]
397pub enum PoolError<E> {
398    PoolClosed,
399    Timeout,
400    Creation(E),
401}
402
403impl<E: std::fmt::Display> std::fmt::Display for PoolError<E> {
404    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
405        match self {
406            PoolError::PoolClosed => write!(f, "Connection pool is closed"),
407            PoolError::Timeout => write!(f, "Connection creation timeout"),
408            PoolError::Creation(e) => write!(f, "Connection creation failed: {e}"),
409        }
410    }
411}
412
413impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
414    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
415        match self {
416            PoolError::Creation(e) => Some(e),
417            _ => None,
418        }
419    }
420}