1#![warn(
3 anonymous_parameters,
4 bare_trait_objects,
5 elided_lifetimes_in_paths,
6 rust_2018_idioms,
7 trivial_casts,
8 trivial_numeric_casts,
9 unsafe_code,
10 unused_extern_crates,
11 unused_import_braces
12)]
13#![warn(
15 clippy::all,
16 clippy::cargo,
17 clippy::dbg_macro,
18 clippy::float_cmp_const,
19 clippy::get_unwrap,
20 clippy::mem_forget,
21 clippy::nursery,
22 clippy::pedantic,
23 clippy::todo,
24 clippy::unwrap_used
25)]
26#![allow(
28 clippy::default_trait_access,
29 clippy::doc_markdown,
30 clippy::if_not_else,
31 clippy::module_name_repetitions,
32 clippy::multiple_crate_versions,
33 clippy::must_use_candidate,
34 clippy::needless_pass_by_value,
35 clippy::use_self,
36 clippy::cargo_common_metadata,
37 clippy::missing_errors_doc,
38 clippy::enum_glob_use,
39 clippy::struct_excessive_bools
40)]
41#![cfg_attr(test, allow(clippy::non_ascii_literal, clippy::unwrap_used))]
43
44mod cluster;
45pub mod config;
46pub mod error;
47mod metrics;
48pub mod mock;
49pub mod pool;
50mod single;
51pub mod stream;
52
53use self::{
54 cluster::ClusterConn, metrics::MeteredConn, single::SingleConn,
55};
56use crate::pool::{ConnectionManager, PoolConnection};
57pub use crate::{
58 config::{Config, ConnectionMode},
59 metrics::RedisMetrics,
60};
61use async_trait::async_trait;
62use displaydoc::Display;
63use redis::{self, ConnectionAddr, ConnectionInfo, RedisConnectionInfo};
64pub use redis::{
65 ErrorKind, FromRedisValue, RedisError, RedisResult, RedisWrite, ToRedisArgs, Value,
66};
67use std::{collections::HashMap, hash::Hash, sync::Arc, time::Duration};
68use stream::{
69 MessageId, ReadStream, ReadStreamOptions, StreamItem, StreamReadReply, WriteStreamOptions,
70};
71
72struct VecWrapper<'a>(&'a Vec<u8>);
73impl<'a> ToRedisArgs for VecWrapper<'a> {
74 fn write_redis_args<W: ?Sized + RedisWrite>(&self, out: &mut W) {
75 out.write_arg(self.0);
76 }
77}
78
79#[derive(Debug, Display, thiserror::Error)]
80pub enum Error {
81 KeyDoesNotExist,
83 RedisError(#[from] RedisError),
85 RedisInsertionError(String),
87 HostParseError(&'static str),
89 IoError(#[from] std::io::Error),
91 MissingConnectionInfo,
93}
94
95#[async_trait]
96pub trait Store {
97 async fn get_key<K, T>(&self, key: K) -> Result<T, Error>
98 where
99 K: ToRedisArgs + Send,
100 T: FromRedisValue + Send + 'static;
101
102 async fn set_key<K, D>(&self, key: K, ttl: Duration, data: D) -> Result<(), Error>
103 where
104 K: ToRedisArgs + Send,
105 D: ToRedisArgs + Send;
106
107 async fn del_key<T: ToRedisArgs + Send>(&self, key: T) -> Result<Option<u64>, Error>;
108
109 async fn del_keys<'a, T>(&self, keys: &'a [T]) -> Result<Vec<Option<u64>>, Error>
110 where
111 T: Sync,
112 &'a T: ToRedisArgs;
113
114 async fn set_if_not_exists<K, D>(&self, key: K, ttl: Duration, data: D) -> Result<bool, Error>
118 where
119 K: ToRedisArgs + Send,
120 D: ToRedisArgs + Send;
121
122 async fn sorted_set_add_one<D>(&self, key: &str, score: i64, data: D) -> Result<(), Error>
124 where
125 D: ToRedisArgs + Send;
126}
127
128#[derive(Clone)]
129pub struct Redis {
130 pool: pool::Pool<Connection>,
131 pub metrics: Arc<RedisMetrics>,
132}
133
134impl Redis {
135 pub async fn new(config: &Config) -> Self {
136 let metrics = Arc::new(RedisMetrics::default());
137 let hosts = config
138 .hosts
139 .iter()
140 .map(|host| (host.0.as_str(), host.1))
141 .collect::<Vec<_>>();
142 let config = pool::ConfigBuilder::new()
143 .address(Client::with_mode(
144 &hosts,
145 config.password.as_deref(),
146 metrics.clone(),
147 config.mode,
148 config.is_tls,
149 ))
150 .max_size(config.pool_max_size)
151 .min_size(config.pool_min_size)
152 .build();
153
154 Self {
155 pool: pool::Pool::new(config).await,
156 metrics,
157 }
158 }
159
160 pub async fn get_conn(&self) -> Result<PoolConnection<Connection>, Error> {
161 self.pool.get_connection().await
162 }
163
164 pub async fn try_get_conn(&self) -> Result<PoolConnection<Connection>, Error> {
165 self.pool.try_get_connection().await
166 }
167
168 pub async fn get_conn_timeout(
169 &self,
170 timeout: Duration,
171 ) -> Result<PoolConnection<Connection>, Error> {
172 self.pool.get_connection_timeout(timeout).await
173 }
174
175 pub async fn stream_write_one<'s, D, S>(
179 &self,
180 stream_id: &'s S,
181 data: D,
182 options: &WriteStreamOptions,
183 ) -> Result<MessageId, Error>
184 where
185 D: ToRedisArgs,
186 &'s S: ToRedisArgs,
187 {
188 self.get_conn().await?.xadd(stream_id, data, options).await
189 }
190
191 pub async fn stream_message_exists<'s, S>(
193 &self,
194 stream_id: &'s S,
195 id: MessageId,
196 ) -> Result<bool, Error>
197 where
198 &'s S: ToRedisArgs,
199 S: FromRedisValue + Eq + Hash + Send + 'static,
200 {
201 let messages = self
202 .get_conn()
203 .await?
204 .xrange::<S, Value>(stream_id, id, id, Some(1))
205 .await?;
206
207 Ok(!messages.is_empty())
208 }
209}
210
211#[async_trait]
212impl Store for Redis {
213 async fn get_key<K, T>(&self, key: K) -> Result<T, Error>
214 where
215 K: ToRedisArgs + Send,
216 T: FromRedisValue + Send + 'static,
217 {
218 self.get_conn().await?.get(key).await
219 }
220
221 async fn set_key<K, D>(&self, key: K, ttl: Duration, data: D) -> Result<(), Error>
222 where
223 K: ToRedisArgs + Send,
224 D: ToRedisArgs + Send,
225 {
226 self.get_conn().await?.set_expiry(key, ttl, data).await
227 }
228
229 async fn del_key<T: ToRedisArgs + Send>(&self, key: T) -> Result<Option<u64>, Error> {
230 self.get_conn().await?.del(key).await
231 }
232
233 async fn del_keys<'a, T>(&self, keys: &'a [T]) -> Result<Vec<Option<u64>>, Error>
234 where
235 T: Sync,
236 &'a T: ToRedisArgs,
237 {
238 let mut pool = self.get_conn().await?;
239 let mut conn = pool.multi().await?;
240
241 for key in keys {
242 conn.del(key).await?;
243 }
244
245 conn.exec().await
246 }
247
248 async fn set_if_not_exists<K, D>(&self, key: K, ttl: Duration, data: D) -> Result<bool, Error>
252 where
253 K: ToRedisArgs + Send,
254 D: ToRedisArgs + Send,
255 {
256 self.get_conn().await?.set_nx_ex(key, ttl, data).await
257 }
258
259 async fn sorted_set_add_one<D>(&self, key: &str, score: i64, data: D) -> Result<(), Error>
261 where
262 D: ToRedisArgs + Send,
263 {
264 let mut conn = self.get_conn().await?;
265
266 let inserted = conn.z_add(key, score, data).await?;
267
268 if inserted == 1 {
269 Ok(())
270 } else {
271 Err(Error::RedisInsertionError(key.to_string()))
272 }
273 }
274}
275
276pub enum Connection {
277 Single(SingleConn),
278 Clustered(ClusterConn),
279 Metered(MeteredConn),
280}
281
282#[async_trait]
283impl ConnectionManager for Connection {
284 type Address = Client;
285 type Connection = Self;
286 type Error = Error;
287
288 async fn connect(address: &Client) -> Result<Connection, Error> {
289 Ok(address.get_connection().await?)
290 }
291
292 fn check_alive(connection: &Self::Connection) -> Option<bool> {
293 match connection {
294 Connection::Single(ref conn) => Some(conn.is_alive()),
295 Connection::Clustered(ref conn) => Some(conn.is_alive()),
296 Connection::Metered(ref conn) => Some(conn.is_alive()),
297 }
298 }
299
300 async fn ping(connection: &mut Self::Connection) -> Result<(), Self::Error> {
301 connection.ping().await
302 }
303}
304
305impl Connection {
306 pub async fn query<T>(&mut self, cmd: redis::Cmd) -> Result<T, Error>
307 where
308 T: FromRedisValue + Send + 'static,
309 {
310 match self {
311 Self::Single(ref mut single) => Ok(single.query(cmd).await?),
312 Self::Clustered(ref mut cluster) => Ok(cluster.query(cmd).await?),
313 Self::Metered(ref mut metered) => Ok(metered.metered_query(cmd).await?),
314 }
315 }
316
317 pub fn partition_keys_by_node<'a, I, K>(
318 &self,
319 keys: I,
320 ) -> Result<HashMap<Address, Vec<&'a K>>, Error>
321 where
322 &'a K: ToRedisArgs,
323 I: Iterator<Item = &'a K>,
324 {
325 match self {
326 Self::Single(ref single) => Ok(single.partition_keys_by_node(keys)?),
327 Self::Clustered(ref cluster) => Ok(cluster.partition_keys_by_node(keys)?),
328 Self::Metered(ref metered) => Ok(metered.partition_keys_by_node(keys)?),
329 }
330 }
331
332 pub async fn exists<K>(&mut self, key: K) -> Result<u64, Error>
336 where
337 K: ToRedisArgs,
338 {
339 Ok(self.query(cmd!["EXISTS", key]).await?)
340 }
341
342 pub async fn get<K, T>(&mut self, key: K) -> Result<T, Error>
343 where
344 K: ToRedisArgs,
345 T: FromRedisValue + Send + 'static,
346 {
347 self.query(cmd!["GET", key]).await
348 }
349
350 pub async fn hget<H, K, T>(&mut self, hash: H, key: K) -> Result<T, Error>
351 where
352 H: ToRedisArgs,
353 K: ToRedisArgs,
354 T: FromRedisValue + Send + 'static,
355 {
356 self.query(cmd!["HGET", hash, key]).await
357 }
358
359 pub async fn hget_all<H, T>(&mut self, hash: H) -> Result<T, Error>
360 where
361 H: ToRedisArgs,
362 T: FromRedisValue + Send + 'static,
363 {
364 self.query(cmd!["HGETALL", hash]).await
365 }
366
367 pub async fn ttl<T: ToRedisArgs>(&mut self, key: T) -> Result<Option<i64>, Error> {
368 self.query(cmd!["TTL", key]).await
369 }
370
371 pub async fn pttl<T: ToRedisArgs>(&mut self, key: T) -> Result<Option<i64>, Error> {
372 self.query(cmd!["PTTL", key]).await
373 }
374
375 pub async fn del<T: ToRedisArgs>(&mut self, key: T) -> Result<Option<u64>, Error> {
376 self.query(cmd!["DEL", key]).await
377 }
378
379 pub async fn ping(&mut self) -> Result<(), Error> {
380 match self {
381 Self::Single(ref mut single) => Ok(single.ping().await?),
382 Self::Clustered(ref mut cluster) => Ok(cluster.ping().await?),
383 Self::Metered(ref mut metered) => Ok(metered.metered_ping().await?),
384 }
385 }
386
387 pub async fn multi(&mut self) -> Result<ConnectionMulti<'_>, Error> {
388 self.query(cmd!["MULTI"]).await?;
389
390 Ok(ConnectionMulti(self))
391 }
392
393 pub async fn set<K, D>(&mut self, key: K, data: D) -> Result<(), Error>
394 where
395 K: ToRedisArgs,
396 D: ToRedisArgs,
397 {
398 self.query(cmd!["SET", key, data]).await
399 }
400
401 pub async fn set_expiry<K, D>(&mut self, key: K, ttl: Duration, data: D) -> Result<(), Error>
402 where
403 K: ToRedisArgs,
404 D: ToRedisArgs,
405 {
406 self.query(cmd!["SETEX", key, ttl.as_secs(), data]).await
407 }
408
409 pub async fn z_add<K, D>(&mut self, key: K, score: i64, data: D) -> Result<u64, Error>
410 where
411 K: ToRedisArgs,
412 D: ToRedisArgs,
413 {
414 self.query(cmd!["ZADD", key, score, data]).await
415 }
416
417 pub async fn zadd_binary<K, T>(&mut self, key: K, score: i64, member: T) -> Result<i64, Error>
418 where
419 K: ToRedisArgs,
420 T: AsRef<[u8]>,
421 {
422 self.query(cmd![
423 "ZADD",
424 key,
425 score,
426 VecWrapper(&member.as_ref().to_vec())
427 ])
428 .await
429 }
430
431 pub async fn z_range<K, T>(
432 &mut self,
433 key: K,
434 min: i64,
435 max: i64,
436 is_reversed: bool,
437 ) -> Result<T, Error>
438 where
439 K: ToRedisArgs,
440 T: FromRedisValue + Send + 'static,
441 {
442 if is_reversed {
443 return self.query(cmd!["ZREVRANGE", key, min, max]).await;
444 }
445 self.query(cmd!["ZRANGE", key, min, max]).await
446 }
447
448 pub async fn z_rem<K, T>(&mut self, key: K, members: T) -> Result<u64, Error>
449 where
450 K: ToRedisArgs,
451 T: ToRedisArgs,
452 {
453 self.query(cmd!["ZREM", key, members]).await
454 }
455
456 pub async fn zrem_all<K>(&mut self, key: K) -> Result<i64, Error>
457 where
458 K: ToRedisArgs,
459 {
460 self.query(cmd!["ZREMRANGEBYRANK", key, 0, -1]).await
461 }
462
463 pub async fn z_card<K>(&mut self, key: K) -> Result<i64, Error>
464 where
465 K: ToRedisArgs,
466 {
467 self.query(cmd!["ZCARD", key]).await
468 }
469
470 pub async fn incr_by<T: ToRedisArgs>(&mut self, key: T, increment: i64) -> Result<i64, Error> {
471 self.query(cmd!["INCRBY", key, increment]).await
472 }
473
474 pub async fn hset<H, D>(&mut self, hash: H, data: D) -> Result<i64, Error>
475 where
476 H: ToRedisArgs,
477 D: ToRedisArgs,
478 {
479 self.query(cmd!["HSET", hash, data]).await
480 }
481
482 pub async fn hdel<H, D>(&mut self, hash: H, data: D) -> Result<i64, Error>
483 where
484 H: ToRedisArgs,
485 D: ToRedisArgs,
486 {
487 self.query(cmd!["HDEL", hash, data]).await
488 }
489
490 pub async fn expire<T: ToRedisArgs>(&mut self, key: T, ttl: Duration) -> Result<(), Error> {
491 match self.query(cmd!["EXPIRE", key, ttl.as_secs()]).await? {
492 1 => Ok(()),
493 _ => Err(Error::KeyDoesNotExist),
494 }
495 }
496
497 pub async fn set_nx_ex<K, D>(&mut self, key: K, ttl: Duration, data: D) -> Result<bool, Error>
498 where
499 K: ToRedisArgs,
500 D: ToRedisArgs,
501 {
502 self.query(cmd!["SET", key, data, "NX", "EX", ttl.as_secs()])
503 .await
504 }
505
506 async fn xadd<'s, D, S>(
507 &mut self,
508 stream_id: &'s S,
509 data: D,
510 options: &WriteStreamOptions,
511 ) -> Result<MessageId, Error>
512 where
513 D: ToRedisArgs,
514 &'s S: ToRedisArgs,
515 {
516 let cmd = cmd!["XADD", stream_id, options, "*", data];
517 self.query(cmd).await
518 }
519
520 async fn xrange<'s, S, T>(
521 &mut self,
522 stream_id: &'s S,
523 start: MessageId,
524 end: MessageId,
525 limit: Option<u64>,
526 ) -> Result<Vec<StreamItem<T>>, Error>
527 where
528 &'s S: ToRedisArgs,
529 S: FromRedisValue + Eq + Hash + Send + 'static,
530 T: FromRedisValue + Send + 'static,
531 {
532 let mut cmd = cmd!["XRANGE", stream_id, start, end];
533 if let Some(limit) = limit {
534 cmd.arg("COUNT");
535 cmd.arg(limit);
536 }
537 self.query(cmd).await
538 }
539
540 pub async fn xread<'s, T, S>(
541 &mut self,
542 streams: &[ReadStream<'s, S>],
543 options: &ReadStreamOptions,
544 ) -> Result<StreamReadReply<S, T>, Error>
545 where
546 T: FromRedisValue + Send + 'static,
547 S: FromRedisValue + Eq + Hash + Send + 'static,
548 &'s S: ToRedisArgs,
549 {
550 let mut cmd = cmd!["XREAD", options, "STREAMS"];
551 for stream in streams {
552 cmd.arg(stream.id);
553 }
554 for stream in streams {
555 cmd.arg(stream.offset);
556 }
557 self.query(cmd).await
558 }
559
560 pub async fn xlen<'s, S>(&mut self, stream_id: &'s S) -> Result<u64, Error>
561 where
562 &'s S: ToRedisArgs,
563 {
564 let cmd = cmd!["XLEN", stream_id];
565 self.query(cmd).await
566 }
567}
568
569pub struct ConnectionMulti<'a>(&'a mut Connection);
570
571impl<'a> ConnectionMulti<'a> {
572 pub async fn del<T: ToRedisArgs>(&mut self, key: T) -> Result<(), Error> {
573 self.0.query(cmd!["DEL", key]).await?;
574 Ok(())
575 }
576
577 pub async fn exec<T>(&mut self) -> Result<Vec<T>, Error>
578 where
579 T: FromRedisValue + Send + 'static,
580 {
581 self.0.query(cmd!["EXEC"]).await
582 }
583}
584
585#[derive(Clone, Debug)]
586pub struct Client {
587 connections: Vec<ConnectionInfo>,
588 metrics: Arc<RedisMetrics>,
589 mode: ConnectionMode,
590}
591
592#[derive(Debug, Clone, PartialEq, Eq, Hash)]
593pub struct Address {
594 pub host: String,
595 pub port: u16,
596}
597
598impl Client {
599 pub fn new(
600 hosts: &[(&str, u16)],
601 password: Option<&str>,
602 metrics: Arc<RedisMetrics>,
603 is_tls: bool,
604 ) -> Self {
605 Self::with_mode(hosts, password, metrics, ConnectionMode::default(), is_tls)
606 }
607
608 pub fn with_mode(
609 hosts: &[(&str, u16)],
610 password: Option<&str>,
611 metrics: Arc<RedisMetrics>,
612 mode: ConnectionMode,
613 is_tls: bool,
614 ) -> Self {
615 let connections = hosts
616 .iter()
617 .map(|(host, port)| build_info(host, *port, password, is_tls))
618 .collect();
619 Self {
620 connections,
621 metrics,
622 mode,
623 }
624 }
625
626 async fn connect_cluster(&self) -> Result<Connection, Error> {
627 let conn = ClusterConn::try_connect(self.connections.clone()).await?;
628 tracing::info!("initiated clustered redis connection");
629 Ok(Connection::Metered(MeteredConn::new(
630 Connection::Clustered(conn),
631 self.metrics.clone(),
632 )))
633 }
634
635 async fn connect_single(&self) -> Result<Connection, Error> {
636 if self.connections.is_empty() {
637 return Err(Error::MissingConnectionInfo);
638 }
639
640 let conn = SingleConn::try_connect(self.connections[0].clone()).await?;
641 tracing::info!("initiated single redis connection");
642 Ok(Connection::Metered(MeteredConn::new(
643 Connection::Single(conn),
644 self.metrics.clone(),
645 )))
646 }
647
648 pub async fn get_connection(&self) -> Result<Connection, Error> {
649 let addresses = self
650 .connections
651 .iter()
652 .map(|info| format!("{:?}", info.addr))
653 .collect::<Vec<_>>()
654 .join(", ");
655 tracing::info!("initiating redis connection with addresses {:?}", addresses);
656
657 match self.mode {
658 ConnectionMode::Detect => match self.connect_cluster().await {
659 Ok(conn) => Ok(conn),
660 Err(_) if self.connections.len() == 1 => self.connect_single().await,
663 Err(e) => Err(e),
664 },
665 ConnectionMode::Single => self.connect_single().await,
666 ConnectionMode::Cluster => self.connect_cluster().await,
667 }
668 }
669}
670
671fn build_info(host: &str, port: u16, password: Option<&str>, is_tls: bool) -> ConnectionInfo {
672 let addr = if is_tls {
673 ConnectionAddr::TcpTls {
674 host: host.to_owned(),
675 port,
676 insecure: true,
677 }
678 } else {
679 ConnectionAddr::Tcp(host.to_owned(), port)
680 };
681 ConnectionInfo {
682 addr,
683 redis: RedisConnectionInfo {
684 db: 0,
685 username: None,
686 password: password.filter(|p| !p.is_empty()).map(String::from),
687 },
688 }
689}
690
691#[macro_export]
692macro_rules! cmd {
693 [$($arg:expr $(,)*)*] => {{
694 let mut cmd = redis::Cmd::new();
695 $(cmd.arg($arg);)*
696 cmd
697 }}
698}
699
700pub mod integration_test {
701
702 use super::*;
703 use std::sync::Arc;
704
705 pub struct TestRedis {
706 pub single: Arc<Redis>,
707 pub cluster: Arc<Redis>,
708 }
709
710 impl TestRedis {
716 pub async fn new() -> Self {
717 let single_config = Config::new(&["localhost:6379"]).unwrap();
718 let cluster_config = Config::new(&["localhost:7000"]).unwrap();
719 let single_redis = Redis::new(&single_config).await;
720 let cluster_redis = Redis::new(&cluster_config).await;
721
722 Self {
723 single: Arc::new(single_redis),
724 cluster: Arc::new(cluster_redis),
725 }
726 }
727 }
728
729 #[macro_export]
731 macro_rules! test_using_redis {
732 (async fn $func:ident ($param:ident: Arc<Redis>) $code:block) => {
733 paste::paste! {
734 #[tokio::test]
735 async fn [<$func _single>]() {
736 let redis = crate::integration_test::TestRedis::new().await;
737 let $param = redis.single;
738 $code
739 }
740
741 #[tokio::test]
742 async fn [<$func _cluster>]() {
743 let redis = crate::integration_test::TestRedis::new().await;
744 let $param = redis.cluster;
745 $code
746 }
747 }
748 };
749 }
750
751 test_using_redis! {
752 async fn test_redis(redis: Arc<Redis>) {
753 let mut connection = redis.get_conn().await.unwrap();
754 println!("connected to redis");
755
756 connection.set("foo", "bar").await.expect("set foo");
757 let data: Option<String> = connection.get("foo").await.expect("data");
758 assert_eq!(data.expect("foo"), "bar");
759 }
760 }
761}