connection_pool/
connection_pool.rs

1use std::collections::VecDeque;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::{Mutex, Semaphore};
6use tokio::task::JoinHandle;
7use tokio::time::{interval, timeout};
8
9pub const DEFAULT_MAX_SIZE: usize = 10;
10pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(5 * 60); // 5 minutes
11pub const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); // 10 seconds
12pub const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); // 30 seconds
13
14/// Configuration for background cleanup task
15#[derive(Clone)]
16pub struct CleanupConfig {
17    pub interval: Duration,
18    pub enabled: bool,
19}
20
21impl Default for CleanupConfig {
22    fn default() -> Self {
23        Self {
24            interval: DEFAULT_CLEANUP_INTERVAL,
25            enabled: true,
26        }
27    }
28}
29
30/// Background cleanup task controller
31pub struct CleanupTaskController {
32    handle: Option<JoinHandle<()>>,
33}
34
35impl CleanupTaskController {
36    pub fn new() -> Self {
37        Self { handle: None }
38    }
39
40    pub fn start<T: Send + 'static>(
41        &mut self,
42        connections: Arc<Mutex<VecDeque<PooledConnection<T>>>>,
43        max_idle_time: Duration,
44        cleanup_interval: Duration,
45    ) {
46        if self.handle.is_some() {
47            log::warn!("Cleanup task is already running");
48            return;
49        }
50
51        let handle = tokio::spawn(async move {
52            let mut interval_timer = interval(cleanup_interval);
53            log::info!("Background cleanup task started with interval: {cleanup_interval:?}");
54
55            loop {
56                interval_timer.tick().await;
57
58                let mut connections = connections.lock().await;
59                let initial_count = connections.len();
60                let now = Instant::now();
61
62                connections.retain(|conn| now.duration_since(conn.created_at) < max_idle_time);
63
64                let removed_count = initial_count - connections.len();
65                if removed_count > 0 {
66                    log::debug!("Background cleanup removed {removed_count} expired connections");
67                }
68
69                // Release the lock
70                drop(connections);
71            }
72        });
73
74        self.handle = Some(handle);
75    }
76
77    pub fn stop(&mut self) {
78        if let Some(handle) = self.handle.take() {
79            handle.abort();
80            log::info!("Background cleanup task stopped");
81        }
82    }
83}
84
85impl Drop for CleanupTaskController {
86    fn drop(&mut self) {
87        self.stop();
88    }
89}
90
91/// Manager responsible for creating new [`ConnectionManager::Connection`]s or checking existing ones.
92pub trait ConnectionManager: Sync + Send {
93    /// Type of [`ConnectionManager::Connection`]s that this [`ConnectionManager`] creates and recycles.
94    type Connection: Send;
95
96    /// Error that this [`ConnectionManager`] can return when creating and/or recycling [`ConnectionManager::Connection`]s.
97    type Error: std::error::Error + Send + Sync + 'static;
98
99    /// Future that resolves to a new [`ConnectionManager::Connection`] when created.
100    type CreateFut: Future<Output = Result<Self::Connection, Self::Error>> + Send;
101
102    /// Future that resolves to true if the connection is valid, false otherwise.
103    type ValidFut<'a>: Future<Output = bool> + Send
104    where
105        Self: 'a;
106
107    /// Create a new connection.
108    fn create_connection(&self) -> Self::CreateFut;
109
110    /// Check if a connection is valid.
111    fn is_valid<'a>(&'a self, connection: &'a Self::Connection) -> Self::ValidFut<'a>;
112}
113
114/// Connection pool
115pub struct ConnectionPool<M>
116where
117    M: ConnectionManager + Send + Sync + 'static,
118{
119    connections: Arc<Mutex<VecDeque<PooledConnection<M::Connection>>>>,
120    semaphore: Arc<Semaphore>,
121    max_size: usize,
122    manager: M,
123    max_idle_time: Duration,
124    connection_timeout: Duration,
125    cleanup_controller: Arc<Mutex<CleanupTaskController>>,
126}
127
128/// Pooled connection, used within the connection pool
129pub struct PooledConnection<T> {
130    pub connection: T,
131    pub created_at: Instant,
132}
133
134/// Pooled stream, provided for the outer world usage
135pub struct PooledStream<M>
136where
137    M: ConnectionManager + Send + Sync + 'static,
138{
139    connection: Option<M::Connection>,
140    pool: Arc<ConnectionPool<M>>,
141    _permit: tokio::sync::OwnedSemaphorePermit,
142}
143
144impl<M> ConnectionPool<M>
145where
146    M: ConnectionManager + Send + Sync + 'static,
147{
148    /// Create a new connection pool
149    pub fn new(
150        max_size: Option<usize>,
151        max_idle_time: Option<Duration>,
152        connection_timeout: Option<Duration>,
153        cleanup_config: Option<CleanupConfig>,
154        manager: M,
155    ) -> Arc<Self> {
156        let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
157        let max_idle_time = max_idle_time.unwrap_or(DEFAULT_IDLE_TIMEOUT);
158        let connection_timeout = connection_timeout.unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
159        let cleanup_config = cleanup_config.unwrap_or_default();
160
161        log::info!(
162            "Creating connection pool with max_size: {max_size}, idle_timeout: {max_idle_time:?}, connection_timeout: {connection_timeout:?}, cleanup_enabled: {}",
163            cleanup_config.enabled
164        );
165
166        let connections = Arc::new(Mutex::new(VecDeque::new()));
167        let cleanup_controller = Arc::new(Mutex::new(CleanupTaskController::new()));
168
169        let pool = Arc::new(ConnectionPool {
170            connections: connections.clone(),
171            semaphore: Arc::new(Semaphore::new(max_size)),
172            max_size,
173            manager,
174            max_idle_time,
175            connection_timeout,
176            cleanup_controller: cleanup_controller.clone(),
177        });
178
179        // Start background cleanup task if enabled
180        if cleanup_config.enabled {
181            tokio::spawn(async move {
182                let mut controller = cleanup_controller.lock().await;
183                controller.start(connections, max_idle_time, cleanup_config.interval);
184            });
185        }
186
187        pool
188    }
189
190    /// Get a connection from the pool
191    pub async fn get_connection(self: Arc<Self>) -> Result<PooledStream<M>, PoolError<M::Error>> {
192        log::debug!("Attempting to get connection from pool");
193
194        // Use semaphore to limit concurrent connections
195        let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| PoolError::PoolClosed)?;
196
197        {
198            // Try to get an existing connection from the pool
199            let mut connections = self.connections.lock().await;
200
201            // With background cleanup enabled, we can skip the inline cleanup
202            // for better performance, but still do a quick validation
203            if let Some(pooled_conn) = connections.pop_front() {
204                log::trace!("Found existing connection in pool, validating...");
205
206                // Quick check if connection is not obviously expired
207                let age = Instant::now().duration_since(pooled_conn.created_at);
208                if age >= self.max_idle_time {
209                    log::debug!("Connection expired (age: {age:?}), discarding");
210                } else if self.manager.is_valid(&pooled_conn.connection).await {
211                    log::debug!("Reusing existing connection from pool (remaining: {})", connections.len());
212                    return Ok(PooledStream {
213                        connection: Some(pooled_conn.connection),
214                        pool: self.clone(),
215                        _permit: permit,
216                    });
217                } else {
218                    log::warn!("Connection validation failed, discarding invalid connection");
219                }
220            }
221        }
222
223        log::trace!("No valid connection available, creating new connection...");
224        // Create new connection
225        match timeout(self.connection_timeout, self.manager.create_connection()).await {
226            Ok(Ok(connection)) => {
227                log::info!("Successfully created new connection");
228                Ok(PooledStream {
229                    connection: Some(connection),
230                    pool: self.clone(),
231                    _permit: permit,
232                })
233            }
234            Ok(Err(e)) => {
235                log::error!("Failed to create new connection");
236                Err(PoolError::Creation(e))
237            }
238            Err(_) => {
239                log::warn!("Connection creation timed out after {:?}", self.connection_timeout);
240                Err(PoolError::Timeout)
241            }
242        }
243    }
244
245    /// Stop the background cleanup task
246    pub async fn stop_cleanup_task(&self) {
247        let mut controller = self.cleanup_controller.lock().await;
248        controller.stop();
249    }
250
251    /// Restart the background cleanup task with new configuration
252    pub async fn restart_cleanup_task(&self, cleanup_config: CleanupConfig) {
253        let mut controller = self.cleanup_controller.lock().await;
254        controller.stop();
255
256        if cleanup_config.enabled {
257            controller.start(self.connections.clone(), self.max_idle_time, cleanup_config.interval);
258        }
259    }
260}
261
262impl<M> ConnectionPool<M>
263where
264    M: ConnectionManager + Send + Sync + 'static,
265{
266    async fn return_connection(&self, connection: M::Connection) {
267        let mut connections = self.connections.lock().await;
268        if connections.len() < self.max_size {
269            connections.push_back(PooledConnection {
270                connection,
271                created_at: Instant::now(),
272            });
273            log::trace!("Connection returned to pool (pool size: {})", connections.len());
274        } else {
275            log::trace!("Pool is full, dropping connection (max_size: {})", self.max_size);
276        }
277        // If the pool is full, the connection will be dropped (automatically closed)
278    }
279}
280
281impl<M> Drop for PooledStream<M>
282where
283    M: ConnectionManager + Send + Sync + 'static,
284{
285    fn drop(&mut self) {
286        if let Some(connection) = self.connection.take() {
287            let pool = self.pool.clone();
288            if let Ok(handle) = tokio::runtime::Handle::try_current() {
289                log::trace!("Returning connection to pool on drop");
290                tokio::task::block_in_place(|| handle.block_on(pool.return_connection(connection)));
291            } else {
292                log::warn!("No tokio runtime available, connection will be dropped");
293            }
294        }
295    }
296}
297
298// Generic implementations for AsRef and AsMut
299impl<M> AsRef<M::Connection> for PooledStream<M>
300where
301    M: ConnectionManager + Send + Sync + 'static,
302{
303    fn as_ref(&self) -> &M::Connection {
304        self.connection.as_ref().unwrap()
305    }
306}
307
308impl<M> AsMut<M::Connection> for PooledStream<M>
309where
310    M: ConnectionManager + Send + Sync + 'static,
311{
312    fn as_mut(&mut self) -> &mut M::Connection {
313        self.connection.as_mut().unwrap()
314    }
315}
316
317// Implement Deref and DerefMut for PooledStream
318impl<M> std::ops::Deref for PooledStream<M>
319where
320    M: ConnectionManager + Send + Sync + 'static,
321{
322    type Target = M::Connection;
323
324    fn deref(&self) -> &Self::Target {
325        self.connection.as_ref().unwrap()
326    }
327}
328
329impl<M> std::ops::DerefMut for PooledStream<M>
330where
331    M: ConnectionManager + Send + Sync + 'static,
332{
333    fn deref_mut(&mut self) -> &mut Self::Target {
334        self.connection.as_mut().unwrap()
335    }
336}
337
338/// Connection pool errors
339#[derive(Debug)]
340pub enum PoolError<E> {
341    PoolClosed,
342    Timeout,
343    Creation(E),
344}
345
346impl<E: std::fmt::Display> std::fmt::Display for PoolError<E> {
347    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
348        match self {
349            PoolError::PoolClosed => write!(f, "Connection pool is closed"),
350            PoolError::Timeout => write!(f, "Connection creation timeout"),
351            PoolError::Creation(e) => write!(f, "Connection creation failed: {e}"),
352        }
353    }
354}
355
356impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
357    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
358        match self {
359            PoolError::Creation(e) => Some(e),
360            _ => None,
361        }
362    }
363}