connection_pool/
connection_pool.rs

1use std::collections::VecDeque;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::{Mutex, Semaphore};
6use tokio::task::JoinHandle;
7use tokio::time::{interval, timeout};
8
9pub const DEFAULT_MAX_SIZE: usize = 10;
10pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(5 * 60); // 5 minutes
11pub const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); // 10 seconds
12pub const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); // 30 seconds
13
14/// Configuration for background cleanup task
15#[derive(Clone)]
16pub struct CleanupConfig {
17    pub interval: Duration,
18    pub enabled: bool,
19}
20
21impl Default for CleanupConfig {
22    fn default() -> Self {
23        Self {
24            interval: DEFAULT_CLEANUP_INTERVAL,
25            enabled: true,
26        }
27    }
28}
29
30/// Background cleanup task controller
31struct CleanupTaskController {
32    handle: Option<JoinHandle<()>>,
33}
34
35impl CleanupTaskController {
36    fn new() -> Self {
37        Self { handle: None }
38    }
39
40    fn start<T, M>(
41        &mut self,
42        connections: Arc<Mutex<VecDeque<InnerConnection<T>>>>,
43        max_idle_time: Duration,
44        cleanup_interval: Duration,
45        manager: Arc<M>,
46    ) where
47        T: Send + 'static,
48        M: ConnectionManager<Connection = T> + Send + Sync + 'static,
49    {
50        if self.handle.is_some() {
51            log::warn!("Cleanup task is already running");
52            return;
53        }
54
55        let handle = tokio::spawn(async move {
56            let mut interval_timer = interval(cleanup_interval);
57            log::info!("Background cleanup task started with interval: {cleanup_interval:?}");
58
59            loop {
60                interval_timer.tick().await;
61
62                let mut connections = connections.lock().await;
63                let initial_count = connections.len();
64                let now = Instant::now();
65
66                // check both idle time and is_valid
67                let mut valid_connections = VecDeque::new();
68                for mut conn in connections.drain(..) {
69                    let not_expired = now.duration_since(conn.created_at) < max_idle_time;
70                    let is_valid = if not_expired {
71                        manager.is_valid(&mut conn.connection).await
72                    } else {
73                        false
74                    };
75                    if not_expired && is_valid {
76                        valid_connections.push_back(conn);
77                    }
78                }
79                let removed_count = initial_count - valid_connections.len();
80                *connections = valid_connections;
81
82                if removed_count > 0 {
83                    log::debug!("Background cleanup removed {removed_count} expired/invalid connections");
84                }
85
86                log::trace!("Current pool size after cleanup: {}", connections.len());
87            }
88        });
89
90        self.handle = Some(handle);
91    }
92
93    fn stop(&mut self) {
94        if let Some(handle) = self.handle.take() {
95            handle.abort();
96            log::info!("Background cleanup task stopped");
97        }
98    }
99}
100
101impl Drop for CleanupTaskController {
102    fn drop(&mut self) {
103        self.stop();
104    }
105}
106
107/// Manager responsible for creating new [`ConnectionManager::Connection`]s or checking existing ones.
108pub trait ConnectionManager: Sync + Send + Clone {
109    /// Type of [`ConnectionManager::Connection`]s that this [`ConnectionManager`] creates and recycles.
110    type Connection: Send;
111
112    /// Error that this [`ConnectionManager`] can return when creating and/or recycling [`ConnectionManager::Connection`]s.
113    type Error: std::error::Error + Send + Sync + 'static;
114
115    /// Future that resolves to a new [`ConnectionManager::Connection`] when created.
116    type CreateFut: Future<Output = Result<Self::Connection, Self::Error>> + Send;
117
118    /// Future that resolves to true if the connection is valid, false otherwise.
119    type ValidFut<'a>: Future<Output = bool> + Send
120    where
121        Self: 'a;
122
123    /// Create a new connection.
124    fn create_connection(&self) -> Self::CreateFut;
125
126    /// Check if a connection is valid.
127    fn is_valid<'a>(&'a self, connection: &'a mut Self::Connection) -> Self::ValidFut<'a>;
128}
129
130/// Connection pool
131pub struct ConnectionPool<M>
132where
133    M: ConnectionManager + Send + Sync + Clone + 'static,
134{
135    connections: Arc<Mutex<VecDeque<InnerConnection<M::Connection>>>>,
136    semaphore: Arc<Semaphore>,
137    max_size: usize,
138    manager: M,
139    max_idle_time: Duration,
140    connection_timeout: Duration,
141    cleanup_controller: Arc<Mutex<CleanupTaskController>>,
142}
143
144/// Pooled inner connection, used within the connection pool
145struct InnerConnection<T> {
146    pub connection: T,
147    pub created_at: Instant,
148}
149
150/// Pooled managed stream, provided for the outer world usage
151pub struct ManagedConnection<M>
152where
153    M: ConnectionManager + Send + Sync + 'static,
154{
155    connection: Option<M::Connection>,
156    pool: Arc<ConnectionPool<M>>,
157    _permit: tokio::sync::OwnedSemaphorePermit,
158}
159
160impl<M> ConnectionPool<M>
161where
162    M: ConnectionManager + Send + Sync + Clone + 'static,
163{
164    /// Create a new connection pool
165    pub fn new(
166        max_size: Option<usize>,
167        max_idle_time: Option<Duration>,
168        connection_timeout: Option<Duration>,
169        cleanup_config: Option<CleanupConfig>,
170        manager: M,
171    ) -> Arc<Self> {
172        let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
173        let max_idle_time = max_idle_time.unwrap_or(DEFAULT_IDLE_TIMEOUT);
174        let connection_timeout = connection_timeout.unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
175        let cleanup_config = cleanup_config.unwrap_or_default();
176
177        log::info!(
178            "Creating connection pool with max_size: {max_size}, idle_timeout: {max_idle_time:?}, connection_timeout: {connection_timeout:?}, cleanup_enabled: {}",
179            cleanup_config.enabled
180        );
181
182        let connections = Arc::new(Mutex::new(VecDeque::new()));
183        let cleanup_controller = Arc::new(Mutex::new(CleanupTaskController::new()));
184
185        let pool = Arc::new(ConnectionPool {
186            connections: connections.clone(),
187            semaphore: Arc::new(Semaphore::new(max_size)),
188            max_size,
189            manager,
190            max_idle_time,
191            connection_timeout,
192            cleanup_controller: cleanup_controller.clone(),
193        });
194
195        // Start background cleanup task if enabled
196        if cleanup_config.enabled {
197            let manager = Arc::new(pool.manager.clone());
198            tokio::spawn(async move {
199                let mut controller = cleanup_controller.lock().await;
200                controller.start(connections, max_idle_time, cleanup_config.interval, manager);
201            });
202        }
203
204        pool
205    }
206
207    /// Get a connection from the pool
208    pub async fn get_connection(self: Arc<Self>) -> Result<ManagedConnection<M>, PoolError<M::Error>> {
209        log::debug!("Attempting to get connection from pool");
210
211        // Use semaphore to limit concurrent connections
212        let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| PoolError::PoolClosed)?;
213
214        // Try to get an existing connection from the pool
215        {
216            let mut connections = self.connections.lock().await;
217            loop {
218                let Some(mut pooled_conn) = connections.pop_front() else {
219                    // No available connection, break the loop
220                    break;
221                };
222                log::trace!("Found existing connection in pool, validating...");
223                let age = Instant::now().duration_since(pooled_conn.created_at);
224                let valid = self.manager.is_valid(&mut pooled_conn.connection).await;
225                if age >= self.max_idle_time {
226                    log::debug!("Connection expired (age: {age:?}), discarding");
227                } else if !valid {
228                    log::warn!("Connection validation failed, discarding invalid connection");
229                } else {
230                    log::debug!("Reusing existing connection from pool (remaining: {})", connections.len());
231                    return Ok(ManagedConnection {
232                        connection: Some(pooled_conn.connection),
233                        pool: self.clone(),
234                        _permit: permit,
235                    });
236                }
237            }
238        }
239
240        log::trace!("No valid connection available, creating new connection...");
241        // Create new connection
242        match timeout(self.connection_timeout, self.manager.create_connection()).await {
243            Ok(Ok(connection)) => {
244                log::info!("Successfully created new connection");
245                Ok(ManagedConnection {
246                    connection: Some(connection),
247                    pool: self.clone(),
248                    _permit: permit,
249                })
250            }
251            Ok(Err(e)) => {
252                log::error!("Failed to create new connection");
253                Err(PoolError::Creation(e))
254            }
255            Err(_) => {
256                log::warn!("Connection creation timed out after {:?}", self.connection_timeout);
257                Err(PoolError::Timeout)
258            }
259        }
260    }
261
262    /// Stop the background cleanup task
263    pub async fn stop_cleanup_task(&self) {
264        let mut controller = self.cleanup_controller.lock().await;
265        controller.stop();
266    }
267
268    /// Restart the background cleanup task with new configuration
269    pub async fn restart_cleanup_task(&self, cleanup_config: CleanupConfig) {
270        self.stop_cleanup_task().await;
271
272        if cleanup_config.enabled {
273            let manager = Arc::new(self.manager.clone());
274            let mut controller = self.cleanup_controller.lock().await;
275            controller.start(self.connections.clone(), self.max_idle_time, cleanup_config.interval, manager);
276        }
277    }
278}
279
280impl<M> ConnectionPool<M>
281where
282    M: ConnectionManager + Send + Sync + Clone + 'static,
283{
284    async fn recycle(&self, mut connection: M::Connection) {
285        if !self.manager.is_valid(&mut connection).await {
286            log::debug!("Invalid connection, dropping");
287            return;
288        }
289        let mut connections = self.connections.lock().await;
290        if connections.len() < self.max_size {
291            connections.push_back(InnerConnection {
292                connection,
293                created_at: Instant::now(),
294            });
295            log::trace!("Connection returned to pool (pool size: {})", connections.len());
296        } else {
297            log::trace!("Pool is full, dropping connection (max_size: {})", self.max_size);
298        }
299        // If the pool is full, the connection will be dropped (automatically closed)
300    }
301}
302
303impl<M> Drop for ManagedConnection<M>
304where
305    M: ConnectionManager + Send + Sync + Clone + 'static,
306{
307    fn drop(&mut self) {
308        if let Some(connection) = self.connection.take() {
309            let pool = self.pool.clone();
310            _ = tokio::spawn(async move {
311                log::trace!("Returning connection to pool on drop");
312                pool.recycle(connection).await;
313            });
314        }
315    }
316}
317
318// Generic implementations for AsRef and AsMut
319impl<M> AsRef<M::Connection> for ManagedConnection<M>
320where
321    M: ConnectionManager + Send + Sync + Clone + 'static,
322{
323    fn as_ref(&self) -> &M::Connection {
324        self.connection.as_ref().unwrap()
325    }
326}
327
328impl<M> AsMut<M::Connection> for ManagedConnection<M>
329where
330    M: ConnectionManager + Send + Sync + Clone + 'static,
331{
332    fn as_mut(&mut self) -> &mut M::Connection {
333        self.connection.as_mut().unwrap()
334    }
335}
336
337// Implement Deref and DerefMut for PooledStream
338impl<M> std::ops::Deref for ManagedConnection<M>
339where
340    M: ConnectionManager + Send + Sync + Clone + 'static,
341{
342    type Target = M::Connection;
343
344    fn deref(&self) -> &Self::Target {
345        self.connection.as_ref().unwrap()
346    }
347}
348
349impl<M> std::ops::DerefMut for ManagedConnection<M>
350where
351    M: ConnectionManager + Send + Sync + Clone + 'static,
352{
353    fn deref_mut(&mut self) -> &mut Self::Target {
354        self.connection.as_mut().unwrap()
355    }
356}
357
358/// Connection pool errors
359#[derive(Debug)]
360pub enum PoolError<E> {
361    PoolClosed,
362    Timeout,
363    Creation(E),
364}
365
366impl<E: std::fmt::Display> std::fmt::Display for PoolError<E> {
367    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368        match self {
369            PoolError::PoolClosed => write!(f, "Connection pool is closed"),
370            PoolError::Timeout => write!(f, "Connection creation timeout"),
371            PoolError::Creation(e) => write!(f, "Connection creation failed: {e}"),
372        }
373    }
374}
375
376impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
377    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
378        match self {
379            PoolError::Creation(e) => Some(e),
380            _ => None,
381        }
382    }
383}