1use std::fmt;
12use std::future::Future;
13use std::mem;
14use std::pin::Pin;
15use std::sync::{Arc, Mutex, MutexGuard};
16use std::time::Duration;
17
18use futures_util::{
19 future::{self, Either},
20 TryFutureExt,
21};
22
23use tokio::time::timeout;
24
25use crate::error::{self, ConnectionReason};
26
27type WorkFn<T, A> = dyn Fn(&T, A) -> Result<(), error::Error> + Send + Sync;
28type ConnFn<T> =
29 dyn Fn() -> Pin<Box<dyn Future<Output = Result<T, error::Error>> + Send + Sync>> + Send + Sync;
30
31struct ReconnectInner<A, T> {
32 state: Mutex<ReconnectState<T>>,
33 work_fn: Box<WorkFn<T, A>>,
34 conn_fn: Box<ConnFn<T>>,
35}
36
37impl<A, T> fmt::Debug for ReconnectInner<A, T> {
38 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
39 let struct_name = format!(
40 "ReconnectInner<{}, {}>",
41 std::any::type_name::<A>(),
42 std::any::type_name::<T>()
43 );
44
45 let work_fn_d = format!("@{:p}", self.work_fn.as_ref());
46 let conn_fn_d = format!("@{:p}", self.conn_fn.as_ref());
47
48 f.debug_struct(&struct_name)
49 .field("state", &self.state)
50 .field("work_fn", &work_fn_d)
51 .field("conn_fn", &conn_fn_d)
52 .finish()
53 }
54}
55
56#[derive(Debug)]
57pub(crate) struct Reconnect<A, T>(Arc<ReconnectInner<A, T>>);
58
59impl<A, T> Clone for Reconnect<A, T> {
60 fn clone(&self) -> Self {
61 Reconnect(self.0.clone())
62 }
63}
64
65pub(crate) async fn reconnect<A, T, W, C>(w: W, c: C) -> Result<Reconnect<A, T>, error::Error>
66where
67 A: Send + 'static,
68 W: Fn(&T, A) -> Result<(), error::Error> + Send + Sync + 'static,
69 C: Fn() -> Pin<Box<dyn Future<Output = Result<T, error::Error>> + Send + Sync>>
70 + Send
71 + Sync
72 + 'static,
73 T: Clone + Send + Sync + 'static,
74{
75 let r = Reconnect(Arc::new(ReconnectInner {
76 state: Mutex::new(ReconnectState::NotConnected),
77
78 work_fn: Box::new(w),
79 conn_fn: Box::new(c),
80 }));
81 let rf = {
82 let state = r.0.state.lock().expect("Poisoned lock");
83 r.reconnect(state)
84 };
85 rf.await?;
86 Ok(r)
87}
88
89enum ReconnectState<T> {
90 NotConnected,
91 Connected(T),
92 ConnectionFailed(Mutex<Option<error::Error>>),
93 Connecting,
94}
95
96impl<T> fmt::Debug for ReconnectState<T> {
97 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
98 write!(f, "ReconnectState::")?;
99 match self {
100 NotConnected => write!(f, "NotConnected"),
101 Connected(_) => write!(f, "Connected"),
102 ConnectionFailed(_) => write!(f, "ConnectionFailed"),
103 Connecting => write!(f, "Connecting"),
104 }
105 }
106}
107
108use self::ReconnectState::*;
109
110const CONNECTION_TIMEOUT_SECONDS: u64 = 10;
111const CONNECTION_TIMEOUT: Duration = Duration::from_secs(CONNECTION_TIMEOUT_SECONDS);
112
113impl<A, T> Reconnect<A, T>
114where
115 A: Send + 'static,
116 T: Clone + Send + Sync + 'static,
117{
118 fn call_work(&self, t: &T, a: A) -> Result<bool, error::Error> {
119 if let Err(e) = (self.0.work_fn)(t, a) {
120 match e {
121 error::Error::IO(_) | error::Error::Unexpected(_) => {
122 log::error!("Error in work_fn will force connection closed, next command will attempt to re-establish connection: {}", e);
123 return Ok(false);
124 }
125 _ => (),
126 }
127 Err(e)
128 } else {
129 Ok(true)
130 }
131 }
132
133 pub(crate) fn do_work(&self, a: A) -> Result<(), error::Error> {
134 let mut state = self.0.state.lock().expect("Cannot obtain read lock");
135 match *state {
136 NotConnected => {
137 self.reconnect_spawn(state);
138 Err(error::Error::Connection(ConnectionReason::NotConnected))
139 }
140 Connected(ref t) => {
141 let success = self.call_work(t, a)?;
142 if !success {
143 *state = NotConnected;
144 self.reconnect_spawn(state);
145 }
146 Ok(())
147 }
148 ConnectionFailed(ref e) => {
149 let mut lock = e.lock().expect("Poisioned lock");
150 let e = match lock.take() {
151 Some(e) => e,
152 None => error::Error::Connection(ConnectionReason::NotConnected),
153 };
154 mem::drop(lock);
155
156 *state = NotConnected;
157 self.reconnect_spawn(state);
158 Err(e)
159 }
160 Connecting => Err(error::Error::Connection(ConnectionReason::Connecting)),
161 }
162 }
163
164 fn reconnect(
167 &self,
168 mut state: MutexGuard<ReconnectState<T>>,
169 ) -> impl Future<Output = Result<(), error::Error>> + Send {
170 log::info!("Attempting to reconnect, current state: {:?}", *state);
171
172 match *state {
173 Connected(_) => {
174 return Either::Right(future::err(error::Error::Connection(
175 ConnectionReason::Connected,
176 )));
177 }
178 Connecting => {
179 return Either::Right(future::err(error::Error::Connection(
180 ConnectionReason::Connecting,
181 )));
182 }
183 NotConnected | ConnectionFailed(_) => (),
184 }
185 *state = ReconnectState::Connecting;
186
187 mem::drop(state);
188
189 let reconnect = self.clone();
190
191 let connection_f = async move {
192 let connection = match timeout(CONNECTION_TIMEOUT, (reconnect.0.conn_fn)()).await {
193 Ok(con_r) => con_r,
194 Err(_) => Err(error::internal(format!(
195 "Connection timed-out after {} seconds",
196 CONNECTION_TIMEOUT_SECONDS
197 ))),
198 };
199
200 let mut state = reconnect.0.state.lock().expect("Cannot obtain write lock");
201
202 match *state {
203 NotConnected | Connecting => match connection {
204 Ok(t) => {
205 log::info!("Connection established");
206 *state = Connected(t);
207 Ok(())
208 }
209 Err(e) => {
210 log::error!("Connection cannot be established: {}", e);
211 *state = ConnectionFailed(Mutex::new(Some(e)));
212 Err(error::Error::Connection(ConnectionReason::ConnectionFailed))
213 }
214 },
215 ConnectionFailed(_) => {
216 panic!("The connection state wasn't reset before connecting")
217 }
218 Connected(_) => panic!("A connected state shouldn't be attempting to reconnect"),
219 }
220 };
221
222 Either::Left(connection_f)
223 }
224
225 fn reconnect_spawn(&self, state: MutexGuard<ReconnectState<T>>) {
226 let reconnect_f = self
227 .reconnect(state)
228 .map_err(|e| log::error!("Error asynchronously reconnecting: {}", e));
229
230 tokio::spawn(reconnect_f);
231 }
232}