Skip to main content

engine/distributed/
sharding.rs

1//! Sharding Strategies for Distributed Vector Storage
2//!
3//! Supports multiple partitioning strategies:
4//! - Consistent hashing for even distribution
5//! - Range-based for locality-aware placement
6//! - Custom strategies for specialized workloads
7
8use serde::{Deserialize, Serialize};
9use std::collections::hash_map::DefaultHasher;
10use std::collections::{BTreeMap, HashMap};
11use std::hash::{Hash, Hasher};
12
13/// Configuration for sharding behavior
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ShardingConfig {
16    /// Number of shards (partitions)
17    pub num_shards: u32,
18    /// Replication factor for each shard
19    pub replication_factor: u32,
20    /// Sharding strategy to use
21    pub strategy: ShardingStrategy,
22    /// Number of virtual nodes for consistent hashing
23    pub virtual_nodes: u32,
24}
25
26impl Default for ShardingConfig {
27    fn default() -> Self {
28        Self {
29            num_shards: 4,
30            replication_factor: 2,
31            strategy: ShardingStrategy::ConsistentHash,
32            virtual_nodes: 150,
33        }
34    }
35}
36
37/// Strategy for distributing data across shards
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum ShardingStrategy {
40    /// Consistent hashing with virtual nodes
41    ConsistentHash,
42    /// Range-based partitioning
43    Range,
44    /// Simple modulo-based hashing
45    Modulo,
46}
47
48/// Information about a single partition/shard
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct PartitionInfo {
51    /// Unique shard identifier
52    pub shard_id: u32,
53    /// Node IDs hosting this shard (primary + replicas)
54    pub node_ids: Vec<String>,
55    /// Primary node for writes
56    pub primary_node: String,
57    /// Whether this shard is healthy
58    pub is_healthy: bool,
59    /// Number of vectors in this shard
60    pub vector_count: u64,
61    /// Approximate memory usage in bytes
62    pub memory_bytes: u64,
63}
64
65/// Assignment of a key to a shard
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct ShardAssignment {
68    /// The shard this key belongs to
69    pub shard_id: u32,
70    /// Node IDs that can serve this shard
71    pub nodes: Vec<String>,
72    /// Preferred node for this request
73    pub preferred_node: String,
74}
75
76/// Consistent hash ring for shard assignment
77#[derive(Debug, Clone)]
78pub struct ConsistentHashRing {
79    /// Ring of virtual node positions to shard IDs
80    ring: BTreeMap<u64, u32>,
81    /// Configuration
82    config: ShardingConfig,
83    /// Shard to nodes mapping
84    shard_nodes: HashMap<u32, Vec<String>>,
85}
86
87impl ConsistentHashRing {
88    /// Create a new consistent hash ring
89    pub fn new(config: ShardingConfig) -> Self {
90        let mut ring = BTreeMap::new();
91
92        // Add virtual nodes for each shard
93        for shard_id in 0..config.num_shards {
94            for vnode in 0..config.virtual_nodes {
95                let key = format!("shard-{}-vnode-{}", shard_id, vnode);
96                let hash = Self::hash_key(&key);
97                ring.insert(hash, shard_id);
98            }
99        }
100
101        Self {
102            ring,
103            config,
104            shard_nodes: HashMap::new(),
105        }
106    }
107
108    /// Hash a key to a u64 position on the ring
109    fn hash_key(key: &str) -> u64 {
110        let mut hasher = DefaultHasher::new();
111        key.hash(&mut hasher);
112        hasher.finish()
113    }
114
115    /// Get the shard assignment for a vector ID
116    pub fn get_shard(&self, vector_id: &str) -> ShardAssignment {
117        let hash = Self::hash_key(vector_id);
118
119        // Find the first node position >= hash
120        let shard_id = self
121            .ring
122            .range(hash..)
123            .next()
124            .or_else(|| self.ring.iter().next())
125            .map(|(_, &shard)| shard)
126            .unwrap_or(0);
127
128        let nodes = self
129            .shard_nodes
130            .get(&shard_id)
131            .cloned()
132            .unwrap_or_else(|| vec![format!("node-{}", shard_id)]);
133
134        let preferred_node = nodes.first().cloned().unwrap_or_default();
135
136        ShardAssignment {
137            shard_id,
138            nodes,
139            preferred_node,
140        }
141    }
142
143    /// Get shards for a batch of vector IDs
144    pub fn get_shards_batch(&self, vector_ids: &[String]) -> HashMap<u32, Vec<String>> {
145        let mut shard_vectors: HashMap<u32, Vec<String>> = HashMap::new();
146
147        for id in vector_ids {
148            let assignment = self.get_shard(id);
149            shard_vectors
150                .entry(assignment.shard_id)
151                .or_default()
152                .push(id.clone());
153        }
154
155        shard_vectors
156    }
157
158    /// Register nodes for a shard
159    pub fn register_shard_nodes(&mut self, shard_id: u32, node_ids: Vec<String>) {
160        self.shard_nodes.insert(shard_id, node_ids);
161    }
162
163    /// Get all shard IDs
164    pub fn get_all_shards(&self) -> Vec<u32> {
165        (0..self.config.num_shards).collect()
166    }
167
168    /// Get partition info for all shards
169    pub fn get_partition_info(&self) -> Vec<PartitionInfo> {
170        (0..self.config.num_shards)
171            .map(|shard_id| {
172                let nodes = self
173                    .shard_nodes
174                    .get(&shard_id)
175                    .cloned()
176                    .unwrap_or_else(|| vec![format!("node-{}", shard_id)]);
177                let primary = nodes.first().cloned().unwrap_or_default();
178
179                PartitionInfo {
180                    shard_id,
181                    node_ids: nodes,
182                    primary_node: primary,
183                    is_healthy: true,
184                    vector_count: 0,
185                    memory_bytes: 0,
186                }
187            })
188            .collect()
189    }
190
191    /// Rebalance when nodes change
192    pub fn rebalance(&mut self, new_node_count: u32) {
193        // Redistribute shards across new node count
194        for shard_id in 0..self.config.num_shards {
195            let mut nodes = Vec::new();
196            for replica in 0..self.config.replication_factor.min(new_node_count) {
197                let node_idx = (shard_id + replica) % new_node_count;
198                nodes.push(format!("node-{}", node_idx));
199            }
200            self.shard_nodes.insert(shard_id, nodes);
201        }
202    }
203}
204
205/// Range-based sharding for ordered data
206#[derive(Debug, Clone)]
207pub struct RangeSharder {
208    /// Boundaries for each shard (exclusive upper bounds)
209    boundaries: Vec<u64>,
210    /// Configuration
211    config: ShardingConfig,
212    /// Shard to nodes mapping
213    shard_nodes: HashMap<u32, Vec<String>>,
214}
215
216impl RangeSharder {
217    /// Create a new range sharder with even distribution
218    pub fn new(config: ShardingConfig) -> Self {
219        let step = u64::MAX / config.num_shards as u64;
220        let boundaries: Vec<u64> = (1..config.num_shards).map(|i| step * i as u64).collect();
221
222        Self {
223            boundaries,
224            config,
225            shard_nodes: HashMap::new(),
226        }
227    }
228
229    /// Get shard for a vector ID using range partitioning
230    pub fn get_shard(&self, vector_id: &str) -> ShardAssignment {
231        let hash = {
232            let mut hasher = DefaultHasher::new();
233            vector_id.hash(&mut hasher);
234            hasher.finish()
235        };
236
237        // Find which range this hash falls into
238        let shard_id = self
239            .boundaries
240            .iter()
241            .position(|&b| hash < b)
242            .unwrap_or(self.config.num_shards as usize - 1) as u32;
243
244        let nodes = self
245            .shard_nodes
246            .get(&shard_id)
247            .cloned()
248            .unwrap_or_else(|| vec![format!("node-{}", shard_id)]);
249
250        let preferred_node = nodes.first().cloned().unwrap_or_default();
251
252        ShardAssignment {
253            shard_id,
254            nodes,
255            preferred_node,
256        }
257    }
258
259    /// Register nodes for a shard
260    pub fn register_shard_nodes(&mut self, shard_id: u32, node_ids: Vec<String>) {
261        self.shard_nodes.insert(shard_id, node_ids);
262    }
263}
264
265/// Unified shard manager supporting multiple strategies
266pub struct ShardManager {
267    config: ShardingConfig,
268    consistent_ring: Option<ConsistentHashRing>,
269    range_sharder: Option<RangeSharder>,
270}
271
272impl ShardManager {
273    /// Create a new shard manager
274    pub fn new(config: ShardingConfig) -> Self {
275        let (consistent_ring, range_sharder) = match config.strategy {
276            ShardingStrategy::ConsistentHash | ShardingStrategy::Modulo => {
277                (Some(ConsistentHashRing::new(config.clone())), None)
278            }
279            ShardingStrategy::Range => (None, Some(RangeSharder::new(config.clone()))),
280        };
281
282        Self {
283            config,
284            consistent_ring,
285            range_sharder,
286        }
287    }
288
289    /// Get shard assignment for a vector ID
290    pub fn get_shard(&self, vector_id: &str) -> ShardAssignment {
291        match self.config.strategy {
292            ShardingStrategy::ConsistentHash | ShardingStrategy::Modulo => {
293                match self.consistent_ring.as_ref() {
294                    Some(ring) => ring.get_shard(vector_id),
295                    None => {
296                        tracing::error!("consistent_ring not initialized for ConsistentHash/Modulo strategy — falling back to shard 0");
297                        ShardAssignment {
298                            shard_id: 0,
299                            nodes: vec![],
300                            preferred_node: String::new(),
301                        }
302                    }
303                }
304            }
305            ShardingStrategy::Range => match self.range_sharder.as_ref() {
306                Some(sharder) => sharder.get_shard(vector_id),
307                None => {
308                    tracing::error!("range_sharder not initialized for Range strategy — falling back to shard 0");
309                    ShardAssignment {
310                        shard_id: 0,
311                        nodes: vec![],
312                        preferred_node: String::new(),
313                    }
314                }
315            },
316        }
317    }
318
319    /// Get shards for a batch of vector IDs
320    pub fn get_shards_batch(&self, vector_ids: &[String]) -> HashMap<u32, Vec<String>> {
321        let mut shard_vectors: HashMap<u32, Vec<String>> = HashMap::new();
322
323        for id in vector_ids {
324            let assignment = self.get_shard(id);
325            shard_vectors
326                .entry(assignment.shard_id)
327                .or_default()
328                .push(id.clone());
329        }
330
331        shard_vectors
332    }
333
334    /// Get all shard IDs for scatter queries
335    pub fn get_all_shards(&self) -> Vec<u32> {
336        (0..self.config.num_shards).collect()
337    }
338
339    /// Register nodes for a shard
340    pub fn register_shard_nodes(&mut self, shard_id: u32, node_ids: Vec<String>) {
341        if let Some(ref mut ring) = self.consistent_ring {
342            ring.register_shard_nodes(shard_id, node_ids);
343        } else if let Some(ref mut sharder) = self.range_sharder {
344            sharder.register_shard_nodes(shard_id, node_ids);
345        }
346    }
347
348    /// Get partition information for all shards
349    pub fn get_partition_info(&self) -> Vec<PartitionInfo> {
350        if let Some(ref ring) = self.consistent_ring {
351            ring.get_partition_info()
352        } else {
353            (0..self.config.num_shards)
354                .map(|shard_id| PartitionInfo {
355                    shard_id,
356                    node_ids: vec![format!("node-{}", shard_id)],
357                    primary_node: format!("node-{}", shard_id),
358                    is_healthy: true,
359                    vector_count: 0,
360                    memory_bytes: 0,
361                })
362                .collect()
363        }
364    }
365
366    /// Rebalance shards across nodes
367    pub fn rebalance(&mut self, node_count: u32) {
368        if let Some(ref mut ring) = self.consistent_ring {
369            ring.rebalance(node_count);
370        }
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_consistent_hash_ring() {
380        let config = ShardingConfig {
381            num_shards: 4,
382            replication_factor: 2,
383            strategy: ShardingStrategy::ConsistentHash,
384            virtual_nodes: 100,
385        };
386
387        let ring = ConsistentHashRing::new(config);
388
389        // Test that same key always maps to same shard
390        let assignment1 = ring.get_shard("vector-123");
391        let assignment2 = ring.get_shard("vector-123");
392        assert_eq!(assignment1.shard_id, assignment2.shard_id);
393
394        // Test that shards are in valid range
395        for i in 0..100 {
396            let assignment = ring.get_shard(&format!("test-{}", i));
397            assert!(assignment.shard_id < 4);
398        }
399    }
400
401    #[test]
402    fn test_consistent_hash_distribution() {
403        let config = ShardingConfig {
404            num_shards: 4,
405            replication_factor: 2,
406            strategy: ShardingStrategy::ConsistentHash,
407            virtual_nodes: 150,
408        };
409
410        let ring = ConsistentHashRing::new(config);
411
412        // Count distribution across shards
413        let mut counts = [0u32; 4];
414        for i in 0..1000 {
415            let assignment = ring.get_shard(&format!("vector-{}", i));
416            counts[assignment.shard_id as usize] += 1;
417        }
418
419        // Check reasonably even distribution (within 50% of average)
420        let avg = 250.0;
421        for count in counts {
422            assert!(count as f64 > avg * 0.5);
423            assert!((count as f64) < avg * 1.5);
424        }
425    }
426
427    #[test]
428    fn test_batch_sharding() {
429        let config = ShardingConfig::default();
430        let ring = ConsistentHashRing::new(config);
431
432        let ids: Vec<String> = (0..100).map(|i| format!("vec-{}", i)).collect();
433        let shard_batches = ring.get_shards_batch(&ids);
434
435        // All IDs should be distributed
436        let total: usize = shard_batches.values().map(|v| v.len()).sum();
437        assert_eq!(total, 100);
438    }
439
440    #[test]
441    fn test_range_sharder() {
442        let config = ShardingConfig {
443            num_shards: 4,
444            replication_factor: 1,
445            strategy: ShardingStrategy::Range,
446            virtual_nodes: 0, // Not used for range
447        };
448
449        let sharder = RangeSharder::new(config);
450
451        // Test deterministic assignment
452        let a1 = sharder.get_shard("test-key");
453        let a2 = sharder.get_shard("test-key");
454        assert_eq!(a1.shard_id, a2.shard_id);
455
456        // Test shard range
457        for i in 0..100 {
458            let assignment = sharder.get_shard(&format!("key-{}", i));
459            assert!(assignment.shard_id < 4);
460        }
461    }
462
463    #[test]
464    fn test_shard_manager() {
465        let config = ShardingConfig::default();
466        let mut manager = ShardManager::new(config);
467
468        // Register nodes
469        manager.register_shard_nodes(0, vec!["node-a".to_string(), "node-b".to_string()]);
470
471        // Test sharding
472        let assignment = manager.get_shard("my-vector");
473        assert!(assignment.shard_id < 4);
474
475        // Test all shards
476        let shards = manager.get_all_shards();
477        assert_eq!(shards.len(), 4);
478
479        // Test partition info
480        let partitions = manager.get_partition_info();
481        assert_eq!(partitions.len(), 4);
482    }
483
484    #[test]
485    fn test_rebalance() {
486        let config = ShardingConfig {
487            num_shards: 4,
488            replication_factor: 2,
489            ..Default::default()
490        };
491
492        let mut ring = ConsistentHashRing::new(config);
493        ring.rebalance(3);
494
495        // After rebalance, each shard should have nodes assigned
496        let partitions = ring.get_partition_info();
497        for partition in partitions {
498            assert!(!partition.node_ids.is_empty());
499            assert!(partition.node_ids.len() <= 2); // replication_factor
500        }
501    }
502}