connection_pool/
connection_pool.rs

1use std::collections::VecDeque;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::{Mutex, Semaphore};
6use tokio::task::JoinHandle;
7use tokio::time::{interval, timeout};
8
9pub const DEFAULT_MAX_SIZE: usize = 10;
10pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(5 * 60); // 5 minutes
11pub const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); // 10 seconds
12pub const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); // 30 seconds
13
14/// Configuration for background cleanup task
15#[derive(Clone)]
16pub struct CleanupConfig {
17    pub interval: Duration,
18    pub enabled: bool,
19}
20
21impl Default for CleanupConfig {
22    fn default() -> Self {
23        Self {
24            interval: DEFAULT_CLEANUP_INTERVAL,
25            enabled: true,
26        }
27    }
28}
29
30/// Background cleanup task controller
31struct CleanupTaskController {
32    handle: Option<JoinHandle<()>>,
33    shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
34}
35
36impl CleanupTaskController {
37    fn new() -> Self {
38        Self {
39            handle: None,
40            shutdown_tx: None,
41        }
42    }
43
44    fn start<T, M>(
45        &mut self,
46        connections: Arc<Mutex<VecDeque<InnerConnection<T>>>>,
47        max_idle_time: Duration,
48        cleanup_interval: Duration,
49        manager: Arc<M>,
50    ) where
51        T: Send + 'static,
52        M: ConnectionManager<Connection = T> + Send + Sync + 'static,
53    {
54        if self.handle.is_some() {
55            log::warn!("Cleanup task is already running");
56            return;
57        }
58        let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
59        self.shutdown_tx = Some(shutdown_tx);
60        let handle = tokio::spawn(async move {
61            let mut interval_timer = interval(cleanup_interval);
62            log::info!("Background cleanup task started with interval: {cleanup_interval:?}");
63
64            loop {
65                tokio::select! {
66                    _ = interval_timer.tick() => {}
67                    _ = &mut shutdown_rx => {
68                        log::info!("Received shutdown signal, exiting cleanup loop");
69                        break;
70                    }
71                };
72
73                let mut connections = connections.lock().await;
74                let initial_count = connections.len();
75                let now = Instant::now();
76
77                // check both idle time and is_valid
78                let mut valid_connections = VecDeque::new();
79                for mut conn in connections.drain(..) {
80                    let not_expired = now.duration_since(conn.created_at) < max_idle_time;
81                    let is_valid = if not_expired {
82                        manager.is_valid(&mut conn.connection).await
83                    } else {
84                        false
85                    };
86                    if not_expired && is_valid {
87                        valid_connections.push_back(conn);
88                    }
89                }
90                let removed_count = initial_count - valid_connections.len();
91                *connections = valid_connections;
92
93                if removed_count > 0 {
94                    log::debug!("Background cleanup removed {removed_count} expired/invalid connections");
95                }
96
97                log::debug!("Current pool (remaining {}) after cleanup", connections.len());
98            }
99        });
100
101        self.handle = Some(handle);
102    }
103
104    async fn stop(&mut self) {
105        if let Some(shutdown_tx) = self.shutdown_tx.take() {
106            let _ = shutdown_tx.send(());
107        }
108        if let Some(handle) = self.handle.take() {
109            if let Err(e) = handle.await {
110                log::error!("Error while stopping background cleanup task: {e}");
111            }
112            log::info!("Background cleanup task stopped in async stop");
113        }
114    }
115
116    fn stop_sync(&mut self) {
117        if let Some(shutdown_tx) = self.shutdown_tx.take() {
118            let _ = shutdown_tx.send(());
119        }
120        if let Some(handle) = self.handle.take() {
121            std::thread::sleep(std::time::Duration::from_millis(100));
122            handle.abort();
123            log::info!("Background cleanup task stopped in stop_sync");
124        }
125    }
126}
127
128impl Drop for CleanupTaskController {
129    fn drop(&mut self) {
130        self.stop_sync();
131    }
132}
133
134/// Manager responsible for creating new [`ConnectionManager::Connection`]s or checking existing ones.
135pub trait ConnectionManager: Sync + Send + Clone {
136    /// Type of [`ConnectionManager::Connection`]s that this [`ConnectionManager`] creates and recycles.
137    type Connection: Send;
138
139    /// Error that this [`ConnectionManager`] can return when creating and/or recycling [`ConnectionManager::Connection`]s.
140    type Error: std::error::Error + Send + Sync + 'static;
141
142    /// Future that resolves to a new [`ConnectionManager::Connection`] when created.
143    type CreateFut: Future<Output = Result<Self::Connection, Self::Error>> + Send;
144
145    /// Future that resolves to true if the connection is valid, false otherwise.
146    type ValidFut<'a>: Future<Output = bool> + Send
147    where
148        Self: 'a;
149
150    /// Create a new connection.
151    fn create_connection(&self) -> Self::CreateFut;
152
153    /// Check if a connection is valid.
154    fn is_valid<'a>(&'a self, connection: &'a mut Self::Connection) -> Self::ValidFut<'a>;
155}
156
157/// Connection pool
158pub struct ConnectionPool<M>
159where
160    M: ConnectionManager + Send + Sync + Clone + 'static,
161{
162    connections: Arc<Mutex<VecDeque<InnerConnection<M::Connection>>>>,
163    semaphore: Arc<Semaphore>,
164    max_size: usize,
165    manager: M,
166    max_idle_time: Duration,
167    connection_timeout: Duration,
168    cleanup_controller: Arc<Mutex<CleanupTaskController>>,
169}
170
171/// Pooled inner connection, used within the connection pool
172struct InnerConnection<T> {
173    pub connection: T,
174    pub created_at: Instant,
175}
176
177/// Pooled managed stream, provided for the outer world usage
178pub struct ManagedConnection<M>
179where
180    M: ConnectionManager + Send + Sync + 'static,
181{
182    connection: Option<M::Connection>,
183    pool: Arc<ConnectionPool<M>>,
184    _permit: tokio::sync::OwnedSemaphorePermit,
185}
186
187impl<M> ConnectionPool<M>
188where
189    M: ConnectionManager + Send + Sync + Clone + 'static,
190{
191    /// Create a new connection pool
192    pub fn new(
193        max_size: Option<usize>,
194        max_idle_time: Option<Duration>,
195        connection_timeout: Option<Duration>,
196        cleanup_config: Option<CleanupConfig>,
197        manager: M,
198    ) -> Arc<Self> {
199        let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
200        let max_idle_time = max_idle_time.unwrap_or(DEFAULT_IDLE_TIMEOUT);
201        let connection_timeout = connection_timeout.unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
202        let cleanup_config = cleanup_config.unwrap_or_default();
203
204        log::info!(
205            "Creating connection pool with max_size: {max_size}, idle_timeout: {max_idle_time:?}, connection_timeout: {connection_timeout:?}, cleanup_enabled: {}",
206            cleanup_config.enabled
207        );
208
209        let connections = Arc::new(Mutex::new(VecDeque::new()));
210        let cleanup_controller = Arc::new(Mutex::new(CleanupTaskController::new()));
211
212        let pool = Arc::new(ConnectionPool {
213            connections: connections.clone(),
214            semaphore: Arc::new(Semaphore::new(max_size)),
215            max_size,
216            manager,
217            max_idle_time,
218            connection_timeout,
219            cleanup_controller: cleanup_controller.clone(),
220        });
221
222        // Start background cleanup task if enabled
223        if cleanup_config.enabled {
224            let manager = Arc::new(pool.manager.clone());
225            tokio::spawn(async move {
226                let mut controller = cleanup_controller.lock().await;
227                controller.start(connections, max_idle_time, cleanup_config.interval, manager);
228            });
229        }
230
231        pool
232    }
233
234    /// Get a connection from the pool
235    pub async fn get_connection(self: Arc<Self>) -> Result<ManagedConnection<M>, PoolError<M::Error>> {
236        log::debug!("Attempting to get connection from pool");
237
238        // Use semaphore to limit concurrent connections
239        let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| PoolError::PoolClosed)?;
240
241        // Try to get an existing connection from the pool
242        {
243            let mut connections = self.connections.lock().await;
244            loop {
245                let Some(mut pooled_conn) = connections.pop_front() else {
246                    // No available connection, break the loop
247                    break;
248                };
249                log::trace!("Found existing connection in pool, validating...");
250                let age = Instant::now().duration_since(pooled_conn.created_at);
251                let valid = self.manager.is_valid(&mut pooled_conn.connection).await;
252                if age >= self.max_idle_time {
253                    log::debug!("Connection expired (age: {age:?}), discarding");
254                } else if !valid {
255                    log::warn!("Connection validation failed, discarding invalid connection");
256                } else {
257                    let size = connections.len();
258                    log::debug!("Reusing existing connection from pool (remaining: {size}/{})", self.max_size);
259                    return Ok(ManagedConnection {
260                        connection: Some(pooled_conn.connection),
261                        pool: self.clone(),
262                        _permit: permit,
263                    });
264                }
265            }
266        }
267
268        log::trace!("No valid connection available, creating new connection...");
269        // Create new connection
270        match timeout(self.connection_timeout, self.manager.create_connection()).await {
271            Ok(Ok(connection)) => {
272                log::info!("Successfully created new connection");
273                Ok(ManagedConnection {
274                    connection: Some(connection),
275                    pool: self.clone(),
276                    _permit: permit,
277                })
278            }
279            Ok(Err(e)) => {
280                log::error!("Failed to create new connection");
281                Err(PoolError::Creation(e))
282            }
283            Err(_) => {
284                log::warn!("Connection creation timed out after {:?}", self.connection_timeout);
285                Err(PoolError::Timeout)
286            }
287        }
288    }
289
290    /// Stop the background cleanup task
291    pub async fn stop_cleanup_task(&self) {
292        let mut controller = self.cleanup_controller.lock().await;
293        controller.stop().await;
294    }
295
296    /// Restart the background cleanup task with new configuration
297    pub async fn restart_cleanup_task(&self, cleanup_config: CleanupConfig) {
298        self.stop_cleanup_task().await;
299
300        if cleanup_config.enabled {
301            let manager = Arc::new(self.manager.clone());
302            let mut controller = self.cleanup_controller.lock().await;
303            controller.start(self.connections.clone(), self.max_idle_time, cleanup_config.interval, manager);
304        }
305    }
306}
307
308impl<M> ConnectionPool<M>
309where
310    M: ConnectionManager + Send + Sync + Clone + 'static,
311{
312    async fn recycle(&self, mut connection: M::Connection) {
313        if !self.manager.is_valid(&mut connection).await {
314            log::debug!("Invalid connection, dropping");
315            return;
316        }
317        let mut connections = self.connections.lock().await;
318        if connections.len() < self.max_size {
319            connections.push_back(InnerConnection {
320                connection,
321                created_at: Instant::now(),
322            });
323            log::debug!("Connection returned to pool (pool size: {}/{})", connections.len(), self.max_size);
324        } else {
325            log::debug!("Pool is full, dropping connection (pool max size: {})", self.max_size);
326        }
327        // If the pool is full, the connection will be dropped (automatically closed)
328    }
329}
330
331impl<M> Drop for ManagedConnection<M>
332where
333    M: ConnectionManager + Send + Sync + Clone + 'static,
334{
335    fn drop(&mut self) {
336        if let Some(connection) = self.connection.take() {
337            let pool = self.pool.clone();
338            _ = tokio::spawn(async move {
339                log::trace!("Returning connection to pool on drop");
340                pool.recycle(connection).await;
341            });
342        }
343    }
344}
345
346// Generic implementations for AsRef and AsMut
347impl<M> AsRef<M::Connection> for ManagedConnection<M>
348where
349    M: ConnectionManager + Send + Sync + Clone + 'static,
350{
351    fn as_ref(&self) -> &M::Connection {
352        self.connection.as_ref().unwrap()
353    }
354}
355
356impl<M> AsMut<M::Connection> for ManagedConnection<M>
357where
358    M: ConnectionManager + Send + Sync + Clone + 'static,
359{
360    fn as_mut(&mut self) -> &mut M::Connection {
361        self.connection.as_mut().unwrap()
362    }
363}
364
365// Implement Deref and DerefMut for PooledStream
366impl<M> std::ops::Deref for ManagedConnection<M>
367where
368    M: ConnectionManager + Send + Sync + Clone + 'static,
369{
370    type Target = M::Connection;
371
372    fn deref(&self) -> &Self::Target {
373        self.connection.as_ref().unwrap()
374    }
375}
376
377impl<M> std::ops::DerefMut for ManagedConnection<M>
378where
379    M: ConnectionManager + Send + Sync + Clone + 'static,
380{
381    fn deref_mut(&mut self) -> &mut Self::Target {
382        self.connection.as_mut().unwrap()
383    }
384}
385
386/// Connection pool errors
387#[derive(Debug)]
388pub enum PoolError<E> {
389    PoolClosed,
390    Timeout,
391    Creation(E),
392}
393
394impl<E: std::fmt::Display> std::fmt::Display for PoolError<E> {
395    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396        match self {
397            PoolError::PoolClosed => write!(f, "Connection pool is closed"),
398            PoolError::Timeout => write!(f, "Connection creation timeout"),
399            PoolError::Creation(e) => write!(f, "Connection creation failed: {e}"),
400        }
401    }
402}
403
404impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
405    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
406        match self {
407            PoolError::Creation(e) => Some(e),
408            _ => None,
409        }
410    }
411}