1use crate::core::{
10 error::{RedisError, RedisResult},
11 types::{NodeInfo, SlotRange},
12};
13use crc16::{State, XMODEM};
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17
18pub const CLUSTER_SLOTS: u16 = 16384;
20
21pub fn calculate_slot(key: &[u8]) -> u16 {
28 let hash_key = extract_hash_tag(key);
29 State::<XMODEM>::calculate(hash_key) % CLUSTER_SLOTS
30}
31
32fn extract_hash_tag(key: &[u8]) -> &[u8] {
37 if let Some(start) = key.iter().position(|&b| b == b'{') {
39 if let Some(end) = key[start + 1..].iter().position(|&b| b == b'}') {
41 let end = start + 1 + end;
42 if end > start + 1 {
44 return &key[start + 1..end];
45 }
46 }
47 }
48 key
49}
50
51#[derive(Clone)]
53pub struct ClusterTopology {
54 slot_map: Arc<RwLock<HashMap<u16, (String, u16)>>>,
56 nodes: Arc<RwLock<HashMap<String, NodeInfo>>>,
58}
59
60impl ClusterTopology {
61 pub fn new() -> Self {
63 Self {
64 slot_map: Arc::new(RwLock::new(HashMap::new())),
65 nodes: Arc::new(RwLock::new(HashMap::new())),
66 }
67 }
68
69 pub async fn get_node_for_slot(&self, slot: u16) -> Option<(String, u16)> {
71 let slot_map = self.slot_map.read().await;
72 slot_map.get(&slot).cloned()
73 }
74
75 pub async fn get_node_for_key(&self, key: &[u8]) -> Option<(String, u16)> {
77 let slot = calculate_slot(key);
78 self.get_node_for_slot(slot).await
79 }
80
81 pub async fn update_slot_mapping(&self, slot: u16, host: String, port: u16) {
83 let mut slot_map = self.slot_map.write().await;
84 slot_map.insert(slot, (host, port));
85 }
86
87 pub async fn clear_slots(&self) {
89 let mut slot_map = self.slot_map.write().await;
90 slot_map.clear();
91 }
92
93 pub async fn update_from_cluster_slots(
98 &self,
99 slots_data: Vec<Vec<(i64, String, i64)>>,
100 ) -> RedisResult<()> {
101 let mut slot_map = self.slot_map.write().await;
102 let mut nodes = self.nodes.write().await;
103
104 slot_map.clear();
105 nodes.clear();
106
107 for slot_info in slots_data {
108 if slot_info.len() < 3 {
109 continue;
110 }
111
112 let start_slot = slot_info[0].0 as u16;
113 let end_slot = slot_info[1].0 as u16;
114 let master_host = slot_info[2].1.clone();
115 let master_port = slot_info[2].2 as u16;
116
117 for slot in start_slot..=end_slot {
119 slot_map.insert(slot, (master_host.clone(), master_port));
120 }
121
122 let node_key = format!("{}:{}", master_host, master_port);
124 let mut node = NodeInfo::new(node_key.clone(), master_host, master_port);
125 node.slots.push(SlotRange::new(start_slot, end_slot));
126 node.is_master = true;
127 nodes.insert(node_key, node);
128 }
129
130 Ok(())
131 }
132
133 pub async fn get_all_nodes(&self) -> Vec<NodeInfo> {
135 let nodes = self.nodes.read().await;
136 nodes.values().cloned().collect()
137 }
138
139 pub async fn mapped_slots_count(&self) -> usize {
141 let slot_map = self.slot_map.read().await;
142 slot_map.len()
143 }
144}
145
146impl Default for ClusterTopology {
147 fn default() -> Self {
148 Self::new()
149 }
150}
151
152#[derive(Clone)]
154pub struct RedirectHandler {
155 topology: ClusterTopology,
156 max_redirects: usize,
157}
158
159impl RedirectHandler {
160 pub fn new(topology: ClusterTopology, max_redirects: usize) -> Self {
162 Self {
163 topology,
164 max_redirects,
165 }
166 }
167
168 pub async fn handle_redirect(&self, error: &RedisError) -> RedisResult<(String, u16, bool)> {
170 match error {
171 RedisError::Moved { slot, host, port } => {
172 self.topology
174 .update_slot_mapping(*slot, host.clone(), *port)
175 .await;
176 Ok((host.clone(), *port, false))
177 }
178 RedisError::Ask { host, port, .. } => {
179 Ok((host.clone(), *port, true))
181 }
182 _ => Err(RedisError::Cluster(format!(
183 "Not a redirect error: {:?}",
184 error
185 ))),
186 }
187 }
188
189 pub fn max_redirects(&self) -> usize {
191 self.max_redirects
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 #[test]
200 fn test_calculate_slot() {
201 let slot = calculate_slot(b"mykey");
203 assert!(slot < CLUSTER_SLOTS);
204
205 let slot1 = calculate_slot(b"{user1000}.following");
207 let slot2 = calculate_slot(b"{user1000}.followers");
208 assert_eq!(
209 slot1, slot2,
210 "Keys with same hash tag should map to same slot"
211 );
212
213 let _slot3 = calculate_slot(b"{user1001}.following");
215 let slot = calculate_slot(b"123456789");
219 assert_eq!(slot, 12739);
220 }
221
222 #[test]
223 fn test_extract_hash_tag() {
224 assert_eq!(extract_hash_tag(b"key"), b"key");
225 assert_eq!(extract_hash_tag(b"{user}key"), b"user");
226 assert_eq!(extract_hash_tag(b"prefix{user}key"), b"user");
227 assert_eq!(extract_hash_tag(b"{user}"), b"user");
228 assert_eq!(extract_hash_tag(b"{}"), b"{}"); assert_eq!(extract_hash_tag(b"{"), b"{"); assert_eq!(extract_hash_tag(b"no{hash"), b"no{hash"); }
232
233 #[tokio::test]
234 async fn test_cluster_topology() {
235 let topology = ClusterTopology::new();
236
237 assert!(topology.get_node_for_slot(100).await.is_none());
239
240 topology
242 .update_slot_mapping(100, "localhost".to_string(), 6379)
243 .await;
244
245 let node = topology.get_node_for_slot(100).await;
247 assert_eq!(node, Some(("localhost".to_string(), 6379)));
248
249 topology.clear_slots().await;
251 assert!(topology.get_node_for_slot(100).await.is_none());
252 }
253
254 #[tokio::test]
255 async fn test_get_node_for_key() {
256 let topology = ClusterTopology::new();
257
258 let key = b"mykey";
259 let slot = calculate_slot(key);
260
261 topology
262 .update_slot_mapping(slot, "localhost".to_string(), 6379)
263 .await;
264
265 let node = topology.get_node_for_key(key).await;
266 assert_eq!(node, Some(("localhost".to_string(), 6379)));
267 }
268
269 #[tokio::test]
270 async fn test_redirect_handler() {
271 let topology = ClusterTopology::new();
272 let handler = RedirectHandler::new(topology.clone(), 3);
273
274 let error = RedisError::Moved {
276 slot: 9916,
277 host: "10.90.6.213".to_string(),
278 port: 6002,
279 };
280
281 let (host, port, is_ask) = handler.handle_redirect(&error).await.unwrap();
282 assert_eq!(host, "10.90.6.213");
283 assert_eq!(port, 6002);
284 assert!(!is_ask);
285
286 let node = topology.get_node_for_slot(9916).await;
288 assert_eq!(node, Some(("10.90.6.213".to_string(), 6002)));
289
290 let error = RedisError::Ask {
292 slot: 100,
293 host: "localhost".to_string(),
294 port: 7000,
295 };
296
297 let (host, port, is_ask) = handler.handle_redirect(&error).await.unwrap();
298 assert_eq!(host, "localhost");
299 assert_eq!(port, 7000);
300 assert!(is_ask);
301
302 assert!(topology.get_node_for_slot(100).await.is_none());
304 }
305}