redis_async/
reconnect.rs

1/*
2 * Copyright 2018-2020 Ben Ashford
3 *
4 * Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5 * http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 * <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
7 * option. This file may not be copied, modified, or distributed
8 * except according to those terms.
9 */
10
11use std::fmt;
12use std::future::Future;
13use std::mem;
14use std::pin::Pin;
15use std::sync::{Arc, Mutex, MutexGuard};
16use std::time::Duration;
17
18use futures_util::{
19    future::{self, Either},
20    TryFutureExt,
21};
22
23use tokio::time::timeout;
24
25use crate::error::{self, ConnectionReason};
26
27type WorkFn<T, A> = dyn Fn(&T, A) -> Result<(), error::Error> + Send + Sync;
28type ConnFn<T> =
29    dyn Fn() -> Pin<Box<dyn Future<Output = Result<T, error::Error>> + Send + Sync>> + Send + Sync;
30
31struct ReconnectInner<A, T> {
32    state: Mutex<ReconnectState<T>>,
33    work_fn: Box<WorkFn<T, A>>,
34    conn_fn: Box<ConnFn<T>>,
35}
36
37impl<A, T> fmt::Debug for ReconnectInner<A, T> {
38    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
39        let struct_name = format!(
40            "ReconnectInner<{}, {}>",
41            std::any::type_name::<A>(),
42            std::any::type_name::<T>()
43        );
44
45        let work_fn_d = format!("@{:p}", self.work_fn.as_ref());
46        let conn_fn_d = format!("@{:p}", self.conn_fn.as_ref());
47
48        f.debug_struct(&struct_name)
49            .field("state", &self.state)
50            .field("work_fn", &work_fn_d)
51            .field("conn_fn", &conn_fn_d)
52            .finish()
53    }
54}
55
56#[derive(Debug)]
57pub(crate) struct Reconnect<A, T>(Arc<ReconnectInner<A, T>>);
58
59impl<A, T> Clone for Reconnect<A, T> {
60    fn clone(&self) -> Self {
61        Reconnect(self.0.clone())
62    }
63}
64
65pub(crate) async fn reconnect<A, T, W, C>(w: W, c: C) -> Result<Reconnect<A, T>, error::Error>
66where
67    A: Send + 'static,
68    W: Fn(&T, A) -> Result<(), error::Error> + Send + Sync + 'static,
69    C: Fn() -> Pin<Box<dyn Future<Output = Result<T, error::Error>> + Send + Sync>>
70        + Send
71        + Sync
72        + 'static,
73    T: Clone + Send + Sync + 'static,
74{
75    let r = Reconnect(Arc::new(ReconnectInner {
76        state: Mutex::new(ReconnectState::NotConnected),
77
78        work_fn: Box::new(w),
79        conn_fn: Box::new(c),
80    }));
81    let rf = {
82        let state = r.0.state.lock().expect("Poisoned lock");
83        r.reconnect(state)
84    };
85    rf.await?;
86    Ok(r)
87}
88
89enum ReconnectState<T> {
90    NotConnected,
91    Connected(T),
92    ConnectionFailed(Mutex<Option<error::Error>>),
93    Connecting,
94}
95
96impl<T> fmt::Debug for ReconnectState<T> {
97    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
98        write!(f, "ReconnectState::")?;
99        match self {
100            NotConnected => write!(f, "NotConnected"),
101            Connected(_) => write!(f, "Connected"),
102            ConnectionFailed(_) => write!(f, "ConnectionFailed"),
103            Connecting => write!(f, "Connecting"),
104        }
105    }
106}
107
108use self::ReconnectState::*;
109
110const CONNECTION_TIMEOUT_SECONDS: u64 = 10;
111const CONNECTION_TIMEOUT: Duration = Duration::from_secs(CONNECTION_TIMEOUT_SECONDS);
112
113impl<A, T> Reconnect<A, T>
114where
115    A: Send + 'static,
116    T: Clone + Send + Sync + 'static,
117{
118    fn call_work(&self, t: &T, a: A) -> Result<bool, error::Error> {
119        if let Err(e) = (self.0.work_fn)(t, a) {
120            match e {
121                error::Error::IO(_) | error::Error::Unexpected(_) => {
122                    log::error!("Error in work_fn will force connection closed, next command will attempt to re-establish connection: {}", e);
123                    return Ok(false);
124                }
125                _ => (),
126            }
127            Err(e)
128        } else {
129            Ok(true)
130        }
131    }
132
133    pub(crate) fn do_work(&self, a: A) -> Result<(), error::Error> {
134        let mut state = self.0.state.lock().expect("Cannot obtain read lock");
135        match *state {
136            NotConnected => {
137                self.reconnect_spawn(state);
138                Err(error::Error::Connection(ConnectionReason::NotConnected))
139            }
140            Connected(ref t) => {
141                let success = self.call_work(t, a)?;
142                if !success {
143                    *state = NotConnected;
144                    self.reconnect_spawn(state);
145                }
146                Ok(())
147            }
148            ConnectionFailed(ref e) => {
149                let mut lock = e.lock().expect("Poisioned lock");
150                let e = match lock.take() {
151                    Some(e) => e,
152                    None => error::Error::Connection(ConnectionReason::NotConnected),
153                };
154                mem::drop(lock);
155
156                *state = NotConnected;
157                self.reconnect_spawn(state);
158                Err(e)
159            }
160            Connecting => Err(error::Error::Connection(ConnectionReason::Connecting)),
161        }
162    }
163
164    /// Returns a future that completes when the connection is established or failed to establish
165    /// used only for timing.
166    fn reconnect(
167        &self,
168        mut state: MutexGuard<ReconnectState<T>>,
169    ) -> impl Future<Output = Result<(), error::Error>> + Send {
170        log::info!("Attempting to reconnect, current state: {:?}", *state);
171
172        match *state {
173            Connected(_) => {
174                return Either::Right(future::err(error::Error::Connection(
175                    ConnectionReason::Connected,
176                )));
177            }
178            Connecting => {
179                return Either::Right(future::err(error::Error::Connection(
180                    ConnectionReason::Connecting,
181                )));
182            }
183            NotConnected | ConnectionFailed(_) => (),
184        }
185        *state = ReconnectState::Connecting;
186
187        mem::drop(state);
188
189        let reconnect = self.clone();
190
191        let connection_f = async move {
192            let connection = match timeout(CONNECTION_TIMEOUT, (reconnect.0.conn_fn)()).await {
193                Ok(con_r) => con_r,
194                Err(_) => Err(error::internal(format!(
195                    "Connection timed-out after {} seconds",
196                    CONNECTION_TIMEOUT_SECONDS
197                ))),
198            };
199
200            let mut state = reconnect.0.state.lock().expect("Cannot obtain write lock");
201
202            match *state {
203                NotConnected | Connecting => match connection {
204                    Ok(t) => {
205                        log::info!("Connection established");
206                        *state = Connected(t);
207                        Ok(())
208                    }
209                    Err(e) => {
210                        log::error!("Connection cannot be established: {}", e);
211                        *state = ConnectionFailed(Mutex::new(Some(e)));
212                        Err(error::Error::Connection(ConnectionReason::ConnectionFailed))
213                    }
214                },
215                ConnectionFailed(_) => {
216                    panic!("The connection state wasn't reset before connecting")
217                }
218                Connected(_) => panic!("A connected state shouldn't be attempting to reconnect"),
219            }
220        };
221
222        Either::Left(connection_f)
223    }
224
225    fn reconnect_spawn(&self, state: MutexGuard<ReconnectState<T>>) {
226        let reconnect_f = self
227            .reconnect(state)
228            .map_err(|e| log::error!("Error asynchronously reconnecting: {}", e));
229
230        tokio::spawn(reconnect_f);
231    }
232}