redis_swapplex/
lib.rs

1//! Redis multiplexing with reconnection notifications and MGET auto-batching. Connection configuration is provided by [env-url](https://crates.io/crates/env-url).
2//!
3//! Why use this instead of [redis::aio::ConnectionManager](https://docs.rs/redis/latest/redis/aio/struct.ConnectionManager.html)?
4//! - Error-free reconnection behavior: when a command would otherwise fail as a consequence of the connection being dropped, this library will immediately reconnect and retry when able without producing an otherwise avoidable IoError and with subsequent reconnections debounced 1500ms
5//! - ENV configuration simplifies kubernetes usage
6//! - Reconnects can be observed allowing for Redis [server-assisted client-side caching](https://redis.io/docs/manual/client-side-caching/) using client tracking redirection
7//! - Integrated MGET auto-batching (up to 180x more performant than GET)
8//!
9//! Composible connection urls are provided by environment variables using [env-url](https://crates.io/crates/env-url) with the `REDIS` prefix:
10//!
11//! ```text
12//! REDIS_URL=redis://127.0.0.1:6379
13//! # Override env mapping for easy kubernetes config
14//! REDIS_HOST_ENV=MONOLITH_STAGE_REDIS_MASTER_PORT_6379_TCP_ADDR
15//! REDIS_PORT_ENV=MONOLITH_STAGE_REDIS_MASTER_SERVICE_PORT_REDIS
16//! ```
17//!
18//! ```rust
19//! use redis::{AsyncCommands, RedisResult};
20//! use redis_swapplex::get_connection;
21//!
22//! async fn get_value(key: &str) -> RedisResult<String> {
23//!   let mut conn = get_connection();
24//!   conn.get(key).await
25//! }
26//! ```
27
28#![allow(rustdoc::private_intra_doc_links)]
29#[doc(hidden)]
30pub extern crate arc_swap;
31extern crate self as redis_swapplex;
32
33pub use into_bytes::IntoBytes;
34
35use arc_swap::{ArcSwapAny, ArcSwapOption, Cache};
36pub use derive_redis_swapplex::ConnectionManagerContext;
37use env_url::*;
38use futures_util::{future::FutureExt, stream::unfold, Stream};
39use once_cell::sync::Lazy;
40use redis::{
41  aio::{ConnectionLike, MultiplexedConnection},
42  Client, Cmd, ErrorKind, Pipeline, RedisError, RedisFuture, RedisResult, Value,
43};
44use stack_queue::{
45  assignment::{CompletionReceipt, PendingAssignment},
46  local_queue, TaskQueue,
47};
48use std::{
49  cell::RefCell,
50  iter,
51  marker::PhantomData,
52  ops::Deref,
53  ptr::addr_of,
54  sync::Arc,
55  task::Poll,
56  thread::LocalKey,
57  time::{Duration, SystemTime},
58};
59use tokio::sync::Notify;
60
61/// Trait for defining redis client creation and db selection
62pub trait ConnectionInfo: Send + Sync + Sized {
63  fn new(client: RedisResult<Client>, db_index: i64) -> Self;
64  fn parse_index(url: &Url) -> Option<i64> {
65    let mut segments = url.path_segments()?;
66    let db_index: i64 = segments.next()?.parse().ok()?;
67
68    Some(db_index)
69  }
70
71  fn from_url(url: &Url) -> Self {
72    let db_index = <Self as ConnectionInfo>::parse_index(url).unwrap_or(0);
73    let client = redis::Client::open(url.as_str());
74
75    <Self as ConnectionInfo>::new(client, db_index)
76  }
77
78  fn get_db(&self) -> i64;
79  fn client(&self) -> &RedisResult<Client>;
80}
81
82#[derive(EnvURL, ConnectionManagerContext)]
83#[env_url(env_prefix = "REDIS", default = "redis://127.0.0.1:6379")]
84/// Default env-configured Redis connection manager
85pub struct EnvConnection;
86
87#[doc(hidden)]
88pub struct RedisDB<T: Send + Sync + Sized> {
89  client: RedisResult<Client>,
90  db_index: i64,
91  _marker: PhantomData<fn() -> T>,
92}
93
94impl<T> RedisDB<T>
95where
96  T: Send + Sync + 'static + Sized,
97{
98  pub fn new(client: RedisResult<Client>, db_index: i64) -> Self {
99    RedisDB {
100      client,
101      db_index,
102      _marker: PhantomData,
103    }
104  }
105}
106
107impl<T> ConnectionInfo for RedisDB<T>
108where
109  T: ServiceURL + Send + Sync + 'static + Sized,
110{
111  fn new(client: RedisResult<Client>, db_index: i64) -> Self {
112    RedisDB::new(client, db_index)
113  }
114
115  fn get_db(&self) -> i64 {
116    self.db_index
117  }
118
119  fn client(&self) -> &RedisResult<Client> {
120    &self.client
121  }
122}
123
124impl<T> Default for RedisDB<T>
125where
126  T: ServiceURL + Send + Sync + 'static + Sized,
127  Self: ConnectionInfo,
128{
129  fn default() -> Self {
130    match <T as ServiceURL>::service_url() {
131      Ok(url) => <Self as ConnectionInfo>::from_url(&url),
132      Err(_) => {
133        let client = Err(RedisError::from((
134          ErrorKind::InvalidClientConfig,
135          "Invalid Redis connection URL",
136        )));
137
138        Self {
139          client,
140          db_index: 0,
141          _marker: PhantomData,
142        }
143      }
144    }
145  }
146}
147
148#[doc(hidden)]
149pub enum ConnectionState {
150  Connecting,
151  ClientError(ErrorKind),
152  ConnectionError(ErrorKind, SystemTime),
153  Connected(MultiplexedConnection),
154}
155
156#[doc(hidden)]
157pub struct ConnectionManager<T: ConnectionInfo> {
158  state: ArcSwapOption<ConnectionState>,
159  notify: Notify,
160  connection_info: Lazy<T>,
161}
162
163impl<T> ConnectionManager<T>
164where
165  T: ConnectionInfo,
166{
167  pub const fn new(connection_info: fn() -> T) -> ConnectionManager<T> {
168    ConnectionManager {
169      state: ArcSwapOption::const_empty(),
170      notify: Notify::const_new(),
171      connection_info: Lazy::new(connection_info),
172    }
173  }
174
175  fn store_and_notify(&self, state: Option<Arc<ConnectionState>>) {
176    self.state.store(state);
177    self.notify.notify_waiters();
178  }
179
180  pub fn client(&self) -> &RedisResult<Client> {
181    self.connection_info.client()
182  }
183
184  pub fn get_db(&self) -> i64 {
185    self.connection_info.get_db()
186  }
187}
188
189impl<T> Deref for ConnectionManager<T>
190where
191  T: ConnectionInfo,
192{
193  type Target = ArcSwapAny<Option<Arc<ConnectionState>>>;
194
195  fn deref(&self) -> &Self::Target {
196    &self.state
197  }
198}
199
200#[derive(PartialEq)]
201struct ConnectionAddr(*const MultiplexedConnection);
202
203impl PartialEq<Option<ConnectionAddr>> for ConnectionAddr {
204  fn eq(&self, other: &Option<ConnectionAddr>) -> bool {
205    if let Some(addr) = other {
206      self.0 == addr.0
207    } else {
208      false
209    }
210  }
211}
212
213unsafe impl Send for ConnectionAddr {}
214unsafe impl Sync for ConnectionAddr {}
215
216pub trait ConnectionManagerContext: Send + Sync + 'static + Sized {
217  type ConnectionInfo: ConnectionInfo;
218
219  fn get_connection() -> ManagedConnection<Self> {
220    ManagedConnection::new()
221  }
222
223  fn connection_manager() -> &'static ConnectionManager<Self::ConnectionInfo>;
224
225  fn client() -> &'static RedisResult<Client> {
226    Self::connection_manager().client()
227  }
228
229  fn get_db() -> i64 {
230    Self::connection_manager().get_db()
231  }
232
233  fn state_cache() -> &'static LocalKey<
234    RefCell<Cache<&'static ArcSwapOption<ConnectionState>, Option<Arc<ConnectionState>>>>,
235  >;
236
237  fn with_state<T>(with_fn: fn(&Option<Arc<ConnectionState>>) -> T) -> T {
238    <Self as ConnectionManagerContext>::state_cache()
239      .with(|cache| with_fn(cache.borrow_mut().load()))
240  }
241}
242
243impl<T> RedisDB<T>
244where
245  T: ConnectionManagerContext,
246{
247  async fn get_multiplexed_connection() -> RedisResult<(MultiplexedConnection, ConnectionAddr)> {
248    let connection = T::with_state(|connection_state| match connection_state.as_deref() {
249      None => {
250        Self::establish_connection(None);
251        None
252      }
253      Some(ConnectionState::Connecting) => None,
254      Some(ConnectionState::ClientError(kind)) => Some(Err(RedisError::from((
255        kind.to_owned(),
256        "Invalid Redis connection URL",
257      )))),
258      Some(ConnectionState::ConnectionError(
259        ErrorKind::IoError | ErrorKind::ClusterDown | ErrorKind::BusyLoadingError,
260        time,
261      )) if SystemTime::now()
262        .duration_since(*time)
263        .unwrap()
264        .gt(&Duration::from_millis(1500)) =>
265      {
266        Self::establish_connection(None);
267        None
268      }
269      Some(ConnectionState::ConnectionError(kind, _)) => Some(Err(RedisError::from((
270        kind.to_owned(),
271        "Unable to establish Redis connection",
272      )))),
273      Some(ConnectionState::Connected(connection)) => {
274        let conn_addr = ConnectionAddr(addr_of!(*connection));
275        Some(Ok((connection.clone(), conn_addr)))
276      }
277    });
278
279    match connection {
280      Some(connection) => connection,
281      None => {
282        T::connection_manager().notify.notified().await;
283
284        T::with_state(|connection_state| match connection_state.as_deref() {
285          None => unreachable!(),
286          Some(ConnectionState::Connecting) => unreachable!(),
287          Some(ConnectionState::ClientError(kind)) => Err(RedisError::from((
288            kind.to_owned(),
289            "Invalid Redis connection URL",
290          ))),
291          Some(ConnectionState::ConnectionError(kind, _timestamp)) => Err(RedisError::from((
292            kind.to_owned(),
293            "Unable to establish Redis connection",
294          ))),
295          Some(ConnectionState::Connected(connection)) => {
296            let conn_addr = ConnectionAddr(addr_of!(*connection));
297            Ok((connection.clone(), conn_addr))
298          }
299        })
300      }
301    }
302  }
303
304  fn establish_connection(conn_addr: Option<ConnectionAddr>) {
305    let state = T::connection_manager().state.load();
306
307    let should_connect = match state.as_deref() {
308      None => true,
309      Some(ConnectionState::Connecting) => false,
310      // Never reconnect if there's been a client error; treat as poisoned
311      Some(ConnectionState::ClientError(_)) => false,
312      Some(ConnectionState::ConnectionError(
313        ErrorKind::AuthenticationFailed | ErrorKind::InvalidClientConfig,
314        _,
315      )) => false,
316      Some(ConnectionState::ConnectionError(_, time))
317        if SystemTime::now()
318          .duration_since(*time)
319          .unwrap()
320          .gt(&Duration::from_millis(1500)) =>
321      {
322        true
323      }
324      Some(ConnectionState::ConnectionError(_, _)) => false,
325      Some(ConnectionState::Connected(connection)) => {
326        if let Some(conn_addr) = conn_addr {
327          let current_addr = ConnectionAddr(addr_of!(*connection));
328
329          // Only reconnect if conn_addr hasn't changed
330          conn_addr.eq(&current_addr)
331        } else {
332          false
333        }
334      }
335    };
336
337    if should_connect {
338      let prev = T::connection_manager()
339        .state
340        .compare_and_swap(&state, Some(Arc::new(ConnectionState::Connecting)));
341
342      if match (prev.as_ref(), state.as_ref()) {
343        (None, None) => true,
344        (Some(prev), Some(state)) => Arc::ptr_eq(prev, state),
345        _ => false,
346      } {
347        tokio::task::spawn(async move {
348          match T::client() {
349            Ok(client) => match client.get_multiplexed_tokio_connection().await {
350              Ok(conn) => {
351                T::connection_manager()
352                  .store_and_notify(Some(Arc::new(ConnectionState::Connected(conn))));
353              }
354              Err(err) => T::connection_manager().store_and_notify(Some(Arc::new(
355                ConnectionState::ConnectionError(err.kind(), SystemTime::now()),
356              ))),
357            },
358            Err(err) => T::connection_manager()
359              .store_and_notify(Some(Arc::new(ConnectionState::ClientError(err.kind())))),
360          }
361        });
362      }
363    }
364  }
365
366  pub async fn on_connected() -> RedisResult<()> {
367    loop {
368      T::connection_manager().notify.notified().await;
369
370      let poll = T::with_state(|connection_state| match connection_state.as_deref() {
371        Some(ConnectionState::ClientError(kind)) => Poll::Ready(Err(RedisError::from((
372          kind.to_owned(),
373          "Invalid Redis connection URL",
374        )))),
375        Some(ConnectionState::ConnectionError(
376          ErrorKind::BusyLoadingError | ErrorKind::ClusterDown | ErrorKind::IoError,
377          _,
378        )) => Poll::Pending,
379        Some(ConnectionState::ConnectionError(kind, _)) => Poll::Ready(Err(RedisError::from((
380          kind.to_owned(),
381          "Unable to establish Redis connection",
382        )))),
383        Some(ConnectionState::Connected(_)) => Poll::Ready(Ok(())),
384        _ => Poll::Pending,
385      });
386
387      match poll {
388        Poll::Pending => continue,
389        Poll::Ready(result) => return result,
390      }
391    }
392  }
393}
394
395/// A multiplexed connection utilizing the respective connection manager
396pub struct ManagedConnection<T: ConnectionManagerContext> {
397  _marker: PhantomData<T>,
398}
399
400impl<T> ManagedConnection<T>
401where
402  T: ConnectionManagerContext,
403{
404  pub fn new() -> Self {
405    ManagedConnection {
406      _marker: PhantomData,
407    }
408  }
409}
410
411impl<T> Default for ManagedConnection<T>
412where
413  T: ConnectionManagerContext,
414{
415  fn default() -> Self {
416    ManagedConnection::new()
417  }
418}
419
420impl<T> ConnectionLike for ManagedConnection<T>
421where
422  T: ConnectionManagerContext,
423{
424  fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
425    (async move {
426      loop {
427        let (mut conn, addr) = <RedisDB<T>>::get_multiplexed_connection().await?;
428
429        match conn.req_packed_command(cmd).await {
430          Ok(result) => break Ok(result),
431          Err(err) => {
432            if err.is_connection_dropped() {
433              <RedisDB<T>>::establish_connection(Some(addr));
434              continue;
435            }
436
437            break Err(err);
438          }
439        }
440      }
441    })
442    .boxed()
443  }
444
445  fn req_packed_commands<'a>(
446    &'a mut self,
447    cmd: &'a Pipeline,
448    offset: usize,
449    count: usize,
450  ) -> RedisFuture<'a, Vec<Value>> {
451    (async move {
452      loop {
453        let (mut conn, addr) = <RedisDB<T>>::get_multiplexed_connection().await?;
454
455        match conn.req_packed_commands(cmd, offset, count).await {
456          Ok(result) => break Ok(result),
457          Err(err) => {
458            if err.is_connection_dropped() {
459              <RedisDB<T>>::establish_connection(Some(addr));
460              continue;
461            }
462
463            break Err(err);
464          }
465        }
466      }
467    })
468    .boxed()
469  }
470
471  fn get_db(&self) -> i64 {
472    T::get_db()
473  }
474}
475
476/// Get a managed multiplexed connection for the default env-configured Redis database
477pub fn get_connection() -> ManagedConnection<EnvConnection> {
478  EnvConnection::get_connection()
479}
480
481/// Notify the next time a connection is established
482pub async fn on_connected<T>() -> RedisResult<()>
483where
484  T: ConnectionManagerContext,
485{
486  <RedisDB<T>>::on_connected().await
487}
488
489fn connection_addr<T>() -> Option<ConnectionAddr>
490where
491  T: ConnectionManagerContext,
492{
493  T::with_state(|connect_state| {
494    if let Some(ConnectionState::Connected(connection)) = connect_state.as_deref() {
495      let conn_addr = ConnectionAddr(addr_of!(*connection));
496
497      Some(conn_addr)
498    } else {
499      None
500    }
501  })
502}
503
504/// A stream notifying whenever the current or a new connection is connected; useful for client tracking redirection
505pub fn connection_stream<T>() -> impl Stream<Item = ()>
506where
507  T: ConnectionManagerContext,
508{
509  unfold(None, |conn_addr| async move {
510    loop {
511      if let Some(current_addr) = connection_addr::<T>() {
512        if current_addr.ne(&conn_addr) {
513          break Some(((), Some(current_addr)));
514        }
515      }
516
517      T::connection_manager().notify.notified().await
518    }
519  })
520}
521
522/// Get the value of a key using auto-batched MGET commands
523pub async fn get<K: IntoBytes>(key: K) -> Result<Option<Vec<u8>>, ErrorKind> {
524  struct MGetQueue;
525
526  #[local_queue(buffer_size = 2048)]
527  impl TaskQueue for MGetQueue {
528    type Task = Vec<u8>;
529    type Value = Result<Option<Vec<u8>>, ErrorKind>;
530
531    async fn batch_process<const N: usize>(
532      batch: PendingAssignment<'_, Self, N>,
533    ) -> CompletionReceipt<Self> {
534      let mut conn = get_connection();
535      let assignment = batch.into_assignment();
536      let (front, back) = assignment.as_slices();
537
538      let data: Result<Vec<Option<Vec<u8>>>, RedisError> = redis::cmd("MGET")
539        .arg(front)
540        .arg(back)
541        .query_async(&mut conn)
542        .await;
543
544      match data {
545        Ok(data) => assignment.resolve_with_iter(data.into_iter().map(Result::Ok)),
546        Err(err) => assignment.resolve_with_iter(iter::repeat(Result::Err(err.kind()))),
547      }
548    }
549  }
550
551  MGetQueue::auto_batch(key.into_bytes()).await
552}
553
554/// Set the value of a key using auto-batched MSET commands
555pub async fn set<K: IntoBytes, V: IntoBytes>(key: K, value: V) -> Result<(), ErrorKind> {
556  struct MSetQueue;
557
558  #[local_queue(buffer_size = 2048)]
559  impl TaskQueue for MSetQueue {
560    type Task = [Vec<u8>; 2];
561    type Value = Result<(), ErrorKind>;
562
563    async fn batch_process<const N: usize>(
564      batch: PendingAssignment<'_, Self, N>,
565    ) -> CompletionReceipt<Self> {
566      let mut conn = get_connection();
567      let assignment = batch.into_assignment();
568
569      let mut cmd = redis::cmd("MSET");
570
571      for kv in assignment.tasks() {
572        cmd.arg(kv.deref());
573      }
574
575      match cmd.query_async(&mut conn).await {
576        Ok(()) => assignment.resolve_with_iter(iter::repeat(Ok(()))),
577        Err(err) => assignment.resolve_with_iter(iter::repeat(Result::Err(err.kind()))),
578      }
579    }
580  }
581
582  MSetQueue::auto_batch([key.into_bytes(), value.into_bytes()]).await
583}
584
585#[cfg(test)]
586#[ctor::ctor]
587fn setup_test_env() {
588  std::env::set_var("REDIS_URL", "redis://127.0.0.1:6379");
589}
590#[cfg(all(test))]
591mod tests {
592  use std::collections::HashSet;
593
594  use futures_util::StreamExt;
595  use redis::AsyncCommands;
596
597  use super::*;
598
599  #[tokio::test]
600  async fn reconnects_on_error() -> RedisResult<()> {
601    let conn_stream = connection_stream::<EnvConnection>();
602
603    tokio::pin!(conn_stream);
604
605    let mut conn = get_connection();
606
607    let mut pipe = redis::pipe();
608
609    pipe
610      .atomic()
611      .del("test::stream")
612      .xgroup_create_mkstream("test::stream", "rustc", "0");
613
614    let _: (i64, String) = pipe.query_async(&mut conn).await?;
615
616    conn_stream.next().await;
617
618    let _: () = redis::cmd("QUIT").query_async(&mut conn).await?;
619
620    let result: RedisResult<String> = conn
621      .xgroup_create_mkstream("test::stream", "rustc", "0")
622      .await;
623
624    match result {
625      Err(err) if err.kind().eq(&ErrorKind::ExtensionError) => {
626        assert_eq!(err.code(), Some("BUSYGROUP"));
627      }
628      _ => panic!("Expected BUSYGROUP error"),
629    };
630
631    conn_stream.next().await;
632
633    conn.del("test::stream").await?;
634
635    Ok(())
636  }
637
638  #[tokio::test]
639  async fn reconnects_immediately() -> RedisResult<()> {
640    let mut conn = get_connection();
641
642    let mut client_list: HashSet<i32> = HashSet::new();
643
644    for _ in 0..10 {
645      let (client_id, _): (i32, String) = redis::pipe()
646        .cmd("CLIENT")
647        .arg("ID")
648        .cmd("QUIT")
649        .query_async(&mut conn)
650        .await?;
651
652      client_list.insert(client_id);
653    }
654
655    assert_eq!(client_list.len(), 10);
656
657    Ok(())
658  }
659
660  #[ignore = "use `cargo test -- --ignored` to test in isolation"]
661  #[tokio::test]
662  async fn handles_shutdown() -> RedisResult<()> {
663    let mut conn = get_connection();
664
665    match redis::cmd("SHUTDOWN").query_async(&mut conn).await {
666      Ok(()) => panic!("Redis shutdown should result in IoError"),
667      Err(err) if err.kind().eq(&ErrorKind::IoError) => Ok(()),
668      Err(err) => Err(err),
669    }?;
670
671    match redis::cmd("CLIENT").arg("ID").query_async(&mut conn).await {
672      Ok(()) => panic!("Redis server should still be offline"),
673      Err(err) if err.kind().eq(&ErrorKind::IoError) => Ok(()),
674      Err(err) => Err(err),
675    }?;
676
677    tokio::time::sleep(Duration::from_millis(1400)).await;
678
679    match redis::cmd("CLIENT").arg("ID").query_async(&mut conn).await {
680      Ok(()) => panic!("Redis server should be online, but we shouldn't be able to reconnect yet"),
681      Err(err) if err.kind().eq(&ErrorKind::IoError) => Ok(()),
682      Err(err) => Err(err),
683    }?;
684
685    tokio::time::sleep(Duration::from_millis(100)).await;
686
687    redis::cmd("CLIENT")
688      .arg("ID")
689      .query_async(&mut conn)
690      .await?;
691
692    Ok(())
693  }
694}