1use std::fmt;
12use std::future::Future;
13use std::mem;
14use std::pin::Pin;
15use std::sync::{Arc, Mutex, MutexGuard};
16use std::time::Duration;
17
18use futures_util::{
19 future::{self, Either},
20 TryFutureExt,
21};
22
23use tokio::time::timeout;
24
25use crate::error::{self, ConnectionReason};
26
27type WorkFn<T, A> = dyn Fn(&T, A) -> Result<(), error::Error> + Send + Sync;
28type ConnFn<T> =
29 dyn Fn() -> Pin<Box<dyn Future<Output = Result<T, error::Error>> + Send + Sync>> + Send + Sync;
30
31const CONNECTION_TIMEOUT_SECONDS: u64 = 1;
32const MAX_CONNECTION_ATTEMPTS: u64 = 10;
33const CONNECTION_TIMEOUT: Duration = Duration::from_secs(CONNECTION_TIMEOUT_SECONDS);
34
35#[derive(Debug, Copy, Clone)]
36pub(crate) struct ReconnectOptions {
37 pub(crate) connection_timeout: Duration,
38 pub(crate) max_connection_attempts: u64,
39}
40
41impl Default for ReconnectOptions {
42 #[inline]
43 fn default() -> Self {
44 ReconnectOptions {
45 connection_timeout: CONNECTION_TIMEOUT,
46 max_connection_attempts: MAX_CONNECTION_ATTEMPTS,
47 }
48 }
49}
50
51struct ReconnectInner<A, T> {
52 state: Mutex<ReconnectState<T>>,
53 work_fn: Box<WorkFn<T, A>>,
54 conn_fn: Box<ConnFn<T>>,
55 reconnect_options: ReconnectOptions,
56}
57
58impl<A, T> fmt::Debug for ReconnectInner<A, T> {
59 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
60 let struct_name = format!(
61 "ReconnectInner<{}, {}>",
62 std::any::type_name::<A>(),
63 std::any::type_name::<T>()
64 );
65
66 let work_fn_d = format!("@{:p}", self.work_fn.as_ref());
67 let conn_fn_d = format!("@{:p}", self.conn_fn.as_ref());
68
69 f.debug_struct(&struct_name)
70 .field("state", &self.state)
71 .field("work_fn", &work_fn_d)
72 .field("conn_fn", &conn_fn_d)
73 .finish()
74 }
75}
76
77#[derive(Debug)]
78pub(crate) struct Reconnect<A, T>(Arc<ReconnectInner<A, T>>);
79
80impl<A, T> Clone for Reconnect<A, T> {
81 fn clone(&self) -> Self {
82 Reconnect(self.0.clone())
83 }
84}
85
86pub(crate) async fn reconnect<A, T, W, C>(
87 w: W,
88 c: C,
89 options: ReconnectOptions,
90) -> Result<Reconnect<A, T>, error::Error>
91where
92 A: Send + 'static,
93 W: Fn(&T, A) -> Result<(), error::Error> + Send + Sync + 'static,
94 C: Fn() -> Pin<Box<dyn Future<Output = Result<T, error::Error>> + Send + Sync>>
95 + Send
96 + Sync
97 + 'static,
98 T: Clone + Send + Sync + 'static,
99{
100 let r = Reconnect(Arc::new(ReconnectInner {
101 state: Mutex::new(ReconnectState::NotConnected),
102
103 work_fn: Box::new(w),
104 conn_fn: Box::new(c),
105
106 reconnect_options: options,
107 }));
108 let rf = {
109 let state = r.0.state.lock().expect("Poisoned lock");
110 r.reconnect(state)
111 };
112 rf.await?;
113 Ok(r)
114}
115
116enum ReconnectState<T> {
117 NotConnected,
118 Connected(T),
119 ConnectionFailed(Mutex<Option<error::Error>>),
120 Connecting,
121}
122
123impl<T> fmt::Debug for ReconnectState<T> {
124 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
125 write!(f, "ReconnectState::")?;
126 match self {
127 Self::NotConnected => write!(f, "NotConnected"),
128 Self::Connected(_) => write!(f, "Connected"),
129 Self::ConnectionFailed(_) => write!(f, "ConnectionFailed"),
130 Self::Connecting => write!(f, "Connecting"),
131 }
132 }
133}
134
135impl<A, T> Reconnect<A, T>
136where
137 A: Send + 'static,
138 T: Clone + Send + Sync + 'static,
139{
140 fn call_work(&self, t: &T, a: A) -> Result<bool, error::Error> {
141 if let Err(e) = (self.0.work_fn)(t, a) {
142 match e {
143 error::Error::IO(_) | error::Error::Unexpected(_) => {
144 log::error!("Error in work_fn will force connection closed, next command will attempt to re-establish connection: {}", e);
145 return Ok(false);
146 }
147 _ => (),
148 }
149 Err(e)
150 } else {
151 Ok(true)
152 }
153 }
154
155 pub(crate) fn do_work(&self, a: A) -> Result<(), error::Error> {
156 let mut state = self.0.state.lock().expect("Cannot obtain read lock");
157 match *state {
158 ReconnectState::NotConnected => {
159 self.reconnect_spawn(state);
160 Err(error::Error::Connection(ConnectionReason::NotConnected))
161 }
162 ReconnectState::Connected(ref t) => {
163 let success = self.call_work(t, a)?;
164 if !success {
165 *state = ReconnectState::NotConnected;
166 self.reconnect_spawn(state);
167 }
168 Ok(())
169 }
170 ReconnectState::ConnectionFailed(ref e) => {
171 let mut lock = e.lock().expect("Poisioned lock");
172 let e = match lock.take() {
173 Some(e) => e,
174 None => error::Error::Connection(ConnectionReason::NotConnected),
175 };
176 mem::drop(lock);
177
178 *state = ReconnectState::NotConnected;
179 self.reconnect_spawn(state);
180 Err(e)
181 }
182 ReconnectState::Connecting => {
183 Err(error::Error::Connection(ConnectionReason::Connecting))
184 }
185 }
186 }
187
188 fn reconnect(
191 &self,
192 mut state: MutexGuard<ReconnectState<T>>,
193 ) -> impl Future<Output = Result<(), error::Error>> + Send {
194 log::info!("Attempting to reconnect, current state: {:?}", *state);
195
196 match *state {
197 ReconnectState::Connected(_) => {
198 return Either::Right(future::err(error::Error::Connection(
199 ConnectionReason::Connected,
200 )));
201 }
202 ReconnectState::Connecting => {
203 return Either::Right(future::err(error::Error::Connection(
204 ConnectionReason::Connecting,
205 )));
206 }
207 ReconnectState::NotConnected | ReconnectState::ConnectionFailed(_) => (),
208 }
209 *state = ReconnectState::Connecting;
210
211 mem::drop(state);
212
213 let reconnect = self.clone();
214
215 let connection_f = async move {
216 let mut connection_result = Err(error::internal("Initial connection failed"));
217 for i in 0..reconnect.0.reconnect_options.max_connection_attempts {
218 let connection_count = i + 1;
219 log::debug!(
220 "Connection attempt {}/{}",
221 connection_count,
222 reconnect.0.reconnect_options.max_connection_attempts
223 );
224 connection_result = match timeout(
225 reconnect.0.reconnect_options.connection_timeout,
226 (reconnect.0.conn_fn)(),
227 )
228 .await
229 {
230 Ok(con_r) => con_r,
231 Err(_) => Err(error::internal(format!(
232 "Connection timed-out after {} seconds",
233 reconnect.0.reconnect_options.connection_timeout.as_secs()
234 * connection_count
235 ))),
236 };
237 if connection_result.is_ok() {
238 break;
239 }
240 }
241
242 let mut state = reconnect.0.state.lock().expect("Cannot obtain write lock");
243
244 match *state {
245 ReconnectState::NotConnected | ReconnectState::Connecting => {
246 match connection_result {
247 Ok(t) => {
248 log::info!("Connection established");
249 *state = ReconnectState::Connected(t);
250 Ok(())
251 }
252 Err(e) => {
253 log::error!("Connection cannot be established: {}", e);
254 *state = ReconnectState::ConnectionFailed(Mutex::new(Some(e)));
255 Err(error::Error::Connection(ConnectionReason::ConnectionFailed))
256 }
257 }
258 }
259 ReconnectState::ConnectionFailed(_) => {
260 panic!("The connection state wasn't reset before connecting")
261 }
262 ReconnectState::Connected(_) => {
263 panic!("A connected state shouldn't be attempting to reconnect")
264 }
265 }
266 };
267
268 Either::Left(connection_f)
269 }
270
271 fn reconnect_spawn(&self, state: MutexGuard<ReconnectState<T>>) {
272 let reconnect_f = self
273 .reconnect(state)
274 .map_err(|e| log::error!("Error asynchronously reconnecting: {}", e));
275
276 tokio::spawn(reconnect_f);
277 }
278}