1use std::collections::VecDeque;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::{Mutex, Semaphore};
6use tokio::task::JoinHandle;
7use tokio::time::{interval, timeout};
8
9pub const DEFAULT_MAX_SIZE: usize = 10;
10pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(5 * 60); pub const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); pub const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); #[derive(Clone)]
16pub struct CleanupConfig {
17 pub interval: Duration,
18 pub enabled: bool,
19}
20
21impl Default for CleanupConfig {
22 fn default() -> Self {
23 Self {
24 interval: DEFAULT_CLEANUP_INTERVAL,
25 enabled: true,
26 }
27 }
28}
29
30pub struct CleanupTaskController {
32 handle: Option<JoinHandle<()>>,
33}
34
35impl CleanupTaskController {
36 pub fn new() -> Self {
37 Self { handle: None }
38 }
39
40 pub fn start<T: Send + 'static>(
41 &mut self,
42 connections: Arc<Mutex<VecDeque<PooledConnection<T>>>>,
43 max_idle_time: Duration,
44 cleanup_interval: Duration,
45 ) {
46 if self.handle.is_some() {
47 log::warn!("Cleanup task is already running");
48 return;
49 }
50
51 let handle = tokio::spawn(async move {
52 let mut interval_timer = interval(cleanup_interval);
53 log::info!("Background cleanup task started with interval: {cleanup_interval:?}");
54
55 loop {
56 interval_timer.tick().await;
57
58 let mut connections = connections.lock().await;
59 let initial_count = connections.len();
60 let now = Instant::now();
61
62 connections.retain(|conn| now.duration_since(conn.created_at) < max_idle_time);
63
64 let removed_count = initial_count - connections.len();
65 if removed_count > 0 {
66 log::debug!("Background cleanup removed {removed_count} expired connections");
67 }
68
69 drop(connections);
71 }
72 });
73
74 self.handle = Some(handle);
75 }
76
77 pub fn stop(&mut self) {
78 if let Some(handle) = self.handle.take() {
79 handle.abort();
80 log::info!("Background cleanup task stopped");
81 }
82 }
83}
84
85impl Drop for CleanupTaskController {
86 fn drop(&mut self) {
87 self.stop();
88 }
89}
90
91pub trait ConnectionManager: Sync + Send {
93 type Connection: Send;
95
96 type Error: std::error::Error + Send + Sync + 'static;
98
99 type CreateFut: Future<Output = Result<Self::Connection, Self::Error>> + Send;
101
102 type ValidFut<'a>: Future<Output = bool> + Send
104 where
105 Self: 'a;
106
107 fn create_connection(&self) -> Self::CreateFut;
109
110 fn is_valid<'a>(&'a self, connection: &'a Self::Connection) -> Self::ValidFut<'a>;
112}
113
114pub struct ConnectionPool<M>
116where
117 M: ConnectionManager + Send + Sync + 'static,
118{
119 connections: Arc<Mutex<VecDeque<PooledConnection<M::Connection>>>>,
120 semaphore: Arc<Semaphore>,
121 max_size: usize,
122 manager: M,
123 max_idle_time: Duration,
124 connection_timeout: Duration,
125 cleanup_controller: Arc<Mutex<CleanupTaskController>>,
126}
127
128pub struct PooledConnection<T> {
130 pub connection: T,
131 pub created_at: Instant,
132}
133
134pub struct PooledStream<M>
136where
137 M: ConnectionManager + Send + Sync + 'static,
138{
139 connection: Option<M::Connection>,
140 pool: Arc<ConnectionPool<M>>,
141 _permit: tokio::sync::OwnedSemaphorePermit,
142}
143
144impl<M> ConnectionPool<M>
145where
146 M: ConnectionManager + Send + Sync + 'static,
147{
148 pub fn new(
150 max_size: Option<usize>,
151 max_idle_time: Option<Duration>,
152 connection_timeout: Option<Duration>,
153 cleanup_config: Option<CleanupConfig>,
154 manager: M,
155 ) -> Arc<Self> {
156 let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
157 let max_idle_time = max_idle_time.unwrap_or(DEFAULT_IDLE_TIMEOUT);
158 let connection_timeout = connection_timeout.unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
159 let cleanup_config = cleanup_config.unwrap_or_default();
160
161 log::info!(
162 "Creating connection pool with max_size: {max_size}, idle_timeout: {max_idle_time:?}, connection_timeout: {connection_timeout:?}, cleanup_enabled: {}",
163 cleanup_config.enabled
164 );
165
166 let connections = Arc::new(Mutex::new(VecDeque::new()));
167 let cleanup_controller = Arc::new(Mutex::new(CleanupTaskController::new()));
168
169 let pool = Arc::new(ConnectionPool {
170 connections: connections.clone(),
171 semaphore: Arc::new(Semaphore::new(max_size)),
172 max_size,
173 manager,
174 max_idle_time,
175 connection_timeout,
176 cleanup_controller: cleanup_controller.clone(),
177 });
178
179 if cleanup_config.enabled {
181 tokio::spawn(async move {
182 let mut controller = cleanup_controller.lock().await;
183 controller.start(connections, max_idle_time, cleanup_config.interval);
184 });
185 }
186
187 pool
188 }
189
190 pub async fn get_connection(self: Arc<Self>) -> Result<PooledStream<M>, PoolError<M::Error>> {
192 log::debug!("Attempting to get connection from pool");
193
194 let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| PoolError::PoolClosed)?;
196
197 {
199 let mut connections = self.connections.lock().await;
200 loop {
201 let Some(pooled_conn) = connections.pop_front() else {
202 break;
204 };
205 log::trace!("Found existing connection in pool, validating...");
206 let age = Instant::now().duration_since(pooled_conn.created_at);
207 let valid = self.manager.is_valid(&pooled_conn.connection).await;
208 if age >= self.max_idle_time {
209 log::debug!("Connection expired (age: {age:?}), discarding");
210 } else if !valid {
211 log::warn!("Connection validation failed, discarding invalid connection");
212 } else {
213 log::debug!("Reusing existing connection from pool (remaining: {})", connections.len());
214 return Ok(PooledStream {
215 connection: Some(pooled_conn.connection),
216 pool: self.clone(),
217 _permit: permit,
218 });
219 }
220 }
221 }
222
223 log::trace!("No valid connection available, creating new connection...");
224 match timeout(self.connection_timeout, self.manager.create_connection()).await {
226 Ok(Ok(connection)) => {
227 log::info!("Successfully created new connection");
228 Ok(PooledStream {
229 connection: Some(connection),
230 pool: self.clone(),
231 _permit: permit,
232 })
233 }
234 Ok(Err(e)) => {
235 log::error!("Failed to create new connection");
236 Err(PoolError::Creation(e))
237 }
238 Err(_) => {
239 log::warn!("Connection creation timed out after {:?}", self.connection_timeout);
240 Err(PoolError::Timeout)
241 }
242 }
243 }
244
245 pub async fn stop_cleanup_task(&self) {
247 let mut controller = self.cleanup_controller.lock().await;
248 controller.stop();
249 }
250
251 pub async fn restart_cleanup_task(&self, cleanup_config: CleanupConfig) {
253 let mut controller = self.cleanup_controller.lock().await;
254 controller.stop();
255
256 if cleanup_config.enabled {
257 controller.start(self.connections.clone(), self.max_idle_time, cleanup_config.interval);
258 }
259 }
260}
261
262impl<M> ConnectionPool<M>
263where
264 M: ConnectionManager + Send + Sync + 'static,
265{
266 async fn return_connection(&self, connection: M::Connection) {
267 let mut connections = self.connections.lock().await;
268 if connections.len() < self.max_size {
269 connections.push_back(PooledConnection {
270 connection,
271 created_at: Instant::now(),
272 });
273 log::trace!("Connection returned to pool (pool size: {})", connections.len());
274 } else {
275 log::trace!("Pool is full, dropping connection (max_size: {})", self.max_size);
276 }
277 }
279}
280
281impl<M> Drop for PooledStream<M>
282where
283 M: ConnectionManager + Send + Sync + 'static,
284{
285 fn drop(&mut self) {
286 if let Some(connection) = self.connection.take() {
287 let pool = self.pool.clone();
288 if let Ok(handle) = tokio::runtime::Handle::try_current() {
289 log::trace!("Returning connection to pool on drop");
290 tokio::task::block_in_place(|| handle.block_on(pool.return_connection(connection)));
291 } else {
292 log::warn!("No tokio runtime available, connection will be dropped");
293 }
294 }
295 }
296}
297
298impl<M> AsRef<M::Connection> for PooledStream<M>
300where
301 M: ConnectionManager + Send + Sync + 'static,
302{
303 fn as_ref(&self) -> &M::Connection {
304 self.connection.as_ref().unwrap()
305 }
306}
307
308impl<M> AsMut<M::Connection> for PooledStream<M>
309where
310 M: ConnectionManager + Send + Sync + 'static,
311{
312 fn as_mut(&mut self) -> &mut M::Connection {
313 self.connection.as_mut().unwrap()
314 }
315}
316
317impl<M> std::ops::Deref for PooledStream<M>
319where
320 M: ConnectionManager + Send + Sync + 'static,
321{
322 type Target = M::Connection;
323
324 fn deref(&self) -> &Self::Target {
325 self.connection.as_ref().unwrap()
326 }
327}
328
329impl<M> std::ops::DerefMut for PooledStream<M>
330where
331 M: ConnectionManager + Send + Sync + 'static,
332{
333 fn deref_mut(&mut self) -> &mut Self::Target {
334 self.connection.as_mut().unwrap()
335 }
336}
337
338#[derive(Debug)]
340pub enum PoolError<E> {
341 PoolClosed,
342 Timeout,
343 Creation(E),
344}
345
346impl<E: std::fmt::Display> std::fmt::Display for PoolError<E> {
347 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
348 match self {
349 PoolError::PoolClosed => write!(f, "Connection pool is closed"),
350 PoolError::Timeout => write!(f, "Connection creation timeout"),
351 PoolError::Creation(e) => write!(f, "Connection creation failed: {e}"),
352 }
353 }
354}
355
356impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
357 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
358 match self {
359 PoolError::Creation(e) => Some(e),
360 _ => None,
361 }
362 }
363}