barter_integration/socket/
on_connect_err.rs1use futures::{Sink, Stream};
2use pin_project::pin_project;
3use std::{
4 pin::Pin,
5 task::{Context, Poll, ready},
6};
7
8pub trait ConnectErrorHandler<Err> {
10 fn handle(&mut self, error: &ConnectError<Err>) -> ConnectErrorAction;
12}
13
14impl<Err, F> ConnectErrorHandler<Err> for F
15where
16 F: FnMut(&ConnectError<Err>) -> ConnectErrorAction,
17{
18 #[inline]
19 fn handle(&mut self, error: &ConnectError<Err>) -> ConnectErrorAction {
20 self(error)
21 }
22}
23
24#[derive(Debug, Copy, Clone, PartialEq)]
26pub struct ConnectError<ErrConnect> {
27 pub reconnection_attempt: u32,
28 pub kind: ConnectErrorKind<ErrConnect>,
29}
30
31#[derive(Debug, Copy, Clone, PartialEq)]
33pub enum ConnectErrorKind<ErrConnect> {
34 Connect(ErrConnect),
36 Timeout,
38}
39
40#[derive(Debug, Copy, Clone, PartialEq)]
42pub enum ConnectErrorAction {
43 Reconnect,
45 Terminate,
47}
48
49#[derive(Debug)]
56#[pin_project]
57pub struct OnConnectErr<S, ErrHandler> {
58 #[pin]
59 socket: S,
60 on_err: ErrHandler,
61}
62
63impl<S, ErrHandler> OnConnectErr<S, ErrHandler> {
64 pub fn new(socket: S, on_err: ErrHandler) -> Self {
65 Self { socket, on_err }
66 }
67}
68
69impl<S, Socket, ErrConnect, ErrHandler> Stream for OnConnectErr<S, ErrHandler>
70where
71 S: Stream<Item = Result<Socket, ConnectError<ErrConnect>>>,
72 ErrHandler: ConnectErrorHandler<ErrConnect>,
73{
74 type Item = Socket;
75
76 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
77 let mut this = self.project();
78
79 loop {
80 let next_ready = ready!(this.socket.as_mut().poll_next(cx));
81
82 let Some(result) = next_ready else {
83 return Poll::Ready(None);
84 };
85
86 match result {
87 Ok(socket) => {
88 return Poll::Ready(Some(socket));
89 }
90 Err(error) => {
91 match this.on_err.handle(&error) {
92 ConnectErrorAction::Reconnect => {
93 }
95 ConnectErrorAction::Terminate => {
96 return Poll::Ready(None);
97 }
98 }
99 }
100 }
101 }
102 }
103}
104
105impl<S, ErrHandler, Item> Sink<Item> for OnConnectErr<S, ErrHandler>
106where
107 S: Sink<Item>,
108{
109 type Error = S::Error;
110
111 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
112 self.project().socket.poll_ready(cx)
113 }
114
115 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
116 self.project().socket.start_send(item)
117 }
118
119 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120 self.project().socket.poll_flush(cx)
121 }
122
123 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
124 self.project().socket.poll_close(cx)
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use crate::socket::ReconnectingSocket;
132 use futures::StreamExt;
133 use tokio::sync::mpsc;
134 use tokio_stream::wrappers::UnboundedReceiverStream;
135 use tokio_test::{assert_pending, assert_ready_eq};
136
137 type TestSocket = i32;
138 type TestError = &'static str;
139
140 #[tokio::test]
141 async fn test_on_connect_err_passes_through_success() {
142 let waker = futures::task::noop_waker_ref();
143 let mut cx = Context::from_waker(waker);
144
145 let (tx, rx) = mpsc::unbounded_channel::<Result<TestSocket, ConnectError<TestError>>>();
146 let rx = UnboundedReceiverStream::new(rx);
147
148 let mut stream =
149 rx.on_connect_err(|_error: &ConnectError<TestError>| ConnectErrorAction::Reconnect);
150
151 assert_pending!(stream.poll_next_unpin(&mut cx));
152
153 tx.send(Ok(1)).unwrap();
154 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
155
156 tx.send(Ok(2)).unwrap();
157 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(2));
158
159 drop(tx);
160 assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
161 }
162
163 #[tokio::test]
164 async fn test_on_connect_err_reconnect_action() {
165 let waker = futures::task::noop_waker_ref();
166 let mut cx = Context::from_waker(waker);
167
168 let (tx, rx) = mpsc::unbounded_channel::<Result<TestSocket, ConnectError<TestError>>>();
169 let rx = UnboundedReceiverStream::new(rx);
170
171 let mut stream =
172 rx.on_connect_err(|_error: &ConnectError<TestError>| ConnectErrorAction::Reconnect);
173
174 tx.send(Err(ConnectError {
175 reconnection_attempt: 1,
176 kind: ConnectErrorKind::Connect("network error"),
177 }))
178 .unwrap();
179 assert_pending!(stream.poll_next_unpin(&mut cx));
180
181 tx.send(Err(ConnectError {
182 reconnection_attempt: 2,
183 kind: ConnectErrorKind::Timeout,
184 }))
185 .unwrap();
186 assert_pending!(stream.poll_next_unpin(&mut cx));
187
188 tx.send(Ok(42)).unwrap();
189 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(42));
190
191 drop(tx);
192 assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
193 }
194
195 #[tokio::test]
196 async fn test_on_connect_err_terminate_action() {
197 let waker = futures::task::noop_waker_ref();
198 let mut cx = Context::from_waker(waker);
199
200 let (tx, rx) = mpsc::unbounded_channel::<Result<TestSocket, ConnectError<TestError>>>();
201 let rx = UnboundedReceiverStream::new(rx);
202
203 let mut stream = rx.on_connect_err(|error: &ConnectError<TestError>| {
204 if error.reconnection_attempt >= 3 {
205 ConnectErrorAction::Terminate
206 } else {
207 ConnectErrorAction::Reconnect
208 }
209 });
210
211 tx.send(Ok(1)).unwrap();
212 assert_ready_eq!(stream.poll_next_unpin(&mut cx), Some(1));
213
214 tx.send(Err(ConnectError {
215 reconnection_attempt: 1,
216 kind: ConnectErrorKind::Connect("error"),
217 }))
218 .unwrap();
219 assert_pending!(stream.poll_next_unpin(&mut cx));
220
221 tx.send(Err(ConnectError {
222 reconnection_attempt: 3,
223 kind: ConnectErrorKind::Connect("error"),
224 }))
225 .unwrap();
226 assert_ready_eq!(stream.poll_next_unpin(&mut cx), None);
227 }
228}