1use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::hash::{Hash, Hasher};
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
20#[serde(rename_all = "snake_case")]
21pub enum ShardingStrategy {
22 Hash,
24 Range,
26 Tenant,
28 Temporal,
30 Geographic,
32 RoundRobin,
34 Custom(String),
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ShardConfig {
41 pub id: String,
43 pub name: String,
45 pub connection_string: String,
47 pub weight: u32,
49 pub active: bool,
51 pub region: Option<String>,
53 pub range_min: Option<String>,
55 pub range_max: Option<String>,
57 pub tenant_ids: Vec<String>,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ShardingConfig {
64 pub strategy: ShardingStrategy,
66 pub virtual_shards: u32,
68 pub replication_factor: u32,
70 pub shards: Vec<ShardConfig>,
72 pub default_shard: String,
74}
75
76impl Default for ShardingConfig {
77 fn default() -> Self {
78 Self {
79 strategy: ShardingStrategy::Hash,
80 virtual_shards: 256,
81 replication_factor: 1,
82 shards: vec![ShardConfig {
83 id: "default".to_string(),
84 name: "Default Shard".to_string(),
85 connection_string: "sqlite://chasm.db".to_string(),
86 weight: 100,
87 active: true,
88 region: None,
89 range_min: None,
90 range_max: None,
91 tenant_ids: vec![],
92 }],
93 default_shard: "default".to_string(),
94 }
95 }
96}
97
98pub struct ShardRouter {
104 config: ShardingConfig,
105 ring: ConsistentHashRing,
106}
107
108impl ShardRouter {
109 pub fn new(config: ShardingConfig) -> Self {
111 let ring = ConsistentHashRing::new(&config);
112 Self { config, ring }
113 }
114
115 pub fn get_shard(&self, key: &str) -> &ShardConfig {
117 match self.config.strategy {
118 ShardingStrategy::Hash => self.get_shard_by_hash(key),
119 ShardingStrategy::Range => self.get_shard_by_range(key),
120 ShardingStrategy::Tenant => self.get_shard_by_tenant(key),
121 ShardingStrategy::Temporal => self.get_shard_by_time(key),
122 ShardingStrategy::Geographic => self.get_shard_by_region(key),
123 ShardingStrategy::RoundRobin => self.get_shard_round_robin(),
124 ShardingStrategy::Custom(_) => self.get_default_shard(),
125 }
126 }
127
128 fn get_shard_by_hash(&self, key: &str) -> &ShardConfig {
130 let shard_id = self.ring.get_node(key);
131 self.config
132 .shards
133 .iter()
134 .find(|s| s.id == shard_id && s.active)
135 .unwrap_or_else(|| self.get_default_shard())
136 }
137
138 fn get_shard_by_range(&self, key: &str) -> &ShardConfig {
140 for shard in &self.config.shards {
141 if !shard.active {
142 continue;
143 }
144 let in_min = shard.range_min.as_ref().map(|m| key >= m.as_str()).unwrap_or(true);
145 let in_max = shard.range_max.as_ref().map(|m| key < m.as_str()).unwrap_or(true);
146 if in_min && in_max {
147 return shard;
148 }
149 }
150 self.get_default_shard()
151 }
152
153 fn get_shard_by_tenant(&self, tenant_id: &str) -> &ShardConfig {
155 self.config
156 .shards
157 .iter()
158 .find(|s| s.active && s.tenant_ids.contains(&tenant_id.to_string()))
159 .unwrap_or_else(|| self.get_default_shard())
160 }
161
162 fn get_shard_by_time(&self, time_key: &str) -> &ShardConfig {
164 self.get_shard_by_range(time_key)
167 }
168
169 fn get_shard_by_region(&self, region: &str) -> &ShardConfig {
171 self.config
172 .shards
173 .iter()
174 .find(|s| s.active && s.region.as_deref() == Some(region))
175 .unwrap_or_else(|| self.get_default_shard())
176 }
177
178 fn get_shard_round_robin(&self) -> &ShardConfig {
180 let active_shards: Vec<_> = self.config.shards.iter().filter(|s| s.active).collect();
181 if active_shards.is_empty() {
182 return self.get_default_shard();
183 }
184 let idx = (Utc::now().timestamp_millis() as usize) % active_shards.len();
185 active_shards[idx]
186 }
187
188 fn get_default_shard(&self) -> &ShardConfig {
190 self.config
191 .shards
192 .iter()
193 .find(|s| s.id == self.config.default_shard)
194 .unwrap_or(&self.config.shards[0])
195 }
196
197 pub fn get_all_shards(&self) -> Vec<&ShardConfig> {
199 self.config.shards.iter().filter(|s| s.active).collect()
200 }
201}
202
203struct ConsistentHashRing {
209 ring: Vec<(u64, String)>,
210}
211
212impl ConsistentHashRing {
213 fn new(config: &ShardingConfig) -> Self {
214 let mut ring = Vec::new();
215
216 for shard in &config.shards {
217 if !shard.active {
218 continue;
219 }
220 let vnodes = (config.virtual_shards * shard.weight) / 100;
222 for i in 0..vnodes {
223 let key = format!("{}:{}", shard.id, i);
224 let hash = Self::hash(&key);
225 ring.push((hash, shard.id.clone()));
226 }
227 }
228
229 ring.sort_by_key(|(hash, _)| *hash);
230 Self { ring }
231 }
232
233 fn hash(key: &str) -> u64 {
234 let mut hasher = std::collections::hash_map::DefaultHasher::new();
235 key.hash(&mut hasher);
236 hasher.finish()
237 }
238
239 fn get_node(&self, key: &str) -> String {
240 if self.ring.is_empty() {
241 return "default".to_string();
242 }
243
244 let hash = Self::hash(key);
245
246 let idx = match self.ring.binary_search_by_key(&hash, |(h, _)| *h) {
248 Ok(i) => i,
249 Err(i) => {
250 if i >= self.ring.len() {
251 0 } else {
253 i
254 }
255 }
256 };
257
258 self.ring[idx].1.clone()
259 }
260}
261
262#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct ReplicaConfig {
269 pub id: String,
271 pub name: String,
273 pub connection_string: String,
275 pub region: Option<String>,
277 pub priority: u32,
279 pub max_lag_seconds: u32,
281 pub active: bool,
283 #[serde(skip)]
285 pub current_lag_ms: u64,
286}
287
288pub struct ReplicaManager {
290 primary: String,
291 replicas: Vec<ReplicaConfig>,
292 health_status: Arc<RwLock<HashMap<String, ReplicaHealth>>>,
293}
294
295#[derive(Debug, Clone)]
296struct ReplicaHealth {
297 is_healthy: bool,
298 last_check: DateTime<Utc>,
299 lag_ms: u64,
300 error_count: u32,
301}
302
303impl ReplicaManager {
304 pub fn new(primary: String, replicas: Vec<ReplicaConfig>) -> Self {
306 Self {
307 primary,
308 replicas,
309 health_status: Arc::new(RwLock::new(HashMap::new())),
310 }
311 }
312
313 pub async fn get_read_replica(&self, preferred_region: Option<&str>) -> String {
315 let health = self.health_status.read().await;
316
317 let mut candidates: Vec<_> = self
319 .replicas
320 .iter()
321 .filter(|r| {
322 r.active
323 && health
324 .get(&r.id)
325 .map(|h| h.is_healthy && h.lag_ms < (r.max_lag_seconds as u64 * 1000))
326 .unwrap_or(false)
327 })
328 .collect();
329
330 if candidates.is_empty() {
331 return self.primary.clone();
332 }
333
334 if let Some(region) = preferred_region {
336 let regional: Vec<_> = candidates
337 .iter()
338 .filter(|r| r.region.as_deref() == Some(region))
339 .copied()
340 .collect();
341 if !regional.is_empty() {
342 candidates = regional;
343 }
344 }
345
346 candidates.sort_by(|a, b| {
348 let lag_a = health.get(&a.id).map(|h| h.lag_ms).unwrap_or(u64::MAX);
349 let lag_b = health.get(&b.id).map(|h| h.lag_ms).unwrap_or(u64::MAX);
350 a.priority.cmp(&b.priority).then(lag_a.cmp(&lag_b))
351 });
352
353 candidates
354 .first()
355 .map(|r| r.connection_string.clone())
356 .unwrap_or_else(|| self.primary.clone())
357 }
358
359 pub fn get_primary(&self) -> &str {
361 &self.primary
362 }
363
364 pub async fn update_health(&self, replica_id: &str, is_healthy: bool, lag_ms: u64) {
366 let mut health = self.health_status.write().await;
367 let entry = health.entry(replica_id.to_string()).or_insert(ReplicaHealth {
368 is_healthy: true,
369 last_check: Utc::now(),
370 lag_ms: 0,
371 error_count: 0,
372 });
373
374 entry.is_healthy = is_healthy;
375 entry.last_check = Utc::now();
376 entry.lag_ms = lag_ms;
377 if !is_healthy {
378 entry.error_count += 1;
379 } else {
380 entry.error_count = 0;
381 }
382 }
383
384 pub async fn health_check_all(&self) {
386 for replica in &self.replicas {
387 if !replica.active {
388 continue;
389 }
390
391 let (is_healthy, lag_ms) = self.check_replica_health(&replica.connection_string).await;
393 self.update_health(&replica.id, is_healthy, lag_ms).await;
394 }
395 }
396
397 async fn check_replica_health(&self, _connection_string: &str) -> (bool, u64) {
398 (true, 50)
400 }
401}
402
403#[derive(Debug, Clone, Serialize, Deserialize)]
409pub struct PoolConfig {
410 pub min_connections: u32,
412 pub max_connections: u32,
414 pub connect_timeout_seconds: u32,
416 pub idle_timeout_seconds: u32,
418 pub max_lifetime_seconds: u32,
420}
421
422impl Default for PoolConfig {
423 fn default() -> Self {
424 Self {
425 min_connections: 5,
426 max_connections: 20,
427 connect_timeout_seconds: 30,
428 idle_timeout_seconds: 300,
429 max_lifetime_seconds: 1800,
430 }
431 }
432}
433
434pub struct ScalingManager {
440 shard_router: ShardRouter,
441 replica_manager: ReplicaManager,
442 pool_config: PoolConfig,
443}
444
445impl ScalingManager {
446 pub fn new(
448 sharding_config: ShardingConfig,
449 primary: String,
450 replicas: Vec<ReplicaConfig>,
451 pool_config: PoolConfig,
452 ) -> Self {
453 Self {
454 shard_router: ShardRouter::new(sharding_config),
455 replica_manager: ReplicaManager::new(primary, replicas),
456 pool_config,
457 }
458 }
459
460 pub fn get_write_connection(&self, key: &str) -> &str {
462 let shard = self.shard_router.get_shard(key);
463 &shard.connection_string
464 }
465
466 pub async fn get_read_connection(&self, key: &str, region: Option<&str>) -> String {
468 let _shard = self.shard_router.get_shard(key);
470 self.replica_manager.get_read_replica(region).await
472 }
473
474 pub fn get_all_shards(&self) -> Vec<&ShardConfig> {
476 self.shard_router.get_all_shards()
477 }
478
479 pub async fn health_check(&self) {
481 self.replica_manager.health_check_all().await;
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488
489 #[test]
490 fn test_consistent_hash_ring() {
491 let config = ShardingConfig::default();
492 let ring = ConsistentHashRing::new(&config);
493
494 let node1 = ring.get_node("test_key");
496 let node2 = ring.get_node("test_key");
497 assert_eq!(node1, node2);
498 }
499
500 #[test]
501 fn test_shard_router() {
502 let config = ShardingConfig {
503 strategy: ShardingStrategy::Hash,
504 shards: vec![
505 ShardConfig {
506 id: "shard1".to_string(),
507 name: "Shard 1".to_string(),
508 connection_string: "sqlite://shard1.db".to_string(),
509 weight: 50,
510 active: true,
511 region: None,
512 range_min: None,
513 range_max: None,
514 tenant_ids: vec![],
515 },
516 ShardConfig {
517 id: "shard2".to_string(),
518 name: "Shard 2".to_string(),
519 connection_string: "sqlite://shard2.db".to_string(),
520 weight: 50,
521 active: true,
522 region: None,
523 range_min: None,
524 range_max: None,
525 tenant_ids: vec![],
526 },
527 ],
528 default_shard: "shard1".to_string(),
529 ..Default::default()
530 };
531
532 let router = ShardRouter::new(config);
533 let shard = router.get_shard("some_key");
534 assert!(shard.id == "shard1" || shard.id == "shard2");
535 }
536}