Skip to main content

phago_distributed/hashing/
mod.rs

1//! Consistent hashing for document-to-shard routing.
2//!
3//! This module implements a consistent hash ring that distributes documents
4//! across shards with minimal redistribution when the cluster topology changes.
5//! Virtual nodes are used to ensure even distribution of data across shards.
6
7use crate::types::ShardId;
8use phago_core::types::DocumentId;
9use std::collections::BTreeMap;
10use std::hash::{Hash, Hasher};
11
12/// Number of virtual nodes per physical shard for better distribution.
13const VIRTUAL_NODES_PER_SHARD: u32 = 150;
14
15/// A consistent hash ring for routing documents to shards.
16///
17/// The ring uses virtual nodes to achieve better load distribution across
18/// shards. Each physical shard is represented by multiple virtual nodes
19/// on the ring, which helps ensure that documents are distributed evenly.
20///
21/// # Thread Safety
22///
23/// The ring itself is `Clone` and can be wrapped in `Arc<RwLock<_>>` for
24/// thread-safe access with dynamic updates, or `Arc<_>` for read-only
25/// concurrent access.
26///
27/// # Example
28///
29/// ```
30/// use phago_distributed::hashing::ConsistentHashRing;
31/// use phago_core::types::DocumentId;
32///
33/// let ring = ConsistentHashRing::new(3);
34/// let doc_id = DocumentId::from_seed(42);
35/// let shard = ring.get_shard(&doc_id);
36/// println!("Document maps to shard: {}", shard);
37/// ```
38#[derive(Debug, Clone)]
39pub struct ConsistentHashRing {
40    /// Ring mapping hash positions to shard IDs.
41    ring: BTreeMap<u64, ShardId>,
42    /// Number of shards in the ring.
43    shard_count: u32,
44    /// Virtual nodes per shard.
45    virtual_nodes: u32,
46}
47
48impl ConsistentHashRing {
49    /// Create a new hash ring with the specified number of shards.
50    ///
51    /// Each shard will be represented by `VIRTUAL_NODES_PER_SHARD` virtual
52    /// nodes on the ring for better distribution.
53    ///
54    /// # Arguments
55    ///
56    /// * `num_shards` - The number of physical shards to distribute across
57    ///
58    /// # Panics
59    ///
60    /// Panics if `num_shards` is 0.
61    pub fn new(num_shards: u32) -> Self {
62        assert!(num_shards > 0, "Number of shards must be greater than 0");
63
64        let mut ring = BTreeMap::new();
65
66        for shard_id in 0..num_shards {
67            for vnode in 0..VIRTUAL_NODES_PER_SHARD {
68                let hash = Self::hash_shard_vnode(shard_id, vnode);
69                ring.insert(hash, ShardId::new(shard_id));
70            }
71        }
72
73        Self {
74            ring,
75            shard_count: num_shards,
76            virtual_nodes: VIRTUAL_NODES_PER_SHARD,
77        }
78    }
79
80    /// Create a new hash ring with custom virtual nodes per shard.
81    ///
82    /// More virtual nodes generally result in better distribution but
83    /// increase memory usage and lookup time slightly.
84    ///
85    /// # Arguments
86    ///
87    /// * `num_shards` - The number of physical shards
88    /// * `virtual_nodes` - Number of virtual nodes per shard
89    pub fn with_virtual_nodes(num_shards: u32, virtual_nodes: u32) -> Self {
90        assert!(num_shards > 0, "Number of shards must be greater than 0");
91        assert!(virtual_nodes > 0, "Virtual nodes must be greater than 0");
92
93        let mut ring = BTreeMap::new();
94
95        for shard_id in 0..num_shards {
96            for vnode in 0..virtual_nodes {
97                let hash = Self::hash_shard_vnode(shard_id, vnode);
98                ring.insert(hash, ShardId::new(shard_id));
99            }
100        }
101
102        Self {
103            ring,
104            shard_count: num_shards,
105            virtual_nodes,
106        }
107    }
108
109    /// Get the shard ID for a document.
110    ///
111    /// This operation is O(log n) where n is the total number of virtual nodes.
112    ///
113    /// # Arguments
114    ///
115    /// * `doc_id` - The document ID to route
116    ///
117    /// # Returns
118    ///
119    /// The shard ID that should store this document.
120    pub fn get_shard(&self, doc_id: &DocumentId) -> ShardId {
121        let hash = Self::hash_document(doc_id);
122
123        // Find the first shard with a hash >= document hash (clockwise)
124        if let Some((&_pos, &shard_id)) = self.ring.range(hash..).next() {
125            shard_id
126        } else {
127            // Wrap around to the first shard
128            *self.ring.values().next().unwrap_or(&ShardId::new(0))
129        }
130    }
131
132    /// Get the shard ID for an arbitrary key.
133    ///
134    /// This is useful for routing non-document data to shards.
135    ///
136    /// # Arguments
137    ///
138    /// * `key` - Any hashable key
139    pub fn get_shard_for_key<K: Hash>(&self, key: &K) -> ShardId {
140        let hash = Self::hash_key(key);
141
142        if let Some((&_pos, &shard_id)) = self.ring.range(hash..).next() {
143            shard_id
144        } else {
145            *self.ring.values().next().unwrap_or(&ShardId::new(0))
146        }
147    }
148
149    /// Add a new shard to the ring.
150    ///
151    /// This will redistribute approximately `1 / (n+1)` of the keys from
152    /// existing shards to the new shard, where n is the current number of shards.
153    ///
154    /// # Arguments
155    ///
156    /// * `shard_id` - The shard ID to add
157    pub fn add_shard(&mut self, shard_id: ShardId) {
158        for vnode in 0..self.virtual_nodes {
159            let hash = Self::hash_shard_vnode(shard_id.0, vnode);
160            self.ring.insert(hash, shard_id);
161        }
162        self.shard_count += 1;
163    }
164
165    /// Remove a shard from the ring.
166    ///
167    /// Documents previously assigned to this shard will be redistributed
168    /// to the next shard in the ring (clockwise).
169    ///
170    /// # Arguments
171    ///
172    /// * `shard_id` - The shard ID to remove
173    pub fn remove_shard(&mut self, shard_id: ShardId) {
174        self.ring.retain(|_, &mut sid| sid != shard_id);
175        self.shard_count = self.shard_count.saturating_sub(1);
176    }
177
178    /// Get the number of shards.
179    pub fn shard_count(&self) -> u32 {
180        self.shard_count
181    }
182
183    /// Get all shard IDs in the ring.
184    ///
185    /// Returns a sorted, deduplicated list of all shard IDs.
186    pub fn all_shards(&self) -> Vec<ShardId> {
187        let mut shards: Vec<ShardId> = self.ring.values().copied().collect();
188        shards.sort_by_key(|s| s.0);
189        shards.dedup();
190        shards
191    }
192
193    /// Get the number of virtual nodes per shard.
194    pub fn virtual_nodes_per_shard(&self) -> u32 {
195        self.virtual_nodes
196    }
197
198    /// Get the total number of virtual nodes in the ring.
199    pub fn total_virtual_nodes(&self) -> usize {
200        self.ring.len()
201    }
202
203    /// Get replica shards for a document.
204    ///
205    /// Returns the primary shard plus `replica_count` additional shards
206    /// that should store replicas of the document.
207    ///
208    /// # Arguments
209    ///
210    /// * `doc_id` - The document ID
211    /// * `replica_count` - Number of additional replicas (excluding primary)
212    ///
213    /// # Returns
214    ///
215    /// A vector of shard IDs, with the primary shard first.
216    pub fn get_replica_shards(&self, doc_id: &DocumentId, replica_count: usize) -> Vec<ShardId> {
217        let hash = Self::hash_document(doc_id);
218        let mut shards = Vec::with_capacity(replica_count + 1);
219        let mut seen_shards = std::collections::HashSet::new();
220
221        // Start from the document's hash position and walk clockwise
222        for (&_pos, &shard_id) in self.ring.range(hash..).chain(self.ring.iter()) {
223            if seen_shards.insert(shard_id) {
224                shards.push(shard_id);
225                if shards.len() > replica_count {
226                    break;
227                }
228            }
229        }
230
231        shards
232    }
233
234    /// Hash a document ID to a ring position.
235    fn hash_document(doc_id: &DocumentId) -> u64 {
236        use std::collections::hash_map::DefaultHasher;
237        let mut hasher = DefaultHasher::new();
238        doc_id.0.hash(&mut hasher);
239        hasher.finish()
240    }
241
242    /// Hash an arbitrary key to a ring position.
243    fn hash_key<K: Hash>(key: &K) -> u64 {
244        use std::collections::hash_map::DefaultHasher;
245        let mut hasher = DefaultHasher::new();
246        key.hash(&mut hasher);
247        hasher.finish()
248    }
249
250    /// Hash a shard + virtual node combination.
251    fn hash_shard_vnode(shard_id: u32, vnode: u32) -> u64 {
252        use std::collections::hash_map::DefaultHasher;
253        let mut hasher = DefaultHasher::new();
254        shard_id.hash(&mut hasher);
255        vnode.hash(&mut hasher);
256        hasher.finish()
257    }
258}
259
260impl Default for ConsistentHashRing {
261    fn default() -> Self {
262        Self::new(1)
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_new_ring() {
272        let ring = ConsistentHashRing::new(3);
273        assert_eq!(ring.shard_count(), 3);
274        assert_eq!(ring.all_shards().len(), 3);
275        assert_eq!(
276            ring.total_virtual_nodes(),
277            3 * VIRTUAL_NODES_PER_SHARD as usize
278        );
279    }
280
281    #[test]
282    fn test_distribution() {
283        let ring = ConsistentHashRing::new(3);
284
285        // Create 100 documents and check distribution
286        let mut counts = [0u32; 3];
287        for i in 0..100 {
288            let doc_id = DocumentId::from_seed(i);
289            let shard = ring.get_shard(&doc_id);
290            counts[shard.0 as usize] += 1;
291        }
292
293        // Each shard should get roughly 33 documents
294        for count in counts {
295            assert!(
296                count >= 20 && count <= 50,
297                "Distribution skewed: {:?}",
298                counts
299            );
300        }
301    }
302
303    #[test]
304    fn test_consistency() {
305        let ring = ConsistentHashRing::new(3);
306        let doc_id = DocumentId::from_seed(42);
307
308        // Same document should always map to same shard
309        let shard1 = ring.get_shard(&doc_id);
310        let shard2 = ring.get_shard(&doc_id);
311        assert_eq!(shard1, shard2);
312    }
313
314    #[test]
315    fn test_add_shard_minimal_redistribution() {
316        let mut ring = ConsistentHashRing::new(3);
317
318        // Record initial assignments
319        let initial: Vec<ShardId> = (0..100)
320            .map(|i| ring.get_shard(&DocumentId::from_seed(i)))
321            .collect();
322
323        // Add a fourth shard
324        ring.add_shard(ShardId::new(3));
325
326        // Check how many documents moved
327        let mut moved = 0;
328        for i in 0..100 {
329            let doc_id = DocumentId::from_seed(i);
330            if ring.get_shard(&doc_id) != initial[i as usize] {
331                moved += 1;
332            }
333        }
334
335        // With consistent hashing, only ~25% should move to the new shard
336        assert!(moved <= 35, "Too many documents moved: {}", moved);
337    }
338
339    #[test]
340    fn test_remove_shard() {
341        let mut ring = ConsistentHashRing::new(3);
342        assert_eq!(ring.shard_count(), 3);
343
344        ring.remove_shard(ShardId::new(1));
345        assert_eq!(ring.shard_count(), 2);
346
347        // Documents should still be assignable
348        let doc_id = DocumentId::from_seed(42);
349        let shard = ring.get_shard(&doc_id);
350        assert!(shard.0 != 1, "Document assigned to removed shard");
351    }
352
353    #[test]
354    fn test_replica_shards() {
355        let ring = ConsistentHashRing::new(5);
356        let doc_id = DocumentId::from_seed(42);
357
358        let replicas = ring.get_replica_shards(&doc_id, 2);
359        assert_eq!(replicas.len(), 3); // primary + 2 replicas
360
361        // All replicas should be unique
362        let unique: std::collections::HashSet<_> = replicas.iter().collect();
363        assert_eq!(unique.len(), 3);
364    }
365
366    #[test]
367    fn test_custom_virtual_nodes() {
368        let ring = ConsistentHashRing::with_virtual_nodes(3, 50);
369        assert_eq!(ring.virtual_nodes_per_shard(), 50);
370        assert_eq!(ring.total_virtual_nodes(), 150);
371    }
372
373    #[test]
374    fn test_get_shard_for_key() {
375        let ring = ConsistentHashRing::new(3);
376
377        // String keys should work
378        let shard1 = ring.get_shard_for_key(&"user:123");
379        let shard2 = ring.get_shard_for_key(&"user:123");
380        assert_eq!(shard1, shard2);
381
382        // Different keys may go to different shards
383        let shard3 = ring.get_shard_for_key(&"user:456");
384        // This might or might not be equal, just verify it works
385        let _ = shard3;
386    }
387
388    #[test]
389    #[should_panic(expected = "Number of shards must be greater than 0")]
390    fn test_zero_shards_panics() {
391        let _ = ConsistentHashRing::new(0);
392    }
393
394    #[test]
395    fn test_default() {
396        let ring = ConsistentHashRing::default();
397        assert_eq!(ring.shard_count(), 1);
398    }
399
400    #[test]
401    fn test_shard_id_display() {
402        let shard = ShardId::new(5);
403        assert_eq!(format!("{}", shard), "shard-5");
404    }
405}