1use serde::{Deserialize, Serialize};
9use std::collections::hash_map::DefaultHasher;
10use std::collections::{BTreeMap, HashMap};
11use std::hash::{Hash, Hasher};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ShardingConfig {
16 pub num_shards: u32,
18 pub replication_factor: u32,
20 pub strategy: ShardingStrategy,
22 pub virtual_nodes: u32,
24}
25
26impl Default for ShardingConfig {
27 fn default() -> Self {
28 Self {
29 num_shards: 4,
30 replication_factor: 2,
31 strategy: ShardingStrategy::ConsistentHash,
32 virtual_nodes: 150,
33 }
34 }
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum ShardingStrategy {
40 ConsistentHash,
42 Range,
44 Modulo,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct PartitionInfo {
51 pub shard_id: u32,
53 pub node_ids: Vec<String>,
55 pub primary_node: String,
57 pub is_healthy: bool,
59 pub vector_count: u64,
61 pub memory_bytes: u64,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct ShardAssignment {
68 pub shard_id: u32,
70 pub nodes: Vec<String>,
72 pub preferred_node: String,
74}
75
76#[derive(Debug, Clone)]
78pub struct ConsistentHashRing {
79 ring: BTreeMap<u64, u32>,
81 config: ShardingConfig,
83 shard_nodes: HashMap<u32, Vec<String>>,
85}
86
87impl ConsistentHashRing {
88 pub fn new(config: ShardingConfig) -> Self {
90 let mut ring = BTreeMap::new();
91
92 for shard_id in 0..config.num_shards {
94 for vnode in 0..config.virtual_nodes {
95 let key = format!("shard-{}-vnode-{}", shard_id, vnode);
96 let hash = Self::hash_key(&key);
97 ring.insert(hash, shard_id);
98 }
99 }
100
101 Self {
102 ring,
103 config,
104 shard_nodes: HashMap::new(),
105 }
106 }
107
108 fn hash_key(key: &str) -> u64 {
110 let mut hasher = DefaultHasher::new();
111 key.hash(&mut hasher);
112 hasher.finish()
113 }
114
115 pub fn get_shard(&self, vector_id: &str) -> ShardAssignment {
117 let hash = Self::hash_key(vector_id);
118
119 let shard_id = self
121 .ring
122 .range(hash..)
123 .next()
124 .or_else(|| self.ring.iter().next())
125 .map(|(_, &shard)| shard)
126 .unwrap_or(0);
127
128 let nodes = self
129 .shard_nodes
130 .get(&shard_id)
131 .cloned()
132 .unwrap_or_else(|| vec![format!("node-{}", shard_id)]);
133
134 let preferred_node = nodes.first().cloned().unwrap_or_default();
135
136 ShardAssignment {
137 shard_id,
138 nodes,
139 preferred_node,
140 }
141 }
142
143 pub fn get_shards_batch(&self, vector_ids: &[String]) -> HashMap<u32, Vec<String>> {
145 let mut shard_vectors: HashMap<u32, Vec<String>> = HashMap::new();
146
147 for id in vector_ids {
148 let assignment = self.get_shard(id);
149 shard_vectors
150 .entry(assignment.shard_id)
151 .or_default()
152 .push(id.clone());
153 }
154
155 shard_vectors
156 }
157
158 pub fn register_shard_nodes(&mut self, shard_id: u32, node_ids: Vec<String>) {
160 self.shard_nodes.insert(shard_id, node_ids);
161 }
162
163 pub fn get_all_shards(&self) -> Vec<u32> {
165 (0..self.config.num_shards).collect()
166 }
167
168 pub fn get_partition_info(&self) -> Vec<PartitionInfo> {
170 (0..self.config.num_shards)
171 .map(|shard_id| {
172 let nodes = self
173 .shard_nodes
174 .get(&shard_id)
175 .cloned()
176 .unwrap_or_else(|| vec![format!("node-{}", shard_id)]);
177 let primary = nodes.first().cloned().unwrap_or_default();
178
179 PartitionInfo {
180 shard_id,
181 node_ids: nodes,
182 primary_node: primary,
183 is_healthy: true,
184 vector_count: 0,
185 memory_bytes: 0,
186 }
187 })
188 .collect()
189 }
190
191 pub fn rebalance(&mut self, new_node_count: u32) {
193 for shard_id in 0..self.config.num_shards {
195 let mut nodes = Vec::new();
196 for replica in 0..self.config.replication_factor.min(new_node_count) {
197 let node_idx = (shard_id + replica) % new_node_count;
198 nodes.push(format!("node-{}", node_idx));
199 }
200 self.shard_nodes.insert(shard_id, nodes);
201 }
202 }
203}
204
205#[derive(Debug, Clone)]
207pub struct RangeSharder {
208 boundaries: Vec<u64>,
210 config: ShardingConfig,
212 shard_nodes: HashMap<u32, Vec<String>>,
214}
215
216impl RangeSharder {
217 pub fn new(config: ShardingConfig) -> Self {
219 let step = u64::MAX / config.num_shards as u64;
220 let boundaries: Vec<u64> = (1..config.num_shards).map(|i| step * i as u64).collect();
221
222 Self {
223 boundaries,
224 config,
225 shard_nodes: HashMap::new(),
226 }
227 }
228
229 pub fn get_shard(&self, vector_id: &str) -> ShardAssignment {
231 let hash = {
232 let mut hasher = DefaultHasher::new();
233 vector_id.hash(&mut hasher);
234 hasher.finish()
235 };
236
237 let shard_id = self
239 .boundaries
240 .iter()
241 .position(|&b| hash < b)
242 .unwrap_or(self.config.num_shards as usize - 1) as u32;
243
244 let nodes = self
245 .shard_nodes
246 .get(&shard_id)
247 .cloned()
248 .unwrap_or_else(|| vec![format!("node-{}", shard_id)]);
249
250 let preferred_node = nodes.first().cloned().unwrap_or_default();
251
252 ShardAssignment {
253 shard_id,
254 nodes,
255 preferred_node,
256 }
257 }
258
259 pub fn register_shard_nodes(&mut self, shard_id: u32, node_ids: Vec<String>) {
261 self.shard_nodes.insert(shard_id, node_ids);
262 }
263}
264
265pub struct ShardManager {
267 config: ShardingConfig,
268 consistent_ring: Option<ConsistentHashRing>,
269 range_sharder: Option<RangeSharder>,
270}
271
272impl ShardManager {
273 pub fn new(config: ShardingConfig) -> Self {
275 let (consistent_ring, range_sharder) = match config.strategy {
276 ShardingStrategy::ConsistentHash | ShardingStrategy::Modulo => {
277 (Some(ConsistentHashRing::new(config.clone())), None)
278 }
279 ShardingStrategy::Range => (None, Some(RangeSharder::new(config.clone()))),
280 };
281
282 Self {
283 config,
284 consistent_ring,
285 range_sharder,
286 }
287 }
288
289 pub fn get_shard(&self, vector_id: &str) -> ShardAssignment {
291 match self.config.strategy {
292 ShardingStrategy::ConsistentHash | ShardingStrategy::Modulo => {
293 match self.consistent_ring.as_ref() {
294 Some(ring) => ring.get_shard(vector_id),
295 None => {
296 tracing::error!("consistent_ring not initialized for ConsistentHash/Modulo strategy — falling back to shard 0");
297 ShardAssignment {
298 shard_id: 0,
299 nodes: vec![],
300 preferred_node: String::new(),
301 }
302 }
303 }
304 }
305 ShardingStrategy::Range => match self.range_sharder.as_ref() {
306 Some(sharder) => sharder.get_shard(vector_id),
307 None => {
308 tracing::error!("range_sharder not initialized for Range strategy — falling back to shard 0");
309 ShardAssignment {
310 shard_id: 0,
311 nodes: vec![],
312 preferred_node: String::new(),
313 }
314 }
315 },
316 }
317 }
318
319 pub fn get_shards_batch(&self, vector_ids: &[String]) -> HashMap<u32, Vec<String>> {
321 let mut shard_vectors: HashMap<u32, Vec<String>> = HashMap::new();
322
323 for id in vector_ids {
324 let assignment = self.get_shard(id);
325 shard_vectors
326 .entry(assignment.shard_id)
327 .or_default()
328 .push(id.clone());
329 }
330
331 shard_vectors
332 }
333
334 pub fn get_all_shards(&self) -> Vec<u32> {
336 (0..self.config.num_shards).collect()
337 }
338
339 pub fn register_shard_nodes(&mut self, shard_id: u32, node_ids: Vec<String>) {
341 if let Some(ref mut ring) = self.consistent_ring {
342 ring.register_shard_nodes(shard_id, node_ids);
343 } else if let Some(ref mut sharder) = self.range_sharder {
344 sharder.register_shard_nodes(shard_id, node_ids);
345 }
346 }
347
348 pub fn get_partition_info(&self) -> Vec<PartitionInfo> {
350 if let Some(ref ring) = self.consistent_ring {
351 ring.get_partition_info()
352 } else {
353 (0..self.config.num_shards)
354 .map(|shard_id| PartitionInfo {
355 shard_id,
356 node_ids: vec![format!("node-{}", shard_id)],
357 primary_node: format!("node-{}", shard_id),
358 is_healthy: true,
359 vector_count: 0,
360 memory_bytes: 0,
361 })
362 .collect()
363 }
364 }
365
366 pub fn rebalance(&mut self, node_count: u32) {
368 if let Some(ref mut ring) = self.consistent_ring {
369 ring.rebalance(node_count);
370 }
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_consistent_hash_ring() {
380 let config = ShardingConfig {
381 num_shards: 4,
382 replication_factor: 2,
383 strategy: ShardingStrategy::ConsistentHash,
384 virtual_nodes: 100,
385 };
386
387 let ring = ConsistentHashRing::new(config);
388
389 let assignment1 = ring.get_shard("vector-123");
391 let assignment2 = ring.get_shard("vector-123");
392 assert_eq!(assignment1.shard_id, assignment2.shard_id);
393
394 for i in 0..100 {
396 let assignment = ring.get_shard(&format!("test-{}", i));
397 assert!(assignment.shard_id < 4);
398 }
399 }
400
401 #[test]
402 fn test_consistent_hash_distribution() {
403 let config = ShardingConfig {
404 num_shards: 4,
405 replication_factor: 2,
406 strategy: ShardingStrategy::ConsistentHash,
407 virtual_nodes: 150,
408 };
409
410 let ring = ConsistentHashRing::new(config);
411
412 let mut counts = [0u32; 4];
414 for i in 0..1000 {
415 let assignment = ring.get_shard(&format!("vector-{}", i));
416 counts[assignment.shard_id as usize] += 1;
417 }
418
419 let avg = 250.0;
421 for count in counts {
422 assert!(count as f64 > avg * 0.5);
423 assert!((count as f64) < avg * 1.5);
424 }
425 }
426
427 #[test]
428 fn test_batch_sharding() {
429 let config = ShardingConfig::default();
430 let ring = ConsistentHashRing::new(config);
431
432 let ids: Vec<String> = (0..100).map(|i| format!("vec-{}", i)).collect();
433 let shard_batches = ring.get_shards_batch(&ids);
434
435 let total: usize = shard_batches.values().map(|v| v.len()).sum();
437 assert_eq!(total, 100);
438 }
439
440 #[test]
441 fn test_range_sharder() {
442 let config = ShardingConfig {
443 num_shards: 4,
444 replication_factor: 1,
445 strategy: ShardingStrategy::Range,
446 virtual_nodes: 0, };
448
449 let sharder = RangeSharder::new(config);
450
451 let a1 = sharder.get_shard("test-key");
453 let a2 = sharder.get_shard("test-key");
454 assert_eq!(a1.shard_id, a2.shard_id);
455
456 for i in 0..100 {
458 let assignment = sharder.get_shard(&format!("key-{}", i));
459 assert!(assignment.shard_id < 4);
460 }
461 }
462
463 #[test]
464 fn test_shard_manager() {
465 let config = ShardingConfig::default();
466 let mut manager = ShardManager::new(config);
467
468 manager.register_shard_nodes(0, vec!["node-a".to_string(), "node-b".to_string()]);
470
471 let assignment = manager.get_shard("my-vector");
473 assert!(assignment.shard_id < 4);
474
475 let shards = manager.get_all_shards();
477 assert_eq!(shards.len(), 4);
478
479 let partitions = manager.get_partition_info();
481 assert_eq!(partitions.len(), 4);
482 }
483
484 #[test]
485 fn test_rebalance() {
486 let config = ShardingConfig {
487 num_shards: 4,
488 replication_factor: 2,
489 ..Default::default()
490 };
491
492 let mut ring = ConsistentHashRing::new(config);
493 ring.rebalance(3);
494
495 let partitions = ring.get_partition_info();
497 for partition in partitions {
498 assert!(!partition.node_ids.is_empty());
499 assert!(partition.node_ids.len() <= 2); }
501 }
502}