Skip to main content

chasm/
scaling.rs

1// Copyright (c) 2024-2028 Nervosys LLC
2// SPDX-License-Identifier: AGPL-3.0-only
3//! Database Scaling Module
4//!
5//! Provides sharding, read replicas, and advanced database scaling capabilities.
6
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::hash::{Hash, Hasher};
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14// ============================================================================
15// Sharding Configuration
16// ============================================================================
17
18/// Sharding strategy
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
20#[serde(rename_all = "snake_case")]
21pub enum ShardingStrategy {
22    /// Shard by hash of key
23    Hash,
24    /// Shard by range of values
25    Range,
26    /// Shard by tenant ID
27    Tenant,
28    /// Shard by time period
29    Temporal,
30    /// Shard by geographic region
31    Geographic,
32    /// Round-robin distribution
33    RoundRobin,
34    /// Custom sharding function
35    Custom(String),
36}
37
38/// Shard configuration
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ShardConfig {
41    /// Shard ID
42    pub id: String,
43    /// Shard name
44    pub name: String,
45    /// Database connection string
46    pub connection_string: String,
47    /// Shard weight (for load balancing)
48    pub weight: u32,
49    /// Whether shard is active
50    pub active: bool,
51    /// Shard region
52    pub region: Option<String>,
53    /// Min range (for range sharding)
54    pub range_min: Option<String>,
55    /// Max range (for range sharding)
56    pub range_max: Option<String>,
57    /// Tenant IDs (for tenant sharding)
58    pub tenant_ids: Vec<String>,
59}
60
61/// Sharding configuration
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ShardingConfig {
64    /// Sharding strategy
65    pub strategy: ShardingStrategy,
66    /// Number of virtual shards (for consistent hashing)
67    pub virtual_shards: u32,
68    /// Replication factor
69    pub replication_factor: u32,
70    /// Shards
71    pub shards: Vec<ShardConfig>,
72    /// Default shard ID
73    pub default_shard: String,
74}
75
76impl Default for ShardingConfig {
77    fn default() -> Self {
78        Self {
79            strategy: ShardingStrategy::Hash,
80            virtual_shards: 256,
81            replication_factor: 1,
82            shards: vec![ShardConfig {
83                id: "default".to_string(),
84                name: "Default Shard".to_string(),
85                connection_string: "sqlite://chasm.db".to_string(),
86                weight: 100,
87                active: true,
88                region: None,
89                range_min: None,
90                range_max: None,
91                tenant_ids: vec![],
92            }],
93            default_shard: "default".to_string(),
94        }
95    }
96}
97
98// ============================================================================
99// Shard Router
100// ============================================================================
101
102/// Routes queries to appropriate shards
103pub struct ShardRouter {
104    config: ShardingConfig,
105    ring: ConsistentHashRing,
106}
107
108impl ShardRouter {
109    /// Create a new shard router
110    pub fn new(config: ShardingConfig) -> Self {
111        let ring = ConsistentHashRing::new(&config);
112        Self { config, ring }
113    }
114
115    /// Get shard for a key
116    pub fn get_shard(&self, key: &str) -> &ShardConfig {
117        match self.config.strategy {
118            ShardingStrategy::Hash => self.get_shard_by_hash(key),
119            ShardingStrategy::Range => self.get_shard_by_range(key),
120            ShardingStrategy::Tenant => self.get_shard_by_tenant(key),
121            ShardingStrategy::Temporal => self.get_shard_by_time(key),
122            ShardingStrategy::Geographic => self.get_shard_by_region(key),
123            ShardingStrategy::RoundRobin => self.get_shard_round_robin(),
124            ShardingStrategy::Custom(_) => self.get_default_shard(),
125        }
126    }
127
128    /// Get shard by consistent hash
129    fn get_shard_by_hash(&self, key: &str) -> &ShardConfig {
130        let shard_id = self.ring.get_node(key);
131        self.config
132            .shards
133            .iter()
134            .find(|s| s.id == shard_id && s.active)
135            .unwrap_or_else(|| self.get_default_shard())
136    }
137
138    /// Get shard by range
139    fn get_shard_by_range(&self, key: &str) -> &ShardConfig {
140        for shard in &self.config.shards {
141            if !shard.active {
142                continue;
143            }
144            let in_min = shard.range_min.as_ref().map(|m| key >= m.as_str()).unwrap_or(true);
145            let in_max = shard.range_max.as_ref().map(|m| key < m.as_str()).unwrap_or(true);
146            if in_min && in_max {
147                return shard;
148            }
149        }
150        self.get_default_shard()
151    }
152
153    /// Get shard by tenant ID
154    fn get_shard_by_tenant(&self, tenant_id: &str) -> &ShardConfig {
155        self.config
156            .shards
157            .iter()
158            .find(|s| s.active && s.tenant_ids.contains(&tenant_id.to_string()))
159            .unwrap_or_else(|| self.get_default_shard())
160    }
161
162    /// Get shard by time
163    fn get_shard_by_time(&self, time_key: &str) -> &ShardConfig {
164        // Parse time and route to appropriate shard
165        // For simplicity, use range-based routing
166        self.get_shard_by_range(time_key)
167    }
168
169    /// Get shard by region
170    fn get_shard_by_region(&self, region: &str) -> &ShardConfig {
171        self.config
172            .shards
173            .iter()
174            .find(|s| s.active && s.region.as_deref() == Some(region))
175            .unwrap_or_else(|| self.get_default_shard())
176    }
177
178    /// Get shard by round robin (stateless, uses current time)
179    fn get_shard_round_robin(&self) -> &ShardConfig {
180        let active_shards: Vec<_> = self.config.shards.iter().filter(|s| s.active).collect();
181        if active_shards.is_empty() {
182            return self.get_default_shard();
183        }
184        let idx = (Utc::now().timestamp_millis() as usize) % active_shards.len();
185        active_shards[idx]
186    }
187
188    /// Get default shard
189    fn get_default_shard(&self) -> &ShardConfig {
190        self.config
191            .shards
192            .iter()
193            .find(|s| s.id == self.config.default_shard)
194            .unwrap_or(&self.config.shards[0])
195    }
196
197    /// Get all shards for scatter-gather query
198    pub fn get_all_shards(&self) -> Vec<&ShardConfig> {
199        self.config.shards.iter().filter(|s| s.active).collect()
200    }
201}
202
203// ============================================================================
204// Consistent Hash Ring
205// ============================================================================
206
207/// Consistent hash ring for shard distribution
208struct ConsistentHashRing {
209    ring: Vec<(u64, String)>,
210}
211
212impl ConsistentHashRing {
213    fn new(config: &ShardingConfig) -> Self {
214        let mut ring = Vec::new();
215
216        for shard in &config.shards {
217            if !shard.active {
218                continue;
219            }
220            // Add virtual nodes for each shard
221            let vnodes = (config.virtual_shards * shard.weight) / 100;
222            for i in 0..vnodes {
223                let key = format!("{}:{}", shard.id, i);
224                let hash = Self::hash(&key);
225                ring.push((hash, shard.id.clone()));
226            }
227        }
228
229        ring.sort_by_key(|(hash, _)| *hash);
230        Self { ring }
231    }
232
233    fn hash(key: &str) -> u64 {
234        let mut hasher = std::collections::hash_map::DefaultHasher::new();
235        key.hash(&mut hasher);
236        hasher.finish()
237    }
238
239    fn get_node(&self, key: &str) -> String {
240        if self.ring.is_empty() {
241            return "default".to_string();
242        }
243
244        let hash = Self::hash(key);
245
246        // Binary search for the first node >= hash
247        let idx = match self.ring.binary_search_by_key(&hash, |(h, _)| *h) {
248            Ok(i) => i,
249            Err(i) => {
250                if i >= self.ring.len() {
251                    0 // Wrap around
252                } else {
253                    i
254                }
255            }
256        };
257
258        self.ring[idx].1.clone()
259    }
260}
261
262// ============================================================================
263// Read Replica Configuration
264// ============================================================================
265
266/// Read replica configuration
267#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct ReplicaConfig {
269    /// Replica ID
270    pub id: String,
271    /// Replica name
272    pub name: String,
273    /// Connection string
274    pub connection_string: String,
275    /// Region
276    pub region: Option<String>,
277    /// Priority (lower = preferred)
278    pub priority: u32,
279    /// Maximum lag allowed (seconds)
280    pub max_lag_seconds: u32,
281    /// Whether replica is active
282    pub active: bool,
283    /// Current lag (updated dynamically)
284    #[serde(skip)]
285    pub current_lag_ms: u64,
286}
287
288/// Read replica manager
289pub struct ReplicaManager {
290    primary: String,
291    replicas: Vec<ReplicaConfig>,
292    health_status: Arc<RwLock<HashMap<String, ReplicaHealth>>>,
293}
294
295#[derive(Debug, Clone)]
296struct ReplicaHealth {
297    is_healthy: bool,
298    last_check: DateTime<Utc>,
299    lag_ms: u64,
300    error_count: u32,
301}
302
303impl ReplicaManager {
304    /// Create a new replica manager
305    pub fn new(primary: String, replicas: Vec<ReplicaConfig>) -> Self {
306        Self {
307            primary,
308            replicas,
309            health_status: Arc::new(RwLock::new(HashMap::new())),
310        }
311    }
312
313    /// Get the best replica for read operations
314    pub async fn get_read_replica(&self, preferred_region: Option<&str>) -> String {
315        let health = self.health_status.read().await;
316
317        // Filter healthy, active replicas
318        let mut candidates: Vec<_> = self
319            .replicas
320            .iter()
321            .filter(|r| {
322                r.active
323                    && health
324                        .get(&r.id)
325                        .map(|h| h.is_healthy && h.lag_ms < (r.max_lag_seconds as u64 * 1000))
326                        .unwrap_or(false)
327            })
328            .collect();
329
330        if candidates.is_empty() {
331            return self.primary.clone();
332        }
333
334        // Prefer same region
335        if let Some(region) = preferred_region {
336            let regional: Vec<_> = candidates
337                .iter()
338                .filter(|r| r.region.as_deref() == Some(region))
339                .copied()
340                .collect();
341            if !regional.is_empty() {
342                candidates = regional;
343            }
344        }
345
346        // Sort by priority and lag
347        candidates.sort_by(|a, b| {
348            let lag_a = health.get(&a.id).map(|h| h.lag_ms).unwrap_or(u64::MAX);
349            let lag_b = health.get(&b.id).map(|h| h.lag_ms).unwrap_or(u64::MAX);
350            a.priority.cmp(&b.priority).then(lag_a.cmp(&lag_b))
351        });
352
353        candidates
354            .first()
355            .map(|r| r.connection_string.clone())
356            .unwrap_or_else(|| self.primary.clone())
357    }
358
359    /// Get primary connection for write operations
360    pub fn get_primary(&self) -> &str {
361        &self.primary
362    }
363
364    /// Update replica health status
365    pub async fn update_health(&self, replica_id: &str, is_healthy: bool, lag_ms: u64) {
366        let mut health = self.health_status.write().await;
367        let entry = health.entry(replica_id.to_string()).or_insert(ReplicaHealth {
368            is_healthy: true,
369            last_check: Utc::now(),
370            lag_ms: 0,
371            error_count: 0,
372        });
373
374        entry.is_healthy = is_healthy;
375        entry.last_check = Utc::now();
376        entry.lag_ms = lag_ms;
377        if !is_healthy {
378            entry.error_count += 1;
379        } else {
380            entry.error_count = 0;
381        }
382    }
383
384    /// Health check all replicas
385    pub async fn health_check_all(&self) {
386        for replica in &self.replicas {
387            if !replica.active {
388                continue;
389            }
390
391            // In a real implementation, ping the replica and measure lag
392            let (is_healthy, lag_ms) = self.check_replica_health(&replica.connection_string).await;
393            self.update_health(&replica.id, is_healthy, lag_ms).await;
394        }
395    }
396
397    async fn check_replica_health(&self, _connection_string: &str) -> (bool, u64) {
398        // Simulate health check
399        (true, 50)
400    }
401}
402
403// ============================================================================
404// Connection Pool
405// ============================================================================
406
407/// Connection pool configuration
408#[derive(Debug, Clone, Serialize, Deserialize)]
409pub struct PoolConfig {
410    /// Minimum connections
411    pub min_connections: u32,
412    /// Maximum connections
413    pub max_connections: u32,
414    /// Connection timeout (seconds)
415    pub connect_timeout_seconds: u32,
416    /// Idle timeout (seconds)
417    pub idle_timeout_seconds: u32,
418    /// Max lifetime (seconds)
419    pub max_lifetime_seconds: u32,
420}
421
422impl Default for PoolConfig {
423    fn default() -> Self {
424        Self {
425            min_connections: 5,
426            max_connections: 20,
427            connect_timeout_seconds: 30,
428            idle_timeout_seconds: 300,
429            max_lifetime_seconds: 1800,
430        }
431    }
432}
433
434// ============================================================================
435// Database Scaling Manager
436// ============================================================================
437
438/// Manages all database scaling features
439pub struct ScalingManager {
440    shard_router: ShardRouter,
441    replica_manager: ReplicaManager,
442    pool_config: PoolConfig,
443}
444
445impl ScalingManager {
446    /// Create a new scaling manager
447    pub fn new(
448        sharding_config: ShardingConfig,
449        primary: String,
450        replicas: Vec<ReplicaConfig>,
451        pool_config: PoolConfig,
452    ) -> Self {
453        Self {
454            shard_router: ShardRouter::new(sharding_config),
455            replica_manager: ReplicaManager::new(primary, replicas),
456            pool_config,
457        }
458    }
459
460    /// Get connection for write operation
461    pub fn get_write_connection(&self, key: &str) -> &str {
462        let shard = self.shard_router.get_shard(key);
463        &shard.connection_string
464    }
465
466    /// Get connection for read operation
467    pub async fn get_read_connection(&self, key: &str, region: Option<&str>) -> String {
468        // For sharded data, get the shard first
469        let _shard = self.shard_router.get_shard(key);
470        // Then get a replica if available
471        self.replica_manager.get_read_replica(region).await
472    }
473
474    /// Get all shards for scatter-gather
475    pub fn get_all_shards(&self) -> Vec<&ShardConfig> {
476        self.shard_router.get_all_shards()
477    }
478
479    /// Health check
480    pub async fn health_check(&self) {
481        self.replica_manager.health_check_all().await;
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488
489    #[test]
490    fn test_consistent_hash_ring() {
491        let config = ShardingConfig::default();
492        let ring = ConsistentHashRing::new(&config);
493
494        // Same key should always return same node
495        let node1 = ring.get_node("test_key");
496        let node2 = ring.get_node("test_key");
497        assert_eq!(node1, node2);
498    }
499
500    #[test]
501    fn test_shard_router() {
502        let config = ShardingConfig {
503            strategy: ShardingStrategy::Hash,
504            shards: vec![
505                ShardConfig {
506                    id: "shard1".to_string(),
507                    name: "Shard 1".to_string(),
508                    connection_string: "sqlite://shard1.db".to_string(),
509                    weight: 50,
510                    active: true,
511                    region: None,
512                    range_min: None,
513                    range_max: None,
514                    tenant_ids: vec![],
515                },
516                ShardConfig {
517                    id: "shard2".to_string(),
518                    name: "Shard 2".to_string(),
519                    connection_string: "sqlite://shard2.db".to_string(),
520                    weight: 50,
521                    active: true,
522                    region: None,
523                    range_min: None,
524                    range_max: None,
525                    tenant_ids: vec![],
526                },
527            ],
528            default_shard: "shard1".to_string(),
529            ..Default::default()
530        };
531
532        let router = ShardRouter::new(config);
533        let shard = router.get_shard("some_key");
534        assert!(shard.id == "shard1" || shard.id == "shard2");
535    }
536}