ezsockets/
socket.rs

1use bytes::Bytes;
2use futures::lock::Mutex;
3use futures::{FutureExt, SinkExt, StreamExt, TryStreamExt};
4use std::marker::PhantomData;
5use std::sync::atomic::{AtomicU8, Ordering};
6use std::sync::Arc;
7use std::time::Duration;
8use tokio_tungstenite_wasm::Error as WSError;
9use tungstenite::Utf8Bytes;
10
11#[cfg(not(target_family = "wasm"))]
12use std::time::{Instant, SystemTime, UNIX_EPOCH};
13
14#[cfg(target_family = "wasm")]
15use wasmtimer::std::{Instant, SystemTime, UNIX_EPOCH};
16
17/// Wrapper trait for `Fn(Duration) -> RawMessage`.
18pub trait SocketHeartbeatPingFn: Fn(Duration) -> RawMessage + Sync + Send {}
19impl<F> SocketHeartbeatPingFn for F where F: Fn(Duration) -> RawMessage + Sync + Send {}
20pub type SocketHeartbeatPingFnT = dyn SocketHeartbeatPingFn<Output = RawMessage>;
21
22impl std::fmt::Debug for SocketHeartbeatPingFnT {
23    fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        Ok(())
25    }
26}
27
28/// Socket configuration.
29#[derive(Debug, Clone)]
30pub struct SocketConfig {
31    /// Duration between each heartbeat check.
32    pub heartbeat: Duration,
33    /// Duration before the keep-alive will fail if there are no new stream messages.
34    pub timeout: Duration,
35    /// Convert 'current time as duration since UNIX_EPOCH' into a ping message (on heartbeat).
36    /// This may be useful for manually implementing Ping/Pong messages via `RawMessage::Text` or `RawMessage::Binary`
37    /// if Ping/Pong are not available for your socket (e.g. in browser).
38    /// The default function outputs a standard `RawMessage::Ping`, with the payload set to the timestamp in milliseconds in
39    /// big-endian bytes.
40    pub heartbeat_ping_msg_fn: Arc<dyn SocketHeartbeatPingFn>,
41}
42
43impl Default for SocketConfig {
44    fn default() -> Self {
45        Self {
46            heartbeat: Duration::from_secs(5),
47            timeout: Duration::from_secs(10),
48            heartbeat_ping_msg_fn: Arc::new(|timestamp: Duration| {
49                let timestamp = timestamp.as_millis();
50                let bytes = timestamp.to_be_bytes();
51                RawMessage::Ping(bytes.to_vec().into())
52            }),
53        }
54    }
55}
56
57#[derive(Debug, Clone)]
58pub enum CloseCode {
59    /// Indicates a normal closure, meaning that the purpose for
60    /// which the connection was established has been fulfilled.
61    Normal,
62    /// Indicates that an endpoint is "going away", such as a server
63    /// going down or a browser having navigated away from a page.
64    Away,
65    /// Indicates that an endpoint is terminating the connection due
66    /// to a protocol error.
67    Protocol,
68    /// Indicates that an endpoint is terminating the connection
69    /// because it has received a type of data it cannot accept (e.g., an
70    /// endpoint that understands only text data MAY send this if it
71    /// receives a binary message).
72    Unsupported,
73    /// Indicates that no status code was included in a closing frame. This
74    /// close code makes it possible to use a single method, `on_close` to
75    /// handle even cases where no close code was provided.
76    Status,
77    /// Indicates an abnormal closure. If the abnormal closure was due to an
78    /// error, this close code will not be used. Instead, the `on_error` method
79    /// of the handler will be called with the error. However, if the connection
80    /// is simply dropped, without an error, this close code will be sent to the
81    /// handler.
82    Abnormal,
83    /// Indicates that an endpoint is terminating the connection
84    /// because it has received data within a message that was not
85    /// consistent with the type of the message (e.g., non-UTF-8 \[RFC3629\]
86    /// data within a text message).
87    Invalid,
88    /// Indicates that an endpoint is terminating the connection
89    /// because it has received a message that violates its policy.  This
90    /// is a generic status code that can be returned when there is no
91    /// other more suitable status code (e.g., Unsupported or Size) or if there
92    /// is a need to hide specific details about the policy.
93    Policy,
94    /// Indicates that an endpoint is terminating the connection
95    /// because it has received a message that is too big for it to
96    /// process.
97    Size,
98    /// Indicates that an endpoint (client) is terminating the
99    /// connection because it has expected the server to negotiate one or
100    /// more extension, but the server didn't return them in the response
101    /// message of the WebSocket handshake.  The list of extensions that
102    /// are needed should be given as the reason for closing.
103    /// Note that this status code is not used by the server, because it
104    /// can fail the WebSocket handshake instead.
105    Extension,
106    /// Indicates that a server is terminating the connection because
107    /// it encountered an unexpected condition that prevented it from
108    /// fulfilling the request.
109    Error,
110    /// Indicates that the server is restarting. A client may choose to reconnect,
111    /// and if it does, it should use a randomized delay of 5-30 seconds between attempts.
112    Restart,
113    /// Indicates that the server is overloaded and the client should either connect
114    /// to a different IP (when multiple targets exist), or reconnect to the same IP
115    /// when a user has performed an action.
116    Again,
117    #[doc(hidden)]
118    Tls,
119    #[doc(hidden)]
120    Reserved(u16),
121    #[doc(hidden)]
122    Iana(u16),
123    #[doc(hidden)]
124    Library(u16),
125    #[doc(hidden)]
126    Bad(u16),
127}
128
129impl From<CloseCode> for u16 {
130    fn from(code: CloseCode) -> u16 {
131        use self::CloseCode::*;
132        match code {
133            Normal => 1000,
134            Away => 1001,
135            Protocol => 1002,
136            Unsupported => 1003,
137            Status => 1005,
138            Abnormal => 1006,
139            Invalid => 1007,
140            Policy => 1008,
141            Size => 1009,
142            Extension => 1010,
143            Error => 1011,
144            Restart => 1012,
145            Again => 1013,
146            Tls => 1015,
147            Reserved(code) => code,
148            Iana(code) => code,
149            Library(code) => code,
150            Bad(code) => code,
151        }
152    }
153}
154
155impl From<u16> for CloseCode {
156    fn from(code: u16) -> Self {
157        use self::CloseCode::*;
158
159        match code {
160            1000 => Normal,
161            1001 => Away,
162            1002 => Protocol,
163            1003 => Unsupported,
164            1005 => Status,
165            1006 => Abnormal,
166            1007 => Invalid,
167            1008 => Policy,
168            1009 => Size,
169            1010 => Extension,
170            1011 => Error,
171            1012 => Restart,
172            1013 => Again,
173            1015 => Tls,
174            1..=999 => Bad(code),
175            1016..=2999 => Reserved(code),
176            3000..=3999 => Iana(code),
177            4000..=4999 => Library(code),
178            _ => Bad(code),
179        }
180    }
181}
182
183#[derive(Debug, Clone)]
184pub struct CloseFrame {
185    pub code: CloseCode,
186    pub reason: Utf8Bytes,
187}
188
189#[derive(Debug, Clone)]
190pub enum Message {
191    Text(Utf8Bytes),
192    Binary(Bytes),
193    Close(Option<CloseFrame>),
194}
195
196#[derive(Debug, Clone)]
197pub enum RawMessage {
198    Text(Utf8Bytes),
199    Binary(Bytes),
200    Ping(Bytes),
201    Pong(Bytes),
202    Close(Option<CloseFrame>),
203}
204
205impl From<Message> for RawMessage {
206    fn from(message: Message) -> Self {
207        match message {
208            Message::Text(text) => Self::Text(text),
209            Message::Binary(bytes) => Self::Binary(bytes),
210            Message::Close(frame) => Self::Close(frame.map(CloseFrame::from)),
211        }
212    }
213}
214
215/// Possible states of a submitted message.
216#[derive(Debug, Copy, Clone, Eq, PartialEq)]
217pub enum MessageStatus {
218    /// Message is in the process of being sent.
219    Sending,
220    /// Message was successfully sent.
221    Sent,
222    /// Message failed sending.
223    Failed,
224}
225
226/// Signal that listens to the current `MessageStatus` of a submitted message.
227#[derive(Debug, Clone)]
228pub struct MessageSignal {
229    signal: Arc<AtomicU8>,
230}
231
232impl MessageSignal {
233    /// Makes a new `MessageSignal` that starts with the specified status.
234    ///
235    /// Useful for creating [`MessageStatus::Failed`] statuses without actually trying to send a message.
236    pub fn new(status: MessageStatus) -> Self {
237        let signal = Self::default();
238        signal.set(status);
239        signal
240    }
241
242    /// Reads the signal's [`MessageStatus`].
243    pub fn status(&self) -> MessageStatus {
244        match self.signal.load(Ordering::Acquire) {
245            0u8 => MessageStatus::Sending,
246            1u8 => MessageStatus::Sent,
247            _ => MessageStatus::Failed,
248        }
249    }
250
251    /// Sets the signal's [`MessageStatus`].
252    ///
253    /// This is crate-private so clones of a signal in the wild cannot lie to each other.
254    pub(crate) fn set(&self, status: MessageStatus) {
255        match status {
256            MessageStatus::Sending => self.signal.store(0u8, Ordering::Release),
257            MessageStatus::Sent => self.signal.store(1u8, Ordering::Release),
258            MessageStatus::Failed => self.signal.store(2u8, Ordering::Release),
259        }
260    }
261}
262
263impl Default for MessageSignal {
264    fn default() -> Self {
265        Self {
266            signal: Arc::new(AtomicU8::new(0u8)),
267        }
268    }
269}
270
271/// Raw message with associated message signal.
272#[derive(Debug, Clone)]
273pub struct InRawMessage {
274    /// We use an `Option` for the message so that we can both extract messages for sending and implement `Drop`.
275    message: Option<RawMessage>,
276    signal: Option<MessageSignal>,
277}
278
279impl InRawMessage {
280    pub fn new(message: RawMessage) -> Self {
281        Self {
282            message: Some(message),
283            signal: Some(MessageSignal::default()),
284        }
285    }
286
287    pub(crate) fn take_message(&mut self) -> Option<RawMessage> {
288        self.message.take()
289    }
290
291    pub(crate) fn set_signal(&mut self, state: MessageStatus) {
292        let Some(signal) = &self.signal else {
293            return;
294        };
295        signal.set(state);
296        self.signal = None;
297    }
298}
299
300impl Drop for InRawMessage {
301    fn drop(&mut self) {
302        // If the signal is still present in the message when dropping, then we need to mark its state as failed.
303        self.set_signal(MessageStatus::Failed);
304    }
305}
306
307/// Message with associated message signal.
308#[derive(Debug, Clone)]
309pub struct InMessage {
310    /// We use an `Option` for the message so that we can both convert to `InMessage`s and implement `Drop`.
311    pub(crate) message: Option<Message>,
312    signal: Option<MessageSignal>,
313}
314
315impl InMessage {
316    pub fn new(message: Message) -> Self {
317        Self {
318            message: Some(message),
319            signal: Some(MessageSignal::default()),
320        }
321    }
322
323    pub fn clone_signal(&self) -> Option<MessageSignal> {
324        self.signal.clone()
325    }
326}
327
328impl From<InMessage> for InRawMessage {
329    fn from(mut inmessage: InMessage) -> Self {
330        Self {
331            message: inmessage.message.take().map(|msg| msg.into()),
332            signal: inmessage.signal.take(),
333        }
334    }
335}
336
337impl Drop for InMessage {
338    fn drop(&mut self) {
339        // If the signal is still present in the message when dropping, then we need to mark its state as failed.
340        let Some(signal) = self.signal.take() else {
341            return;
342        };
343        signal.set(MessageStatus::Failed);
344    }
345}
346
347#[derive(Debug)]
348struct SinkActor<M, S>
349where
350    M: From<RawMessage>,
351    S: SinkExt<M, Error = WSError> + Unpin,
352{
353    receiver: async_channel::Receiver<InRawMessage>,
354    abort_receiver: async_channel::Receiver<()>,
355    sink: S,
356    phantom: PhantomData<M>,
357}
358
359impl<M, S> SinkActor<M, S>
360where
361    M: From<RawMessage>,
362    S: SinkExt<M, Error = WSError> + Unpin,
363{
364    async fn run(&mut self) -> Result<(), WSError> {
365        loop {
366            futures::select! {
367                res = self.receiver.recv().fuse() => {
368                    let Ok(mut inmessage) = res else {
369                        break;
370                    };
371                    let Some(message) = inmessage.take_message() else {
372                        continue;
373                    };
374                    tracing::trace!("sending message: {:?}", message);
375                    match self.sink.send(M::from(message)).await {
376                        Ok(()) => inmessage.set_signal(MessageStatus::Sent),
377                        Err(err) => {
378                            inmessage.set_signal(MessageStatus::Failed);
379                            tracing::warn!(?err, "sink send failed");
380                            return Err(err);
381                        }
382                    }
383                },
384                _ = &mut self.abort_receiver.recv().fuse() => {
385                    break;
386                },
387            }
388        }
389        Ok(())
390    }
391}
392
393#[derive(Debug, Clone)]
394pub struct Sink {
395    sender: async_channel::Sender<InRawMessage>,
396}
397
398impl Sink {
399    fn new<M, S>(
400        sink: S,
401        abort_receiver: async_channel::Receiver<()>,
402        handle: impl enfync::Handle,
403    ) -> (enfync::PendingResult<Result<(), WSError>>, Self)
404    where
405        M: From<RawMessage> + Send + 'static,
406        S: SinkExt<M, Error = WSError> + Unpin + Send + 'static,
407    {
408        let (sender, receiver) = async_channel::unbounded();
409        let mut actor = SinkActor {
410            receiver,
411            abort_receiver,
412            sink,
413            phantom: Default::default(),
414        };
415        let future = handle.spawn(async move { actor.run().await });
416        (future, Self { sender })
417    }
418
419    pub fn is_closed(&self) -> bool {
420        self.sender.is_closed()
421    }
422
423    pub async fn send(
424        &self,
425        inmessage: InMessage,
426    ) -> Result<(), async_channel::SendError<InRawMessage>> {
427        self.sender.send(inmessage.into()).await
428    }
429
430    pub(crate) async fn send_raw(
431        &self,
432        inmessage: InRawMessage,
433    ) -> Result<(), async_channel::SendError<InRawMessage>> {
434        self.sender.send(inmessage).await
435    }
436}
437
438#[derive(Debug)]
439struct StreamActor<M, S>
440where
441    M: Into<RawMessage>,
442    S: StreamExt<Item = Result<M, WSError>> + Unpin,
443{
444    sender: async_channel::Sender<Result<Message, WSError>>,
445    stream: S,
446    last_alive: Arc<Mutex<Instant>>,
447}
448
449impl<M, S> StreamActor<M, S>
450where
451    M: Into<RawMessage>,
452    S: StreamExt<Item = Result<M, WSError>> + Unpin,
453{
454    async fn run(mut self) {
455        while let Some(result) = self.stream.next().await {
456            let result = result.map(M::into);
457            tracing::trace!("received message: {:?}", result);
458            *self.last_alive.lock().await = Instant::now();
459
460            let mut closing = false;
461            let message = match result {
462                Ok(message) => Ok(match message {
463                    RawMessage::Text(text) => Message::Text(text),
464                    RawMessage::Binary(bytes) => Message::Binary(bytes),
465                    RawMessage::Ping(_bytes) => continue,
466                    RawMessage::Pong(bytes) => {
467                        if let Ok(bytes) = (*bytes).try_into() {
468                            let bytes: [u8; 16] = bytes;
469                            let timestamp = u128::from_be_bytes(bytes);
470                            let timestamp = Duration::from_millis(timestamp as u64); // TODO: handle overflow
471                            let latency = SystemTime::now()
472                                .duration_since(UNIX_EPOCH + timestamp)
473                                .unwrap_or_default();
474                            // TODO: handle time zone
475                            tracing::trace!("latency: {}ms", latency.as_millis());
476                        }
477
478                        continue;
479                    }
480                    RawMessage::Close(frame) => {
481                        closing = true;
482                        Message::Close(frame)
483                    }
484                }),
485                Err(err) => Err(err), // maybe early return here?
486            };
487            if self.sender.send(message).await.is_err() {
488                // In websockets, you always echo a close frame received from your connection partner back to them.
489                // This means a normal close sequence will always end with the following line emitted by the socket of
490                // the client/server that initiated the close sequence (in response to the close frame echoed by their
491                // partner).
492                if closing {
493                    tracing::trace!("stream is closed");
494                } else {
495                    tracing::warn!("failed to forward message, stream is disconnected");
496                }
497                break;
498            };
499        }
500    }
501}
502
503#[derive(Debug)]
504pub struct Stream {
505    receiver: async_channel::Receiver<Result<Message, WSError>>,
506}
507
508impl Stream {
509    fn new<M, S>(
510        stream: S,
511        last_alive: Arc<Mutex<Instant>>,
512        handle: impl enfync::Handle,
513    ) -> (enfync::PendingResult<()>, Self)
514    where
515        M: Into<RawMessage> + std::fmt::Debug + Send + 'static,
516        S: StreamExt<Item = Result<M, WSError>> + Unpin + Send + 'static,
517    {
518        let (sender, receiver) = async_channel::unbounded();
519        let actor = StreamActor {
520            sender,
521            stream,
522            last_alive,
523        };
524        let future = handle.spawn(actor.run());
525
526        (future, Self { receiver })
527    }
528
529    pub async fn recv(&mut self) -> Option<Result<Message, WSError>> {
530        self.receiver.recv().await.ok()
531    }
532}
533
534#[derive(Debug)]
535pub struct Socket {
536    pub sink: Sink,
537    pub stream: Stream,
538    sink_result_receiver: Option<async_channel::Receiver<Result<(), WSError>>>,
539}
540
541impl Socket {
542    pub fn new<M, E, S>(socket: S, config: SocketConfig, handle: impl enfync::Handle) -> Self
543    where
544        M: Into<RawMessage> + From<RawMessage> + std::fmt::Debug + Send + 'static,
545        E: Into<WSError> + std::error::Error,
546        S: SinkExt<M, Error = E> + Unpin + StreamExt<Item = Result<M, E>> + Unpin + Send + 'static,
547    {
548        let last_alive = Instant::now();
549        let last_alive = Arc::new(Mutex::new(last_alive));
550        let (sink, stream) = socket.sink_err_into().err_into().split();
551        let (sink_abort_sender, sink_abort_receiver) = async_channel::bounded(1usize);
552        let ((mut sink_future, sink), (mut stream_future, stream)) = (
553            Sink::new(sink, sink_abort_receiver, handle.clone()),
554            Stream::new(stream, last_alive.clone(), handle.clone()),
555        );
556        let (hearbeat_abort_sender, hearbeat_abort_receiver) = async_channel::bounded(1usize);
557        let sink_clone = sink.clone();
558        handle.spawn(async move {
559            socket_heartbeat(sink_clone, config, hearbeat_abort_receiver, last_alive).await
560        });
561
562        let (sink_result_sender, sink_result_receiver) = async_channel::bounded(1usize);
563        handle.spawn(async move {
564            let _ = stream_future.extract().await;
565            let _ = sink_abort_sender.send_blocking(());
566            let _ = hearbeat_abort_sender.send_blocking(());
567            let _ = sink_result_sender.send_blocking(
568                sink_future
569                    .extract()
570                    .await
571                    .unwrap_or(Err(WSError::AlreadyClosed)),
572            );
573        });
574
575        Self {
576            sink,
577            stream,
578            sink_result_receiver: Some(sink_result_receiver),
579        }
580    }
581
582    pub async fn send(
583        &self,
584        message: InMessage,
585    ) -> Result<(), async_channel::SendError<InRawMessage>> {
586        self.sink.send(message).await
587    }
588
589    pub async fn send_raw(
590        &self,
591        message: InRawMessage,
592    ) -> Result<(), async_channel::SendError<InRawMessage>> {
593        self.sink.send_raw(message).await
594    }
595
596    pub async fn recv(&mut self) -> Option<Result<Message, WSError>> {
597        self.stream.recv().await
598    }
599
600    pub(crate) async fn await_sink_close(&mut self) -> Result<(), WSError> {
601        let Some(sink_result_receiver) = self.sink_result_receiver.take() else {
602            return Err(WSError::AlreadyClosed);
603        };
604        sink_result_receiver
605            .recv()
606            .await
607            .unwrap_or(Err(WSError::AlreadyClosed))
608    }
609}
610
611#[cfg(not(target_family = "wasm"))]
612async fn socket_heartbeat(
613    sink: Sink,
614    config: SocketConfig,
615    abort_receiver: async_channel::Receiver<()>,
616    last_alive: Arc<Mutex<Instant>>,
617) {
618    let sleep = tokio::time::sleep(config.heartbeat);
619    tokio::pin!(sleep);
620
621    loop {
622        tokio::select! {
623            _ = &mut sleep => {
624                let Some(next_sleep_duration) = handle_heartbeat_sleep_elapsed(&sink, &config, &last_alive).await else {
625                    break;
626                };
627                sleep.as_mut().reset(tokio::time::Instant::now() + next_sleep_duration);
628            }
629            _ = abort_receiver.recv() => break,
630        }
631    }
632}
633
634#[cfg(target_family = "wasm")]
635async fn socket_heartbeat(
636    sink: Sink,
637    config: SocketConfig,
638    abort_receiver: async_channel::Receiver<()>,
639    last_alive: Arc<Mutex<Instant>>,
640) {
641    let mut sleep_duration = config.heartbeat;
642
643    loop {
644        // It is better to use Sleep::reset(), but we can't do it here because fuse() consumes the sleep
645        // and we need futures::select since we can't use tokio on WASM targets.
646        let sleep = wasmtimer::tokio::sleep(sleep_duration).fuse();
647        futures::pin_mut!(sleep);
648        futures::select! {
649            _ = sleep => {
650                let Some(next_sleep_duration) = handle_heartbeat_sleep_elapsed(&sink, &config, &last_alive).await else {
651                    break;
652                };
653                sleep_duration = next_sleep_duration;
654            }
655            _ = &mut abort_receiver.recv().fuse() => break,
656        }
657    }
658}
659
660async fn handle_heartbeat_sleep_elapsed(
661    sink: &Sink,
662    config: &SocketConfig,
663    last_alive: &Arc<Mutex<Instant>>,
664) -> Option<Duration> {
665    // check last alive
666    let elapsed_since_last_alive = last_alive.lock().await.elapsed();
667    if elapsed_since_last_alive > config.timeout {
668        tracing::info!("closing connection due to timeout");
669        let _ = sink
670            .send_raw(InRawMessage::new(RawMessage::Close(Some(CloseFrame {
671                code: CloseCode::Abnormal,
672                reason: "remote partner is inactive".into(),
673            }))))
674            .await;
675        return None;
676    } else if elapsed_since_last_alive < config.heartbeat {
677        // todo: this branch will needlessly fire at least once per heartbeat for idle connections since
678        //       Pongs arrive after some delay
679        return Some(config.heartbeat.saturating_sub(elapsed_since_last_alive));
680    }
681
682    // send ping
683    let timestamp = SystemTime::now()
684        .duration_since(UNIX_EPOCH)
685        .unwrap_or_default();
686    if sink
687        .send_raw(InRawMessage::new((config.heartbeat_ping_msg_fn)(timestamp)))
688        .await
689        .is_err()
690    {
691        return None;
692    }
693
694    Some(config.heartbeat)
695}