use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ShardingStrategy {
Hash,
Range,
Tenant,
Temporal,
Geographic,
RoundRobin,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardConfig {
pub id: String,
pub name: String,
pub connection_string: String,
pub weight: u32,
pub active: bool,
pub region: Option<String>,
pub range_min: Option<String>,
pub range_max: Option<String>,
pub tenant_ids: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardingConfig {
pub strategy: ShardingStrategy,
pub virtual_shards: u32,
pub replication_factor: u32,
pub shards: Vec<ShardConfig>,
pub default_shard: String,
}
impl Default for ShardingConfig {
fn default() -> Self {
Self {
strategy: ShardingStrategy::Hash,
virtual_shards: 256,
replication_factor: 1,
shards: vec![ShardConfig {
id: "default".to_string(),
name: "Default Shard".to_string(),
connection_string: "sqlite://chasm.db".to_string(),
weight: 100,
active: true,
region: None,
range_min: None,
range_max: None,
tenant_ids: vec![],
}],
default_shard: "default".to_string(),
}
}
}
pub struct ShardRouter {
config: ShardingConfig,
ring: ConsistentHashRing,
}
impl ShardRouter {
pub fn new(config: ShardingConfig) -> Self {
let ring = ConsistentHashRing::new(&config);
Self { config, ring }
}
pub fn get_shard(&self, key: &str) -> &ShardConfig {
match self.config.strategy {
ShardingStrategy::Hash => self.get_shard_by_hash(key),
ShardingStrategy::Range => self.get_shard_by_range(key),
ShardingStrategy::Tenant => self.get_shard_by_tenant(key),
ShardingStrategy::Temporal => self.get_shard_by_time(key),
ShardingStrategy::Geographic => self.get_shard_by_region(key),
ShardingStrategy::RoundRobin => self.get_shard_round_robin(),
ShardingStrategy::Custom(_) => self.get_default_shard(),
}
}
fn get_shard_by_hash(&self, key: &str) -> &ShardConfig {
let shard_id = self.ring.get_node(key);
self.config
.shards
.iter()
.find(|s| s.id == shard_id && s.active)
.unwrap_or_else(|| self.get_default_shard())
}
fn get_shard_by_range(&self, key: &str) -> &ShardConfig {
for shard in &self.config.shards {
if !shard.active {
continue;
}
let in_min = shard.range_min.as_ref().map(|m| key >= m.as_str()).unwrap_or(true);
let in_max = shard.range_max.as_ref().map(|m| key < m.as_str()).unwrap_or(true);
if in_min && in_max {
return shard;
}
}
self.get_default_shard()
}
fn get_shard_by_tenant(&self, tenant_id: &str) -> &ShardConfig {
self.config
.shards
.iter()
.find(|s| s.active && s.tenant_ids.contains(&tenant_id.to_string()))
.unwrap_or_else(|| self.get_default_shard())
}
fn get_shard_by_time(&self, time_key: &str) -> &ShardConfig {
self.get_shard_by_range(time_key)
}
fn get_shard_by_region(&self, region: &str) -> &ShardConfig {
self.config
.shards
.iter()
.find(|s| s.active && s.region.as_deref() == Some(region))
.unwrap_or_else(|| self.get_default_shard())
}
fn get_shard_round_robin(&self) -> &ShardConfig {
let active_shards: Vec<_> = self.config.shards.iter().filter(|s| s.active).collect();
if active_shards.is_empty() {
return self.get_default_shard();
}
let idx = (Utc::now().timestamp_millis() as usize) % active_shards.len();
active_shards[idx]
}
fn get_default_shard(&self) -> &ShardConfig {
self.config
.shards
.iter()
.find(|s| s.id == self.config.default_shard)
.unwrap_or(&self.config.shards[0])
}
pub fn get_all_shards(&self) -> Vec<&ShardConfig> {
self.config.shards.iter().filter(|s| s.active).collect()
}
}
struct ConsistentHashRing {
ring: Vec<(u64, String)>,
}
impl ConsistentHashRing {
fn new(config: &ShardingConfig) -> Self {
let mut ring = Vec::new();
for shard in &config.shards {
if !shard.active {
continue;
}
let vnodes = (config.virtual_shards * shard.weight) / 100;
for i in 0..vnodes {
let key = format!("{}:{}", shard.id, i);
let hash = Self::hash(&key);
ring.push((hash, shard.id.clone()));
}
}
ring.sort_by_key(|(hash, _)| *hash);
Self { ring }
}
fn hash(key: &str) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
fn get_node(&self, key: &str) -> String {
if self.ring.is_empty() {
return "default".to_string();
}
let hash = Self::hash(key);
let idx = match self.ring.binary_search_by_key(&hash, |(h, _)| *h) {
Ok(i) => i,
Err(i) => {
if i >= self.ring.len() {
0 } else {
i
}
}
};
self.ring[idx].1.clone()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReplicaConfig {
pub id: String,
pub name: String,
pub connection_string: String,
pub region: Option<String>,
pub priority: u32,
pub max_lag_seconds: u32,
pub active: bool,
#[serde(skip)]
pub current_lag_ms: u64,
}
pub struct ReplicaManager {
primary: String,
replicas: Vec<ReplicaConfig>,
health_status: Arc<RwLock<HashMap<String, ReplicaHealth>>>,
}
#[derive(Debug, Clone)]
struct ReplicaHealth {
is_healthy: bool,
last_check: DateTime<Utc>,
lag_ms: u64,
error_count: u32,
}
impl ReplicaManager {
pub fn new(primary: String, replicas: Vec<ReplicaConfig>) -> Self {
Self {
primary,
replicas,
health_status: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn get_read_replica(&self, preferred_region: Option<&str>) -> String {
let health = self.health_status.read().await;
let mut candidates: Vec<_> = self
.replicas
.iter()
.filter(|r| {
r.active
&& health
.get(&r.id)
.map(|h| h.is_healthy && h.lag_ms < (r.max_lag_seconds as u64 * 1000))
.unwrap_or(false)
})
.collect();
if candidates.is_empty() {
return self.primary.clone();
}
if let Some(region) = preferred_region {
let regional: Vec<_> = candidates
.iter()
.filter(|r| r.region.as_deref() == Some(region))
.copied()
.collect();
if !regional.is_empty() {
candidates = regional;
}
}
candidates.sort_by(|a, b| {
let lag_a = health.get(&a.id).map(|h| h.lag_ms).unwrap_or(u64::MAX);
let lag_b = health.get(&b.id).map(|h| h.lag_ms).unwrap_or(u64::MAX);
a.priority.cmp(&b.priority).then(lag_a.cmp(&lag_b))
});
candidates
.first()
.map(|r| r.connection_string.clone())
.unwrap_or_else(|| self.primary.clone())
}
pub fn get_primary(&self) -> &str {
&self.primary
}
pub async fn update_health(&self, replica_id: &str, is_healthy: bool, lag_ms: u64) {
let mut health = self.health_status.write().await;
let entry = health.entry(replica_id.to_string()).or_insert(ReplicaHealth {
is_healthy: true,
last_check: Utc::now(),
lag_ms: 0,
error_count: 0,
});
entry.is_healthy = is_healthy;
entry.last_check = Utc::now();
entry.lag_ms = lag_ms;
if !is_healthy {
entry.error_count += 1;
} else {
entry.error_count = 0;
}
}
pub async fn health_check_all(&self) {
for replica in &self.replicas {
if !replica.active {
continue;
}
let (is_healthy, lag_ms) = self.check_replica_health(&replica.connection_string).await;
self.update_health(&replica.id, is_healthy, lag_ms).await;
}
}
async fn check_replica_health(&self, _connection_string: &str) -> (bool, u64) {
(true, 50)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PoolConfig {
pub min_connections: u32,
pub max_connections: u32,
pub connect_timeout_seconds: u32,
pub idle_timeout_seconds: u32,
pub max_lifetime_seconds: u32,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
min_connections: 5,
max_connections: 20,
connect_timeout_seconds: 30,
idle_timeout_seconds: 300,
max_lifetime_seconds: 1800,
}
}
}
pub struct ScalingManager {
shard_router: ShardRouter,
replica_manager: ReplicaManager,
pool_config: PoolConfig,
}
impl ScalingManager {
pub fn new(
sharding_config: ShardingConfig,
primary: String,
replicas: Vec<ReplicaConfig>,
pool_config: PoolConfig,
) -> Self {
Self {
shard_router: ShardRouter::new(sharding_config),
replica_manager: ReplicaManager::new(primary, replicas),
pool_config,
}
}
pub fn get_write_connection(&self, key: &str) -> &str {
let shard = self.shard_router.get_shard(key);
&shard.connection_string
}
pub async fn get_read_connection(&self, key: &str, region: Option<&str>) -> String {
let _shard = self.shard_router.get_shard(key);
self.replica_manager.get_read_replica(region).await
}
pub fn get_all_shards(&self) -> Vec<&ShardConfig> {
self.shard_router.get_all_shards()
}
pub async fn health_check(&self) {
self.replica_manager.health_check_all().await;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_consistent_hash_ring() {
let config = ShardingConfig::default();
let ring = ConsistentHashRing::new(&config);
let node1 = ring.get_node("test_key");
let node2 = ring.get_node("test_key");
assert_eq!(node1, node2);
}
#[test]
fn test_shard_router() {
let config = ShardingConfig {
strategy: ShardingStrategy::Hash,
shards: vec![
ShardConfig {
id: "shard1".to_string(),
name: "Shard 1".to_string(),
connection_string: "sqlite://shard1.db".to_string(),
weight: 50,
active: true,
region: None,
range_min: None,
range_max: None,
tenant_ids: vec![],
},
ShardConfig {
id: "shard2".to_string(),
name: "Shard 2".to_string(),
connection_string: "sqlite://shard2.db".to_string(),
weight: 50,
active: true,
region: None,
range_min: None,
range_max: None,
tenant_ids: vec![],
},
],
default_shard: "shard1".to_string(),
..Default::default()
};
let router = ShardRouter::new(config);
let shard = router.get_shard("some_key");
assert!(shard.id == "shard1" || shard.id == "shard2");
}
}