1use std::sync::{
4 atomic::{AtomicUsize, Ordering},
5 Arc,
6};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
12pub enum ReadStrategy {
13 #[default]
15 RoundRobin,
16 Random,
18 LeastConnections,
21}
22
23#[derive(Debug, Clone)]
38pub struct DatabaseConfig {
39 pub write: String,
41 pub read: Vec<String>,
43 pub read_strategy: ReadStrategy,
45}
46
47impl DatabaseConfig {
48 pub fn new(write: impl Into<String>) -> Self {
49 Self {
50 write: write.into(),
51 read: Vec::new(),
52 read_strategy: ReadStrategy::default(),
53 }
54 }
55
56 pub fn with_replica(mut self, url: impl Into<String>) -> Self {
57 self.read.push(url.into());
58 self
59 }
60
61 pub fn with_strategy(mut self, strategy: ReadStrategy) -> Self {
62 self.read_strategy = strategy;
63 self
64 }
65
66 pub fn has_replicas(&self) -> bool {
67 !self.read.is_empty()
68 }
69}
70
71#[derive(Debug, Default, Clone)]
77pub struct RoundRobinCounter(Arc<AtomicUsize>);
78
79impl RoundRobinCounter {
80 pub fn new() -> Self {
81 Self(Arc::new(AtomicUsize::new(0)))
82 }
83
84 pub fn next(&self, len: usize) -> usize {
86 if len == 0 {
87 return 0;
88 }
89 self.0.fetch_add(1, Ordering::Relaxed) % len
90 }
91}
92
93#[cfg(test)]
96mod tests {
97 use super::*;
98
99 #[test]
100 fn database_config_builder() {
101 let cfg = DatabaseConfig::new("postgres://primary/db")
102 .with_replica("postgres://r1/db")
103 .with_replica("postgres://r2/db")
104 .with_strategy(ReadStrategy::RoundRobin);
105
106 assert_eq!(cfg.write, "postgres://primary/db");
107 assert_eq!(cfg.read.len(), 2);
108 assert!(cfg.has_replicas());
109 assert_eq!(cfg.read_strategy, ReadStrategy::RoundRobin);
110 }
111
112 #[test]
113 fn no_replicas_has_replicas_false() {
114 let cfg = DatabaseConfig::new("postgres://primary/db");
115 assert!(!cfg.has_replicas());
116 }
117
118 #[test]
119 fn round_robin_wraps() {
120 let counter = RoundRobinCounter::new();
121 assert_eq!(counter.next(3), 0);
122 assert_eq!(counter.next(3), 1);
123 assert_eq!(counter.next(3), 2);
124 assert_eq!(counter.next(3), 0); }
126
127 #[test]
128 fn round_robin_zero_len() {
129 let counter = RoundRobinCounter::new();
130 assert_eq!(counter.next(0), 0);
131 }
132
133 #[test]
134 fn round_robin_cloned_shares_state() {
135 let a = RoundRobinCounter::new();
136 let b = a.clone();
137 a.next(4);
138 assert_eq!(b.next(4), 1); }
140}