reconnecting_websocket/socket.rs
1use std::{
2 convert,
3 fmt::{self, Debug},
4 marker::PhantomData,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use cfg_if::cfg_if;
10use exponential_backoff::Backoff;
11use futures::{
12 channel::mpsc::{self, SendError, TrySendError, UnboundedReceiver, UnboundedSender},
13 ready,
14 stream::{self, Fuse, FusedStream},
15 Sink, Stream, StreamExt,
16};
17use gloo::{
18 net::websocket::{futures::WebSocket, Message, WebSocketError},
19 timers::future::TimeoutFuture,
20};
21
22use crate::{
23 constants::DEFAULT_STABLE_CONNECTION_TIMEOUT,
24 debug, error,
25 event::{map_err, map_poll},
26 info, trace, Error, Event, SocketInput, SocketOutput, State, DEFAULT_BACKOFF_MAX,
27 DEFAULT_BACKOFF_MIN, DEFAULT_MAX_RETRIES,
28};
29
30/// Enum to track which sub future/stream we polled most recently
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub(crate) enum NextPoll {
33 Socket,
34 Channel,
35}
36
37impl Default for NextPoll {
38 fn default() -> Self {
39 Self::Socket
40 }
41}
42
43impl NextPoll {
44 fn next(self) -> NextPoll {
45 use NextPoll::*;
46 match self {
47 Socket => Channel,
48 Channel => Socket,
49 }
50 }
51}
52
53impl IntoIterator for NextPoll {
54 type IntoIter = NextPollIter;
55 type Item = NextPoll;
56
57 fn into_iter(self) -> Self::IntoIter {
58 use NextPoll::*;
59 let items = match self {
60 Socket => [Socket, Channel],
61 Channel => [Channel, Socket],
62 };
63 NextPollIter { i: 0, items }
64 }
65}
66
67/// An iterator that always contains all the things to poll in the right sequence
68pub(crate) struct NextPollIter {
69 i: usize,
70 items: [NextPoll; 2],
71}
72
73impl Iterator for NextPollIter {
74 type Item = NextPoll;
75
76 fn next(&mut self) -> Option<Self::Item> {
77 if self.i >= self.items.len() {
78 None
79 } else {
80 self.i += 1;
81 Some(self.items[self.i - 1])
82 }
83 }
84}
85
86/// A handle that implements [`Sink`] for sending messages from the client to the server
87///
88/// Cheap and safe to clone (internally it's a channel sender)
89#[derive(Debug, Clone)]
90pub struct SocketSink<I> {
91 sender: UnboundedSender<I>,
92}
93
94impl<I> From<UnboundedSender<I>> for SocketSink<I> {
95 fn from(sender: UnboundedSender<I>) -> Self {
96 Self { sender }
97 }
98}
99
100impl<I> Sink<I> for SocketSink<I>
101where
102 I: SocketInput,
103 Message: TryFrom<I>,
104 <Message as TryFrom<I>>::Error: Debug,
105{
106 type Error = SendError;
107
108 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109 UnboundedSender::poll_ready(&self.sender, cx)
110 }
111
112 fn start_send(self: Pin<&mut Self>, msg: I) -> Result<(), Self::Error> {
113 self.sender.unbounded_send(msg).map_err(TrySendError::into_send_error)
114 }
115
116 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
117 Poll::Ready(Ok(()))
118 }
119
120 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
121 self.sender.close_channel();
122 Poll::Ready(Ok(()))
123 }
124}
125
126/// A wrapper around [`WebSocket`] that reconnects when the socket
127/// drops. Uses [`Backoff`] to determine the delay between reconnects
128///
129/// See the [`crate`] documentation for usage and examples
130///
131/// An error returned by the [`Stream`] aren't necessarily fatal. Check [`Error`] for more detail.
132/// `Poll::Ready(None)` is the main fatal case that requires a new instance of [`Socket`]
133pub struct Socket<I, O> {
134 /// The server URL to connect to on reconnect
135 pub(crate) url: String,
136 /// The sending end of the input message channel
137 /// Retained to implement [`Self::get_sink`] and [`Self::send`]
138 pub(crate) sink_sender: UnboundedSender<I>,
139 /// The receiving side of the input message channel
140 /// Polled by the [`Stream`] implementation
141 pub(crate) sink_receiver: UnboundedReceiver<I>,
142 /// The inner socket, None when a reconnect is pending
143 pub(crate) socket: Option<WebSocket>,
144 /// A queued message that needs to be sent as soon as the socket is [`State::Open`] This
145 /// happens when the inner socket exists but hasn't yet fully connected. When in this
146 /// state the [`WebSocket`] [`Sink`] implementation returns [`Poll::Pending`]. Since we
147 /// can't reliably know that with any certainty until we've already created the
148 /// [`Message`] from the input channel and called [`Sink::poll_ready`]. Calling
149 /// [`Sink::poll_ready`] before creating the [`Message`] isn't really an option because we
150 /// have no way of undoing anything the Sink does to prepare a slot for us to send to -
151 /// in this specific case, [`WebSocket`] doesn't actually do anything that needs to be
152 /// reversed but we can't rely on that always being the case. See
153 /// <https://github.com/rust-lang/futures-rs/issues/2109> for a discussion about this
154 /// problem. So what we do is take the [`Message`] but don't try and send it directly,
155 /// instead calling [`Sink::poll_ready`] and only sending it if this returns [`Poll::Ready`]
156 pub(crate) queued_message: Option<Message>,
157 pub(crate) state: State,
158 pub(crate) backoff: Backoff,
159 pub(crate) max_retries: u32,
160 pub(crate) retry: u32,
161 /// When socket.is_none this is a reconnect timeout
162 /// When socket.is_some this is a connection stable after retry timeout
163 pub(crate) timeout: Fuse<stream::Once<TimeoutFuture>>,
164 pub(crate) next_poll: NextPoll,
165 pub(crate) closed: bool,
166 /// How long to wait after reconnecting before resetting retries to 0
167 pub(crate) stable_timeout_millis: u32,
168 pub(crate) _phantom: PhantomData<(I, O)>,
169}
170
171impl<I, O> Default for Socket<I, O>
172where
173 I: SocketInput,
174 O: SocketOutput,
175 Message: TryFrom<I>,
176 <Message as TryFrom<I>>::Error: Debug,
177 <O as TryFrom<Message>>::Error: Debug,
178{
179 fn default() -> Self {
180 let (sender, receiver) = mpsc::unbounded();
181 Self {
182 url: String::new(),
183 sink_sender: sender,
184 sink_receiver: receiver,
185 socket: None,
186 queued_message: None,
187 state: State::Connecting,
188 backoff: Backoff::new(DEFAULT_MAX_RETRIES, DEFAULT_BACKOFF_MIN, DEFAULT_BACKOFF_MAX),
189 max_retries: DEFAULT_MAX_RETRIES,
190 retry: 0,
191 timeout: stream::once(TimeoutFuture::new(0)).fuse(),
192 next_poll: NextPoll::Socket,
193 closed: false,
194 stable_timeout_millis: DEFAULT_STABLE_CONNECTION_TIMEOUT.as_millis() as u32,
195 _phantom: PhantomData,
196 }
197 }
198}
199
200impl<I, O> fmt::Debug for Socket<I, O> {
201 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202 f.debug_struct("Socket")
203 .field("url", &self.url)
204 .field("sink_sender", &self.sink_sender)
205 .field("sink_receiver", &self.sink_receiver)
206 .field("socket.is_some", &self.socket.is_some())
207 .field("state", &self.state)
208 .field("backoff", &self.backoff)
209 .field("max_retries", &self.max_retries)
210 .field("retry", &self.retry)
211 .field("timeout", &self.timeout)
212 .field("next_poll", &self.next_poll)
213 .field("closed", &self.closed)
214 .finish()
215 }
216}
217
218impl<I, O> Socket<I, O>
219where
220 I: SocketInput,
221 O: SocketOutput,
222 Message: TryFrom<I>,
223 <Message as TryFrom<I>>::Error: Debug,
224 <O as TryFrom<Message>>::Error: Debug,
225{
226 /// Send the given `message` for sending
227 ///
228 /// Internally it is added to a channel which is polled by the [`Stream`] implementation
229 /// when the underlying [`WebSocket`] is open and ready to transmit it
230 pub async fn send(&mut self, message: I) -> Result<(), TrySendError<I>> {
231 self.sink_sender.unbounded_send(message)
232 }
233
234 /// Get a sink handle for sending messages from the client to the server
235 pub fn get_sink(&self) -> SocketSink<I> {
236 self.sink_sender.clone().into()
237 }
238
239 /// Close the inner socket with the given `code` and `reason`
240 ///
241 /// The socket will try and reconnect after a timeout if there are sufficient retries remaining
242 ///
243 /// This is mainly an implementation detail but it's exposed so it can be used in test code
244 /// to force a reconnect. If used in this way it's worth noting that the Closing/Closed state
245 /// events won't be emitted
246 pub fn close_socket(&mut self, code: Option<u16>, reason: Option<&str>) {
247 // Take and drop the socket
248 if let Some(socket) = self.socket.take() {
249 // Attempt to send the close but don't fail if it can't be sent (the socket could be
250 // dead already)
251 let _ = socket.close(code, reason);
252 }
253
254 // Update our state
255 self.state = State::Closed;
256
257 if let Some(timeout) = self.backoff.next(self.retry) {
258 debug!("Backoff retry: {}, timeout: {:.3}s", self.retry, timeout.as_secs_f32());
259 let millis = timeout.as_millis() as u32;
260 self.timeout = stream::once(TimeoutFuture::new(millis)).fuse();
261 } else {
262 // If we have exceeded our retries the next poll of the stream will close it and error
263 // no need to have a timeout in that case
264 self.timeout = Self::default().timeout;
265 }
266 }
267
268 /// Permanently close the reconnecting socket. No further reconnects will be possible
269 ///
270 /// The socket implements [`FusedStream`] so polling it after close won't panic
271 pub fn close(&mut self, code: Option<u16>, reason: Option<&str>) {
272 self.closed = true;
273 let _ = self.close_socket(code, reason);
274 }
275
276 fn map_socket_output(
277 output: Option<Result<Message, WebSocketError>>,
278 ) -> Option<Result<O, Error<I, O>>> {
279 output.map(|result| {
280 result
281 // Map the gloo socket error
282 .map_err(Error::from)
283 // Convert the return value into the consumers type
284 .map(|message| {
285 debug!("Got output message: {message:?}");
286 O::try_from(message)
287 // Map the consumers try_from error into our error so we can
288 // flatten the result
289 .map_err(Error::<I, O>::from_output)
290 })
291 // Equivalent to .flatten unstable feature
292 .and_then(convert::identity)
293 })
294 }
295
296 fn map_channel_input(input: Option<I>) -> Option<Result<Message, Error<I, O>>> {
297 input.map(|input| {
298 debug!("Got input message: {input:?}");
299 Message::try_from(input)
300 // Map the consumers try_from error into our error
301 .map_err(Error::<I, O>::from_input)
302 })
303 }
304}
305
306impl<I, O> FusedStream for Socket<I, O>
307where
308 I: SocketInput,
309 O: SocketOutput,
310 Message: TryFrom<I>,
311 <Message as TryFrom<I>>::Error: Debug,
312 <O as TryFrom<Message>>::Error: Debug,
313{
314 fn is_terminated(&self) -> bool {
315 self.closed
316 }
317}
318
319impl<I, O> Stream for Socket<I, O>
320where
321 I: SocketInput,
322 O: SocketOutput,
323 Message: TryFrom<I>,
324 <Message as TryFrom<I>>::Error: Debug,
325 <O as TryFrom<Message>>::Error: Debug,
326{
327 type Item = Event<I, O>;
328
329 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
330 if self.closed {
331 trace!("polled when closed");
332 return Poll::Ready(None);
333 }
334
335 // Reconnect & queue loop
336 // Loops in two cases
337 // 1. When we disconnected and need to reconnect: socket is none && !self.closed
338 // 2. When we sent a queued message and need to re-poll the channel: queued == true &&
339 // self.queued_message.is_none()
340 while !self.closed {
341 // Check we have a socket first
342 if let Some(socket) = self.socket.as_ref() {
343 // Update our copy of the state and notify if it's changed
344 let current_state = socket.state().into();
345 if self.state != current_state {
346 self.state = current_state;
347
348 #[cfg(feature = "state-events")]
349 return Poll::Ready(Some(self.state.into()));
350 }
351
352 // Check if the connection has become stable
353 if self.retry > 0 && Pin::new(&mut self.timeout).poll_next(cx).is_ready() {
354 trace!("connection is stable. Resetting retries ({} -> 0)", self.retry);
355 self.retry = 0;
356 }
357 } else {
358 trace!("socket is none");
359 ready!(Pin::new(&mut self.timeout).poll_next(cx));
360
361 if self.retry > self.max_retries {
362 error!("retries exceeded. Closing");
363 self.close(None, None);
364 return Poll::Ready(None);
365 }
366
367 info!("Reconnecting socket...");
368 self.retry += 1;
369 match WebSocket::open(&self.url).map_err(Error::<I, O>::from) {
370 Ok(v) => self.socket = Some(v),
371 Err(e) => {
372 error!("WebSocket::open err: {e:?}");
373 // Reset the connection and set the next retry timeout (although this kind
374 // of error is likely fatal)
375 self.close_socket(None, None);
376 return map_err(e);
377 },
378 }
379
380 // Update our state
381 self.state = State::Connecting;
382
383 // Set the stable timeout
384 self.timeout = stream::once(TimeoutFuture::new(self.stable_timeout_millis)).fuse();
385
386 // Announce it if state events are turned on
387 #[cfg(feature = "state-events")]
388 return Poll::Ready(Some(self.state.into()));
389 }
390
391 let next_poll_iter = if self.state == State::Open {
392 // If the socket is established we need to poll each future in turn even if we
393 // return in between If we return Pending before polling each future, we won't get
394 // woken when the unpolled future wakes
395 self.next_poll.into_iter()
396 } else {
397 // If the socket is not established, we want to poll the socket first and if it is
398 // still !Open skip polling the incomming message channel since there is nothing we
399 // can do with any messages we unqueue from there at this point. The socket has
400 // extra waker logic to make sure it wakes up after the socket opens even though it
401 // doesn't produce any values at that point so we'll also get woken up and go back
402 // to normal polling logic
403 NextPoll::Socket.into_iter()
404 };
405
406 // Stash if we have a queued item so we can work out if we need to loop again before
407 // returning Poll::Pending
408 let queued = self.queued_message.is_some();
409
410 for next in next_poll_iter {
411 // Update so if we return Ready we resume with the right future
412 self.next_poll = next.next();
413
414 use NextPoll::*;
415 match next {
416 Socket => {
417 // Unwrap ok because we assigned it above if one didn't exist
418 let mut socket = self.socket.as_mut().unwrap();
419
420 let poll = Pin::new(&mut socket).poll_next(cx).map(Self::map_socket_output);
421 match poll {
422 // Just continue to poll the next thing if this is pending
423 Poll::Pending => {},
424 // If it's None (closed) disconnect the socket
425 Poll::Ready(None) => {
426 self.close_socket(None, None);
427
428 cfg_if! {
429 if #[cfg(feature = "state-events")] {
430 // Announce it if state events are turned on
431 return Poll::Ready(Some(self.state.into()));
432 } else {
433 // If not break the next_poll loop to go back to the top of the retry loop
434 break;
435 }
436 }
437 },
438 other @ Poll::Ready(Some(_)) => return map_poll(other),
439 }
440 },
441
442 Channel => {
443 // Get the value directly from socket here because plausibly this could be
444 // the 2nd poll of the loop and it could have updated in between
445
446 // Unwrap ok because we assigned it above if one didn't exist
447 if State::Open != self.socket.as_mut().unwrap().state().into() {
448 // Don't take anything off the incomming message channel if the socket
449 // isn't open because messages sent to WebSocket when it's not yet open
450 // are lost Don't poll the channel because the next time we want to be
451 // woken is when the socket is established, there's no point being woken
452 // if the consumer keeps adding data to the channel
453 trace!("socket not open, skipping channel poll");
454 continue;
455 }
456
457 let message_poll = self
458 .queued_message
459 // Take the queued message if there is one
460 .take()
461 // Map it into a poll result to match the stream result
462 .map(|m| {
463 trace!("attempting to send queued message: {m:?}");
464 Poll::Ready(Some(Ok(m)))
465 })
466 // If there isn't one, poll the stream
467 .unwrap_or_else(|| {
468 Pin::new(&mut self.sink_receiver)
469 .poll_next(cx)
470 .map(Self::map_channel_input)
471 });
472
473 if let Poll::Ready(message_result) = message_poll {
474 if let Some(try_from_result) = message_result {
475 let message = match try_from_result {
476 Err(e) => return map_err(e),
477 Ok(payload) => payload,
478 };
479
480 // Unwrap ok because we assigned it above if one didn't exist
481 let mut socket = self.socket.as_mut().unwrap();
482
483 // Check that the Sink is ready to receive the message before trying
484 // to send it because otherwise we'd have to clone the Message when
485 // the send fails See [`Socket::queued_message`] for some more
486 // context
487 match Pin::new(&mut socket)
488 .poll_ready(cx)
489 .map_err(Error::<I, O>::from)
490 {
491 Poll::Pending => {
492 // We don't need to register a waker for the channel here
493 // because we can't do anything if it wakes us when we
494 // already have a queued message. We will next be woken by
495 // the socket when it is ready and it's already queued to
496 // wake because of the poll_ready
497 trace!(
498 "socket Sink::poll_ready == Poll::Pending. Queuing \
499 message: {message:?}"
500 );
501 self.queued_message = Some(message);
502 },
503 Poll::Ready(ready) => {
504 trace!("socket Sink::poll_ready == Poll::Ready");
505 match ready {
506 Err(e) => {
507 error!("socket Sink::poll_ready err: {e:?}");
508 return map_err(e);
509 },
510 Ok(()) => match Pin::new(&mut socket)
511 .start_send(message)
512 .map_err(Error::<I, O>::from)
513 {
514 Ok(()) => {
515 trace!("socket Sink::start_send Ok");
516 if let Err(e) =
517 ready!(Pin::new(&mut socket).poll_flush(cx))
518 .map_err(Error::<I, O>::from)
519 {
520 error!(
521 "socket Sink::poll_flush err: {e:?}"
522 );
523 return map_err(e);
524 }
525 },
526 Err(e) => {
527 error!("socket Sink::start_send err: {e:?}");
528 return map_err(e);
529 },
530 },
531 }
532 },
533 }
534 } else {
535 info!("Input channel closed. Closing");
536 self.close(None, None);
537 return Poll::Ready(None);
538 }
539 }
540 },
541 }
542 }
543
544 // Break out of loop if we have a socket and don't need to reconnect
545 if self.socket.is_some()
546 // and we didn't dispatch a queued message
547 && !(queued && self.queued_message.is_none())
548 {
549 break;
550 }
551 }
552
553 Poll::Pending
554 }
555}