1use rand::prelude::IndexedRandom;
7use serde::{Deserialize, Serialize};
8use sqlx::postgres::PgPoolOptions;
9use sqlx::PgPool;
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::Arc;
12use std::time::Duration;
13
14use crate::error::{DbError, Result};
15use crate::pool::{health_check, HealthCheck, HealthStatus, RetryConfig};
16
17#[derive(Debug, Clone, Deserialize)]
19pub struct ReplicaConfig {
20 pub primary_url: String,
22 pub replica_urls: Vec<String>,
24 pub max_connections: u32,
26 pub min_connections: u32,
28 pub acquire_timeout_secs: u64,
30 pub load_balance_strategy: LoadBalanceStrategy,
32}
33
34impl Default for ReplicaConfig {
35 fn default() -> Self {
36 Self {
37 primary_url: String::new(),
38 replica_urls: Vec::new(),
39 max_connections: 20,
40 min_connections: 5,
41 acquire_timeout_secs: 5,
42 load_balance_strategy: LoadBalanceStrategy::RoundRobin,
43 }
44 }
45}
46
47impl ReplicaConfig {
48 pub fn primary_only(url: impl Into<String>) -> Self {
50 Self {
51 primary_url: url.into(),
52 ..Default::default()
53 }
54 }
55
56 pub fn add_replica(mut self, url: impl Into<String>) -> Self {
58 self.replica_urls.push(url.into());
59 self
60 }
61
62 pub fn strategy(mut self, strategy: LoadBalanceStrategy) -> Self {
64 self.load_balance_strategy = strategy;
65 self
66 }
67}
68
69#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
71#[serde(rename_all = "snake_case")]
72pub enum LoadBalanceStrategy {
73 #[default]
75 RoundRobin,
76 Random,
78 FirstAvailable,
80 LeastConnections,
82}
83
84#[derive(Debug, Clone, Serialize)]
86pub struct ReplicaStatus {
87 pub url_masked: String,
88 pub is_healthy: bool,
89 pub pool_size: u32,
90 pub pool_idle: u32,
91 pub latency_ms: Option<u64>,
92}
93
94pub struct ReplicaPoolManager {
96 primary: PgPool,
98 replicas: Vec<PgPool>,
100 round_robin_index: AtomicUsize,
102 strategy: LoadBalanceStrategy,
104 #[allow(dead_code)]
106 config: ReplicaConfig,
107}
108
109impl ReplicaPoolManager {
110 pub async fn new(config: ReplicaConfig) -> Result<Self> {
112 let retry_config = RetryConfig::default();
113 Self::with_retry(config, &retry_config).await
114 }
115
116 pub async fn with_retry(config: ReplicaConfig, retry_config: &RetryConfig) -> Result<Self> {
118 let primary = create_pool_with_config(&config.primary_url, &config, retry_config).await?;
120 tracing::info!("Primary database pool created");
121
122 let mut replicas = Vec::with_capacity(config.replica_urls.len());
124 for (i, url) in config.replica_urls.iter().enumerate() {
125 match create_pool_with_config(url, &config, retry_config).await {
126 Ok(pool) => {
127 replicas.push(pool);
128 tracing::info!(replica_index = i, "Read replica pool created");
129 }
130 Err(e) => {
131 tracing::warn!(
132 replica_index = i,
133 error = %e,
134 "Failed to create read replica pool, skipping"
135 );
136 }
137 }
138 }
139
140 if replicas.is_empty() && !config.replica_urls.is_empty() {
141 tracing::warn!("No read replicas available, falling back to primary for reads");
142 }
143
144 Ok(Self {
145 primary,
146 replicas,
147 round_robin_index: AtomicUsize::new(0),
148 strategy: config.load_balance_strategy,
149 config,
150 })
151 }
152
153 pub fn write_pool(&self) -> &PgPool {
155 &self.primary
156 }
157
158 pub fn read_pool(&self) -> &PgPool {
160 if self.replicas.is_empty() {
161 return &self.primary;
162 }
163
164 match self.strategy {
165 LoadBalanceStrategy::RoundRobin => self.round_robin_replica(),
166 LoadBalanceStrategy::Random => self.random_replica(),
167 LoadBalanceStrategy::FirstAvailable => self.first_available_replica(),
168 LoadBalanceStrategy::LeastConnections => self.least_connections_replica(),
169 }
170 }
171
172 pub fn primary(&self) -> &PgPool {
174 &self.primary
175 }
176
177 pub fn all_pools(&self) -> impl Iterator<Item = &PgPool> {
179 std::iter::once(&self.primary).chain(self.replicas.iter())
180 }
181
182 pub fn replica_count(&self) -> usize {
184 self.replicas.len()
185 }
186
187 pub fn has_replicas(&self) -> bool {
189 !self.replicas.is_empty()
190 }
191
192 pub async fn health_status(&self) -> ReplicaHealthStatus {
194 let primary_health = health_check(&self.primary).await;
195
196 let mut replica_health = Vec::with_capacity(self.replicas.len());
197 for (i, pool) in self.replicas.iter().enumerate() {
198 let health = health_check(pool).await;
199 replica_health.push(ReplicaStatus {
200 url_masked: format!("replica_{}", i),
201 is_healthy: health.status == HealthStatus::Healthy,
202 pool_size: health.pool_size,
203 pool_idle: health.pool_idle,
204 latency_ms: health.latency_ms,
205 });
206 }
207
208 let healthy_replicas = replica_health.iter().filter(|r| r.is_healthy).count();
209 let overall_status = if primary_health.status != HealthStatus::Healthy {
210 HealthStatus::Unhealthy
211 } else if healthy_replicas < self.replicas.len() {
212 HealthStatus::Degraded
213 } else {
214 HealthStatus::Healthy
215 };
216
217 ReplicaHealthStatus {
218 overall_status,
219 primary: primary_health,
220 replicas: replica_health,
221 healthy_replica_count: healthy_replicas,
222 total_replica_count: self.replicas.len(),
223 }
224 }
225
226 fn round_robin_replica(&self) -> &PgPool {
229 let index = self.round_robin_index.fetch_add(1, Ordering::Relaxed) % self.replicas.len();
230 &self.replicas[index]
231 }
232
233 fn random_replica(&self) -> &PgPool {
234 self.replicas
235 .choose(&mut rand::rng())
236 .unwrap_or(&self.primary)
237 }
238
239 fn first_available_replica(&self) -> &PgPool {
240 for replica in &self.replicas {
242 if replica.num_idle() > 0 {
243 return replica;
244 }
245 }
246 &self.replicas[0]
248 }
249
250 fn least_connections_replica(&self) -> &PgPool {
251 self.replicas
252 .iter()
253 .max_by_key(|p| p.num_idle())
254 .unwrap_or(&self.primary)
255 }
256}
257
258#[derive(Debug, Serialize)]
260pub struct ReplicaHealthStatus {
261 pub overall_status: HealthStatus,
262 pub primary: HealthCheck,
263 pub replicas: Vec<ReplicaStatus>,
264 pub healthy_replica_count: usize,
265 pub total_replica_count: usize,
266}
267
268async fn create_pool_with_config(
270 url: &str,
271 config: &ReplicaConfig,
272 retry_config: &RetryConfig,
273) -> Result<PgPool> {
274 let mut last_error = None;
275
276 for attempt in 0..retry_config.max_attempts {
277 match try_create_pool(url, config).await {
278 Ok(pool) => return Ok(pool),
279 Err(e) => {
280 last_error = Some(e);
281 if attempt + 1 < retry_config.max_attempts {
282 let delay = retry_config.delay_for_attempt(attempt);
283 tokio::time::sleep(delay).await;
284 }
285 }
286 }
287 }
288
289 Err(DbError::Connection(format!(
290 "Failed to create pool after {} attempts: {}",
291 retry_config.max_attempts,
292 last_error.map(|e| e.to_string()).unwrap_or_default()
293 )))
294}
295
296async fn try_create_pool(
297 url: &str,
298 config: &ReplicaConfig,
299) -> std::result::Result<PgPool, sqlx::Error> {
300 PgPoolOptions::new()
301 .max_connections(config.max_connections)
302 .min_connections(config.min_connections)
303 .acquire_timeout(Duration::from_secs(config.acquire_timeout_secs))
304 .idle_timeout(Duration::from_secs(600))
305 .connect(url)
306 .await
307}
308
309pub struct SmartDbClient {
311 manager: Arc<ReplicaPoolManager>,
312}
313
314impl SmartDbClient {
315 pub fn new(manager: ReplicaPoolManager) -> Self {
316 Self {
317 manager: Arc::new(manager),
318 }
319 }
320
321 pub fn from_arc(manager: Arc<ReplicaPoolManager>) -> Self {
322 Self { manager }
323 }
324
325 pub fn read(&self) -> &PgPool {
327 self.manager.read_pool()
328 }
329
330 pub fn write(&self) -> &PgPool {
332 self.manager.write_pool()
333 }
334
335 pub fn manager(&self) -> &ReplicaPoolManager {
337 &self.manager
338 }
339
340 pub fn shared_manager(&self) -> Arc<ReplicaPoolManager> {
342 self.manager.clone()
343 }
344}
345
346impl Clone for SmartDbClient {
347 fn clone(&self) -> Self {
348 Self {
349 manager: self.manager.clone(),
350 }
351 }
352}
353
354pub struct SmartDbClientBuilder {
356 config: ReplicaConfig,
357}
358
359impl SmartDbClientBuilder {
360 pub fn new(primary_url: impl Into<String>) -> Self {
361 Self {
362 config: ReplicaConfig::primary_only(primary_url),
363 }
364 }
365
366 pub fn add_replica(mut self, url: impl Into<String>) -> Self {
367 self.config.replica_urls.push(url.into());
368 self
369 }
370
371 pub fn max_connections(mut self, max: u32) -> Self {
372 self.config.max_connections = max;
373 self
374 }
375
376 pub fn min_connections(mut self, min: u32) -> Self {
377 self.config.min_connections = min;
378 self
379 }
380
381 pub fn strategy(mut self, strategy: LoadBalanceStrategy) -> Self {
382 self.config.load_balance_strategy = strategy;
383 self
384 }
385
386 pub async fn build(self) -> Result<SmartDbClient> {
387 let manager = ReplicaPoolManager::new(self.config).await?;
388 Ok(SmartDbClient::new(manager))
389 }
390}