redis_cacher/
cluster.rs

1use super::{Address, SingleConn};
2use crate::cmd;
3use futures::future::{join_all, TryFutureExt};
4use rand::{distributions::Uniform, prelude::*};
5use redis::{
6    from_redis_value, parse_redis_value, Cmd, ConnectionAddr, ConnectionInfo, ErrorKind,
7    FromRedisValue, RedisError, RedisResult, ToRedisArgs, Value,
8};
9use slotmap::SlotMap;
10use std::{
11    collections::{BTreeMap, HashMap, HashSet},
12    fmt,
13};
14
15const RETRIES: usize = 3;
16const SLOT_SIZE: u16 = 16384;
17
18impl fmt::Display for Address {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        write!(f, "{}:{}", self.host, self.port)
21    }
22}
23
24type Key = slotmap::DefaultKey;
25
26pub struct ClusterConn {
27    connections: SlotMap<Key, SingleConn>,
28    password: Option<String>,
29    node_slots: NodeSlots,
30    distribution: Option<Uniform<usize>>,
31    is_tls: bool,
32}
33
34#[derive(Debug, Clone)]
35struct NodeSlots {
36    pub slots: BTreeMap<u16, Key>,
37    pub addresses: HashMap<Address, Key>,
38}
39
40fn cluster_error(msg: impl Into<String>) -> RedisError {
41    RedisError::from((ErrorKind::ExtensionError, "cluster error", msg.into()))
42}
43
44fn slot_resp_error(msg: impl Into<String>) -> RedisError {
45    RedisError::from((ErrorKind::TypeError, "error parsing slots", msg.into()))
46}
47
48fn partition_error(msg: impl Into<String>) -> RedisError {
49    RedisError::from((
50        ErrorKind::ExtensionError,
51        "error partitioning keys by cluster node",
52        msg.into(),
53    ))
54}
55
56#[derive(Debug)]
57struct SlotResp {
58    start: u16,
59    end: u16,
60    address: Address,
61}
62
63impl FromRedisValue for SlotResp {
64    fn from_redis_value(v: &Value) -> RedisResult<Self> {
65        match v {
66            Value::Bulk(arr) => {
67                if arr.len() < 3 {
68                    return Err(slot_resp_error("not enough elements for slot record"));
69                }
70                let start: u16 = from_redis_value(&arr[0])?;
71                let end: u16 = from_redis_value(&arr[1])?;
72                // We only connect to the slot master
73                let address: Address = from_redis_value(&arr[2])?;
74                Ok(SlotResp {
75                    start,
76                    end,
77                    address,
78                })
79            }
80            _ => Err(slot_resp_error("expecting bulk for slot resp")),
81        }
82    }
83}
84
85impl FromRedisValue for Address {
86    fn from_redis_value(v: &Value) -> RedisResult<Self> {
87        match v {
88            Value::Bulk(arr) => {
89                if arr.len() < 2 {
90                    return Err(slot_resp_error("not enough elements for host record"));
91                }
92
93                let host: String = from_redis_value(&arr[0])?;
94                let port: u16 = from_redis_value(&arr[1])?;
95
96                Ok(Address { host, port })
97            }
98            _ => Err(slot_resp_error("expecting bulk for slot host")),
99        }
100    }
101}
102
103impl ClusterConn {
104    fn get_slot_conn(&mut self, slot: Option<u16>) -> Option<&mut SingleConn> {
105        let range = slot
106            .and_then(|slot| self.node_slots.slots.range(..=slot).next_back())
107            .map(|(_, key)| *key);
108
109        if let Some(key) = range {
110            self.connections.get_mut(key)
111        } else {
112            self.random_conn()
113        }
114    }
115
116    fn random_conn(&mut self) -> Option<&mut SingleConn> {
117        let idx = if let Some(distribution) = &self.distribution {
118            let mut rng = rand::thread_rng();
119            distribution.sample(&mut rng)
120        } else {
121            0
122        };
123        self.connections.values_mut().nth(idx)
124    }
125
126    async fn refresh_slots(&mut self) -> Result<(), RedisError> {
127        // FIXME: We should cache current connections to avoid unnecessary
128        // re-connecting
129        let conn = self
130            .random_conn()
131            .ok_or_else(|| cluster_error("no connections left"))?;
132        let slot_resp: Vec<SlotResp> = conn.query(cmd!["CLUSTER", "SLOTS"]).await?;
133        self.connect_slots(slot_resp).await?;
134        Ok(())
135    }
136
137    pub async fn query<T>(&mut self, cmd: Cmd) -> Result<T, RedisError>
138    where
139        T: FromRedisValue + Send + 'static,
140    {
141        self.req_packed_command(&cmd).await
142    }
143
144    pub async fn execute_script<T>(
145        &mut self,
146        eval_command: &Cmd,
147        load_command: &Cmd,
148    ) -> Result<T, RedisError>
149    where
150        T: FromRedisValue + Send + 'static,
151    {
152        let mut tries = 0;
153        // Get routing info from the eval command
154        let slot = match RoutingInfo::for_packed_command(&eval_command.get_packed_command()) {
155            Some(routing) => routing.slot(),
156            None => {
157                return Err((
158                    ErrorKind::ClientError,
159                    "this command cannot be safely routed in cluster mode",
160                )
161                    .into());
162            }
163        };
164
165        loop {
166            let conn = self.get_slot_conn(slot);
167
168            let error = if let Some(conn) = conn {
169                if conn.is_alive() {
170                    match conn.execute_script(eval_command, load_command).await {
171                        Ok(res) => return Ok(res),
172                        Err(e) => {
173                            // IO errors and MOVED responses indicate that the
174                            // cluster may have been shuffled and our connections
175                            // need refreshing. Everything else should be surfaced.
176                            if !e.is_io_error() && e.code() != Some("MOVED") {
177                                return Err(e);
178                            } else {
179                                e
180                            }
181                        }
182                    }
183                } else {
184                    cluster_error("fetched connection for slot was not alive")
185                }
186            } else {
187                cluster_error("couldn't fetch a connection for slot")
188            };
189
190            if tries <= RETRIES {
191                tries += 1;
192                tracing::warn!(
193                    "Failed to fetch a connection for execute_script: {}, retrying. i={} max={}",
194                    error,
195                    tries,
196                    RETRIES
197                );
198                self.refresh_slots().await?;
199            } else {
200                return Err(error);
201            }
202        }
203    }
204
205    pub async fn req_packed_command<T>(&mut self, cmd: &Cmd) -> Result<T, RedisError>
206    where
207        T: FromRedisValue + Send + 'static,
208    {
209        let mut tries = 0;
210        let slot = match RoutingInfo::for_packed_command(&cmd.get_packed_command()) {
211            Some(routing) => routing.slot(),
212            None => {
213                return Err((
214                    ErrorKind::ClientError,
215                    "this command cannot be safely routed in cluster mode",
216                )
217                    .into());
218            }
219        };
220
221        loop {
222            let conn = self.get_slot_conn(slot);
223            let error = if let Some(conn) = conn {
224                if conn.is_alive() {
225                    match conn.req_packed_command(cmd).await {
226                        Ok(res) => return Ok(res),
227                        Err(e) => {
228                            // IO errors and MOVED responses indicate that the
229                            // cluster may have been shuffled and our connections
230                            // need refreshing. Everything else should be surfaced.
231                            if !e.is_io_error() && e.code() != Some("MOVED") {
232                                return Err(e);
233                            } else {
234                                e
235                            }
236                        }
237                    }
238                } else {
239                    cluster_error("fetched connection for slot was not alive")
240                }
241            } else {
242                cluster_error("couldn't fetch a connection for slot")
243            };
244
245            if tries <= RETRIES {
246                tries += 1;
247                tracing::warn!("Failed to fetch a connection for req_packed_command: {}, retrying. i={} max={}", error, tries, RETRIES);
248                self.refresh_slots().await?;
249            } else {
250                return Err(error);
251            }
252        }
253    }
254
255    pub fn is_alive(&self) -> bool {
256        self.connections.values().all(SingleConn::is_alive)
257    }
258
259    pub async fn try_connect(infos: Vec<ConnectionInfo>) -> Result<Self, RedisError> {
260        if infos.is_empty() {
261            return Err(cluster_error("no connection info provided"));
262        }
263
264        let password = infos[0].redis.password.as_ref().cloned();
265        let is_tls = match infos[0].addr.clone() {
266            ConnectionAddr::TcpTls {
267                host: _,
268                port: _,
269                insecure: _,
270            } => true,
271            _ => false,
272        };
273
274        let mut addresses = HashMap::new();
275        let mut connections = SlotMap::new();
276
277        for info in infos {
278            let address = match &info.addr {
279                ConnectionAddr::Tcp(host, port) => Address {
280                    host: host.clone(),
281                    port: *port,
282                },
283                ConnectionAddr::TcpTls { host, port, .. } => Address {
284                    host: host.clone(),
285                    port: *port,
286                },
287                ConnectionAddr::Unix(path) => Address {
288                    host: path.to_str().unwrap_or("").to_owned(),
289                    port: 0,
290                },
291            };
292            let conn = match SingleConn::try_connect(info).await {
293                Ok(conn) => conn,
294                Err(_) => continue,
295            };
296
297            let key = connections.insert(conn);
298            addresses.insert(address, key);
299            break;
300        }
301
302        let mut cluster = ClusterConn {
303            connections,
304            node_slots: NodeSlots {
305                addresses,
306                slots: BTreeMap::new(),
307            },
308            password,
309            distribution: None,
310            is_tls,
311        };
312
313        cluster.refresh_slots().await?;
314
315        Ok(cluster)
316    }
317
318    async fn connect_multiple<'a, I>(
319        &self,
320        addresses: I,
321    ) -> Result<Vec<(&'a Address, SingleConn)>, RedisError>
322    where
323        I: Iterator<Item = &'a Address>,
324    {
325        let connections = addresses.map(|address| {
326            SingleConn::try_connect(super::build_info(
327                &address.host,
328                address.port,
329                self.password.as_deref(),
330                self.is_tls,
331            ))
332            .map_ok(move |conn| (address, conn))
333        });
334
335        join_all(connections).await.into_iter().collect()
336    }
337
338    async fn connect_slots(&mut self, slots: Vec<SlotResp>) -> Result<(), RedisError> {
339        let previous_connections = self.connections.len();
340        let addresses = unique_addresses(&slots);
341
342        let (mut remaining, removed): (HashMap<_, _>, HashMap<_, _>) = self
343            .node_slots
344            .addresses
345            .drain()
346            .partition(|(address, _)| addresses.contains(address));
347
348        for (_, key) in removed {
349            self.connections.remove(key);
350        }
351
352        // Drop dead connections so that they are reconnected
353        remaining.retain(|_, key| {
354            let conn = match self.connections.get(*key) {
355                Some(conn) => conn,
356                None => return false,
357            };
358
359            if conn.is_alive() {
360                true
361            } else {
362                self.connections.remove(*key);
363                false
364            }
365        });
366
367        self.node_slots.addresses = remaining;
368
369        let added = addresses
370            .into_iter()
371            .filter(|address| !self.node_slots.addresses.contains_key(*address));
372
373        let new_connections = self.connect_multiple(added).await?;
374
375        for (address, connection) in new_connections {
376            let key = self.connections.insert(connection);
377            self.node_slots.addresses.insert(address.clone(), key);
378        }
379
380        let mut new_slots = BTreeMap::new();
381        for slot in slots {
382            if let Some(key) = self.node_slots.addresses.get(&slot.address) {
383                new_slots.insert(slot.start, *key);
384            } else {
385                // This is a programming error and should not happen
386                tracing::warn!(
387                    start = slot.start,
388                    end = slot.end,
389                    address = format!("{}", slot.address),
390                    "Redis cluster: missing address for slot connection",
391                );
392            }
393        }
394        self.node_slots.slots = new_slots;
395
396        if self.connections.len() != previous_connections || self.distribution.is_none() {
397            self.distribution = Some(Uniform::new(0, self.connections.len()));
398        }
399
400        Ok(())
401    }
402
403    pub async fn ping(&mut self) -> Result<(), RedisError> {
404        let mut tries = 0;
405        'retry: loop {
406            // Ping all connections concurrently
407            let results = futures::future::join_all(
408                self.connections
409                    .values_mut()
410                    .filter(|c| c.is_alive())
411                    .map(|c| c.ping()),
412            )
413            .await;
414
415            for res in results {
416                // If there's an error, it could simply mean that the cluster has been shuffled.
417                // Refresh and try again.
418                if let Err(e) = res {
419                    if tries <= RETRIES {
420                        tries += 1;
421                        self.refresh_slots().await?;
422                        continue 'retry;
423                    } else {
424                        return Err(e);
425                    }
426                }
427            }
428
429            return Ok(());
430        }
431    }
432
433    pub fn partition_keys_by_node<'a, I, K>(
434        &self,
435        keys: I,
436    ) -> Result<HashMap<Address, Vec<&'a K>>, RedisError>
437    where
438        &'a K: ToRedisArgs,
439        I: Iterator<Item = &'a K>,
440    {
441        let mut res = HashMap::new();
442
443        for key in keys {
444            let args = key.to_redis_args();
445            let bytes = if args.len() != 1 {
446                Err(partition_error("multiple args for key"))
447            } else {
448                Ok(&args[0])
449            }?;
450            let target_slot = RoutingInfo::for_key(bytes)
451                .and_then(|routing_info| routing_info.slot())
452                .ok_or_else(|| partition_error("no routing info for key"))?;
453            let target_key = self
454                .node_slots
455                .slots
456                .range(0..=target_slot)
457                .next_back()
458                .map(|(_, key)| *key)
459                .ok_or_else(|| partition_error("unknown slot"))?;
460            let address = self
461                .node_slots
462                .addresses
463                .iter()
464                .find(|(_, &key)| target_key == key)
465                .map(|(address, _)| address)
466                .ok_or_else(|| partition_error("unknown address"))?;
467
468            let entry = res.entry(address.clone()).or_insert_with(Vec::new);
469            entry.push(key);
470        }
471
472        Ok(res)
473    }
474}
475
476fn unique_addresses(slots: &[SlotResp]) -> HashSet<&Address> {
477    slots.iter().map(|slot| &slot.address).collect()
478}
479
480fn get_hashtag(key: &[u8]) -> Option<&[u8]> {
481    let open = key.iter().position(|v| *v == b'{');
482    let open = match open {
483        Some(open) => open,
484        None => return None,
485    };
486
487    let close = key[open..].iter().position(|v| *v == b'}');
488    let close = match close {
489        Some(close) => close,
490        None => return None,
491    };
492
493    let rv = &key[open + 1..open + close];
494    if rv.is_empty() {
495        None
496    } else {
497        Some(rv)
498    }
499}
500
501/// Taken from redis-rs cluster support
502#[derive(Debug, Clone, Copy)]
503enum RoutingInfo {
504    Random,
505    Slot(u16),
506}
507
508fn get_arg(values: &[Value], idx: usize) -> Option<&[u8]> {
509    match values.get(idx) {
510        Some(Value::Data(ref data)) => Some(&data[..]),
511        _ => None,
512    }
513}
514
515fn get_command_arg(values: &[Value], idx: usize) -> Option<Vec<u8>> {
516    get_arg(values, idx).map(|x| x.to_ascii_uppercase())
517}
518
519fn get_u64_arg(values: &[Value], idx: usize) -> Option<u64> {
520    get_arg(values, idx)
521        .and_then(|x| std::str::from_utf8(x).ok())
522        .and_then(|x| x.parse().ok())
523}
524
525impl RoutingInfo {
526    pub fn slot(&self) -> Option<u16> {
527        match self {
528            RoutingInfo::Random => None,
529            RoutingInfo::Slot(slot) => Some(*slot),
530        }
531    }
532
533    pub fn for_packed_command(cmd: &[u8]) -> Option<RoutingInfo> {
534        parse_redis_value(cmd).ok().and_then(RoutingInfo::for_value)
535    }
536
537    pub fn for_value(value: Value) -> Option<RoutingInfo> {
538        let args = match value {
539            Value::Bulk(args) => args,
540            _ => return None,
541        };
542
543        match &get_command_arg(&args, 0)?[..] {
544            b"SCAN" | b"CLIENT SETNAME" | b"SHUTDOWN" | b"SLAVEOF" | b"REPLICAOF"
545            | b"SCRIPT KILL" | b"MOVE" | b"BITOP" => None,
546            b"EVALSHA" | b"EVAL" => {
547                let key_count = get_u64_arg(&args, 2)?;
548                if key_count == 0 {
549                    Some(RoutingInfo::Random)
550                } else {
551                    get_arg(&args, 3).and_then(RoutingInfo::for_key)
552                }
553            }
554            b"XGROUP" | b"XINFO" => get_arg(&args, 2).and_then(RoutingInfo::for_key),
555            b"XREAD" | b"XREADGROUP" => {
556                let streams_position = args.iter().position(|a| match a {
557                    Value::Data(a) => a == b"STREAMS",
558                    _ => false,
559                })?;
560                get_arg(&args, streams_position + 1).and_then(RoutingInfo::for_key)
561            }
562            _ => match get_arg(&args, 1) {
563                Some(key) => RoutingInfo::for_key(key),
564                None => Some(RoutingInfo::Random),
565            },
566        }
567    }
568
569    pub fn for_key(key: &[u8]) -> Option<RoutingInfo> {
570        let key = match get_hashtag(key) {
571            Some(tag) => tag,
572            None => key,
573        };
574        Some(RoutingInfo::Slot(
575            crc16::State::<crc16::XMODEM>::calculate(key) % SLOT_SIZE,
576        ))
577    }
578}
579
580#[cfg(test)]
581mod test {
582    use super::*;
583
584    #[test]
585    fn test_routing() {
586        let key: &[u8] = b"[dbreq.approvedeviceemail]\0\0\0\0\0\nP\x08\x01";
587        let slot = match RoutingInfo::for_key(key) {
588            Some(RoutingInfo::Slot(x)) => x,
589            _ => panic!("Expected slot"),
590        };
591        assert_eq!(8505, slot);
592
593        let cmd: &[u8] =
594            b"*2\r\n$3\r\nGET\r\n$35\r\n[dbreq.approvedeviceemail]\0\0\0\0\0\nP\x08\x01\r\n";
595        let slot = match RoutingInfo::for_packed_command(cmd) {
596            Some(RoutingInfo::Slot(x)) => x,
597            _ => panic!("Expected slot"),
598        };
599        assert_eq!(8505, slot);
600    }
601}