connection_pool/
connection_pool.rs

1use std::collections::VecDeque;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::net::{TcpStream, ToSocketAddrs};
6use tokio::sync::{Mutex, Semaphore};
7use tokio::task::JoinHandle;
8use tokio::time::{interval, timeout};
9
10pub const DEFAULT_MAX_SIZE: usize = 10;
11pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(5 * 60); // 5 minutes
12pub const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); // 10 seconds
13pub const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); // 30 seconds
14
15/// Configuration for background cleanup task
16#[derive(Clone)]
17pub struct CleanupConfig {
18    pub interval: Duration,
19    pub enabled: bool,
20}
21
22impl Default for CleanupConfig {
23    fn default() -> Self {
24        Self {
25            interval: DEFAULT_CLEANUP_INTERVAL,
26            enabled: true,
27        }
28    }
29}
30
31/// Background cleanup task controller
32pub struct CleanupTaskController {
33    handle: Option<JoinHandle<()>>,
34}
35
36impl CleanupTaskController {
37    pub fn new() -> Self {
38        Self { handle: None }
39    }
40
41    pub fn start<T: Send + 'static>(
42        &mut self,
43        connections: Arc<Mutex<VecDeque<PooledConnection<T>>>>,
44        max_idle_time: Duration,
45        cleanup_interval: Duration,
46    ) {
47        if self.handle.is_some() {
48            log::warn!("Cleanup task is already running");
49            return;
50        }
51
52        let handle = tokio::spawn(async move {
53            let mut interval_timer = interval(cleanup_interval);
54            log::info!("Background cleanup task started with interval: {cleanup_interval:?}");
55
56            loop {
57                interval_timer.tick().await;
58
59                let mut connections = connections.lock().await;
60                let initial_count = connections.len();
61                let now = Instant::now();
62
63                connections.retain(|conn| now.duration_since(conn.created_at) < max_idle_time);
64
65                let removed_count = initial_count - connections.len();
66                if removed_count > 0 {
67                    log::debug!("Background cleanup removed {removed_count} expired connections");
68                }
69
70                // Release the lock
71                drop(connections);
72            }
73        });
74
75        self.handle = Some(handle);
76    }
77
78    pub fn stop(&mut self) {
79        if let Some(handle) = self.handle.take() {
80            handle.abort();
81            log::info!("Background cleanup task stopped");
82        }
83    }
84}
85
86impl Drop for CleanupTaskController {
87    fn drop(&mut self) {
88        self.stop();
89    }
90}
91
92/// Connection creator trait
93pub trait ConnectionCreator<T, P> {
94    type Error;
95    type Future: Future<Output = Result<T, Self::Error>>;
96
97    fn create_connection(&self, params: &P) -> Self::Future;
98}
99
100/// Connection validator trait
101pub trait ConnectionValidator<T> {
102    fn is_valid(&self, connection: &T) -> impl Future<Output = bool> + Send;
103}
104
105pub struct ConnectionPool<T, P, C, V>
106where
107    T: Send + 'static,
108    P: Send + Sync + Clone + 'static,
109    C: Send + Sync + 'static,
110    V: Send + Sync + 'static,
111{
112    connections: Arc<Mutex<VecDeque<PooledConnection<T>>>>,
113    semaphore: Arc<Semaphore>,
114    max_size: usize,
115    connection_params: P,
116    connection_creator: C,
117    connection_validator: V,
118    max_idle_time: Duration,
119    connection_timeout: Duration,
120    cleanup_controller: Arc<Mutex<CleanupTaskController>>,
121}
122
123pub struct PooledConnection<T> {
124    pub connection: T,
125    pub created_at: Instant,
126}
127
128pub struct PooledStream<T, P, C, V>
129where
130    T: Send + 'static,
131    P: Send + Sync + Clone + 'static,
132    C: Send + Sync + 'static,
133    V: Send + Sync + 'static,
134{
135    connection: Option<T>,
136    pool: Arc<ConnectionPool<T, P, C, V>>,
137    _permit: tokio::sync::OwnedSemaphorePermit,
138}
139
140impl<T, P, C, V> ConnectionPool<T, P, C, V>
141where
142    C: ConnectionCreator<T, P> + Send + Sync + 'static,
143    V: ConnectionValidator<T> + Send + Sync + 'static,
144    T: Send + 'static,
145    P: Send + Sync + Clone + 'static,
146{
147    pub fn new(
148        max_size: Option<usize>,
149        max_idle_time: Option<Duration>,
150        connection_timeout: Option<Duration>,
151        cleanup_config: Option<CleanupConfig>,
152        connection_params: P,
153        connection_creator: C,
154        connection_validator: V,
155    ) -> Arc<Self> {
156        let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
157        let max_idle_time = max_idle_time.unwrap_or(DEFAULT_IDLE_TIMEOUT);
158        let connection_timeout = connection_timeout.unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
159        let cleanup_config = cleanup_config.unwrap_or_default();
160
161        log::info!(
162            "Creating connection pool with max_size: {}, idle_timeout: {:?}, connection_timeout: {:?}, cleanup_enabled: {}",
163            max_size,
164            max_idle_time,
165            connection_timeout,
166            cleanup_config.enabled
167        );
168
169        let connections = Arc::new(Mutex::new(VecDeque::new()));
170        let cleanup_controller = Arc::new(Mutex::new(CleanupTaskController::new()));
171
172        let pool = Arc::new(ConnectionPool {
173            connections: connections.clone(),
174            semaphore: Arc::new(Semaphore::new(max_size)),
175            max_size,
176            connection_params,
177            connection_creator,
178            connection_validator,
179            max_idle_time,
180            connection_timeout,
181            cleanup_controller: cleanup_controller.clone(),
182        });
183
184        // Start background cleanup task if enabled
185        if cleanup_config.enabled {
186            tokio::spawn(async move {
187                let mut controller = cleanup_controller.lock().await;
188                controller.start(connections, max_idle_time, cleanup_config.interval);
189            });
190        }
191
192        pool
193    }
194
195    pub async fn get_connection(self: Arc<Self>) -> Result<PooledStream<T, P, C, V>, PoolError<C::Error>> {
196        log::debug!("Attempting to get connection from pool");
197
198        // Use semaphore to limit concurrent connections
199        let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| PoolError::PoolClosed)?;
200
201        {
202            // Try to get an existing connection from the pool
203            let mut connections = self.connections.lock().await;
204
205            // With background cleanup enabled, we can skip the inline cleanup
206            // for better performance, but still do a quick validation
207            if let Some(pooled_conn) = connections.pop_front() {
208                log::trace!("Found existing connection in pool, validating...");
209
210                // Quick check if connection is not obviously expired
211                let age = Instant::now().duration_since(pooled_conn.created_at);
212                if age >= self.max_idle_time {
213                    log::debug!("Connection expired (age: {age:?}), discarding");
214                } else if self.connection_validator.is_valid(&pooled_conn.connection).await {
215                    log::debug!("Reusing existing connection from pool (remaining: {})", connections.len());
216                    return Ok(PooledStream {
217                        connection: Some(pooled_conn.connection),
218                        pool: self.clone(),
219                        _permit: permit,
220                    });
221                } else {
222                    log::warn!("Connection validation failed, discarding invalid connection");
223                }
224            }
225        }
226
227        log::trace!("No valid connection available, creating new connection...");
228        // Create new connection
229        match timeout(
230            self.connection_timeout,
231            self.connection_creator.create_connection(&self.connection_params),
232        )
233        .await
234        {
235            Ok(Ok(connection)) => {
236                log::info!("Successfully created new connection");
237                Ok(PooledStream {
238                    connection: Some(connection),
239                    pool: self.clone(),
240                    _permit: permit,
241                })
242            }
243            Ok(Err(e)) => {
244                log::error!("Failed to create new connection");
245                Err(PoolError::Creation(e))
246            }
247            Err(_) => {
248                log::warn!("Connection creation timed out after {:?}", self.connection_timeout);
249                Err(PoolError::Timeout)
250            }
251        }
252    }
253
254    /// Stop the background cleanup task
255    pub async fn stop_cleanup_task(&self) {
256        let mut controller = self.cleanup_controller.lock().await;
257        controller.stop();
258    }
259
260    /// Restart the background cleanup task with new configuration
261    pub async fn restart_cleanup_task(&self, cleanup_config: CleanupConfig) {
262        let mut controller = self.cleanup_controller.lock().await;
263        controller.stop();
264
265        if cleanup_config.enabled {
266            controller.start(self.connections.clone(), self.max_idle_time, cleanup_config.interval);
267        }
268    }
269}
270
271// Implementation without trait bounds for basic operations
272impl<T, P, C, V> ConnectionPool<T, P, C, V>
273where
274    T: Send + 'static,
275    P: Send + Sync + Clone + 'static,
276    C: Send + Sync + 'static,
277    V: Send + Sync + 'static,
278{
279    async fn return_connection(&self, connection: T) {
280        let mut connections = self.connections.lock().await;
281        if connections.len() < self.max_size {
282            connections.push_back(PooledConnection {
283                connection,
284                created_at: Instant::now(),
285            });
286            log::trace!("Connection returned to pool (pool size: {})", connections.len());
287        } else {
288            log::trace!("Pool is full, dropping connection (max_size: {})", self.max_size);
289        }
290        // If the pool is full, the connection will be dropped (automatically closed)
291    }
292}
293
294impl<T, P, C, V> Drop for PooledStream<T, P, C, V>
295where
296    T: Send + 'static,
297    P: Send + Sync + Clone + 'static,
298    C: Send + Sync + 'static,
299    V: Send + Sync + 'static,
300{
301    fn drop(&mut self) {
302        if let Some(connection) = self.connection.take() {
303            let pool = self.pool.clone();
304            if let Ok(handle) = tokio::runtime::Handle::try_current() {
305                log::trace!("Returning connection to pool on drop");
306                tokio::task::block_in_place(|| handle.block_on(pool.return_connection(connection)));
307            } else {
308                log::warn!("No tokio runtime available, connection will be dropped");
309            }
310        }
311    }
312}
313
314// Generic implementations for AsRef and AsMut
315impl<T, P, C, V> AsRef<T> for PooledStream<T, P, C, V>
316where
317    T: Send + 'static,
318    P: Send + Sync + Clone + 'static,
319    C: Send + Sync + 'static,
320    V: Send + Sync + 'static,
321{
322    fn as_ref(&self) -> &T {
323        self.connection.as_ref().unwrap()
324    }
325}
326
327impl<T, P, C, V> AsMut<T> for PooledStream<T, P, C, V>
328where
329    T: Send + 'static,
330    P: Send + Sync + Clone + 'static,
331    C: Send + Sync + 'static,
332    V: Send + Sync + 'static,
333{
334    fn as_mut(&mut self) -> &mut T {
335        self.connection.as_mut().unwrap()
336    }
337}
338
339// Implement Deref and DerefMut for PooledStream
340impl<T, P, C, V> std::ops::Deref for PooledStream<T, P, C, V>
341where
342    T: Send + 'static,
343    P: Send + Sync + Clone + 'static,
344    C: Send + Sync + 'static,
345    V: Send + Sync + 'static,
346{
347    type Target = T;
348
349    fn deref(&self) -> &Self::Target {
350        self.connection.as_ref().unwrap()
351    }
352}
353
354impl<T, P, C, V> std::ops::DerefMut for PooledStream<T, P, C, V>
355where
356    T: Send + 'static,
357    P: Send + Sync + Clone + 'static,
358    C: Send + Sync + 'static,
359    V: Send + Sync + 'static,
360{
361    fn deref_mut(&mut self) -> &mut Self::Target {
362        self.connection.as_mut().unwrap()
363    }
364}
365
366/// Pool errors
367#[derive(Debug)]
368pub enum PoolError<E> {
369    PoolClosed,
370    Timeout,
371    Creation(E),
372}
373
374impl<E: std::fmt::Display> std::fmt::Display for PoolError<E> {
375    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376        match self {
377            PoolError::PoolClosed => write!(f, "Connection pool is closed"),
378            PoolError::Timeout => write!(f, "Connection creation timeout"),
379            PoolError::Creation(e) => write!(f, "Connection creation failed: {e}"),
380        }
381    }
382}
383
384impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
385    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
386        match self {
387            PoolError::Creation(e) => Some(e),
388            _ => None,
389        }
390    }
391}
392
393// Implement for TcpStream
394pub struct TcpConnectionCreator;
395
396impl<A> ConnectionCreator<TcpStream, A> for TcpConnectionCreator
397where
398    A: ToSocketAddrs + Send + Sync + Clone + 'static,
399{
400    type Error = std::io::Error;
401    type Future = std::pin::Pin<Box<dyn Future<Output = Result<TcpStream, Self::Error>> + Send>>;
402
403    fn create_connection(&self, address: &A) -> Self::Future {
404        let addr = address.clone();
405        Box::pin(async move { TcpStream::connect(addr).await })
406    }
407}
408
409pub struct TcpConnectionValidator;
410
411impl ConnectionValidator<TcpStream> for TcpConnectionValidator {
412    async fn is_valid(&self, stream: &TcpStream) -> bool {
413        // Simple validation: check if the stream is readable and writable
414        stream
415            .ready(tokio::io::Interest::READABLE | tokio::io::Interest::WRITABLE)
416            .await
417            .is_ok()
418    }
419}
420
421// Convenience type aliases
422pub type TcpConnectionPool<A = std::net::SocketAddr> = ConnectionPool<TcpStream, A, TcpConnectionCreator, TcpConnectionValidator>;
423pub type TcpPooledStream<A = std::net::SocketAddr> = PooledStream<TcpStream, A, TcpConnectionCreator, TcpConnectionValidator>;
424
425impl<A> TcpConnectionPool<A>
426where
427    A: ToSocketAddrs + Send + Sync + Clone + 'static,
428{
429    pub fn new_tcp(
430        max_size: Option<usize>,
431        max_idle_time: Option<Duration>,
432        connection_timeout: Option<Duration>,
433        cleanup_config: Option<CleanupConfig>,
434        address: A,
435    ) -> Arc<Self> {
436        log::info!("Creating TCP connection pool");
437        Self::new(
438            max_size,
439            max_idle_time,
440            connection_timeout,
441            cleanup_config,
442            address,
443            TcpConnectionCreator,
444            TcpConnectionValidator,
445        )
446    }
447}