pws/
lib.rs

1use futures::stream::{SplitSink, SplitStream};
2use futures::{SinkExt, StreamExt};
3use log::{error, info};
4use std::time::Duration;
5use thiserror::Error;
6use tokio::net::TcpStream;
7use tokio::sync::broadcast::error::{RecvError as BroadcastRecvError, SendError};
8use tokio::sync::oneshot::error::RecvError as OneshotRecvError;
9use tokio::sync::{broadcast, mpsc, oneshot};
10use tokio::task::JoinError;
11use tokio_tungstenite::tungstenite::protocol::frame::Frame;
12use tokio_tungstenite::tungstenite::protocol::CloseFrame;
13pub use tokio_tungstenite::tungstenite::Message as TungsteniteMessage;
14use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
15pub use url::Url;
16
17const INITIAL_BACKOFF_MILLIS: u64 = 100;
18const MAX_BACKOFF_MILLIS: u64 = 5 * 60 * 1000;
19const CHANNEL_CAPACITY: usize = 32;
20
21type WsWrite = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, TungsteniteMessage>;
22type WsRead = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
23type WsError = tokio_tungstenite::tungstenite::Error;
24
25#[derive(Debug, Error)]
26pub enum Error {
27    #[error("tungstenite: {0}")]
28    Tungstenite(WsError),
29    #[error("broadcast::recv: {0}")]
30    BroadcastRecv(BroadcastRecvError),
31    #[error("oneshot::recv: {0}")]
32    OneshotRecv(OneshotRecvError),
33    #[error("join: {0}")]
34    Join(JoinError),
35    #[error("broadcast::send: {0}")]
36    Send(SendError<Message>),
37}
38
39pub type WsMessageSender = mpsc::Sender<Message>;
40pub type WsMessageReceiver = broadcast::Receiver<Message>;
41
42#[derive(Debug, Eq, PartialEq, Clone)]
43pub enum Message {
44    Text(String),
45    Binary(Vec<u8>),
46    Ping(Vec<u8>),
47    Pong(Vec<u8>),
48    Close(Option<CloseFrame<'static>>),
49    Frame(Frame),
50    /// A message to notify that the connection was opened.
51    /// Note that this cannot be sent, it is received only.
52    /// This variant is only part of the Pws crate and not of
53    /// tungstenite.
54    ConnectionOpened,
55    /// A message to notify that the connection was closed.
56    /// Note that this cannot be sent, it is received only.
57    /// This variant is only part of the Pws crate and not of
58    /// tungstenite.
59    ConnectionClosed,
60}
61
62pub async fn connect_persistent_websocket_async(
63    url: Url,
64) -> Result<(WsMessageSender, WsMessageReceiver), Error> {
65    let (msg_tx_out, msg_rx_out) = mpsc::channel::<Message>(CHANNEL_CAPACITY);
66    let (msg_tx_in, msg_rx_in) = broadcast::channel::<Message>(CHANNEL_CAPACITY);
67    let (first_conn_tx, first_conn_rx) = oneshot::channel();
68    tokio::spawn(setup_persistent_websocket(
69        url.into(),
70        msg_tx_in,
71        msg_rx_out,
72        first_conn_tx,
73    ));
74    if let Some(first_conn_error) = first_conn_rx.await? {
75        return Err(Error::Tungstenite(first_conn_error));
76    }
77    Ok((msg_tx_out, msg_rx_in))
78}
79
80async fn setup_persistent_websocket(
81    url: Url,
82    mut msg_tx_in: broadcast::Sender<Message>,
83    mut msg_rx_out: mpsc::Receiver<Message>,
84    first_conn_tx: oneshot::Sender<Option<WsError>>,
85) {
86    let mut connection_count = 0;
87    let mut first_conn_tx = Some(first_conn_tx);
88    loop {
89        info!("connecting to {url}");
90        let result = listen_for_persistent_ws_messages(
91            url.clone(),
92            &mut msg_tx_in,
93            &mut msg_rx_out,
94            &mut first_conn_tx,
95        )
96        .await;
97        if let Err(e) = result {
98            error!("error during ws connection: {e}");
99        }
100        info!("disconnected from {url}");
101        connection_count += 1;
102        tokio::time::sleep(get_backoff(connection_count)).await;
103    }
104}
105
106async fn listen_for_persistent_ws_messages(
107    url: Url,
108    msg_tx_in: &mut broadcast::Sender<Message>,
109    msg_rx_out: &mut mpsc::Receiver<Message>,
110    first_conn_tx: &mut Option<oneshot::Sender<Option<WsError>>>,
111) -> Result<(), Error> {
112    let connection_result = connect_async(url).await;
113    let (socket, _) = match connection_result {
114        Ok(c) => {
115            if let Some(first_conn_tx) = first_conn_tx.take() {
116                first_conn_tx.send(None).expect("first_conn_rx dropped")
117            }
118            c
119        }
120        Err(e) => {
121            if let Some(first_conn_tx) = first_conn_tx.take() {
122                first_conn_tx.send(Some(e)).expect("first_conn_rx dropped");
123            }
124            return Ok(());
125        }
126    };
127    let (mut ws_tx, mut ws_rx) = socket.split();
128    msg_tx_in.send(Message::ConnectionOpened)?;
129    let result = process_connection(&mut ws_tx, &mut ws_rx, msg_tx_in, msg_rx_out).await;
130    msg_tx_in.send(Message::ConnectionClosed)?;
131    result?;
132    Ok(())
133}
134
135async fn process_connection(
136    ws_tx: &mut WsWrite,
137    ws_rx: &mut WsRead,
138    msg_tx_in: &mut broadcast::Sender<Message>,
139    msg_rx_out: &mut mpsc::Receiver<Message>,
140) -> Result<(), Error> {
141    loop {
142        tokio::select! {
143            Some(incoming_msg) = ws_rx.next() => {
144                let should_close = handle_incoming_message(incoming_msg, ws_tx, msg_tx_in).await?;
145                if should_close {
146                    info!("closing connection");
147                    break;
148                }
149            },
150            Some(outgoing_msg) = msg_rx_out.recv() => {
151                let Some(outgoing_msg) = outgoing_msg.to_tungstenite() else {
152                    continue;
153                };
154                ws_tx.send(outgoing_msg).await?;
155            }
156        }
157    }
158    Ok(())
159}
160
161async fn handle_incoming_message(
162    message: Result<TungsteniteMessage, WsError>,
163    ws_tx: &mut WsWrite,
164    msg_tx_in: &mut broadcast::Sender<Message>,
165) -> Result<bool, Error> {
166    let message = match message {
167        Ok(m) => m,
168        Err(e) => {
169            error!("connection error: {e}");
170            return Ok(true);
171        }
172    };
173    match message {
174        TungsteniteMessage::Ping(_) => {
175            #[cfg(feature = "pong")]
176            if let Err(e) = ws_tx.send(TungsteniteMessage::Pong(vec![])).await {
177                error!("error sending pong: {e}");
178            }
179            return Ok(false);
180        }
181        TungsteniteMessage::Close(frame) => {
182            info!("received socket close signal: {:#?}", frame);
183            return Ok(true);
184        }
185        _ => {}
186    }
187    msg_tx_in.send(message.into())?;
188    Ok(false)
189}
190
191fn get_backoff(attempt: u64) -> Duration {
192    let backoff = INITIAL_BACKOFF_MILLIS * attempt.pow(2);
193    Duration::from_millis(backoff.min(MAX_BACKOFF_MILLIS))
194}
195
196impl From<WsError> for Error {
197    fn from(e: WsError) -> Self {
198        Self::Tungstenite(e)
199    }
200}
201
202impl From<BroadcastRecvError> for Error {
203    fn from(e: BroadcastRecvError) -> Self {
204        Self::BroadcastRecv(e)
205    }
206}
207
208impl From<OneshotRecvError> for Error {
209    fn from(e: OneshotRecvError) -> Self {
210        Self::OneshotRecv(e)
211    }
212}
213
214impl From<JoinError> for Error {
215    fn from(e: JoinError) -> Self {
216        Self::Join(e)
217    }
218}
219
220impl From<SendError<Message>> for Error {
221    fn from(e: SendError<Message>) -> Self {
222        Self::Send(e)
223    }
224}
225
226impl Message {
227    fn to_tungstenite(self) -> Option<TungsteniteMessage> {
228        match self {
229            Message::Text(v) => Some(TungsteniteMessage::Text(v)),
230            Message::Binary(v) => Some(TungsteniteMessage::Binary(v)),
231            Message::Ping(v) => Some(TungsteniteMessage::Ping(v)),
232            Message::Pong(v) => Some(TungsteniteMessage::Pong(v)),
233            Message::Close(v) => Some(TungsteniteMessage::Close(v)),
234            Message::Frame(v) => Some(TungsteniteMessage::Frame(v)),
235            Message::ConnectionOpened => None,
236            Message::ConnectionClosed => None,
237        }
238    }
239}
240
241impl From<TungsteniteMessage> for Message {
242    fn from(message: TungsteniteMessage) -> Self {
243        match message {
244            TungsteniteMessage::Text(v) => Message::Text(v),
245            TungsteniteMessage::Binary(v) => Message::Binary(v),
246            TungsteniteMessage::Ping(v) => Message::Ping(v),
247            TungsteniteMessage::Pong(v) => Message::Pong(v),
248            TungsteniteMessage::Close(v) => Message::Close(v),
249            TungsteniteMessage::Frame(v) => Message::Frame(v),
250        }
251    }
252}