1use crate::node::NodeId;
9use serde::{Deserialize, Serialize};
10use std::collections::BTreeMap;
11use std::hash::{Hash, Hasher};
12
13#[derive(Debug, Clone)]
19pub struct HashRing {
20 ring: BTreeMap<u64, VirtualNode>,
21 virtual_nodes_per_node: usize,
22 nodes: Vec<NodeId>,
23}
24
25impl HashRing {
26 pub fn new(virtual_nodes_per_node: usize) -> Self {
28 Self {
29 ring: BTreeMap::new(),
30 virtual_nodes_per_node,
31 nodes: Vec::new(),
32 }
33 }
34
35 pub fn default_ring() -> Self {
37 Self::new(150)
38 }
39
40 pub fn add_node(&mut self, node_id: NodeId) {
42 if self.nodes.contains(&node_id) {
43 return;
44 }
45
46 for i in 0..self.virtual_nodes_per_node {
47 let vnode = VirtualNode {
48 node_id: node_id.clone(),
49 replica_index: i,
50 };
51 let hash = vnode.hash_position();
52 self.ring.insert(hash, vnode);
53 }
54
55 self.nodes.push(node_id);
56 }
57
58 pub fn remove_node(&mut self, node_id: &NodeId) {
60 self.ring.retain(|_, vnode| &vnode.node_id != node_id);
61 self.nodes.retain(|n| n != node_id);
62 }
63
64 pub fn get_node(&self, key: &str) -> Option<&NodeId> {
66 if self.ring.is_empty() {
67 return None;
68 }
69
70 let hash = hash_key(key);
71
72 if let Some((_, vnode)) = self.ring.range(hash..).next() {
74 return Some(&vnode.node_id);
75 }
76
77 self.ring.values().next().map(|vnode| &vnode.node_id)
79 }
80
81 pub fn get_nodes(&self, key: &str, count: usize) -> Vec<&NodeId> {
83 if self.ring.is_empty() || count == 0 {
84 return Vec::new();
85 }
86
87 let hash = hash_key(key);
88 let mut result = Vec::with_capacity(count);
89 let mut seen = std::collections::HashSet::new();
90
91 for (_, vnode) in self.ring.range(hash..) {
93 if seen.insert(&vnode.node_id) {
94 result.push(&vnode.node_id);
95 if result.len() >= count {
96 return result;
97 }
98 }
99 }
100
101 for (_, vnode) in self.ring.iter() {
103 if seen.insert(&vnode.node_id) {
104 result.push(&vnode.node_id);
105 if result.len() >= count {
106 return result;
107 }
108 }
109 }
110
111 result
112 }
113
114 pub fn nodes(&self) -> &[NodeId] {
116 &self.nodes
117 }
118
119 pub fn node_count(&self) -> usize {
121 self.nodes.len()
122 }
123
124 pub fn virtual_node_count(&self) -> usize {
126 self.ring.len()
127 }
128
129 pub fn is_empty(&self) -> bool {
131 self.nodes.is_empty()
132 }
133
134 pub fn key_position(&self, key: &str) -> u64 {
136 hash_key(key)
137 }
138}
139
140impl Default for HashRing {
141 fn default() -> Self {
142 Self::default_ring()
143 }
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct VirtualNode {
153 pub node_id: NodeId,
154 pub replica_index: usize,
155}
156
157impl VirtualNode {
158 pub fn hash_position(&self) -> u64 {
160 let key = format!("{}:{}", self.node_id.as_str(), self.replica_index);
161 hash_key(&key)
162 }
163}
164
165pub trait ConsistentHash {
171 fn route(&self, key: &str) -> Option<NodeId>;
173
174 fn route_replicas(&self, key: &str, count: usize) -> Vec<NodeId>;
176
177 fn add(&mut self, node: NodeId);
179
180 fn remove(&mut self, node: &NodeId);
182}
183
184impl ConsistentHash for HashRing {
185 fn route(&self, key: &str) -> Option<NodeId> {
186 self.get_node(key).cloned()
187 }
188
189 fn route_replicas(&self, key: &str, count: usize) -> Vec<NodeId> {
190 self.get_nodes(key, count).into_iter().cloned().collect()
191 }
192
193 fn add(&mut self, node: NodeId) {
194 self.add_node(node);
195 }
196
197 fn remove(&mut self, node: &NodeId) {
198 self.remove_node(node);
199 }
200}
201
202pub struct JumpHash {
208 num_buckets: u32,
209}
210
211impl JumpHash {
212 pub fn new(num_buckets: u32) -> Self {
214 Self { num_buckets }
215 }
216
217 pub fn bucket(&self, key: &str) -> u32 {
219 let hash = hash_key(key);
220 jump_consistent_hash(hash, self.num_buckets)
221 }
222
223 pub fn bucket_u64(&self, key: u64) -> u32 {
225 jump_consistent_hash(key, self.num_buckets)
226 }
227}
228
229fn jump_consistent_hash(mut key: u64, num_buckets: u32) -> u32 {
231 let mut b: i64 = -1;
232 let mut j: i64 = 0;
233
234 while j < num_buckets as i64 {
235 b = j;
236 key = key.wrapping_mul(2862933555777941757).wrapping_add(1);
237 j = ((b.wrapping_add(1) as f64) * (1i64 << 31) as f64
238 / ((key >> 33).wrapping_add(1) as f64)) as i64;
239 }
240
241 b as u32
242}
243
244pub struct RendezvousHash {
250 nodes: Vec<NodeId>,
251}
252
253impl RendezvousHash {
254 pub fn new() -> Self {
256 Self { nodes: Vec::new() }
257 }
258
259 pub fn add_node(&mut self, node: NodeId) {
261 if !self.nodes.contains(&node) {
262 self.nodes.push(node);
263 }
264 }
265
266 pub fn remove_node(&mut self, node: &NodeId) {
268 self.nodes.retain(|n| n != node);
269 }
270
271 pub fn get_node(&self, key: &str) -> Option<&NodeId> {
273 self.nodes.iter().max_by_key(|node| {
274 let combined = format!("{}:{}", key, node.as_str());
275 hash_key(&combined)
276 })
277 }
278
279 pub fn get_nodes(&self, key: &str, count: usize) -> Vec<&NodeId> {
281 let mut weighted: Vec<_> = self
282 .nodes
283 .iter()
284 .map(|node| {
285 let combined = format!("{}:{}", key, node.as_str());
286 (hash_key(&combined), node)
287 })
288 .collect();
289
290 weighted.sort_by(|a, b| b.0.cmp(&a.0));
291
292 weighted
293 .into_iter()
294 .take(count)
295 .map(|(_, node)| node)
296 .collect()
297 }
298}
299
300impl Default for RendezvousHash {
301 fn default() -> Self {
302 Self::new()
303 }
304}
305
306fn hash_key(key: &str) -> u64 {
312 let mut hasher = XxHasher::new();
313 key.hash(&mut hasher);
314 hasher.finish()
315}
316
317struct XxHasher {
319 state: u64,
320}
321
322impl XxHasher {
323 fn new() -> Self {
324 Self {
325 state: 0xcbf29ce484222325,
326 }
327 }
328}
329
330impl Hasher for XxHasher {
331 fn finish(&self) -> u64 {
332 self.state
333 }
334
335 fn write(&mut self, bytes: &[u8]) {
336 for byte in bytes {
337 self.state ^= *byte as u64;
338 self.state = self.state.wrapping_mul(0x100000001b3);
339 }
340 }
341}
342
343#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
352 fn test_hash_ring_basic() {
353 let mut ring = HashRing::new(10);
354
355 ring.add_node(NodeId::new("node1"));
356 ring.add_node(NodeId::new("node2"));
357 ring.add_node(NodeId::new("node3"));
358
359 assert_eq!(ring.node_count(), 3);
360 assert_eq!(ring.virtual_node_count(), 30);
361 }
362
363 #[test]
364 fn test_hash_ring_get_node() {
365 let mut ring = HashRing::new(100);
366
367 ring.add_node(NodeId::new("node1"));
368 ring.add_node(NodeId::new("node2"));
369 ring.add_node(NodeId::new("node3"));
370
371 let node = ring.get_node("test_key").unwrap();
372 assert!(["node1", "node2", "node3"].contains(&node.as_str()));
373
374 let node2 = ring.get_node("test_key").unwrap();
376 assert_eq!(node, node2);
377 }
378
379 #[test]
380 fn test_hash_ring_distribution() {
381 let mut ring = HashRing::new(150);
382
383 ring.add_node(NodeId::new("node1"));
384 ring.add_node(NodeId::new("node2"));
385 ring.add_node(NodeId::new("node3"));
386
387 let mut counts = std::collections::HashMap::new();
388
389 for i in 0..1000 {
390 let key = format!("key_{}", i);
391 let node = ring.get_node(&key).unwrap();
392 *counts.entry(node.as_str().to_string()).or_insert(0) += 1;
393 }
394
395 for count in counts.values() {
397 assert!(*count > 200, "Distribution too uneven: {:?}", counts);
398 assert!(*count < 500, "Distribution too uneven: {:?}", counts);
399 }
400 }
401
402 #[test]
403 fn test_hash_ring_remove_node() {
404 let mut ring = HashRing::new(10);
405
406 ring.add_node(NodeId::new("node1"));
407 ring.add_node(NodeId::new("node2"));
408
409 let key = "test_key";
410 let _before = ring.get_node(key).unwrap().clone();
411
412 ring.remove_node(&NodeId::new("node1"));
413 assert_eq!(ring.node_count(), 1);
414
415 let after = ring.get_node(key).unwrap();
416 assert_eq!(after.as_str(), "node2");
417 }
418
419 #[test]
420 fn test_hash_ring_get_replicas() {
421 let mut ring = HashRing::new(50);
422
423 ring.add_node(NodeId::new("node1"));
424 ring.add_node(NodeId::new("node2"));
425 ring.add_node(NodeId::new("node3"));
426
427 let nodes = ring.get_nodes("test_key", 2);
428 assert_eq!(nodes.len(), 2);
429 assert_ne!(nodes[0], nodes[1]);
430 }
431
432 #[test]
433 fn test_jump_hash() {
434 let hash = JumpHash::new(10);
435
436 let bucket1 = hash.bucket("key1");
437 let bucket2 = hash.bucket("key1");
438
439 assert_eq!(bucket1, bucket2);
440 assert!(bucket1 < 10);
441 }
442
443 #[test]
444 fn test_jump_hash_distribution() {
445 let hash = JumpHash::new(5);
446 let mut counts = vec![0; 5];
447
448 for i in 0..1000 {
449 let bucket = hash.bucket(&format!("key_{}", i)) as usize;
450 counts[bucket] += 1;
451 }
452
453 for count in &counts {
455 assert!(*count > 100, "Jump hash distribution uneven: {:?}", counts);
456 assert!(*count < 300, "Jump hash distribution uneven: {:?}", counts);
457 }
458 }
459
460 #[test]
461 fn test_rendezvous_hash() {
462 let mut hash = RendezvousHash::new();
463
464 hash.add_node(NodeId::new("node1"));
465 hash.add_node(NodeId::new("node2"));
466 hash.add_node(NodeId::new("node3"));
467
468 let node = hash.get_node("test_key").unwrap();
469 assert!(["node1", "node2", "node3"].contains(&node.as_str()));
470
471 let node2 = hash.get_node("test_key").unwrap();
473 assert_eq!(node, node2);
474 }
475
476 #[test]
477 fn test_rendezvous_get_multiple() {
478 let mut hash = RendezvousHash::new();
479
480 hash.add_node(NodeId::new("node1"));
481 hash.add_node(NodeId::new("node2"));
482 hash.add_node(NodeId::new("node3"));
483
484 let nodes = hash.get_nodes("test_key", 2);
485 assert_eq!(nodes.len(), 2);
486 assert_ne!(nodes[0], nodes[1]);
487 }
488
489 #[test]
490 fn test_consistent_hash_trait() {
491 let mut ring = HashRing::new(50);
492
493 ring.add(NodeId::new("node1"));
494 ring.add(NodeId::new("node2"));
495
496 let node = ring.route("key").unwrap();
497 assert!(["node1", "node2"].contains(&node.as_str()));
498
499 let replicas = ring.route_replicas("key", 2);
500 assert_eq!(replicas.len(), 2);
501 }
502}