use std::future::Future;
use std::sync::atomic::{AtomicUsize, Ordering};
use asupersync::{Cx, Outcome};
use sqlmodel_core::{Connection, Error};
use crate::{Pool, PooledConnection};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReplicaStrategy {
RoundRobin,
Random,
}
pub struct ReplicaPool<C: Connection> {
primary: Pool<C>,
replicas: Vec<Pool<C>>,
strategy: ReplicaStrategy,
round_robin_counter: AtomicUsize,
}
impl<C: Connection> ReplicaPool<C> {
pub fn new(primary: Pool<C>, replicas: Vec<Pool<C>>) -> Self {
Self {
primary,
replicas,
strategy: ReplicaStrategy::RoundRobin,
round_robin_counter: AtomicUsize::new(0),
}
}
pub fn with_strategy(
primary: Pool<C>,
replicas: Vec<Pool<C>>,
strategy: ReplicaStrategy,
) -> Self {
Self {
primary,
replicas,
strategy,
round_robin_counter: AtomicUsize::new(0),
}
}
pub async fn acquire_read<F, Fut>(
&self,
cx: &Cx,
factory: F,
) -> Outcome<PooledConnection<C>, Error>
where
F: Fn() -> Fut,
Fut: Future<Output = Outcome<C, Error>>,
{
if self.replicas.is_empty() {
return self.primary.acquire(cx, factory).await;
}
let idx = self.select_replica();
self.replicas[idx].acquire(cx, factory).await
}
pub async fn acquire_write<F, Fut>(
&self,
cx: &Cx,
factory: F,
) -> Outcome<PooledConnection<C>, Error>
where
F: Fn() -> Fut,
Fut: Future<Output = Outcome<C, Error>>,
{
self.primary.acquire(cx, factory).await
}
pub async fn acquire_primary<F, Fut>(
&self,
cx: &Cx,
factory: F,
) -> Outcome<PooledConnection<C>, Error>
where
F: Fn() -> Fut,
Fut: Future<Output = Outcome<C, Error>>,
{
self.primary.acquire(cx, factory).await
}
pub fn primary(&self) -> &Pool<C> {
&self.primary
}
pub fn replicas(&self) -> &[Pool<C>] {
&self.replicas
}
pub fn replica_count(&self) -> usize {
self.replicas.len()
}
pub fn strategy(&self) -> ReplicaStrategy {
self.strategy
}
fn select_replica(&self) -> usize {
match self.strategy {
ReplicaStrategy::RoundRobin => {
let idx = self.round_robin_counter.fetch_add(1, Ordering::Relaxed);
idx % self.replicas.len()
}
ReplicaStrategy::Random => {
let seq = self.round_robin_counter.fetch_add(1, Ordering::Relaxed);
#[allow(clippy::cast_possible_truncation)]
let seq32 = seq as u32;
let mixed = seq32.wrapping_mul(2_654_435_761_u32);
(mixed as usize) % self.replicas.len()
}
}
}
}
impl<C: Connection> std::fmt::Debug for ReplicaPool<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReplicaPool")
.field("primary", &"Pool { .. }")
.field("replicas", &self.replicas.len())
.field("strategy", &self.strategy)
.field(
"round_robin_counter",
&self.round_robin_counter.load(Ordering::Relaxed),
)
.finish()
}
}