1use super::{Address, SingleConn};
2use crate::cmd;
3use futures::future::{join_all, TryFutureExt};
4use rand::{distributions::Uniform, prelude::*};
5use redis::{
6 from_redis_value, parse_redis_value, Cmd, ConnectionAddr, ConnectionInfo, ErrorKind,
7 FromRedisValue, RedisError, RedisResult, ToRedisArgs, Value,
8};
9use slotmap::SlotMap;
10use std::{
11 collections::{BTreeMap, HashMap, HashSet},
12 fmt,
13};
14
15const RETRIES: usize = 3;
16const SLOT_SIZE: u16 = 16384;
17
18impl fmt::Display for Address {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 write!(f, "{}:{}", self.host, self.port)
21 }
22}
23
24type Key = slotmap::DefaultKey;
25
26pub struct ClusterConn {
27 connections: SlotMap<Key, SingleConn>,
28 password: Option<String>,
29 node_slots: NodeSlots,
30 distribution: Option<Uniform<usize>>,
31 is_tls: bool,
32}
33
34#[derive(Debug, Clone)]
35struct NodeSlots {
36 pub slots: BTreeMap<u16, Key>,
37 pub addresses: HashMap<Address, Key>,
38}
39
40fn cluster_error(msg: impl Into<String>) -> RedisError {
41 RedisError::from((ErrorKind::ExtensionError, "cluster error", msg.into()))
42}
43
44fn slot_resp_error(msg: impl Into<String>) -> RedisError {
45 RedisError::from((ErrorKind::TypeError, "error parsing slots", msg.into()))
46}
47
48fn partition_error(msg: impl Into<String>) -> RedisError {
49 RedisError::from((
50 ErrorKind::ExtensionError,
51 "error partitioning keys by cluster node",
52 msg.into(),
53 ))
54}
55
56#[derive(Debug)]
57struct SlotResp {
58 start: u16,
59 end: u16,
60 address: Address,
61}
62
63impl FromRedisValue for SlotResp {
64 fn from_redis_value(v: &Value) -> RedisResult<Self> {
65 match v {
66 Value::Bulk(arr) => {
67 if arr.len() < 3 {
68 return Err(slot_resp_error("not enough elements for slot record"));
69 }
70 let start: u16 = from_redis_value(&arr[0])?;
71 let end: u16 = from_redis_value(&arr[1])?;
72 let address: Address = from_redis_value(&arr[2])?;
74 Ok(SlotResp {
75 start,
76 end,
77 address,
78 })
79 }
80 _ => Err(slot_resp_error("expecting bulk for slot resp")),
81 }
82 }
83}
84
85impl FromRedisValue for Address {
86 fn from_redis_value(v: &Value) -> RedisResult<Self> {
87 match v {
88 Value::Bulk(arr) => {
89 if arr.len() < 2 {
90 return Err(slot_resp_error("not enough elements for host record"));
91 }
92
93 let host: String = from_redis_value(&arr[0])?;
94 let port: u16 = from_redis_value(&arr[1])?;
95
96 Ok(Address { host, port })
97 }
98 _ => Err(slot_resp_error("expecting bulk for slot host")),
99 }
100 }
101}
102
103impl ClusterConn {
104 fn get_slot_conn(&mut self, slot: Option<u16>) -> Option<&mut SingleConn> {
105 let range = slot
106 .and_then(|slot| self.node_slots.slots.range(..=slot).next_back())
107 .map(|(_, key)| *key);
108
109 if let Some(key) = range {
110 self.connections.get_mut(key)
111 } else {
112 self.random_conn()
113 }
114 }
115
116 fn random_conn(&mut self) -> Option<&mut SingleConn> {
117 let idx = if let Some(distribution) = &self.distribution {
118 let mut rng = rand::thread_rng();
119 distribution.sample(&mut rng)
120 } else {
121 0
122 };
123 self.connections.values_mut().nth(idx)
124 }
125
126 async fn refresh_slots(&mut self) -> Result<(), RedisError> {
127 let conn = self
130 .random_conn()
131 .ok_or_else(|| cluster_error("no connections left"))?;
132 let slot_resp: Vec<SlotResp> = conn.query(cmd!["CLUSTER", "SLOTS"]).await?;
133 self.connect_slots(slot_resp).await?;
134 Ok(())
135 }
136
137 pub async fn query<T>(&mut self, cmd: Cmd) -> Result<T, RedisError>
138 where
139 T: FromRedisValue + Send + 'static,
140 {
141 self.req_packed_command(&cmd).await
142 }
143
144 pub async fn execute_script<T>(
145 &mut self,
146 eval_command: &Cmd,
147 load_command: &Cmd,
148 ) -> Result<T, RedisError>
149 where
150 T: FromRedisValue + Send + 'static,
151 {
152 let mut tries = 0;
153 let slot = match RoutingInfo::for_packed_command(&eval_command.get_packed_command()) {
155 Some(routing) => routing.slot(),
156 None => {
157 return Err((
158 ErrorKind::ClientError,
159 "this command cannot be safely routed in cluster mode",
160 )
161 .into());
162 }
163 };
164
165 loop {
166 let conn = self.get_slot_conn(slot);
167
168 let error = if let Some(conn) = conn {
169 if conn.is_alive() {
170 match conn.execute_script(eval_command, load_command).await {
171 Ok(res) => return Ok(res),
172 Err(e) => {
173 if !e.is_io_error() && e.code() != Some("MOVED") {
177 return Err(e);
178 } else {
179 e
180 }
181 }
182 }
183 } else {
184 cluster_error("fetched connection for slot was not alive")
185 }
186 } else {
187 cluster_error("couldn't fetch a connection for slot")
188 };
189
190 if tries <= RETRIES {
191 tries += 1;
192 tracing::warn!(
193 "Failed to fetch a connection for execute_script: {}, retrying. i={} max={}",
194 error,
195 tries,
196 RETRIES
197 );
198 self.refresh_slots().await?;
199 } else {
200 return Err(error);
201 }
202 }
203 }
204
205 pub async fn req_packed_command<T>(&mut self, cmd: &Cmd) -> Result<T, RedisError>
206 where
207 T: FromRedisValue + Send + 'static,
208 {
209 let mut tries = 0;
210 let slot = match RoutingInfo::for_packed_command(&cmd.get_packed_command()) {
211 Some(routing) => routing.slot(),
212 None => {
213 return Err((
214 ErrorKind::ClientError,
215 "this command cannot be safely routed in cluster mode",
216 )
217 .into());
218 }
219 };
220
221 loop {
222 let conn = self.get_slot_conn(slot);
223 let error = if let Some(conn) = conn {
224 if conn.is_alive() {
225 match conn.req_packed_command(cmd).await {
226 Ok(res) => return Ok(res),
227 Err(e) => {
228 if !e.is_io_error() && e.code() != Some("MOVED") {
232 return Err(e);
233 } else {
234 e
235 }
236 }
237 }
238 } else {
239 cluster_error("fetched connection for slot was not alive")
240 }
241 } else {
242 cluster_error("couldn't fetch a connection for slot")
243 };
244
245 if tries <= RETRIES {
246 tries += 1;
247 tracing::warn!("Failed to fetch a connection for req_packed_command: {}, retrying. i={} max={}", error, tries, RETRIES);
248 self.refresh_slots().await?;
249 } else {
250 return Err(error);
251 }
252 }
253 }
254
255 pub fn is_alive(&self) -> bool {
256 self.connections.values().all(SingleConn::is_alive)
257 }
258
259 pub async fn try_connect(infos: Vec<ConnectionInfo>) -> Result<Self, RedisError> {
260 if infos.is_empty() {
261 return Err(cluster_error("no connection info provided"));
262 }
263
264 let password = infos[0].redis.password.as_ref().cloned();
265 let is_tls = match infos[0].addr.clone() {
266 ConnectionAddr::TcpTls {
267 host: _,
268 port: _,
269 insecure: _,
270 } => true,
271 _ => false,
272 };
273
274 let mut addresses = HashMap::new();
275 let mut connections = SlotMap::new();
276
277 for info in infos {
278 let address = match &info.addr {
279 ConnectionAddr::Tcp(host, port) => Address {
280 host: host.clone(),
281 port: *port,
282 },
283 ConnectionAddr::TcpTls { host, port, .. } => Address {
284 host: host.clone(),
285 port: *port,
286 },
287 ConnectionAddr::Unix(path) => Address {
288 host: path.to_str().unwrap_or("").to_owned(),
289 port: 0,
290 },
291 };
292 let conn = match SingleConn::try_connect(info).await {
293 Ok(conn) => conn,
294 Err(_) => continue,
295 };
296
297 let key = connections.insert(conn);
298 addresses.insert(address, key);
299 break;
300 }
301
302 let mut cluster = ClusterConn {
303 connections,
304 node_slots: NodeSlots {
305 addresses,
306 slots: BTreeMap::new(),
307 },
308 password,
309 distribution: None,
310 is_tls,
311 };
312
313 cluster.refresh_slots().await?;
314
315 Ok(cluster)
316 }
317
318 async fn connect_multiple<'a, I>(
319 &self,
320 addresses: I,
321 ) -> Result<Vec<(&'a Address, SingleConn)>, RedisError>
322 where
323 I: Iterator<Item = &'a Address>,
324 {
325 let connections = addresses.map(|address| {
326 SingleConn::try_connect(super::build_info(
327 &address.host,
328 address.port,
329 self.password.as_deref(),
330 self.is_tls,
331 ))
332 .map_ok(move |conn| (address, conn))
333 });
334
335 join_all(connections).await.into_iter().collect()
336 }
337
338 async fn connect_slots(&mut self, slots: Vec<SlotResp>) -> Result<(), RedisError> {
339 let previous_connections = self.connections.len();
340 let addresses = unique_addresses(&slots);
341
342 let (mut remaining, removed): (HashMap<_, _>, HashMap<_, _>) = self
343 .node_slots
344 .addresses
345 .drain()
346 .partition(|(address, _)| addresses.contains(address));
347
348 for (_, key) in removed {
349 self.connections.remove(key);
350 }
351
352 remaining.retain(|_, key| {
354 let conn = match self.connections.get(*key) {
355 Some(conn) => conn,
356 None => return false,
357 };
358
359 if conn.is_alive() {
360 true
361 } else {
362 self.connections.remove(*key);
363 false
364 }
365 });
366
367 self.node_slots.addresses = remaining;
368
369 let added = addresses
370 .into_iter()
371 .filter(|address| !self.node_slots.addresses.contains_key(*address));
372
373 let new_connections = self.connect_multiple(added).await?;
374
375 for (address, connection) in new_connections {
376 let key = self.connections.insert(connection);
377 self.node_slots.addresses.insert(address.clone(), key);
378 }
379
380 let mut new_slots = BTreeMap::new();
381 for slot in slots {
382 if let Some(key) = self.node_slots.addresses.get(&slot.address) {
383 new_slots.insert(slot.start, *key);
384 } else {
385 tracing::warn!(
387 start = slot.start,
388 end = slot.end,
389 address = format!("{}", slot.address),
390 "Redis cluster: missing address for slot connection",
391 );
392 }
393 }
394 self.node_slots.slots = new_slots;
395
396 if self.connections.len() != previous_connections || self.distribution.is_none() {
397 self.distribution = Some(Uniform::new(0, self.connections.len()));
398 }
399
400 Ok(())
401 }
402
403 pub async fn ping(&mut self) -> Result<(), RedisError> {
404 let mut tries = 0;
405 'retry: loop {
406 let results = futures::future::join_all(
408 self.connections
409 .values_mut()
410 .filter(|c| c.is_alive())
411 .map(|c| c.ping()),
412 )
413 .await;
414
415 for res in results {
416 if let Err(e) = res {
419 if tries <= RETRIES {
420 tries += 1;
421 self.refresh_slots().await?;
422 continue 'retry;
423 } else {
424 return Err(e);
425 }
426 }
427 }
428
429 return Ok(());
430 }
431 }
432
433 pub fn partition_keys_by_node<'a, I, K>(
434 &self,
435 keys: I,
436 ) -> Result<HashMap<Address, Vec<&'a K>>, RedisError>
437 where
438 &'a K: ToRedisArgs,
439 I: Iterator<Item = &'a K>,
440 {
441 let mut res = HashMap::new();
442
443 for key in keys {
444 let args = key.to_redis_args();
445 let bytes = if args.len() != 1 {
446 Err(partition_error("multiple args for key"))
447 } else {
448 Ok(&args[0])
449 }?;
450 let target_slot = RoutingInfo::for_key(bytes)
451 .and_then(|routing_info| routing_info.slot())
452 .ok_or_else(|| partition_error("no routing info for key"))?;
453 let target_key = self
454 .node_slots
455 .slots
456 .range(0..=target_slot)
457 .next_back()
458 .map(|(_, key)| *key)
459 .ok_or_else(|| partition_error("unknown slot"))?;
460 let address = self
461 .node_slots
462 .addresses
463 .iter()
464 .find(|(_, &key)| target_key == key)
465 .map(|(address, _)| address)
466 .ok_or_else(|| partition_error("unknown address"))?;
467
468 let entry = res.entry(address.clone()).or_insert_with(Vec::new);
469 entry.push(key);
470 }
471
472 Ok(res)
473 }
474}
475
476fn unique_addresses(slots: &[SlotResp]) -> HashSet<&Address> {
477 slots.iter().map(|slot| &slot.address).collect()
478}
479
480fn get_hashtag(key: &[u8]) -> Option<&[u8]> {
481 let open = key.iter().position(|v| *v == b'{');
482 let open = match open {
483 Some(open) => open,
484 None => return None,
485 };
486
487 let close = key[open..].iter().position(|v| *v == b'}');
488 let close = match close {
489 Some(close) => close,
490 None => return None,
491 };
492
493 let rv = &key[open + 1..open + close];
494 if rv.is_empty() {
495 None
496 } else {
497 Some(rv)
498 }
499}
500
501#[derive(Debug, Clone, Copy)]
503enum RoutingInfo {
504 Random,
505 Slot(u16),
506}
507
508fn get_arg(values: &[Value], idx: usize) -> Option<&[u8]> {
509 match values.get(idx) {
510 Some(Value::Data(ref data)) => Some(&data[..]),
511 _ => None,
512 }
513}
514
515fn get_command_arg(values: &[Value], idx: usize) -> Option<Vec<u8>> {
516 get_arg(values, idx).map(|x| x.to_ascii_uppercase())
517}
518
519fn get_u64_arg(values: &[Value], idx: usize) -> Option<u64> {
520 get_arg(values, idx)
521 .and_then(|x| std::str::from_utf8(x).ok())
522 .and_then(|x| x.parse().ok())
523}
524
525impl RoutingInfo {
526 pub fn slot(&self) -> Option<u16> {
527 match self {
528 RoutingInfo::Random => None,
529 RoutingInfo::Slot(slot) => Some(*slot),
530 }
531 }
532
533 pub fn for_packed_command(cmd: &[u8]) -> Option<RoutingInfo> {
534 parse_redis_value(cmd).ok().and_then(RoutingInfo::for_value)
535 }
536
537 pub fn for_value(value: Value) -> Option<RoutingInfo> {
538 let args = match value {
539 Value::Bulk(args) => args,
540 _ => return None,
541 };
542
543 match &get_command_arg(&args, 0)?[..] {
544 b"SCAN" | b"CLIENT SETNAME" | b"SHUTDOWN" | b"SLAVEOF" | b"REPLICAOF"
545 | b"SCRIPT KILL" | b"MOVE" | b"BITOP" => None,
546 b"EVALSHA" | b"EVAL" => {
547 let key_count = get_u64_arg(&args, 2)?;
548 if key_count == 0 {
549 Some(RoutingInfo::Random)
550 } else {
551 get_arg(&args, 3).and_then(RoutingInfo::for_key)
552 }
553 }
554 b"XGROUP" | b"XINFO" => get_arg(&args, 2).and_then(RoutingInfo::for_key),
555 b"XREAD" | b"XREADGROUP" => {
556 let streams_position = args.iter().position(|a| match a {
557 Value::Data(a) => a == b"STREAMS",
558 _ => false,
559 })?;
560 get_arg(&args, streams_position + 1).and_then(RoutingInfo::for_key)
561 }
562 _ => match get_arg(&args, 1) {
563 Some(key) => RoutingInfo::for_key(key),
564 None => Some(RoutingInfo::Random),
565 },
566 }
567 }
568
569 pub fn for_key(key: &[u8]) -> Option<RoutingInfo> {
570 let key = match get_hashtag(key) {
571 Some(tag) => tag,
572 None => key,
573 };
574 Some(RoutingInfo::Slot(
575 crc16::State::<crc16::XMODEM>::calculate(key) % SLOT_SIZE,
576 ))
577 }
578}
579
580#[cfg(test)]
581mod test {
582 use super::*;
583
584 #[test]
585 fn test_routing() {
586 let key: &[u8] = b"[dbreq.approvedeviceemail]\0\0\0\0\0\nP\x08\x01";
587 let slot = match RoutingInfo::for_key(key) {
588 Some(RoutingInfo::Slot(x)) => x,
589 _ => panic!("Expected slot"),
590 };
591 assert_eq!(8505, slot);
592
593 let cmd: &[u8] =
594 b"*2\r\n$3\r\nGET\r\n$35\r\n[dbreq.approvedeviceemail]\0\0\0\0\0\nP\x08\x01\r\n";
595 let slot = match RoutingInfo::for_packed_command(cmd) {
596 Some(RoutingInfo::Slot(x)) => x,
597 _ => panic!("Expected slot"),
598 };
599 assert_eq!(8505, slot);
600 }
601}