1use std::collections::VecDeque;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::{Mutex, Semaphore};
6use tokio::task::JoinHandle;
7use tokio::time::{interval, timeout};
8
9pub const DEFAULT_MAX_SIZE: usize = 10;
10pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(5 * 60); pub const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); pub const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); #[derive(Clone)]
16pub struct CleanupConfig {
17 pub interval: Duration,
18 pub enabled: bool,
19}
20
21impl Default for CleanupConfig {
22 fn default() -> Self {
23 Self {
24 interval: DEFAULT_CLEANUP_INTERVAL,
25 enabled: true,
26 }
27 }
28}
29
30struct CleanupTaskController {
32 handle: Option<JoinHandle<()>>,
33 shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
34}
35
36impl CleanupTaskController {
37 fn new() -> Self {
38 Self {
39 handle: None,
40 shutdown_tx: None,
41 }
42 }
43
44 fn start<T, M>(
45 &mut self,
46 connections: Arc<Mutex<VecDeque<InnerConnection<T>>>>,
47 max_idle_time: Duration,
48 cleanup_interval: Duration,
49 manager: Arc<M>,
50 max_size: usize,
51 ) where
52 T: Send + 'static,
53 M: ConnectionManager<Connection = T> + Send + Sync + 'static,
54 {
55 if self.handle.is_some() {
56 log::warn!("Cleanup task is already running");
57 return;
58 }
59 let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
60 self.shutdown_tx = Some(shutdown_tx);
61 let handle = tokio::spawn(async move {
62 let mut interval_timer = interval(cleanup_interval);
63 log::info!("Background cleanup task started with interval: {cleanup_interval:?}");
64
65 loop {
66 tokio::select! {
67 _ = interval_timer.tick() => {}
68 _ = &mut shutdown_rx => {
69 log::info!("Received shutdown signal, exiting cleanup loop");
70 break;
71 }
72 };
73
74 let mut connections = connections.lock().await;
75 let initial_count = connections.len();
76 let now = Instant::now();
77
78 let mut valid_connections = VecDeque::new();
80 for mut conn in connections.drain(..) {
81 let not_expired = now.duration_since(conn.created_at) < max_idle_time;
82 let is_valid = if not_expired {
83 manager.is_valid(&mut conn.connection).await
84 } else {
85 false
86 };
87 if not_expired && is_valid {
88 valid_connections.push_back(conn);
89 }
90 }
91 let removed_count = initial_count - valid_connections.len();
92 *connections = valid_connections;
93
94 if removed_count > 0 {
95 log::debug!("Background cleanup removed {removed_count} expired/invalid connections");
96 }
97
98 log::debug!("Current pool (remaining {}/{max_size}) after cleanup", connections.len());
99 }
100 });
101
102 self.handle = Some(handle);
103 }
104
105 async fn stop(&mut self) {
106 if let Some(shutdown_tx) = self.shutdown_tx.take() {
107 let _ = shutdown_tx.send(());
108 }
109 if let Some(handle) = self.handle.take() {
110 if let Err(e) = handle.await {
111 log::error!("Error while stopping background cleanup task: {e}");
112 }
113 log::info!("Background cleanup task stopped in async stop");
114 }
115 }
116
117 fn stop_sync(&mut self) {
118 if let Some(shutdown_tx) = self.shutdown_tx.take() {
119 let _ = shutdown_tx.send(());
120 }
121 if let Some(handle) = self.handle.take() {
122 std::thread::sleep(std::time::Duration::from_millis(100));
123 handle.abort();
124 log::info!("Background cleanup task stopped in stop_sync");
125 }
126 }
127}
128
129impl Drop for CleanupTaskController {
130 fn drop(&mut self) {
131 self.stop_sync();
132 }
133}
134
135pub trait ConnectionManager: Sync + Send + Clone {
137 type Connection: Send;
139
140 type Error: std::error::Error + Send + Sync + 'static;
142
143 type CreateFut: Future<Output = Result<Self::Connection, Self::Error>> + Send;
145
146 type ValidFut<'a>: Future<Output = bool> + Send
148 where
149 Self: 'a;
150
151 fn create_connection(&self) -> Self::CreateFut;
153
154 fn is_valid<'a>(&'a self, connection: &'a mut Self::Connection) -> Self::ValidFut<'a>;
156}
157
158pub struct ConnectionPool<M: ConnectionManager> {
160 connections: Arc<Mutex<VecDeque<InnerConnection<M::Connection>>>>,
161 semaphore: Arc<Semaphore>,
162 max_size: usize,
163 manager: M,
164 max_idle_time: Duration,
165 connection_timeout: Duration,
166 cleanup_controller: Arc<Mutex<CleanupTaskController>>,
167 outstanding_count: Arc<std::sync::atomic::AtomicUsize>,
168}
169
170struct InnerConnection<T> {
172 pub connection: T,
173 pub created_at: Instant,
174}
175
176pub struct ManagedConnection<M>
178where
179 M: ConnectionManager + Send + Sync + 'static,
180{
181 connection: Option<M::Connection>,
182 pool: Arc<ConnectionPool<M>>,
183 _permit: tokio::sync::OwnedSemaphorePermit,
184}
185
186impl<M: ConnectionManager> ManagedConnection<M> {
187 pub fn into_inner(mut self) -> M::Connection {
189 self.connection.take().unwrap()
190 }
191
192 fn new(connection: M::Connection, pool: Arc<ConnectionPool<M>>, permit: tokio::sync::OwnedSemaphorePermit) -> Self {
193 pool.outstanding_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
194 ManagedConnection {
195 connection: Some(connection),
196 pool,
197 _permit: permit,
198 }
199 }
200}
201
202impl<M> ConnectionPool<M>
203where
204 M: ConnectionManager + Send + Sync + Clone + 'static,
205{
206 pub fn new(
208 max_size: Option<usize>,
209 max_idle_time: Option<Duration>,
210 connection_timeout: Option<Duration>,
211 cleanup_config: Option<CleanupConfig>,
212 manager: M,
213 ) -> Arc<Self> {
214 let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
215 let max_idle_time = max_idle_time.unwrap_or(DEFAULT_IDLE_TIMEOUT);
216 let connection_timeout = connection_timeout.unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
217 let cleanup_config = cleanup_config.unwrap_or_default();
218
219 log::info!(
220 "Creating connection pool with max_size: {max_size}, idle_timeout: {max_idle_time:?}, connection_timeout: {connection_timeout:?}, cleanup_enabled: {}",
221 cleanup_config.enabled
222 );
223
224 let connections = Arc::new(Mutex::new(VecDeque::new()));
225 let cleanup_controller = Arc::new(Mutex::new(CleanupTaskController::new()));
226
227 let pool = Arc::new(ConnectionPool {
228 connections: connections.clone(),
229 semaphore: Arc::new(Semaphore::new(max_size)),
230 max_size,
231 manager,
232 max_idle_time,
233 connection_timeout,
234 cleanup_controller: cleanup_controller.clone(),
235 outstanding_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
236 });
237
238 if cleanup_config.enabled {
240 let manager = Arc::new(pool.manager.clone());
241 tokio::spawn(async move {
242 let mut controller = cleanup_controller.lock().await;
243 controller.start(connections, max_idle_time, cleanup_config.interval, manager, max_size);
244 });
245 }
246
247 pool
248 }
249
250 pub async fn get_connection(self: Arc<Self>) -> Result<ManagedConnection<M>, PoolError<M::Error>> {
252 log::debug!("Attempting to get connection from pool");
253
254 let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| PoolError::PoolClosed)?;
256
257 {
259 let mut connections = self.connections.lock().await;
260 loop {
261 let Some(mut pooled_conn) = connections.pop_front() else {
262 break;
264 };
265 log::trace!("Found existing connection in pool, validating...");
266 let age = Instant::now().duration_since(pooled_conn.created_at);
267 let is_valid = if age < self.max_idle_time {
268 let r = self.manager.is_valid(&mut pooled_conn.connection).await;
269 if !r {
270 log::warn!("Connection validation failed, discarding invalid connection");
271 }
272 r
273 } else {
274 log::debug!("Connection expired (age: {age:?}), discarding");
275 false
276 };
277 if is_valid {
278 let size = connections.len();
279 log::debug!("Reusing existing connection from pool (remaining: {size}/{})", self.max_size);
280 return Ok(ManagedConnection::new(pooled_conn.connection, self.clone(), permit));
281 }
282 }
283 }
284
285 log::trace!("No valid connection available, creating new connection...");
286 match timeout(self.connection_timeout, self.manager.create_connection()).await {
288 Ok(Ok(connection)) => {
289 log::info!("Successfully created new connection");
290 Ok(ManagedConnection::new(connection, self.clone(), permit))
291 }
292 Ok(Err(e)) => {
293 log::error!("Failed to create new connection");
294 Err(PoolError::Creation(e))
295 }
296 Err(_) => {
297 log::warn!("Connection creation timed out after {:?}", self.connection_timeout);
298 Err(PoolError::Timeout)
299 }
300 }
301 }
302
303 pub fn outstanding_count(&self) -> usize {
304 self.outstanding_count.load(std::sync::atomic::Ordering::SeqCst)
305 }
306
307 pub async fn pool_size(&self) -> usize {
308 self.connections.lock().await.len()
309 }
310
311 pub fn max_size(&self) -> usize {
312 self.max_size
313 }
314
315 pub async fn stop_cleanup_task(&self) {
317 let mut controller = self.cleanup_controller.lock().await;
318 controller.stop().await;
319 }
320
321 pub async fn restart_cleanup_task(&self, cleanup_config: CleanupConfig) {
323 self.stop_cleanup_task().await;
324
325 if cleanup_config.enabled {
326 let manager = Arc::new(self.manager.clone());
327 let mut controller = self.cleanup_controller.lock().await;
328 let m = self.max_size;
329 controller.start(self.connections.clone(), self.max_idle_time, cleanup_config.interval, manager, m);
330 }
331 }
332}
333
334impl<M: ConnectionManager> ConnectionPool<M> {
335 async fn recycle(&self, mut connection: M::Connection) {
336 if !self.manager.is_valid(&mut connection).await {
337 log::debug!("Invalid connection, dropping");
338 return;
339 }
340 let mut connections = self.connections.lock().await;
341 if connections.len() < self.max_size {
342 connections.push_back(InnerConnection {
343 connection,
344 created_at: Instant::now(),
345 });
346 log::debug!("Connection recycled to pool (pool size: {}/{})", connections.len(), self.max_size);
347 } else {
348 log::debug!("Pool is full, dropping connection (pool max size: {})", self.max_size);
349 }
350 }
352}
353
354impl<M: ConnectionManager> Drop for ManagedConnection<M> {
355 fn drop(&mut self) {
356 self.pool.outstanding_count.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
357 if let Some(connection) = self.connection.take() {
358 let pool = self.pool.clone();
359 _ = tokio::spawn(async move {
360 log::trace!("Recycling connection to pool on drop");
361 pool.recycle(connection).await;
362 });
363 }
364 }
365}
366
367impl<M: ConnectionManager> AsRef<M::Connection> for ManagedConnection<M> {
369 fn as_ref(&self) -> &M::Connection {
370 self.connection.as_ref().unwrap()
371 }
372}
373
374impl<M: ConnectionManager> AsMut<M::Connection> for ManagedConnection<M> {
375 fn as_mut(&mut self) -> &mut M::Connection {
376 self.connection.as_mut().unwrap()
377 }
378}
379
380impl<M: ConnectionManager> std::ops::Deref for ManagedConnection<M> {
382 type Target = M::Connection;
383
384 fn deref(&self) -> &Self::Target {
385 self.connection.as_ref().unwrap()
386 }
387}
388
389impl<M: ConnectionManager> std::ops::DerefMut for ManagedConnection<M> {
390 fn deref_mut(&mut self) -> &mut Self::Target {
391 self.connection.as_mut().unwrap()
392 }
393}
394
395#[derive(Debug)]
397pub enum PoolError<E> {
398 PoolClosed,
399 Timeout,
400 Creation(E),
401}
402
403impl<E: std::fmt::Display> std::fmt::Display for PoolError<E> {
404 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
405 match self {
406 PoolError::PoolClosed => write!(f, "Connection pool is closed"),
407 PoolError::Timeout => write!(f, "Connection creation timeout"),
408 PoolError::Creation(e) => write!(f, "Connection creation failed: {e}"),
409 }
410 }
411}
412
413impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
414 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
415 match self {
416 PoolError::Creation(e) => Some(e),
417 _ => None,
418 }
419 }
420}