Skip to main content

barter_integration/socket/
mod.rs

1use crate::socket::{
2    backoff::ReconnectBackoff,
3    on_connect_err::{ConnectError, ConnectErrorHandler, ConnectErrorKind, OnConnectErr},
4    on_stream_err::{OnStreamErr, StreamErrorHandler},
5    on_stream_err_filter::OnStreamErrFilter,
6    update::SocketUpdate,
7};
8use futures::{Sink, Stream, stream::SplitSink};
9
10/// Backoff strategies for reconnection attempts.
11pub mod backoff;
12
13/// Connection error handling.
14pub mod on_connect_err;
15
16/// Stream error handling.
17pub mod on_stream_err;
18
19/// Stream error handling with filtering.
20pub mod on_stream_err_filter;
21
22/// Defines the socket lifecycle [`SocketUpdate`] event.
23pub mod update;
24
25/// Extension trait providing reconnection utilities for streams.
26pub trait ReconnectingSocket
27where
28    Self: Stream,
29{
30    /// Handles connection errors using the provided [`ConnectErrorHandler`].
31    fn on_connect_err<Socket, ErrConnect, ErrHandler>(
32        self,
33        on_err: ErrHandler,
34    ) -> OnConnectErr<Self, ErrHandler>
35    where
36        Self: Stream<Item = Result<Socket, ConnectError<ErrConnect>>> + Sized,
37        ErrHandler: ConnectErrorHandler<ErrConnect>,
38    {
39        OnConnectErr::new(self, on_err)
40    }
41
42    /// Applies error handling to the inner Stream using the provided [`StreamErrorHandler`].
43    ///
44    /// Errors may be passed through or trigger a reconnecting.
45    fn on_stream_err<Socket, StOk, StErr, ErrHandler>(
46        self,
47        on_err: ErrHandler,
48    ) -> impl Stream<Item = OnStreamErr<Socket, ErrHandler>>
49    where
50        Self: Stream<Item = Socket> + Sized,
51        Socket: Stream<Item = Result<StOk, StErr>>,
52        ErrHandler: StreamErrorHandler<StErr> + Clone + 'static,
53    {
54        use futures::StreamExt;
55        self.map(move |socket| OnStreamErr::new(socket, on_err.clone()))
56    }
57
58    /// Similar to [`ReconnectingSocket::on_stream_err`] but filters all errors after applying
59    /// the provided [`StreamErrorHandler`].
60    fn on_stream_err_filter<Socket, StOk, StErr, ErrHandler>(
61        self,
62        on_err: ErrHandler,
63    ) -> impl Stream<Item = OnStreamErrFilter<Socket, ErrHandler>>
64    where
65        Self: Stream<Item = Socket> + Sized,
66        Socket: Stream<Item = Result<StOk, StErr>>,
67        ErrHandler: StreamErrorHandler<StErr> + Clone + 'static,
68    {
69        use futures::StreamExt;
70        self.map(move |socket| OnStreamErrFilter::new(socket, on_err.clone()))
71    }
72
73    /// Wrap stream items with [`SocketUpdate`] lifecycle events.
74    fn with_socket_updates<Socket, SinkItem>(
75        self,
76    ) -> impl Stream<Item = SocketUpdate<SplitSink<Socket, SinkItem>, Socket::Item>>
77    where
78        Self: Stream<Item = Socket> + Sized,
79        Socket: Sink<SinkItem> + Stream,
80    {
81        use futures::{StreamExt, stream::once};
82        use std::future::ready;
83
84        self.map(move |socket| {
85            let (sink, stream) = socket.split();
86            once(ready(SocketUpdate::Connected(sink))).chain(
87                stream
88                    .map(SocketUpdate::Item)
89                    .chain(once(ready(SocketUpdate::Reconnecting))),
90            )
91        })
92        .flatten()
93    }
94}
95
96impl<St> ReconnectingSocket for St where St: Stream {}
97
98/// Initialises a "reconnecting socket" using the provided connect function.
99///
100/// Upon disconnecting, the [`ReconnectBackoff`] is used to determine how long to wait
101/// between reconnecting attempts.
102///
103/// Returns a `Stream` of `Socket` connection results.
104pub fn init_reconnecting_socket<FnConnect, Backoff, Socket, ErrConnect>(
105    connect: FnConnect,
106    timeout_connect: std::time::Duration,
107    backoff: Backoff,
108) -> impl Stream<Item = Result<Socket, ConnectError<ErrConnect>>>
109where
110    FnConnect: AsyncFnMut() -> Result<Socket, ErrConnect>,
111    Backoff: ReconnectBackoff,
112{
113    struct State<F, B> {
114        connect: F,
115        backoff: B,
116        reconnection_attempt: u32,
117    }
118
119    futures::stream::unfold(
120        State {
121            connect,
122            backoff,
123            reconnection_attempt: 0,
124        },
125        move |mut state| async move {
126            // Apply reconnection backoff
127            let backoff = state.backoff.reconnect_backoff(state.reconnection_attempt);
128            tokio::time::sleep(backoff).await;
129
130            // Connect with timeout
131            let result = match tokio::time::timeout(timeout_connect, (state.connect)()).await {
132                Ok(Ok(socket)) => {
133                    state.reconnection_attempt = 0;
134                    Ok(socket)
135                }
136                Ok(Err(error)) => {
137                    state.reconnection_attempt = state.reconnection_attempt.saturating_add(1);
138                    Err(ConnectError {
139                        reconnection_attempt: state.reconnection_attempt,
140                        kind: ConnectErrorKind::Connect(error),
141                    })
142                }
143                Err(_elapsed) => {
144                    state.reconnection_attempt = state.reconnection_attempt.saturating_add(1);
145                    Err(ConnectError {
146                        reconnection_attempt: state.reconnection_attempt,
147                        kind: ConnectErrorKind::Timeout,
148                    })
149                }
150            };
151
152            Some((result, state))
153        },
154    )
155}