use std::fmt;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
const DEFAULT_POOL_LIFETIME: Duration = Duration::from_mins(1);
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ConnectionPoolOptions {
pub connection_idle_timeout: ConnectionIdleTimeout,
pub max_connections: usize,
pub multiple_pools: Option<(usize, PoolSelection)>,
pub connection_lifetime: ConnectionLifetime,
}
impl Default for ConnectionPoolOptions {
fn default() -> Self {
Self {
connection_idle_timeout: ConnectionIdleTimeout::default(),
max_connections: usize::MAX,
multiple_pools: None,
connection_lifetime: ConnectionLifetime::default(),
}
}
}
impl ConnectionPoolOptions {
#[must_use]
pub fn connection_idle_timeout(mut self, timeout: impl Into<Option<Duration>>) -> Self {
self.connection_idle_timeout = match timeout.into() {
Some(duration) => ConnectionIdleTimeout::Limited(duration),
None => ConnectionIdleTimeout::Unlimited,
};
self
}
#[must_use]
pub fn max_connections(mut self, max: usize) -> Self {
self.max_connections = max;
self
}
#[must_use]
pub fn connection_lifetime(mut self, lifetime: ConnectionLifetime) -> Self {
self.connection_lifetime = lifetime;
self
}
#[must_use]
pub fn multiple_pools(mut self, count: usize, selection: PoolSelection) -> Self {
self.multiple_pools = (count > 1).then_some((count, selection));
self
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum ConnectionIdleTimeout {
Unlimited,
Limited(Duration),
}
impl Default for ConnectionIdleTimeout {
fn default() -> Self {
Self::Limited(DEFAULT_POOL_LIFETIME)
}
}
#[derive(Clone, Default)]
#[repr(transparent)]
pub struct ConnectionLifetime(Inner);
#[derive(Clone, Default)]
enum Inner {
#[default]
Unlimited,
Fixed(Duration),
PerConnection(Arc<dyn Fn() -> Option<Duration> + Send + Sync + 'static>),
}
impl ConnectionLifetime {
#[must_use]
pub const fn unlimited() -> Self {
Self(Inner::Unlimited)
}
#[must_use]
pub const fn fixed(duration: Duration) -> Self {
Self(Inner::Fixed(duration))
}
#[must_use]
pub fn per_connection<F>(generator: F) -> Self
where
F: Fn() -> Option<Duration> + Send + Sync + 'static,
{
Self(Inner::PerConnection(Arc::new(generator)))
}
#[must_use]
pub fn resolve(&self) -> Option<Duration> {
match &self.0 {
Inner::Unlimited => None,
Inner::Fixed(duration) => Some(*duration),
Inner::PerConnection(generator) => generator(),
}
}
}
impl fmt::Debug for ConnectionLifetime {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.0 {
Inner::Unlimited => f.debug_tuple("ConnectionLifetime").field(&format_args!("Unlimited")).finish(),
Inner::Fixed(duration) => f.debug_tuple("ConnectionLifetime").field(duration).finish(),
Inner::PerConnection(_) => f.debug_tuple("ConnectionLifetime").field(&format_args!("<closure>")).finish(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PoolIndex(usize);
impl PoolIndex {
#[must_use]
pub fn new(index: usize) -> Self {
Self(index)
}
#[must_use]
pub fn index(self) -> usize {
self.0
}
}
impl fmt::Display for PoolIndex {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "PoolIndex({})", self.0)
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct PoolSelection {
mode: Mode,
}
impl PoolSelection {
pub const DEFAULT_REQUESTS_PER_CLIENT: u32 = 100;
#[must_use]
pub fn saturating(requests_per_client: u32) -> Self {
assert!(requests_per_client > 0, "requests_per_client must be > 0");
Self {
mode: Mode::Saturating { requests_per_client },
}
}
#[must_use]
pub fn round_robin() -> Self {
Self { mode: Mode::RoundRobin }
}
pub fn into_selector<T>(self) -> impl Fn(&[T]) -> (&T, PoolIndex) {
let strategy = PoolSelectionStrategy::from(self);
move |clients: &[T]| strategy.select(clients)
}
}
#[derive(Debug, Clone)]
enum Mode {
Saturating { requests_per_client: u32 },
RoundRobin,
}
#[derive(Debug)]
pub(crate) enum PoolSelectionStrategy {
Saturating { requests_per_client: u32, counter: AtomicU32 },
RoundRobin { counter: AtomicU32 },
}
impl PoolSelectionStrategy {
pub(crate) fn select<'a, T>(&self, clients: &'a [T]) -> (&'a T, PoolIndex) {
assert!(!clients.is_empty(), "clients must not be empty");
match self {
Self::Saturating {
requests_per_client,
counter,
} => {
let counter = counter.fetch_add(1, Ordering::Relaxed);
let i = (counter / requests_per_client) as usize % clients.len();
(&clients[i], PoolIndex::new(i))
}
Self::RoundRobin { counter } => {
let counter = counter.fetch_add(1, Ordering::Relaxed);
let i = counter as usize % clients.len();
(&clients[i], PoolIndex::new(i))
}
}
}
}
impl From<PoolSelection> for PoolSelectionStrategy {
fn from(mode: PoolSelection) -> Self {
match mode.mode {
Mode::Saturating { requests_per_client } => Self::Saturating {
requests_per_client,
counter: AtomicU32::new(0),
},
Mode::RoundRobin => Self::RoundRobin {
counter: AtomicU32::new(0),
},
}
}
}
#[cfg(test)]
mod tests {
use std::fmt::Debug;
use super::*;
#[test]
fn connection_idle_timeout_default() {
let default = ConnectionIdleTimeout::default();
assert!(matches!(
default,
ConnectionIdleTimeout::Limited(d) if d == DEFAULT_POOL_LIFETIME
));
}
#[test]
fn connection_idle_timeout_debug() {
assert_eq!(format!("{:?}", ConnectionIdleTimeout::Unlimited), "Unlimited");
assert_eq!(
format!("{:?}", ConnectionIdleTimeout::Limited(Duration::from_secs(1))),
"Limited(1s)"
);
}
#[test]
fn assert_connection_idle_timeout_type() {
static_assertions::assert_impl_all!(
ConnectionIdleTimeout: Send,
Sync,
Clone,
Debug,
Default
);
}
#[test]
fn connection_pool_options_default() {
let options = ConnectionPoolOptions::default();
assert_eq!(options.max_connections, usize::MAX);
assert!(matches!(
options.connection_idle_timeout,
ConnectionIdleTimeout::Limited(d) if d == Duration::from_mins(1)
));
}
#[test]
fn connection_pool_options_connection_idle_timeout_set() {
let options = ConnectionPoolOptions::default().connection_idle_timeout(Duration::from_mins(2));
assert!(matches!(
options.connection_idle_timeout,
ConnectionIdleTimeout::Limited(d) if d == Duration::from_mins(2)
));
}
#[test]
fn connection_pool_options_connection_idle_timeout_none() {
let options = ConnectionPoolOptions::default()
.connection_idle_timeout(Duration::from_mins(1))
.connection_idle_timeout(None);
assert!(matches!(options.connection_idle_timeout, ConnectionIdleTimeout::Unlimited));
}
#[test]
fn connection_pool_options_max_connections() {
let options = ConnectionPoolOptions::default().max_connections(100);
assert_eq!(options.max_connections, 100);
}
#[test]
fn connection_idle_timeout_field_returns_configured_value() {
let options = ConnectionPoolOptions::default().connection_idle_timeout(Duration::from_secs(45));
assert!(matches!(
options.connection_idle_timeout,
ConnectionIdleTimeout::Limited(d) if d == Duration::from_secs(45)
));
}
#[test]
fn connection_idle_timeout_field_returns_unlimited_when_disabled() {
let options = ConnectionPoolOptions::default().connection_idle_timeout(None);
assert!(matches!(options.connection_idle_timeout, ConnectionIdleTimeout::Unlimited));
}
#[test]
fn connection_idle_timeout_field_default_is_sixty_seconds() {
let options = ConnectionPoolOptions::default();
assert!(matches!(
options.connection_idle_timeout,
ConnectionIdleTimeout::Limited(d) if d == Duration::from_mins(1)
));
}
#[test]
fn max_connections_field_returns_configured_value() {
let options = ConnectionPoolOptions::default().max_connections(42);
assert_eq!(options.max_connections, 42);
}
#[test]
fn max_connections_field_default_is_unlimited() {
let options = ConnectionPoolOptions::default();
assert_eq!(options.max_connections, usize::MAX);
}
#[test]
fn connection_lifetime_field_returns_configured_value() {
let options = ConnectionPoolOptions::default().connection_lifetime(ConnectionLifetime::fixed(Duration::from_hours(1)));
assert_eq!(options.connection_lifetime.resolve(), Some(Duration::from_hours(1)));
}
#[test]
fn connection_lifetime_field_default_is_unlimited() {
let options = ConnectionPoolOptions::default();
assert_eq!(options.connection_lifetime.resolve(), None);
}
#[test]
fn connection_lifetime_field_returns_per_connection() {
let options =
ConnectionPoolOptions::default().connection_lifetime(ConnectionLifetime::per_connection(|| Some(Duration::from_mins(2))));
assert_eq!(options.connection_lifetime.resolve(), Some(Duration::from_mins(2)));
}
#[test]
fn connection_pool_options_connection_lifetime_default() {
let options = ConnectionPoolOptions::default();
assert_eq!(options.connection_lifetime.resolve(), None);
}
#[test]
fn connection_pool_options_connection_lifetime_set() {
let options = ConnectionPoolOptions::default().connection_lifetime(ConnectionLifetime::fixed(Duration::from_hours(1)));
assert_eq!(options.connection_lifetime.resolve(), Some(Duration::from_hours(1)));
}
#[test]
fn connection_pool_options_connection_lifetime_set_unlimited() {
let options = ConnectionPoolOptions::default()
.connection_lifetime(ConnectionLifetime::fixed(Duration::from_mins(1)))
.connection_lifetime(ConnectionLifetime::unlimited());
assert_eq!(options.connection_lifetime.resolve(), None);
}
#[test]
fn connection_pool_options_connection_lifetime_per_connection() {
let options =
ConnectionPoolOptions::default().connection_lifetime(ConnectionLifetime::per_connection(|| Some(Duration::from_secs(7))));
assert_eq!(options.connection_lifetime.resolve(), Some(Duration::from_secs(7)));
}
#[test]
fn connection_lifetime_resolve_unlimited() {
assert_eq!(ConnectionLifetime::unlimited().resolve(), None);
}
#[test]
fn connection_lifetime_resolve_fixed() {
let policy = ConnectionLifetime::fixed(Duration::from_secs(10));
assert_eq!(policy.resolve(), Some(Duration::from_secs(10)));
}
#[test]
fn connection_lifetime_resolve_per_connection_evaluates_closure() {
let counter = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let policy = ConnectionLifetime::per_connection(move || {
counter_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Some(Duration::from_secs(42))
});
assert_eq!(policy.resolve(), Some(Duration::from_secs(42)));
assert_eq!(policy.resolve(), Some(Duration::from_secs(42)));
assert_eq!(counter.load(std::sync::atomic::Ordering::Relaxed), 2);
}
#[test]
fn connection_lifetime_resolve_per_connection_can_return_none() {
let policy = ConnectionLifetime::per_connection(|| None);
assert_eq!(policy.resolve(), None);
}
#[test]
fn connection_lifetime_debug() {
assert!(format!("{:?}", ConnectionLifetime::unlimited()).contains("Unlimited"));
let fixed = format!("{:?}", ConnectionLifetime::fixed(Duration::from_secs(1)));
assert!(fixed.contains("1s"));
let policy = ConnectionLifetime::per_connection(|| None);
assert!(format!("{policy:?}").contains("<closure>"));
}
#[test]
fn saturating_ok() {
let mode = PoolSelection::saturating(PoolSelection::DEFAULT_REQUESTS_PER_CLIENT);
assert!(matches!(mode.mode, Mode::Saturating { requests_per_client: 100 }));
}
#[test]
fn saturating_with_custom_value() {
let mode = PoolSelection::saturating(50);
assert!(matches!(mode.mode, Mode::Saturating { requests_per_client: 50 }));
}
#[test]
#[should_panic(expected = "requests_per_client must be > 0")]
fn saturating_panics_on_zero() {
let _ = PoolSelection::saturating(0);
}
#[test]
fn default_requests_per_client_constant() {
assert_eq!(PoolSelection::DEFAULT_REQUESTS_PER_CLIENT, 100);
}
#[test]
fn distributes_requests_across_clients() {
let clients = [1, 2];
let strategy = PoolSelectionStrategy::from(PoolSelection::saturating(PoolSelection::DEFAULT_REQUESTS_PER_CLIENT));
for _ in 0..100 {
assert_eq!(strategy.select(&clients).0, &1);
}
for _ in 0..100 {
assert_eq!(strategy.select(&clients).0, &2);
}
for _ in 0..100 {
assert_eq!(strategy.select(&clients).0, &1);
}
}
#[test]
fn round_robin_ok() {
let mode = PoolSelection::round_robin();
assert!(matches!(mode.mode, Mode::RoundRobin));
}
#[test]
fn round_robin_distributes_requests_evenly() {
let clients = [1, 2, 3];
let strategy = PoolSelectionStrategy::from(PoolSelection::round_robin());
let selection = strategy.select(&clients);
assert_eq!(selection.0, &1);
assert_eq!(selection.1, PoolIndex::new(0));
assert_eq!(strategy.select(&clients).0, &2);
assert_eq!(strategy.select(&clients).0, &3);
assert_eq!(strategy.select(&clients).0, &1);
assert_eq!(strategy.select(&clients).0, &2);
assert_eq!(strategy.select(&clients).0, &3);
}
#[test]
fn round_robin_with_two_clients() {
let clients = [1, 2];
let strategy = PoolSelectionStrategy::from(PoolSelection::round_robin());
for _ in 0..50 {
assert_eq!(strategy.select(&clients).0, &1);
assert_eq!(strategy.select(&clients).0, &2);
}
}
#[test]
fn multiple_pools_field_returns_none_when_single_pool() {
let options = ConnectionPoolOptions::default();
assert!(options.multiple_pools.is_none());
}
#[test]
fn multiple_pools_field_returns_some_when_configured() {
let selection = PoolSelection::saturating(50);
let options = ConnectionPoolOptions::default().multiple_pools(4, selection);
let result = options.multiple_pools;
assert!(result.is_some());
let (count, _sel) = result.unwrap();
assert_eq!(count, 4);
}
#[test]
fn multiple_pools_field_returns_none_when_pool_count_is_one() {
let selection = PoolSelection::round_robin();
let options = ConnectionPoolOptions::default().multiple_pools(1, selection);
assert!(options.multiple_pools.is_none());
}
#[test]
fn multiple_pools_field_with_round_robin() {
let selection = PoolSelection::round_robin();
let options = ConnectionPoolOptions::default().multiple_pools(3, selection);
let result = options.multiple_pools;
assert_eq!(result.map(|(count, _)| count), Some(3));
}
#[test]
fn into_selector_round_robin_with_integers() {
let selector = PoolSelection::round_robin().into_selector::<i32>();
let clients = [10, 20, 30];
let (first, index) = selector(&clients);
assert_eq!(*first, 10);
assert_eq!(index, PoolIndex::new(0));
let (second, index) = selector(&clients);
assert_eq!(*second, 20);
assert_eq!(index, PoolIndex::new(1));
let (third, index) = selector(&clients);
assert_eq!(*third, 30);
assert_eq!(index, PoolIndex::new(2));
let (wrapped, index) = selector(&clients);
assert_eq!(*wrapped, 10);
assert_eq!(index, PoolIndex::new(0));
}
#[test]
fn into_selector_saturating_with_strings() {
let selector = PoolSelection::saturating(2).into_selector::<&str>();
let clients = ["a", "b"];
assert_eq!(*selector(&clients).0, "a");
assert_eq!(*selector(&clients).0, "a");
assert_eq!(*selector(&clients).0, "b");
assert_eq!(*selector(&clients).0, "b");
assert_eq!(*selector(&clients).0, "a");
}
#[test]
fn multiple_pools_field_returns_owned_selection() {
let selection = PoolSelection::round_robin();
let options = ConnectionPoolOptions::default().multiple_pools(2, selection);
let (count, owned_selection) = options.multiple_pools.unwrap();
assert_eq!(count, 2);
let selector = owned_selection.into_selector::<i32>();
let clients = [1, 2];
let (selected, _) = selector(&clients);
assert_eq!(*selected, 1);
}
#[test]
fn pool_index_index_returns_value() {
assert_eq!(PoolIndex::new(7).index(), 7);
}
#[test]
fn pool_index_display() {
assert_eq!(PoolIndex::new(3).to_string(), "PoolIndex(3)");
}
}