1use crate::error::{DbError, Result};
4use chrono::{DateTime, Utc};
5use parking_lot::RwLock;
6use serde::{Deserialize, Serialize};
7use sqlx::PgPool;
8use std::collections::HashMap;
9use std::sync::Arc;
10
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13pub struct Region {
14 pub code: String,
16 pub name: String,
18 pub latitude: f64,
20 pub longitude: f64,
22}
23
24impl Region {
25 pub fn new(code: String, name: String, latitude: f64, longitude: f64) -> Self {
27 Self {
28 code,
29 name,
30 latitude,
31 longitude,
32 }
33 }
34
35 pub fn distance_to(&self, other: &Region) -> f64 {
37 let r = 6371.0; let lat1 = self.latitude.to_radians();
40 let lat2 = other.latitude.to_radians();
41 let delta_lat = (other.latitude - self.latitude).to_radians();
42 let delta_lon = (other.longitude - self.longitude).to_radians();
43
44 let a = (delta_lat / 2.0).sin().powi(2)
45 + lat1.cos() * lat2.cos() * (delta_lon / 2.0).sin().powi(2);
46 let c = 2.0 * a.sqrt().atan2((1.0 - a).sqrt());
47
48 r * c
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct RegionalPoolConfig {
55 pub region: Region,
57 pub database_url: String,
59 pub is_primary: bool,
61 pub failover_priority: u32,
63 pub is_active: bool,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ReplicationLag {
70 pub source_region: String,
72 pub replica_region: String,
74 pub lag_seconds: f64,
76 pub measured_at: DateTime<Utc>,
78 pub is_healthy: bool,
80}
81
82pub struct MultiRegionPoolManager {
84 pools: Arc<RwLock<HashMap<String, PgPool>>>,
86 configs: Arc<RwLock<Vec<RegionalPoolConfig>>>,
88 current_region: Arc<RwLock<Option<String>>>,
90 replication_lags: Arc<RwLock<HashMap<String, ReplicationLag>>>,
92 max_lag_seconds: f64,
94}
95
96impl MultiRegionPoolManager {
97 pub fn new(max_lag_seconds: f64) -> Self {
99 Self {
100 pools: Arc::new(RwLock::new(HashMap::new())),
101 configs: Arc::new(RwLock::new(Vec::new())),
102 current_region: Arc::new(RwLock::new(None)),
103 replication_lags: Arc::new(RwLock::new(HashMap::new())),
104 max_lag_seconds,
105 }
106 }
107
108 pub async fn add_regional_pool(&self, config: RegionalPoolConfig) -> Result<()> {
110 let pool = PgPool::connect(&config.database_url).await.map_err(|e| {
111 DbError::Connection(format!(
112 "Failed to connect to {}: {}",
113 config.region.code, e
114 ))
115 })?;
116
117 let region_code = config.region.code.clone();
118 let is_primary = config.is_primary;
119
120 let mut pools = self.pools.write();
121 let mut configs = self.configs.write();
122
123 pools.insert(region_code.clone(), pool);
124 configs.push(config);
125
126 if is_primary {
128 let mut current = self.current_region.write();
129 if current.is_none() {
130 *current = Some(region_code);
131 }
132 }
133
134 Ok(())
135 }
136
137 pub fn get_regional_pool(&self, region_code: &str) -> Result<PgPool> {
139 let pools = self.pools.read();
140 pools
141 .get(region_code)
142 .cloned()
143 .ok_or_else(|| DbError::Other(format!("Region {} not found", region_code)))
144 }
145
146 pub fn get_closest_pool(&self, latitude: f64, longitude: f64) -> Result<(String, PgPool)> {
148 let configs = self.configs.read();
149 let pools = self.pools.read();
150
151 let user_location =
152 Region::new("user".to_string(), "User".to_string(), latitude, longitude);
153
154 let mut closest: Option<(f64, String)> = None;
155
156 for config in configs.iter().filter(|c| c.is_active) {
157 let distance = user_location.distance_to(&config.region);
158
159 if let Some((min_dist, _)) = closest {
160 if distance < min_dist {
161 closest = Some((distance, config.region.code.clone()));
162 }
163 } else {
164 closest = Some((distance, config.region.code.clone()));
165 }
166 }
167
168 let region_code = closest
169 .map(|(_, code)| code)
170 .ok_or_else(|| DbError::Other("No active regions available".to_string()))?;
171
172 let pool = pools
173 .get(®ion_code)
174 .cloned()
175 .ok_or_else(|| DbError::Other(format!("Pool for region {} not found", region_code)))?;
176
177 Ok((region_code, pool))
178 }
179
180 pub fn get_primary_pool(&self) -> Result<PgPool> {
182 let configs = self.configs.read();
183 let pools = self.pools.read();
184
185 let primary_region = configs
186 .iter()
187 .find(|c| c.is_primary && c.is_active)
188 .map(|c| c.region.code.clone())
189 .ok_or_else(|| DbError::Other("No active primary region found".to_string()))?;
190
191 pools
192 .get(&primary_region)
193 .cloned()
194 .ok_or_else(|| DbError::Other("Primary pool not found".to_string()))
195 }
196
197 pub fn get_current_pool(&self) -> Result<PgPool> {
199 let current = self.current_region.read();
200 let pools = self.pools.read();
201
202 let region_code = current
203 .as_ref()
204 .ok_or_else(|| DbError::Other("No current region set".to_string()))?;
205
206 pools
207 .get(region_code)
208 .cloned()
209 .ok_or_else(|| DbError::Other("Current region pool not found".to_string()))
210 }
211
212 pub fn failover_to_region(&self, region_code: &str) -> Result<()> {
214 let pools = self.pools.read();
215 let configs = self.configs.read();
216
217 let _config = configs
219 .iter()
220 .find(|c| c.region.code == region_code && c.is_active)
221 .ok_or_else(|| {
222 DbError::Other(format!("Region {} not found or not active", region_code))
223 })?;
224
225 if !pools.contains_key(region_code) {
226 return Err(DbError::Other(format!(
227 "Pool for region {} not found",
228 region_code
229 )));
230 }
231
232 let mut current = self.current_region.write();
233 *current = Some(region_code.to_string());
234
235 tracing::info!("Failed over to region: {}", region_code);
236
237 Ok(())
238 }
239
240 pub async fn measure_replication_lag(
242 &self,
243 source_region: &str,
244 replica_region: &str,
245 ) -> Result<ReplicationLag> {
246 let source_pool = self.get_regional_pool(source_region)?;
247 let replica_pool = self.get_regional_pool(replica_region)?;
248
249 let source_lsn: (String,) = sqlx::query_as("SELECT pg_current_wal_lsn()::text")
251 .fetch_one(&source_pool)
252 .await?;
253
254 let replica_lsn: (Option<String>,) =
256 sqlx::query_as("SELECT pg_last_wal_receive_lsn()::text")
257 .fetch_one(&replica_pool)
258 .await?;
259
260 let lag_query = format!(
262 "SELECT COALESCE(pg_wal_lsn_diff('{}', '{}'), 0) / 1024.0 / 1024.0",
263 source_lsn.0,
264 replica_lsn.0.unwrap_or_else(|| "0/0".to_string())
265 );
266
267 let lag_mb: (f64,) = sqlx::query_as(&lag_query).fetch_one(&source_pool).await?;
268
269 let lag_seconds = lag_mb.0 / 10.0;
271
272 let lag = ReplicationLag {
273 source_region: source_region.to_string(),
274 replica_region: replica_region.to_string(),
275 lag_seconds,
276 measured_at: Utc::now(),
277 is_healthy: lag_seconds < self.max_lag_seconds,
278 };
279
280 let mut lags = self.replication_lags.write();
282 lags.insert(replica_region.to_string(), lag.clone());
283
284 Ok(lag)
285 }
286
287 pub fn get_replication_lags(&self) -> HashMap<String, ReplicationLag> {
289 let lags = self.replication_lags.read();
290 lags.clone()
291 }
292
293 pub fn are_all_replicas_healthy(&self) -> bool {
295 let lags = self.replication_lags.read();
296 lags.values().all(|lag| lag.is_healthy)
297 }
298
299 pub fn get_unhealthy_replicas(&self) -> Vec<String> {
301 let lags = self.replication_lags.read();
302 lags.iter()
303 .filter(|(_, lag)| !lag.is_healthy)
304 .map(|(region, _)| region.clone())
305 .collect()
306 }
307
308 pub fn auto_failover(&self) -> Result<String> {
310 let configs = self.configs.read();
311 let lags = self.replication_lags.read();
312
313 let mut candidates: Vec<_> = configs
315 .iter()
316 .filter(|c| c.is_active)
317 .map(|c| {
318 let lag_ok = lags
319 .get(&c.region.code)
320 .map(|l| l.is_healthy)
321 .unwrap_or(true);
322 (c, lag_ok)
323 })
324 .collect();
325
326 candidates.sort_by(|a, b| match (a.1, b.1) {
328 (true, false) => std::cmp::Ordering::Less,
329 (false, true) => std::cmp::Ordering::Greater,
330 _ => a.0.failover_priority.cmp(&b.0.failover_priority),
331 });
332
333 let target_region = candidates
334 .first()
335 .map(|(c, _)| c.region.code.clone())
336 .ok_or_else(|| DbError::Other("No suitable region for failover".to_string()))?;
337
338 self.failover_to_region(&target_region)?;
339
340 Ok(target_region)
341 }
342
343 pub fn list_regions(&self) -> Vec<RegionalPoolConfig> {
345 let configs = self.configs.read();
346 configs.clone()
347 }
348
349 pub fn get_current_region(&self) -> Option<String> {
351 let current = self.current_region.read();
352 current.clone()
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_region_distance() {
362 let us_east = Region::new("us-east-1".to_string(), "US East".to_string(), 39.0, -77.0);
363 let eu_west = Region::new("eu-west-1".to_string(), "EU West".to_string(), 53.0, -8.0);
364
365 let distance = us_east.distance_to(&eu_west);
366 assert!(distance > 4000.0); }
368
369 #[test]
370 fn test_multi_region_manager_creation() {
371 let manager = MultiRegionPoolManager::new(5.0);
372 assert!(manager.get_current_region().is_none());
373 assert_eq!(manager.list_regions().len(), 0);
374 }
375
376 #[test]
377 fn test_region_health_check() {
378 let manager = MultiRegionPoolManager::new(5.0);
379 assert!(manager.are_all_replicas_healthy());
380 assert_eq!(manager.get_unhealthy_replicas().len(), 0);
381 }
382}