1use std::collections::VecDeque;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::{Mutex, Semaphore};
6use tokio::task::JoinHandle;
7use tokio::time::{interval, timeout};
8
9pub const DEFAULT_MAX_SIZE: usize = 10;
10pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(5 * 60); pub const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); pub const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); #[derive(Clone)]
16pub struct CleanupConfig {
17 pub interval: Duration,
18 pub enabled: bool,
19}
20
21impl Default for CleanupConfig {
22 fn default() -> Self {
23 Self {
24 interval: DEFAULT_CLEANUP_INTERVAL,
25 enabled: true,
26 }
27 }
28}
29
30struct CleanupTaskController {
32 handle: Option<JoinHandle<()>>,
33 shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
34}
35
36impl CleanupTaskController {
37 fn new() -> Self {
38 Self {
39 handle: None,
40 shutdown_tx: None,
41 }
42 }
43
44 fn start<T, M>(
45 &mut self,
46 connections: Arc<Mutex<VecDeque<InnerConnection<T>>>>,
47 max_idle_time: Duration,
48 cleanup_interval: Duration,
49 manager: Arc<M>,
50 max_size: usize,
51 ) where
52 T: Send + 'static,
53 M: ConnectionManager<Connection = T> + Send + Sync + 'static,
54 {
55 if self.handle.is_some() {
56 log::warn!("Cleanup task is already running");
57 return;
58 }
59 let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
60 self.shutdown_tx = Some(shutdown_tx);
61 let handle = tokio::spawn(async move {
62 let mut interval_timer = interval(cleanup_interval);
63 log::info!("Background cleanup task started with interval: {cleanup_interval:?}");
64
65 loop {
66 tokio::select! {
67 _ = interval_timer.tick() => {}
68 _ = &mut shutdown_rx => {
69 log::info!("Received shutdown signal, exiting cleanup loop");
70 break;
71 }
72 };
73
74 let mut connections = connections.lock().await;
75 let initial_count = connections.len();
76 let now = Instant::now();
77
78 let mut valid_connections = VecDeque::new();
80 for mut conn in connections.drain(..) {
81 let not_expired = now.duration_since(conn.created_at) < max_idle_time;
82 let is_valid = if not_expired {
83 manager.is_valid(&mut conn.connection).await
84 } else {
85 false
86 };
87 if not_expired && is_valid {
88 valid_connections.push_back(conn);
89 }
90 }
91 let removed_count = initial_count - valid_connections.len();
92 *connections = valid_connections;
93
94 if removed_count > 0 {
95 log::debug!("Background cleanup removed {removed_count} expired/invalid connections");
96 }
97
98 log::debug!("Current pool (remaining {}/{max_size}) after cleanup", connections.len());
99 }
100 });
101
102 self.handle = Some(handle);
103 }
104
105 async fn stop(&mut self) {
106 if let Some(shutdown_tx) = self.shutdown_tx.take() {
107 let _ = shutdown_tx.send(());
108 }
109 if let Some(handle) = self.handle.take() {
110 if let Err(e) = handle.await {
111 log::error!("Error while stopping background cleanup task: {e}");
112 }
113 log::info!("Background cleanup task stopped in async stop");
114 }
115 }
116
117 fn stop_sync(&mut self) {
118 if let Some(shutdown_tx) = self.shutdown_tx.take() {
119 let _ = shutdown_tx.send(());
120 }
121 if let Some(handle) = self.handle.take() {
122 std::thread::sleep(std::time::Duration::from_millis(100));
123 handle.abort();
124 log::info!("Background cleanup task stopped in stop_sync");
125 }
126 }
127}
128
129impl Drop for CleanupTaskController {
130 fn drop(&mut self) {
131 self.stop_sync();
132 }
133}
134
135pub trait ConnectionManager: Sync + Send + Clone {
137 type Connection: Send;
139
140 type Error: std::error::Error + Send + Sync + 'static;
142
143 type CreateFut: Future<Output = Result<Self::Connection, Self::Error>> + Send;
145
146 type ValidFut<'a>: Future<Output = bool> + Send
148 where
149 Self: 'a;
150
151 fn create_connection(&self) -> Self::CreateFut;
153
154 fn is_valid<'a>(&'a self, connection: &'a mut Self::Connection) -> Self::ValidFut<'a>;
156}
157
158pub struct ConnectionPool<M>
160where
161 M: ConnectionManager + Send + Sync + Clone + 'static,
162{
163 connections: Arc<Mutex<VecDeque<InnerConnection<M::Connection>>>>,
164 semaphore: Arc<Semaphore>,
165 max_size: usize,
166 manager: M,
167 max_idle_time: Duration,
168 connection_timeout: Duration,
169 cleanup_controller: Arc<Mutex<CleanupTaskController>>,
170 outstanding_count: Arc<std::sync::atomic::AtomicUsize>,
171}
172
173struct InnerConnection<T> {
175 pub connection: T,
176 pub created_at: Instant,
177}
178
179pub struct ManagedConnection<M>
181where
182 M: ConnectionManager + Send + Sync + 'static,
183{
184 connection: Option<M::Connection>,
185 pool: Arc<ConnectionPool<M>>,
186 _permit: tokio::sync::OwnedSemaphorePermit,
187}
188
189impl<M> ManagedConnection<M>
190where
191 M: ConnectionManager,
192{
193 pub fn into_inner(mut self) -> M::Connection {
195 self.connection.take().unwrap()
196 }
197
198 fn new(connection: M::Connection, pool: Arc<ConnectionPool<M>>, permit: tokio::sync::OwnedSemaphorePermit) -> Self {
199 pool.outstanding_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
200 ManagedConnection {
201 connection: Some(connection),
202 pool,
203 _permit: permit,
204 }
205 }
206}
207
208impl<M> ConnectionPool<M>
209where
210 M: ConnectionManager + Send + Sync + Clone + 'static,
211{
212 pub fn new(
214 max_size: Option<usize>,
215 max_idle_time: Option<Duration>,
216 connection_timeout: Option<Duration>,
217 cleanup_config: Option<CleanupConfig>,
218 manager: M,
219 ) -> Arc<Self> {
220 let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
221 let max_idle_time = max_idle_time.unwrap_or(DEFAULT_IDLE_TIMEOUT);
222 let connection_timeout = connection_timeout.unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
223 let cleanup_config = cleanup_config.unwrap_or_default();
224
225 log::info!(
226 "Creating connection pool with max_size: {max_size}, idle_timeout: {max_idle_time:?}, connection_timeout: {connection_timeout:?}, cleanup_enabled: {}",
227 cleanup_config.enabled
228 );
229
230 let connections = Arc::new(Mutex::new(VecDeque::new()));
231 let cleanup_controller = Arc::new(Mutex::new(CleanupTaskController::new()));
232
233 let pool = Arc::new(ConnectionPool {
234 connections: connections.clone(),
235 semaphore: Arc::new(Semaphore::new(max_size)),
236 max_size,
237 manager,
238 max_idle_time,
239 connection_timeout,
240 cleanup_controller: cleanup_controller.clone(),
241 outstanding_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
242 });
243
244 if cleanup_config.enabled {
246 let manager = Arc::new(pool.manager.clone());
247 tokio::spawn(async move {
248 let mut controller = cleanup_controller.lock().await;
249 controller.start(connections, max_idle_time, cleanup_config.interval, manager, max_size);
250 });
251 }
252
253 pool
254 }
255
256 pub async fn get_connection(self: Arc<Self>) -> Result<ManagedConnection<M>, PoolError<M::Error>> {
258 log::debug!("Attempting to get connection from pool");
259
260 let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| PoolError::PoolClosed)?;
262
263 {
265 let mut connections = self.connections.lock().await;
266 loop {
267 let Some(mut pooled_conn) = connections.pop_front() else {
268 break;
270 };
271 log::trace!("Found existing connection in pool, validating...");
272 let age = Instant::now().duration_since(pooled_conn.created_at);
273 let valid = self.manager.is_valid(&mut pooled_conn.connection).await;
274 if age >= self.max_idle_time {
275 log::debug!("Connection expired (age: {age:?}), discarding");
276 } else if !valid {
277 log::warn!("Connection validation failed, discarding invalid connection");
278 } else {
279 let size = connections.len();
280 log::debug!("Reusing existing connection from pool (remaining: {size}/{})", self.max_size);
281 return Ok(ManagedConnection::new(pooled_conn.connection, self.clone(), permit));
282 }
283 }
284 }
285
286 log::trace!("No valid connection available, creating new connection...");
287 match timeout(self.connection_timeout, self.manager.create_connection()).await {
289 Ok(Ok(connection)) => {
290 log::info!("Successfully created new connection");
291 Ok(ManagedConnection::new(connection, self.clone(), permit))
292 }
293 Ok(Err(e)) => {
294 log::error!("Failed to create new connection");
295 Err(PoolError::Creation(e))
296 }
297 Err(_) => {
298 log::warn!("Connection creation timed out after {:?}", self.connection_timeout);
299 Err(PoolError::Timeout)
300 }
301 }
302 }
303
304 pub fn outstanding_count(&self) -> usize {
305 self.outstanding_count.load(std::sync::atomic::Ordering::SeqCst)
306 }
307
308 pub async fn pool_size(&self) -> usize {
309 self.connections.lock().await.len()
310 }
311
312 pub fn max_size(&self) -> usize {
313 self.max_size
314 }
315
316 pub async fn stop_cleanup_task(&self) {
318 let mut controller = self.cleanup_controller.lock().await;
319 controller.stop().await;
320 }
321
322 pub async fn restart_cleanup_task(&self, cleanup_config: CleanupConfig) {
324 self.stop_cleanup_task().await;
325
326 if cleanup_config.enabled {
327 let manager = Arc::new(self.manager.clone());
328 let mut controller = self.cleanup_controller.lock().await;
329 let m = self.max_size;
330 controller.start(self.connections.clone(), self.max_idle_time, cleanup_config.interval, manager, m);
331 }
332 }
333}
334
335impl<M> ConnectionPool<M>
336where
337 M: ConnectionManager,
338{
339 async fn recycle(&self, mut connection: M::Connection) {
340 if !self.manager.is_valid(&mut connection).await {
341 log::debug!("Invalid connection, dropping");
342 return;
343 }
344 let mut connections = self.connections.lock().await;
345 if connections.len() < self.max_size {
346 connections.push_back(InnerConnection {
347 connection,
348 created_at: Instant::now(),
349 });
350 log::debug!("Connection recycled to pool (pool size: {}/{})", connections.len(), self.max_size);
351 } else {
352 log::debug!("Pool is full, dropping connection (pool max size: {})", self.max_size);
353 }
354 }
356}
357
358impl<M> Drop for ManagedConnection<M>
359where
360 M: ConnectionManager + Send + Sync + Clone + 'static,
361{
362 fn drop(&mut self) {
363 self.pool.outstanding_count.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
364 if let Some(connection) = self.connection.take() {
365 let pool = self.pool.clone();
366 _ = tokio::spawn(async move {
367 log::trace!("Recycling connection to pool on drop");
368 pool.recycle(connection).await;
369 });
370 }
371 }
372}
373
374impl<M> AsRef<M::Connection> for ManagedConnection<M>
376where
377 M: ConnectionManager + Send + Sync + Clone + 'static,
378{
379 fn as_ref(&self) -> &M::Connection {
380 self.connection.as_ref().unwrap()
381 }
382}
383
384impl<M> AsMut<M::Connection> for ManagedConnection<M>
385where
386 M: ConnectionManager + Send + Sync + Clone + 'static,
387{
388 fn as_mut(&mut self) -> &mut M::Connection {
389 self.connection.as_mut().unwrap()
390 }
391}
392
393impl<M> std::ops::Deref for ManagedConnection<M>
395where
396 M: ConnectionManager + Send + Sync + Clone + 'static,
397{
398 type Target = M::Connection;
399
400 fn deref(&self) -> &Self::Target {
401 self.connection.as_ref().unwrap()
402 }
403}
404
405impl<M> std::ops::DerefMut for ManagedConnection<M>
406where
407 M: ConnectionManager + Send + Sync + Clone + 'static,
408{
409 fn deref_mut(&mut self) -> &mut Self::Target {
410 self.connection.as_mut().unwrap()
411 }
412}
413
414#[derive(Debug)]
416pub enum PoolError<E> {
417 PoolClosed,
418 Timeout,
419 Creation(E),
420}
421
422impl<E: std::fmt::Display> std::fmt::Display for PoolError<E> {
423 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424 match self {
425 PoolError::PoolClosed => write!(f, "Connection pool is closed"),
426 PoolError::Timeout => write!(f, "Connection creation timeout"),
427 PoolError::Creation(e) => write!(f, "Connection creation failed: {e}"),
428 }
429 }
430}
431
432impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
433 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
434 match self {
435 PoolError::Creation(e) => Some(e),
436 _ => None,
437 }
438 }
439}