birdie/
web_socket.rs

1use futures_util::{SinkExt, StreamExt};
2use tokio::sync::mpsc;
3use tokio_tungstenite::{connect_async, tungstenite};
4use tracing::{debug, error, info};
5
6#[derive(Clone, Copy, Debug)]
7pub enum ConnectionStatus {
8    Connected,
9    PingReceived,
10    PoingSent,
11    Disconnected,
12}
13
14pub struct WebSocketClient {
15    endpoint: String,
16    read_channel: mpsc::Receiver<String>,
17    write_channel: mpsc::Sender<String>,
18    status_channel: mpsc::Sender<ConnectionStatus>,
19}
20
21impl WebSocketClient {
22    pub fn new(
23        endpoint: &str,
24        read_channel: mpsc::Receiver<String>,
25        write_channel: mpsc::Sender<String>,
26        status_channel: mpsc::Sender<ConnectionStatus>,
27    ) -> Self {
28        Self {
29            endpoint: endpoint.to_owned(),
30            read_channel,
31            write_channel,
32            status_channel,
33        }
34    }
35
36    pub async fn connect(mut self) -> Result<(), tungstenite::Error> {
37        let (stream, _) = connect_async(&self.endpoint).await?;
38        let (mut write, mut read) = stream.split();
39        let _ = self.status_channel.send(ConnectionStatus::Connected).await;
40
41        tokio::spawn(async move {
42            loop {
43                tokio::select! {
44                    Some(msg) = self.read_channel.recv() => {
45                        debug!("sending message to websocket: {msg:?}");
46                        let msg = tungstenite::Message::Text(msg);
47                        write.send(msg).await.unwrap_or_else(|err| {
48                            error!("websocket write error: {err}");
49                        })
50                    }
51                    Some(msg) = read.next() => {
52                        debug!("received message from websocket: {msg:?}");
53                        let msg = match msg {
54                            Ok(msg) => msg,
55                            Err(err) => {
56                                error!("websocket read error: {err}");
57                                break;
58                            }
59                        };
60
61                        match msg {
62                            tungstenite::Message::Text(msg) => {
63                                self.write_channel.send(msg).await.unwrap_or_else(|err| {
64                                    error!("write channel error: {err}");
65                                });
66                            }
67                            tungstenite::Message::Ping(payload) => {
68                                info!("ping received");
69                                self.status_channel.send(ConnectionStatus::PingReceived).await.unwrap_or_else(|err| {
70                                    error!("status channel error: {err}");
71                                });
72
73                                write.send(tungstenite::Message::Pong(payload)).await.unwrap_or_else(|err| {
74                                    error!("websocket write error: {err}");
75                                });
76                                info!("pong sent");
77
78                                self.status_channel.send(ConnectionStatus::PoingSent).await.unwrap_or_else(|err| {
79                                    error!("status channel error: {err}");
80                                });
81                            }
82                            tungstenite::Message::Close(_) => {
83                                self.status_channel.send(ConnectionStatus::Disconnected).await.unwrap_or_else(|err| {
84                                    error!("status channel error: {err}");
85                                });
86                                break;
87                            }
88                            _ => {
89                                error!("unexpected message: {msg}");
90                            }
91                        }
92                    }
93                }
94            }
95        });
96
97        Ok(())
98    }
99}