reconnecting_websocket/
socket.rs

1use std::{
2    convert,
3    fmt::{self, Debug},
4    marker::PhantomData,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use cfg_if::cfg_if;
10use exponential_backoff::Backoff;
11use futures::{
12    channel::mpsc::{self, SendError, TrySendError, UnboundedReceiver, UnboundedSender},
13    ready,
14    stream::{self, Fuse, FusedStream},
15    Sink, Stream, StreamExt,
16};
17use gloo::{
18    net::websocket::{futures::WebSocket, Message, WebSocketError},
19    timers::future::TimeoutFuture,
20};
21
22use crate::{
23    constants::DEFAULT_STABLE_CONNECTION_TIMEOUT,
24    debug, error,
25    event::{map_err, map_poll},
26    info, trace, Error, Event, SocketInput, SocketOutput, State, DEFAULT_BACKOFF_MAX,
27    DEFAULT_BACKOFF_MIN, DEFAULT_MAX_RETRIES,
28};
29
30/// Enum to track which sub future/stream we polled most recently
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub(crate) enum NextPoll {
33    Socket,
34    Channel,
35}
36
37impl Default for NextPoll {
38    fn default() -> Self {
39        Self::Socket
40    }
41}
42
43impl NextPoll {
44    fn next(self) -> NextPoll {
45        use NextPoll::*;
46        match self {
47            Socket => Channel,
48            Channel => Socket,
49        }
50    }
51}
52
53impl IntoIterator for NextPoll {
54    type IntoIter = NextPollIter;
55    type Item = NextPoll;
56
57    fn into_iter(self) -> Self::IntoIter {
58        use NextPoll::*;
59        let items = match self {
60            Socket => [Socket, Channel],
61            Channel => [Channel, Socket],
62        };
63        NextPollIter { i: 0, items }
64    }
65}
66
67/// An iterator that always contains all the things to poll in the right sequence
68pub(crate) struct NextPollIter {
69    i: usize,
70    items: [NextPoll; 2],
71}
72
73impl Iterator for NextPollIter {
74    type Item = NextPoll;
75
76    fn next(&mut self) -> Option<Self::Item> {
77        if self.i >= self.items.len() {
78            None
79        } else {
80            self.i += 1;
81            Some(self.items[self.i - 1])
82        }
83    }
84}
85
86/// A handle that implements [`Sink`] for sending messages from the client to the server
87///
88/// Cheap and safe to clone (internally it's a channel sender)
89#[derive(Debug, Clone)]
90pub struct SocketSink<I> {
91    sender: UnboundedSender<I>,
92}
93
94impl<I> From<UnboundedSender<I>> for SocketSink<I> {
95    fn from(sender: UnboundedSender<I>) -> Self {
96        Self { sender }
97    }
98}
99
100impl<I> Sink<I> for SocketSink<I>
101where
102    I: SocketInput,
103    Message: TryFrom<I>,
104    <Message as TryFrom<I>>::Error: Debug,
105{
106    type Error = SendError;
107
108    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109        UnboundedSender::poll_ready(&self.sender, cx)
110    }
111
112    fn start_send(self: Pin<&mut Self>, msg: I) -> Result<(), Self::Error> {
113        self.sender.unbounded_send(msg).map_err(TrySendError::into_send_error)
114    }
115
116    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
117        Poll::Ready(Ok(()))
118    }
119
120    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
121        self.sender.close_channel();
122        Poll::Ready(Ok(()))
123    }
124}
125
126/// A wrapper around [`WebSocket`] that reconnects when the socket
127/// drops. Uses [`Backoff`] to determine the delay between reconnects
128///
129/// See the [`crate`] documentation for usage and examples
130///
131/// An error returned by the [`Stream`] aren't necessarily fatal. Check [`Error`] for more detail.
132/// `Poll::Ready(None)` is the main fatal case that requires a new instance of [`Socket`]
133pub struct Socket<I, O> {
134    /// The server URL to connect to on reconnect
135    pub(crate) url: String,
136    /// The sending end of the input message channel
137    /// Retained to implement [`Self::get_sink`] and [`Self::send`]
138    pub(crate) sink_sender: UnboundedSender<I>,
139    /// The receiving side of the input message channel
140    /// Polled by the [`Stream`] implementation
141    pub(crate) sink_receiver: UnboundedReceiver<I>,
142    /// The inner socket, None when a reconnect is pending
143    pub(crate) socket: Option<WebSocket>,
144    /// A queued message that needs to be sent as soon as the socket is [`State::Open`] This
145    /// happens when the inner socket exists but hasn't yet fully connected. When in this
146    /// state the [`WebSocket`] [`Sink`] implementation returns [`Poll::Pending`]. Since we
147    /// can't reliably know that with any certainty until we've already created the
148    /// [`Message`] from the input channel and called [`Sink::poll_ready`]. Calling
149    /// [`Sink::poll_ready`] before creating the [`Message`] isn't really an option because we
150    ///  have no way of undoing anything the Sink does to prepare a slot for us to send to -
151    /// in this specific case, [`WebSocket`] doesn't actually do anything that needs to be
152    /// reversed but we can't rely on that always being the case. See
153    /// <https://github.com/rust-lang/futures-rs/issues/2109> for a discussion about this
154    /// problem. So what we do is take the [`Message`] but don't try and send it directly,
155    /// instead calling [`Sink::poll_ready`] and only sending it if this returns [`Poll::Ready`]
156    pub(crate) queued_message: Option<Message>,
157    pub(crate) state: State,
158    pub(crate) backoff: Backoff,
159    pub(crate) max_retries: u32,
160    pub(crate) retry: u32,
161    /// When socket.is_none this is a reconnect timeout
162    /// When socket.is_some this is a connection stable after retry timeout
163    pub(crate) timeout: Fuse<stream::Once<TimeoutFuture>>,
164    pub(crate) next_poll: NextPoll,
165    pub(crate) closed: bool,
166    /// How long to wait after reconnecting before resetting retries to 0
167    pub(crate) stable_timeout_millis: u32,
168    pub(crate) _phantom: PhantomData<(I, O)>,
169}
170
171impl<I, O> Default for Socket<I, O>
172where
173    I: SocketInput,
174    O: SocketOutput,
175    Message: TryFrom<I>,
176    <Message as TryFrom<I>>::Error: Debug,
177    <O as TryFrom<Message>>::Error: Debug,
178{
179    fn default() -> Self {
180        let (sender, receiver) = mpsc::unbounded();
181        Self {
182            url: String::new(),
183            sink_sender: sender,
184            sink_receiver: receiver,
185            socket: None,
186            queued_message: None,
187            state: State::Connecting,
188            backoff: Backoff::new(DEFAULT_MAX_RETRIES, DEFAULT_BACKOFF_MIN, DEFAULT_BACKOFF_MAX),
189            max_retries: DEFAULT_MAX_RETRIES,
190            retry: 0,
191            timeout: stream::once(TimeoutFuture::new(0)).fuse(),
192            next_poll: NextPoll::Socket,
193            closed: false,
194            stable_timeout_millis: DEFAULT_STABLE_CONNECTION_TIMEOUT.as_millis() as u32,
195            _phantom: PhantomData,
196        }
197    }
198}
199
200impl<I, O> fmt::Debug for Socket<I, O> {
201    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202        f.debug_struct("Socket")
203            .field("url", &self.url)
204            .field("sink_sender", &self.sink_sender)
205            .field("sink_receiver", &self.sink_receiver)
206            .field("socket.is_some", &self.socket.is_some())
207            .field("state", &self.state)
208            .field("backoff", &self.backoff)
209            .field("max_retries", &self.max_retries)
210            .field("retry", &self.retry)
211            .field("timeout", &self.timeout)
212            .field("next_poll", &self.next_poll)
213            .field("closed", &self.closed)
214            .finish()
215    }
216}
217
218impl<I, O> Socket<I, O>
219where
220    I: SocketInput,
221    O: SocketOutput,
222    Message: TryFrom<I>,
223    <Message as TryFrom<I>>::Error: Debug,
224    <O as TryFrom<Message>>::Error: Debug,
225{
226    /// Send the given `message` for sending
227    ///
228    /// Internally it is added to a channel which is polled by the [`Stream`] implementation
229    /// when the underlying [`WebSocket`] is open and ready to transmit it
230    pub async fn send(&mut self, message: I) -> Result<(), TrySendError<I>> {
231        self.sink_sender.unbounded_send(message)
232    }
233
234    /// Get a sink handle for sending messages from the client to the server
235    pub fn get_sink(&self) -> SocketSink<I> {
236        self.sink_sender.clone().into()
237    }
238
239    /// Close the inner socket with the given `code` and `reason`
240    ///
241    /// The socket will try and reconnect after a timeout if there are sufficient retries remaining
242    ///
243    /// This is mainly an implementation detail but it's exposed so it can be used in test code
244    /// to force a reconnect. If used in this way it's worth noting that the Closing/Closed state
245    /// events won't be emitted
246    pub fn close_socket(&mut self, code: Option<u16>, reason: Option<&str>) {
247        // Take and drop the socket
248        if let Some(socket) = self.socket.take() {
249            // Attempt to send the close but don't fail if it can't be sent (the socket could be
250            // dead already)
251            let _ = socket.close(code, reason);
252        }
253
254        // Update our state
255        self.state = State::Closed;
256
257        if let Some(timeout) = self.backoff.next(self.retry) {
258            debug!("Backoff retry: {}, timeout: {:.3}s", self.retry, timeout.as_secs_f32());
259            let millis = timeout.as_millis() as u32;
260            self.timeout = stream::once(TimeoutFuture::new(millis)).fuse();
261        } else {
262            // If we have exceeded our retries the next poll of the stream will close it and error
263            // no need to have a timeout in that case
264            self.timeout = Self::default().timeout;
265        }
266    }
267
268    /// Permanently close the reconnecting socket. No further reconnects will be possible
269    ///
270    /// The socket implements [`FusedStream`] so polling it after close won't panic
271    pub fn close(&mut self, code: Option<u16>, reason: Option<&str>) {
272        self.closed = true;
273        let _ = self.close_socket(code, reason);
274    }
275
276    fn map_socket_output(
277        output: Option<Result<Message, WebSocketError>>,
278    ) -> Option<Result<O, Error<I, O>>> {
279        output.map(|result| {
280            result
281                // Map the gloo socket error
282                .map_err(Error::from)
283                // Convert the return value into the consumers type
284                .map(|message| {
285                    debug!("Got output message: {message:?}");
286                    O::try_from(message)
287                        // Map the consumers try_from error into our error so we can
288                        // flatten the result
289                        .map_err(Error::<I, O>::from_output)
290                })
291                // Equivalent to .flatten unstable feature
292                .and_then(convert::identity)
293        })
294    }
295
296    fn map_channel_input(input: Option<I>) -> Option<Result<Message, Error<I, O>>> {
297        input.map(|input| {
298            debug!("Got input message: {input:?}");
299            Message::try_from(input)
300                // Map the consumers try_from error into our error
301                .map_err(Error::<I, O>::from_input)
302        })
303    }
304}
305
306impl<I, O> FusedStream for Socket<I, O>
307where
308    I: SocketInput,
309    O: SocketOutput,
310    Message: TryFrom<I>,
311    <Message as TryFrom<I>>::Error: Debug,
312    <O as TryFrom<Message>>::Error: Debug,
313{
314    fn is_terminated(&self) -> bool {
315        self.closed
316    }
317}
318
319impl<I, O> Stream for Socket<I, O>
320where
321    I: SocketInput,
322    O: SocketOutput,
323    Message: TryFrom<I>,
324    <Message as TryFrom<I>>::Error: Debug,
325    <O as TryFrom<Message>>::Error: Debug,
326{
327    type Item = Event<I, O>;
328
329    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
330        if self.closed {
331            trace!("polled when closed");
332            return Poll::Ready(None);
333        }
334
335        // Reconnect & queue loop
336        // Loops in two cases
337        // 1. When we disconnected and need to reconnect: socket is none && !self.closed
338        // 2. When we sent a queued message and need to re-poll the channel: queued == true &&
339        //    self.queued_message.is_none()
340        while !self.closed {
341            // Check we have a socket first
342            if let Some(socket) = self.socket.as_ref() {
343                // Update our copy of the state and notify if it's changed
344                let current_state = socket.state().into();
345                if self.state != current_state {
346                    self.state = current_state;
347
348                    #[cfg(feature = "state-events")]
349                    return Poll::Ready(Some(self.state.into()));
350                }
351
352                // Check if the connection has become stable
353                if self.retry > 0 && Pin::new(&mut self.timeout).poll_next(cx).is_ready() {
354                    trace!("connection is stable. Resetting retries ({} -> 0)", self.retry);
355                    self.retry = 0;
356                }
357            } else {
358                trace!("socket is none");
359                ready!(Pin::new(&mut self.timeout).poll_next(cx));
360
361                if self.retry > self.max_retries {
362                    error!("retries exceeded. Closing");
363                    self.close(None, None);
364                    return Poll::Ready(None);
365                }
366
367                info!("Reconnecting socket...");
368                self.retry += 1;
369                match WebSocket::open(&self.url).map_err(Error::<I, O>::from) {
370                    Ok(v) => self.socket = Some(v),
371                    Err(e) => {
372                        error!("WebSocket::open err: {e:?}");
373                        // Reset the connection and set the next retry timeout (although this kind
374                        // of error is likely fatal)
375                        self.close_socket(None, None);
376                        return map_err(e);
377                    },
378                }
379
380                // Update our state
381                self.state = State::Connecting;
382
383                // Set the stable timeout
384                self.timeout = stream::once(TimeoutFuture::new(self.stable_timeout_millis)).fuse();
385
386                // Announce it if state events are turned on
387                #[cfg(feature = "state-events")]
388                return Poll::Ready(Some(self.state.into()));
389            }
390
391            let next_poll_iter = if self.state == State::Open {
392                // If the socket is established we need to poll each future in turn even if we
393                // return in between If we return Pending before polling each future, we won't get
394                // woken when the unpolled future wakes
395                self.next_poll.into_iter()
396            } else {
397                // If the socket is not established, we want to poll the socket first and if it is
398                // still !Open skip polling the incomming message channel since there is nothing we
399                // can do with any messages we unqueue from there at this point. The socket has
400                // extra waker logic to make sure it wakes up after the socket opens even though it
401                // doesn't produce any values at that point so we'll also get woken up and go back
402                // to normal polling logic
403                NextPoll::Socket.into_iter()
404            };
405
406            // Stash if we have a queued item so we can work out if we need to loop again before
407            // returning Poll::Pending
408            let queued = self.queued_message.is_some();
409
410            for next in next_poll_iter {
411                // Update so if we return Ready we resume with the right future
412                self.next_poll = next.next();
413
414                use NextPoll::*;
415                match next {
416                    Socket => {
417                        // Unwrap ok because we assigned it above if one didn't exist
418                        let mut socket = self.socket.as_mut().unwrap();
419
420                        let poll = Pin::new(&mut socket).poll_next(cx).map(Self::map_socket_output);
421                        match poll {
422                            // Just continue to poll the next thing if this is pending
423                            Poll::Pending => {},
424                            // If it's None (closed) disconnect the socket
425                            Poll::Ready(None) => {
426                                self.close_socket(None, None);
427
428                                cfg_if! {
429                                    if #[cfg(feature = "state-events")] {
430                                        // Announce it if state events are turned on
431                                        return Poll::Ready(Some(self.state.into()));
432                                    } else {
433                                        // If not break the next_poll loop to go back to the top of the retry loop
434                                        break;
435                                    }
436                                }
437                            },
438                            other @ Poll::Ready(Some(_)) => return map_poll(other),
439                        }
440                    },
441
442                    Channel => {
443                        // Get the value directly from socket here because plausibly this could be
444                        // the 2nd poll of the loop and it could have updated in between
445
446                        // Unwrap ok because we assigned it above if one didn't exist
447                        if State::Open != self.socket.as_mut().unwrap().state().into() {
448                            // Don't take anything off the incomming message channel if the socket
449                            // isn't open because messages sent to WebSocket when it's not yet open
450                            // are lost Don't poll the channel because the next time we want to be
451                            // woken is when the socket is established, there's no point being woken
452                            // if the consumer keeps adding data to the channel
453                            trace!("socket not open, skipping channel poll");
454                            continue;
455                        }
456
457                        let message_poll = self
458                            .queued_message
459                            // Take the queued message if there is one
460                            .take()
461                            // Map it into a poll result to match the stream result
462                            .map(|m| {
463                                trace!("attempting to send queued message: {m:?}");
464                                Poll::Ready(Some(Ok(m)))
465                            })
466                            // If there isn't one, poll the stream
467                            .unwrap_or_else(|| {
468                                Pin::new(&mut self.sink_receiver)
469                                    .poll_next(cx)
470                                    .map(Self::map_channel_input)
471                            });
472
473                        if let Poll::Ready(message_result) = message_poll {
474                            if let Some(try_from_result) = message_result {
475                                let message = match try_from_result {
476                                    Err(e) => return map_err(e),
477                                    Ok(payload) => payload,
478                                };
479
480                                // Unwrap ok because we assigned it above if one didn't exist
481                                let mut socket = self.socket.as_mut().unwrap();
482
483                                // Check that the Sink is ready to receive the message before trying
484                                // to send it because otherwise we'd have to clone the Message when
485                                // the send fails See [`Socket::queued_message`] for some more
486                                // context
487                                match Pin::new(&mut socket)
488                                    .poll_ready(cx)
489                                    .map_err(Error::<I, O>::from)
490                                {
491                                    Poll::Pending => {
492                                        // We don't need to register a waker for the channel here
493                                        // because we can't do anything if it wakes us when we
494                                        // already have a queued message. We will next be woken by
495                                        // the socket when it is ready and it's already queued to
496                                        // wake because of the poll_ready
497                                        trace!(
498                                            "socket Sink::poll_ready == Poll::Pending. Queuing \
499                                             message: {message:?}"
500                                        );
501                                        self.queued_message = Some(message);
502                                    },
503                                    Poll::Ready(ready) => {
504                                        trace!("socket Sink::poll_ready == Poll::Ready");
505                                        match ready {
506                                            Err(e) => {
507                                                error!("socket Sink::poll_ready err: {e:?}");
508                                                return map_err(e);
509                                            },
510                                            Ok(()) => match Pin::new(&mut socket)
511                                                .start_send(message)
512                                                .map_err(Error::<I, O>::from)
513                                            {
514                                                Ok(()) => {
515                                                    trace!("socket Sink::start_send Ok");
516                                                    if let Err(e) =
517                                                        ready!(Pin::new(&mut socket).poll_flush(cx))
518                                                            .map_err(Error::<I, O>::from)
519                                                    {
520                                                        error!(
521                                                            "socket Sink::poll_flush err: {e:?}"
522                                                        );
523                                                        return map_err(e);
524                                                    }
525                                                },
526                                                Err(e) => {
527                                                    error!("socket Sink::start_send err: {e:?}");
528                                                    return map_err(e);
529                                                },
530                                            },
531                                        }
532                                    },
533                                }
534                            } else {
535                                info!("Input channel closed. Closing");
536                                self.close(None, None);
537                                return Poll::Ready(None);
538                            }
539                        }
540                    },
541                }
542            }
543
544            // Break out of loop if we have a socket and don't need to reconnect
545            if self.socket.is_some()
546            // and we didn't dispatch a queued message 
547            && !(queued && self.queued_message.is_none())
548            {
549                break;
550            }
551        }
552
553        Poll::Pending
554    }
555}