1use std::collections::VecDeque;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::sync::{Mutex, Semaphore};
6use tokio::task::JoinHandle;
7use tokio::time::{interval, timeout};
8
9pub const DEFAULT_MAX_SIZE: usize = 10;
10pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(5 * 60); pub const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); pub const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); #[derive(Clone)]
16pub struct CleanupConfig {
17 pub interval: Duration,
18 pub enabled: bool,
19}
20
21impl Default for CleanupConfig {
22 fn default() -> Self {
23 Self {
24 interval: DEFAULT_CLEANUP_INTERVAL,
25 enabled: true,
26 }
27 }
28}
29
30pub struct CleanupTaskController {
32 handle: Option<JoinHandle<()>>,
33}
34
35impl CleanupTaskController {
36 pub fn new() -> Self {
37 Self { handle: None }
38 }
39
40 pub fn start<T: Send + 'static>(
41 &mut self,
42 connections: Arc<Mutex<VecDeque<PooledConnection<T>>>>,
43 max_idle_time: Duration,
44 cleanup_interval: Duration,
45 ) {
46 if self.handle.is_some() {
47 log::warn!("Cleanup task is already running");
48 return;
49 }
50
51 let handle = tokio::spawn(async move {
52 let mut interval_timer = interval(cleanup_interval);
53 log::info!("Background cleanup task started with interval: {cleanup_interval:?}");
54
55 loop {
56 interval_timer.tick().await;
57
58 let mut connections = connections.lock().await;
59 let initial_count = connections.len();
60 let now = Instant::now();
61
62 connections.retain(|conn| now.duration_since(conn.created_at) < max_idle_time);
63
64 let removed_count = initial_count - connections.len();
65 if removed_count > 0 {
66 log::debug!("Background cleanup removed {removed_count} expired connections");
67 }
68
69 drop(connections);
71 }
72 });
73
74 self.handle = Some(handle);
75 }
76
77 pub fn stop(&mut self) {
78 if let Some(handle) = self.handle.take() {
79 handle.abort();
80 log::info!("Background cleanup task stopped");
81 }
82 }
83}
84
85impl Drop for CleanupTaskController {
86 fn drop(&mut self) {
87 self.stop();
88 }
89}
90
91pub trait ConnectionManager: Sync + Send {
92 type Connection: Send;
93 type Error: std::error::Error + Send;
94 type CreateFut: Future<Output = Result<Self::Connection, Self::Error>> + Send;
95 type ValidFut<'a>: Future<Output = bool> + Send
96 where
97 Self: 'a;
98
99 fn create_connection(&self) -> Self::CreateFut;
100 fn is_valid<'a>(&'a self, connection: &'a Self::Connection) -> Self::ValidFut<'a>;
101}
102
103pub struct ConnectionPool<M>
104where
105 M: ConnectionManager + Send + Sync + 'static,
106{
107 connections: Arc<Mutex<VecDeque<PooledConnection<M::Connection>>>>,
108 semaphore: Arc<Semaphore>,
109 max_size: usize,
110 manager: M,
111 max_idle_time: Duration,
112 connection_timeout: Duration,
113 cleanup_controller: Arc<Mutex<CleanupTaskController>>,
114}
115
116pub struct PooledConnection<T> {
117 pub connection: T,
118 pub created_at: Instant,
119}
120
121pub struct PooledStream<M>
122where
123 M: ConnectionManager + Send + Sync + 'static,
124{
125 connection: Option<M::Connection>,
126 pool: Arc<ConnectionPool<M>>,
127 _permit: tokio::sync::OwnedSemaphorePermit,
128}
129
130impl<M> ConnectionPool<M>
131where
132 M: ConnectionManager + Send + Sync + 'static,
133{
134 pub fn new(
135 max_size: Option<usize>,
136 max_idle_time: Option<Duration>,
137 connection_timeout: Option<Duration>,
138 cleanup_config: Option<CleanupConfig>,
139 manager: M,
140 ) -> Arc<Self> {
141 let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
142 let max_idle_time = max_idle_time.unwrap_or(DEFAULT_IDLE_TIMEOUT);
143 let connection_timeout = connection_timeout.unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
144 let cleanup_config = cleanup_config.unwrap_or_default();
145
146 log::info!(
147 "Creating connection pool with max_size: {max_size}, idle_timeout: {max_idle_time:?}, connection_timeout: {connection_timeout:?}, cleanup_enabled: {}",
148 cleanup_config.enabled
149 );
150
151 let connections = Arc::new(Mutex::new(VecDeque::new()));
152 let cleanup_controller = Arc::new(Mutex::new(CleanupTaskController::new()));
153
154 let pool = Arc::new(ConnectionPool {
155 connections: connections.clone(),
156 semaphore: Arc::new(Semaphore::new(max_size)),
157 max_size,
158 manager,
159 max_idle_time,
160 connection_timeout,
161 cleanup_controller: cleanup_controller.clone(),
162 });
163
164 if cleanup_config.enabled {
166 tokio::spawn(async move {
167 let mut controller = cleanup_controller.lock().await;
168 controller.start(connections, max_idle_time, cleanup_config.interval);
169 });
170 }
171
172 pool
173 }
174
175 pub async fn get_connection(self: Arc<Self>) -> Result<PooledStream<M>, PoolError<M::Error>> {
176 log::debug!("Attempting to get connection from pool");
177
178 let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| PoolError::PoolClosed)?;
180
181 {
182 let mut connections = self.connections.lock().await;
184
185 if let Some(pooled_conn) = connections.pop_front() {
188 log::trace!("Found existing connection in pool, validating...");
189
190 let age = Instant::now().duration_since(pooled_conn.created_at);
192 if age >= self.max_idle_time {
193 log::debug!("Connection expired (age: {age:?}), discarding");
194 } else if self.manager.is_valid(&pooled_conn.connection).await {
195 log::debug!("Reusing existing connection from pool (remaining: {})", connections.len());
196 return Ok(PooledStream {
197 connection: Some(pooled_conn.connection),
198 pool: self.clone(),
199 _permit: permit,
200 });
201 } else {
202 log::warn!("Connection validation failed, discarding invalid connection");
203 }
204 }
205 }
206
207 log::trace!("No valid connection available, creating new connection...");
208 match timeout(self.connection_timeout, self.manager.create_connection()).await {
210 Ok(Ok(connection)) => {
211 log::info!("Successfully created new connection");
212 Ok(PooledStream {
213 connection: Some(connection),
214 pool: self.clone(),
215 _permit: permit,
216 })
217 }
218 Ok(Err(e)) => {
219 log::error!("Failed to create new connection");
220 Err(PoolError::Creation(e))
221 }
222 Err(_) => {
223 log::warn!("Connection creation timed out after {:?}", self.connection_timeout);
224 Err(PoolError::Timeout)
225 }
226 }
227 }
228
229 pub async fn stop_cleanup_task(&self) {
231 let mut controller = self.cleanup_controller.lock().await;
232 controller.stop();
233 }
234
235 pub async fn restart_cleanup_task(&self, cleanup_config: CleanupConfig) {
237 let mut controller = self.cleanup_controller.lock().await;
238 controller.stop();
239
240 if cleanup_config.enabled {
241 controller.start(self.connections.clone(), self.max_idle_time, cleanup_config.interval);
242 }
243 }
244}
245
246impl<M> ConnectionPool<M>
247where
248 M: ConnectionManager + Send + Sync + 'static,
249{
250 async fn return_connection(&self, connection: M::Connection) {
251 let mut connections = self.connections.lock().await;
252 if connections.len() < self.max_size {
253 connections.push_back(PooledConnection {
254 connection,
255 created_at: Instant::now(),
256 });
257 log::trace!("Connection returned to pool (pool size: {})", connections.len());
258 } else {
259 log::trace!("Pool is full, dropping connection (max_size: {})", self.max_size);
260 }
261 }
263}
264
265impl<M> Drop for PooledStream<M>
266where
267 M: ConnectionManager + Send + Sync + 'static,
268{
269 fn drop(&mut self) {
270 if let Some(connection) = self.connection.take() {
271 let pool = self.pool.clone();
272 if let Ok(handle) = tokio::runtime::Handle::try_current() {
273 log::trace!("Returning connection to pool on drop");
274 tokio::task::block_in_place(|| handle.block_on(pool.return_connection(connection)));
275 } else {
276 log::warn!("No tokio runtime available, connection will be dropped");
277 }
278 }
279 }
280}
281
282impl<M> AsRef<M::Connection> for PooledStream<M>
284where
285 M: ConnectionManager + Send + Sync + 'static,
286{
287 fn as_ref(&self) -> &M::Connection {
288 self.connection.as_ref().unwrap()
289 }
290}
291
292impl<M> AsMut<M::Connection> for PooledStream<M>
293where
294 M: ConnectionManager + Send + Sync + 'static,
295{
296 fn as_mut(&mut self) -> &mut M::Connection {
297 self.connection.as_mut().unwrap()
298 }
299}
300
301impl<M> std::ops::Deref for PooledStream<M>
303where
304 M: ConnectionManager + Send + Sync + 'static,
305{
306 type Target = M::Connection;
307
308 fn deref(&self) -> &Self::Target {
309 self.connection.as_ref().unwrap()
310 }
311}
312
313impl<M> std::ops::DerefMut for PooledStream<M>
314where
315 M: ConnectionManager + Send + Sync + 'static,
316{
317 fn deref_mut(&mut self) -> &mut Self::Target {
318 self.connection.as_mut().unwrap()
319 }
320}
321
322#[derive(Debug)]
324pub enum PoolError<E> {
325 PoolClosed,
326 Timeout,
327 Creation(E),
328}
329
330impl<E: std::fmt::Display> std::fmt::Display for PoolError<E> {
331 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332 match self {
333 PoolError::PoolClosed => write!(f, "Connection pool is closed"),
334 PoolError::Timeout => write!(f, "Connection creation timeout"),
335 PoolError::Creation(e) => write!(f, "Connection creation failed: {e}"),
336 }
337 }
338}
339
340impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
341 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
342 match self {
343 PoolError::Creation(e) => Some(e),
344 _ => None,
345 }
346 }
347}