1use std::collections::VecDeque;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5use tokio::net::{TcpStream, ToSocketAddrs};
6use tokio::sync::{Mutex, Semaphore};
7use tokio::task::JoinHandle;
8use tokio::time::{interval, timeout};
9
10pub const DEFAULT_MAX_SIZE: usize = 10;
11pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(5 * 60); pub const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10); pub const DEFAULT_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); #[derive(Clone)]
17pub struct CleanupConfig {
18 pub interval: Duration,
19 pub enabled: bool,
20}
21
22impl Default for CleanupConfig {
23 fn default() -> Self {
24 Self {
25 interval: DEFAULT_CLEANUP_INTERVAL,
26 enabled: true,
27 }
28 }
29}
30
31pub struct CleanupTaskController {
33 handle: Option<JoinHandle<()>>,
34}
35
36impl CleanupTaskController {
37 pub fn new() -> Self {
38 Self { handle: None }
39 }
40
41 pub fn start<T: Send + 'static>(
42 &mut self,
43 connections: Arc<Mutex<VecDeque<PooledConnection<T>>>>,
44 max_idle_time: Duration,
45 cleanup_interval: Duration,
46 ) {
47 if self.handle.is_some() {
48 log::warn!("Cleanup task is already running");
49 return;
50 }
51
52 let handle = tokio::spawn(async move {
53 let mut interval_timer = interval(cleanup_interval);
54 log::info!("Background cleanup task started with interval: {cleanup_interval:?}");
55
56 loop {
57 interval_timer.tick().await;
58
59 let mut connections = connections.lock().await;
60 let initial_count = connections.len();
61 let now = Instant::now();
62
63 connections.retain(|conn| now.duration_since(conn.created_at) < max_idle_time);
64
65 let removed_count = initial_count - connections.len();
66 if removed_count > 0 {
67 log::debug!("Background cleanup removed {removed_count} expired connections");
68 }
69
70 drop(connections);
72 }
73 });
74
75 self.handle = Some(handle);
76 }
77
78 pub fn stop(&mut self) {
79 if let Some(handle) = self.handle.take() {
80 handle.abort();
81 log::info!("Background cleanup task stopped");
82 }
83 }
84}
85
86impl Drop for CleanupTaskController {
87 fn drop(&mut self) {
88 self.stop();
89 }
90}
91
92pub trait ConnectionCreator<T, P> {
94 type Error;
95 type Future: Future<Output = Result<T, Self::Error>>;
96
97 fn create_connection(&self, params: &P) -> Self::Future;
98}
99
100pub trait ConnectionValidator<T> {
102 fn is_valid(&self, connection: &T) -> impl Future<Output = bool> + Send;
103}
104
105pub struct ConnectionPool<T, P, C, V>
106where
107 T: Send + 'static,
108 P: Send + Sync + Clone + 'static,
109 C: Send + Sync + 'static,
110 V: Send + Sync + 'static,
111{
112 connections: Arc<Mutex<VecDeque<PooledConnection<T>>>>,
113 semaphore: Arc<Semaphore>,
114 max_size: usize,
115 connection_params: P,
116 connection_creator: C,
117 connection_validator: V,
118 max_idle_time: Duration,
119 connection_timeout: Duration,
120 cleanup_controller: Arc<Mutex<CleanupTaskController>>,
121}
122
123pub struct PooledConnection<T> {
124 pub connection: T,
125 pub created_at: Instant,
126}
127
128pub struct PooledStream<T, P, C, V>
129where
130 T: Send + 'static,
131 P: Send + Sync + Clone + 'static,
132 C: Send + Sync + 'static,
133 V: Send + Sync + 'static,
134{
135 connection: Option<T>,
136 pool: Arc<ConnectionPool<T, P, C, V>>,
137 _permit: tokio::sync::OwnedSemaphorePermit,
138}
139
140impl<T, P, C, V> ConnectionPool<T, P, C, V>
141where
142 C: ConnectionCreator<T, P> + Send + Sync + 'static,
143 V: ConnectionValidator<T> + Send + Sync + 'static,
144 T: Send + 'static,
145 P: Send + Sync + Clone + 'static,
146{
147 pub fn new(
148 max_size: Option<usize>,
149 max_idle_time: Option<Duration>,
150 connection_timeout: Option<Duration>,
151 cleanup_config: Option<CleanupConfig>,
152 connection_params: P,
153 connection_creator: C,
154 connection_validator: V,
155 ) -> Arc<Self> {
156 let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
157 let max_idle_time = max_idle_time.unwrap_or(DEFAULT_IDLE_TIMEOUT);
158 let connection_timeout = connection_timeout.unwrap_or(DEFAULT_CONNECTION_TIMEOUT);
159 let cleanup_config = cleanup_config.unwrap_or_default();
160
161 log::info!(
162 "Creating connection pool with max_size: {}, idle_timeout: {:?}, connection_timeout: {:?}, cleanup_enabled: {}",
163 max_size,
164 max_idle_time,
165 connection_timeout,
166 cleanup_config.enabled
167 );
168
169 let connections = Arc::new(Mutex::new(VecDeque::new()));
170 let cleanup_controller = Arc::new(Mutex::new(CleanupTaskController::new()));
171
172 let pool = Arc::new(ConnectionPool {
173 connections: connections.clone(),
174 semaphore: Arc::new(Semaphore::new(max_size)),
175 max_size,
176 connection_params,
177 connection_creator,
178 connection_validator,
179 max_idle_time,
180 connection_timeout,
181 cleanup_controller: cleanup_controller.clone(),
182 });
183
184 if cleanup_config.enabled {
186 tokio::spawn(async move {
187 let mut controller = cleanup_controller.lock().await;
188 controller.start(connections, max_idle_time, cleanup_config.interval);
189 });
190 }
191
192 pool
193 }
194
195 pub async fn get_connection(self: Arc<Self>) -> Result<PooledStream<T, P, C, V>, PoolError<C::Error>> {
196 log::debug!("Attempting to get connection from pool");
197
198 let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| PoolError::PoolClosed)?;
200
201 {
202 let mut connections = self.connections.lock().await;
204
205 if let Some(pooled_conn) = connections.pop_front() {
208 log::trace!("Found existing connection in pool, validating...");
209
210 let age = Instant::now().duration_since(pooled_conn.created_at);
212 if age >= self.max_idle_time {
213 log::debug!("Connection expired (age: {age:?}), discarding");
214 } else if self.connection_validator.is_valid(&pooled_conn.connection).await {
215 log::debug!("Reusing existing connection from pool (remaining: {})", connections.len());
216 return Ok(PooledStream {
217 connection: Some(pooled_conn.connection),
218 pool: self.clone(),
219 _permit: permit,
220 });
221 } else {
222 log::warn!("Connection validation failed, discarding invalid connection");
223 }
224 }
225 }
226
227 log::trace!("No valid connection available, creating new connection...");
228 match timeout(
230 self.connection_timeout,
231 self.connection_creator.create_connection(&self.connection_params),
232 )
233 .await
234 {
235 Ok(Ok(connection)) => {
236 log::info!("Successfully created new connection");
237 Ok(PooledStream {
238 connection: Some(connection),
239 pool: self.clone(),
240 _permit: permit,
241 })
242 }
243 Ok(Err(e)) => {
244 log::error!("Failed to create new connection");
245 Err(PoolError::Creation(e))
246 }
247 Err(_) => {
248 log::warn!("Connection creation timed out after {:?}", self.connection_timeout);
249 Err(PoolError::Timeout)
250 }
251 }
252 }
253
254 pub async fn stop_cleanup_task(&self) {
256 let mut controller = self.cleanup_controller.lock().await;
257 controller.stop();
258 }
259
260 pub async fn restart_cleanup_task(&self, cleanup_config: CleanupConfig) {
262 let mut controller = self.cleanup_controller.lock().await;
263 controller.stop();
264
265 if cleanup_config.enabled {
266 controller.start(self.connections.clone(), self.max_idle_time, cleanup_config.interval);
267 }
268 }
269}
270
271impl<T, P, C, V> ConnectionPool<T, P, C, V>
273where
274 T: Send + 'static,
275 P: Send + Sync + Clone + 'static,
276 C: Send + Sync + 'static,
277 V: Send + Sync + 'static,
278{
279 async fn return_connection(&self, connection: T) {
280 let mut connections = self.connections.lock().await;
281 if connections.len() < self.max_size {
282 connections.push_back(PooledConnection {
283 connection,
284 created_at: Instant::now(),
285 });
286 log::trace!("Connection returned to pool (pool size: {})", connections.len());
287 } else {
288 log::trace!("Pool is full, dropping connection (max_size: {})", self.max_size);
289 }
290 }
292}
293
294impl<T, P, C, V> Drop for PooledStream<T, P, C, V>
295where
296 T: Send + 'static,
297 P: Send + Sync + Clone + 'static,
298 C: Send + Sync + 'static,
299 V: Send + Sync + 'static,
300{
301 fn drop(&mut self) {
302 if let Some(connection) = self.connection.take() {
303 let pool = self.pool.clone();
304 if let Ok(handle) = tokio::runtime::Handle::try_current() {
305 log::trace!("Returning connection to pool on drop");
306 tokio::task::block_in_place(|| handle.block_on(pool.return_connection(connection)));
307 } else {
308 log::warn!("No tokio runtime available, connection will be dropped");
309 }
310 }
311 }
312}
313
314impl<T, P, C, V> AsRef<T> for PooledStream<T, P, C, V>
316where
317 T: Send + 'static,
318 P: Send + Sync + Clone + 'static,
319 C: Send + Sync + 'static,
320 V: Send + Sync + 'static,
321{
322 fn as_ref(&self) -> &T {
323 self.connection.as_ref().unwrap()
324 }
325}
326
327impl<T, P, C, V> AsMut<T> for PooledStream<T, P, C, V>
328where
329 T: Send + 'static,
330 P: Send + Sync + Clone + 'static,
331 C: Send + Sync + 'static,
332 V: Send + Sync + 'static,
333{
334 fn as_mut(&mut self) -> &mut T {
335 self.connection.as_mut().unwrap()
336 }
337}
338
339impl<T, P, C, V> std::ops::Deref for PooledStream<T, P, C, V>
341where
342 T: Send + 'static,
343 P: Send + Sync + Clone + 'static,
344 C: Send + Sync + 'static,
345 V: Send + Sync + 'static,
346{
347 type Target = T;
348
349 fn deref(&self) -> &Self::Target {
350 self.connection.as_ref().unwrap()
351 }
352}
353
354impl<T, P, C, V> std::ops::DerefMut for PooledStream<T, P, C, V>
355where
356 T: Send + 'static,
357 P: Send + Sync + Clone + 'static,
358 C: Send + Sync + 'static,
359 V: Send + Sync + 'static,
360{
361 fn deref_mut(&mut self) -> &mut Self::Target {
362 self.connection.as_mut().unwrap()
363 }
364}
365
366#[derive(Debug)]
368pub enum PoolError<E> {
369 PoolClosed,
370 Timeout,
371 Creation(E),
372}
373
374impl<E: std::fmt::Display> std::fmt::Display for PoolError<E> {
375 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376 match self {
377 PoolError::PoolClosed => write!(f, "Connection pool is closed"),
378 PoolError::Timeout => write!(f, "Connection creation timeout"),
379 PoolError::Creation(e) => write!(f, "Connection creation failed: {e}"),
380 }
381 }
382}
383
384impl<E: std::error::Error + 'static> std::error::Error for PoolError<E> {
385 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
386 match self {
387 PoolError::Creation(e) => Some(e),
388 _ => None,
389 }
390 }
391}
392
393pub struct TcpConnectionCreator;
395
396impl<A> ConnectionCreator<TcpStream, A> for TcpConnectionCreator
397where
398 A: ToSocketAddrs + Send + Sync + Clone + 'static,
399{
400 type Error = std::io::Error;
401 type Future = std::pin::Pin<Box<dyn Future<Output = Result<TcpStream, Self::Error>> + Send>>;
402
403 fn create_connection(&self, address: &A) -> Self::Future {
404 let addr = address.clone();
405 Box::pin(async move { TcpStream::connect(addr).await })
406 }
407}
408
409pub struct TcpConnectionValidator;
410
411impl ConnectionValidator<TcpStream> for TcpConnectionValidator {
412 async fn is_valid(&self, stream: &TcpStream) -> bool {
413 stream
415 .ready(tokio::io::Interest::READABLE | tokio::io::Interest::WRITABLE)
416 .await
417 .is_ok()
418 }
419}
420
421pub type TcpConnectionPool<A = std::net::SocketAddr> = ConnectionPool<TcpStream, A, TcpConnectionCreator, TcpConnectionValidator>;
423pub type TcpPooledStream<A = std::net::SocketAddr> = PooledStream<TcpStream, A, TcpConnectionCreator, TcpConnectionValidator>;
424
425impl<A> TcpConnectionPool<A>
426where
427 A: ToSocketAddrs + Send + Sync + Clone + 'static,
428{
429 pub fn new_tcp(
430 max_size: Option<usize>,
431 max_idle_time: Option<Duration>,
432 connection_timeout: Option<Duration>,
433 cleanup_config: Option<CleanupConfig>,
434 address: A,
435 ) -> Arc<Self> {
436 log::info!("Creating TCP connection pool");
437 Self::new(
438 max_size,
439 max_idle_time,
440 connection_timeout,
441 cleanup_config,
442 address,
443 TcpConnectionCreator,
444 TcpConnectionValidator,
445 )
446 }
447}