1#![allow(rustdoc::private_intra_doc_links)]
29#[doc(hidden)]
30pub extern crate arc_swap;
31extern crate self as redis_swapplex;
32
33pub use into_bytes::IntoBytes;
34
35use arc_swap::{ArcSwapAny, ArcSwapOption, Cache};
36pub use derive_redis_swapplex::ConnectionManagerContext;
37use env_url::*;
38use futures_util::{future::FutureExt, stream::unfold, Stream};
39use once_cell::sync::Lazy;
40use redis::{
41 aio::{ConnectionLike, MultiplexedConnection},
42 Client, Cmd, ErrorKind, Pipeline, RedisError, RedisFuture, RedisResult, Value,
43};
44use stack_queue::{
45 assignment::{CompletionReceipt, PendingAssignment},
46 local_queue, TaskQueue,
47};
48use std::{
49 cell::RefCell,
50 iter,
51 marker::PhantomData,
52 ops::Deref,
53 ptr::addr_of,
54 sync::Arc,
55 task::Poll,
56 thread::LocalKey,
57 time::{Duration, SystemTime},
58};
59use tokio::sync::Notify;
60
61pub trait ConnectionInfo: Send + Sync + Sized {
63 fn new(client: RedisResult<Client>, db_index: i64) -> Self;
64 fn parse_index(url: &Url) -> Option<i64> {
65 let mut segments = url.path_segments()?;
66 let db_index: i64 = segments.next()?.parse().ok()?;
67
68 Some(db_index)
69 }
70
71 fn from_url(url: &Url) -> Self {
72 let db_index = <Self as ConnectionInfo>::parse_index(url).unwrap_or(0);
73 let client = redis::Client::open(url.as_str());
74
75 <Self as ConnectionInfo>::new(client, db_index)
76 }
77
78 fn get_db(&self) -> i64;
79 fn client(&self) -> &RedisResult<Client>;
80}
81
82#[derive(EnvURL, ConnectionManagerContext)]
83#[env_url(env_prefix = "REDIS", default = "redis://127.0.0.1:6379")]
84pub struct EnvConnection;
86
87#[doc(hidden)]
88pub struct RedisDB<T: Send + Sync + Sized> {
89 client: RedisResult<Client>,
90 db_index: i64,
91 _marker: PhantomData<fn() -> T>,
92}
93
94impl<T> RedisDB<T>
95where
96 T: Send + Sync + 'static + Sized,
97{
98 pub fn new(client: RedisResult<Client>, db_index: i64) -> Self {
99 RedisDB {
100 client,
101 db_index,
102 _marker: PhantomData,
103 }
104 }
105}
106
107impl<T> ConnectionInfo for RedisDB<T>
108where
109 T: ServiceURL + Send + Sync + 'static + Sized,
110{
111 fn new(client: RedisResult<Client>, db_index: i64) -> Self {
112 RedisDB::new(client, db_index)
113 }
114
115 fn get_db(&self) -> i64 {
116 self.db_index
117 }
118
119 fn client(&self) -> &RedisResult<Client> {
120 &self.client
121 }
122}
123
124impl<T> Default for RedisDB<T>
125where
126 T: ServiceURL + Send + Sync + 'static + Sized,
127 Self: ConnectionInfo,
128{
129 fn default() -> Self {
130 match <T as ServiceURL>::service_url() {
131 Ok(url) => <Self as ConnectionInfo>::from_url(&url),
132 Err(_) => {
133 let client = Err(RedisError::from((
134 ErrorKind::InvalidClientConfig,
135 "Invalid Redis connection URL",
136 )));
137
138 Self {
139 client,
140 db_index: 0,
141 _marker: PhantomData,
142 }
143 }
144 }
145 }
146}
147
148#[doc(hidden)]
149pub enum ConnectionState {
150 Connecting,
151 ClientError(ErrorKind),
152 ConnectionError(ErrorKind, SystemTime),
153 Connected(MultiplexedConnection),
154}
155
156#[doc(hidden)]
157pub struct ConnectionManager<T: ConnectionInfo> {
158 state: ArcSwapOption<ConnectionState>,
159 notify: Notify,
160 connection_info: Lazy<T>,
161}
162
163impl<T> ConnectionManager<T>
164where
165 T: ConnectionInfo,
166{
167 pub const fn new(connection_info: fn() -> T) -> ConnectionManager<T> {
168 ConnectionManager {
169 state: ArcSwapOption::const_empty(),
170 notify: Notify::const_new(),
171 connection_info: Lazy::new(connection_info),
172 }
173 }
174
175 fn store_and_notify(&self, state: Option<Arc<ConnectionState>>) {
176 self.state.store(state);
177 self.notify.notify_waiters();
178 }
179
180 pub fn client(&self) -> &RedisResult<Client> {
181 self.connection_info.client()
182 }
183
184 pub fn get_db(&self) -> i64 {
185 self.connection_info.get_db()
186 }
187}
188
189impl<T> Deref for ConnectionManager<T>
190where
191 T: ConnectionInfo,
192{
193 type Target = ArcSwapAny<Option<Arc<ConnectionState>>>;
194
195 fn deref(&self) -> &Self::Target {
196 &self.state
197 }
198}
199
200#[derive(PartialEq)]
201struct ConnectionAddr(*const MultiplexedConnection);
202
203impl PartialEq<Option<ConnectionAddr>> for ConnectionAddr {
204 fn eq(&self, other: &Option<ConnectionAddr>) -> bool {
205 if let Some(addr) = other {
206 self.0 == addr.0
207 } else {
208 false
209 }
210 }
211}
212
213unsafe impl Send for ConnectionAddr {}
214unsafe impl Sync for ConnectionAddr {}
215
216pub trait ConnectionManagerContext: Send + Sync + 'static + Sized {
217 type ConnectionInfo: ConnectionInfo;
218
219 fn get_connection() -> ManagedConnection<Self> {
220 ManagedConnection::new()
221 }
222
223 fn connection_manager() -> &'static ConnectionManager<Self::ConnectionInfo>;
224
225 fn client() -> &'static RedisResult<Client> {
226 Self::connection_manager().client()
227 }
228
229 fn get_db() -> i64 {
230 Self::connection_manager().get_db()
231 }
232
233 fn state_cache() -> &'static LocalKey<
234 RefCell<Cache<&'static ArcSwapOption<ConnectionState>, Option<Arc<ConnectionState>>>>,
235 >;
236
237 fn with_state<T>(with_fn: fn(&Option<Arc<ConnectionState>>) -> T) -> T {
238 <Self as ConnectionManagerContext>::state_cache()
239 .with(|cache| with_fn(cache.borrow_mut().load()))
240 }
241}
242
243impl<T> RedisDB<T>
244where
245 T: ConnectionManagerContext,
246{
247 async fn get_multiplexed_connection() -> RedisResult<(MultiplexedConnection, ConnectionAddr)> {
248 let connection = T::with_state(|connection_state| match connection_state.as_deref() {
249 None => {
250 Self::establish_connection(None);
251 None
252 }
253 Some(ConnectionState::Connecting) => None,
254 Some(ConnectionState::ClientError(kind)) => Some(Err(RedisError::from((
255 kind.to_owned(),
256 "Invalid Redis connection URL",
257 )))),
258 Some(ConnectionState::ConnectionError(
259 ErrorKind::IoError | ErrorKind::ClusterDown | ErrorKind::BusyLoadingError,
260 time,
261 )) if SystemTime::now()
262 .duration_since(*time)
263 .unwrap()
264 .gt(&Duration::from_millis(1500)) =>
265 {
266 Self::establish_connection(None);
267 None
268 }
269 Some(ConnectionState::ConnectionError(kind, _)) => Some(Err(RedisError::from((
270 kind.to_owned(),
271 "Unable to establish Redis connection",
272 )))),
273 Some(ConnectionState::Connected(connection)) => {
274 let conn_addr = ConnectionAddr(addr_of!(*connection));
275 Some(Ok((connection.clone(), conn_addr)))
276 }
277 });
278
279 match connection {
280 Some(connection) => connection,
281 None => {
282 T::connection_manager().notify.notified().await;
283
284 T::with_state(|connection_state| match connection_state.as_deref() {
285 None => unreachable!(),
286 Some(ConnectionState::Connecting) => unreachable!(),
287 Some(ConnectionState::ClientError(kind)) => Err(RedisError::from((
288 kind.to_owned(),
289 "Invalid Redis connection URL",
290 ))),
291 Some(ConnectionState::ConnectionError(kind, _timestamp)) => Err(RedisError::from((
292 kind.to_owned(),
293 "Unable to establish Redis connection",
294 ))),
295 Some(ConnectionState::Connected(connection)) => {
296 let conn_addr = ConnectionAddr(addr_of!(*connection));
297 Ok((connection.clone(), conn_addr))
298 }
299 })
300 }
301 }
302 }
303
304 fn establish_connection(conn_addr: Option<ConnectionAddr>) {
305 let state = T::connection_manager().state.load();
306
307 let should_connect = match state.as_deref() {
308 None => true,
309 Some(ConnectionState::Connecting) => false,
310 Some(ConnectionState::ClientError(_)) => false,
312 Some(ConnectionState::ConnectionError(
313 ErrorKind::AuthenticationFailed | ErrorKind::InvalidClientConfig,
314 _,
315 )) => false,
316 Some(ConnectionState::ConnectionError(_, time))
317 if SystemTime::now()
318 .duration_since(*time)
319 .unwrap()
320 .gt(&Duration::from_millis(1500)) =>
321 {
322 true
323 }
324 Some(ConnectionState::ConnectionError(_, _)) => false,
325 Some(ConnectionState::Connected(connection)) => {
326 if let Some(conn_addr) = conn_addr {
327 let current_addr = ConnectionAddr(addr_of!(*connection));
328
329 conn_addr.eq(¤t_addr)
331 } else {
332 false
333 }
334 }
335 };
336
337 if should_connect {
338 let prev = T::connection_manager()
339 .state
340 .compare_and_swap(&state, Some(Arc::new(ConnectionState::Connecting)));
341
342 if match (prev.as_ref(), state.as_ref()) {
343 (None, None) => true,
344 (Some(prev), Some(state)) => Arc::ptr_eq(prev, state),
345 _ => false,
346 } {
347 tokio::task::spawn(async move {
348 match T::client() {
349 Ok(client) => match client.get_multiplexed_tokio_connection().await {
350 Ok(conn) => {
351 T::connection_manager()
352 .store_and_notify(Some(Arc::new(ConnectionState::Connected(conn))));
353 }
354 Err(err) => T::connection_manager().store_and_notify(Some(Arc::new(
355 ConnectionState::ConnectionError(err.kind(), SystemTime::now()),
356 ))),
357 },
358 Err(err) => T::connection_manager()
359 .store_and_notify(Some(Arc::new(ConnectionState::ClientError(err.kind())))),
360 }
361 });
362 }
363 }
364 }
365
366 pub async fn on_connected() -> RedisResult<()> {
367 loop {
368 T::connection_manager().notify.notified().await;
369
370 let poll = T::with_state(|connection_state| match connection_state.as_deref() {
371 Some(ConnectionState::ClientError(kind)) => Poll::Ready(Err(RedisError::from((
372 kind.to_owned(),
373 "Invalid Redis connection URL",
374 )))),
375 Some(ConnectionState::ConnectionError(
376 ErrorKind::BusyLoadingError | ErrorKind::ClusterDown | ErrorKind::IoError,
377 _,
378 )) => Poll::Pending,
379 Some(ConnectionState::ConnectionError(kind, _)) => Poll::Ready(Err(RedisError::from((
380 kind.to_owned(),
381 "Unable to establish Redis connection",
382 )))),
383 Some(ConnectionState::Connected(_)) => Poll::Ready(Ok(())),
384 _ => Poll::Pending,
385 });
386
387 match poll {
388 Poll::Pending => continue,
389 Poll::Ready(result) => return result,
390 }
391 }
392 }
393}
394
395pub struct ManagedConnection<T: ConnectionManagerContext> {
397 _marker: PhantomData<T>,
398}
399
400impl<T> ManagedConnection<T>
401where
402 T: ConnectionManagerContext,
403{
404 pub fn new() -> Self {
405 ManagedConnection {
406 _marker: PhantomData,
407 }
408 }
409}
410
411impl<T> Default for ManagedConnection<T>
412where
413 T: ConnectionManagerContext,
414{
415 fn default() -> Self {
416 ManagedConnection::new()
417 }
418}
419
420impl<T> ConnectionLike for ManagedConnection<T>
421where
422 T: ConnectionManagerContext,
423{
424 fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
425 (async move {
426 loop {
427 let (mut conn, addr) = <RedisDB<T>>::get_multiplexed_connection().await?;
428
429 match conn.req_packed_command(cmd).await {
430 Ok(result) => break Ok(result),
431 Err(err) => {
432 if err.is_connection_dropped() {
433 <RedisDB<T>>::establish_connection(Some(addr));
434 continue;
435 }
436
437 break Err(err);
438 }
439 }
440 }
441 })
442 .boxed()
443 }
444
445 fn req_packed_commands<'a>(
446 &'a mut self,
447 cmd: &'a Pipeline,
448 offset: usize,
449 count: usize,
450 ) -> RedisFuture<'a, Vec<Value>> {
451 (async move {
452 loop {
453 let (mut conn, addr) = <RedisDB<T>>::get_multiplexed_connection().await?;
454
455 match conn.req_packed_commands(cmd, offset, count).await {
456 Ok(result) => break Ok(result),
457 Err(err) => {
458 if err.is_connection_dropped() {
459 <RedisDB<T>>::establish_connection(Some(addr));
460 continue;
461 }
462
463 break Err(err);
464 }
465 }
466 }
467 })
468 .boxed()
469 }
470
471 fn get_db(&self) -> i64 {
472 T::get_db()
473 }
474}
475
476pub fn get_connection() -> ManagedConnection<EnvConnection> {
478 EnvConnection::get_connection()
479}
480
481pub async fn on_connected<T>() -> RedisResult<()>
483where
484 T: ConnectionManagerContext,
485{
486 <RedisDB<T>>::on_connected().await
487}
488
489fn connection_addr<T>() -> Option<ConnectionAddr>
490where
491 T: ConnectionManagerContext,
492{
493 T::with_state(|connect_state| {
494 if let Some(ConnectionState::Connected(connection)) = connect_state.as_deref() {
495 let conn_addr = ConnectionAddr(addr_of!(*connection));
496
497 Some(conn_addr)
498 } else {
499 None
500 }
501 })
502}
503
504pub fn connection_stream<T>() -> impl Stream<Item = ()>
506where
507 T: ConnectionManagerContext,
508{
509 unfold(None, |conn_addr| async move {
510 loop {
511 if let Some(current_addr) = connection_addr::<T>() {
512 if current_addr.ne(&conn_addr) {
513 break Some(((), Some(current_addr)));
514 }
515 }
516
517 T::connection_manager().notify.notified().await
518 }
519 })
520}
521
522pub async fn get<K: IntoBytes>(key: K) -> Result<Option<Vec<u8>>, ErrorKind> {
524 struct MGetQueue;
525
526 #[local_queue(buffer_size = 2048)]
527 impl TaskQueue for MGetQueue {
528 type Task = Vec<u8>;
529 type Value = Result<Option<Vec<u8>>, ErrorKind>;
530
531 async fn batch_process<const N: usize>(
532 batch: PendingAssignment<'_, Self, N>,
533 ) -> CompletionReceipt<Self> {
534 let mut conn = get_connection();
535 let assignment = batch.into_assignment();
536 let (front, back) = assignment.as_slices();
537
538 let data: Result<Vec<Option<Vec<u8>>>, RedisError> = redis::cmd("MGET")
539 .arg(front)
540 .arg(back)
541 .query_async(&mut conn)
542 .await;
543
544 match data {
545 Ok(data) => assignment.resolve_with_iter(data.into_iter().map(Result::Ok)),
546 Err(err) => assignment.resolve_with_iter(iter::repeat(Result::Err(err.kind()))),
547 }
548 }
549 }
550
551 MGetQueue::auto_batch(key.into_bytes()).await
552}
553
554pub async fn set<K: IntoBytes, V: IntoBytes>(key: K, value: V) -> Result<(), ErrorKind> {
556 struct MSetQueue;
557
558 #[local_queue(buffer_size = 2048)]
559 impl TaskQueue for MSetQueue {
560 type Task = [Vec<u8>; 2];
561 type Value = Result<(), ErrorKind>;
562
563 async fn batch_process<const N: usize>(
564 batch: PendingAssignment<'_, Self, N>,
565 ) -> CompletionReceipt<Self> {
566 let mut conn = get_connection();
567 let assignment = batch.into_assignment();
568
569 let mut cmd = redis::cmd("MSET");
570
571 for kv in assignment.tasks() {
572 cmd.arg(kv.deref());
573 }
574
575 match cmd.query_async(&mut conn).await {
576 Ok(()) => assignment.resolve_with_iter(iter::repeat(Ok(()))),
577 Err(err) => assignment.resolve_with_iter(iter::repeat(Result::Err(err.kind()))),
578 }
579 }
580 }
581
582 MSetQueue::auto_batch([key.into_bytes(), value.into_bytes()]).await
583}
584
585#[cfg(test)]
586#[ctor::ctor]
587fn setup_test_env() {
588 std::env::set_var("REDIS_URL", "redis://127.0.0.1:6379");
589}
590#[cfg(all(test))]
591mod tests {
592 use std::collections::HashSet;
593
594 use futures_util::StreamExt;
595 use redis::AsyncCommands;
596
597 use super::*;
598
599 #[tokio::test]
600 async fn reconnects_on_error() -> RedisResult<()> {
601 let conn_stream = connection_stream::<EnvConnection>();
602
603 tokio::pin!(conn_stream);
604
605 let mut conn = get_connection();
606
607 let mut pipe = redis::pipe();
608
609 pipe
610 .atomic()
611 .del("test::stream")
612 .xgroup_create_mkstream("test::stream", "rustc", "0");
613
614 let _: (i64, String) = pipe.query_async(&mut conn).await?;
615
616 conn_stream.next().await;
617
618 let _: () = redis::cmd("QUIT").query_async(&mut conn).await?;
619
620 let result: RedisResult<String> = conn
621 .xgroup_create_mkstream("test::stream", "rustc", "0")
622 .await;
623
624 match result {
625 Err(err) if err.kind().eq(&ErrorKind::ExtensionError) => {
626 assert_eq!(err.code(), Some("BUSYGROUP"));
627 }
628 _ => panic!("Expected BUSYGROUP error"),
629 };
630
631 conn_stream.next().await;
632
633 conn.del("test::stream").await?;
634
635 Ok(())
636 }
637
638 #[tokio::test]
639 async fn reconnects_immediately() -> RedisResult<()> {
640 let mut conn = get_connection();
641
642 let mut client_list: HashSet<i32> = HashSet::new();
643
644 for _ in 0..10 {
645 let (client_id, _): (i32, String) = redis::pipe()
646 .cmd("CLIENT")
647 .arg("ID")
648 .cmd("QUIT")
649 .query_async(&mut conn)
650 .await?;
651
652 client_list.insert(client_id);
653 }
654
655 assert_eq!(client_list.len(), 10);
656
657 Ok(())
658 }
659
660 #[ignore = "use `cargo test -- --ignored` to test in isolation"]
661 #[tokio::test]
662 async fn handles_shutdown() -> RedisResult<()> {
663 let mut conn = get_connection();
664
665 match redis::cmd("SHUTDOWN").query_async(&mut conn).await {
666 Ok(()) => panic!("Redis shutdown should result in IoError"),
667 Err(err) if err.kind().eq(&ErrorKind::IoError) => Ok(()),
668 Err(err) => Err(err),
669 }?;
670
671 match redis::cmd("CLIENT").arg("ID").query_async(&mut conn).await {
672 Ok(()) => panic!("Redis server should still be offline"),
673 Err(err) if err.kind().eq(&ErrorKind::IoError) => Ok(()),
674 Err(err) => Err(err),
675 }?;
676
677 tokio::time::sleep(Duration::from_millis(1400)).await;
678
679 match redis::cmd("CLIENT").arg("ID").query_async(&mut conn).await {
680 Ok(()) => panic!("Redis server should be online, but we shouldn't be able to reconnect yet"),
681 Err(err) if err.kind().eq(&ErrorKind::IoError) => Ok(()),
682 Err(err) => Err(err),
683 }?;
684
685 tokio::time::sleep(Duration::from_millis(100)).await;
686
687 redis::cmd("CLIENT")
688 .arg("ID")
689 .query_async(&mut conn)
690 .await?;
691
692 Ok(())
693 }
694}