barter_data/streams/reconnect/
stream.rs1use crate::streams::{consumer::StreamKey, reconnect::Event};
2use barter_integration::channel::Tx;
3use derive_more::Constructor;
4use futures::Stream;
5use futures_util::StreamExt;
6use serde::{Deserialize, Serialize};
7use std::{convert, fmt::Debug, future, future::Future};
8use tracing::{error, info, warn};
9
10pub trait ReconnectingStream
13where
14 Self: Stream + Sized,
15{
16 fn with_reconnect_backoff<St, InitError>(
19 self,
20 policy: ReconnectionBackoffPolicy,
21 stream_key: StreamKey,
22 ) -> impl Stream<Item = St>
23 where
24 Self: Stream<Item = Result<St, InitError>>,
25 St: Stream,
26 InitError: Debug,
27 {
28 self.enumerate()
29 .scan(
30 ReconnectionState::from(policy),
31 move |state, (attempt, result)| match result {
32 Ok(stream) => {
33 info!(attempt, ?stream_key, "successfully initialised Stream");
34 state.reset_backoff();
35 futures::future::Either::Left(future::ready(Some(Ok(stream))))
36 }
37 Err(error) => {
38 warn!(
39 attempt,
40 ?stream_key,
41 ?error,
42 "failed to re-initialise Stream"
43 );
44 let sleep_fut = state.generate_sleep_future();
45 state.multiply_backoff();
46 futures::future::Either::Right(Box::pin(async move {
47 sleep_fut.await;
48 Some(Err(error))
49 }))
50 }
51 },
52 )
53 .filter_map(|result| future::ready(result.ok()))
54 }
55
56 fn with_termination_on_error<St, T, E, FnIsTerminal>(
60 self,
61 is_terminal: FnIsTerminal,
62 stream_key: StreamKey,
63 ) -> impl Stream<Item = impl Stream<Item = Result<T, E>>>
64 where
65 Self: Stream<Item = St>,
66 St: Stream<Item = Result<T, E>>,
67 FnIsTerminal: Fn(&E) -> bool + Copy,
68 {
69 self.map(move |stream| {
70 tokio_stream::StreamExt::map_while(stream, {
71 move |result| match result {
72 Ok(item) => Some(Ok(item)),
73 Err(error) if is_terminal(&error) => {
74 error!(
75 ?stream_key,
76 "MarketStream encountered terminal error that requires reconnecting"
77 );
78 None
79 }
80 Err(error) => Some(Err(error)),
81 }
82 })
83 })
84 }
85
86 fn with_reconnection_events<St, Origin>(
89 self,
90 origin: Origin,
91 ) -> impl Stream<Item = Event<Origin, St::Item>>
92 where
93 Self: Stream<Item = St>,
94 St: Stream,
95 Origin: Clone + 'static,
96 {
97 self.map(move |stream| {
98 stream
99 .map(Event::Item)
100 .chain(futures::stream::once(future::ready(Event::Reconnecting(
101 origin.clone(),
102 ))))
103 })
104 .flatten()
105 }
106
107 fn with_error_handler<FnOnErr, Origin, T, E>(
111 self,
112 op: FnOnErr,
113 ) -> impl Stream<Item = Event<Origin, T>>
114 where
115 Self: Stream<Item = Event<Origin, Result<T, E>>>,
116 FnOnErr: Fn(E) + 'static,
117 {
118 self.filter_map(move |event| {
119 std::future::ready(match event {
120 Event::Reconnecting(origin) => Some(Event::Reconnecting(origin)),
121 Event::Item(Ok(item)) => Some(Event::Item(item)),
122 Event::Item(Err(error)) => {
123 op(error);
124 None
125 }
126 })
127 })
128 }
129
130 fn forward_to<Transmitter>(self, tx: Transmitter) -> impl Future<Output = ()> + Send
132 where
133 Self: Stream + Sized + Send,
134 Self::Item: Into<Transmitter::Item>,
135 Transmitter: Tx + Send + 'static,
136 {
137 tokio_stream::StreamExt::map_while(self, move |event| tx.send(event.into()).ok()).collect()
138 }
139}
140
141impl<T> ReconnectingStream for T where T: Stream {}
142
143pub async fn init_reconnecting_stream<FnInit, St, FnInitError, FnInitFut>(
145 init_stream: FnInit,
146) -> Result<impl Stream<Item = Result<St, FnInitError>>, FnInitError>
147where
148 FnInit: Fn() -> FnInitFut,
149 FnInitFut: Future<Output = Result<St, FnInitError>>,
150{
151 let initial = init_stream().await?;
152 let reconnections = futures::stream::repeat_with(init_stream).then(convert::identity);
153
154 Ok(futures::stream::once(future::ready(Ok(initial))).chain(reconnections))
155}
156
157#[derive(
159 Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize, Constructor,
160)]
161pub struct ReconnectionBackoffPolicy {
162 pub backoff_ms_initial: u64,
167
168 pub backoff_multiplier: u8,
171
172 pub backoff_ms_max: u64,
174}
175
176#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
177struct ReconnectionState {
178 policy: ReconnectionBackoffPolicy,
179 backoff_ms_current: u64,
180}
181
182impl From<ReconnectionBackoffPolicy> for ReconnectionState {
183 fn from(policy: ReconnectionBackoffPolicy) -> Self {
184 Self {
185 backoff_ms_current: policy.backoff_ms_initial,
186 policy,
187 }
188 }
189}
190
191impl ReconnectionState {
192 fn reset_backoff(&mut self) {
193 self.backoff_ms_current = self.policy.backoff_ms_initial;
194 }
195
196 fn multiply_backoff(&mut self) {
197 let next = self.backoff_ms_current * self.policy.backoff_multiplier as u64;
198 let next_capped = std::cmp::min(next, self.policy.backoff_ms_max);
199 self.backoff_ms_current = next_capped;
200 }
201
202 fn generate_sleep_future(&self) -> tokio::time::Sleep {
203 let sleep_duration = std::time::Duration::from_millis(self.backoff_ms_current);
204 tokio::time::sleep(sleep_duration)
205 }
206}