1use crate::core::{
10    error::{RedisError, RedisResult},
11    types::{NodeInfo, SlotRange},
12};
13use crc16::{State, XMODEM};
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17
18pub const CLUSTER_SLOTS: u16 = 16384;
20
21pub fn calculate_slot(key: &[u8]) -> u16 {
28    let hash_key = extract_hash_tag(key);
29    State::<XMODEM>::calculate(hash_key) % CLUSTER_SLOTS
30}
31
32fn extract_hash_tag(key: &[u8]) -> &[u8] {
37    if let Some(start) = key.iter().position(|&b| b == b'{') {
39        if let Some(end) = key[start + 1..].iter().position(|&b| b == b'}') {
41            let end = start + 1 + end;
42            if end > start + 1 {
44                return &key[start + 1..end];
45            }
46        }
47    }
48    key
49}
50
51#[derive(Clone)]
53pub struct ClusterTopology {
54    slot_map: Arc<RwLock<HashMap<u16, (String, u16)>>>,
56    nodes: Arc<RwLock<HashMap<String, NodeInfo>>>,
58}
59
60impl ClusterTopology {
61    pub fn new() -> Self {
63        Self {
64            slot_map: Arc::new(RwLock::new(HashMap::new())),
65            nodes: Arc::new(RwLock::new(HashMap::new())),
66        }
67    }
68
69    pub async fn get_node_for_slot(&self, slot: u16) -> Option<(String, u16)> {
71        let slot_map = self.slot_map.read().await;
72        slot_map.get(&slot).cloned()
73    }
74
75    pub async fn get_node_for_key(&self, key: &[u8]) -> Option<(String, u16)> {
77        let slot = calculate_slot(key);
78        self.get_node_for_slot(slot).await
79    }
80
81    pub async fn update_slot_mapping(&self, slot: u16, host: String, port: u16) {
83        let mut slot_map = self.slot_map.write().await;
84        slot_map.insert(slot, (host, port));
85    }
86
87    pub async fn clear_slots(&self) {
89        let mut slot_map = self.slot_map.write().await;
90        slot_map.clear();
91    }
92
93    pub async fn update_from_cluster_slots(
98        &self,
99        slots_data: Vec<Vec<(i64, String, i64)>>,
100    ) -> RedisResult<()> {
101        let mut slot_map = self.slot_map.write().await;
102        let mut nodes = self.nodes.write().await;
103
104        slot_map.clear();
105        nodes.clear();
106
107        for slot_info in slots_data {
108            if slot_info.len() < 3 {
109                continue;
110            }
111
112            let start_slot = slot_info[0].0 as u16;
113            let end_slot = slot_info[1].0 as u16;
114            let master_host = slot_info[2].1.clone();
115            let master_port = slot_info[2].2 as u16;
116
117            for slot in start_slot..=end_slot {
119                slot_map.insert(slot, (master_host.clone(), master_port));
120            }
121
122            let node_key = format!("{}:{}", master_host, master_port);
124            let mut node = NodeInfo::new(node_key.clone(), master_host, master_port);
125            node.slots.push(SlotRange::new(start_slot, end_slot));
126            node.is_master = true;
127            nodes.insert(node_key, node);
128        }
129
130        Ok(())
131    }
132
133    pub async fn get_all_nodes(&self) -> Vec<NodeInfo> {
135        let nodes = self.nodes.read().await;
136        nodes.values().cloned().collect()
137    }
138
139    pub async fn mapped_slots_count(&self) -> usize {
141        let slot_map = self.slot_map.read().await;
142        slot_map.len()
143    }
144}
145
146impl Default for ClusterTopology {
147    fn default() -> Self {
148        Self::new()
149    }
150}
151
152#[derive(Clone)]
154pub struct RedirectHandler {
155    topology: ClusterTopology,
156    max_redirects: usize,
157}
158
159impl RedirectHandler {
160    pub fn new(topology: ClusterTopology, max_redirects: usize) -> Self {
162        Self {
163            topology,
164            max_redirects,
165        }
166    }
167
168    pub async fn handle_redirect(&self, error: &RedisError) -> RedisResult<(String, u16, bool)> {
170        match error {
171            RedisError::Moved { slot, host, port } => {
172                self.topology
174                    .update_slot_mapping(*slot, host.clone(), *port)
175                    .await;
176                Ok((host.clone(), *port, false))
177            }
178            RedisError::Ask { host, port, .. } => {
179                Ok((host.clone(), *port, true))
181            }
182            _ => Err(RedisError::Cluster(format!(
183                "Not a redirect error: {:?}",
184                error
185            ))),
186        }
187    }
188
189    pub fn max_redirects(&self) -> usize {
191        self.max_redirects
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    #[test]
200    fn test_calculate_slot() {
201        let slot = calculate_slot(b"mykey");
203        assert!(slot < CLUSTER_SLOTS);
204
205        let slot1 = calculate_slot(b"{user1000}.following");
207        let slot2 = calculate_slot(b"{user1000}.followers");
208        assert_eq!(
209            slot1, slot2,
210            "Keys with same hash tag should map to same slot"
211        );
212
213        let _slot3 = calculate_slot(b"{user1001}.following");
215        let slot = calculate_slot(b"123456789");
219        assert_eq!(slot, 12739);
220    }
221
222    #[test]
223    fn test_extract_hash_tag() {
224        assert_eq!(extract_hash_tag(b"key"), b"key");
225        assert_eq!(extract_hash_tag(b"{user}key"), b"user");
226        assert_eq!(extract_hash_tag(b"prefix{user}key"), b"user");
227        assert_eq!(extract_hash_tag(b"{user}"), b"user");
228        assert_eq!(extract_hash_tag(b"{}"), b"{}"); assert_eq!(extract_hash_tag(b"{"), b"{"); assert_eq!(extract_hash_tag(b"no{hash"), b"no{hash"); }
232
233    #[tokio::test]
234    async fn test_cluster_topology() {
235        let topology = ClusterTopology::new();
236
237        assert!(topology.get_node_for_slot(100).await.is_none());
239
240        topology
242            .update_slot_mapping(100, "localhost".to_string(), 6379)
243            .await;
244
245        let node = topology.get_node_for_slot(100).await;
247        assert_eq!(node, Some(("localhost".to_string(), 6379)));
248
249        topology.clear_slots().await;
251        assert!(topology.get_node_for_slot(100).await.is_none());
252    }
253
254    #[tokio::test]
255    async fn test_get_node_for_key() {
256        let topology = ClusterTopology::new();
257
258        let key = b"mykey";
259        let slot = calculate_slot(key);
260
261        topology
262            .update_slot_mapping(slot, "localhost".to_string(), 6379)
263            .await;
264
265        let node = topology.get_node_for_key(key).await;
266        assert_eq!(node, Some(("localhost".to_string(), 6379)));
267    }
268
269    #[tokio::test]
270    async fn test_redirect_handler() {
271        let topology = ClusterTopology::new();
272        let handler = RedirectHandler::new(topology.clone(), 3);
273
274        let error = RedisError::Moved {
276            slot: 9916,
277            host: "10.90.6.213".to_string(),
278            port: 6002,
279        };
280
281        let (host, port, is_ask) = handler.handle_redirect(&error).await.unwrap();
282        assert_eq!(host, "10.90.6.213");
283        assert_eq!(port, 6002);
284        assert!(!is_ask);
285
286        let node = topology.get_node_for_slot(9916).await;
288        assert_eq!(node, Some(("10.90.6.213".to_string(), 6002)));
289
290        let error = RedisError::Ask {
292            slot: 100,
293            host: "localhost".to_string(),
294            port: 7000,
295        };
296
297        let (host, port, is_ask) = handler.handle_redirect(&error).await.unwrap();
298        assert_eq!(host, "localhost");
299        assert_eq!(port, 7000);
300        assert!(is_ask);
301
302        assert!(topology.get_node_for_slot(100).await.is_none());
304    }
305}