1use std::collections::VecDeque;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::{Mutex, Semaphore};
6use tokio::task::JoinHandle;
7use tokio::time::{interval, timeout};
8
9pub const DEFAULT_MAX_SIZE: usize = 10;
10pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(5 * 60); pub const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); pub const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); #[derive(Clone)]
16pub struct CleanupConfig {
17 pub interval: Duration,
18 pub enabled: bool,
19}
20
21impl Default for CleanupConfig {
22 fn default() -> Self {
23 Self {
24 interval: DEFAULT_CLEANUP_INTERVAL,
25 enabled: true,
26 }
27 }
28}
29
30struct CleanupTaskController {
32 handle: Option<JoinHandle<()>>,
33}
34
35impl CleanupTaskController {
36 fn new() -> Self {
37 Self { handle: None }
38 }
39
40 fn start<T: Send + 'static>(
41 &mut self,
42 connections: Arc<Mutex<VecDeque<InnerConnection<T>>>>,
43 max_idle_time: Duration,
44 cleanup_interval: Duration,
45 ) {
46 if self.handle.is_some() {
47 log::warn!("Cleanup task is already running");
48 return;
49 }
50
51 let handle = tokio::spawn(async move {
52 let mut interval_timer = interval(cleanup_interval);
53 log::info!("Background cleanup task started with interval: {cleanup_interval:?}");
54
55 loop {
56 interval_timer.tick().await;
57
58 let mut connections = connections.lock().await;
59 let initial_count = connections.len();
60 let now = Instant::now();
61
62 connections.retain(|conn| now.duration_since(conn.created_at) < max_idle_time);
63
64 let removed_count = initial_count - connections.len();
65 if removed_count > 0 {
66 log::debug!("Background cleanup removed {removed_count} expired connections");
67 }
68
69 log::trace!("Current pool size after cleanup: {}", connections.len());
70 }
71 });
72
73 self.handle = Some(handle);
74 }
75
76 fn stop(&mut self) {
77 if let Some(handle) = self.handle.take() {
78 handle.abort();
79 log::info!("Background cleanup task stopped");
80 }
81 }
82}
83
84impl Drop for CleanupTaskController {
85 fn drop(&mut self) {
86 self.stop();
87 }
88}
89
90pub trait ConnectionManager: Sync + Send {
92 type Connection: Send;
94
95 type Error: std::error::Error + Send + Sync + 'static;
97
98 type CreateFut: Future<Output = Result<Self::Connection, Self::Error>> + Send;
100
101 type ValidFut<'a>: Future<Output = bool> + Send
103 where
104 Self: 'a;
105
106 fn create_connection(&self) -> Self::CreateFut;
108
109 fn is_valid<'a>(&'a self, connection: &'a mut Self::Connection) -> Self::ValidFut<'a>;
111}
112
113pub struct ConnectionPool<M>
115where
116 M: ConnectionManager + Send + Sync + 'static,
117{
118 connections: Arc<Mutex<VecDeque<InnerConnection<M::Connection>>>>,
119 semaphore: Arc<Semaphore>,
120 max_size: usize,
121 manager: M,
122 max_idle_time: Duration,
123 connection_timeout: Duration,
124 cleanup_controller: Arc<Mutex<CleanupTaskController>>,
125}
126
127struct InnerConnection<T> {
129 pub connection: T,
130 pub created_at: Instant,
131}
132
133pub struct ManagedConnection<M>
135where
136 M: ConnectionManager + Send + Sync + 'static,
137{
138 connection: Option<M::Connection>,
139 pool: Arc<ConnectionPool<M>>,
140 _permit: tokio::sync::OwnedSemaphorePermit,
141}
142
143impl<M> ConnectionPool<M>
144where
145 M: ConnectionManager + Send + Sync + 'static,
146{
147 pub fn new(
149 max_size: Option<usize>,
150 max_idle_time: Option<Duration>,
151 connection_timeout: Option<Duration>,
152 cleanup_config: Option<CleanupConfig>,
153 manager: M,
154 ) -> Arc<Self> {
155 let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
156 let max_idle_time = max_idle_time.unwrap_or(DEFAULT_IDLE_TIMEOUT);
157 let connection_timeout = connection_timeout.unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
158 let cleanup_config = cleanup_config.unwrap_or_default();
159
160 log::info!(
161 "Creating connection pool with max_size: {max_size}, idle_timeout: {max_idle_time:?}, connection_timeout: {connection_timeout:?}, cleanup_enabled: {}",
162 cleanup_config.enabled
163 );
164
165 let connections = Arc::new(Mutex::new(VecDeque::new()));
166 let cleanup_controller = Arc::new(Mutex::new(CleanupTaskController::new()));
167
168 let pool = Arc::new(ConnectionPool {
169 connections: connections.clone(),
170 semaphore: Arc::new(Semaphore::new(max_size)),
171 max_size,
172 manager,
173 max_idle_time,
174 connection_timeout,
175 cleanup_controller: cleanup_controller.clone(),
176 });
177
178 if cleanup_config.enabled {
180 tokio::spawn(async move {
181 let mut controller = cleanup_controller.lock().await;
182 controller.start(connections, max_idle_time, cleanup_config.interval);
183 });
184 }
185
186 pool
187 }
188
189 pub async fn get_connection(self: Arc<Self>) -> Result<ManagedConnection<M>, PoolError<M::Error>> {
191 log::debug!("Attempting to get connection from pool");
192
193 let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| PoolError::PoolClosed)?;
195
196 {
198 let mut connections = self.connections.lock().await;
199 loop {
200 let Some(mut pooled_conn) = connections.pop_front() else {
201 break;
203 };
204 log::trace!("Found existing connection in pool, validating...");
205 let age = Instant::now().duration_since(pooled_conn.created_at);
206 let valid = self.manager.is_valid(&mut pooled_conn.connection).await;
207 if age >= self.max_idle_time {
208 log::debug!("Connection expired (age: {age:?}), discarding");
209 } else if !valid {
210 log::warn!("Connection validation failed, discarding invalid connection");
211 } else {
212 log::debug!("Reusing existing connection from pool (remaining: {})", connections.len());
213 return Ok(ManagedConnection {
214 connection: Some(pooled_conn.connection),
215 pool: self.clone(),
216 _permit: permit,
217 });
218 }
219 }
220 }
221
222 log::trace!("No valid connection available, creating new connection...");
223 match timeout(self.connection_timeout, self.manager.create_connection()).await {
225 Ok(Ok(connection)) => {
226 log::info!("Successfully created new connection");
227 Ok(ManagedConnection {
228 connection: Some(connection),
229 pool: self.clone(),
230 _permit: permit,
231 })
232 }
233 Ok(Err(e)) => {
234 log::error!("Failed to create new connection");
235 Err(PoolError::Creation(e))
236 }
237 Err(_) => {
238 log::warn!("Connection creation timed out after {:?}", self.connection_timeout);
239 Err(PoolError::Timeout)
240 }
241 }
242 }
243
244 pub async fn stop_cleanup_task(&self) {
246 let mut controller = self.cleanup_controller.lock().await;
247 controller.stop();
248 }
249
250 pub async fn restart_cleanup_task(&self, cleanup_config: CleanupConfig) {
252 let mut controller = self.cleanup_controller.lock().await;
253 controller.stop();
254
255 if cleanup_config.enabled {
256 controller.start(self.connections.clone(), self.max_idle_time, cleanup_config.interval);
257 }
258 }
259}
260
261impl<M> ConnectionPool<M>
262where
263 M: ConnectionManager + Send + Sync + 'static,
264{
265 async fn recycle(&self, mut connection: M::Connection) {
266 if !self.manager.is_valid(&mut connection).await {
267 log::debug!("Invalid connection, dropping");
268 return;
269 }
270 let mut connections = self.connections.lock().await;
271 if connections.len() < self.max_size {
272 connections.push_back(InnerConnection {
273 connection,
274 created_at: Instant::now(),
275 });
276 log::trace!("Connection returned to pool (pool size: {})", connections.len());
277 } else {
278 log::trace!("Pool is full, dropping connection (max_size: {})", self.max_size);
279 }
280 }
282}
283
284impl<M> Drop for ManagedConnection<M>
285where
286 M: ConnectionManager + Send + Sync + 'static,
287{
288 fn drop(&mut self) {
289 if let Some(connection) = self.connection.take() {
290 let pool = self.pool.clone();
291 if let Ok(handle) = tokio::runtime::Handle::try_current() {
292 log::trace!("Returning connection to pool on drop");
293 tokio::task::block_in_place(|| handle.block_on(pool.recycle(connection)));
294 } else {
295 log::warn!("No tokio runtime available, connection will be dropped");
296 }
297 }
298 }
299}
300
301impl<M> AsRef<M::Connection> for ManagedConnection<M>
303where
304 M: ConnectionManager + Send + Sync + 'static,
305{
306 fn as_ref(&self) -> &M::Connection {
307 self.connection.as_ref().unwrap()
308 }
309}
310
311impl<M> AsMut<M::Connection> for ManagedConnection<M>
312where
313 M: ConnectionManager + Send + Sync + 'static,
314{
315 fn as_mut(&mut self) -> &mut M::Connection {
316 self.connection.as_mut().unwrap()
317 }
318}
319
320impl<M> std::ops::Deref for ManagedConnection<M>
322where
323 M: ConnectionManager + Send + Sync + 'static,
324{
325 type Target = M::Connection;
326
327 fn deref(&self) -> &Self::Target {
328 self.connection.as_ref().unwrap()
329 }
330}
331
332impl<M> std::ops::DerefMut for ManagedConnection<M>
333where
334 M: ConnectionManager + Send + Sync + 'static,
335{
336 fn deref_mut(&mut self) -> &mut Self::Target {
337 self.connection.as_mut().unwrap()
338 }
339}
340
341#[derive(Debug)]
343pub enum PoolError<E> {
344 PoolClosed,
345 Timeout,
346 Creation(E),
347}
348
349impl<E: std::fmt::Display> std::fmt::Display for PoolError<E> {
350 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
351 match self {
352 PoolError::PoolClosed => write!(f, "Connection pool is closed"),
353 PoolError::Timeout => write!(f, "Connection creation timeout"),
354 PoolError::Creation(e) => write!(f, "Connection creation failed: {e}"),
355 }
356 }
357}
358
359impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
360 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
361 match self {
362 PoolError::Creation(e) => Some(e),
363 _ => None,
364 }
365 }
366}