Skip to main content

barter_integration/socket/
on_connect_err.rs

1use futures::{Sink, Stream};
2use pin_project::pin_project;
3use std::{
4    pin::Pin,
5    task::{Context, Poll, ready},
6};
7
8/// Handles connection errors and determines the appropriate [`ConnectErrorAction`].
9pub trait ConnectErrorHandler<Err> {
10    /// Handles a connection error and returns the action to take.
11    fn handle(&mut self, error: &ConnectError<Err>) -> ConnectErrorAction;
12}
13
14impl<Err, F> ConnectErrorHandler<Err> for F
15where
16    F: FnMut(&ConnectError<Err>) -> ConnectErrorAction,
17{
18    #[inline]
19    fn handle(&mut self, error: &ConnectError<Err>) -> ConnectErrorAction {
20        self(error)
21    }
22}
23
24/// Connection error with reconnection attempt count.
25#[derive(Debug, Copy, Clone, PartialEq)]
26pub struct ConnectError<ErrConnect> {
27    pub reconnection_attempt: u32,
28    pub kind: ConnectErrorKind<ErrConnect>,
29}
30
31/// Connection error variants.
32#[derive(Debug, Copy, Clone, PartialEq)]
33pub enum ConnectErrorKind<ErrConnect> {
34    /// Connection attempt failed.
35    Connect(ErrConnect),
36    /// Connection attempt timed out.
37    Timeout,
38}
39
40/// Action to take in response to a connection error.
41#[derive(Debug, Copy, Clone, PartialEq)]
42pub enum ConnectErrorAction {
43    /// Attempt to reconnect.
44    Reconnect,
45    /// Terminate the stream.
46    Terminate,
47}
48
49/// Stream adapter that handles connection errors using a custom error handler.
50///
51/// - Ok(socket) items are passed through to the output stream
52/// - Err(error) items are handled by the error handler:
53///   - ConnectErrorAction::Reconnect: Filter out the error and continue polling
54///   - ConnectErrorAction::Terminate: End the stream
55#[derive(Debug)]
56#[pin_project]
57pub struct OnConnectErr<S, ErrHandler> {
58    #[pin]
59    socket: S,
60    on_err: ErrHandler,
61}
62
63impl<S, ErrHandler> OnConnectErr<S, ErrHandler> {
64    pub fn new(socket: S, on_err: ErrHandler) -> Self {
65        Self { socket, on_err }
66    }
67}
68
69impl<S, Socket, ErrConnect, ErrHandler> Stream for OnConnectErr<S, ErrHandler>
70where
71    S: Stream<Item = Result<Socket, ConnectError<ErrConnect>>>,
72    ErrHandler: ConnectErrorHandler<ErrConnect>,
73{
74    type Item = Socket;
75
76    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
77        let mut this = self.project();
78
79        loop {
80            let next_ready = ready!(this.socket.as_mut().poll_next(cx));
81
82            let Some(result) = next_ready else {
83                return Poll::Ready(None);
84            };
85
86            match result {
87                Ok(socket) => {
88                    return Poll::Ready(Some(socket));
89                }
90                Err(error) => {
91                    match this.on_err.handle(&error) {
92                        ConnectErrorAction::Reconnect => {
93                            // Continue polling for the next item
94                        }
95                        ConnectErrorAction::Terminate => {
96                            return Poll::Ready(None);
97                        }
98                    }
99                }
100            }
101        }
102    }
103}
104
105impl<S, ErrHandler, Item> Sink<Item> for OnConnectErr<S, ErrHandler>
106where
107    S: Sink<Item>,
108{
109    type Error = S::Error;
110
111    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
112        self.project().socket.poll_ready(cx)
113    }
114
115    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
116        self.project().socket.start_send(item)
117    }
118
119    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120        self.project().socket.poll_flush(cx)
121    }
122
123    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
124        self.project().socket.poll_close(cx)
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use crate::socket::ReconnectingSocket;
132    use futures::StreamExt;
133    use tokio::sync::mpsc;
134    use tokio_stream::wrappers::UnboundedReceiverStream;
135    use tokio_test::{assert_pending, assert_ready_eq};
136
137    type TestSocket = i32;
138    type TestError = &'static str;
139
140    #[tokio::test]
141    async fn test_on_connect_err_passes_through_success() {
142        let waker = futures::task::noop_waker_ref();
143        let mut cx = Context::from_waker(waker);
144
145        let (tx, rx) = mpsc::unbounded_channel::<Result<TestSocket, ConnectError<TestError>>>();
146        let rx = UnboundedReceiverStream::new(rx);
147
148        let mut stream =
149            rx.on_connect_err(|_error: &ConnectError<TestError>| ConnectErrorAction::Reconnect);
150
151        assert_pending!(stream.poll_next_unpin(&mut cx));
152
153        tx.send(Ok(1)).unwrap();
154        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
155
156        tx.send(Ok(2)).unwrap();
157        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(2));
158
159        drop(tx);
160        assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
161    }
162
163    #[tokio::test]
164    async fn test_on_connect_err_reconnect_action() {
165        let waker = futures::task::noop_waker_ref();
166        let mut cx = Context::from_waker(waker);
167
168        let (tx, rx) = mpsc::unbounded_channel::<Result<TestSocket, ConnectError<TestError>>>();
169        let rx = UnboundedReceiverStream::new(rx);
170
171        let mut stream =
172            rx.on_connect_err(|_error: &ConnectError<TestError>| ConnectErrorAction::Reconnect);
173
174        tx.send(Err(ConnectError {
175            reconnection_attempt: 1,
176            kind: ConnectErrorKind::Connect("network error"),
177        }))
178        .unwrap();
179        assert_pending!(stream.poll_next_unpin(&mut cx));
180
181        tx.send(Err(ConnectError {
182            reconnection_attempt: 2,
183            kind: ConnectErrorKind::Timeout,
184        }))
185        .unwrap();
186        assert_pending!(stream.poll_next_unpin(&mut cx));
187
188        tx.send(Ok(42)).unwrap();
189        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(42));
190
191        drop(tx);
192        assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
193    }
194
195    #[tokio::test]
196    async fn test_on_connect_err_terminate_action() {
197        let waker = futures::task::noop_waker_ref();
198        let mut cx = Context::from_waker(waker);
199
200        let (tx, rx) = mpsc::unbounded_channel::<Result<TestSocket, ConnectError<TestError>>>();
201        let rx = UnboundedReceiverStream::new(rx);
202
203        let mut stream = rx.on_connect_err(|error: &ConnectError<TestError>| {
204            if error.reconnection_attempt >= 3 {
205                ConnectErrorAction::Terminate
206            } else {
207                ConnectErrorAction::Reconnect
208            }
209        });
210
211        tx.send(Ok(1)).unwrap();
212        assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
213
214        tx.send(Err(ConnectError {
215            reconnection_attempt: 1,
216            kind: ConnectErrorKind::Connect("error"),
217        }))
218        .unwrap();
219        assert_pending!(stream.poll_next_unpin(&mut cx));
220
221        tx.send(Err(ConnectError {
222            reconnection_attempt: 3,
223            kind: ConnectErrorKind::Connect("error"),
224        }))
225        .unwrap();
226        assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
227    }
228}