Skip to main content

amaters_cluster/
partitioner.rs

1//! Key range partitioning and query routing
2//!
3//! This module provides partitioning strategies for distributing keys across shards
4//! and routing queries to the correct shard(s).
5
6use crate::error::{RaftError, RaftResult};
7use crate::shard::{KeyRange, ShardId, ShardMetadata, ShardRegistry};
8use crate::types::NodeId;
9use amaters_core::Key;
10use std::collections::{BinaryHeap, HashMap, HashSet};
11use std::hash::{Hash, Hasher};
12use std::sync::Arc;
13
14/// Partitioning strategy for key distribution
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum PartitionStrategy {
17    /// Range-based partitioning (keys sorted by range)
18    Range,
19    /// Hash-based partitioning (keys distributed by hash)
20    Hash,
21    /// Consistent hashing (virtual nodes for better distribution)
22    ConsistentHash,
23}
24
25/// Hash function for key partitioning
26fn hash_key(key: &Key) -> u64 {
27    let mut hasher = std::collections::hash_map::DefaultHasher::new();
28    key.hash(&mut hasher);
29    hasher.finish()
30}
31
32/// A maintained consistent hash ring backed by a `BTreeMap`.
33///
34/// Virtual nodes for each shard are inserted at construction and maintained
35/// incrementally via `add_shard`/`remove_shard`, avoiding the O(S·V·log(S·V))
36/// rebuild on every lookup.
37#[derive(Clone)]
38pub struct HashRing {
39    /// Sorted map from virtual-node hash → ShardId.
40    ring: std::collections::BTreeMap<u64, ShardId>,
41    /// Number of virtual nodes per shard.
42    virtual_nodes: usize,
43}
44
45impl HashRing {
46    /// Create an empty ring with the given number of virtual nodes per shard.
47    pub fn new(virtual_nodes: usize) -> Self {
48        Self {
49            ring: std::collections::BTreeMap::new(),
50            virtual_nodes,
51        }
52    }
53
54    /// Insert `virtual_nodes` entries for the given shard.
55    pub fn add_shard(&mut self, id: ShardId) {
56        for i in 0..self.virtual_nodes {
57            let hash = Self::virtual_node_hash(id, i);
58            self.ring.insert(hash, id);
59        }
60    }
61
62    /// Remove all virtual-node entries for the given shard.
63    pub fn remove_shard(&mut self, id: ShardId) {
64        for i in 0..self.virtual_nodes {
65            let hash = Self::virtual_node_hash(id, i);
66            // Only remove if this slot still maps to our shard (no hash collision replacement).
67            if self.ring.get(&hash) == Some(&id) {
68                self.ring.remove(&hash);
69            }
70        }
71    }
72
73    /// Route a key to the responsible shard.
74    ///
75    /// Returns `None` if the ring is empty.
76    pub fn get_shard_for_key(&self, key: &Key) -> Option<ShardId> {
77        if self.ring.is_empty() {
78            return None;
79        }
80        let key_hash = hash_key(key);
81        // Find the first virtual node with hash >= key_hash (successor).
82        self.ring
83            .range(key_hash..)
84            .next()
85            .or_else(|| self.ring.iter().next()) // wrap-around to first entry
86            .map(|(_, &id)| id)
87    }
88
89    /// Hash function for virtual node `(shard_id, i)`.
90    fn virtual_node_hash(shard_id: ShardId, i: usize) -> u64 {
91        use std::hash::{Hash, Hasher};
92        let virtual_key = format!("{}:{}", shard_id, i);
93        let mut hasher = std::collections::hash_map::DefaultHasher::new();
94        virtual_key.hash(&mut hasher);
95        hasher.finish()
96    }
97}
98
99/// Partitioner handles key-to-shard routing
100#[derive(Clone)]
101pub struct Partitioner {
102    /// Shard registry
103    registry: Arc<ShardRegistry>,
104    /// Partitioning strategy
105    strategy: PartitionStrategy,
106    /// Number of virtual nodes for consistent hashing
107    virtual_nodes: usize,
108    /// Maintained consistent hash ring (initialized from registry at construction).
109    hash_ring: HashRing,
110}
111
112impl Partitioner {
113    /// Create a new partitioner
114    pub fn new(registry: Arc<ShardRegistry>, strategy: PartitionStrategy) -> Self {
115        let virtual_nodes = 100;
116        let mut hash_ring = HashRing::new(virtual_nodes);
117        for shard in registry.get_all() {
118            hash_ring.add_shard(shard.id);
119        }
120        Self {
121            registry,
122            strategy,
123            virtual_nodes,
124            hash_ring,
125        }
126    }
127
128    /// Set the number of virtual nodes for consistent hashing
129    pub fn with_virtual_nodes(mut self, count: usize) -> Self {
130        self.virtual_nodes = count;
131        // Rebuild the hash ring with the new virtual node count.
132        let mut new_ring = HashRing::new(count);
133        for shard in self.registry.get_all() {
134            new_ring.add_shard(shard.id);
135        }
136        self.hash_ring = new_ring;
137        self
138    }
139
140    /// Route a key to the responsible shard
141    pub fn route_key(&self, key: &Key) -> RaftResult<ShardMetadata> {
142        match self.strategy {
143            PartitionStrategy::Range => self.route_by_range(key),
144            PartitionStrategy::Hash => self.route_by_hash(key),
145            PartitionStrategy::ConsistentHash => self.route_by_consistent_hash(key),
146        }
147    }
148
149    /// Route by range partitioning
150    fn route_by_range(&self, key: &Key) -> RaftResult<ShardMetadata> {
151        self.registry
152            .find_shard_for_key(key)
153            .ok_or_else(|| RaftError::ConfigError {
154                message: format!("No shard found for key: {:?}", key),
155            })
156    }
157
158    /// Route by hash partitioning
159    fn route_by_hash(&self, key: &Key) -> RaftResult<ShardMetadata> {
160        let shards = self.registry.get_all();
161        if shards.is_empty() {
162            return Err(RaftError::ConfigError {
163                message: "No shards available".to_string(),
164            });
165        }
166
167        let hash = hash_key(key);
168        let index = (hash % shards.len() as u64) as usize;
169        Ok(shards[index].clone())
170    }
171
172    /// Route by consistent hashing using the maintained hash ring.
173    fn route_by_consistent_hash(&self, key: &Key) -> RaftResult<ShardMetadata> {
174        let shard_id =
175            self.hash_ring
176                .get_shard_for_key(key)
177                .ok_or_else(|| RaftError::ConfigError {
178                    message: "Consistent hash ring is empty — no shards registered".to_string(),
179                })?;
180
181        self.registry
182            .get(shard_id)
183            .ok_or_else(|| RaftError::ConfigError {
184                message: format!("Shard {} not found in registry", shard_id),
185            })
186    }
187
188    /// Route a key range query to all relevant shards
189    pub fn route_range(&self, start: &Key, end: &Key) -> RaftResult<Vec<ShardMetadata>> {
190        let query_range = KeyRange::new(start.clone(), end.clone())?;
191        let shards = self.registry.get_all();
192
193        let relevant_shards: Vec<ShardMetadata> = shards
194            .into_iter()
195            .filter(|shard| shard.range.overlaps(&query_range))
196            .collect();
197
198        if relevant_shards.is_empty() {
199            return Err(RaftError::ConfigError {
200                message: format!("No shards found for range {:?} to {:?}", start, end),
201            });
202        }
203
204        Ok(relevant_shards)
205    }
206
207    /// Get all shards on a specific node
208    pub fn get_shards_on_node(&self, node_id: NodeId) -> Vec<ShardMetadata> {
209        self.registry.get_by_node(node_id)
210    }
211
212    /// Get all shards in the cluster
213    pub fn get_all_shards(&self) -> Vec<ShardMetadata> {
214        self.registry.get_all()
215    }
216}
217
218/// Query router for distributed queries
219pub struct QueryRouter {
220    partitioner: Partitioner,
221}
222
223impl QueryRouter {
224    /// Create a new query router
225    pub fn new(partitioner: Partitioner) -> Self {
226        Self { partitioner }
227    }
228
229    /// Route a point query (single key lookup)
230    pub fn route_point_query(&self, key: &Key) -> RaftResult<QueryPlan> {
231        let shard = self.partitioner.route_key(key)?;
232        Ok(QueryPlan::Single {
233            shard_id: shard.id,
234            node_id: shard.node_id,
235        })
236    }
237
238    /// Route a range query (multiple keys)
239    pub fn route_range_query(&self, start: &Key, end: &Key) -> RaftResult<QueryPlan> {
240        let shards = self.partitioner.route_range(start, end)?;
241
242        let mut targets: HashMap<NodeId, Vec<ShardId>> = HashMap::new();
243        for shard in shards {
244            targets.entry(shard.node_id).or_default().push(shard.id);
245        }
246
247        Ok(QueryPlan::Scatter {
248            targets,
249            merge_required: true,
250        })
251    }
252
253    /// Route a full scan query (all shards)
254    pub fn route_scan_query(&self) -> RaftResult<QueryPlan> {
255        let shards = self.partitioner.get_all_shards();
256        if shards.is_empty() {
257            return Err(RaftError::ConfigError {
258                message: "No shards available for scan".to_string(),
259            });
260        }
261
262        let mut targets: HashMap<NodeId, Vec<ShardId>> = HashMap::new();
263        for shard in shards {
264            targets.entry(shard.node_id).or_default().push(shard.id);
265        }
266
267        Ok(QueryPlan::Scatter {
268            targets,
269            merge_required: true,
270        })
271    }
272
273    /// Get statistics for query planning
274    pub fn get_query_stats(&self) -> QueryStats {
275        let shards = self.partitioner.get_all_shards();
276        let total_shards = shards.len();
277        let nodes: HashSet<NodeId> = shards.iter().map(|s| s.node_id).collect();
278        let total_nodes = nodes.len();
279
280        let total_keys: u64 = shards.iter().map(|s| s.estimated_keys).sum();
281        let total_size: u64 = shards.iter().map(|s| s.estimated_size_bytes).sum();
282
283        QueryStats {
284            total_shards,
285            total_nodes,
286            total_keys,
287            total_size_bytes: total_size,
288        }
289    }
290}
291
292/// Query execution plan
293#[derive(Debug, Clone)]
294pub enum QueryPlan {
295    /// Single shard query
296    Single {
297        /// Target shard ID
298        shard_id: ShardId,
299        /// Target node ID
300        node_id: NodeId,
301    },
302    /// Multi-shard scatter-gather query
303    Scatter {
304        /// Map of node ID to shard IDs
305        targets: HashMap<NodeId, Vec<ShardId>>,
306        /// Whether results need to be merged
307        merge_required: bool,
308    },
309}
310
311impl QueryPlan {
312    /// Get all nodes involved in the query
313    pub fn get_nodes(&self) -> Vec<NodeId> {
314        match self {
315            QueryPlan::Single { node_id, .. } => vec![*node_id],
316            QueryPlan::Scatter { targets, .. } => targets.keys().copied().collect(),
317        }
318    }
319
320    /// Get all shards involved in the query
321    pub fn get_shards(&self) -> Vec<ShardId> {
322        match self {
323            QueryPlan::Single { shard_id, .. } => vec![*shard_id],
324            QueryPlan::Scatter { targets, .. } => targets.values().flatten().copied().collect(),
325        }
326    }
327
328    /// Check if the query requires result merging
329    pub fn requires_merge(&self) -> bool {
330        match self {
331            QueryPlan::Single { .. } => false,
332            QueryPlan::Scatter { merge_required, .. } => *merge_required,
333        }
334    }
335}
336
337/// Query statistics for optimization
338#[derive(Debug, Clone)]
339pub struct QueryStats {
340    /// Total number of shards
341    pub total_shards: usize,
342    /// Total number of nodes
343    pub total_nodes: usize,
344    /// Total estimated keys across all shards
345    pub total_keys: u64,
346    /// Total estimated size in bytes
347    pub total_size_bytes: u64,
348}
349
350impl QueryStats {
351    /// Get average keys per shard
352    pub fn avg_keys_per_shard(&self) -> u64 {
353        if self.total_shards == 0 {
354            0
355        } else {
356            self.total_keys / self.total_shards as u64
357        }
358    }
359
360    /// Get average size per shard
361    pub fn avg_size_per_shard(&self) -> u64 {
362        if self.total_shards == 0 {
363            0
364        } else {
365            self.total_size_bytes / self.total_shards as u64
366        }
367    }
368
369    /// Get average shards per node
370    pub fn avg_shards_per_node(&self) -> f64 {
371        if self.total_nodes == 0 {
372            0.0
373        } else {
374            self.total_shards as f64 / self.total_nodes as f64
375        }
376    }
377}
378
379/// Range-based partitioner backed by a sorted `BTreeMap`.
380///
381/// Maps key-space ranges to shards via range-start bytes.  Each shard owns
382/// the half-open interval `[start, next_start)`.  The shard whose range start
383/// is the largest value ≤ the lookup key is selected.
384pub struct RangePartitioner {
385    /// BTreeMap from range-start bytes (inclusive) → ShardId.
386    /// The owning shard covers [start, next_start).
387    ranges: std::collections::BTreeMap<Vec<u8>, ShardId>,
388}
389
390impl RangePartitioner {
391    /// Create an empty range partitioner.
392    pub fn new() -> Self {
393        Self {
394            ranges: std::collections::BTreeMap::new(),
395        }
396    }
397
398    /// Register a shard that owns the range starting at `start` (inclusive).
399    pub fn add_range(&mut self, start: Vec<u8>, shard_id: ShardId) {
400        self.ranges.insert(start, shard_id);
401    }
402
403    /// Remove the range entry whose start equals `start`.
404    pub fn remove_range(&mut self, start: &[u8]) {
405        self.ranges.remove(start);
406    }
407
408    /// Route a key to its responsible shard.
409    ///
410    /// Finds the shard whose range start is the greatest value ≤ the key.
411    /// Returns `None` if no range covers the key (i.e., key is before all ranges).
412    pub fn get_shard_for_key(&self, key: &Key) -> Option<ShardId> {
413        let key_bytes = key.as_bytes().to_vec();
414        self.ranges
415            .range(..=key_bytes)
416            .next_back()
417            .map(|(_, &id)| id)
418    }
419
420    /// Build a `RangePartitioner` from a `ShardRegistry`, using each shard's
421    /// range start as the routing key.
422    pub fn from_registry(registry: &ShardRegistry) -> Self {
423        let mut rp = Self::new();
424        for shard in registry.get_all() {
425            rp.add_range(shard.range.start.as_bytes().to_vec(), shard.id);
426        }
427        rp
428    }
429}
430
431impl Default for RangePartitioner {
432    fn default() -> Self {
433        Self::new()
434    }
435}
436
437/// Item tracked in the k-way merge heap.
438///
439/// Stores the value alongside its shard and position indices so that
440/// the next element from the same shard can be pushed after a pop.
441struct MergeItem<T> {
442    value: T,
443    shard_idx: usize,
444    item_idx: usize,
445}
446
447// We want a *min*-heap but `BinaryHeap` is a max-heap, so we reverse
448// the ordering.  Two items from different shards that compare equal are
449// tie-broken by shard index to give a stable, deterministic output.
450impl<T: Ord> PartialEq for MergeItem<T> {
451    fn eq(&self, other: &Self) -> bool {
452        self.value == other.value && self.shard_idx == other.shard_idx
453    }
454}
455
456impl<T: Ord> Eq for MergeItem<T> {}
457
458impl<T: Ord> PartialOrd for MergeItem<T> {
459    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
460        Some(self.cmp(other))
461    }
462}
463
464impl<T: Ord> Ord for MergeItem<T> {
465    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
466        // Reverse so that BinaryHeap (max-heap) behaves as a min-heap.
467        // Tie-break on shard_idx for determinism.
468        other
469            .value
470            .cmp(&self.value)
471            .then_with(|| other.shard_idx.cmp(&self.shard_idx))
472    }
473}
474
475/// Wrapper used for `merge_sorted_by_key` where the sort key is extracted
476/// via a closure and may differ from the value's natural ordering.
477struct MergeItemByKey<T, K> {
478    value: T,
479    key: K,
480    shard_idx: usize,
481    item_idx: usize,
482}
483
484impl<T, K: Ord> PartialEq for MergeItemByKey<T, K> {
485    fn eq(&self, other: &Self) -> bool {
486        self.key == other.key && self.shard_idx == other.shard_idx
487    }
488}
489
490impl<T, K: Ord> Eq for MergeItemByKey<T, K> {}
491
492impl<T, K: Ord> PartialOrd for MergeItemByKey<T, K> {
493    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
494        Some(self.cmp(other))
495    }
496}
497
498impl<T, K: Ord> Ord for MergeItemByKey<T, K> {
499    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
500        other
501            .key
502            .cmp(&self.key)
503            .then_with(|| other.shard_idx.cmp(&self.shard_idx))
504    }
505}
506
507/// Result merger for scatter-gather queries.
508///
509/// Provides several strategies for combining results returned from
510/// multiple shards:
511///
512/// - **`merge`** – simple concatenation (no ordering guarantees).
513/// - **`merge_sorted`** – efficient O(N log K) k-way merge that
514///   assumes each input `Vec` is already sorted.
515/// - **`merge_sorted_by_key`** – same algorithm but sorts by a
516///   caller-supplied key extractor, useful for `(Key, Value)` tuples.
517/// - **`merge_deduplicate`** – unordered merge with duplicate removal.
518/// - **`merge_sorted_deduplicate`** – ordered merge with duplicate removal.
519pub struct ResultMerger;
520
521impl ResultMerger {
522    /// Merge results from multiple shards by simple concatenation.
523    ///
524    /// No ordering guarantees.  O(N) where N is the total number of items.
525    pub fn merge<T>(results: Vec<Vec<T>>) -> Vec<T> {
526        let total_len: usize = results.iter().map(|v| v.len()).sum();
527        let mut merged = Vec::with_capacity(total_len);
528        for batch in results {
529            merged.extend(batch);
530        }
531        merged
532    }
533
534    /// Merge pre-sorted shard results using an efficient k-way merge.
535    ///
536    /// Each input `Vec` **must** be sorted in ascending order.  The output
537    /// is a single sorted `Vec`.
538    ///
539    /// Complexity: O(N log K) where N = total items, K = number of shards.
540    pub fn merge_sorted<T>(results: Vec<Vec<T>>) -> Vec<T>
541    where
542        T: Ord,
543    {
544        let total_len: usize = results.iter().map(|v| v.len()).sum();
545        if total_len == 0 {
546            return Vec::new();
547        }
548
549        // Convert each Vec into an owning iterator so we can pull items
550        // one-by-one without cloning.
551        let mut iterators: Vec<std::vec::IntoIter<T>> =
552            results.into_iter().map(|v| v.into_iter()).collect();
553
554        let mut heap: BinaryHeap<MergeItem<T>> = BinaryHeap::with_capacity(iterators.len());
555
556        // Seed the heap with the first element from each non-empty shard.
557        for (shard_idx, iter) in iterators.iter_mut().enumerate() {
558            if let Some(value) = iter.next() {
559                heap.push(MergeItem {
560                    value,
561                    shard_idx,
562                    item_idx: 0,
563                });
564            }
565        }
566
567        let mut merged = Vec::with_capacity(total_len);
568
569        while let Some(item) = heap.pop() {
570            let next_item_idx = item.item_idx + 1;
571            let shard_idx = item.shard_idx;
572            merged.push(item.value);
573
574            // Push the next element from the same shard, if available.
575            if let Some(value) = iterators[shard_idx].next() {
576                heap.push(MergeItem {
577                    value,
578                    shard_idx,
579                    item_idx: next_item_idx,
580                });
581            }
582        }
583
584        merged
585    }
586
587    /// Merge pre-sorted shard results by a key extracted via `key_fn`.
588    ///
589    /// Each input `Vec` **must** be sorted in ascending order of the key.
590    /// Useful for merging `(Key, CipherBlob)` tuples where sorting is
591    /// done on the `Key` component.
592    ///
593    /// Complexity: O(N log K) where N = total items, K = number of shards.
594    pub fn merge_sorted_by_key<T, K, F>(results: Vec<Vec<T>>, key_fn: F) -> Vec<T>
595    where
596        K: Ord,
597        F: Fn(&T) -> K,
598    {
599        let total_len: usize = results.iter().map(|v| v.len()).sum();
600        if total_len == 0 {
601            return Vec::new();
602        }
603
604        let mut iterators: Vec<std::vec::IntoIter<T>> =
605            results.into_iter().map(|v| v.into_iter()).collect();
606
607        let mut heap: BinaryHeap<MergeItemByKey<T, K>> = BinaryHeap::with_capacity(iterators.len());
608
609        for (shard_idx, iter) in iterators.iter_mut().enumerate() {
610            if let Some(value) = iter.next() {
611                let key = key_fn(&value);
612                heap.push(MergeItemByKey {
613                    value,
614                    key,
615                    shard_idx,
616                    item_idx: 0,
617                });
618            }
619        }
620
621        let mut merged = Vec::with_capacity(total_len);
622
623        while let Some(item) = heap.pop() {
624            let next_item_idx = item.item_idx + 1;
625            let shard_idx = item.shard_idx;
626            merged.push(item.value);
627
628            if let Some(value) = iterators[shard_idx].next() {
629                let key = key_fn(&value);
630                heap.push(MergeItemByKey {
631                    value,
632                    key,
633                    shard_idx,
634                    item_idx: next_item_idx,
635                });
636            }
637        }
638
639        merged
640    }
641
642    /// Merge with deduplication (unordered).
643    pub fn merge_deduplicate<T>(results: Vec<Vec<T>>) -> Vec<T>
644    where
645        T: Eq + Hash,
646    {
647        let mut set: HashSet<T> = HashSet::new();
648        for batch in results {
649            set.extend(batch);
650        }
651        set.into_iter().collect()
652    }
653
654    /// Merge pre-sorted shard results with deduplication.
655    ///
656    /// Uses the k-way merge algorithm and skips consecutive duplicates.
657    /// Each input `Vec` **must** be sorted and should not contain
658    /// duplicates within a single shard for best results.
659    ///
660    /// Complexity: O(N log K) where N = total items, K = number of shards.
661    pub fn merge_sorted_deduplicate<T>(results: Vec<Vec<T>>) -> Vec<T>
662    where
663        T: Ord,
664    {
665        let total_len: usize = results.iter().map(|v| v.len()).sum();
666        if total_len == 0 {
667            return Vec::new();
668        }
669
670        let mut iterators: Vec<std::vec::IntoIter<T>> =
671            results.into_iter().map(|v| v.into_iter()).collect();
672
673        let mut heap: BinaryHeap<MergeItem<T>> = BinaryHeap::with_capacity(iterators.len());
674
675        for (shard_idx, iter) in iterators.iter_mut().enumerate() {
676            if let Some(value) = iter.next() {
677                heap.push(MergeItem {
678                    value,
679                    shard_idx,
680                    item_idx: 0,
681                });
682            }
683        }
684
685        let mut merged = Vec::with_capacity(total_len);
686
687        while let Some(item) = heap.pop() {
688            let next_item_idx = item.item_idx + 1;
689            let shard_idx = item.shard_idx;
690
691            // Skip duplicate if the last pushed element is equal.
692            let is_dup = merged.last().is_some_and(|last: &T| last == &item.value);
693            if !is_dup {
694                merged.push(item.value);
695            }
696
697            if let Some(value) = iterators[shard_idx].next() {
698                heap.push(MergeItem {
699                    value,
700                    shard_idx,
701                    item_idx: next_item_idx,
702                });
703            }
704        }
705
706        merged.shrink_to_fit();
707        merged
708    }
709}
710
711#[cfg(test)]
712mod tests {
713    use super::*;
714
715    fn create_test_registry() -> Arc<ShardRegistry> {
716        let registry = Arc::new(ShardRegistry::new());
717
718        // Create 3 shards with non-overlapping ranges
719        let range1 = KeyRange::new(Key::from_str("a"), Key::from_str("h")).expect("valid range");
720        let shard1 = ShardMetadata::new(1, range1, 100);
721        registry.register(shard1).expect("register shard 1");
722
723        let range2 = KeyRange::new(Key::from_str("h"), Key::from_str("p")).expect("valid range");
724        let shard2 = ShardMetadata::new(2, range2, 101);
725        registry.register(shard2).expect("register shard 2");
726
727        let range3 = KeyRange::new(Key::from_str("p"), Key::from_str("z")).expect("valid range");
728        let shard3 = ShardMetadata::new(3, range3, 102);
729        registry.register(shard3).expect("register shard 3");
730
731        registry
732    }
733
734    #[test]
735    fn test_partitioner_range_routing() -> RaftResult<()> {
736        let registry = create_test_registry();
737        let partitioner = Partitioner::new(registry, PartitionStrategy::Range);
738
739        let shard = partitioner.route_key(&Key::from_str("d"))?;
740        assert_eq!(shard.id, 1);
741
742        let shard = partitioner.route_key(&Key::from_str("m"))?;
743        assert_eq!(shard.id, 2);
744
745        let shard = partitioner.route_key(&Key::from_str("x"))?;
746        assert_eq!(shard.id, 3);
747
748        Ok(())
749    }
750
751    #[test]
752    fn test_partitioner_hash_routing() -> RaftResult<()> {
753        let registry = create_test_registry();
754        let partitioner = Partitioner::new(registry, PartitionStrategy::Hash);
755
756        // Hash routing should be deterministic
757        let shard1 = partitioner.route_key(&Key::from_str("test_key"))?;
758        let shard2 = partitioner.route_key(&Key::from_str("test_key"))?;
759        assert_eq!(shard1.id, shard2.id);
760
761        Ok(())
762    }
763
764    #[test]
765    fn test_partitioner_consistent_hash_routing() -> RaftResult<()> {
766        let registry = create_test_registry();
767        let partitioner =
768            Partitioner::new(registry, PartitionStrategy::ConsistentHash).with_virtual_nodes(50);
769
770        // Consistent hashing should be deterministic
771        let shard1 = partitioner.route_key(&Key::from_str("test_key"))?;
772        let shard2 = partitioner.route_key(&Key::from_str("test_key"))?;
773        assert_eq!(shard1.id, shard2.id);
774
775        Ok(())
776    }
777
778    #[test]
779    fn test_partitioner_range_query() -> RaftResult<()> {
780        let registry = create_test_registry();
781        let partitioner = Partitioner::new(registry, PartitionStrategy::Range);
782
783        // Query spanning two shards
784        let shards = partitioner.route_range(&Key::from_str("d"), &Key::from_str("m"))?;
785        assert_eq!(shards.len(), 2);
786
787        // Query spanning all shards
788        let shards = partitioner.route_range(&Key::from_str("a"), &Key::from_str("z"))?;
789        assert_eq!(shards.len(), 3);
790
791        Ok(())
792    }
793
794    #[test]
795    fn test_query_router_point_query() -> RaftResult<()> {
796        let registry = create_test_registry();
797        let partitioner = Partitioner::new(registry, PartitionStrategy::Range);
798        let router = QueryRouter::new(partitioner);
799
800        let plan = router.route_point_query(&Key::from_str("d"))?;
801        match plan {
802            QueryPlan::Single { shard_id, node_id } => {
803                assert_eq!(shard_id, 1);
804                assert_eq!(node_id, 100);
805            }
806            _ => panic!("Expected single query plan"),
807        }
808
809        Ok(())
810    }
811
812    #[test]
813    fn test_query_router_range_query() -> RaftResult<()> {
814        let registry = create_test_registry();
815        let partitioner = Partitioner::new(registry, PartitionStrategy::Range);
816        let router = QueryRouter::new(partitioner);
817
818        let plan = router.route_range_query(&Key::from_str("d"), &Key::from_str("m"))?;
819        match plan {
820            QueryPlan::Scatter {
821                targets,
822                merge_required,
823            } => {
824                assert!(merge_required);
825                assert_eq!(targets.len(), 2); // Two nodes involved
826            }
827            _ => panic!("Expected scatter query plan"),
828        }
829
830        Ok(())
831    }
832
833    #[test]
834    fn test_query_router_scan_query() -> RaftResult<()> {
835        let registry = create_test_registry();
836        let partitioner = Partitioner::new(registry, PartitionStrategy::Range);
837        let router = QueryRouter::new(partitioner);
838
839        let plan = router.route_scan_query()?;
840        match plan {
841            QueryPlan::Scatter { targets, .. } => {
842                assert_eq!(targets.len(), 3); // All nodes involved
843            }
844            _ => panic!("Expected scatter query plan"),
845        }
846
847        Ok(())
848    }
849
850    #[test]
851    fn test_query_stats() -> RaftResult<()> {
852        let registry = create_test_registry();
853        let partitioner = Partitioner::new(registry, PartitionStrategy::Range);
854        let router = QueryRouter::new(partitioner);
855
856        let stats = router.get_query_stats();
857        assert_eq!(stats.total_shards, 3);
858        assert_eq!(stats.total_nodes, 3);
859
860        Ok(())
861    }
862
863    #[test]
864    fn test_query_plan_methods() -> RaftResult<()> {
865        let mut targets = HashMap::new();
866        targets.insert(100, vec![1, 2]);
867        targets.insert(101, vec![3]);
868
869        let plan = QueryPlan::Scatter {
870            targets,
871            merge_required: true,
872        };
873
874        let nodes = plan.get_nodes();
875        assert_eq!(nodes.len(), 2);
876
877        let shards = plan.get_shards();
878        assert_eq!(shards.len(), 3);
879
880        assert!(plan.requires_merge());
881
882        Ok(())
883    }
884
885    // ---------------------------------------------------------------
886    //  ResultMerger tests
887    // ---------------------------------------------------------------
888
889    #[test]
890    fn test_result_merger_merge_concatenates() {
891        let results = vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]];
892        let merged = ResultMerger::merge(results);
893        assert_eq!(merged, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
894    }
895
896    #[test]
897    fn test_result_merger_merge_empty_inputs() {
898        let results: Vec<Vec<i32>> = vec![];
899        let merged = ResultMerger::merge(results);
900        assert!(merged.is_empty());
901    }
902
903    #[test]
904    fn test_result_merger_merge_some_empty_vecs() {
905        let results: Vec<Vec<i32>> = vec![vec![], vec![1, 2], vec![], vec![3]];
906        let merged = ResultMerger::merge(results);
907        assert_eq!(merged, vec![1, 2, 3]);
908    }
909
910    #[test]
911    fn test_result_merger_merge_all_empty_vecs() {
912        let results: Vec<Vec<i32>> = vec![vec![], vec![], vec![]];
913        let merged = ResultMerger::merge(results);
914        assert!(merged.is_empty());
915    }
916
917    // -- merge_sorted (k-way merge) ------------------------------------
918
919    #[test]
920    fn test_merge_sorted_basic() {
921        let results = vec![vec![1, 5, 9], vec![2, 6, 10], vec![3, 7, 11]];
922        let merged = ResultMerger::merge_sorted(results);
923        assert_eq!(merged, vec![1, 2, 3, 5, 6, 7, 9, 10, 11]);
924    }
925
926    #[test]
927    fn test_merge_sorted_empty_input() {
928        let results: Vec<Vec<i32>> = vec![];
929        let merged = ResultMerger::merge_sorted(results);
930        assert!(merged.is_empty());
931    }
932
933    #[test]
934    fn test_merge_sorted_single_shard() {
935        let results = vec![vec![10, 20, 30]];
936        let merged = ResultMerger::merge_sorted(results);
937        assert_eq!(merged, vec![10, 20, 30]);
938    }
939
940    #[test]
941    fn test_merge_sorted_single_element_shards() {
942        let results = vec![vec![5], vec![1], vec![3]];
943        let merged = ResultMerger::merge_sorted(results);
944        assert_eq!(merged, vec![1, 3, 5]);
945    }
946
947    #[test]
948    fn test_merge_sorted_with_empty_shards() {
949        let results = vec![vec![], vec![1, 3, 5], vec![], vec![2, 4], vec![]];
950        let merged = ResultMerger::merge_sorted(results);
951        assert_eq!(merged, vec![1, 2, 3, 4, 5]);
952    }
953
954    #[test]
955    fn test_merge_sorted_all_empty_shards() {
956        let results: Vec<Vec<i32>> = vec![vec![], vec![], vec![]];
957        let merged = ResultMerger::merge_sorted(results);
958        assert!(merged.is_empty());
959    }
960
961    #[test]
962    fn test_merge_sorted_with_duplicates() {
963        let results = vec![vec![1, 3, 5], vec![1, 3, 5], vec![2, 4, 6]];
964        let merged = ResultMerger::merge_sorted(results);
965        assert_eq!(merged, vec![1, 1, 2, 3, 3, 4, 5, 5, 6]);
966    }
967
968    #[test]
969    fn test_merge_sorted_unequal_lengths() {
970        let results = vec![vec![1], vec![2, 4, 6, 8, 10], vec![3, 5]];
971        let merged = ResultMerger::merge_sorted(results);
972        assert_eq!(merged, vec![1, 2, 3, 4, 5, 6, 8, 10]);
973    }
974
975    #[test]
976    fn test_merge_sorted_negative_numbers() {
977        let results = vec![vec![-10, -5, 0], vec![-8, -3, 2], vec![-20, 1]];
978        let merged = ResultMerger::merge_sorted(results);
979        assert_eq!(merged, vec![-20, -10, -8, -5, -3, 0, 1, 2]);
980    }
981
982    #[test]
983    fn test_merge_sorted_strings() {
984        let results = vec![
985            vec!["apple".to_string(), "cherry".to_string()],
986            vec!["banana".to_string(), "date".to_string()],
987        ];
988        let merged = ResultMerger::merge_sorted(results);
989        assert_eq!(
990            merged,
991            vec![
992                "apple".to_string(),
993                "banana".to_string(),
994                "cherry".to_string(),
995                "date".to_string()
996            ]
997        );
998    }
999
1000    #[test]
1001    fn test_merge_sorted_large_scale() {
1002        // 100 shards x 100 items each
1003        let num_shards = 100;
1004        let items_per_shard = 100;
1005        let mut results: Vec<Vec<i64>> = Vec::with_capacity(num_shards);
1006
1007        for shard_idx in 0..num_shards {
1008            let shard: Vec<i64> = (0..items_per_shard)
1009                .map(|i| (shard_idx as i64) + (i as i64) * (num_shards as i64))
1010                .collect();
1011            results.push(shard);
1012        }
1013
1014        let merged = ResultMerger::merge_sorted(results);
1015
1016        // Verify length
1017        assert_eq!(merged.len(), num_shards * items_per_shard);
1018
1019        // Verify sorted
1020        for window in merged.windows(2) {
1021            assert!(
1022                window[0] <= window[1],
1023                "Output not sorted: {} > {}",
1024                window[0],
1025                window[1]
1026            );
1027        }
1028    }
1029
1030    #[test]
1031    fn test_merge_sorted_deterministic_tie_breaking() {
1032        // When values are equal, shard ordering should be deterministic
1033        let results = vec![vec![1, 2, 3], vec![1, 2, 3], vec![1, 2, 3]];
1034        let merged1 = ResultMerger::merge_sorted(results.clone());
1035        let merged2 = ResultMerger::merge_sorted(results);
1036        assert_eq!(merged1, merged2);
1037        assert_eq!(merged1, vec![1, 1, 1, 2, 2, 2, 3, 3, 3]);
1038    }
1039
1040    // -- merge_sorted_by_key -------------------------------------------
1041
1042    #[test]
1043    fn test_merge_sorted_by_key_basic() {
1044        // Simulate (key, value) tuples sorted by key
1045        let results = vec![
1046            vec![(1, "a"), (3, "c"), (5, "e")],
1047            vec![(2, "b"), (4, "d"), (6, "f")],
1048        ];
1049        let merged = ResultMerger::merge_sorted_by_key(results, |item| item.0);
1050        let keys: Vec<i32> = merged.iter().map(|item| item.0).collect();
1051        assert_eq!(keys, vec![1, 2, 3, 4, 5, 6]);
1052    }
1053
1054    #[test]
1055    fn test_merge_sorted_by_key_empty() {
1056        let results: Vec<Vec<(i32, &str)>> = vec![];
1057        let merged = ResultMerger::merge_sorted_by_key(results, |item: &(i32, &str)| item.0);
1058        assert!(merged.is_empty());
1059    }
1060
1061    #[test]
1062    fn test_merge_sorted_by_key_with_string_keys() {
1063        let results = vec![
1064            vec![("apple", 10), ("cherry", 30)],
1065            vec![("banana", 20), ("date", 40)],
1066        ];
1067        let merged = ResultMerger::merge_sorted_by_key(results, |item| item.0);
1068        let keys: Vec<&str> = merged.iter().map(|item| item.0).collect();
1069        assert_eq!(keys, vec!["apple", "banana", "cherry", "date"]);
1070    }
1071
1072    #[test]
1073    fn test_merge_sorted_by_key_reverse_field() {
1074        // Sort by a secondary field (the value, not the first element)
1075        let results = vec![
1076            vec![("x", 1), ("y", 3), ("z", 5)],
1077            vec![("a", 2), ("b", 4), ("c", 6)],
1078        ];
1079        let merged = ResultMerger::merge_sorted_by_key(results, |item| item.1);
1080        let values: Vec<i32> = merged.iter().map(|item| item.1).collect();
1081        assert_eq!(values, vec![1, 2, 3, 4, 5, 6]);
1082    }
1083
1084    // -- merge_deduplicate ---------------------------------------------
1085
1086    #[test]
1087    fn test_result_merger_deduplicate() {
1088        let results = vec![vec![1, 2, 3], vec![2, 3, 4], vec![3, 4, 5]];
1089        let mut merged = ResultMerger::merge_deduplicate(results);
1090        merged.sort();
1091        assert_eq!(merged, vec![1, 2, 3, 4, 5]);
1092    }
1093
1094    #[test]
1095    fn test_result_merger_deduplicate_empty() {
1096        let results: Vec<Vec<i32>> = vec![];
1097        let merged = ResultMerger::merge_deduplicate(results);
1098        assert!(merged.is_empty());
1099    }
1100
1101    // -- merge_sorted_deduplicate --------------------------------------
1102
1103    #[test]
1104    fn test_merge_sorted_deduplicate_basic() {
1105        let results = vec![vec![1, 3, 5], vec![1, 3, 5], vec![2, 4, 6]];
1106        let merged = ResultMerger::merge_sorted_deduplicate(results);
1107        assert_eq!(merged, vec![1, 2, 3, 4, 5, 6]);
1108    }
1109
1110    #[test]
1111    fn test_merge_sorted_deduplicate_no_dups() {
1112        let results = vec![vec![1, 4, 7], vec![2, 5, 8], vec![3, 6, 9]];
1113        let merged = ResultMerger::merge_sorted_deduplicate(results);
1114        assert_eq!(merged, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
1115    }
1116
1117    #[test]
1118    fn test_merge_sorted_deduplicate_all_same() {
1119        let results = vec![vec![1, 1, 1], vec![1, 1], vec![1]];
1120        let merged = ResultMerger::merge_sorted_deduplicate(results);
1121        assert_eq!(merged, vec![1]);
1122    }
1123
1124    #[test]
1125    fn test_merge_sorted_deduplicate_empty() {
1126        let results: Vec<Vec<i32>> = vec![];
1127        let merged = ResultMerger::merge_sorted_deduplicate(results);
1128        assert!(merged.is_empty());
1129    }
1130
1131    // -- property-style randomized test --------------------------------
1132
1133    #[test]
1134    fn test_merge_sorted_random_property() {
1135        // Generate pseudo-random sorted vectors and verify the merge is sorted.
1136        // Uses a simple LCG to avoid depending on `rand`.
1137        let num_shards = 20;
1138        let max_items = 50;
1139        let mut seed: u64 = 0xDEAD_BEEF_CAFE;
1140
1141        let mut results: Vec<Vec<i64>> = Vec::with_capacity(num_shards);
1142        for _ in 0..num_shards {
1143            // Determine shard length
1144            seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
1145            let len = (seed % (max_items as u64 + 1)) as usize;
1146
1147            let mut shard = Vec::with_capacity(len);
1148            for _ in 0..len {
1149                seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
1150                shard.push((seed >> 33) as i64); // use upper bits for quality
1151            }
1152            shard.sort();
1153            results.push(shard);
1154        }
1155
1156        let expected_len: usize = results.iter().map(|v| v.len()).sum();
1157        let merged = ResultMerger::merge_sorted(results);
1158        assert_eq!(merged.len(), expected_len);
1159
1160        // Verify sorted
1161        for window in merged.windows(2) {
1162            assert!(
1163                window[0] <= window[1],
1164                "Property violation: {} > {}",
1165                window[0],
1166                window[1]
1167            );
1168        }
1169    }
1170
1171    #[test]
1172    fn test_hash_ring_maintained() {
1173        let mut ring = HashRing::new(10);
1174
1175        // Add two shards and verify keys route to one of them.
1176        ring.add_shard(1u64);
1177        ring.add_shard(2u64);
1178
1179        let key = Key::from_str("hello");
1180        let shard = ring.get_shard_for_key(&key);
1181        assert!(shard.is_some(), "key must route to some shard");
1182        let shard_id = shard.expect("key must route to some shard");
1183        assert!(shard_id == 1 || shard_id == 2);
1184
1185        // Remove shard 1 — all keys must now route to shard 2.
1186        ring.remove_shard(1u64);
1187        let shard_after = ring.get_shard_for_key(&key);
1188        assert_eq!(shard_after, Some(2u64));
1189    }
1190
1191    #[test]
1192    fn test_range_partitioner_routing() {
1193        let mut rp = RangePartitioner::new();
1194        rp.add_range(b"a".to_vec(), 1u64);
1195        rp.add_range(b"m".to_vec(), 2u64);
1196        rp.add_range(b"z".to_vec(), 3u64);
1197
1198        // "apple" >= "a" and < "m" → shard 1
1199        let k_apple = Key::from_str("apple");
1200        assert_eq!(rp.get_shard_for_key(&k_apple), Some(1u64));
1201
1202        // "moon" >= "m" and < "z" → shard 2
1203        let k_moon = Key::from_str("moon");
1204        assert_eq!(rp.get_shard_for_key(&k_moon), Some(2u64));
1205
1206        // "zebra" >= "z" → shard 3
1207        let k_zebra = Key::from_str("zebra");
1208        assert_eq!(rp.get_shard_for_key(&k_zebra), Some(3u64));
1209
1210        // Key before any range → None
1211        let k_zero = Key::from_slice(&[0u8]);
1212        assert_eq!(rp.get_shard_for_key(&k_zero), None);
1213    }
1214
1215    #[test]
1216    fn test_partitioner_consistent_hash_uses_ring() {
1217        // Build a registry with two shards.
1218        let registry = Arc::new(ShardRegistry::new());
1219        let range1 = KeyRange::new(Key::from_str("a"), Key::from_str("m")).expect("valid range");
1220        let range2 = KeyRange::new(Key::from_str("m"), Key::from_str("z")).expect("valid range");
1221        let shard1 = ShardMetadata::new(1, range1, 100);
1222        let shard2 = ShardMetadata::new(2, range2, 101);
1223        registry.register(shard1).expect("register shard 1");
1224        registry.register(shard2).expect("register shard 2");
1225
1226        let partitioner = Partitioner::new(registry, PartitionStrategy::ConsistentHash);
1227
1228        // Routing must succeed for a key in the hash space.
1229        let key = Key::from_str("hello");
1230        let result = partitioner.route_key(&key);
1231        assert!(
1232            result.is_ok(),
1233            "consistent hash routing must succeed: {:?}",
1234            result
1235        );
1236    }
1237}