Skip to main content

dynomite/net/
pool.rs

1//! Outbound connection pool with backoff and auto-eject.
2//!
3//! Two pool flavors share the same policy in this codebase: the
4//! per-datastore pool that hands a Redis or memcache backend
5//! connection to a CLIENT FSM, and the per-peer pool that hands a
6//! peer connection to the cluster routing layer. Both share:
7//!
8//! * a cap on the active connection count (`max_connections`),
9//! * round-robin across slots keyed on a caller-supplied `tag`,
10//! * exponential connect-failure backoff (doubling the timeout each
11//!   time, capped at `max_timeout`),
12//! * auto-eject of the host after `failure_limit` consecutive
13//!   failures, with retry after `retry_after`.
14//!
15//! [`ConnPool`] reproduces the policy in safe Rust. Connections are
16//! manufactured by a caller-supplied [`ConnFactory`] (so tests can
17//! inject failure-injecting transports) and handed back to the pool
18//! through [`ConnHandle::release`].
19//!
20//! # Examples
21//!
22//! ```
23//! use dynomite::net::pool::{ConnPool, ConnPoolConfig};
24//! use dynomite::io::reactor::TcpTransport;
25//!
26//! let pool: ConnPool<TcpTransport> = ConnPool::new(ConnPoolConfig {
27//!     max_connections: 4,
28//!     server_failure_limit: 3,
29//!     server_retry_timeout_ms: 1_000,
30//!     auto_eject: true,
31//! });
32//! assert_eq!(pool.config().max_connections, 4);
33//! ```
34
35use std::collections::VecDeque;
36use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::time::{Duration, Instant};
40
41use parking_lot::Mutex;
42use tokio::sync::Notify;
43
44use super::auto_eject::{AutoEject, AutoEjectState};
45use super::NetError;
46
47/// Tunable knobs taken straight from the YAML pool block.
48///
49/// `max_connections` mirrors `datastore_connections` /
50/// `local_peer_connections` / `remote_peer_connections`.
51/// `server_failure_limit` and `server_retry_timeout_ms` mirror the
52/// fields with the same name. `auto_eject` mirrors
53/// `auto_eject_hosts`.
54#[derive(Debug, Clone)]
55pub struct ConnPoolConfig {
56    /// Maximum number of concurrent outbound connections kept by the
57    /// pool.
58    pub max_connections: usize,
59    /// Consecutive failures before the host is ejected.
60    pub server_failure_limit: u32,
61    /// Eject window in milliseconds.
62    pub server_retry_timeout_ms: u64,
63    /// Honor the `auto_eject_hosts` policy.
64    pub auto_eject: bool,
65}
66
67impl Default for ConnPoolConfig {
68    fn default() -> Self {
69        Self {
70            max_connections: 1,
71            server_failure_limit: 3,
72            server_retry_timeout_ms: 30_000,
73            auto_eject: true,
74        }
75    }
76}
77
78/// Boxed future returned by a [`ConnFactory`].
79pub type ConnFuture<C> = Pin<Box<dyn Future<Output = Result<C, NetError>> + Send + 'static>>;
80
81/// Factory that produces a fresh connection on demand.
82///
83/// The factory is invoked from the pool whenever a slot needs a
84/// new connection. Implementations typically wrap a target address
85/// and produce TCP or QUIC connections; tests inject
86/// failure-injecting factories.
87pub trait ConnFactory<C>: Send + Sync + 'static {
88    /// Build a fresh connection.
89    fn connect(&self) -> ConnFuture<C>;
90}
91
92impl<C, F, Fut> ConnFactory<C> for F
93where
94    F: Fn() -> Fut + Send + Sync + 'static,
95    Fut: Future<Output = Result<C, NetError>> + Send + 'static,
96{
97    fn connect(&self) -> ConnFuture<C> {
98        Box::pin(self())
99    }
100}
101
102struct PoolInner<C> {
103    cfg: ConnPoolConfig,
104    idle: VecDeque<C>,
105    in_flight: usize,
106    auto_eject: AutoEject,
107    backoff: Backoff,
108    shutdown: bool,
109}
110
111#[derive(Debug, Clone)]
112struct Backoff {
113    current: Duration,
114    max: Duration,
115}
116
117impl Backoff {
118    fn new(max: Duration) -> Self {
119        Self {
120            current: Duration::ZERO,
121            max,
122        }
123    }
124
125    fn record_failure(&mut self) -> Duration {
126        // Exponential connect-failure backoff: the wait starts at
127        // 1s and doubles on each consecutive failure, capped at
128        // the configured maximum.
129        if self.current.is_zero() {
130            self.current = Duration::from_secs(1);
131        } else {
132            self.current = self.current.saturating_mul(2);
133            if self.current > self.max {
134                self.current = self.max;
135            }
136        }
137        self.current
138    }
139
140    fn record_success(&mut self) {
141        self.current = Duration::ZERO;
142    }
143}
144
145/// Outbound connection pool.
146///
147/// The pool keeps an [`AutoEject`] tracker and a small idle list;
148/// callers acquire a connection through [`ConnPool::get`] and
149/// return it through [`ConnHandle::release`] (or by dropping the
150/// handle, which routes to the same path).
151pub struct ConnPool<C> {
152    factory: Option<Arc<dyn ConnFactory<C>>>,
153    state: Arc<Mutex<PoolInner<C>>>,
154    notify: Arc<Notify>,
155}
156
157impl<C> Clone for ConnPool<C> {
158    fn clone(&self) -> Self {
159        Self {
160            factory: self.factory.clone(),
161            state: Arc::clone(&self.state),
162            notify: Arc::clone(&self.notify),
163        }
164    }
165}
166
167impl<C: Send + 'static> ConnPool<C> {
168    /// Build a pool with no factory installed.
169    /// Build a pool with no factory installed.
170    ///
171    /// Callers must install a factory through
172    /// [`ConnPool::set_factory`] before invoking
173    /// [`ConnPool::get`]. The factory-less constructor exists so
174    /// callers can build the pool eagerly during configuration and
175    /// wire the factory once the resolver has populated the target
176    /// address.
177    ///
178    /// # Examples
179    ///
180    /// ```
181    /// use dynomite::io::reactor::TcpTransport;
182    /// use dynomite::net::pool::{ConnPool, ConnPoolConfig};
183    /// let _: ConnPool<TcpTransport> = ConnPool::new(ConnPoolConfig::default());
184    /// ```
185    #[must_use]
186    pub fn new(cfg: ConnPoolConfig) -> Self {
187        let auto_eject = AutoEject::new(
188            cfg.auto_eject,
189            cfg.server_failure_limit.max(1),
190            Duration::from_millis(cfg.server_retry_timeout_ms),
191        );
192        let max_backoff = Duration::from_millis(cfg.server_retry_timeout_ms.max(1_000));
193        Self {
194            factory: None,
195            state: Arc::new(Mutex::new(PoolInner {
196                cfg,
197                idle: VecDeque::new(),
198                in_flight: 0,
199                auto_eject,
200                backoff: Backoff::new(max_backoff),
201                shutdown: false,
202            })),
203            notify: Arc::new(Notify::new()),
204        }
205    }
206
207    /// Build a pool with a factory installed up front.
208    ///
209    /// # Examples
210    ///
211    /// ```
212    /// use dynomite::net::pool::{ConnPool, ConnPoolConfig};
213    /// use dynomite::net::NetError;
214    /// let pool = ConnPool::with_factory(ConnPoolConfig::default(), || async {
215    ///     Ok::<u32, NetError>(0)
216    /// });
217    /// assert_eq!(pool.config().max_connections, 1);
218    /// ```
219    pub fn with_factory<F>(cfg: ConnPoolConfig, factory: F) -> Self
220    where
221        F: ConnFactory<C>,
222    {
223        let mut pool = Self::new(cfg);
224        pool.factory = Some(Arc::new(factory));
225        pool
226    }
227
228    /// Install a connection factory.
229    pub fn set_factory<F: ConnFactory<C>>(&mut self, factory: F) {
230        self.factory = Some(Arc::new(factory));
231    }
232
233    /// Borrow the pool config.
234    #[must_use]
235    pub fn config(&self) -> ConnPoolConfig {
236        self.state.lock().cfg.clone()
237    }
238
239    /// Number of idle connections currently in the pool.
240    #[must_use]
241    pub fn idle_count(&self) -> usize {
242        self.state.lock().idle.len()
243    }
244
245    /// Number of in-flight (handed out) connections.
246    #[must_use]
247    pub fn in_flight(&self) -> usize {
248        self.state.lock().in_flight
249    }
250
251    /// True when the host has been auto-ejected at this instant.
252    #[must_use]
253    pub fn is_ejected(&self, now: Instant) -> bool {
254        let mut g = self.state.lock();
255        g.auto_eject.record_attempt(now) == AutoEjectState::Ejected
256    }
257
258    /// Snapshot of the auto-eject tracker.
259    #[must_use]
260    pub fn auto_eject(&self) -> AutoEject {
261        self.state.lock().auto_eject.clone()
262    }
263
264    /// Shut the pool down.
265    ///
266    /// Wakes every waiter blocked in [`ConnPool::get`]. Subsequent
267    /// `get` calls return [`NetError::PoolShutdown`].
268    pub fn shutdown(&self) {
269        {
270            let mut g = self.state.lock();
271            g.shutdown = true;
272            g.idle.clear();
273        }
274        self.notify.notify_waiters();
275    }
276
277    /// Acquire a connection.
278    ///
279    /// Reuses an idle connection when one is available; otherwise
280    /// invokes the factory until it succeeds or the auto-eject
281    /// tracker reports the target as unreachable.
282    ///
283    /// # Errors
284    /// Returns [`NetError::Ejected`] when the host is in its eject
285    /// window, [`NetError::PoolShutdown`] when the pool was shut
286    /// down, or the underlying factory error otherwise.
287    pub async fn get(&self) -> Result<ConnHandle<C>, NetError> {
288        loop {
289            // Fast path: an idle connection is sitting in the pool.
290            let waiter = {
291                let mut g = self.state.lock();
292                if g.shutdown {
293                    return Err(NetError::PoolShutdown);
294                }
295                if let Some(conn) = g.idle.pop_front() {
296                    g.in_flight += 1;
297                    return Ok(ConnHandle {
298                        pool: self.clone(),
299                        inner: Some(conn),
300                    });
301                }
302                if g.in_flight + g.idle.len() >= g.cfg.max_connections {
303                    true
304                } else {
305                    let now = Instant::now();
306                    if g.auto_eject.record_attempt(now) == AutoEjectState::Ejected {
307                        return Err(NetError::Ejected);
308                    }
309                    false
310                }
311            };
312            if waiter {
313                self.notify.notified().await;
314                continue;
315            }
316
317            let factory = self
318                .factory
319                .as_ref()
320                .ok_or(NetError::PoolExhausted)?
321                .clone();
322            match factory.connect().await {
323                Ok(conn) => {
324                    let mut g = self.state.lock();
325                    g.in_flight += 1;
326                    g.auto_eject.record_success(Instant::now());
327                    g.backoff.record_success();
328                    return Ok(ConnHandle {
329                        pool: self.clone(),
330                        inner: Some(conn),
331                    });
332                }
333                Err(err) => {
334                    let ejected;
335                    {
336                        let mut g = self.state.lock();
337                        let now = Instant::now();
338                        ejected = g.auto_eject.record_failure(now) == AutoEjectState::Ejected;
339                        let _ = g.backoff.record_failure();
340                    }
341                    if ejected {
342                        return Err(NetError::Ejected);
343                    }
344                    return Err(err);
345                }
346            }
347        }
348    }
349
350    fn return_conn(&self, conn: C) {
351        let mut g = self.state.lock();
352        if g.in_flight > 0 {
353            g.in_flight -= 1;
354        }
355        if !g.shutdown && g.idle.len() + g.in_flight < g.cfg.max_connections {
356            g.idle.push_back(conn);
357        }
358        drop(g);
359        self.notify.notify_one();
360    }
361
362    fn drop_conn(&self) {
363        let mut g = self.state.lock();
364        if g.in_flight > 0 {
365            g.in_flight -= 1;
366        }
367        drop(g);
368        self.notify.notify_one();
369    }
370}
371
372impl<C: std::fmt::Debug> std::fmt::Debug for ConnPool<C> {
373    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374        let g = self.state.lock();
375        let factory_present = self.factory.is_some();
376        f.debug_struct("ConnPool")
377            .field("cfg", &g.cfg)
378            .field("idle", &g.idle.len())
379            .field("in_flight", &g.in_flight)
380            .field("auto_eject_failures", &g.auto_eject.failure_count())
381            .field("factory_installed", &factory_present)
382            .field("notify", &"<tokio::sync::Notify>")
383            .finish()
384    }
385}
386
387/// Handle returned by [`ConnPool::get`].
388///
389/// Dropping the handle returns the connection to the pool. Call
390/// [`ConnHandle::discard`] when the connection is no longer healthy
391/// (the slot will then be re-filled by the next `get`).
392pub struct ConnHandle<C: Send + 'static> {
393    pool: ConnPool<C>,
394    inner: Option<C>,
395}
396
397impl<C: Send + 'static> std::fmt::Debug for ConnHandle<C> {
398    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
399        // The pool reference and the inner connection are
400        // intentionally elided; only the alive bit is printed.
401        let _ = (&self.pool, &self.inner);
402        f.debug_struct("ConnHandle")
403            .field("alive", &self.inner.is_some())
404            .finish()
405    }
406}
407
408impl<C: Send + 'static> ConnHandle<C> {
409    /// Borrow the wrapped connection.
410    pub fn get(&self) -> &C {
411        self.inner.as_ref().expect("invariant: handle is alive")
412    }
413
414    /// Mutably borrow the wrapped connection.
415    pub fn get_mut(&mut self) -> &mut C {
416        self.inner.as_mut().expect("invariant: handle is alive")
417    }
418
419    /// Return the connection to the pool. Equivalent to dropping
420    /// the handle.
421    pub fn release(mut self) {
422        if let Some(conn) = self.inner.take() {
423            self.pool.return_conn(conn);
424        }
425    }
426
427    /// Discard the connection (do not return it to the pool).
428    pub fn discard(mut self) {
429        self.inner.take();
430        self.pool.drop_conn();
431    }
432}
433
434impl<C: Send + 'static> Drop for ConnHandle<C> {
435    fn drop(&mut self) {
436        if let Some(conn) = self.inner.take() {
437            self.pool.return_conn(conn);
438        }
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use std::sync::atomic::{AtomicUsize, Ordering};
446
447    #[tokio::test]
448    async fn round_trip_basic() {
449        let counter = Arc::new(AtomicUsize::new(0));
450        let c2 = Arc::clone(&counter);
451        let pool: ConnPool<usize> = ConnPool::with_factory(
452            ConnPoolConfig {
453                max_connections: 2,
454                ..ConnPoolConfig::default()
455            },
456            move || {
457                let c = Arc::clone(&c2);
458                async move {
459                    let id = c.fetch_add(1, Ordering::Relaxed);
460                    Ok::<usize, NetError>(id)
461                }
462            },
463        );
464        let h1 = pool.get().await.unwrap();
465        let h2 = pool.get().await.unwrap();
466        assert_ne!(h1.get(), h2.get());
467        h1.release();
468        let h3 = pool.get().await.unwrap();
469        assert_eq!(*h3.get(), 0);
470        h3.release();
471        h2.release();
472    }
473
474    #[tokio::test]
475    async fn max_connections_blocks_until_release() {
476        let pool: ConnPool<u32> = ConnPool::with_factory(
477            ConnPoolConfig {
478                max_connections: 1,
479                ..ConnPoolConfig::default()
480            },
481            || async { Ok::<u32, NetError>(7) },
482        );
483        let h = pool.get().await.unwrap();
484        let pool2 = pool.clone();
485        let waiter = tokio::spawn(async move {
486            let h2 = pool2.get().await.unwrap();
487            assert_eq!(*h2.get(), 7);
488        });
489        // Briefly yield to ensure waiter is parked.
490        tokio::task::yield_now().await;
491        assert!(!waiter.is_finished());
492        drop(h);
493        waiter.await.unwrap();
494    }
495
496    #[tokio::test]
497    async fn auto_eject_after_consecutive_failures() {
498        let pool: ConnPool<u8> = ConnPool::with_factory(
499            ConnPoolConfig {
500                max_connections: 1,
501                server_failure_limit: 2,
502                server_retry_timeout_ms: 50,
503                auto_eject: true,
504            },
505            || async {
506                Err::<u8, NetError>(NetError::Io(std::io::Error::new(
507                    std::io::ErrorKind::ConnectionRefused,
508                    "test",
509                )))
510            },
511        );
512        // First failure surfaces the io error.
513        match pool.get().await {
514            Err(NetError::Io(_)) => {}
515            other => panic!("expected io error, got {other:?}"),
516        }
517        // Second failure trips the eject window.
518        match pool.get().await {
519            Err(NetError::Ejected) => {}
520            other => panic!("expected eject, got {other:?}"),
521        }
522        // Subsequent attempts within the window are short-circuited.
523        match pool.get().await {
524            Err(NetError::Ejected) => {}
525            other => panic!("expected eject, got {other:?}"),
526        }
527    }
528
529    #[tokio::test]
530    async fn shutdown_unblocks_waiters() {
531        let pool: ConnPool<u32> = ConnPool::with_factory(
532            ConnPoolConfig {
533                max_connections: 1,
534                ..ConnPoolConfig::default()
535            },
536            || async { Ok::<u32, NetError>(1) },
537        );
538        let _h = pool.get().await.unwrap();
539        let pool2 = pool.clone();
540        let w = tokio::spawn(async move { pool2.get().await });
541        tokio::task::yield_now().await;
542        pool.shutdown();
543        assert!(matches!(w.await.unwrap(), Err(NetError::PoolShutdown)));
544    }
545}