barter_integration/socket/
on_stream_err.rs1use futures::{Sink, Stream};
2use pin_project::pin_project;
3use std::{
4 pin::Pin,
5 task::{Context, Poll, ready},
6};
7
8pub trait StreamErrorHandler<Err> {
10 fn handle(&mut self, error: &Err) -> StreamErrorAction;
12}
13
14impl<Err, F> StreamErrorHandler<Err> for F
15where
16 F: FnMut(&Err) -> StreamErrorAction,
17{
18 #[inline]
19 fn handle(&mut self, error: &Err) -> StreamErrorAction {
20 self(error)
21 }
22}
23
24#[derive(Debug, Copy, Clone, PartialEq)]
26pub enum StreamErrorAction {
27 Continue,
29 Reconnect,
31}
32
33#[derive(Debug)]
39#[pin_project]
40pub struct OnStreamErr<S, ErrHandler> {
41 #[pin]
42 socket: S,
43 on_err: ErrHandler,
44}
45
46impl<S, ErrHandler> OnStreamErr<S, ErrHandler> {
47 pub fn new(socket: S, on_err: ErrHandler) -> Self {
48 Self { socket, on_err }
49 }
50}
51
52impl<S, StOk, StErr, ErrHandler> Stream for OnStreamErr<S, ErrHandler>
53where
54 S: Stream<Item = Result<StOk, StErr>>,
55 ErrHandler: StreamErrorHandler<StErr>,
56{
57 type Item = Result<StOk, StErr>;
58
59 #[inline]
60 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
61 let mut this = self.project();
62
63 let next_ready = ready!(this.socket.as_mut().poll_next(cx));
64
65 let Some(result) = next_ready else {
66 return Poll::Ready(None);
67 };
68
69 match result {
70 Ok(item) => Poll::Ready(Some(Ok(item))),
71 Err(error) => match (this.on_err).handle(&error) {
72 StreamErrorAction::Continue => Poll::Ready(Some(Err(error))),
73 StreamErrorAction::Reconnect => Poll::Ready(None),
74 },
75 }
76 }
77}
78
79impl<St, ErrHandler, Item> Sink<Item> for OnStreamErr<St, ErrHandler>
80where
81 St: Sink<Item>,
82{
83 type Error = St::Error;
84
85 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86 self.project().socket.poll_ready(cx)
87 }
88
89 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
90 self.project().socket.start_send(item)
91 }
92
93 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
94 self.project().socket.poll_flush(cx)
95 }
96
97 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
98 self.project().socket.poll_close(cx)
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use futures::StreamExt;
106 use tokio::sync::mpsc;
107 use tokio_stream::wrappers::UnboundedReceiverStream;
108 use tokio_test::{assert_pending, assert_ready};
109
110 type TestError = &'static str;
111
112 #[tokio::test]
113 async fn test_on_stream_err_passes_through_ok() {
114 let waker = futures::task::noop_waker_ref();
115 let mut cx = Context::from_waker(waker);
116
117 let (tx, rx) = mpsc::unbounded_channel::<Result<i32, TestError>>();
118 let rx = UnboundedReceiverStream::new(rx);
119
120 let mut stream = OnStreamErr::new(rx, |_error: &TestError| StreamErrorAction::Continue);
121
122 assert_pending!(stream.poll_next_unpin(&mut cx));
123
124 tx.send(Ok(1)).unwrap();
125 assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(1)));
126
127 tx.send(Ok(2)).unwrap();
128 assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(2)));
129
130 drop(tx);
131 assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), None);
132 }
133
134 #[tokio::test]
135 async fn test_on_stream_err_continue_action() {
136 let waker = futures::task::noop_waker_ref();
137 let mut cx = Context::from_waker(waker);
138
139 let (tx, rx) = mpsc::unbounded_channel::<Result<i32, TestError>>();
140 let rx = UnboundedReceiverStream::new(rx);
141
142 let mut stream = OnStreamErr::new(rx, |_error: &TestError| StreamErrorAction::Continue);
143
144 tx.send(Ok(1)).unwrap();
145 assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(1)));
146
147 tx.send(Err("error1")).unwrap();
148 assert_eq!(
149 assert_ready!(stream.poll_next_unpin(&mut cx)),
150 Some(Err("error1"))
151 );
152
153 tx.send(Ok(2)).unwrap();
154 assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(2)));
155
156 drop(tx);
157 assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), None);
158 }
159
160 #[tokio::test]
161 async fn test_on_stream_err_reconnect_action() {
162 let waker = futures::task::noop_waker_ref();
163 let mut cx = Context::from_waker(waker);
164
165 let (tx, rx) = mpsc::unbounded_channel::<Result<i32, TestError>>();
166 let rx = UnboundedReceiverStream::new(rx);
167
168 let mut stream = OnStreamErr::new(rx, |error: &TestError| {
169 if *error == "fatal" {
170 StreamErrorAction::Reconnect
171 } else {
172 StreamErrorAction::Continue
173 }
174 });
175
176 tx.send(Ok(1)).unwrap();
177 assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), Some(Ok(1)));
178
179 tx.send(Err("non-fatal")).unwrap();
180 assert_eq!(
181 assert_ready!(stream.poll_next_unpin(&mut cx)),
182 Some(Err("non-fatal"))
183 );
184
185 tx.send(Err("fatal")).unwrap();
186 assert_eq!(assert_ready!(stream.poll_next_unpin(&mut cx)), None);
187 }
188}