1use std::collections::VecDeque;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::{Mutex, Semaphore};
6use tokio::task::JoinHandle;
7use tokio::time::{interval, timeout};
8
9pub const DEFAULT_MAX_SIZE: usize = 10;
10pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(5 * 60); pub const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); pub const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); #[derive(Clone)]
16pub struct CleanupConfig {
17 pub interval: Duration,
18 pub enabled: bool,
19}
20
21impl Default for CleanupConfig {
22 fn default() -> Self {
23 Self {
24 interval: DEFAULT_CLEANUP_INTERVAL,
25 enabled: true,
26 }
27 }
28}
29
30struct CleanupTaskController {
32 handle: Option<JoinHandle<()>>,
33 shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
34}
35
36impl CleanupTaskController {
37 fn new() -> Self {
38 Self {
39 handle: None,
40 shutdown_tx: None,
41 }
42 }
43
44 fn start<T, M>(
45 &mut self,
46 connections: Arc<Mutex<VecDeque<InnerConnection<T>>>>,
47 max_idle_time: Duration,
48 cleanup_interval: Duration,
49 manager: Arc<M>,
50 ) where
51 T: Send + 'static,
52 M: ConnectionManager<Connection = T> + Send + Sync + 'static,
53 {
54 if self.handle.is_some() {
55 log::warn!("Cleanup task is already running");
56 return;
57 }
58 let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
59 self.shutdown_tx = Some(shutdown_tx);
60 let handle = tokio::spawn(async move {
61 let mut interval_timer = interval(cleanup_interval);
62 log::info!("Background cleanup task started with interval: {cleanup_interval:?}");
63
64 loop {
65 tokio::select! {
66 _ = interval_timer.tick() => {}
67 _ = &mut shutdown_rx => {
68 log::info!("Received shutdown signal, exiting cleanup loop");
69 break;
70 }
71 };
72
73 let mut connections = connections.lock().await;
74 let initial_count = connections.len();
75 let now = Instant::now();
76
77 let mut valid_connections = VecDeque::new();
79 for mut conn in connections.drain(..) {
80 let not_expired = now.duration_since(conn.created_at) < max_idle_time;
81 let is_valid = if not_expired {
82 manager.is_valid(&mut conn.connection).await
83 } else {
84 false
85 };
86 if not_expired && is_valid {
87 valid_connections.push_back(conn);
88 }
89 }
90 let removed_count = initial_count - valid_connections.len();
91 *connections = valid_connections;
92
93 if removed_count > 0 {
94 log::debug!("Background cleanup removed {removed_count} expired/invalid connections");
95 }
96
97 log::trace!("Current pool size after cleanup: {}", connections.len());
98 }
99 });
100
101 self.handle = Some(handle);
102 }
103
104 async fn stop(&mut self) {
105 if let Some(shutdown_tx) = self.shutdown_tx.take() {
106 let _ = shutdown_tx.send(());
107 }
108 if let Some(handle) = self.handle.take() {
109 if let Err(e) = handle.await {
110 log::error!("Error while stopping background cleanup task: {e}");
111 }
112 log::info!("Background cleanup task stopped in async stop");
113 }
114 }
115
116 fn stop_sync(&mut self) {
117 if let Some(shutdown_tx) = self.shutdown_tx.take() {
118 let _ = shutdown_tx.send(());
119 }
120 if let Some(handle) = self.handle.take() {
121 std::thread::sleep(std::time::Duration::from_millis(100));
122 handle.abort();
123 log::info!("Background cleanup task stopped in stop_sync");
124 }
125 }
126}
127
128impl Drop for CleanupTaskController {
129 fn drop(&mut self) {
130 self.stop_sync();
131 }
132}
133
134pub trait ConnectionManager: Sync + Send + Clone {
136 type Connection: Send;
138
139 type Error: std::error::Error + Send + Sync + 'static;
141
142 type CreateFut: Future<Output = Result<Self::Connection, Self::Error>> + Send;
144
145 type ValidFut<'a>: Future<Output = bool> + Send
147 where
148 Self: 'a;
149
150 fn create_connection(&self) -> Self::CreateFut;
152
153 fn is_valid<'a>(&'a self, connection: &'a mut Self::Connection) -> Self::ValidFut<'a>;
155}
156
157pub struct ConnectionPool<M>
159where
160 M: ConnectionManager + Send + Sync + Clone + 'static,
161{
162 connections: Arc<Mutex<VecDeque<InnerConnection<M::Connection>>>>,
163 semaphore: Arc<Semaphore>,
164 max_size: usize,
165 manager: M,
166 max_idle_time: Duration,
167 connection_timeout: Duration,
168 cleanup_controller: Arc<Mutex<CleanupTaskController>>,
169}
170
171struct InnerConnection<T> {
173 pub connection: T,
174 pub created_at: Instant,
175}
176
177pub struct ManagedConnection<M>
179where
180 M: ConnectionManager + Send + Sync + 'static,
181{
182 connection: Option<M::Connection>,
183 pool: Arc<ConnectionPool<M>>,
184 _permit: tokio::sync::OwnedSemaphorePermit,
185}
186
187impl<M> ConnectionPool<M>
188where
189 M: ConnectionManager + Send + Sync + Clone + 'static,
190{
191 pub fn new(
193 max_size: Option<usize>,
194 max_idle_time: Option<Duration>,
195 connection_timeout: Option<Duration>,
196 cleanup_config: Option<CleanupConfig>,
197 manager: M,
198 ) -> Arc<Self> {
199 let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
200 let max_idle_time = max_idle_time.unwrap_or(DEFAULT_IDLE_TIMEOUT);
201 let connection_timeout = connection_timeout.unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
202 let cleanup_config = cleanup_config.unwrap_or_default();
203
204 log::info!(
205 "Creating connection pool with max_size: {max_size}, idle_timeout: {max_idle_time:?}, connection_timeout: {connection_timeout:?}, cleanup_enabled: {}",
206 cleanup_config.enabled
207 );
208
209 let connections = Arc::new(Mutex::new(VecDeque::new()));
210 let cleanup_controller = Arc::new(Mutex::new(CleanupTaskController::new()));
211
212 let pool = Arc::new(ConnectionPool {
213 connections: connections.clone(),
214 semaphore: Arc::new(Semaphore::new(max_size)),
215 max_size,
216 manager,
217 max_idle_time,
218 connection_timeout,
219 cleanup_controller: cleanup_controller.clone(),
220 });
221
222 if cleanup_config.enabled {
224 let manager = Arc::new(pool.manager.clone());
225 tokio::spawn(async move {
226 let mut controller = cleanup_controller.lock().await;
227 controller.start(connections, max_idle_time, cleanup_config.interval, manager);
228 });
229 }
230
231 pool
232 }
233
234 pub async fn get_connection(self: Arc<Self>) -> Result<ManagedConnection<M>, PoolError<M::Error>> {
236 log::debug!("Attempting to get connection from pool");
237
238 let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| PoolError::PoolClosed)?;
240
241 {
243 let mut connections = self.connections.lock().await;
244 loop {
245 let Some(mut pooled_conn) = connections.pop_front() else {
246 break;
248 };
249 log::trace!("Found existing connection in pool, validating...");
250 let age = Instant::now().duration_since(pooled_conn.created_at);
251 let valid = self.manager.is_valid(&mut pooled_conn.connection).await;
252 if age >= self.max_idle_time {
253 log::debug!("Connection expired (age: {age:?}), discarding");
254 } else if !valid {
255 log::warn!("Connection validation failed, discarding invalid connection");
256 } else {
257 log::debug!("Reusing existing connection from pool (remaining: {})", connections.len());
258 return Ok(ManagedConnection {
259 connection: Some(pooled_conn.connection),
260 pool: self.clone(),
261 _permit: permit,
262 });
263 }
264 }
265 }
266
267 log::trace!("No valid connection available, creating new connection...");
268 match timeout(self.connection_timeout, self.manager.create_connection()).await {
270 Ok(Ok(connection)) => {
271 log::info!("Successfully created new connection");
272 Ok(ManagedConnection {
273 connection: Some(connection),
274 pool: self.clone(),
275 _permit: permit,
276 })
277 }
278 Ok(Err(e)) => {
279 log::error!("Failed to create new connection");
280 Err(PoolError::Creation(e))
281 }
282 Err(_) => {
283 log::warn!("Connection creation timed out after {:?}", self.connection_timeout);
284 Err(PoolError::Timeout)
285 }
286 }
287 }
288
289 pub async fn stop_cleanup_task(&self) {
291 let mut controller = self.cleanup_controller.lock().await;
292 controller.stop().await;
293 }
294
295 pub async fn restart_cleanup_task(&self, cleanup_config: CleanupConfig) {
297 self.stop_cleanup_task().await;
298
299 if cleanup_config.enabled {
300 let manager = Arc::new(self.manager.clone());
301 let mut controller = self.cleanup_controller.lock().await;
302 controller.start(self.connections.clone(), self.max_idle_time, cleanup_config.interval, manager);
303 }
304 }
305}
306
307impl<M> ConnectionPool<M>
308where
309 M: ConnectionManager + Send + Sync + Clone + 'static,
310{
311 async fn recycle(&self, mut connection: M::Connection) {
312 if !self.manager.is_valid(&mut connection).await {
313 log::debug!("Invalid connection, dropping");
314 return;
315 }
316 let mut connections = self.connections.lock().await;
317 if connections.len() < self.max_size {
318 connections.push_back(InnerConnection {
319 connection,
320 created_at: Instant::now(),
321 });
322 log::trace!("Connection returned to pool (pool size: {})", connections.len());
323 } else {
324 log::trace!("Pool is full, dropping connection (max_size: {})", self.max_size);
325 }
326 }
328}
329
330impl<M> Drop for ManagedConnection<M>
331where
332 M: ConnectionManager + Send + Sync + Clone + 'static,
333{
334 fn drop(&mut self) {
335 if let Some(connection) = self.connection.take() {
336 let pool = self.pool.clone();
337 _ = tokio::spawn(async move {
338 log::trace!("Returning connection to pool on drop");
339 pool.recycle(connection).await;
340 });
341 }
342 }
343}
344
345impl<M> AsRef<M::Connection> for ManagedConnection<M>
347where
348 M: ConnectionManager + Send + Sync + Clone + 'static,
349{
350 fn as_ref(&self) -> &M::Connection {
351 self.connection.as_ref().unwrap()
352 }
353}
354
355impl<M> AsMut<M::Connection> for ManagedConnection<M>
356where
357 M: ConnectionManager + Send + Sync + Clone + 'static,
358{
359 fn as_mut(&mut self) -> &mut M::Connection {
360 self.connection.as_mut().unwrap()
361 }
362}
363
364impl<M> std::ops::Deref for ManagedConnection<M>
366where
367 M: ConnectionManager + Send + Sync + Clone + 'static,
368{
369 type Target = M::Connection;
370
371 fn deref(&self) -> &Self::Target {
372 self.connection.as_ref().unwrap()
373 }
374}
375
376impl<M> std::ops::DerefMut for ManagedConnection<M>
377where
378 M: ConnectionManager + Send + Sync + Clone + 'static,
379{
380 fn deref_mut(&mut self) -> &mut Self::Target {
381 self.connection.as_mut().unwrap()
382 }
383}
384
385#[derive(Debug)]
387pub enum PoolError<E> {
388 PoolClosed,
389 Timeout,
390 Creation(E),
391}
392
393impl<E: std::fmt::Display> std::fmt::Display for PoolError<E> {
394 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
395 match self {
396 PoolError::PoolClosed => write!(f, "Connection pool is closed"),
397 PoolError::Timeout => write!(f, "Connection creation timeout"),
398 PoolError::Creation(e) => write!(f, "Connection creation failed: {e}"),
399 }
400 }
401}
402
403impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
404 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
405 match self {
406 PoolError::Creation(e) => Some(e),
407 _ => None,
408 }
409 }
410}