connection_pool/
connection_pool.rs

1use std::collections::VecDeque;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::{Mutex, Semaphore};
6use tokio::task::JoinHandle;
7use tokio::time::{interval, timeout};
8
9pub const DEFAULT_MAX_SIZE: usize = 10;
10pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(5 * 60); // 5 minutes
11pub const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); // 10 seconds
12pub const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); // 30 seconds
13
14/// Configuration for background cleanup task
15#[derive(Clone)]
16pub struct CleanupConfig {
17    pub interval: Duration,
18    pub enabled: bool,
19}
20
21impl Default for CleanupConfig {
22    fn default() -> Self {
23        Self {
24            interval: DEFAULT_CLEANUP_INTERVAL,
25            enabled: true,
26        }
27    }
28}
29
30/// Background cleanup task controller
31pub struct CleanupTaskController {
32    handle: Option<JoinHandle<()>>,
33}
34
35impl CleanupTaskController {
36    pub fn new() -> Self {
37        Self { handle: None }
38    }
39
40    pub fn start<T: Send + 'static>(
41        &mut self,
42        connections: Arc<Mutex<VecDeque<PooledConnection<T>>>>,
43        max_idle_time: Duration,
44        cleanup_interval: Duration,
45    ) {
46        if self.handle.is_some() {
47            log::warn!("Cleanup task is already running");
48            return;
49        }
50
51        let handle = tokio::spawn(async move {
52            let mut interval_timer = interval(cleanup_interval);
53            log::info!("Background cleanup task started with interval: {cleanup_interval:?}");
54
55            loop {
56                interval_timer.tick().await;
57
58                let mut connections = connections.lock().await;
59                let initial_count = connections.len();
60                let now = Instant::now();
61
62                connections.retain(|conn| now.duration_since(conn.created_at) < max_idle_time);
63
64                let removed_count = initial_count - connections.len();
65                if removed_count > 0 {
66                    log::debug!("Background cleanup removed {removed_count} expired connections");
67                }
68
69                // Release the lock
70                drop(connections);
71            }
72        });
73
74        self.handle = Some(handle);
75    }
76
77    pub fn stop(&mut self) {
78        if let Some(handle) = self.handle.take() {
79            handle.abort();
80            log::info!("Background cleanup task stopped");
81        }
82    }
83}
84
85impl Drop for CleanupTaskController {
86    fn drop(&mut self) {
87        self.stop();
88    }
89}
90
91pub trait ConnectionManager: Sync + Send {
92    type Connection: Send;
93    type Error: std::error::Error + Send;
94    type CreateFut: Future<Output = Result<Self::Connection, Self::Error>> + Send;
95    type ValidFut<'a>: Future<Output = bool> + Send
96    where
97        Self: 'a;
98
99    fn create_connection(&self) -> Self::CreateFut;
100    fn is_valid<'a>(&'a self, connection: &'a Self::Connection) -> Self::ValidFut<'a>;
101}
102
103pub struct ConnectionPool<M>
104where
105    M: ConnectionManager + Send + Sync + 'static,
106{
107    connections: Arc<Mutex<VecDeque<PooledConnection<M::Connection>>>>,
108    semaphore: Arc<Semaphore>,
109    max_size: usize,
110    manager: M,
111    max_idle_time: Duration,
112    connection_timeout: Duration,
113    cleanup_controller: Arc<Mutex<CleanupTaskController>>,
114}
115
116pub struct PooledConnection<T> {
117    pub connection: T,
118    pub created_at: Instant,
119}
120
121pub struct PooledStream<M>
122where
123    M: ConnectionManager + Send + Sync + 'static,
124{
125    connection: Option<M::Connection>,
126    pool: Arc<ConnectionPool<M>>,
127    _permit: tokio::sync::OwnedSemaphorePermit,
128}
129
130impl<M> ConnectionPool<M>
131where
132    M: ConnectionManager + Send + Sync + 'static,
133{
134    pub fn new(
135        max_size: Option<usize>,
136        max_idle_time: Option<Duration>,
137        connection_timeout: Option<Duration>,
138        cleanup_config: Option<CleanupConfig>,
139        manager: M,
140    ) -> Arc<Self> {
141        let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
142        let max_idle_time = max_idle_time.unwrap_or(DEFAULT_IDLE_TIMEOUT);
143        let connection_timeout = connection_timeout.unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
144        let cleanup_config = cleanup_config.unwrap_or_default();
145
146        log::info!(
147            "Creating connection pool with max_size: {max_size}, idle_timeout: {max_idle_time:?}, connection_timeout: {connection_timeout:?}, cleanup_enabled: {}",
148            cleanup_config.enabled
149        );
150
151        let connections = Arc::new(Mutex::new(VecDeque::new()));
152        let cleanup_controller = Arc::new(Mutex::new(CleanupTaskController::new()));
153
154        let pool = Arc::new(ConnectionPool {
155            connections: connections.clone(),
156            semaphore: Arc::new(Semaphore::new(max_size)),
157            max_size,
158            manager,
159            max_idle_time,
160            connection_timeout,
161            cleanup_controller: cleanup_controller.clone(),
162        });
163
164        // Start background cleanup task if enabled
165        if cleanup_config.enabled {
166            tokio::spawn(async move {
167                let mut controller = cleanup_controller.lock().await;
168                controller.start(connections, max_idle_time, cleanup_config.interval);
169            });
170        }
171
172        pool
173    }
174
175    pub async fn get_connection(self: Arc<Self>) -> Result<PooledStream<M>, PoolError<M::Error>> {
176        log::debug!("Attempting to get connection from pool");
177
178        // Use semaphore to limit concurrent connections
179        let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| PoolError::PoolClosed)?;
180
181        {
182            // Try to get an existing connection from the pool
183            let mut connections = self.connections.lock().await;
184
185            // With background cleanup enabled, we can skip the inline cleanup
186            // for better performance, but still do a quick validation
187            if let Some(pooled_conn) = connections.pop_front() {
188                log::trace!("Found existing connection in pool, validating...");
189
190                // Quick check if connection is not obviously expired
191                let age = Instant::now().duration_since(pooled_conn.created_at);
192                if age >= self.max_idle_time {
193                    log::debug!("Connection expired (age: {age:?}), discarding");
194                } else if self.manager.is_valid(&pooled_conn.connection).await {
195                    log::debug!("Reusing existing connection from pool (remaining: {})", connections.len());
196                    return Ok(PooledStream {
197                        connection: Some(pooled_conn.connection),
198                        pool: self.clone(),
199                        _permit: permit,
200                    });
201                } else {
202                    log::warn!("Connection validation failed, discarding invalid connection");
203                }
204            }
205        }
206
207        log::trace!("No valid connection available, creating new connection...");
208        // Create new connection
209        match timeout(self.connection_timeout, self.manager.create_connection()).await {
210            Ok(Ok(connection)) => {
211                log::info!("Successfully created new connection");
212                Ok(PooledStream {
213                    connection: Some(connection),
214                    pool: self.clone(),
215                    _permit: permit,
216                })
217            }
218            Ok(Err(e)) => {
219                log::error!("Failed to create new connection");
220                Err(PoolError::Creation(e))
221            }
222            Err(_) => {
223                log::warn!("Connection creation timed out after {:?}", self.connection_timeout);
224                Err(PoolError::Timeout)
225            }
226        }
227    }
228
229    /// Stop the background cleanup task
230    pub async fn stop_cleanup_task(&self) {
231        let mut controller = self.cleanup_controller.lock().await;
232        controller.stop();
233    }
234
235    /// Restart the background cleanup task with new configuration
236    pub async fn restart_cleanup_task(&self, cleanup_config: CleanupConfig) {
237        let mut controller = self.cleanup_controller.lock().await;
238        controller.stop();
239
240        if cleanup_config.enabled {
241            controller.start(self.connections.clone(), self.max_idle_time, cleanup_config.interval);
242        }
243    }
244}
245
246impl<M> ConnectionPool<M>
247where
248    M: ConnectionManager + Send + Sync + 'static,
249{
250    async fn return_connection(&self, connection: M::Connection) {
251        let mut connections = self.connections.lock().await;
252        if connections.len() < self.max_size {
253            connections.push_back(PooledConnection {
254                connection,
255                created_at: Instant::now(),
256            });
257            log::trace!("Connection returned to pool (pool size: {})", connections.len());
258        } else {
259            log::trace!("Pool is full, dropping connection (max_size: {})", self.max_size);
260        }
261        // If the pool is full, the connection will be dropped (automatically closed)
262    }
263}
264
265impl<M> Drop for PooledStream<M>
266where
267    M: ConnectionManager + Send + Sync + 'static,
268{
269    fn drop(&mut self) {
270        if let Some(connection) = self.connection.take() {
271            let pool = self.pool.clone();
272            if let Ok(handle) = tokio::runtime::Handle::try_current() {
273                log::trace!("Returning connection to pool on drop");
274                tokio::task::block_in_place(|| handle.block_on(pool.return_connection(connection)));
275            } else {
276                log::warn!("No tokio runtime available, connection will be dropped");
277            }
278        }
279    }
280}
281
282// Generic implementations for AsRef and AsMut
283impl<M> AsRef<M::Connection> for PooledStream<M>
284where
285    M: ConnectionManager + Send + Sync + 'static,
286{
287    fn as_ref(&self) -> &M::Connection {
288        self.connection.as_ref().unwrap()
289    }
290}
291
292impl<M> AsMut<M::Connection> for PooledStream<M>
293where
294    M: ConnectionManager + Send + Sync + 'static,
295{
296    fn as_mut(&mut self) -> &mut M::Connection {
297        self.connection.as_mut().unwrap()
298    }
299}
300
301// Implement Deref and DerefMut for PooledStream
302impl<M> std::ops::Deref for PooledStream<M>
303where
304    M: ConnectionManager + Send + Sync + 'static,
305{
306    type Target = M::Connection;
307
308    fn deref(&self) -> &Self::Target {
309        self.connection.as_ref().unwrap()
310    }
311}
312
313impl<M> std::ops::DerefMut for PooledStream<M>
314where
315    M: ConnectionManager + Send + Sync + 'static,
316{
317    fn deref_mut(&mut self) -> &mut Self::Target {
318        self.connection.as_mut().unwrap()
319    }
320}
321
322/// Pool errors
323#[derive(Debug)]
324pub enum PoolError<E> {
325    PoolClosed,
326    Timeout,
327    Creation(E),
328}
329
330impl<E: std::fmt::Display> std::fmt::Display for PoolError<E> {
331    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332        match self {
333            PoolError::PoolClosed => write!(f, "Connection pool is closed"),
334            PoolError::Timeout => write!(f, "Connection creation timeout"),
335            PoolError::Creation(e) => write!(f, "Connection creation failed: {e}"),
336        }
337    }
338}
339
340impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
341    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
342        match self {
343            PoolError::Creation(e) => Some(e),
344            _ => None,
345        }
346    }
347}