redis_async/
reconnect.rs

1/*
2 * Copyright 2018-2025 Ben Ashford
3 *
4 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5 * http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 * <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
7 * option. This file may not be copied, modified, or distributed
8 * except according to those terms.
9 */
10
11use std::fmt;
12use std::future::Future;
13use std::mem;
14use std::pin::Pin;
15use std::sync::{Arc, Mutex, MutexGuard};
16use std::time::Duration;
17
18use futures_util::{
19    future::{self, Either},
20    TryFutureExt,
21};
22
23use tokio::time::timeout;
24
25use crate::error::{self, ConnectionReason};
26
27type WorkFn<T, A> = dyn Fn(&T, A) -> Result<(), error::Error> + Send + Sync;
28type ConnFn<T> =
29    dyn Fn() -> Pin<Box<dyn Future<Output = Result<T, error::Error>> + Send + Sync>> + Send + Sync;
30
31const CONNECTION_TIMEOUT_SECONDS: u64 = 1;
32const MAX_CONNECTION_ATTEMPTS: u64 = 10;
33const CONNECTION_TIMEOUT: Duration = Duration::from_secs(CONNECTION_TIMEOUT_SECONDS);
34
35#[derive(Debug, Copy, Clone)]
36pub(crate) struct ReconnectOptions {
37    pub(crate) connection_timeout: Duration,
38    pub(crate) max_connection_attempts: u64,
39}
40
41impl Default for ReconnectOptions {
42    #[inline]
43    fn default() -> Self {
44        ReconnectOptions {
45            connection_timeout: CONNECTION_TIMEOUT,
46            max_connection_attempts: MAX_CONNECTION_ATTEMPTS,
47        }
48    }
49}
50
51struct ReconnectInner<A, T> {
52    state: Mutex<ReconnectState<T>>,
53    work_fn: Box<WorkFn<T, A>>,
54    conn_fn: Box<ConnFn<T>>,
55    reconnect_options: ReconnectOptions,
56}
57
58impl<A, T> fmt::Debug for ReconnectInner<A, T> {
59    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
60        let struct_name = format!(
61            "ReconnectInner<{}, {}>",
62            std::any::type_name::<A>(),
63            std::any::type_name::<T>()
64        );
65
66        let work_fn_d = format!("@{:p}", self.work_fn.as_ref());
67        let conn_fn_d = format!("@{:p}", self.conn_fn.as_ref());
68
69        f.debug_struct(&struct_name)
70            .field("state", &self.state)
71            .field("work_fn", &work_fn_d)
72            .field("conn_fn", &conn_fn_d)
73            .finish()
74    }
75}
76
77#[derive(Debug)]
78pub(crate) struct Reconnect<A, T>(Arc<ReconnectInner<A, T>>);
79
80impl<A, T> Clone for Reconnect<A, T> {
81    fn clone(&self) -> Self {
82        Reconnect(self.0.clone())
83    }
84}
85
86pub(crate) async fn reconnect<A, T, W, C>(
87    w: W,
88    c: C,
89    options: ReconnectOptions,
90) -> Result<Reconnect<A, T>, error::Error>
91where
92    A: Send + 'static,
93    W: Fn(&T, A) -> Result<(), error::Error> + Send + Sync + 'static,
94    C: Fn() -> Pin<Box<dyn Future<Output = Result<T, error::Error>> + Send + Sync>>
95        + Send
96        + Sync
97        + 'static,
98    T: Clone + Send + Sync + 'static,
99{
100    let r = Reconnect(Arc::new(ReconnectInner {
101        state: Mutex::new(ReconnectState::NotConnected),
102
103        work_fn: Box::new(w),
104        conn_fn: Box::new(c),
105
106        reconnect_options: options,
107    }));
108    let rf = {
109        let state = r.0.state.lock().expect("Poisoned lock");
110        r.reconnect(state)
111    };
112    rf.await?;
113    Ok(r)
114}
115
116enum ReconnectState<T> {
117    NotConnected,
118    Connected(T),
119    ConnectionFailed(Mutex<Option<error::Error>>),
120    Connecting,
121}
122
123impl<T> fmt::Debug for ReconnectState<T> {
124    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
125        write!(f, "ReconnectState::")?;
126        match self {
127            Self::NotConnected => write!(f, "NotConnected"),
128            Self::Connected(_) => write!(f, "Connected"),
129            Self::ConnectionFailed(_) => write!(f, "ConnectionFailed"),
130            Self::Connecting => write!(f, "Connecting"),
131        }
132    }
133}
134
135impl<A, T> Reconnect<A, T>
136where
137    A: Send + 'static,
138    T: Clone + Send + Sync + 'static,
139{
140    fn call_work(&self, t: &T, a: A) -> Result<bool, error::Error> {
141        if let Err(e) = (self.0.work_fn)(t, a) {
142            match e {
143                error::Error::IO(_) | error::Error::Unexpected(_) => {
144                    log::error!("Error in work_fn will force connection closed, next command will attempt to re-establish connection: {}", e);
145                    return Ok(false);
146                }
147                _ => (),
148            }
149            Err(e)
150        } else {
151            Ok(true)
152        }
153    }
154
155    pub(crate) fn do_work(&self, a: A) -> Result<(), error::Error> {
156        let mut state = self.0.state.lock().expect("Cannot obtain read lock");
157        match *state {
158            ReconnectState::NotConnected => {
159                self.reconnect_spawn(state);
160                Err(error::Error::Connection(ConnectionReason::NotConnected))
161            }
162            ReconnectState::Connected(ref t) => {
163                let success = self.call_work(t, a)?;
164                if !success {
165                    *state = ReconnectState::NotConnected;
166                    self.reconnect_spawn(state);
167                }
168                Ok(())
169            }
170            ReconnectState::ConnectionFailed(ref e) => {
171                let mut lock = e.lock().expect("Poisioned lock");
172                let e = match lock.take() {
173                    Some(e) => e,
174                    None => error::Error::Connection(ConnectionReason::NotConnected),
175                };
176                mem::drop(lock);
177
178                *state = ReconnectState::NotConnected;
179                self.reconnect_spawn(state);
180                Err(e)
181            }
182            ReconnectState::Connecting => {
183                Err(error::Error::Connection(ConnectionReason::Connecting))
184            }
185        }
186    }
187
188    /// Returns a future that completes when the connection is established or failed to establish
189    /// used only for timing.
190    fn reconnect(
191        &self,
192        mut state: MutexGuard<ReconnectState<T>>,
193    ) -> impl Future<Output = Result<(), error::Error>> + Send {
194        log::info!("Attempting to reconnect, current state: {:?}", *state);
195
196        match *state {
197            ReconnectState::Connected(_) => {
198                return Either::Right(future::err(error::Error::Connection(
199                    ConnectionReason::Connected,
200                )));
201            }
202            ReconnectState::Connecting => {
203                return Either::Right(future::err(error::Error::Connection(
204                    ConnectionReason::Connecting,
205                )));
206            }
207            ReconnectState::NotConnected | ReconnectState::ConnectionFailed(_) => (),
208        }
209        *state = ReconnectState::Connecting;
210
211        mem::drop(state);
212
213        let reconnect = self.clone();
214
215        let connection_f = async move {
216            let mut connection_result = Err(error::internal("Initial connection failed"));
217            for i in 0..reconnect.0.reconnect_options.max_connection_attempts {
218                let connection_count = i + 1;
219                log::debug!(
220                    "Connection attempt {}/{}",
221                    connection_count,
222                    reconnect.0.reconnect_options.max_connection_attempts
223                );
224                connection_result = match timeout(
225                    reconnect.0.reconnect_options.connection_timeout,
226                    (reconnect.0.conn_fn)(),
227                )
228                .await
229                {
230                    Ok(con_r) => con_r,
231                    Err(_) => Err(error::internal(format!(
232                        "Connection timed-out after {} seconds",
233                        reconnect.0.reconnect_options.connection_timeout.as_secs()
234                            * connection_count
235                    ))),
236                };
237                if connection_result.is_ok() {
238                    break;
239                }
240            }
241
242            let mut state = reconnect.0.state.lock().expect("Cannot obtain write lock");
243
244            match *state {
245                ReconnectState::NotConnected | ReconnectState::Connecting => {
246                    match connection_result {
247                        Ok(t) => {
248                            log::info!("Connection established");
249                            *state = ReconnectState::Connected(t);
250                            Ok(())
251                        }
252                        Err(e) => {
253                            log::error!("Connection cannot be established: {}", e);
254                            *state = ReconnectState::ConnectionFailed(Mutex::new(Some(e)));
255                            Err(error::Error::Connection(ConnectionReason::ConnectionFailed))
256                        }
257                    }
258                }
259                ReconnectState::ConnectionFailed(_) => {
260                    panic!("The connection state wasn't reset before connecting")
261                }
262                ReconnectState::Connected(_) => {
263                    panic!("A connected state shouldn't be attempting to reconnect")
264                }
265            }
266        };
267
268        Either::Left(connection_f)
269    }
270
271    fn reconnect_spawn(&self, state: MutexGuard<ReconnectState<T>>) {
272        let reconnect_f = self
273            .reconnect(state)
274            .map_err(|e| log::error!("Error asynchronously reconnecting: {}", e));
275
276        tokio::spawn(reconnect_f);
277    }
278}