pollen_router/
hashring.rs1use pollen_types::NodeId;
4use std::collections::BTreeMap;
5use std::hash::{Hash, Hasher};
6
7pub struct HashRing {
9 ring: BTreeMap<u64, NodeId>,
11 replicas: usize,
13 nodes: Vec<NodeId>,
15}
16
17impl HashRing {
18 pub fn new(replicas: usize) -> Self {
20 Self {
21 ring: BTreeMap::new(),
22 replicas,
23 nodes: Vec::new(),
24 }
25 }
26
27 pub fn add(&mut self, node: NodeId) {
29 if self.nodes.contains(&node) {
30 return;
31 }
32
33 self.nodes.push(node);
34
35 for i in 0..self.replicas {
36 let key = format!("{}:{}", node, i);
37 let hash = self.hash(key.as_bytes());
38 self.ring.insert(hash, node);
39 }
40 }
41
42 pub fn remove(&mut self, node: NodeId) {
44 self.nodes.retain(|n| *n != node);
45
46 for i in 0..self.replicas {
47 let key = format!("{}:{}", node, i);
48 let hash = self.hash(key.as_bytes());
49 self.ring.remove(&hash);
50 }
51 }
52
53 pub fn clear(&mut self) {
55 self.ring.clear();
56 self.nodes.clear();
57 }
58
59 pub fn get(&self, key: &[u8]) -> Option<&NodeId> {
61 if self.ring.is_empty() {
62 return None;
63 }
64
65 let hash = self.hash(key);
66
67 if let Some((_, node)) = self.ring.range(hash..).next() {
69 return Some(node);
70 }
71
72 self.ring.values().next()
74 }
75
76 pub fn get_n(&self, key: &[u8], n: usize) -> Vec<NodeId> {
78 if self.ring.is_empty() || n == 0 {
79 return vec![];
80 }
81
82 let hash = self.hash(key);
83 let mut result = Vec::with_capacity(n.min(self.nodes.len()));
84 let mut seen = std::collections::HashSet::new();
85
86 for (_, node) in self.ring.range(hash..).chain(self.ring.range(..hash)) {
88 if seen.insert(*node) {
89 result.push(*node);
90 if result.len() >= n || result.len() >= self.nodes.len() {
91 break;
92 }
93 }
94 }
95
96 result
97 }
98
99 pub fn is_empty(&self) -> bool {
101 self.ring.is_empty()
102 }
103
104 pub fn len(&self) -> usize {
106 self.nodes.len()
107 }
108
109 pub fn nodes(&self) -> &[NodeId] {
111 &self.nodes
112 }
113
114 fn hash(&self, key: &[u8]) -> u64 {
116 use std::collections::hash_map::DefaultHasher;
117 let mut hasher = DefaultHasher::new();
118 key.hash(&mut hasher);
119 hasher.finish()
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn test_empty_ring() {
129 let ring = HashRing::new(10);
130 assert!(ring.get(b"test").is_none());
131 assert!(ring.is_empty());
132 }
133
134 #[test]
135 fn test_single_node() {
136 let mut ring = HashRing::new(10);
137 let node = NodeId::from_raw(1);
138 ring.add(node);
139
140 assert_eq!(ring.get(b"key1"), Some(&node));
142 assert_eq!(ring.get(b"key2"), Some(&node));
143 assert_eq!(ring.get(b"key3"), Some(&node));
144 }
145
146 #[test]
147 fn test_multiple_nodes() {
148 let mut ring = HashRing::new(100);
149 let nodes: Vec<_> = (1..=3).map(NodeId::from_raw).collect();
150
151 for node in &nodes {
152 ring.add(*node);
153 }
154
155 assert_eq!(ring.len(), 3);
156
157 let mut distribution = std::collections::HashMap::new();
159 for i in 0..1000 {
160 let key = format!("key{}", i);
161 if let Some(node) = ring.get(key.as_bytes()) {
162 *distribution.entry(*node).or_insert(0) += 1;
163 }
164 }
165
166 for node in &nodes {
168 assert!(distribution.get(node).unwrap_or(&0) > &0);
169 }
170 }
171
172 #[test]
173 fn test_get_n() {
174 let mut ring = HashRing::new(100);
175 let nodes: Vec<_> = (1..=5).map(NodeId::from_raw).collect();
176
177 for node in &nodes {
178 ring.add(*node);
179 }
180
181 let replicas = ring.get_n(b"test", 3);
183 assert_eq!(replicas.len(), 3);
184
185 let unique: std::collections::HashSet<_> = replicas.iter().collect();
187 assert_eq!(unique.len(), 3);
188 }
189
190 #[test]
191 fn test_node_removal() {
192 let mut ring = HashRing::new(100);
193 let node1 = NodeId::from_raw(1);
194 let node2 = NodeId::from_raw(2);
195
196 ring.add(node1);
197 ring.add(node2);
198
199 ring.remove(node1);
201 assert_eq!(ring.len(), 1);
202
203 assert_eq!(ring.get(b"test"), Some(&node2));
205 }
206
207 #[test]
208 fn test_consistent_hashing() {
209 let mut ring = HashRing::new(100);
210 let nodes: Vec<_> = (1..=3).map(NodeId::from_raw).collect();
211
212 for node in &nodes {
213 ring.add(*node);
214 }
215
216 let mut assignments: std::collections::HashMap<String, NodeId> = std::collections::HashMap::new();
218 for i in 0..100 {
219 let key = format!("key{}", i);
220 if let Some(node) = ring.get(key.as_bytes()) {
221 assignments.insert(key, *node);
222 }
223 }
224
225 ring.add(NodeId::from_raw(4));
227
228 let mut unchanged = 0;
230 for (key, old_node) in &assignments {
231 if let Some(new_node) = ring.get(key.as_bytes()) {
232 if *new_node == *old_node {
233 unchanged += 1;
234 }
235 }
236 }
237
238 assert!(unchanged >= 70);
240 }
241}