use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use serde::{Deserialize, Serialize};
use crate::error::{QueryError, QueryResult};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ReplicaConfig {
pub id: String,
pub url: String,
pub role: ReplicaRole,
pub priority: u32,
pub weight: u32,
pub region: Option<String>,
pub max_lag: Option<Duration>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ReplicaRole {
Primary,
Secondary,
Arbiter,
Hidden,
}
impl ReplicaConfig {
pub fn primary(id: impl Into<String>, url: impl Into<String>) -> Self {
Self {
id: id.into(),
url: url.into(),
role: ReplicaRole::Primary,
priority: 100,
weight: 100,
region: None,
max_lag: None,
}
}
pub fn secondary(id: impl Into<String>, url: impl Into<String>) -> Self {
Self {
id: id.into(),
url: url.into(),
role: ReplicaRole::Secondary,
priority: 50,
weight: 100,
region: None,
max_lag: Some(Duration::from_secs(10)),
}
}
pub fn with_region(mut self, region: impl Into<String>) -> Self {
self.region = Some(region.into());
self
}
pub fn with_weight(mut self, weight: u32) -> Self {
self.weight = weight;
self
}
pub fn with_priority(mut self, priority: u32) -> Self {
self.priority = priority;
self
}
pub fn with_max_lag(mut self, lag: Duration) -> Self {
self.max_lag = Some(lag);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReplicaSetConfig {
pub name: String,
pub replicas: Vec<ReplicaConfig>,
pub default_read_preference: ReadPreference,
pub health_check_interval: Duration,
pub failover_timeout: Duration,
}
impl ReplicaSetConfig {
pub fn new(name: impl Into<String>) -> ReplicaSetBuilder {
ReplicaSetBuilder::new(name)
}
pub fn primary(&self) -> Option<&ReplicaConfig> {
self.replicas
.iter()
.find(|r| r.role == ReplicaRole::Primary)
}
pub fn secondaries(&self) -> impl Iterator<Item = &ReplicaConfig> {
self.replicas
.iter()
.filter(|r| r.role == ReplicaRole::Secondary)
}
pub fn in_region(&self, region: &str) -> impl Iterator<Item = &ReplicaConfig> {
self.replicas
.iter()
.filter(move |r| r.region.as_deref() == Some(region))
}
}
#[derive(Debug, Clone)]
pub struct ReplicaSetBuilder {
name: String,
replicas: Vec<ReplicaConfig>,
default_read_preference: ReadPreference,
health_check_interval: Duration,
failover_timeout: Duration,
}
impl ReplicaSetBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
replicas: Vec::new(),
default_read_preference: ReadPreference::Primary,
health_check_interval: Duration::from_secs(10),
failover_timeout: Duration::from_secs(30),
}
}
pub fn replica(mut self, config: ReplicaConfig) -> Self {
self.replicas.push(config);
self
}
pub fn primary(self, id: impl Into<String>, url: impl Into<String>) -> Self {
self.replica(ReplicaConfig::primary(id, url))
}
pub fn secondary(self, id: impl Into<String>, url: impl Into<String>) -> Self {
self.replica(ReplicaConfig::secondary(id, url))
}
pub fn read_preference(mut self, pref: ReadPreference) -> Self {
self.default_read_preference = pref;
self
}
pub fn health_check_interval(mut self, interval: Duration) -> Self {
self.health_check_interval = interval;
self
}
pub fn failover_timeout(mut self, timeout: Duration) -> Self {
self.failover_timeout = timeout;
self
}
pub fn build(self) -> ReplicaSetConfig {
ReplicaSetConfig {
name: self.name,
replicas: self.replicas,
default_read_preference: self.default_read_preference,
health_check_interval: self.health_check_interval,
failover_timeout: self.failover_timeout,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ReadPreference {
Primary,
PrimaryPreferred,
Secondary,
SecondaryPreferred,
Nearest,
Region(String),
TagSet(Vec<HashMap<String, String>>),
}
impl ReadPreference {
pub fn region(region: impl Into<String>) -> Self {
Self::Region(region.into())
}
pub fn tag_set(tags: Vec<HashMap<String, String>>) -> Self {
Self::TagSet(tags)
}
pub fn to_mongodb(&self) -> &'static str {
match self {
Self::Primary => "primary",
Self::PrimaryPreferred => "primaryPreferred",
Self::Secondary => "secondary",
Self::SecondaryPreferred => "secondaryPreferred",
Self::Nearest => "nearest",
Self::Region(_) | Self::TagSet(_) => "nearest",
}
}
pub fn allows_primary(&self) -> bool {
matches!(
self,
Self::Primary
| Self::PrimaryPreferred
| Self::Nearest
| Self::Region(_)
| Self::TagSet(_)
)
}
pub fn allows_secondary(&self) -> bool {
!matches!(self, Self::Primary)
}
}
impl Default for ReadPreference {
fn default() -> Self {
Self::Primary
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum HealthStatus {
Healthy,
Degraded,
Unhealthy,
Unknown,
}
#[derive(Debug, Clone)]
pub struct ReplicaHealth {
pub id: String,
pub status: HealthStatus,
pub lag: Option<Duration>,
pub last_check: Option<Instant>,
pub latency: Option<Duration>,
pub consecutive_failures: u32,
}
impl ReplicaHealth {
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
status: HealthStatus::Unknown,
lag: None,
last_check: None,
latency: None,
consecutive_failures: 0,
}
}
pub fn mark_healthy(&mut self, latency: Duration, lag: Option<Duration>) {
self.status = HealthStatus::Healthy;
self.latency = Some(latency);
self.lag = lag;
self.last_check = Some(Instant::now());
self.consecutive_failures = 0;
}
pub fn mark_degraded(&mut self, reason: &str) {
self.status = HealthStatus::Degraded;
self.last_check = Some(Instant::now());
let _ = reason; }
pub fn mark_unhealthy(&mut self) {
self.status = HealthStatus::Unhealthy;
self.last_check = Some(Instant::now());
self.consecutive_failures += 1;
}
pub fn is_usable(&self) -> bool {
matches!(self.status, HealthStatus::Healthy | HealthStatus::Degraded)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryType {
Read,
Write,
Transaction,
}
#[derive(Debug)]
pub struct ConnectionRouter {
config: ReplicaSetConfig,
health: HashMap<String, ReplicaHealth>,
current_primary: Option<String>,
round_robin: AtomicUsize,
in_failover: AtomicBool,
}
impl ConnectionRouter {
pub fn new(config: ReplicaSetConfig) -> Self {
let mut health = HashMap::new();
let mut primary_id = None;
for replica in &config.replicas {
health.insert(replica.id.clone(), ReplicaHealth::new(&replica.id));
if replica.role == ReplicaRole::Primary {
primary_id = Some(replica.id.clone());
}
}
Self {
config,
health,
current_primary: primary_id,
round_robin: AtomicUsize::new(0),
in_failover: AtomicBool::new(false),
}
}
pub fn route(
&self,
query_type: QueryType,
preference: Option<&ReadPreference>,
) -> QueryResult<&ReplicaConfig> {
let pref = preference.unwrap_or(&self.config.default_read_preference);
match query_type {
QueryType::Write | QueryType::Transaction => self.get_primary(),
QueryType::Read => self.route_read(pref),
}
}
pub fn get_primary(&self) -> QueryResult<&ReplicaConfig> {
let primary_id = self
.current_primary
.as_ref()
.ok_or_else(|| QueryError::connection("No primary replica available"))?;
self.config
.replicas
.iter()
.find(|r| &r.id == primary_id)
.ok_or_else(|| QueryError::connection("Primary replica not found"))
}
fn route_read(&self, preference: &ReadPreference) -> QueryResult<&ReplicaConfig> {
match preference {
ReadPreference::Primary => self.get_primary(),
ReadPreference::PrimaryPreferred => {
self.get_primary().or_else(|_| self.get_any_secondary())
}
ReadPreference::Secondary => self.get_any_secondary(),
ReadPreference::SecondaryPreferred => {
self.get_any_secondary().or_else(|_| self.get_primary())
}
ReadPreference::Nearest => self.get_nearest(),
ReadPreference::Region(region) => self.get_in_region(region),
ReadPreference::TagSet(_tags) => {
self.get_nearest()
}
}
}
fn get_any_secondary(&self) -> QueryResult<&ReplicaConfig> {
let secondaries: Vec<_> = self
.config
.secondaries()
.filter(|r| self.is_replica_healthy(&r.id))
.collect();
if secondaries.is_empty() {
return Err(QueryError::connection(
"No healthy secondary replicas available",
));
}
let idx = self.round_robin.fetch_add(1, Ordering::Relaxed) % secondaries.len();
Ok(secondaries[idx])
}
fn get_nearest(&self) -> QueryResult<&ReplicaConfig> {
let mut best: Option<(&ReplicaConfig, Duration)> = None;
for replica in &self.config.replicas {
if !self.is_replica_healthy(&replica.id) {
continue;
}
if let Some(health) = self.health.get(&replica.id) {
if let Some(latency) = health.latency {
match &best {
None => best = Some((replica, latency)),
Some((_, best_latency)) if latency < *best_latency => {
best = Some((replica, latency));
}
_ => {}
}
}
}
}
best.map(|(r, _)| r)
.ok_or_else(|| QueryError::connection("No healthy replicas available"))
}
fn get_in_region(&self, region: &str) -> QueryResult<&ReplicaConfig> {
let replicas: Vec<_> = self
.config
.in_region(region)
.filter(|r| self.is_replica_healthy(&r.id))
.collect();
if replicas.is_empty() {
return self.get_nearest();
}
let idx = self.round_robin.fetch_add(1, Ordering::Relaxed) % replicas.len();
Ok(replicas[idx])
}
fn is_replica_healthy(&self, id: &str) -> bool {
self.health.get(id).map(|h| h.is_usable()).unwrap_or(false)
}
pub fn update_health(
&mut self,
id: &str,
status: HealthStatus,
latency: Option<Duration>,
lag: Option<Duration>,
) {
if let Some(health) = self.health.get_mut(id) {
match status {
HealthStatus::Healthy => {
health.mark_healthy(latency.unwrap_or(Duration::ZERO), lag);
}
HealthStatus::Degraded => {
health.mark_degraded("degraded");
}
HealthStatus::Unhealthy => {
health.mark_unhealthy();
}
HealthStatus::Unknown => {}
}
}
}
pub fn check_lag(&self, replica_id: &str, max_lag: Duration) -> bool {
self.health
.get(replica_id)
.and_then(|h| h.lag)
.map(|lag| lag <= max_lag)
.unwrap_or(false)
}
pub fn initiate_failover(&mut self) -> QueryResult<String> {
self.in_failover.store(true, Ordering::SeqCst);
let candidate = self
.config
.replicas
.iter()
.filter(|r| r.role == ReplicaRole::Secondary)
.filter(|r| self.is_replica_healthy(&r.id))
.max_by_key(|r| r.priority);
match candidate {
Some(new_primary) => {
let new_primary_id = new_primary.id.clone();
self.current_primary = Some(new_primary_id.clone());
self.in_failover.store(false, Ordering::SeqCst);
Ok(new_primary_id)
}
None => {
self.in_failover.store(false, Ordering::SeqCst);
Err(QueryError::connection(
"No suitable failover candidate found",
))
}
}
}
pub fn is_in_failover(&self) -> bool {
self.in_failover.load(Ordering::SeqCst)
}
}
#[derive(Debug)]
pub struct LagMonitor {
measurements: HashMap<String, LagMeasurement>,
max_acceptable_lag: Duration,
}
#[derive(Debug, Clone)]
pub struct LagMeasurement {
pub current: Duration,
pub average: Duration,
pub max: Duration,
pub timestamp: Instant,
pub samples: u64,
}
impl LagMonitor {
pub fn new(max_acceptable_lag: Duration) -> Self {
Self {
measurements: HashMap::new(),
max_acceptable_lag,
}
}
pub fn record(&mut self, replica_id: &str, lag: Duration) {
let entry = self
.measurements
.entry(replica_id.to_string())
.or_insert_with(|| LagMeasurement {
current: Duration::ZERO,
average: Duration::ZERO,
max: Duration::ZERO,
timestamp: Instant::now(),
samples: 0,
});
entry.current = lag;
entry.max = entry.max.max(lag);
entry.samples += 1;
let alpha = 0.3;
let new_avg = Duration::from_secs_f64(
entry.average.as_secs_f64() * (1.0 - alpha) + lag.as_secs_f64() * alpha,
);
entry.average = new_avg;
entry.timestamp = Instant::now();
}
pub fn is_acceptable(&self, replica_id: &str) -> bool {
self.measurements
.get(replica_id)
.map(|m| m.current <= self.max_acceptable_lag)
.unwrap_or(true) }
pub fn get_lag(&self, replica_id: &str) -> Option<Duration> {
self.measurements.get(replica_id).map(|m| m.current)
}
pub fn get_lagging_replicas(&self) -> Vec<&str> {
self.measurements
.iter()
.filter(|(_, m)| m.current > self.max_acceptable_lag)
.map(|(id, _)| id.as_str())
.collect()
}
}
pub mod lag_queries {
use crate::sql::DatabaseType;
pub fn check_lag_sql(db_type: DatabaseType) -> &'static str {
match db_type {
DatabaseType::PostgreSQL => {
"SELECT EXTRACT(EPOCH FROM (now() - pg_last_xact_replay_timestamp()))::INT AS lag_seconds"
}
DatabaseType::MySQL => {
"SHOW SLAVE STATUS"
}
DatabaseType::MSSQL => {
"SELECT datediff(s, last_commit_time, getdate()) AS lag_seconds \
FROM sys.dm_hadr_database_replica_states \
WHERE is_local = 1"
}
DatabaseType::SQLite => {
"SELECT 0 AS lag_seconds"
}
}
}
pub fn is_primary_sql(db_type: DatabaseType) -> &'static str {
match db_type {
DatabaseType::PostgreSQL => "SELECT NOT pg_is_in_recovery() AS is_primary",
DatabaseType::MySQL => "SELECT @@read_only = 0 AS is_primary",
DatabaseType::MSSQL => {
"SELECT CASE WHEN role = 1 THEN 1 ELSE 0 END AS is_primary \
FROM sys.dm_hadr_availability_replica_states \
WHERE is_local = 1"
}
DatabaseType::SQLite => "SELECT 1 AS is_primary",
}
}
pub fn replica_status_sql(db_type: DatabaseType) -> &'static str {
match db_type {
DatabaseType::PostgreSQL => {
"SELECT \
pg_is_in_recovery() AS is_replica, \
pg_last_wal_receive_lsn() AS receive_lsn, \
pg_last_wal_replay_lsn() AS replay_lsn"
}
DatabaseType::MySQL => "SHOW REPLICA STATUS",
DatabaseType::MSSQL => {
"SELECT synchronization_state_desc, synchronization_health_desc \
FROM sys.dm_hadr_database_replica_states \
WHERE is_local = 1"
}
DatabaseType::SQLite => "SELECT 'primary' AS status",
}
}
}
pub mod mongodb {
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use super::ReadPreference;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ReadConcern {
Local,
Majority,
Linearizable,
Snapshot,
Available,
}
impl ReadConcern {
pub fn as_str(&self) -> &'static str {
match self {
Self::Local => "local",
Self::Majority => "majority",
Self::Linearizable => "linearizable",
Self::Snapshot => "snapshot",
Self::Available => "available",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum WriteConcern {
W1,
Majority,
W(u32),
Tag(String),
}
impl WriteConcern {
pub fn to_options(&self) -> JsonValue {
match self {
Self::W1 => serde_json::json!({ "w": 1 }),
Self::Majority => serde_json::json!({ "w": "majority" }),
Self::W(n) => serde_json::json!({ "w": n }),
Self::Tag(tag) => serde_json::json!({ "w": tag }),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MongoReadPreference {
pub mode: ReadPreference,
pub max_staleness_seconds: Option<u32>,
pub tag_sets: Vec<serde_json::Map<String, JsonValue>>,
pub hedge: Option<bool>,
}
impl MongoReadPreference {
pub fn new(mode: ReadPreference) -> Self {
Self {
mode,
max_staleness_seconds: None,
tag_sets: Vec::new(),
hedge: None,
}
}
pub fn max_staleness(mut self, seconds: u32) -> Self {
self.max_staleness_seconds = Some(seconds);
self
}
pub fn tag_set(mut self, tags: serde_json::Map<String, JsonValue>) -> Self {
self.tag_sets.push(tags);
self
}
pub fn hedged(mut self) -> Self {
self.hedge = Some(true);
self
}
pub fn to_connection_options(&self) -> String {
let mut opts = vec![format!("readPreference={}", self.mode.to_mongodb())];
if let Some(staleness) = self.max_staleness_seconds {
opts.push(format!("maxStalenessSeconds={}", staleness));
}
opts.join("&")
}
pub fn to_command_options(&self) -> JsonValue {
let mut opts = serde_json::Map::new();
opts.insert(
"mode".to_string(),
serde_json::json!(self.mode.to_mongodb()),
);
if let Some(staleness) = self.max_staleness_seconds {
opts.insert(
"maxStalenessSeconds".to_string(),
serde_json::json!(staleness),
);
}
if !self.tag_sets.is_empty() {
opts.insert("tagSets".to_string(), serde_json::json!(self.tag_sets));
}
if let Some(hedge) = self.hedge {
opts.insert("hedge".to_string(), serde_json::json!({ "enabled": hedge }));
}
serde_json::json!(opts)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReplicaSetStatus {
pub set: String,
pub members: Vec<MemberStatus>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemberStatus {
pub id: u32,
pub name: String,
pub state_str: String,
pub health: f64,
#[serde(default)]
pub lag_seconds: Option<i64>,
}
impl MemberStatus {
pub fn is_primary(&self) -> bool {
self.state_str == "PRIMARY"
}
pub fn is_secondary(&self) -> bool {
self.state_str == "SECONDARY"
}
pub fn is_healthy(&self) -> bool {
self.health >= 1.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_replica_config() {
let primary =
ReplicaConfig::primary("pg1", "postgres://primary:5432/db").with_region("us-east-1");
assert_eq!(primary.role, ReplicaRole::Primary);
assert_eq!(primary.region.as_deref(), Some("us-east-1"));
}
#[test]
fn test_replica_set_builder() {
let config = ReplicaSetConfig::new("myapp")
.primary("pg1", "postgres://primary:5432/db")
.secondary("pg2", "postgres://secondary1:5432/db")
.secondary("pg3", "postgres://secondary2:5432/db")
.read_preference(ReadPreference::SecondaryPreferred)
.build();
assert_eq!(config.name, "myapp");
assert_eq!(config.replicas.len(), 3);
assert!(config.primary().is_some());
assert_eq!(config.secondaries().count(), 2);
}
#[test]
fn test_read_preference_mongodb() {
assert_eq!(ReadPreference::Primary.to_mongodb(), "primary");
assert_eq!(
ReadPreference::SecondaryPreferred.to_mongodb(),
"secondaryPreferred"
);
assert_eq!(ReadPreference::Nearest.to_mongodb(), "nearest");
}
#[test]
fn test_connection_router_write() {
let config = ReplicaSetConfig::new("test")
.primary("pg1", "postgres://primary:5432/db")
.secondary("pg2", "postgres://secondary:5432/db")
.build();
let mut router = ConnectionRouter::new(config);
router.update_health(
"pg1",
HealthStatus::Healthy,
Some(Duration::from_millis(5)),
None,
);
router.update_health(
"pg2",
HealthStatus::Healthy,
Some(Duration::from_millis(10)),
Some(Duration::from_secs(1)),
);
let target = router.route(QueryType::Write, None).unwrap();
assert_eq!(target.id, "pg1");
}
#[test]
fn test_connection_router_read_secondary() {
let config = ReplicaSetConfig::new("test")
.primary("pg1", "postgres://primary:5432/db")
.secondary("pg2", "postgres://secondary:5432/db")
.read_preference(ReadPreference::Secondary)
.build();
let mut router = ConnectionRouter::new(config);
router.update_health(
"pg1",
HealthStatus::Healthy,
Some(Duration::from_millis(5)),
None,
);
router.update_health(
"pg2",
HealthStatus::Healthy,
Some(Duration::from_millis(10)),
Some(Duration::from_secs(1)),
);
let target = router.route(QueryType::Read, None).unwrap();
assert_eq!(target.id, "pg2");
}
#[test]
fn test_lag_monitor() {
let mut monitor = LagMonitor::new(Duration::from_secs(10));
monitor.record("pg2", Duration::from_secs(5));
assert!(monitor.is_acceptable("pg2"));
monitor.record("pg3", Duration::from_secs(15));
assert!(!monitor.is_acceptable("pg3"));
let lagging = monitor.get_lagging_replicas();
assert_eq!(lagging, vec!["pg3"]);
}
#[test]
fn test_failover() {
let config = ReplicaSetConfig::new("test")
.primary("pg1", "postgres://primary:5432/db")
.replica(
ReplicaConfig::secondary("pg2", "postgres://secondary1:5432/db").with_priority(80),
)
.replica(
ReplicaConfig::secondary("pg3", "postgres://secondary2:5432/db").with_priority(60),
)
.build();
let mut router = ConnectionRouter::new(config);
router.update_health("pg1", HealthStatus::Unhealthy, None, None);
router.update_health(
"pg2",
HealthStatus::Healthy,
Some(Duration::from_millis(10)),
None,
);
router.update_health(
"pg3",
HealthStatus::Healthy,
Some(Duration::from_millis(15)),
None,
);
let new_primary = router.initiate_failover().unwrap();
assert_eq!(new_primary, "pg2"); }
mod mongodb_tests {
use super::super::mongodb::*;
use super::*;
#[test]
fn test_read_concern() {
assert_eq!(ReadConcern::Majority.as_str(), "majority");
assert_eq!(ReadConcern::Local.as_str(), "local");
}
#[test]
fn test_write_concern() {
let w = WriteConcern::Majority;
let opts = w.to_options();
assert_eq!(opts["w"], "majority");
let w2 = WriteConcern::W(3);
let opts2 = w2.to_options();
assert_eq!(opts2["w"], 3);
}
#[test]
fn test_mongo_read_preference() {
let pref = MongoReadPreference::new(ReadPreference::SecondaryPreferred)
.max_staleness(90)
.hedged();
let conn_opts = pref.to_connection_options();
assert!(conn_opts.contains("readPreference=secondaryPreferred"));
assert!(conn_opts.contains("maxStalenessSeconds=90"));
let cmd_opts = pref.to_command_options();
assert_eq!(cmd_opts["mode"], "secondaryPreferred");
assert_eq!(cmd_opts["maxStalenessSeconds"], 90);
}
}
}