phago_distributed/hashing/
mod.rs1use crate::types::ShardId;
8use phago_core::types::DocumentId;
9use std::collections::BTreeMap;
10use std::hash::{Hash, Hasher};
11
12const VIRTUAL_NODES_PER_SHARD: u32 = 150;
14
15#[derive(Debug, Clone)]
39pub struct ConsistentHashRing {
40 ring: BTreeMap<u64, ShardId>,
42 shard_count: u32,
44 virtual_nodes: u32,
46}
47
48impl ConsistentHashRing {
49 pub fn new(num_shards: u32) -> Self {
62 assert!(num_shards > 0, "Number of shards must be greater than 0");
63
64 let mut ring = BTreeMap::new();
65
66 for shard_id in 0..num_shards {
67 for vnode in 0..VIRTUAL_NODES_PER_SHARD {
68 let hash = Self::hash_shard_vnode(shard_id, vnode);
69 ring.insert(hash, ShardId::new(shard_id));
70 }
71 }
72
73 Self {
74 ring,
75 shard_count: num_shards,
76 virtual_nodes: VIRTUAL_NODES_PER_SHARD,
77 }
78 }
79
80 pub fn with_virtual_nodes(num_shards: u32, virtual_nodes: u32) -> Self {
90 assert!(num_shards > 0, "Number of shards must be greater than 0");
91 assert!(virtual_nodes > 0, "Virtual nodes must be greater than 0");
92
93 let mut ring = BTreeMap::new();
94
95 for shard_id in 0..num_shards {
96 for vnode in 0..virtual_nodes {
97 let hash = Self::hash_shard_vnode(shard_id, vnode);
98 ring.insert(hash, ShardId::new(shard_id));
99 }
100 }
101
102 Self {
103 ring,
104 shard_count: num_shards,
105 virtual_nodes,
106 }
107 }
108
109 pub fn get_shard(&self, doc_id: &DocumentId) -> ShardId {
121 let hash = Self::hash_document(doc_id);
122
123 if let Some((&_pos, &shard_id)) = self.ring.range(hash..).next() {
125 shard_id
126 } else {
127 *self.ring.values().next().unwrap_or(&ShardId::new(0))
129 }
130 }
131
132 pub fn get_shard_for_key<K: Hash>(&self, key: &K) -> ShardId {
140 let hash = Self::hash_key(key);
141
142 if let Some((&_pos, &shard_id)) = self.ring.range(hash..).next() {
143 shard_id
144 } else {
145 *self.ring.values().next().unwrap_or(&ShardId::new(0))
146 }
147 }
148
149 pub fn add_shard(&mut self, shard_id: ShardId) {
158 for vnode in 0..self.virtual_nodes {
159 let hash = Self::hash_shard_vnode(shard_id.0, vnode);
160 self.ring.insert(hash, shard_id);
161 }
162 self.shard_count += 1;
163 }
164
165 pub fn remove_shard(&mut self, shard_id: ShardId) {
174 self.ring.retain(|_, &mut sid| sid != shard_id);
175 self.shard_count = self.shard_count.saturating_sub(1);
176 }
177
178 pub fn shard_count(&self) -> u32 {
180 self.shard_count
181 }
182
183 pub fn all_shards(&self) -> Vec<ShardId> {
187 let mut shards: Vec<ShardId> = self.ring.values().copied().collect();
188 shards.sort_by_key(|s| s.0);
189 shards.dedup();
190 shards
191 }
192
193 pub fn virtual_nodes_per_shard(&self) -> u32 {
195 self.virtual_nodes
196 }
197
198 pub fn total_virtual_nodes(&self) -> usize {
200 self.ring.len()
201 }
202
203 pub fn get_replica_shards(&self, doc_id: &DocumentId, replica_count: usize) -> Vec<ShardId> {
217 let hash = Self::hash_document(doc_id);
218 let mut shards = Vec::with_capacity(replica_count + 1);
219 let mut seen_shards = std::collections::HashSet::new();
220
221 for (&_pos, &shard_id) in self.ring.range(hash..).chain(self.ring.iter()) {
223 if seen_shards.insert(shard_id) {
224 shards.push(shard_id);
225 if shards.len() > replica_count {
226 break;
227 }
228 }
229 }
230
231 shards
232 }
233
234 fn hash_document(doc_id: &DocumentId) -> u64 {
236 use std::collections::hash_map::DefaultHasher;
237 let mut hasher = DefaultHasher::new();
238 doc_id.0.hash(&mut hasher);
239 hasher.finish()
240 }
241
242 fn hash_key<K: Hash>(key: &K) -> u64 {
244 use std::collections::hash_map::DefaultHasher;
245 let mut hasher = DefaultHasher::new();
246 key.hash(&mut hasher);
247 hasher.finish()
248 }
249
250 fn hash_shard_vnode(shard_id: u32, vnode: u32) -> u64 {
252 use std::collections::hash_map::DefaultHasher;
253 let mut hasher = DefaultHasher::new();
254 shard_id.hash(&mut hasher);
255 vnode.hash(&mut hasher);
256 hasher.finish()
257 }
258}
259
260impl Default for ConsistentHashRing {
261 fn default() -> Self {
262 Self::new(1)
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_new_ring() {
272 let ring = ConsistentHashRing::new(3);
273 assert_eq!(ring.shard_count(), 3);
274 assert_eq!(ring.all_shards().len(), 3);
275 assert_eq!(
276 ring.total_virtual_nodes(),
277 3 * VIRTUAL_NODES_PER_SHARD as usize
278 );
279 }
280
281 #[test]
282 fn test_distribution() {
283 let ring = ConsistentHashRing::new(3);
284
285 let mut counts = [0u32; 3];
287 for i in 0..100 {
288 let doc_id = DocumentId::from_seed(i);
289 let shard = ring.get_shard(&doc_id);
290 counts[shard.0 as usize] += 1;
291 }
292
293 for count in counts {
295 assert!(
296 count >= 20 && count <= 50,
297 "Distribution skewed: {:?}",
298 counts
299 );
300 }
301 }
302
303 #[test]
304 fn test_consistency() {
305 let ring = ConsistentHashRing::new(3);
306 let doc_id = DocumentId::from_seed(42);
307
308 let shard1 = ring.get_shard(&doc_id);
310 let shard2 = ring.get_shard(&doc_id);
311 assert_eq!(shard1, shard2);
312 }
313
314 #[test]
315 fn test_add_shard_minimal_redistribution() {
316 let mut ring = ConsistentHashRing::new(3);
317
318 let initial: Vec<ShardId> = (0..100)
320 .map(|i| ring.get_shard(&DocumentId::from_seed(i)))
321 .collect();
322
323 ring.add_shard(ShardId::new(3));
325
326 let mut moved = 0;
328 for i in 0..100 {
329 let doc_id = DocumentId::from_seed(i);
330 if ring.get_shard(&doc_id) != initial[i as usize] {
331 moved += 1;
332 }
333 }
334
335 assert!(moved <= 35, "Too many documents moved: {}", moved);
337 }
338
339 #[test]
340 fn test_remove_shard() {
341 let mut ring = ConsistentHashRing::new(3);
342 assert_eq!(ring.shard_count(), 3);
343
344 ring.remove_shard(ShardId::new(1));
345 assert_eq!(ring.shard_count(), 2);
346
347 let doc_id = DocumentId::from_seed(42);
349 let shard = ring.get_shard(&doc_id);
350 assert!(shard.0 != 1, "Document assigned to removed shard");
351 }
352
353 #[test]
354 fn test_replica_shards() {
355 let ring = ConsistentHashRing::new(5);
356 let doc_id = DocumentId::from_seed(42);
357
358 let replicas = ring.get_replica_shards(&doc_id, 2);
359 assert_eq!(replicas.len(), 3); let unique: std::collections::HashSet<_> = replicas.iter().collect();
363 assert_eq!(unique.len(), 3);
364 }
365
366 #[test]
367 fn test_custom_virtual_nodes() {
368 let ring = ConsistentHashRing::with_virtual_nodes(3, 50);
369 assert_eq!(ring.virtual_nodes_per_shard(), 50);
370 assert_eq!(ring.total_virtual_nodes(), 150);
371 }
372
373 #[test]
374 fn test_get_shard_for_key() {
375 let ring = ConsistentHashRing::new(3);
376
377 let shard1 = ring.get_shard_for_key(&"user:123");
379 let shard2 = ring.get_shard_for_key(&"user:123");
380 assert_eq!(shard1, shard2);
381
382 let shard3 = ring.get_shard_for_key(&"user:456");
384 let _ = shard3;
386 }
387
388 #[test]
389 #[should_panic(expected = "Number of shards must be greater than 0")]
390 fn test_zero_shards_panics() {
391 let _ = ConsistentHashRing::new(0);
392 }
393
394 #[test]
395 fn test_default() {
396 let ring = ConsistentHashRing::default();
397 assert_eq!(ring.shard_count(), 1);
398 }
399
400 #[test]
401 fn test_shard_id_display() {
402 let shard = ShardId::new(5);
403 assert_eq!(format!("{}", shard), "shard-5");
404 }
405}