1use crate::node::NodeId;
9use serde::{Deserialize, Serialize};
10use std::collections::BTreeMap;
11use std::hash::{Hash, Hasher};
12
13#[derive(Debug, Clone)]
19pub struct HashRing {
20 ring: BTreeMap<u64, VirtualNode>,
21 virtual_nodes_per_node: usize,
22 nodes: Vec<NodeId>,
23}
24
25impl HashRing {
26 pub fn new(virtual_nodes_per_node: usize) -> Self {
28 Self {
29 ring: BTreeMap::new(),
30 virtual_nodes_per_node,
31 nodes: Vec::new(),
32 }
33 }
34
35 pub fn default_ring() -> Self {
37 Self::new(150)
38 }
39
40 pub fn add_node(&mut self, node_id: NodeId) {
42 if self.nodes.contains(&node_id) {
43 return;
44 }
45
46 for i in 0..self.virtual_nodes_per_node {
47 let vnode = VirtualNode {
48 node_id: node_id.clone(),
49 replica_index: i,
50 };
51 let hash = vnode.hash_position();
52 self.ring.insert(hash, vnode);
53 }
54
55 self.nodes.push(node_id);
56 }
57
58 pub fn remove_node(&mut self, node_id: &NodeId) {
60 self.ring.retain(|_, vnode| &vnode.node_id != node_id);
61 self.nodes.retain(|n| n != node_id);
62 }
63
64 pub fn get_node(&self, key: &str) -> Option<&NodeId> {
66 if self.ring.is_empty() {
67 return None;
68 }
69
70 let hash = hash_key(key);
71
72 if let Some((_, vnode)) = self.ring.range(hash..).next() {
74 return Some(&vnode.node_id);
75 }
76
77 self.ring.values().next().map(|vnode| &vnode.node_id)
79 }
80
81 pub fn get_nodes(&self, key: &str, count: usize) -> Vec<&NodeId> {
83 if self.ring.is_empty() || count == 0 {
84 return Vec::new();
85 }
86
87 let hash = hash_key(key);
88 let mut result = Vec::with_capacity(count);
89 let mut seen = std::collections::HashSet::new();
90
91 for (_, vnode) in self.ring.range(hash..) {
93 if seen.insert(&vnode.node_id) {
94 result.push(&vnode.node_id);
95 if result.len() >= count {
96 return result;
97 }
98 }
99 }
100
101 for (_, vnode) in self.ring.iter() {
103 if seen.insert(&vnode.node_id) {
104 result.push(&vnode.node_id);
105 if result.len() >= count {
106 return result;
107 }
108 }
109 }
110
111 result
112 }
113
114 pub fn nodes(&self) -> &[NodeId] {
116 &self.nodes
117 }
118
119 pub fn node_count(&self) -> usize {
121 self.nodes.len()
122 }
123
124 pub fn virtual_node_count(&self) -> usize {
126 self.ring.len()
127 }
128
129 pub fn is_empty(&self) -> bool {
131 self.nodes.is_empty()
132 }
133
134 pub fn key_position(&self, key: &str) -> u64 {
136 hash_key(key)
137 }
138}
139
140impl Default for HashRing {
141 fn default() -> Self {
142 Self::default_ring()
143 }
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct VirtualNode {
153 pub node_id: NodeId,
154 pub replica_index: usize,
155}
156
157impl VirtualNode {
158 pub fn hash_position(&self) -> u64 {
160 let key = format!("{}:{}", self.node_id.as_str(), self.replica_index);
161 hash_key(&key)
162 }
163}
164
165pub trait ConsistentHash {
171 fn route(&self, key: &str) -> Option<NodeId>;
173
174 fn route_replicas(&self, key: &str, count: usize) -> Vec<NodeId>;
176
177 fn add(&mut self, node: NodeId);
179
180 fn remove(&mut self, node: &NodeId);
182}
183
184impl ConsistentHash for HashRing {
185 fn route(&self, key: &str) -> Option<NodeId> {
186 self.get_node(key).cloned()
187 }
188
189 fn route_replicas(&self, key: &str, count: usize) -> Vec<NodeId> {
190 self.get_nodes(key, count).into_iter().cloned().collect()
191 }
192
193 fn add(&mut self, node: NodeId) {
194 self.add_node(node);
195 }
196
197 fn remove(&mut self, node: &NodeId) {
198 self.remove_node(node);
199 }
200}
201
202pub struct JumpHash {
208 num_buckets: u32,
209}
210
211impl JumpHash {
212 pub fn new(num_buckets: u32) -> Self {
214 Self { num_buckets }
215 }
216
217 pub fn bucket(&self, key: &str) -> u32 {
219 let hash = hash_key(key);
220 jump_consistent_hash(hash, self.num_buckets)
221 }
222
223 pub fn bucket_u64(&self, key: u64) -> u32 {
225 jump_consistent_hash(key, self.num_buckets)
226 }
227}
228
229fn jump_consistent_hash(mut key: u64, num_buckets: u32) -> u32 {
231 let mut b: i64 = -1;
232 let mut j: i64 = 0;
233
234 while j < num_buckets as i64 {
235 b = j;
236 key = key.wrapping_mul(2862933555777941757).wrapping_add(1);
237 j = ((b.wrapping_add(1) as f64) * (1i64 << 31) as f64
238 / ((key >> 33).wrapping_add(1) as f64)) as i64;
239 }
240
241 b as u32
242}
243
244pub struct RendezvousHash {
250 nodes: Vec<NodeId>,
251}
252
253impl RendezvousHash {
254 pub fn new() -> Self {
256 Self { nodes: Vec::new() }
257 }
258
259 pub fn add_node(&mut self, node: NodeId) {
261 if !self.nodes.contains(&node) {
262 self.nodes.push(node);
263 }
264 }
265
266 pub fn remove_node(&mut self, node: &NodeId) {
268 self.nodes.retain(|n| n != node);
269 }
270
271 pub fn get_node(&self, key: &str) -> Option<&NodeId> {
273 self.nodes
274 .iter()
275 .max_by_key(|node| {
276 let combined = format!("{}:{}", key, node.as_str());
277 hash_key(&combined)
278 })
279 }
280
281 pub fn get_nodes(&self, key: &str, count: usize) -> Vec<&NodeId> {
283 let mut weighted: Vec<_> = self
284 .nodes
285 .iter()
286 .map(|node| {
287 let combined = format!("{}:{}", key, node.as_str());
288 (hash_key(&combined), node)
289 })
290 .collect();
291
292 weighted.sort_by(|a, b| b.0.cmp(&a.0));
293
294 weighted.into_iter().take(count).map(|(_, node)| node).collect()
295 }
296}
297
298impl Default for RendezvousHash {
299 fn default() -> Self {
300 Self::new()
301 }
302}
303
304fn hash_key(key: &str) -> u64 {
310 let mut hasher = XxHasher::new();
311 key.hash(&mut hasher);
312 hasher.finish()
313}
314
315struct XxHasher {
317 state: u64,
318}
319
320impl XxHasher {
321 fn new() -> Self {
322 Self {
323 state: 0xcbf29ce484222325,
324 }
325 }
326}
327
328impl Hasher for XxHasher {
329 fn finish(&self) -> u64 {
330 self.state
331 }
332
333 fn write(&mut self, bytes: &[u8]) {
334 for byte in bytes {
335 self.state ^= *byte as u64;
336 self.state = self.state.wrapping_mul(0x100000001b3);
337 }
338 }
339}
340
341#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_hash_ring_basic() {
351 let mut ring = HashRing::new(10);
352
353 ring.add_node(NodeId::new("node1"));
354 ring.add_node(NodeId::new("node2"));
355 ring.add_node(NodeId::new("node3"));
356
357 assert_eq!(ring.node_count(), 3);
358 assert_eq!(ring.virtual_node_count(), 30);
359 }
360
361 #[test]
362 fn test_hash_ring_get_node() {
363 let mut ring = HashRing::new(100);
364
365 ring.add_node(NodeId::new("node1"));
366 ring.add_node(NodeId::new("node2"));
367 ring.add_node(NodeId::new("node3"));
368
369 let node = ring.get_node("test_key").unwrap();
370 assert!(["node1", "node2", "node3"].contains(&node.as_str()));
371
372 let node2 = ring.get_node("test_key").unwrap();
374 assert_eq!(node, node2);
375 }
376
377 #[test]
378 fn test_hash_ring_distribution() {
379 let mut ring = HashRing::new(150);
380
381 ring.add_node(NodeId::new("node1"));
382 ring.add_node(NodeId::new("node2"));
383 ring.add_node(NodeId::new("node3"));
384
385 let mut counts = std::collections::HashMap::new();
386
387 for i in 0..1000 {
388 let key = format!("key_{}", i);
389 let node = ring.get_node(&key).unwrap();
390 *counts.entry(node.as_str().to_string()).or_insert(0) += 1;
391 }
392
393 for count in counts.values() {
395 assert!(*count > 200, "Distribution too uneven: {:?}", counts);
396 assert!(*count < 500, "Distribution too uneven: {:?}", counts);
397 }
398 }
399
400 #[test]
401 fn test_hash_ring_remove_node() {
402 let mut ring = HashRing::new(10);
403
404 ring.add_node(NodeId::new("node1"));
405 ring.add_node(NodeId::new("node2"));
406
407 let key = "test_key";
408 let _before = ring.get_node(key).unwrap().clone();
409
410 ring.remove_node(&NodeId::new("node1"));
411 assert_eq!(ring.node_count(), 1);
412
413 let after = ring.get_node(key).unwrap();
414 assert_eq!(after.as_str(), "node2");
415 }
416
417 #[test]
418 fn test_hash_ring_get_replicas() {
419 let mut ring = HashRing::new(50);
420
421 ring.add_node(NodeId::new("node1"));
422 ring.add_node(NodeId::new("node2"));
423 ring.add_node(NodeId::new("node3"));
424
425 let nodes = ring.get_nodes("test_key", 2);
426 assert_eq!(nodes.len(), 2);
427 assert_ne!(nodes[0], nodes[1]);
428 }
429
430 #[test]
431 fn test_jump_hash() {
432 let hash = JumpHash::new(10);
433
434 let bucket1 = hash.bucket("key1");
435 let bucket2 = hash.bucket("key1");
436
437 assert_eq!(bucket1, bucket2);
438 assert!(bucket1 < 10);
439 }
440
441 #[test]
442 fn test_jump_hash_distribution() {
443 let hash = JumpHash::new(5);
444 let mut counts = vec![0; 5];
445
446 for i in 0..1000 {
447 let bucket = hash.bucket(&format!("key_{}", i)) as usize;
448 counts[bucket] += 1;
449 }
450
451 for count in &counts {
453 assert!(*count > 100, "Jump hash distribution uneven: {:?}", counts);
454 assert!(*count < 300, "Jump hash distribution uneven: {:?}", counts);
455 }
456 }
457
458 #[test]
459 fn test_rendezvous_hash() {
460 let mut hash = RendezvousHash::new();
461
462 hash.add_node(NodeId::new("node1"));
463 hash.add_node(NodeId::new("node2"));
464 hash.add_node(NodeId::new("node3"));
465
466 let node = hash.get_node("test_key").unwrap();
467 assert!(["node1", "node2", "node3"].contains(&node.as_str()));
468
469 let node2 = hash.get_node("test_key").unwrap();
471 assert_eq!(node, node2);
472 }
473
474 #[test]
475 fn test_rendezvous_get_multiple() {
476 let mut hash = RendezvousHash::new();
477
478 hash.add_node(NodeId::new("node1"));
479 hash.add_node(NodeId::new("node2"));
480 hash.add_node(NodeId::new("node3"));
481
482 let nodes = hash.get_nodes("test_key", 2);
483 assert_eq!(nodes.len(), 2);
484 assert_ne!(nodes[0], nodes[1]);
485 }
486
487 #[test]
488 fn test_consistent_hash_trait() {
489 let mut ring = HashRing::new(50);
490
491 ring.add(NodeId::new("node1"));
492 ring.add(NodeId::new("node2"));
493
494 let node = ring.route("key").unwrap();
495 assert!(["node1", "node2"].contains(&node.as_str()));
496
497 let replicas = ring.route_replicas("key", 2);
498 assert_eq!(replicas.len(), 2);
499 }
500}