generic_api_client/
websocket.rs

1use std::{
2    sync::{Arc, atomic::{AtomicBool, Ordering}},
3    collections::hash_map::{HashMap, Entry},
4    time::Duration,
5    mem,
6};
7use tokio::{
8    sync::{mpsc as tokio_mpsc, Mutex as AsyncMutex, Notify},
9    task::JoinHandle,
10    net::TcpStream,
11    time::{MissedTickBehavior, timeout},
12};
13use tokio_tungstenite::{
14    tungstenite,
15    MaybeTlsStream,
16};
17pub use tungstenite::Error as TungsteniteError;
18use futures_util::{
19    sink::SinkExt,
20    stream::{StreamExt, SplitSink},
21};
22use parking_lot::Mutex as SyncMutex;
23
24type WebSocketStream = tokio_tungstenite::WebSocketStream<MaybeTlsStream<TcpStream>>;
25type WebSocketSplitSink = SplitSink<WebSocketStream, tungstenite::Message>;
26
27/// A `struct` that holds a websocket connection.
28///
29/// Dropping this `struct` terminates the connection.
30///
31/// # Reconnecting
32/// `WebSocketConnection` automatically reconnects when an [TungsteniteError] occurs.
33/// Note, that during reconnection, it is **possible** that the [WebSocketHandler] receives multiple identical messages
34/// even though the message was sent only once by the server, or receives only one message even though
35/// multiple identical messages were sent by the server, because there could be a time difference in the new connection and
36/// the old connection.
37///
38/// You can use the [reconnect_state()][Self::reconnect_state()] method to check if the connection is under
39/// a reconnection, or manually request a reconnection.
40#[derive(Debug)]
41#[must_use = "dropping WebSocketConnection closes the connection"]
42pub struct WebSocketConnection<H: WebSocketHandler> {
43    task_reconnect: JoinHandle<()>,
44    sink: Arc<AsyncMutex<WebSocketSplitSink>>,
45    inner: Arc<ConnectionInner<H>>,
46    reconnect_state: ReconnectState,
47}
48
49// Two ways connections end:
50// - User drops WebSocketConnection
51//     1. feed_handler receives a message and closes the connection, then terminates
52//     2. start_connection notices that the connection is closed, and attempts to notify feed_handler, then terminates
53// - Reconnection
54//     This happens when:
55//     - the user requests so
56//     - message timeout
57//     - the server closes the connection
58//     - some kind of error occurs while receiving the message
59//
60//     1. task_reconnect starts a new connection
61//     2. task_reconnect closes the old connection
62//     3. start_connection (old) notices that the connection is closed, and notifies feed_handler, then terminates
63//     4. feed_handler receives the message, but ignores it because it is from the old connection
64#[derive(Debug)]
65struct ConnectionInner<H: WebSocketHandler> {
66    url: String,
67    handler: Arc<SyncMutex<H>>,
68    message_tx: tokio_mpsc::UnboundedSender<(bool, FeederMessage)>,
69    next_connection_id: AtomicBool,
70}
71
72enum FeederMessage {
73    Message(tungstenite::Result<tungstenite::Message>),
74    ConnectionClosed,
75    DropConnectionRequest,
76}
77
78impl<H: WebSocketHandler> WebSocketConnection<H> {
79    /// Starts a new `WebSocketConnection` to the given url using the given [handler][WebSocketHandler].
80    pub async fn new(url: &str, handler: H) -> Result<Self, TungsteniteError> {
81        let config = handler.websocket_config();
82        let handler = Arc::new(SyncMutex::new(handler));
83        let url = config.url_prefix.clone() + url;
84
85        let (message_tx, message_rx) = tokio_mpsc::unbounded_channel();
86        let reconnect_manager = ReconnectState::new();
87
88        let connection = Arc::new(ConnectionInner {
89            url,
90            handler: Arc::clone(&handler),
91            message_tx,
92            next_connection_id: AtomicBool::new(false),
93        });
94
95        async fn feed_handler(
96            connection: Arc<ConnectionInner<impl WebSocketHandler>>,
97            mut message_rx: tokio_mpsc::UnboundedReceiver<(bool, FeederMessage)>,
98            reconnect_manager: ReconnectState,
99            config: WebSocketConfig,
100            sink: Arc<AsyncMutex<WebSocketSplitSink>>,
101        ) {
102            let mut messages: HashMap<WebSocketMessage, isize> = HashMap::new();
103
104            let timeout_duration = if config.message_timeout.is_zero() {
105                Duration::MAX
106            } else {
107                config.message_timeout
108            };
109
110            loop {
111                match timeout(timeout_duration, message_rx.recv()).await {
112                    // message successfully received
113                    Ok(Some((id, FeederMessage::Message(Ok(message))))) => {
114                        // message successfully received
115                        if let Some(message) = WebSocketMessage::from_message(message) {
116                            if reconnect_manager.is_reconnecting() {
117                                // reconnecting
118                                let id_sign: isize = if id {
119                                    1
120                                } else {
121                                    -1
122                                };
123                                let entry = messages.entry(message.clone());
124                                match entry {
125                                    Entry::Occupied(mut occupied) => {
126                                        if config.ignore_duplicate_during_reconnection {
127                                            log::debug!("Skipping duplicate message.");
128                                            continue;
129                                        }
130
131                                        *occupied.get_mut() += id_sign;
132                                        if id_sign != occupied.get().signum() {
133                                            // same message which comes from different connections, so we assume it's a duplicate.
134                                            log::debug!("Skipping duplicate message.");
135                                            continue;
136                                        }
137                                        // comes from the same connection, which means the message was sent twice.
138                                    },
139                                    Entry::Vacant(vacant) => {
140                                        // new message
141                                        vacant.insert(id_sign);
142                                    }
143                                }
144                            } else {
145                                messages.clear();
146                            }
147                            let messages = connection.handler.lock().handle_message(message);
148                            let mut sink_lock = sink.lock().await;
149                            for message in messages {
150                                if let Err(error) = sink_lock.send(message.into_message()).await {
151                                    log::error!("Failed to send message because of an error: {}", error);
152                                };
153                            }
154                            if let Err(error) = sink_lock.flush().await {
155                                log::error!("An error occurred while flushing WebSocket sink: {error:?}");
156                            }
157                        }
158                    },
159                    // failed to receive message
160                    Ok(Some((_, FeederMessage::Message(Err(error))))) => {
161                        log::error!("Failed to receive message because of an error: {error:?}");
162                        if reconnect_manager.request_reconnect() {
163                            log::info!("Reconnecting WebSocket because there was an error while receiving a message");
164                        }
165                    },
166                    // timeout
167                    Err(_) => {
168                        log::debug!("WebSocket message timeout");
169                        if reconnect_manager.request_reconnect() {
170                            log::info!("Reconnecting WebSocket because of timeout");
171                        }
172                    },
173                    // connection was closed
174                    Ok(Some((id, FeederMessage::ConnectionClosed))) => {
175                        let current_id = !connection.next_connection_id.load(Ordering::SeqCst);
176                        if id != current_id {
177                            // old connection, ignore
178                            continue;
179                        }
180                        log::debug!("WebSocket connection closed by server");
181                        if reconnect_manager.request_reconnect() {
182                            log::info!("Reconnecting WebSocket because it was disconnected by the server");
183                        }
184                    },
185                    // the connection is no longer needed because WebSocketConnection was dropped
186                    Ok(Some((_, FeederMessage::DropConnectionRequest))) => {
187                        if let Err(error) = sink.lock().await.close().await {
188                            log::debug!("Failed to close WebSocket connection: {error:?}");
189                        }
190                        break;
191                    }
192                    // message_tx has been dropped, which should never happen because it's always accessible by connection.message_tx.
193                    Ok(None) => unreachable!("message_rx should never be closed"),
194                }
195            }
196            connection.handler.lock().handle_close(false);
197        }
198
199        async fn reconnect<H: WebSocketHandler>(
200            interval: Duration,
201            cooldown: Duration,
202            connection: Arc<ConnectionInner<H>>,
203            sink: Arc<AsyncMutex<WebSocketSplitSink>>,
204            reconnect_manager: ReconnectState,
205            no_duplicate: bool,
206            wait: Duration,
207        ) {
208            let mut cooldown = tokio::time::interval(cooldown);
209            cooldown.set_missed_tick_behavior(MissedTickBehavior::Delay);
210            loop {
211                let timer = if interval.is_zero() {
212                    // never completes
213                    tokio::time::sleep(Duration::MAX)
214                } else {
215                    tokio::time::sleep(interval)
216                };
217                tokio::select! {
218                    _ = reconnect_manager.inner.reconnect_notify.notified() => {},
219                    _ = timer => {},
220                }
221                log::debug!("Reconnection requested");
222                cooldown.tick().await;
223                reconnect_manager.inner.reconnecting.store(true, Ordering::SeqCst);
224
225                // reconnect_notify might have been notified while waiting the cooldown,
226                // so we consume any existing permits on reconnect_notify
227                reconnect_manager.inner.reconnect_notify.notify_one();
228                // this completes immediately because we just added a permit
229                reconnect_manager.inner.reconnect_notify.notified().await;
230
231                log::debug!("Starting reconnection process ...");
232                if no_duplicate {
233                    tokio::time::sleep(wait).await;
234                }
235
236                // start a new connection
237                match WebSocketConnection::<H>::start_connection(Arc::clone(&connection)).await {
238                    Ok(new_sink) => {
239                        // replace the sink with the new one
240                        let mut old_sink = mem::replace(&mut *sink.lock().await, new_sink);
241                        log::debug!("New connection established");
242
243                        if no_duplicate {
244                            tokio::time::sleep(wait).await;
245                        }
246
247                        if let Err(error) = old_sink.close().await {
248                            log::debug!("An error occurred while closing old connection: {}", error);
249                        }
250                        connection.handler.lock().handle_close(true);
251                        log::debug!("Old connection closed");
252                    },
253                    Err(error) => {
254                        // try reconnecting again
255                        log::error!("Failed to reconnect because of an error: {}, trying again ...", error);
256                        reconnect_manager.inner.reconnect_notify.notify_one();
257                    },
258                }
259
260                if no_duplicate {
261                    tokio::time::sleep(wait).await;
262                }
263
264                reconnect_manager.inner.reconnecting.store(false, Ordering::SeqCst);
265                log::debug!("Reconnection process complete");
266            }
267        }
268
269        let sink_inner = Self::start_connection(Arc::clone(&connection)).await?;
270        let sink = Arc::new(AsyncMutex::new(sink_inner));
271
272        tokio::spawn(
273            feed_handler(
274                Arc::clone(&connection),
275                message_rx,
276                reconnect_manager.clone(),
277                config.clone(),
278                Arc::clone(&sink),
279            )
280        );
281
282        let task_reconnect = tokio::spawn(reconnect(
283            config.refresh_after,
284            config.connect_cooldown,
285            Arc::clone(&connection),
286            Arc::clone(&sink),
287            reconnect_manager.clone(),
288            config.ignore_duplicate_during_reconnection,
289            config.reconnection_wait,
290        ));
291
292        Ok(Self {
293            task_reconnect,
294            sink,
295            inner: connection,
296            reconnect_state: reconnect_manager,
297        })
298    }
299
300    async fn start_connection(connection: Arc<ConnectionInner<impl WebSocketHandler>>) -> Result<WebSocketSplitSink, TungsteniteError> {
301        let (websocket_stream, _) = tokio_tungstenite::connect_async(connection.url.clone()).await?;
302        let (mut sink, mut stream) = websocket_stream.split();
303
304        let messages = connection.handler.lock().handle_start();
305        for message in messages {
306            sink.send(message.into_message()).await?;
307        }
308        sink.flush().await?;
309
310        // fetch_not is unstable so we use fetch_xor
311        let id = connection.next_connection_id.fetch_xor(true, Ordering::SeqCst);
312
313        // pass messages to task_feed_handler
314        tokio::spawn(async move {
315            while let Some(message) = stream.next().await {
316                // send the received message to the task running feed_handler
317                if connection.message_tx.send((id, FeederMessage::Message(message))).is_err() {
318                    // the channel is closed. we can't disconnect because we don't have the sink
319                    log::debug!("WebSocket message receiver is closed; abandon connection");
320                    return;
321                }
322            }
323            // the underlying WebSocket connection was closed
324
325            drop(connection.message_tx.send((id, FeederMessage::ConnectionClosed))); // this may be Err
326            log::debug!("WebSocket stream closed");
327        });
328        Ok(sink)
329    }
330
331    /// Sends a message to the connection.
332    pub async fn send_message(&self, message: WebSocketMessage) -> Result<(), TungsteniteError> {
333        let mut sink_lock = self.sink.lock().await;
334        sink_lock.send(message.into_message()).await?;
335        sink_lock.flush().await
336    }
337
338    /// Returns a [ReconnectState] for this connection.
339    ///
340    /// See [ReconnectState] for more information.
341    pub fn reconnect_state(&self) -> ReconnectState {
342        self.reconnect_state.clone()
343    }
344}
345
346impl<H: WebSocketHandler> Drop for WebSocketConnection<H> {
347    fn drop(&mut self) {
348        self.task_reconnect.abort();
349        // sending None tells the feeder to close
350        let current_id = !self.inner.next_connection_id.load(Ordering::SeqCst);
351        self.inner.message_tx.send((current_id, FeederMessage::DropConnectionRequest)).ok();
352    }
353}
354
355/// A `struct` to request the [WebSocketConnection] to perform a reconnect.
356///
357/// This `struct` uses an [Arc] internally, so you can obtain multiple
358/// `ReconnectState`s for a single [WebSocketConnection] by [cloning][Clone].
359#[derive(Debug, Clone)]
360pub struct ReconnectState {
361    inner: Arc<ReconnectMangerInner>,
362}
363
364#[derive(Debug)]
365struct ReconnectMangerInner {
366    reconnect_notify: Notify,
367    reconnecting: AtomicBool,
368}
369
370impl ReconnectState {
371    fn new() -> Self {
372        Self {
373            inner: Arc::new(ReconnectMangerInner {
374                reconnect_notify: Notify::new(),
375                reconnecting: AtomicBool::new(false),
376            })
377        }
378    }
379
380    /// Returns `true` iff the [WebSocketConnection] is undergoing a reconnection process.
381    pub fn is_reconnecting(&self) -> bool {
382        self.inner.reconnecting.load(Ordering::SeqCst)
383    }
384
385    /// Request the [WebSocketConnection] to perform a reconnect.
386    ///
387    /// Will return `false` if it is already in a reconnection process.
388    pub fn request_reconnect(&self) -> bool {
389        if self.is_reconnecting() {
390            false
391        } else {
392            self.inner.reconnect_notify.notify_one();
393            true
394        }
395    }
396}
397
398/// An enum that represents a websocket message.
399///
400/// See also [tungstenite::Message].
401#[derive(Debug, Eq, PartialEq, Clone, Hash)]
402pub enum WebSocketMessage {
403    /// A text message
404    Text(String),
405    /// A binary message
406    Binary(Vec<u8>),
407    /// A ping message
408    Ping(Vec<u8>),
409    /// A pong message
410    Pong(Vec<u8>),
411}
412
413impl WebSocketMessage {
414    fn from_message(message: tungstenite::Message) -> Option<Self> {
415        match message {
416            tungstenite::Message::Text(text) => Some(Self::Text(text)),
417            tungstenite::Message::Binary(data) => Some(Self::Binary(data)),
418            tungstenite::Message::Ping(data) => Some(Self::Ping(data)),
419            tungstenite::Message::Pong(data) => Some(Self::Pong(data)),
420            tungstenite::Message::Close(_) | tungstenite::Message::Frame(_) => None,
421        }
422    }
423
424    fn into_message(self) -> tungstenite::Message {
425        match self {
426            WebSocketMessage::Text(text) => tungstenite::Message::Text(text),
427            WebSocketMessage::Binary(data) => tungstenite::Message::Binary(data),
428            WebSocketMessage::Ping(data) => tungstenite::Message::Ping(data),
429            WebSocketMessage::Pong(data) => tungstenite::Message::Pong(data),
430        }
431    }
432}
433
434/// A `trait` which is used to handle events on the [WebSocketConnection].
435///
436/// The `struct` implementing this `trait` is required to be [Send] and `'static` because
437/// it will be sent between threads.
438pub trait WebSocketHandler: Send + 'static {
439    /// Returns a [WebSocketConfig] that will be applied for all WebSocket connections handled by this handler.
440    fn websocket_config(&self) -> WebSocketConfig {
441        WebSocketConfig::default()
442    }
443
444    /// Called when a new connection has been started, and returns messages that should be sent to the server.
445    ///
446    /// This could be called multiple times because the connection can be reconnected.
447    fn handle_start(&mut self) -> Vec<WebSocketMessage> {
448        log::debug!("WebSocket connection started");
449        vec![]
450    }
451
452    /// Called when the [WebSocketConnection] received a message, returns messages to be sent to the server.
453    fn handle_message(&mut self, message: WebSocketMessage) -> Vec<WebSocketMessage>;
454
455    /// Called when a websocket connection is closed.
456    ///
457    /// If the parameter `reconnect` is:
458    /// - `true`, it means that the connection is being reconnected for some reason.
459    /// - `false`, it means that the connection will not be reconnected, because the [WebSocketConnection] was dropped.
460    #[allow(unused_variables)]
461    fn handle_close(&mut self, reconnect: bool) {
462        log::debug!("WebSocket connection closed; reconnect: {}", reconnect);
463    }
464}
465
466/// Configuration for [WebSocketHandler].
467///
468/// Should be returned by [WebSocketHandler::websocket_config()].
469#[derive(Debug, Clone)]
470#[non_exhaustive]
471pub struct WebSocketConfig {
472    /// Duration that should elapse between each attempt to start a new connection.
473    ///
474    /// This matters because the [WebSocketConnection] reconnects on error. If the error
475    /// continues to happen, it could spam the server if `connect_cooldown` is too short. [Default]s to 3000ms.
476    pub connect_cooldown: Duration,
477    /// The [WebSocketConnection] will automatically reconnect when `refresh_after` has elapsed since
478    /// the last connection started. If you don't want this feature, set it to [Duration::ZERO]. [Default]s to [Duration::ZERO].
479    pub refresh_after: Duration,
480    /// Prefix which will be used for connections that started using this `WebSocketConfig`. [Default]s to `""`.
481    ///
482    /// Example usage: `"wss://example.com"`
483    pub url_prefix: String,
484    /// During reconnection, [WebSocketHandler] might receive two identical messages
485    /// even though the server sent only one message. By setting this to `true`, [WebSocketConnection]
486    /// will not send duplicate messages to the [WebSocketHandler]. You should set this option to `true`
487    /// when messages contain some sort of ID and are distinguishable.
488    ///
489    /// Note, that [WebSocketConnection] will **not** check duplicate messages when it is not under reconnection
490    /// even this option is set to `true`.
491    pub ignore_duplicate_during_reconnection: bool,
492    /// When `ignore_duplicate_during_reconnection` is set to `true`, [WebSocketConnection] will wait for a
493    /// certain amount of time to make sure no message is lost. [Default]s to 300ms
494    pub reconnection_wait: Duration,
495    /// A reconnection will be triggered if no messages are received within this amount of time.
496    /// [Default]s to [Duration::ZERO], which means no timeout will be applied.
497    pub message_timeout: Duration,
498}
499
500impl WebSocketConfig {
501    /// Constructs a new `WebSocketConfig` with its fields set to [default][WebSocketConfig::default()].
502    pub fn new() -> Self {
503        Self::default()
504    }
505}
506
507impl Default for WebSocketConfig {
508    fn default() -> Self {
509        Self {
510            connect_cooldown: Duration::from_millis(3000),
511            refresh_after: Duration::ZERO,
512            url_prefix: String::new(),
513            ignore_duplicate_during_reconnection: false,
514            reconnection_wait: Duration::from_millis(300),
515            message_timeout: Duration::ZERO,
516        }
517    }
518}