1use std::collections::VecDeque;
36use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::time::{Duration, Instant};
40
41use parking_lot::Mutex;
42use tokio::sync::Notify;
43
44use super::auto_eject::{AutoEject, AutoEjectState};
45use super::NetError;
46
47#[derive(Debug, Clone)]
55pub struct ConnPoolConfig {
56 pub max_connections: usize,
59 pub server_failure_limit: u32,
61 pub server_retry_timeout_ms: u64,
63 pub auto_eject: bool,
65}
66
67impl Default for ConnPoolConfig {
68 fn default() -> Self {
69 Self {
70 max_connections: 1,
71 server_failure_limit: 3,
72 server_retry_timeout_ms: 30_000,
73 auto_eject: true,
74 }
75 }
76}
77
78pub type ConnFuture<C> = Pin<Box<dyn Future<Output = Result<C, NetError>> + Send + 'static>>;
80
81pub trait ConnFactory<C>: Send + Sync + 'static {
88 fn connect(&self) -> ConnFuture<C>;
90}
91
92impl<C, F, Fut> ConnFactory<C> for F
93where
94 F: Fn() -> Fut + Send + Sync + 'static,
95 Fut: Future<Output = Result<C, NetError>> + Send + 'static,
96{
97 fn connect(&self) -> ConnFuture<C> {
98 Box::pin(self())
99 }
100}
101
102struct PoolInner<C> {
103 cfg: ConnPoolConfig,
104 idle: VecDeque<C>,
105 in_flight: usize,
106 auto_eject: AutoEject,
107 backoff: Backoff,
108 shutdown: bool,
109}
110
111#[derive(Debug, Clone)]
112struct Backoff {
113 current: Duration,
114 max: Duration,
115}
116
117impl Backoff {
118 fn new(max: Duration) -> Self {
119 Self {
120 current: Duration::ZERO,
121 max,
122 }
123 }
124
125 fn record_failure(&mut self) -> Duration {
126 if self.current.is_zero() {
130 self.current = Duration::from_secs(1);
131 } else {
132 self.current = self.current.saturating_mul(2);
133 if self.current > self.max {
134 self.current = self.max;
135 }
136 }
137 self.current
138 }
139
140 fn record_success(&mut self) {
141 self.current = Duration::ZERO;
142 }
143}
144
145pub struct ConnPool<C> {
152 factory: Option<Arc<dyn ConnFactory<C>>>,
153 state: Arc<Mutex<PoolInner<C>>>,
154 notify: Arc<Notify>,
155}
156
157impl<C> Clone for ConnPool<C> {
158 fn clone(&self) -> Self {
159 Self {
160 factory: self.factory.clone(),
161 state: Arc::clone(&self.state),
162 notify: Arc::clone(&self.notify),
163 }
164 }
165}
166
167impl<C: Send + 'static> ConnPool<C> {
168 #[must_use]
186 pub fn new(cfg: ConnPoolConfig) -> Self {
187 let auto_eject = AutoEject::new(
188 cfg.auto_eject,
189 cfg.server_failure_limit.max(1),
190 Duration::from_millis(cfg.server_retry_timeout_ms),
191 );
192 let max_backoff = Duration::from_millis(cfg.server_retry_timeout_ms.max(1_000));
193 Self {
194 factory: None,
195 state: Arc::new(Mutex::new(PoolInner {
196 cfg,
197 idle: VecDeque::new(),
198 in_flight: 0,
199 auto_eject,
200 backoff: Backoff::new(max_backoff),
201 shutdown: false,
202 })),
203 notify: Arc::new(Notify::new()),
204 }
205 }
206
207 pub fn with_factory<F>(cfg: ConnPoolConfig, factory: F) -> Self
220 where
221 F: ConnFactory<C>,
222 {
223 let mut pool = Self::new(cfg);
224 pool.factory = Some(Arc::new(factory));
225 pool
226 }
227
228 pub fn set_factory<F: ConnFactory<C>>(&mut self, factory: F) {
230 self.factory = Some(Arc::new(factory));
231 }
232
233 #[must_use]
235 pub fn config(&self) -> ConnPoolConfig {
236 self.state.lock().cfg.clone()
237 }
238
239 #[must_use]
241 pub fn idle_count(&self) -> usize {
242 self.state.lock().idle.len()
243 }
244
245 #[must_use]
247 pub fn in_flight(&self) -> usize {
248 self.state.lock().in_flight
249 }
250
251 #[must_use]
253 pub fn is_ejected(&self, now: Instant) -> bool {
254 let mut g = self.state.lock();
255 g.auto_eject.record_attempt(now) == AutoEjectState::Ejected
256 }
257
258 #[must_use]
260 pub fn auto_eject(&self) -> AutoEject {
261 self.state.lock().auto_eject.clone()
262 }
263
264 pub fn shutdown(&self) {
269 {
270 let mut g = self.state.lock();
271 g.shutdown = true;
272 g.idle.clear();
273 }
274 self.notify.notify_waiters();
275 }
276
277 pub async fn get(&self) -> Result<ConnHandle<C>, NetError> {
288 loop {
289 let waiter = {
291 let mut g = self.state.lock();
292 if g.shutdown {
293 return Err(NetError::PoolShutdown);
294 }
295 if let Some(conn) = g.idle.pop_front() {
296 g.in_flight += 1;
297 return Ok(ConnHandle {
298 pool: self.clone(),
299 inner: Some(conn),
300 });
301 }
302 if g.in_flight + g.idle.len() >= g.cfg.max_connections {
303 true
304 } else {
305 let now = Instant::now();
306 if g.auto_eject.record_attempt(now) == AutoEjectState::Ejected {
307 return Err(NetError::Ejected);
308 }
309 false
310 }
311 };
312 if waiter {
313 self.notify.notified().await;
314 continue;
315 }
316
317 let factory = self
318 .factory
319 .as_ref()
320 .ok_or(NetError::PoolExhausted)?
321 .clone();
322 match factory.connect().await {
323 Ok(conn) => {
324 let mut g = self.state.lock();
325 g.in_flight += 1;
326 g.auto_eject.record_success(Instant::now());
327 g.backoff.record_success();
328 return Ok(ConnHandle {
329 pool: self.clone(),
330 inner: Some(conn),
331 });
332 }
333 Err(err) => {
334 let ejected;
335 {
336 let mut g = self.state.lock();
337 let now = Instant::now();
338 ejected = g.auto_eject.record_failure(now) == AutoEjectState::Ejected;
339 let _ = g.backoff.record_failure();
340 }
341 if ejected {
342 return Err(NetError::Ejected);
343 }
344 return Err(err);
345 }
346 }
347 }
348 }
349
350 fn return_conn(&self, conn: C) {
351 let mut g = self.state.lock();
352 if g.in_flight > 0 {
353 g.in_flight -= 1;
354 }
355 if !g.shutdown && g.idle.len() + g.in_flight < g.cfg.max_connections {
356 g.idle.push_back(conn);
357 }
358 drop(g);
359 self.notify.notify_one();
360 }
361
362 fn drop_conn(&self) {
363 let mut g = self.state.lock();
364 if g.in_flight > 0 {
365 g.in_flight -= 1;
366 }
367 drop(g);
368 self.notify.notify_one();
369 }
370}
371
372impl<C: std::fmt::Debug> std::fmt::Debug for ConnPool<C> {
373 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 let g = self.state.lock();
375 let factory_present = self.factory.is_some();
376 f.debug_struct("ConnPool")
377 .field("cfg", &g.cfg)
378 .field("idle", &g.idle.len())
379 .field("in_flight", &g.in_flight)
380 .field("auto_eject_failures", &g.auto_eject.failure_count())
381 .field("factory_installed", &factory_present)
382 .field("notify", &"<tokio::sync::Notify>")
383 .finish()
384 }
385}
386
387pub struct ConnHandle<C: Send + 'static> {
393 pool: ConnPool<C>,
394 inner: Option<C>,
395}
396
397impl<C: Send + 'static> std::fmt::Debug for ConnHandle<C> {
398 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
399 let _ = (&self.pool, &self.inner);
402 f.debug_struct("ConnHandle")
403 .field("alive", &self.inner.is_some())
404 .finish()
405 }
406}
407
408impl<C: Send + 'static> ConnHandle<C> {
409 pub fn get(&self) -> &C {
411 self.inner.as_ref().expect("invariant: handle is alive")
412 }
413
414 pub fn get_mut(&mut self) -> &mut C {
416 self.inner.as_mut().expect("invariant: handle is alive")
417 }
418
419 pub fn release(mut self) {
422 if let Some(conn) = self.inner.take() {
423 self.pool.return_conn(conn);
424 }
425 }
426
427 pub fn discard(mut self) {
429 self.inner.take();
430 self.pool.drop_conn();
431 }
432}
433
434impl<C: Send + 'static> Drop for ConnHandle<C> {
435 fn drop(&mut self) {
436 if let Some(conn) = self.inner.take() {
437 self.pool.return_conn(conn);
438 }
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445 use std::sync::atomic::{AtomicUsize, Ordering};
446
447 #[tokio::test]
448 async fn round_trip_basic() {
449 let counter = Arc::new(AtomicUsize::new(0));
450 let c2 = Arc::clone(&counter);
451 let pool: ConnPool<usize> = ConnPool::with_factory(
452 ConnPoolConfig {
453 max_connections: 2,
454 ..ConnPoolConfig::default()
455 },
456 move || {
457 let c = Arc::clone(&c2);
458 async move {
459 let id = c.fetch_add(1, Ordering::Relaxed);
460 Ok::<usize, NetError>(id)
461 }
462 },
463 );
464 let h1 = pool.get().await.unwrap();
465 let h2 = pool.get().await.unwrap();
466 assert_ne!(h1.get(), h2.get());
467 h1.release();
468 let h3 = pool.get().await.unwrap();
469 assert_eq!(*h3.get(), 0);
470 h3.release();
471 h2.release();
472 }
473
474 #[tokio::test]
475 async fn max_connections_blocks_until_release() {
476 let pool: ConnPool<u32> = ConnPool::with_factory(
477 ConnPoolConfig {
478 max_connections: 1,
479 ..ConnPoolConfig::default()
480 },
481 || async { Ok::<u32, NetError>(7) },
482 );
483 let h = pool.get().await.unwrap();
484 let pool2 = pool.clone();
485 let waiter = tokio::spawn(async move {
486 let h2 = pool2.get().await.unwrap();
487 assert_eq!(*h2.get(), 7);
488 });
489 tokio::task::yield_now().await;
491 assert!(!waiter.is_finished());
492 drop(h);
493 waiter.await.unwrap();
494 }
495
496 #[tokio::test]
497 async fn auto_eject_after_consecutive_failures() {
498 let pool: ConnPool<u8> = ConnPool::with_factory(
499 ConnPoolConfig {
500 max_connections: 1,
501 server_failure_limit: 2,
502 server_retry_timeout_ms: 50,
503 auto_eject: true,
504 },
505 || async {
506 Err::<u8, NetError>(NetError::Io(std::io::Error::new(
507 std::io::ErrorKind::ConnectionRefused,
508 "test",
509 )))
510 },
511 );
512 match pool.get().await {
514 Err(NetError::Io(_)) => {}
515 other => panic!("expected io error, got {other:?}"),
516 }
517 match pool.get().await {
519 Err(NetError::Ejected) => {}
520 other => panic!("expected eject, got {other:?}"),
521 }
522 match pool.get().await {
524 Err(NetError::Ejected) => {}
525 other => panic!("expected eject, got {other:?}"),
526 }
527 }
528
529 #[tokio::test]
530 async fn shutdown_unblocks_waiters() {
531 let pool: ConnPool<u32> = ConnPool::with_factory(
532 ConnPoolConfig {
533 max_connections: 1,
534 ..ConnPoolConfig::default()
535 },
536 || async { Ok::<u32, NetError>(1) },
537 );
538 let _h = pool.get().await.unwrap();
539 let pool2 = pool.clone();
540 let w = tokio::spawn(async move { pool2.get().await });
541 tokio::task::yield_now().await;
542 pool.shutdown();
543 assert!(matches!(w.await.unwrap(), Err(NetError::PoolShutdown)));
544 }
545}