use serde::{Deserialize, Serialize};
use std::collections::hash_map::DefaultHasher;
use std::collections::{BTreeMap, HashMap};
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardingConfig {
pub num_shards: u32,
pub replication_factor: u32,
pub strategy: ShardingStrategy,
pub virtual_nodes: u32,
}
impl Default for ShardingConfig {
fn default() -> Self {
Self {
num_shards: 4,
replication_factor: 2,
strategy: ShardingStrategy::ConsistentHash,
virtual_nodes: 150,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ShardingStrategy {
ConsistentHash,
Range,
Modulo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PartitionInfo {
pub shard_id: u32,
pub node_ids: Vec<String>,
pub primary_node: String,
pub is_healthy: bool,
pub vector_count: u64,
pub memory_bytes: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardAssignment {
pub shard_id: u32,
pub nodes: Vec<String>,
pub preferred_node: String,
}
#[derive(Debug, Clone)]
pub struct ConsistentHashRing {
ring: BTreeMap<u64, u32>,
config: ShardingConfig,
shard_nodes: HashMap<u32, Vec<String>>,
}
impl ConsistentHashRing {
pub fn new(config: ShardingConfig) -> Self {
let mut ring = BTreeMap::new();
for shard_id in 0..config.num_shards {
for vnode in 0..config.virtual_nodes {
let key = format!("shard-{}-vnode-{}", shard_id, vnode);
let hash = Self::hash_key(&key);
ring.insert(hash, shard_id);
}
}
Self {
ring,
config,
shard_nodes: HashMap::new(),
}
}
fn hash_key(key: &str) -> u64 {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
pub fn get_shard(&self, vector_id: &str) -> ShardAssignment {
let hash = Self::hash_key(vector_id);
let shard_id = self
.ring
.range(hash..)
.next()
.or_else(|| self.ring.iter().next())
.map(|(_, &shard)| shard)
.unwrap_or(0);
let nodes = self
.shard_nodes
.get(&shard_id)
.cloned()
.unwrap_or_else(|| vec![format!("node-{}", shard_id)]);
let preferred_node = nodes.first().cloned().unwrap_or_default();
ShardAssignment {
shard_id,
nodes,
preferred_node,
}
}
pub fn get_shards_batch(&self, vector_ids: &[String]) -> HashMap<u32, Vec<String>> {
let mut shard_vectors: HashMap<u32, Vec<String>> = HashMap::new();
for id in vector_ids {
let assignment = self.get_shard(id);
shard_vectors
.entry(assignment.shard_id)
.or_default()
.push(id.clone());
}
shard_vectors
}
pub fn register_shard_nodes(&mut self, shard_id: u32, node_ids: Vec<String>) {
self.shard_nodes.insert(shard_id, node_ids);
}
pub fn get_all_shards(&self) -> Vec<u32> {
(0..self.config.num_shards).collect()
}
pub fn get_partition_info(&self) -> Vec<PartitionInfo> {
(0..self.config.num_shards)
.map(|shard_id| {
let nodes = self
.shard_nodes
.get(&shard_id)
.cloned()
.unwrap_or_else(|| vec![format!("node-{}", shard_id)]);
let primary = nodes.first().cloned().unwrap_or_default();
PartitionInfo {
shard_id,
node_ids: nodes,
primary_node: primary,
is_healthy: true,
vector_count: 0,
memory_bytes: 0,
}
})
.collect()
}
pub fn rebalance(&mut self, new_node_count: u32) {
for shard_id in 0..self.config.num_shards {
let mut nodes = Vec::new();
for replica in 0..self.config.replication_factor.min(new_node_count) {
let node_idx = (shard_id + replica) % new_node_count;
nodes.push(format!("node-{}", node_idx));
}
self.shard_nodes.insert(shard_id, nodes);
}
}
}
#[derive(Debug, Clone)]
pub struct RangeSharder {
boundaries: Vec<u64>,
config: ShardingConfig,
shard_nodes: HashMap<u32, Vec<String>>,
}
impl RangeSharder {
pub fn new(config: ShardingConfig) -> Self {
let step = u64::MAX / config.num_shards as u64;
let boundaries: Vec<u64> = (1..config.num_shards).map(|i| step * i as u64).collect();
Self {
boundaries,
config,
shard_nodes: HashMap::new(),
}
}
pub fn get_shard(&self, vector_id: &str) -> ShardAssignment {
let hash = {
let mut hasher = DefaultHasher::new();
vector_id.hash(&mut hasher);
hasher.finish()
};
let shard_id = self
.boundaries
.iter()
.position(|&b| hash < b)
.unwrap_or(self.config.num_shards as usize - 1) as u32;
let nodes = self
.shard_nodes
.get(&shard_id)
.cloned()
.unwrap_or_else(|| vec![format!("node-{}", shard_id)]);
let preferred_node = nodes.first().cloned().unwrap_or_default();
ShardAssignment {
shard_id,
nodes,
preferred_node,
}
}
pub fn register_shard_nodes(&mut self, shard_id: u32, node_ids: Vec<String>) {
self.shard_nodes.insert(shard_id, node_ids);
}
}
pub struct ShardManager {
config: ShardingConfig,
consistent_ring: Option<ConsistentHashRing>,
range_sharder: Option<RangeSharder>,
}
impl ShardManager {
pub fn new(config: ShardingConfig) -> Self {
let (consistent_ring, range_sharder) = match config.strategy {
ShardingStrategy::ConsistentHash | ShardingStrategy::Modulo => {
(Some(ConsistentHashRing::new(config.clone())), None)
}
ShardingStrategy::Range => (None, Some(RangeSharder::new(config.clone()))),
};
Self {
config,
consistent_ring,
range_sharder,
}
}
pub fn get_shard(&self, vector_id: &str) -> ShardAssignment {
match self.config.strategy {
ShardingStrategy::ConsistentHash | ShardingStrategy::Modulo => {
match self.consistent_ring.as_ref() {
Some(ring) => ring.get_shard(vector_id),
None => {
tracing::error!("consistent_ring not initialized for ConsistentHash/Modulo strategy — falling back to shard 0");
ShardAssignment {
shard_id: 0,
nodes: vec![],
preferred_node: String::new(),
}
}
}
}
ShardingStrategy::Range => match self.range_sharder.as_ref() {
Some(sharder) => sharder.get_shard(vector_id),
None => {
tracing::error!("range_sharder not initialized for Range strategy — falling back to shard 0");
ShardAssignment {
shard_id: 0,
nodes: vec![],
preferred_node: String::new(),
}
}
},
}
}
pub fn get_shards_batch(&self, vector_ids: &[String]) -> HashMap<u32, Vec<String>> {
let mut shard_vectors: HashMap<u32, Vec<String>> = HashMap::new();
for id in vector_ids {
let assignment = self.get_shard(id);
shard_vectors
.entry(assignment.shard_id)
.or_default()
.push(id.clone());
}
shard_vectors
}
pub fn get_all_shards(&self) -> Vec<u32> {
(0..self.config.num_shards).collect()
}
pub fn register_shard_nodes(&mut self, shard_id: u32, node_ids: Vec<String>) {
if let Some(ref mut ring) = self.consistent_ring {
ring.register_shard_nodes(shard_id, node_ids);
} else if let Some(ref mut sharder) = self.range_sharder {
sharder.register_shard_nodes(shard_id, node_ids);
}
}
pub fn get_partition_info(&self) -> Vec<PartitionInfo> {
if let Some(ref ring) = self.consistent_ring {
ring.get_partition_info()
} else {
(0..self.config.num_shards)
.map(|shard_id| PartitionInfo {
shard_id,
node_ids: vec![format!("node-{}", shard_id)],
primary_node: format!("node-{}", shard_id),
is_healthy: true,
vector_count: 0,
memory_bytes: 0,
})
.collect()
}
}
pub fn rebalance(&mut self, node_count: u32) {
if let Some(ref mut ring) = self.consistent_ring {
ring.rebalance(node_count);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_consistent_hash_ring() {
let config = ShardingConfig {
num_shards: 4,
replication_factor: 2,
strategy: ShardingStrategy::ConsistentHash,
virtual_nodes: 100,
};
let ring = ConsistentHashRing::new(config);
let assignment1 = ring.get_shard("vector-123");
let assignment2 = ring.get_shard("vector-123");
assert_eq!(assignment1.shard_id, assignment2.shard_id);
for i in 0..100 {
let assignment = ring.get_shard(&format!("test-{}", i));
assert!(assignment.shard_id < 4);
}
}
#[test]
fn test_consistent_hash_distribution() {
let config = ShardingConfig {
num_shards: 4,
replication_factor: 2,
strategy: ShardingStrategy::ConsistentHash,
virtual_nodes: 150,
};
let ring = ConsistentHashRing::new(config);
let mut counts = [0u32; 4];
for i in 0..1000 {
let assignment = ring.get_shard(&format!("vector-{}", i));
counts[assignment.shard_id as usize] += 1;
}
let avg = 250.0;
for count in counts {
assert!(count as f64 > avg * 0.5);
assert!((count as f64) < avg * 1.5);
}
}
#[test]
fn test_batch_sharding() {
let config = ShardingConfig::default();
let ring = ConsistentHashRing::new(config);
let ids: Vec<String> = (0..100).map(|i| format!("vec-{}", i)).collect();
let shard_batches = ring.get_shards_batch(&ids);
let total: usize = shard_batches.values().map(|v| v.len()).sum();
assert_eq!(total, 100);
}
#[test]
fn test_range_sharder() {
let config = ShardingConfig {
num_shards: 4,
replication_factor: 1,
strategy: ShardingStrategy::Range,
virtual_nodes: 0, };
let sharder = RangeSharder::new(config);
let a1 = sharder.get_shard("test-key");
let a2 = sharder.get_shard("test-key");
assert_eq!(a1.shard_id, a2.shard_id);
for i in 0..100 {
let assignment = sharder.get_shard(&format!("key-{}", i));
assert!(assignment.shard_id < 4);
}
}
#[test]
fn test_shard_manager() {
let config = ShardingConfig::default();
let mut manager = ShardManager::new(config);
manager.register_shard_nodes(0, vec!["node-a".to_string(), "node-b".to_string()]);
let assignment = manager.get_shard("my-vector");
assert!(assignment.shard_id < 4);
let shards = manager.get_all_shards();
assert_eq!(shards.len(), 4);
let partitions = manager.get_partition_info();
assert_eq!(partitions.len(), 4);
}
#[test]
fn test_rebalance() {
let config = ShardingConfig {
num_shards: 4,
replication_factor: 2,
..Default::default()
};
let mut ring = ConsistentHashRing::new(config);
ring.rebalance(3);
let partitions = ring.get_partition_info();
for partition in partitions {
assert!(!partition.node_ids.is_empty());
assert!(partition.node_ids.len() <= 2); }
}
}