use super::types::ShardId;
use std::collections::HashMap;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct ShardDefinition {
pub id: ShardId,
pub endpoint: String,
pub labels: Vec<String>,
pub replicas: Vec<String>,
pub weight: u32,
}
impl ShardDefinition {
pub fn new<S: Into<String>>(id: u16, endpoint: S, labels: Vec<&str>) -> Self {
Self {
id: ShardId::new_unchecked(id),
endpoint: endpoint.into(),
labels: labels.into_iter().map(|s| s.to_string()).collect(),
replicas: Vec::new(),
weight: 100,
}
}
pub fn with_replicas(mut self, replicas: Vec<&str>) -> Self {
self.replicas = replicas.into_iter().map(|s| s.to_string()).collect();
self
}
pub fn with_weight(mut self, weight: u32) -> Self {
self.weight = weight;
self
}
}
#[derive(Debug, Clone)]
pub enum ShardDiscovery {
Static(Vec<String>),
Etcd {
endpoints: Vec<String>,
prefix: String,
},
Consul {
address: String,
service: String,
},
}
impl Default for ShardDiscovery {
fn default() -> Self {
ShardDiscovery::Static(Vec::new())
}
}
#[derive(Debug, Clone)]
pub struct ShardConfig {
pub shards: Vec<ShardDefinition>,
pub default_shard: ShardId,
pub discovery: ShardDiscovery,
pub connection_timeout: Duration,
pub request_timeout: Duration,
pub max_connections_per_shard: usize,
pub auto_failover: bool,
pub health_check_interval: Duration,
pub max_retries: u32,
pub retry_base_delay: Duration,
pub wal_path: Option<std::path::PathBuf>,
}
impl ShardConfig {
pub fn new(shards: Vec<ShardDefinition>) -> Self {
let default_shard = shards
.first()
.map(|s| s.id)
.unwrap_or_else(|| ShardId::new_unchecked(0));
Self {
shards,
default_shard,
discovery: ShardDiscovery::default(),
connection_timeout: Duration::from_secs(5),
request_timeout: Duration::from_secs(30),
max_connections_per_shard: 10,
auto_failover: true,
health_check_interval: Duration::from_secs(10),
max_retries: 3,
retry_base_delay: Duration::from_millis(100),
wal_path: None,
}
}
pub fn single_shard() -> Self {
Self::new(vec![ShardDefinition::new(0, "localhost:9000", vec![])])
}
pub fn with_default_shard(mut self, shard_id: ShardId) -> Self {
self.default_shard = shard_id;
self
}
pub fn with_discovery(mut self, discovery: ShardDiscovery) -> Self {
self.discovery = discovery;
self
}
pub fn with_connection_timeout(mut self, timeout: Duration) -> Self {
self.connection_timeout = timeout;
self
}
pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
self.request_timeout = timeout;
self
}
pub fn with_max_connections(mut self, max: usize) -> Self {
self.max_connections_per_shard = max;
self
}
pub fn with_wal_path<P: Into<std::path::PathBuf>>(mut self, path: P) -> Self {
self.wal_path = Some(path.into());
self
}
pub fn build_label_map(&self) -> HashMap<String, ShardId> {
let mut map = HashMap::new();
for shard in &self.shards {
for label in &shard.labels {
map.insert(label.clone(), shard.id);
}
}
map
}
pub fn get_shard(&self, id: ShardId) -> Option<&ShardDefinition> {
self.shards.iter().find(|s| s.id == id)
}
pub fn shard_ids(&self) -> Vec<ShardId> {
self.shards.iter().map(|s| s.id).collect()
}
pub fn num_shards(&self) -> usize {
self.shards.len()
}
pub fn validate(&self) -> Result<(), String> {
let mut seen_ids = std::collections::HashSet::new();
for shard in &self.shards {
if !seen_ids.insert(shard.id) {
return Err(format!("Duplicate shard ID: {}", shard.id));
}
}
let mut seen_labels = HashMap::new();
for shard in &self.shards {
for label in &shard.labels {
if let Some(existing) = seen_labels.insert(label.clone(), shard.id) {
return Err(format!(
"Label '{}' assigned to multiple shards: {} and {}",
label, existing, shard.id
));
}
}
}
if !self.shards.iter().any(|s| s.id == self.default_shard) {
return Err(format!(
"Default shard {} not found in shard list",
self.default_shard
));
}
Ok(())
}
}
impl Default for ShardConfig {
fn default() -> Self {
Self::single_shard()
}
}
#[derive(Debug, Clone)]
pub struct RebalanceConfig {
pub imbalance_threshold: f64,
pub batch_size: usize,
pub cooldown: Duration,
pub max_concurrent_migrations: usize,
pub auto_rebalance: bool,
pub migration_delay: Duration,
pub migration_timeout: Duration,
pub migration_retries: u32,
}
impl RebalanceConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_imbalance_threshold(mut self, threshold: f64) -> Self {
self.imbalance_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn with_batch_size(mut self, size: usize) -> Self {
self.batch_size = size;
self
}
pub fn with_cooldown(mut self, cooldown: Duration) -> Self {
self.cooldown = cooldown;
self
}
pub fn with_auto_rebalance(mut self, enabled: bool) -> Self {
self.auto_rebalance = enabled;
self
}
pub fn should_rebalance(&self, imbalance: f64) -> bool {
self.auto_rebalance && imbalance > self.imbalance_threshold
}
}
impl Default for RebalanceConfig {
fn default() -> Self {
Self {
imbalance_threshold: 0.3, batch_size: 10_000,
cooldown: Duration::from_secs(3600), max_concurrent_migrations: 2,
auto_rebalance: true,
migration_delay: Duration::from_secs(10),
migration_timeout: Duration::from_secs(300), migration_retries: 3,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shard_definition_creation() {
let def = ShardDefinition::new(0, "localhost:9000", vec!["Person", "User"]);
assert_eq!(def.id.as_u16(), 0);
assert_eq!(def.endpoint, "localhost:9000");
assert_eq!(def.labels, vec!["Person", "User"]);
assert!(def.replicas.is_empty());
assert_eq!(def.weight, 100);
}
#[test]
fn test_shard_definition_with_replicas() {
let def = ShardDefinition::new(0, "primary:9000", vec!["Person"])
.with_replicas(vec!["replica1:9000", "replica2:9000"])
.with_weight(150);
assert_eq!(def.replicas.len(), 2);
assert_eq!(def.weight, 150);
}
#[test]
fn test_shard_config_creation() {
let config = ShardConfig::new(vec![
ShardDefinition::new(0, "shard0:9000", vec!["Person", "User"]),
ShardDefinition::new(1, "shard1:9000", vec!["Place", "Location"]),
]);
assert_eq!(config.num_shards(), 2);
assert_eq!(config.default_shard.as_u16(), 0);
}
#[test]
fn test_shard_config_single_shard() {
let config = ShardConfig::single_shard();
assert_eq!(config.num_shards(), 1);
}
#[test]
fn test_shard_config_label_map() {
let config = ShardConfig::new(vec![
ShardDefinition::new(0, "shard0:9000", vec!["Person", "User"]),
ShardDefinition::new(1, "shard1:9000", vec!["Place", "Location"]),
]);
let label_map = config.build_label_map();
assert_eq!(label_map.get("Person").unwrap().as_u16(), 0);
assert_eq!(label_map.get("User").unwrap().as_u16(), 0);
assert_eq!(label_map.get("Place").unwrap().as_u16(), 1);
assert_eq!(label_map.get("Location").unwrap().as_u16(), 1);
assert!(!label_map.contains_key("Unknown"));
}
#[test]
fn test_shard_config_get_shard() {
let config = ShardConfig::new(vec![
ShardDefinition::new(0, "shard0:9000", vec!["Person"]),
ShardDefinition::new(1, "shard1:9000", vec!["Place"]),
]);
let shard = config.get_shard(ShardId::new(1).unwrap());
assert!(shard.is_some());
assert_eq!(shard.unwrap().endpoint, "shard1:9000");
assert!(config.get_shard(ShardId::new(99).unwrap()).is_none());
}
#[test]
fn test_shard_config_validation() {
let config = ShardConfig::new(vec![
ShardDefinition::new(0, "shard0:9000", vec!["Person"]),
ShardDefinition::new(1, "shard1:9000", vec!["Place"]),
]);
assert!(config.validate().is_ok());
let bad_config = ShardConfig::new(vec![
ShardDefinition::new(0, "shard0:9000", vec!["Person"]),
ShardDefinition::new(0, "shard1:9000", vec!["Place"]),
]);
assert!(bad_config.validate().is_err());
let bad_config = ShardConfig::new(vec![
ShardDefinition::new(0, "shard0:9000", vec!["Person"]),
ShardDefinition::new(1, "shard1:9000", vec!["Person"]),
]);
assert!(bad_config.validate().is_err());
}
#[test]
fn test_shard_config_default_shard_validation() {
let mut config =
ShardConfig::new(vec![ShardDefinition::new(0, "shard0:9000", vec!["Person"])]);
config.default_shard = ShardId::new(99).unwrap();
assert!(config.validate().is_err());
}
#[test]
fn test_shard_discovery_default() {
let discovery = ShardDiscovery::default();
assert!(matches!(discovery, ShardDiscovery::Static(_)));
}
#[test]
fn test_rebalance_config_defaults() {
let config = RebalanceConfig::new();
assert!((config.imbalance_threshold - 0.3).abs() < 0.001);
assert_eq!(config.batch_size, 10_000);
assert!(config.auto_rebalance);
}
#[test]
fn test_rebalance_config_builders() {
let config = RebalanceConfig::new()
.with_imbalance_threshold(0.5)
.with_batch_size(5000)
.with_auto_rebalance(false);
assert!((config.imbalance_threshold - 0.5).abs() < 0.001);
assert_eq!(config.batch_size, 5000);
assert!(!config.auto_rebalance);
}
#[test]
fn test_rebalance_config_threshold_clamping() {
let config = RebalanceConfig::new().with_imbalance_threshold(1.5);
assert!((config.imbalance_threshold - 1.0).abs() < 0.001);
let config = RebalanceConfig::new().with_imbalance_threshold(-0.5);
assert!((config.imbalance_threshold - 0.0).abs() < 0.001);
}
#[test]
fn test_rebalance_config_should_rebalance() {
let config = RebalanceConfig::new().with_imbalance_threshold(0.3);
assert!(!config.should_rebalance(0.2));
assert!(config.should_rebalance(0.4));
let disabled = config.with_auto_rebalance(false);
assert!(!disabled.should_rebalance(0.5));
}
}