use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use tracing::debug;
const VIRTUAL_NODE_COUNT: usize = 150;
#[derive(Debug)]
pub struct ConsistentHashRing {
ring: BTreeMap<u64, String>,
nodes: HashMap<String, usize>,
replication_factor: usize,
}
impl ConsistentHashRing {
pub fn new(replication_factor: usize) -> Self {
Self {
ring: BTreeMap::new(),
nodes: HashMap::new(),
replication_factor,
}
}
pub fn add_node(&mut self, node_id: String) {
if self.nodes.contains_key(&node_id) {
return;
}
for i in 0..VIRTUAL_NODE_COUNT {
let virtual_key = format!("{}:{}", node_id, i);
let hash = Self::hash_key(&virtual_key);
self.ring.insert(hash, node_id.clone());
}
self.nodes.insert(node_id, VIRTUAL_NODE_COUNT);
debug!(
"Added node to hash ring with {} virtual nodes",
VIRTUAL_NODE_COUNT
);
}
pub fn remove_node(&mut self, node_id: &str) {
if !self.nodes.contains_key(node_id) {
return;
}
self.ring.retain(|_, v| v != node_id);
self.nodes.remove(node_id);
debug!("Removed node from hash ring");
}
pub fn get_nodes(&self, key: &str, count: usize) -> Vec<String> {
if self.ring.is_empty() {
return Vec::new();
}
let hash = Self::hash_key(key);
let mut nodes = Vec::new();
let mut seen = std::collections::HashSet::new();
for (_, node_id) in self.ring.range(hash..) {
if seen.insert(node_id.clone()) {
nodes.push(node_id.clone());
if nodes.len() >= count {
return nodes;
}
}
}
for (_, node_id) in self.ring.iter() {
if seen.insert(node_id.clone()) {
nodes.push(node_id.clone());
if nodes.len() >= count {
return nodes;
}
}
}
nodes
}
pub fn get_primary_node(&self, key: &str) -> Option<String> {
self.get_nodes(key, 1).first().cloned()
}
fn hash_key(key: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn list_nodes(&self) -> Vec<String> {
self.nodes.keys().cloned().collect()
}
}
pub struct ShardRouter {
shard_count: u32,
cache: Arc<RwLock<HashMap<String, u32>>>,
}
impl ShardRouter {
pub fn new(shard_count: u32) -> Self {
Self {
shard_count,
cache: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn get_shard(&self, key: &str) -> u32 {
{
let cache = self.cache.read();
if let Some(&shard_id) = cache.get(key) {
return shard_id;
}
}
let shard_id = self.jump_consistent_hash(key, self.shard_count);
{
let mut cache = self.cache.write();
cache.insert(key.to_string(), shard_id);
}
shard_id
}
fn jump_consistent_hash(&self, key: &str, num_buckets: u32) -> u32 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let mut hash = hasher.finish();
let mut b: i64 = -1;
let mut j: i64 = 0;
while j < num_buckets as i64 {
b = j;
hash = hash.wrapping_mul(2862933555777941757).wrapping_add(1);
j = ((b.wrapping_add(1) as f64)
* ((1i64 << 31) as f64 / ((hash >> 33).wrapping_add(1) as f64)))
as i64;
}
b as u32
}
pub fn get_shard_for_vector(&self, vector_id: &str) -> u32 {
self.get_shard(vector_id)
}
pub fn get_shards_for_range(&self, _start: &str, _end: &str) -> Vec<u32> {
(0..self.shard_count).collect()
}
pub fn clear_cache(&self) {
let mut cache = self.cache.write();
cache.clear();
}
pub fn cache_stats(&self) -> CacheStats {
let cache = self.cache.read();
CacheStats {
entries: cache.len(),
shard_count: self.shard_count as usize,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheStats {
pub entries: usize,
pub shard_count: usize,
}
pub struct ShardMigration {
pub source_shard: u32,
pub target_shard: u32,
pub progress: f64,
pub keys_migrated: usize,
pub total_keys: usize,
}
impl ShardMigration {
pub fn new(source_shard: u32, target_shard: u32, total_keys: usize) -> Self {
Self {
source_shard,
target_shard,
progress: 0.0,
keys_migrated: 0,
total_keys,
}
}
pub fn update_progress(&mut self, keys_migrated: usize) {
self.keys_migrated = keys_migrated;
self.progress = if self.total_keys > 0 {
keys_migrated as f64 / self.total_keys as f64
} else {
1.0
};
}
pub fn is_complete(&self) -> bool {
self.progress >= 1.0 || self.keys_migrated >= self.total_keys
}
}
pub struct LoadBalancer {
loads: Arc<RwLock<HashMap<u32, f64>>>,
}
impl LoadBalancer {
pub fn new() -> Self {
Self {
loads: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn update_load(&self, shard_id: u32, load: f64) {
let mut loads = self.loads.write();
loads.insert(shard_id, load);
}
pub fn get_load(&self, shard_id: u32) -> f64 {
let loads = self.loads.read();
loads.get(&shard_id).copied().unwrap_or(0.0)
}
pub fn get_least_loaded_shard(&self, shard_ids: &[u32]) -> Option<u32> {
let loads = self.loads.read();
shard_ids
.iter()
.min_by(|&&a, &&b| {
let load_a = loads.get(&a).copied().unwrap_or(0.0);
let load_b = loads.get(&b).copied().unwrap_or(0.0);
load_a
.partial_cmp(&load_b)
.unwrap_or(std::cmp::Ordering::Equal)
})
.copied()
}
pub fn get_stats(&self) -> LoadStats {
let loads = self.loads.read();
let total: f64 = loads.values().sum();
let count = loads.len();
let avg = if count > 0 { total / count as f64 } else { 0.0 };
let max = loads.values().copied().fold(f64::NEG_INFINITY, f64::max);
let min = loads.values().copied().fold(f64::INFINITY, f64::min);
LoadStats {
total_load: total,
avg_load: avg,
max_load: if max.is_finite() { max } else { 0.0 },
min_load: if min.is_finite() { min } else { 0.0 },
shard_count: count,
}
}
}
impl Default for LoadBalancer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoadStats {
pub total_load: f64,
pub avg_load: f64,
pub max_load: f64,
pub min_load: f64,
pub shard_count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_consistent_hash_ring() {
let mut ring = ConsistentHashRing::new(3);
ring.add_node("node1".to_string());
ring.add_node("node2".to_string());
ring.add_node("node3".to_string());
assert_eq!(ring.node_count(), 3);
let nodes = ring.get_nodes("test-key", 3);
assert_eq!(nodes.len(), 3);
let primary = ring.get_primary_node("test-key");
assert!(primary.is_some());
}
#[test]
fn test_consistent_hashing_distribution() {
let mut ring = ConsistentHashRing::new(3);
ring.add_node("node1".to_string());
ring.add_node("node2".to_string());
ring.add_node("node3".to_string());
let mut distribution: HashMap<String, usize> = HashMap::new();
for i in 0..1000 {
let key = format!("key{}", i);
if let Some(node) = ring.get_primary_node(&key) {
*distribution.entry(node).or_insert(0) += 1;
}
}
for count in distribution.values() {
let ratio = *count as f64 / 1000.0;
assert!(ratio > 0.2 && ratio < 0.5, "Distribution ratio: {}", ratio);
}
}
#[test]
fn test_shard_router() {
let router = ShardRouter::new(16);
let shard1 = router.get_shard("test-key-1");
let shard2 = router.get_shard("test-key-1");
assert_eq!(shard1, shard2);
assert!(shard1 < 16);
let stats = router.cache_stats();
assert_eq!(stats.entries, 1);
}
#[test]
fn test_jump_consistent_hash() {
let router = ShardRouter::new(10);
let shard1 = router.get_shard("consistent-key");
let shard2 = router.get_shard("consistent-key");
assert_eq!(shard1, shard2);
}
#[test]
fn test_shard_migration() {
let mut migration = ShardMigration::new(0, 1, 100);
assert!(!migration.is_complete());
assert_eq!(migration.progress, 0.0);
migration.update_progress(50);
assert_eq!(migration.progress, 0.5);
migration.update_progress(100);
assert!(migration.is_complete());
}
#[test]
fn test_load_balancer() {
let balancer = LoadBalancer::new();
balancer.update_load(0, 0.5);
balancer.update_load(1, 0.8);
balancer.update_load(2, 0.3);
let least_loaded = balancer.get_least_loaded_shard(&[0, 1, 2]);
assert_eq!(least_loaded, Some(2));
let stats = balancer.get_stats();
assert_eq!(stats.shard_count, 3);
assert!(stats.avg_load > 0.0);
}
}