1use crate::config::ReconnectOptions;
2use crate::log::{error, info};
3use std::future::Future;
4use std::io::{self, ErrorKind, IoSlice};
5use std::marker::PhantomData;
6use std::ops::{Deref, DerefMut};
7use std::pin::Pin;
8use std::sync::{Arc, Mutex};
9use std::task::{Context, Poll};
10use std::time::Duration;
11use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
12use tokio::time::sleep;
13
14pub trait UnderlyingIo<C>: Sized + Unpin
17where
18 C: Clone + Send + Unpin,
19{
20 fn establish(ctor_arg: C) -> Pin<Box<dyn Future<Output = io::Result<Self>> + Send>>;
23
24 fn is_disconnect_error(&self, err: &io::Error) -> bool {
28 use std::io::ErrorKind::*;
29
30 matches!(
31 err.kind(),
32 NotFound
33 | PermissionDenied
34 | ConnectionRefused
35 | ConnectionReset
36 | ConnectionAborted
37 | NotConnected
38 | AddrInUse
39 | AddrNotAvailable
40 | BrokenPipe
41 | AlreadyExists
42 )
43 }
44
45 fn is_final_read(&self, bytes_read: usize) -> bool {
49 bytes_read == 0
52 }
53}
54
55struct AttemptsTracker {
56 attempt_num: usize,
57 retries_remaining: Box<dyn Iterator<Item = Duration> + Send + Sync>,
58}
59
60struct ReconnectStatus<T, C> {
61 attempts_tracker: AttemptsTracker,
62 #[allow(clippy::type_complexity)]
63 reconnect_attempt: Arc<Mutex<Pin<Box<dyn Future<Output = io::Result<T>> + Send>>>>,
64 _phantom_data: PhantomData<C>,
65}
66
67impl<T, C> ReconnectStatus<T, C>
68where
69 T: UnderlyingIo<C>,
70 C: Clone + Send + Unpin + 'static,
71{
72 pub fn new(options: &ReconnectOptions) -> Self {
73 ReconnectStatus {
74 attempts_tracker: AttemptsTracker {
75 attempt_num: 0,
76 retries_remaining: (options.retries_to_attempt_fn)(),
77 },
78 reconnect_attempt: Arc::new(Mutex::new(Box::pin(async {
79 unreachable!("Not going to happen")
80 }))),
81 _phantom_data: PhantomData,
82 }
83 }
84}
85
86pub struct StubbornIo<T, C> {
90 status: Status<T, C>,
91 underlying_io: T,
92 options: ReconnectOptions,
93 ctor_arg: C,
94}
95
96enum Status<T, C> {
97 Connected,
98 Disconnected(ReconnectStatus<T, C>),
99 FailedAndExhausted, }
101
102#[inline]
103fn poll_err<T>(
104 kind: ErrorKind,
105 reason: impl Into<Box<dyn std::error::Error + Send + Sync>>,
106) -> Poll<io::Result<T>> {
107 let io_err = io::Error::new(kind, reason);
108 Poll::Ready(Err(io_err))
109}
110
111fn exhausted_err<T>() -> Poll<io::Result<T>> {
112 poll_err(
113 ErrorKind::NotConnected,
114 "Disconnected. Connection attempts have been exhausted.",
115 )
116}
117
118fn disconnected_err<T>() -> Poll<io::Result<T>> {
119 poll_err(ErrorKind::NotConnected, "Underlying I/O is disconnected.")
120}
121
122impl<T, C> Deref for StubbornIo<T, C> {
123 type Target = T;
124
125 fn deref(&self) -> &Self::Target {
126 &self.underlying_io
127 }
128}
129
130impl<T, C> DerefMut for StubbornIo<T, C> {
131 fn deref_mut(&mut self) -> &mut Self::Target {
132 &mut self.underlying_io
133 }
134}
135
136impl<T, C> StubbornIo<T, C>
137where
138 T: UnderlyingIo<C>,
139 C: Clone + Send + Unpin + 'static,
140{
141 pub async fn connect(ctor_arg: C) -> io::Result<Self> {
144 let options = ReconnectOptions::new();
145 Self::connect_with_options(ctor_arg, options).await
146 }
147
148 pub async fn connect_with_options(ctor_arg: C, options: ReconnectOptions) -> io::Result<Self> {
149 let tcp = match T::establish(ctor_arg.clone()).await {
150 Ok(tcp) => {
151 info!("Initial connection succeeded.");
152 (options.on_connect_callback)();
153 tcp
154 }
155 Err(e) => {
156 error!("Initial connection failed due to: {:?}.", e);
157 (options.on_connect_fail_callback)();
158
159 if options.exit_if_first_connect_fails {
160 error!("Bailing after initial connection failure.");
161 return Err(e);
162 }
163
164 let mut result = Err(e);
165
166 for (i, duration) in (options.retries_to_attempt_fn)().enumerate() {
167 let reconnect_num = i + 1;
168
169 info!(
170 "Will re-perform initial connect attempt #{} in {:?}.",
171 reconnect_num, duration
172 );
173
174 sleep(duration).await;
175
176 info!("Attempting reconnect #{} now.", reconnect_num);
177
178 match T::establish(ctor_arg.clone()).await {
179 Ok(tcp) => {
180 result = Ok(tcp);
181 (options.on_connect_callback)();
182 info!("Initial connection successfully established.");
183 break;
184 }
185 Err(e) => {
186 (options.on_connect_fail_callback)();
187 result = Err(e);
188 }
189 }
190 }
191
192 match result {
193 Ok(tcp) => tcp,
194 Err(e) => {
195 error!("No more re-connect retries remaining. Never able to establish initial connection.");
196 return Err(e);
197 }
198 }
199 }
200 };
201
202 Ok(StubbornIo {
203 status: Status::Connected,
204 ctor_arg,
205 underlying_io: tcp,
206 options,
207 })
208 }
209
210 fn on_disconnect(mut self: Pin<&mut Self>, cx: &mut Context) {
211 match &mut self.status {
212 Status::Connected => {
214 error!("Disconnect occurred");
215 (self.options.on_disconnect_callback)();
216 self.status = Status::Disconnected(ReconnectStatus::new(&self.options));
217 }
218 Status::Disconnected(_) => {
219 (self.options.on_connect_fail_callback)();
220 }
221 Status::FailedAndExhausted => {
222 unreachable!("on_disconnect will not occur for already exhausted state.")
223 }
224 };
225
226 let ctor_arg = self.ctor_arg.clone();
227
228 if let Status::Disconnected(reconnect_status) = &mut self.status {
230 let next_duration = match reconnect_status.attempts_tracker.retries_remaining.next() {
231 Some(duration) => duration,
232 None => {
233 error!("No more re-connect retries remaining. Giving up.");
234 self.status = Status::FailedAndExhausted;
235 cx.waker().wake_by_ref();
236 return;
237 }
238 };
239
240 let future_instant = sleep(next_duration);
241
242 reconnect_status.attempts_tracker.attempt_num += 1;
243 let cur_num = reconnect_status.attempts_tracker.attempt_num;
244
245 let reconnect_attempt = async move {
246 future_instant.await;
247 info!("Attempting reconnect #{} now.", cur_num);
248 T::establish(ctor_arg).await
249 };
250
251 reconnect_status.reconnect_attempt = Arc::new(Mutex::new(Box::pin(reconnect_attempt)));
252
253 info!(
254 "Will perform reconnect attempt #{} in {:?}.",
255 reconnect_status.attempts_tracker.attempt_num, next_duration
256 );
257
258 cx.waker().wake_by_ref();
259 }
260 }
261
262 fn poll_disconnect(mut self: Pin<&mut Self>, cx: &mut Context) {
263 let (attempt, attempt_num) = match self.status {
264 Status::Connected => unreachable!(),
265 Status::Disconnected(ref mut status) => (
266 status.reconnect_attempt.clone(),
267 status.attempts_tracker.attempt_num,
268 ),
269 Status::FailedAndExhausted => unreachable!(),
270 };
271
272 let mut attempt = attempt.lock().unwrap();
273
274 match attempt.as_mut().poll(cx) {
275 Poll::Ready(Ok(underlying_io)) => {
276 info!("Connection re-established");
277 cx.waker().wake_by_ref();
278 self.status = Status::Connected;
279 (self.options.on_connect_callback)();
280 self.underlying_io = underlying_io;
281 }
282 Poll::Ready(Err(err)) => {
283 error!("Connection attempt #{} failed: {:?}", attempt_num, err);
284 self.on_disconnect(cx);
285 }
286 Poll::Pending => {}
287 }
288 }
289
290 fn is_read_disconnect_detected(
291 &self,
292 poll_result: &Poll<io::Result<()>>,
293 bytes_read: usize,
294 ) -> bool {
295 match poll_result {
296 Poll::Ready(Ok(())) if self.is_final_read(bytes_read) => true,
297 Poll::Ready(Err(err)) => self.is_disconnect_error(err),
298 _ => false,
299 }
300 }
301
302 fn is_write_disconnect_detected<X>(&self, poll_result: &Poll<io::Result<X>>) -> bool {
303 match poll_result {
304 Poll::Ready(Err(err)) => self.is_disconnect_error(err),
305 _ => false,
306 }
307 }
308}
309
310impl<T, C> AsyncRead for StubbornIo<T, C>
311where
312 T: UnderlyingIo<C> + AsyncRead,
313 C: Clone + Send + Unpin + 'static,
314{
315 fn poll_read(
316 mut self: Pin<&mut Self>,
317 cx: &mut Context<'_>,
318 buf: &mut ReadBuf<'_>,
319 ) -> Poll<io::Result<()>> {
320 match &mut self.status {
321 Status::Connected => {
322 let pre_len = buf.filled().len();
323 let poll = AsyncRead::poll_read(Pin::new(&mut self.underlying_io), cx, buf);
324 let post_len = buf.filled().len();
325 let bytes_read = post_len - pre_len;
326 if self.is_read_disconnect_detected(&poll, bytes_read) {
327 self.on_disconnect(cx);
328 Poll::Pending
329 } else {
330 poll
331 }
332 }
333 Status::Disconnected(_) => {
334 self.poll_disconnect(cx);
335 Poll::Pending
336 }
337 Status::FailedAndExhausted => exhausted_err(),
338 }
339 }
340}
341
342impl<T, C> AsyncWrite for StubbornIo<T, C>
343where
344 T: UnderlyingIo<C> + AsyncWrite,
345 C: Clone + Send + Unpin + 'static,
346{
347 fn poll_write(
348 mut self: Pin<&mut Self>,
349 cx: &mut Context<'_>,
350 buf: &[u8],
351 ) -> Poll<io::Result<usize>> {
352 match &mut self.status {
353 Status::Connected => {
354 let poll = AsyncWrite::poll_write(Pin::new(&mut self.underlying_io), cx, buf);
355
356 if self.is_write_disconnect_detected(&poll) {
357 self.on_disconnect(cx);
358 Poll::Pending
359 } else {
360 poll
361 }
362 }
363 Status::Disconnected(_) => {
364 self.poll_disconnect(cx);
365 Poll::Pending
366 }
367 Status::FailedAndExhausted => exhausted_err(),
368 }
369 }
370
371 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
372 match &mut self.status {
373 Status::Connected => {
374 let poll = AsyncWrite::poll_flush(Pin::new(&mut self.underlying_io), cx);
375
376 if self.is_write_disconnect_detected(&poll) {
377 self.on_disconnect(cx);
378 Poll::Pending
379 } else {
380 poll
381 }
382 }
383 Status::Disconnected(_) => {
384 self.poll_disconnect(cx);
385 Poll::Pending
386 }
387 Status::FailedAndExhausted => exhausted_err(),
388 }
389 }
390
391 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
392 match &mut self.status {
393 Status::Connected => {
394 let poll = AsyncWrite::poll_shutdown(Pin::new(&mut self.underlying_io), cx);
395 if poll.is_ready() {
396 self.on_disconnect(cx);
398 }
399
400 poll
401 }
402 Status::Disconnected(_) => disconnected_err(),
403 Status::FailedAndExhausted => exhausted_err(),
404 }
405 }
406
407 fn poll_write_vectored(
408 mut self: Pin<&mut Self>,
409 cx: &mut Context<'_>,
410 bufs: &[IoSlice<'_>],
411 ) -> Poll<io::Result<usize>> {
412 match &mut self.status {
413 Status::Connected => {
414 let poll =
415 AsyncWrite::poll_write_vectored(Pin::new(&mut self.underlying_io), cx, bufs);
416
417 if self.is_write_disconnect_detected(&poll) {
418 self.on_disconnect(cx);
419 Poll::Pending
420 } else {
421 poll
422 }
423 }
424 Status::Disconnected(_) => {
425 self.poll_disconnect(cx);
426 Poll::Pending
427 }
428 Status::FailedAndExhausted => exhausted_err(),
429 }
430 }
431
432 fn is_write_vectored(&self) -> bool {
433 self.underlying_io.is_write_vectored()
434 }
435}