1use std::collections::VecDeque;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::{Mutex, Semaphore};
6use tokio::task::JoinHandle;
7use tokio::time::{interval, timeout};
8
9pub const DEFAULT_MAX_SIZE: usize = 10;
10pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(5 * 60); pub const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); pub const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); #[derive(Clone)]
16pub struct CleanupConfig {
17 pub interval: Duration,
18 pub enabled: bool,
19}
20
21impl Default for CleanupConfig {
22 fn default() -> Self {
23 Self {
24 interval: DEFAULT_CLEANUP_INTERVAL,
25 enabled: true,
26 }
27 }
28}
29
30struct CleanupTaskController {
32 handle: Option<JoinHandle<()>>,
33 shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
34}
35
36impl CleanupTaskController {
37 fn new() -> Self {
38 Self {
39 handle: None,
40 shutdown_tx: None,
41 }
42 }
43
44 fn start<T, M>(
45 &mut self,
46 connections: Arc<Mutex<VecDeque<InnerConnection<T>>>>,
47 max_idle_time: Duration,
48 cleanup_interval: Duration,
49 manager: Arc<M>,
50 ) where
51 T: Send + 'static,
52 M: ConnectionManager<Connection = T> + Send + Sync + 'static,
53 {
54 if self.handle.is_some() {
55 log::warn!("Cleanup task is already running");
56 return;
57 }
58 let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
59 self.shutdown_tx = Some(shutdown_tx);
60 let handle = tokio::spawn(async move {
61 let mut interval_timer = interval(cleanup_interval);
62 log::info!("Background cleanup task started with interval: {cleanup_interval:?}");
63
64 loop {
65 tokio::select! {
66 _ = interval_timer.tick() => {}
67 _ = &mut shutdown_rx => {
68 log::info!("Received shutdown signal, exiting cleanup loop");
69 break;
70 }
71 };
72
73 let mut connections = connections.lock().await;
74 let initial_count = connections.len();
75 let now = Instant::now();
76
77 let mut valid_connections = VecDeque::new();
79 for mut conn in connections.drain(..) {
80 let not_expired = now.duration_since(conn.created_at) < max_idle_time;
81 let is_valid = if not_expired {
82 manager.is_valid(&mut conn.connection).await
83 } else {
84 false
85 };
86 if not_expired && is_valid {
87 valid_connections.push_back(conn);
88 }
89 }
90 let removed_count = initial_count - valid_connections.len();
91 *connections = valid_connections;
92
93 if removed_count > 0 {
94 log::debug!("Background cleanup removed {removed_count} expired/invalid connections");
95 }
96
97 log::debug!("Current pool (remaining {}) after cleanup", connections.len());
98 }
99 });
100
101 self.handle = Some(handle);
102 }
103
104 async fn stop(&mut self) {
105 if let Some(shutdown_tx) = self.shutdown_tx.take() {
106 let _ = shutdown_tx.send(());
107 }
108 if let Some(handle) = self.handle.take() {
109 if let Err(e) = handle.await {
110 log::error!("Error while stopping background cleanup task: {e}");
111 }
112 log::info!("Background cleanup task stopped in async stop");
113 }
114 }
115
116 fn stop_sync(&mut self) {
117 if let Some(shutdown_tx) = self.shutdown_tx.take() {
118 let _ = shutdown_tx.send(());
119 }
120 if let Some(handle) = self.handle.take() {
121 std::thread::sleep(std::time::Duration::from_millis(100));
122 handle.abort();
123 log::info!("Background cleanup task stopped in stop_sync");
124 }
125 }
126}
127
128impl Drop for CleanupTaskController {
129 fn drop(&mut self) {
130 self.stop_sync();
131 }
132}
133
134pub trait ConnectionManager: Sync + Send + Clone {
136 type Connection: Send;
138
139 type Error: std::error::Error + Send + Sync + 'static;
141
142 type CreateFut: Future<Output = Result<Self::Connection, Self::Error>> + Send;
144
145 type ValidFut<'a>: Future<Output = bool> + Send
147 where
148 Self: 'a;
149
150 fn create_connection(&self) -> Self::CreateFut;
152
153 fn is_valid<'a>(&'a self, connection: &'a mut Self::Connection) -> Self::ValidFut<'a>;
155}
156
157pub struct ConnectionPool<M>
159where
160 M: ConnectionManager + Send + Sync + Clone + 'static,
161{
162 connections: Arc<Mutex<VecDeque<InnerConnection<M::Connection>>>>,
163 semaphore: Arc<Semaphore>,
164 max_size: usize,
165 manager: M,
166 max_idle_time: Duration,
167 connection_timeout: Duration,
168 cleanup_controller: Arc<Mutex<CleanupTaskController>>,
169}
170
171struct InnerConnection<T> {
173 pub connection: T,
174 pub created_at: Instant,
175}
176
177pub struct ManagedConnection<M>
179where
180 M: ConnectionManager + Send + Sync + 'static,
181{
182 connection: Option<M::Connection>,
183 pool: Arc<ConnectionPool<M>>,
184 _permit: tokio::sync::OwnedSemaphorePermit,
185}
186
187impl<M> ConnectionPool<M>
188where
189 M: ConnectionManager + Send + Sync + Clone + 'static,
190{
191 pub fn new(
193 max_size: Option<usize>,
194 max_idle_time: Option<Duration>,
195 connection_timeout: Option<Duration>,
196 cleanup_config: Option<CleanupConfig>,
197 manager: M,
198 ) -> Arc<Self> {
199 let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
200 let max_idle_time = max_idle_time.unwrap_or(DEFAULT_IDLE_TIMEOUT);
201 let connection_timeout = connection_timeout.unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
202 let cleanup_config = cleanup_config.unwrap_or_default();
203
204 log::info!(
205 "Creating connection pool with max_size: {max_size}, idle_timeout: {max_idle_time:?}, connection_timeout: {connection_timeout:?}, cleanup_enabled: {}",
206 cleanup_config.enabled
207 );
208
209 let connections = Arc::new(Mutex::new(VecDeque::new()));
210 let cleanup_controller = Arc::new(Mutex::new(CleanupTaskController::new()));
211
212 let pool = Arc::new(ConnectionPool {
213 connections: connections.clone(),
214 semaphore: Arc::new(Semaphore::new(max_size)),
215 max_size,
216 manager,
217 max_idle_time,
218 connection_timeout,
219 cleanup_controller: cleanup_controller.clone(),
220 });
221
222 if cleanup_config.enabled {
224 let manager = Arc::new(pool.manager.clone());
225 tokio::spawn(async move {
226 let mut controller = cleanup_controller.lock().await;
227 controller.start(connections, max_idle_time, cleanup_config.interval, manager);
228 });
229 }
230
231 pool
232 }
233
234 pub async fn get_connection(self: Arc<Self>) -> Result<ManagedConnection<M>, PoolError<M::Error>> {
236 log::debug!("Attempting to get connection from pool");
237
238 let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| PoolError::PoolClosed)?;
240
241 {
243 let mut connections = self.connections.lock().await;
244 loop {
245 let Some(mut pooled_conn) = connections.pop_front() else {
246 break;
248 };
249 log::trace!("Found existing connection in pool, validating...");
250 let age = Instant::now().duration_since(pooled_conn.created_at);
251 let valid = self.manager.is_valid(&mut pooled_conn.connection).await;
252 if age >= self.max_idle_time {
253 log::debug!("Connection expired (age: {age:?}), discarding");
254 } else if !valid {
255 log::warn!("Connection validation failed, discarding invalid connection");
256 } else {
257 let size = connections.len();
258 log::debug!("Reusing existing connection from pool (remaining: {size}/{})", self.max_size);
259 return Ok(ManagedConnection {
260 connection: Some(pooled_conn.connection),
261 pool: self.clone(),
262 _permit: permit,
263 });
264 }
265 }
266 }
267
268 log::trace!("No valid connection available, creating new connection...");
269 match timeout(self.connection_timeout, self.manager.create_connection()).await {
271 Ok(Ok(connection)) => {
272 log::info!("Successfully created new connection");
273 Ok(ManagedConnection {
274 connection: Some(connection),
275 pool: self.clone(),
276 _permit: permit,
277 })
278 }
279 Ok(Err(e)) => {
280 log::error!("Failed to create new connection");
281 Err(PoolError::Creation(e))
282 }
283 Err(_) => {
284 log::warn!("Connection creation timed out after {:?}", self.connection_timeout);
285 Err(PoolError::Timeout)
286 }
287 }
288 }
289
290 pub async fn stop_cleanup_task(&self) {
292 let mut controller = self.cleanup_controller.lock().await;
293 controller.stop().await;
294 }
295
296 pub async fn restart_cleanup_task(&self, cleanup_config: CleanupConfig) {
298 self.stop_cleanup_task().await;
299
300 if cleanup_config.enabled {
301 let manager = Arc::new(self.manager.clone());
302 let mut controller = self.cleanup_controller.lock().await;
303 controller.start(self.connections.clone(), self.max_idle_time, cleanup_config.interval, manager);
304 }
305 }
306}
307
308impl<M> ConnectionPool<M>
309where
310 M: ConnectionManager + Send + Sync + Clone + 'static,
311{
312 async fn recycle(&self, mut connection: M::Connection) {
313 if !self.manager.is_valid(&mut connection).await {
314 log::debug!("Invalid connection, dropping");
315 return;
316 }
317 let mut connections = self.connections.lock().await;
318 if connections.len() < self.max_size {
319 connections.push_back(InnerConnection {
320 connection,
321 created_at: Instant::now(),
322 });
323 log::debug!("Connection returned to pool (pool size: {}/{})", connections.len(), self.max_size);
324 } else {
325 log::debug!("Pool is full, dropping connection (pool max size: {})", self.max_size);
326 }
327 }
329}
330
331impl<M> Drop for ManagedConnection<M>
332where
333 M: ConnectionManager + Send + Sync + Clone + 'static,
334{
335 fn drop(&mut self) {
336 if let Some(connection) = self.connection.take() {
337 let pool = self.pool.clone();
338 _ = tokio::spawn(async move {
339 log::trace!("Returning connection to pool on drop");
340 pool.recycle(connection).await;
341 });
342 }
343 }
344}
345
346impl<M> AsRef<M::Connection> for ManagedConnection<M>
348where
349 M: ConnectionManager + Send + Sync + Clone + 'static,
350{
351 fn as_ref(&self) -> &M::Connection {
352 self.connection.as_ref().unwrap()
353 }
354}
355
356impl<M> AsMut<M::Connection> for ManagedConnection<M>
357where
358 M: ConnectionManager + Send + Sync + Clone + 'static,
359{
360 fn as_mut(&mut self) -> &mut M::Connection {
361 self.connection.as_mut().unwrap()
362 }
363}
364
365impl<M> std::ops::Deref for ManagedConnection<M>
367where
368 M: ConnectionManager + Send + Sync + Clone + 'static,
369{
370 type Target = M::Connection;
371
372 fn deref(&self) -> &Self::Target {
373 self.connection.as_ref().unwrap()
374 }
375}
376
377impl<M> std::ops::DerefMut for ManagedConnection<M>
378where
379 M: ConnectionManager + Send + Sync + Clone + 'static,
380{
381 fn deref_mut(&mut self) -> &mut Self::Target {
382 self.connection.as_mut().unwrap()
383 }
384}
385
386#[derive(Debug)]
388pub enum PoolError<E> {
389 PoolClosed,
390 Timeout,
391 Creation(E),
392}
393
394impl<E: std::fmt::Display> std::fmt::Display for PoolError<E> {
395 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396 match self {
397 PoolError::PoolClosed => write!(f, "Connection pool is closed"),
398 PoolError::Timeout => write!(f, "Connection creation timeout"),
399 PoolError::Creation(e) => write!(f, "Connection creation failed: {e}"),
400 }
401 }
402}
403
404impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
405 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
406 match self {
407 PoolError::Creation(e) => Some(e),
408 _ => None,
409 }
410 }
411}