barter_data/streams/reconnect/
stream.rs

1use crate::streams::{consumer::StreamKey, reconnect::Event};
2use barter_integration::channel::Tx;
3use derive_more::Constructor;
4use futures::Stream;
5use futures_util::StreamExt;
6use serde::{Deserialize, Serialize};
7use std::{convert, fmt::Debug, future, future::Future};
8use tracing::{error, info, warn};
9
10/// Utilities for handling a continually reconnecting [`Stream`] initialised via the
11/// [`init_reconnecting_stream`] function.
12pub trait ReconnectingStream
13where
14    Self: Stream + Sized,
15{
16    /// Add an exponential backoff policy to an initialised [`ReconnectingStream`] using the
17    /// provided [`ReconnectionBackoffPolicy`].
18    fn with_reconnect_backoff<St, InitError>(
19        self,
20        policy: ReconnectionBackoffPolicy,
21        stream_key: StreamKey,
22    ) -> impl Stream<Item = St>
23    where
24        Self: Stream<Item = Result<St, InitError>>,
25        St: Stream,
26        InitError: Debug,
27    {
28        self.enumerate()
29            .scan(
30                ReconnectionState::from(policy),
31                move |state, (attempt, result)| match result {
32                    Ok(stream) => {
33                        info!(attempt, ?stream_key, "successfully initialised Stream");
34                        state.reset_backoff();
35                        futures::future::Either::Left(future::ready(Some(Ok(stream))))
36                    }
37                    Err(error) => {
38                        warn!(
39                            attempt,
40                            ?stream_key,
41                            ?error,
42                            "failed to re-initialise Stream"
43                        );
44                        let sleep_fut = state.generate_sleep_future();
45                        state.multiply_backoff();
46                        futures::future::Either::Right(Box::pin(async move {
47                            sleep_fut.await;
48                            Some(Err(error))
49                        }))
50                    }
51                },
52            )
53            .filter_map(|result| future::ready(result.ok()))
54    }
55
56    /// Terminates the inner [`Stream`] if the encountered error is determined to be unrecoverable
57    /// by the provided closure. This will cause the [`ReconnectingStream`] to re-initialise the
58    /// inner [`Stream`].
59    fn with_termination_on_error<St, T, E, FnIsTerminal>(
60        self,
61        is_terminal: FnIsTerminal,
62        stream_key: StreamKey,
63    ) -> impl Stream<Item = impl Stream<Item = Result<T, E>>>
64    where
65        Self: Stream<Item = St>,
66        St: Stream<Item = Result<T, E>>,
67        FnIsTerminal: Fn(&E) -> bool + Copy,
68    {
69        self.map(move |stream| {
70            tokio_stream::StreamExt::map_while(stream, {
71                move |result| match result {
72                    Ok(item) => Some(Ok(item)),
73                    Err(error) if is_terminal(&error) => {
74                        error!(
75                            ?stream_key,
76                            "MarketStream encountered terminal error that requires reconnecting"
77                        );
78                        None
79                    }
80                    Err(error) => Some(Err(error)),
81                }
82            })
83        })
84    }
85
86    /// Maps every [`ReconnectingStream`] `Stream::Item` into an [`reconnect::Event::Item`](Event),
87    /// and chain a [`reconnect::Event::Reconnecting`](Event)
88    fn with_reconnection_events<St, Origin>(
89        self,
90        origin: Origin,
91    ) -> impl Stream<Item = Event<Origin, St::Item>>
92    where
93        Self: Stream<Item = St>,
94        St: Stream,
95        Origin: Clone + 'static,
96    {
97        self.map(move |stream| {
98            stream
99                .map(Event::Item)
100                .chain(futures::stream::once(future::ready(Event::Reconnecting(
101                    origin.clone(),
102                ))))
103        })
104        .flatten()
105    }
106
107    /// Handles all encountered errors with the provided closure before filtering them out,
108    /// returning a [`Stream`] of the Ok values. Useful for logging recoverable errors before
109    /// continuing.
110    fn with_error_handler<FnOnErr, Origin, T, E>(
111        self,
112        op: FnOnErr,
113    ) -> impl Stream<Item = Event<Origin, T>>
114    where
115        Self: Stream<Item = Event<Origin, Result<T, E>>>,
116        FnOnErr: Fn(E) + 'static,
117    {
118        self.filter_map(move |event| {
119            std::future::ready(match event {
120                Event::Reconnecting(origin) => Some(Event::Reconnecting(origin)),
121                Event::Item(Ok(item)) => Some(Event::Item(item)),
122                Event::Item(Err(error)) => {
123                    op(error);
124                    None
125                }
126            })
127        })
128    }
129
130    /// Future for forwarding items in [`Self`] to the provided channel [`Tx`].
131    fn forward_to<Transmitter>(self, tx: Transmitter) -> impl Future<Output = ()> + Send
132    where
133        Self: Stream + Sized + Send,
134        Self::Item: Into<Transmitter::Item>,
135        Transmitter: Tx + Send + 'static,
136    {
137        tokio_stream::StreamExt::map_while(self, move |event| tx.send(event.into()).ok()).collect()
138    }
139}
140
141impl<T> ReconnectingStream for T where T: Stream {}
142
143/// Initialise a [`ReconnectingStream`] using the provided initialisation closure.
144pub async fn init_reconnecting_stream<FnInit, St, FnInitError, FnInitFut>(
145    init_stream: FnInit,
146) -> Result<impl Stream<Item = Result<St, FnInitError>>, FnInitError>
147where
148    FnInit: Fn() -> FnInitFut,
149    FnInitFut: Future<Output = Result<St, FnInitError>>,
150{
151    let initial = init_stream().await?;
152    let reconnections = futures::stream::repeat_with(init_stream).then(convert::identity);
153
154    Ok(futures::stream::once(future::ready(Ok(initial))).chain(reconnections))
155}
156
157/// Reconnection backoff policy for a [`ReconnectingStream::with_reconnect_backoff`].
158#[derive(
159    Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize, Constructor,
160)]
161pub struct ReconnectionBackoffPolicy {
162    /// Initial backoff millisecond duration after the first `Stream` disconnection.
163    ///
164    /// This value then scales with the `backoff_multiplier` in the case of repeated failed
165    /// `Stream` reconnection attempts.
166    pub backoff_ms_initial: u64,
167
168    /// Scaling factor for the backoff duration in the case of repeated `Stream` reconnection
169    /// attempts.
170    pub backoff_multiplier: u8,
171
172    /// Maximum possible backoff duration between reconnection attempts.
173    pub backoff_ms_max: u64,
174}
175
176#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
177struct ReconnectionState {
178    policy: ReconnectionBackoffPolicy,
179    backoff_ms_current: u64,
180}
181
182impl From<ReconnectionBackoffPolicy> for ReconnectionState {
183    fn from(policy: ReconnectionBackoffPolicy) -> Self {
184        Self {
185            backoff_ms_current: policy.backoff_ms_initial,
186            policy,
187        }
188    }
189}
190
191impl ReconnectionState {
192    fn reset_backoff(&mut self) {
193        self.backoff_ms_current = self.policy.backoff_ms_initial;
194    }
195
196    fn multiply_backoff(&mut self) {
197        let next = self.backoff_ms_current * self.policy.backoff_multiplier as u64;
198        let next_capped = std::cmp::min(next, self.policy.backoff_ms_max);
199        self.backoff_ms_current = next_capped;
200    }
201
202    fn generate_sleep_future(&self) -> tokio::time::Sleep {
203        let sleep_duration = std::time::Duration::from_millis(self.backoff_ms_current);
204        tokio::time::sleep(sleep_duration)
205    }
206}