connection_pool/
connection_pool.rs

1use std::collections::VecDeque;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::{Mutex, Semaphore};
6use tokio::task::JoinHandle;
7use tokio::time::{interval, timeout};
8
9pub const DEFAULT_MAX_SIZE: usize = 10;
10pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(5 * 60); // 5 minutes
11pub const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); // 10 seconds
12pub const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); // 30 seconds
13
14/// Configuration for background cleanup task
15#[derive(Clone)]
16pub struct CleanupConfig {
17    pub interval: Duration,
18    pub enabled: bool,
19}
20
21impl Default for CleanupConfig {
22    fn default() -> Self {
23        Self {
24            interval: DEFAULT_CLEANUP_INTERVAL,
25            enabled: true,
26        }
27    }
28}
29
30/// Background cleanup task controller
31struct CleanupTaskController {
32    handle: Option<JoinHandle<()>>,
33}
34
35impl CleanupTaskController {
36    fn new() -> Self {
37        Self { handle: None }
38    }
39
40    fn start<T: Send + 'static>(
41        &mut self,
42        connections: Arc<Mutex<VecDeque<InnerConnection<T>>>>,
43        max_idle_time: Duration,
44        cleanup_interval: Duration,
45    ) {
46        if self.handle.is_some() {
47            log::warn!("Cleanup task is already running");
48            return;
49        }
50
51        let handle = tokio::spawn(async move {
52            let mut interval_timer = interval(cleanup_interval);
53            log::info!("Background cleanup task started with interval: {cleanup_interval:?}");
54
55            loop {
56                interval_timer.tick().await;
57
58                let mut connections = connections.lock().await;
59                let initial_count = connections.len();
60                let now = Instant::now();
61
62                connections.retain(|conn| now.duration_since(conn.created_at) < max_idle_time);
63
64                let removed_count = initial_count - connections.len();
65                if removed_count > 0 {
66                    log::debug!("Background cleanup removed {removed_count} expired connections");
67                }
68
69                log::trace!("Current pool size after cleanup: {}", connections.len());
70            }
71        });
72
73        self.handle = Some(handle);
74    }
75
76    fn stop(&mut self) {
77        if let Some(handle) = self.handle.take() {
78            handle.abort();
79            log::info!("Background cleanup task stopped");
80        }
81    }
82}
83
84impl Drop for CleanupTaskController {
85    fn drop(&mut self) {
86        self.stop();
87    }
88}
89
90/// Manager responsible for creating new [`ConnectionManager::Connection`]s or checking existing ones.
91pub trait ConnectionManager: Sync + Send {
92    /// Type of [`ConnectionManager::Connection`]s that this [`ConnectionManager`] creates and recycles.
93    type Connection: Send;
94
95    /// Error that this [`ConnectionManager`] can return when creating and/or recycling [`ConnectionManager::Connection`]s.
96    type Error: std::error::Error + Send + Sync + 'static;
97
98    /// Future that resolves to a new [`ConnectionManager::Connection`] when created.
99    type CreateFut: Future<Output = Result<Self::Connection, Self::Error>> + Send;
100
101    /// Future that resolves to true if the connection is valid, false otherwise.
102    type ValidFut<'a>: Future<Output = bool> + Send
103    where
104        Self: 'a;
105
106    /// Create a new connection.
107    fn create_connection(&self) -> Self::CreateFut;
108
109    /// Check if a connection is valid.
110    fn is_valid<'a>(&'a self, connection: &'a mut Self::Connection) -> Self::ValidFut<'a>;
111}
112
113/// Connection pool
114pub struct ConnectionPool<M>
115where
116    M: ConnectionManager + Send + Sync + 'static,
117{
118    connections: Arc<Mutex<VecDeque<InnerConnection<M::Connection>>>>,
119    semaphore: Arc<Semaphore>,
120    max_size: usize,
121    manager: M,
122    max_idle_time: Duration,
123    connection_timeout: Duration,
124    cleanup_controller: Arc<Mutex<CleanupTaskController>>,
125}
126
127/// Pooled inner connection, used within the connection pool
128struct InnerConnection<T> {
129    pub connection: T,
130    pub created_at: Instant,
131}
132
133/// Pooled managed stream, provided for the outer world usage
134pub struct ManagedConnection<M>
135where
136    M: ConnectionManager + Send + Sync + 'static,
137{
138    connection: Option<M::Connection>,
139    pool: Arc<ConnectionPool<M>>,
140    _permit: tokio::sync::OwnedSemaphorePermit,
141}
142
143impl<M> ConnectionPool<M>
144where
145    M: ConnectionManager + Send + Sync + 'static,
146{
147    /// Create a new connection pool
148    pub fn new(
149        max_size: Option<usize>,
150        max_idle_time: Option<Duration>,
151        connection_timeout: Option<Duration>,
152        cleanup_config: Option<CleanupConfig>,
153        manager: M,
154    ) -> Arc<Self> {
155        let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
156        let max_idle_time = max_idle_time.unwrap_or(DEFAULT_IDLE_TIMEOUT);
157        let connection_timeout = connection_timeout.unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
158        let cleanup_config = cleanup_config.unwrap_or_default();
159
160        log::info!(
161            "Creating connection pool with max_size: {max_size}, idle_timeout: {max_idle_time:?}, connection_timeout: {connection_timeout:?}, cleanup_enabled: {}",
162            cleanup_config.enabled
163        );
164
165        let connections = Arc::new(Mutex::new(VecDeque::new()));
166        let cleanup_controller = Arc::new(Mutex::new(CleanupTaskController::new()));
167
168        let pool = Arc::new(ConnectionPool {
169            connections: connections.clone(),
170            semaphore: Arc::new(Semaphore::new(max_size)),
171            max_size,
172            manager,
173            max_idle_time,
174            connection_timeout,
175            cleanup_controller: cleanup_controller.clone(),
176        });
177
178        // Start background cleanup task if enabled
179        if cleanup_config.enabled {
180            tokio::spawn(async move {
181                let mut controller = cleanup_controller.lock().await;
182                controller.start(connections, max_idle_time, cleanup_config.interval);
183            });
184        }
185
186        pool
187    }
188
189    /// Get a connection from the pool
190    pub async fn get_connection(self: Arc<Self>) -> Result<ManagedConnection<M>, PoolError<M::Error>> {
191        log::debug!("Attempting to get connection from pool");
192
193        // Use semaphore to limit concurrent connections
194        let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| PoolError::PoolClosed)?;
195
196        // Try to get an existing connection from the pool
197        {
198            let mut connections = self.connections.lock().await;
199            loop {
200                let Some(mut pooled_conn) = connections.pop_front() else {
201                    // No available connection, break the loop
202                    break;
203                };
204                log::trace!("Found existing connection in pool, validating...");
205                let age = Instant::now().duration_since(pooled_conn.created_at);
206                let valid = self.manager.is_valid(&mut pooled_conn.connection).await;
207                if age >= self.max_idle_time {
208                    log::debug!("Connection expired (age: {age:?}), discarding");
209                } else if !valid {
210                    log::warn!("Connection validation failed, discarding invalid connection");
211                } else {
212                    log::debug!("Reusing existing connection from pool (remaining: {})", connections.len());
213                    return Ok(ManagedConnection {
214                        connection: Some(pooled_conn.connection),
215                        pool: self.clone(),
216                        _permit: permit,
217                    });
218                }
219            }
220        }
221
222        log::trace!("No valid connection available, creating new connection...");
223        // Create new connection
224        match timeout(self.connection_timeout, self.manager.create_connection()).await {
225            Ok(Ok(connection)) => {
226                log::info!("Successfully created new connection");
227                Ok(ManagedConnection {
228                    connection: Some(connection),
229                    pool: self.clone(),
230                    _permit: permit,
231                })
232            }
233            Ok(Err(e)) => {
234                log::error!("Failed to create new connection");
235                Err(PoolError::Creation(e))
236            }
237            Err(_) => {
238                log::warn!("Connection creation timed out after {:?}", self.connection_timeout);
239                Err(PoolError::Timeout)
240            }
241        }
242    }
243
244    /// Stop the background cleanup task
245    pub async fn stop_cleanup_task(&self) {
246        let mut controller = self.cleanup_controller.lock().await;
247        controller.stop();
248    }
249
250    /// Restart the background cleanup task with new configuration
251    pub async fn restart_cleanup_task(&self, cleanup_config: CleanupConfig) {
252        let mut controller = self.cleanup_controller.lock().await;
253        controller.stop();
254
255        if cleanup_config.enabled {
256            controller.start(self.connections.clone(), self.max_idle_time, cleanup_config.interval);
257        }
258    }
259}
260
261impl<M> ConnectionPool<M>
262where
263    M: ConnectionManager + Send + Sync + 'static,
264{
265    async fn return_connection(&self, connection: M::Connection) {
266        let mut connections = self.connections.lock().await;
267        if connections.len() < self.max_size {
268            connections.push_back(InnerConnection {
269                connection,
270                created_at: Instant::now(),
271            });
272            log::trace!("Connection returned to pool (pool size: {})", connections.len());
273        } else {
274            log::trace!("Pool is full, dropping connection (max_size: {})", self.max_size);
275        }
276        // If the pool is full, the connection will be dropped (automatically closed)
277    }
278}
279
280impl<M> Drop for ManagedConnection<M>
281where
282    M: ConnectionManager + Send + Sync + 'static,
283{
284    fn drop(&mut self) {
285        if let Some(connection) = self.connection.take() {
286            let pool = self.pool.clone();
287            if let Ok(handle) = tokio::runtime::Handle::try_current() {
288                log::trace!("Returning connection to pool on drop");
289                tokio::task::block_in_place(|| handle.block_on(pool.return_connection(connection)));
290            } else {
291                log::warn!("No tokio runtime available, connection will be dropped");
292            }
293        }
294    }
295}
296
297// Generic implementations for AsRef and AsMut
298impl<M> AsRef<M::Connection> for ManagedConnection<M>
299where
300    M: ConnectionManager + Send + Sync + 'static,
301{
302    fn as_ref(&self) -> &M::Connection {
303        self.connection.as_ref().unwrap()
304    }
305}
306
307impl<M> AsMut<M::Connection> for ManagedConnection<M>
308where
309    M: ConnectionManager + Send + Sync + 'static,
310{
311    fn as_mut(&mut self) -> &mut M::Connection {
312        self.connection.as_mut().unwrap()
313    }
314}
315
316// Implement Deref and DerefMut for PooledStream
317impl<M> std::ops::Deref for ManagedConnection<M>
318where
319    M: ConnectionManager + Send + Sync + 'static,
320{
321    type Target = M::Connection;
322
323    fn deref(&self) -> &Self::Target {
324        self.connection.as_ref().unwrap()
325    }
326}
327
328impl<M> std::ops::DerefMut for ManagedConnection<M>
329where
330    M: ConnectionManager + Send + Sync + 'static,
331{
332    fn deref_mut(&mut self) -> &mut Self::Target {
333        self.connection.as_mut().unwrap()
334    }
335}
336
337/// Connection pool errors
338#[derive(Debug)]
339pub enum PoolError<E> {
340    PoolClosed,
341    Timeout,
342    Creation(E),
343}
344
345impl<E: std::fmt::Display> std::fmt::Display for PoolError<E> {
346    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347        match self {
348            PoolError::PoolClosed => write!(f, "Connection pool is closed"),
349            PoolError::Timeout => write!(f, "Connection creation timeout"),
350            PoolError::Creation(e) => write!(f, "Connection creation failed: {e}"),
351        }
352    }
353}
354
355impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
356    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
357        match self {
358            PoolError::Creation(e) => Some(e),
359            _ => None,
360        }
361    }
362}