Skip to main content

pollen_router/
hashring.rs

1//! Consistent hash ring implementation.
2
3use pollen_types::NodeId;
4use std::collections::BTreeMap;
5use std::hash::{Hash, Hasher};
6
7/// Consistent hash ring with virtual nodes.
8pub struct HashRing {
9    /// Ring: position -> node_id
10    ring: BTreeMap<u64, NodeId>,
11    /// Number of virtual nodes per physical node.
12    replicas: usize,
13    /// Physical nodes.
14    nodes: Vec<NodeId>,
15}
16
17impl HashRing {
18    /// Create a new hash ring.
19    pub fn new(replicas: usize) -> Self {
20        Self {
21            ring: BTreeMap::new(),
22            replicas,
23            nodes: Vec::new(),
24        }
25    }
26
27    /// Add a node to the ring.
28    pub fn add(&mut self, node: NodeId) {
29        if self.nodes.contains(&node) {
30            return;
31        }
32
33        self.nodes.push(node);
34
35        for i in 0..self.replicas {
36            let key = format!("{}:{}", node, i);
37            let hash = self.hash(key.as_bytes());
38            self.ring.insert(hash, node);
39        }
40    }
41
42    /// Remove a node from the ring.
43    pub fn remove(&mut self, node: NodeId) {
44        self.nodes.retain(|n| *n != node);
45
46        for i in 0..self.replicas {
47            let key = format!("{}:{}", node, i);
48            let hash = self.hash(key.as_bytes());
49            self.ring.remove(&hash);
50        }
51    }
52
53    /// Clear all nodes from the ring.
54    pub fn clear(&mut self) {
55        self.ring.clear();
56        self.nodes.clear();
57    }
58
59    /// Get the node responsible for a key.
60    pub fn get(&self, key: &[u8]) -> Option<&NodeId> {
61        if self.ring.is_empty() {
62            return None;
63        }
64
65        let hash = self.hash(key);
66
67        // Find the first node with position >= hash
68        if let Some((_, node)) = self.ring.range(hash..).next() {
69            return Some(node);
70        }
71
72        // Wrap around to the first node
73        self.ring.values().next()
74    }
75
76    /// Get N unique nodes for a key (for replication).
77    pub fn get_n(&self, key: &[u8], n: usize) -> Vec<NodeId> {
78        if self.ring.is_empty() || n == 0 {
79            return vec![];
80        }
81
82        let hash = self.hash(key);
83        let mut result = Vec::with_capacity(n.min(self.nodes.len()));
84        let mut seen = std::collections::HashSet::new();
85
86        // Start from the hash position and walk around the ring
87        for (_, node) in self.ring.range(hash..).chain(self.ring.range(..hash)) {
88            if seen.insert(*node) {
89                result.push(*node);
90                if result.len() >= n || result.len() >= self.nodes.len() {
91                    break;
92                }
93            }
94        }
95
96        result
97    }
98
99    /// Check if the ring is empty.
100    pub fn is_empty(&self) -> bool {
101        self.ring.is_empty()
102    }
103
104    /// Get the number of physical nodes.
105    pub fn len(&self) -> usize {
106        self.nodes.len()
107    }
108
109    /// Get all physical nodes.
110    pub fn nodes(&self) -> &[NodeId] {
111        &self.nodes
112    }
113
114    /// Hash a key to a ring position.
115    fn hash(&self, key: &[u8]) -> u64 {
116        use std::collections::hash_map::DefaultHasher;
117        let mut hasher = DefaultHasher::new();
118        key.hash(&mut hasher);
119        hasher.finish()
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn test_empty_ring() {
129        let ring = HashRing::new(10);
130        assert!(ring.get(b"test").is_none());
131        assert!(ring.is_empty());
132    }
133
134    #[test]
135    fn test_single_node() {
136        let mut ring = HashRing::new(10);
137        let node = NodeId::from_raw(1);
138        ring.add(node);
139
140        // All keys should map to the single node
141        assert_eq!(ring.get(b"key1"), Some(&node));
142        assert_eq!(ring.get(b"key2"), Some(&node));
143        assert_eq!(ring.get(b"key3"), Some(&node));
144    }
145
146    #[test]
147    fn test_multiple_nodes() {
148        let mut ring = HashRing::new(100);
149        let nodes: Vec<_> = (1..=3).map(NodeId::from_raw).collect();
150
151        for node in &nodes {
152            ring.add(*node);
153        }
154
155        assert_eq!(ring.len(), 3);
156
157        // Keys should be distributed across nodes
158        let mut distribution = std::collections::HashMap::new();
159        for i in 0..1000 {
160            let key = format!("key{}", i);
161            if let Some(node) = ring.get(key.as_bytes()) {
162                *distribution.entry(*node).or_insert(0) += 1;
163            }
164        }
165
166        // Each node should have some keys
167        for node in &nodes {
168            assert!(distribution.get(node).unwrap_or(&0) > &0);
169        }
170    }
171
172    #[test]
173    fn test_get_n() {
174        let mut ring = HashRing::new(100);
175        let nodes: Vec<_> = (1..=5).map(NodeId::from_raw).collect();
176
177        for node in &nodes {
178            ring.add(*node);
179        }
180
181        // Should get 3 unique nodes
182        let replicas = ring.get_n(b"test", 3);
183        assert_eq!(replicas.len(), 3);
184
185        // All should be unique
186        let unique: std::collections::HashSet<_> = replicas.iter().collect();
187        assert_eq!(unique.len(), 3);
188    }
189
190    #[test]
191    fn test_node_removal() {
192        let mut ring = HashRing::new(100);
193        let node1 = NodeId::from_raw(1);
194        let node2 = NodeId::from_raw(2);
195
196        ring.add(node1);
197        ring.add(node2);
198
199        // Remove node1
200        ring.remove(node1);
201        assert_eq!(ring.len(), 1);
202
203        // All keys should now map to node2
204        assert_eq!(ring.get(b"test"), Some(&node2));
205    }
206
207    #[test]
208    fn test_consistent_hashing() {
209        let mut ring = HashRing::new(100);
210        let nodes: Vec<_> = (1..=3).map(NodeId::from_raw).collect();
211
212        for node in &nodes {
213            ring.add(*node);
214        }
215
216        // Record where keys are assigned
217        let mut assignments: std::collections::HashMap<String, NodeId> = std::collections::HashMap::new();
218        for i in 0..100 {
219            let key = format!("key{}", i);
220            if let Some(node) = ring.get(key.as_bytes()) {
221                assignments.insert(key, *node);
222            }
223        }
224
225        // Add a new node
226        ring.add(NodeId::from_raw(4));
227
228        // Most keys should stay on the same node
229        let mut unchanged = 0;
230        for (key, old_node) in &assignments {
231            if let Some(new_node) = ring.get(key.as_bytes()) {
232                if *new_node == *old_node {
233                    unchanged += 1;
234                }
235            }
236        }
237
238        // At least 70% should be unchanged (typically much higher)
239        assert!(unchanged >= 70);
240    }
241}