Skip to main content

barter_integration/socket/
on_stream_err.rs

1use futures::{Sink, Stream};
2use pin_project::pin_project;
3use std::{
4    pin::Pin,
5    task::{Context, Poll, ready},
6};
7
8/// Handles stream errors and determines the appropriate [`StreamErrorAction`].
9pub trait StreamErrorHandler<Err> {
10    /// Handles a stream error and returns the action to take.
11    fn handle(&mut self, error: &Err) -> StreamErrorAction;
12}
13
14impl<Err, F> StreamErrorHandler<Err> for F
15where
16    F: FnMut(&Err) -> StreamErrorAction,
17{
18    #[inline]
19    fn handle(&mut self, error: &Err) -> StreamErrorAction {
20        self(error)
21    }
22}
23
24/// Action to take in response to a stream error.
25#[derive(Debug, Copy, Clone, PartialEq)]
26pub enum StreamErrorAction {
27    /// Keep the stream alive.
28    Continue,
29    /// End the stream and trigger reconnection.
30    Reconnect,
31}
32
33/// Stream wrapper that applies error handling to a Result stream.
34///
35/// When an error occurs:
36/// - `StreamErrorAction::Continue`: Pass the error through
37/// - `StreamErrorAction::Reconnect`: End the stream (triggers reconnection)
38#[derive(Debug)]
39#[pin_project]
40pub struct OnStreamErr<S, ErrHandler> {
41    #[pin]
42    socket: S,
43    on_err: ErrHandler,
44}
45
46impl<S, ErrHandler> OnStreamErr<S, ErrHandler> {
47    pub fn new(socket: S, on_err: ErrHandler) -> Self {
48        Self { socket, on_err }
49    }
50}
51
52impl<S, StOk, StErr, ErrHandler> Stream for OnStreamErr<S, ErrHandler>
53where
54    S: Stream<Item = Result<StOk, StErr>>,
55    ErrHandler: StreamErrorHandler<StErr>,
56{
57    type Item = Result<StOk, StErr>;
58
59    #[inline]
60    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
61        let mut this = self.project();
62
63        let next_ready = ready!(this.socket.as_mut().poll_next(cx));
64
65        let Some(result) = next_ready else {
66            return Poll::Ready(None);
67        };
68
69        match result {
70            Ok(item) => Poll::Ready(Some(Ok(item))),
71            Err(error) => match (this.on_err).handle(&error) {
72                StreamErrorAction::Continue => Poll::Ready(Some(Err(error))),
73                StreamErrorAction::Reconnect => Poll::Ready(None),
74            },
75        }
76    }
77}
78
79impl<St, ErrHandler, Item> Sink<Item> for OnStreamErr<St, ErrHandler>
80where
81    St: Sink<Item>,
82{
83    type Error = St::Error;
84
85    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86        self.project().socket.poll_ready(cx)
87    }
88
89    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
90        self.project().socket.start_send(item)
91    }
92
93    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
94        self.project().socket.poll_flush(cx)
95    }
96
97    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
98        self.project().socket.poll_close(cx)
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use futures::StreamExt;
106    use tokio::sync::mpsc;
107    use tokio_stream::wrappers::UnboundedReceiverStream;
108    use tokio_test::{assert_pending, assert_ready};
109
110    type TestError = &'static str;
111
112    #[tokio::test]
113    async fn test_on_stream_err_passes_through_ok() {
114        let waker = futures::task::noop_waker_ref();
115        let mut cx = Context::from_waker(waker);
116
117        let (tx, rx) = mpsc::unbounded_channel::<Result<i32, TestError>>();
118        let rx = UnboundedReceiverStream::new(rx);
119
120        let mut stream = OnStreamErr::new(rx, |_error: &TestError| StreamErrorAction::Continue);
121
122        assert_pending!(stream.poll_next_unpin(&mut cx));
123
124        tx.send(Ok(1)).unwrap();
125        assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(1)));
126
127        tx.send(Ok(2)).unwrap();
128        assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(2)));
129
130        drop(tx);
131        assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), None);
132    }
133
134    #[tokio::test]
135    async fn test_on_stream_err_continue_action() {
136        let waker = futures::task::noop_waker_ref();
137        let mut cx = Context::from_waker(waker);
138
139        let (tx, rx) = mpsc::unbounded_channel::<Result<i32, TestError>>();
140        let rx = UnboundedReceiverStream::new(rx);
141
142        let mut stream = OnStreamErr::new(rx, |_error: &TestError| StreamErrorAction::Continue);
143
144        tx.send(Ok(1)).unwrap();
145        assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(1)));
146
147        tx.send(Err("error1")).unwrap();
148        assert_eq!(
149            assert_ready!(stream.poll_next_unpin(&mut cx)),
150            Some(Err("error1"))
151        );
152
153        tx.send(Ok(2)).unwrap();
154        assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(2)));
155
156        drop(tx);
157        assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), None);
158    }
159
160    #[tokio::test]
161    async fn test_on_stream_err_reconnect_action() {
162        let waker = futures::task::noop_waker_ref();
163        let mut cx = Context::from_waker(waker);
164
165        let (tx, rx) = mpsc::unbounded_channel::<Result<i32, TestError>>();
166        let rx = UnboundedReceiverStream::new(rx);
167
168        let mut stream = OnStreamErr::new(rx, |error: &TestError| {
169            if *error == "fatal" {
170                StreamErrorAction::Reconnect
171            } else {
172                StreamErrorAction::Continue
173            }
174        });
175
176        tx.send(Ok(1)).unwrap();
177        assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(1)));
178
179        tx.send(Err("non-fatal")).unwrap();
180        assert_eq!(
181            assert_ready!(stream.poll_next_unpin(&mut cx)),
182            Some(Err("non-fatal"))
183        );
184
185        tx.send(Err("fatal")).unwrap();
186        assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), None);
187    }
188}