1use std::collections::VecDeque;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::{Mutex, Semaphore};
6use tokio::task::JoinHandle;
7use tokio::time::{interval, timeout};
8
9pub const DEFAULT_MAX_SIZE: usize = 10;
10pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(5 * 60); pub const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); pub const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); #[derive(Clone)]
16pub struct CleanupConfig {
17 pub interval: Duration,
18 pub enabled: bool,
19}
20
21impl Default for CleanupConfig {
22 fn default() -> Self {
23 Self {
24 interval: DEFAULT_CLEANUP_INTERVAL,
25 enabled: true,
26 }
27 }
28}
29
30struct CleanupTaskController {
32 handle: Option<JoinHandle<()>>,
33}
34
35impl CleanupTaskController {
36 fn new() -> Self {
37 Self { handle: None }
38 }
39
40 fn start<T, M>(
41 &mut self,
42 connections: Arc<Mutex<VecDeque<InnerConnection<T>>>>,
43 max_idle_time: Duration,
44 cleanup_interval: Duration,
45 manager: Arc<M>,
46 ) where
47 T: Send + 'static,
48 M: ConnectionManager<Connection = T> + Send + Sync + 'static,
49 {
50 if self.handle.is_some() {
51 log::warn!("Cleanup task is already running");
52 return;
53 }
54
55 let handle = tokio::spawn(async move {
56 let mut interval_timer = interval(cleanup_interval);
57 log::info!("Background cleanup task started with interval: {cleanup_interval:?}");
58
59 loop {
60 interval_timer.tick().await;
61
62 let mut connections = connections.lock().await;
63 let initial_count = connections.len();
64 let now = Instant::now();
65
66 let mut valid_connections = VecDeque::new();
68 for mut conn in connections.drain(..) {
69 let not_expired = now.duration_since(conn.created_at) < max_idle_time;
70 let is_valid = if not_expired {
71 manager.is_valid(&mut conn.connection).await
72 } else {
73 false
74 };
75 if not_expired && is_valid {
76 valid_connections.push_back(conn);
77 }
78 }
79 let removed_count = initial_count - valid_connections.len();
80 *connections = valid_connections;
81
82 if removed_count > 0 {
83 log::debug!("Background cleanup removed {removed_count} expired/invalid connections");
84 }
85
86 log::trace!("Current pool size after cleanup: {}", connections.len());
87 }
88 });
89
90 self.handle = Some(handle);
91 }
92
93 fn stop(&mut self) {
94 if let Some(handle) = self.handle.take() {
95 handle.abort();
96 log::info!("Background cleanup task stopped");
97 }
98 }
99}
100
101impl Drop for CleanupTaskController {
102 fn drop(&mut self) {
103 self.stop();
104 }
105}
106
107pub trait ConnectionManager: Sync + Send + Clone {
109 type Connection: Send;
111
112 type Error: std::error::Error + Send + Sync + 'static;
114
115 type CreateFut: Future<Output = Result<Self::Connection, Self::Error>> + Send;
117
118 type ValidFut<'a>: Future<Output = bool> + Send
120 where
121 Self: 'a;
122
123 fn create_connection(&self) -> Self::CreateFut;
125
126 fn is_valid<'a>(&'a self, connection: &'a mut Self::Connection) -> Self::ValidFut<'a>;
128}
129
130pub struct ConnectionPool<M>
132where
133 M: ConnectionManager + Send + Sync + Clone + 'static,
134{
135 connections: Arc<Mutex<VecDeque<InnerConnection<M::Connection>>>>,
136 semaphore: Arc<Semaphore>,
137 max_size: usize,
138 manager: M,
139 max_idle_time: Duration,
140 connection_timeout: Duration,
141 cleanup_controller: Arc<Mutex<CleanupTaskController>>,
142}
143
144struct InnerConnection<T> {
146 pub connection: T,
147 pub created_at: Instant,
148}
149
150pub struct ManagedConnection<M>
152where
153 M: ConnectionManager + Send + Sync + 'static,
154{
155 connection: Option<M::Connection>,
156 pool: Arc<ConnectionPool<M>>,
157 _permit: tokio::sync::OwnedSemaphorePermit,
158}
159
160impl<M> ConnectionPool<M>
161where
162 M: ConnectionManager + Send + Sync + Clone + 'static,
163{
164 pub fn new(
166 max_size: Option<usize>,
167 max_idle_time: Option<Duration>,
168 connection_timeout: Option<Duration>,
169 cleanup_config: Option<CleanupConfig>,
170 manager: M,
171 ) -> Arc<Self> {
172 let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
173 let max_idle_time = max_idle_time.unwrap_or(DEFAULT_IDLE_TIMEOUT);
174 let connection_timeout = connection_timeout.unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
175 let cleanup_config = cleanup_config.unwrap_or_default();
176
177 log::info!(
178 "Creating connection pool with max_size: {max_size}, idle_timeout: {max_idle_time:?}, connection_timeout: {connection_timeout:?}, cleanup_enabled: {}",
179 cleanup_config.enabled
180 );
181
182 let connections = Arc::new(Mutex::new(VecDeque::new()));
183 let cleanup_controller = Arc::new(Mutex::new(CleanupTaskController::new()));
184
185 let pool = Arc::new(ConnectionPool {
186 connections: connections.clone(),
187 semaphore: Arc::new(Semaphore::new(max_size)),
188 max_size,
189 manager,
190 max_idle_time,
191 connection_timeout,
192 cleanup_controller: cleanup_controller.clone(),
193 });
194
195 if cleanup_config.enabled {
197 let manager = Arc::new(pool.manager.clone());
198 tokio::spawn(async move {
199 let mut controller = cleanup_controller.lock().await;
200 controller.start(connections, max_idle_time, cleanup_config.interval, manager);
201 });
202 }
203
204 pool
205 }
206
207 pub async fn get_connection(self: Arc<Self>) -> Result<ManagedConnection<M>, PoolError<M::Error>> {
209 log::debug!("Attempting to get connection from pool");
210
211 let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| PoolError::PoolClosed)?;
213
214 {
216 let mut connections = self.connections.lock().await;
217 loop {
218 let Some(mut pooled_conn) = connections.pop_front() else {
219 break;
221 };
222 log::trace!("Found existing connection in pool, validating...");
223 let age = Instant::now().duration_since(pooled_conn.created_at);
224 let valid = self.manager.is_valid(&mut pooled_conn.connection).await;
225 if age >= self.max_idle_time {
226 log::debug!("Connection expired (age: {age:?}), discarding");
227 } else if !valid {
228 log::warn!("Connection validation failed, discarding invalid connection");
229 } else {
230 log::debug!("Reusing existing connection from pool (remaining: {})", connections.len());
231 return Ok(ManagedConnection {
232 connection: Some(pooled_conn.connection),
233 pool: self.clone(),
234 _permit: permit,
235 });
236 }
237 }
238 }
239
240 log::trace!("No valid connection available, creating new connection...");
241 match timeout(self.connection_timeout, self.manager.create_connection()).await {
243 Ok(Ok(connection)) => {
244 log::info!("Successfully created new connection");
245 Ok(ManagedConnection {
246 connection: Some(connection),
247 pool: self.clone(),
248 _permit: permit,
249 })
250 }
251 Ok(Err(e)) => {
252 log::error!("Failed to create new connection");
253 Err(PoolError::Creation(e))
254 }
255 Err(_) => {
256 log::warn!("Connection creation timed out after {:?}", self.connection_timeout);
257 Err(PoolError::Timeout)
258 }
259 }
260 }
261
262 pub async fn stop_cleanup_task(&self) {
264 let mut controller = self.cleanup_controller.lock().await;
265 controller.stop();
266 }
267
268 pub async fn restart_cleanup_task(&self, cleanup_config: CleanupConfig) {
270 self.stop_cleanup_task().await;
271
272 if cleanup_config.enabled {
273 let manager = Arc::new(self.manager.clone());
274 let mut controller = self.cleanup_controller.lock().await;
275 controller.start(self.connections.clone(), self.max_idle_time, cleanup_config.interval, manager);
276 }
277 }
278}
279
280impl<M> ConnectionPool<M>
281where
282 M: ConnectionManager + Send + Sync + Clone + 'static,
283{
284 async fn recycle(&self, mut connection: M::Connection) {
285 if !self.manager.is_valid(&mut connection).await {
286 log::debug!("Invalid connection, dropping");
287 return;
288 }
289 let mut connections = self.connections.lock().await;
290 if connections.len() < self.max_size {
291 connections.push_back(InnerConnection {
292 connection,
293 created_at: Instant::now(),
294 });
295 log::trace!("Connection returned to pool (pool size: {})", connections.len());
296 } else {
297 log::trace!("Pool is full, dropping connection (max_size: {})", self.max_size);
298 }
299 }
301}
302
303impl<M> Drop for ManagedConnection<M>
304where
305 M: ConnectionManager + Send + Sync + Clone + 'static,
306{
307 fn drop(&mut self) {
308 if let Some(connection) = self.connection.take() {
309 let pool = self.pool.clone();
310 _ = tokio::spawn(async move {
311 log::trace!("Returning connection to pool on drop");
312 pool.recycle(connection).await;
313 });
314 }
315 }
316}
317
318impl<M> AsRef<M::Connection> for ManagedConnection<M>
320where
321 M: ConnectionManager + Send + Sync + Clone + 'static,
322{
323 fn as_ref(&self) -> &M::Connection {
324 self.connection.as_ref().unwrap()
325 }
326}
327
328impl<M> AsMut<M::Connection> for ManagedConnection<M>
329where
330 M: ConnectionManager + Send + Sync + Clone + 'static,
331{
332 fn as_mut(&mut self) -> &mut M::Connection {
333 self.connection.as_mut().unwrap()
334 }
335}
336
337impl<M> std::ops::Deref for ManagedConnection<M>
339where
340 M: ConnectionManager + Send + Sync + Clone + 'static,
341{
342 type Target = M::Connection;
343
344 fn deref(&self) -> &Self::Target {
345 self.connection.as_ref().unwrap()
346 }
347}
348
349impl<M> std::ops::DerefMut for ManagedConnection<M>
350where
351 M: ConnectionManager + Send + Sync + Clone + 'static,
352{
353 fn deref_mut(&mut self) -> &mut Self::Target {
354 self.connection.as_mut().unwrap()
355 }
356}
357
358#[derive(Debug)]
360pub enum PoolError<E> {
361 PoolClosed,
362 Timeout,
363 Creation(E),
364}
365
366impl<E: std::fmt::Display> std::fmt::Display for PoolError<E> {
367 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368 match self {
369 PoolError::PoolClosed => write!(f, "Connection pool is closed"),
370 PoolError::Timeout => write!(f, "Connection creation timeout"),
371 PoolError::Creation(e) => write!(f, "Connection creation failed: {e}"),
372 }
373 }
374}
375
376impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
377 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
378 match self {
379 PoolError::Creation(e) => Some(e),
380 _ => None,
381 }
382 }
383}