Skip to main content

burn_central_client/
websocket.rs

1use std::{thread, time::Duration};
2
3use reqwest::header::COOKIE;
4use serde::{Serialize, de::DeserializeOwned};
5
6use thiserror::Error;
7
8use tungstenite::{
9    Message, Utf8Bytes, WebSocket, client::IntoClientRequest, connect, stream::MaybeTlsStream,
10};
11
12pub use crate::experiment::websocket::*;
13
14#[derive(Error, Debug)]
15#[allow(clippy::enum_variant_names)]
16pub enum WebSocketError {
17    #[error("Failed to connect WebSocket: {0}")]
18    ConnectionError(String),
19    #[error("WebSocket send error: {0}")]
20    SendError(String),
21    #[error("WebSocket receive error: {0}")]
22    ReceiveError(String),
23    #[error("WebSocket is not connected")]
24    NotConnected,
25    #[error("WebSocket cannot reconnect: {0}")]
26    CannotReconnect(String),
27    #[error("Serialization error: {0}")]
28    SerializationError(String),
29}
30
31const DEFAULT_RECONNECT_DELAY: Duration = Duration::from_millis(1000);
32
33type Socket = WebSocket<MaybeTlsStream<std::net::TcpStream>>;
34struct ConnectedSocket {
35    socket: Socket,
36    url: String,
37    cookie: String,
38}
39
40#[derive(Default)]
41pub struct WebSocketClient {
42    state: Option<ConnectedSocket>,
43}
44
45impl WebSocketClient {
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    #[allow(dead_code)]
51    pub fn is_connected(&self) -> bool {
52        self.state.is_some()
53    }
54
55    pub fn connect(&mut self, url: String, session_cookie: &str) -> Result<(), WebSocketError> {
56        let mut req = url
57            .clone()
58            .into_client_request()
59            .expect("Should be able to create a client request from the URL");
60
61        req.headers_mut().append(
62            COOKIE,
63            session_cookie
64                .parse()
65                .expect("Should be able to parse cookie header"),
66        );
67
68        let (mut socket, _) =
69            connect(req).map_err(|e| WebSocketError::ConnectionError(e.to_string()))?;
70
71        match socket.get_mut() {
72            MaybeTlsStream::Plain(stream) => stream.set_nonblocking(true),
73            MaybeTlsStream::NativeTls(stream) => stream.get_mut().set_nonblocking(true),
74            _ => unimplemented!("Other TLS streams are not supported"),
75        }
76        .map_err(|e| {
77            WebSocketError::ConnectionError(format!("Failed to set non-blocking mode: {e}"))
78        })?;
79
80        self.state = Some(ConnectedSocket {
81            socket,
82            url,
83            cookie: session_cookie.to_string(),
84        });
85        Ok(())
86    }
87
88    pub fn reconnect(&mut self) -> Result<(), WebSocketError> {
89        if let Some(socket) = self.state.take() {
90            self.connect(socket.url, &socket.cookie)
91        } else {
92            Err(WebSocketError::CannotReconnect(
93                "The websocket was never opened so it cannot be reconnected".to_string(),
94            ))
95        }
96    }
97
98    /// Sends a message over the WebSocket connection. This is a non-blocking call.
99    /// If sending fails, it attempts to reconnect and resend the message.
100    /// Returns an error if both attempts fail.
101    pub fn send<I: Serialize>(&mut self, message: I) -> Result<(), WebSocketError> {
102        let socket = self.active_socket()?;
103
104        let json = serde_json::to_string(&message)
105            .map_err(|e| WebSocketError::SerializationError(e.to_string()))?;
106
107        match Self::attempt_send(socket, &json) {
108            Ok(_) => Ok(()),
109            Err(_) => {
110                tracing::debug!("WebSocket send failed, attempting to reconnect...");
111                thread::sleep(DEFAULT_RECONNECT_DELAY);
112                self.reconnect()?;
113
114                let socket = self.active_socket()?;
115                Self::attempt_send(socket, &json)
116            }
117        }
118    }
119
120    /// Attempts to receive a message from the WebSocket. This is a non-blocking call.
121    /// Returns `Ok(None)` if no message is available.
122    pub fn receive<T: DeserializeOwned>(&mut self) -> Result<Option<T>, WebSocketError> {
123        let socket = self.active_socket()?;
124
125        match socket.read() {
126            Ok(msg) => match msg {
127                Message::Text(text) => {
128                    let deserialized: T = serde_json::from_str(&text)
129                        .map_err(|e| WebSocketError::SerializationError(e.to_string()))?;
130                    Ok(Some(deserialized))
131                }
132                Message::Binary(_) => {
133                    tracing::warn!("Received unexpected binary message");
134                    Ok(None)
135                }
136                Message::Ping(_) | Message::Pong(_) | Message::Close(_) => Ok(None),
137                Message::Frame(frame) => {
138                    tracing::warn!("Received unexpected frame message: {:?}", frame);
139                    Ok(None)
140                }
141            },
142            Err(tungstenite::Error::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
143                // No messages available
144                Ok(None)
145            }
146            Err(e) => Err(WebSocketError::ReceiveError(e.to_string())),
147        }
148    }
149
150    fn attempt_send(socket: &mut Socket, payload: &str) -> Result<(), WebSocketError> {
151        socket
152            .send(Message::Text(Utf8Bytes::from(payload)))
153            .map_err(|e| WebSocketError::SendError(e.to_string()))
154    }
155
156    /// Closes the WebSocket connection gracefully. This is a non-blocking call.
157    pub fn close(&mut self) -> Result<(), WebSocketError> {
158        let socket = self.active_socket()?;
159        socket
160            .close(None)
161            .map_err(|e| WebSocketError::SendError(e.to_string()))
162    }
163
164    /// Waits until the WebSocket connection is fully closed. This is a blocking call that will return once the connection is closed.
165    pub fn wait_until_closed(&mut self) -> Result<(), WebSocketError> {
166        let socket = self.active_socket()?;
167        match socket.get_mut() {
168            MaybeTlsStream::Plain(stream) => stream.set_nonblocking(false),
169            MaybeTlsStream::NativeTls(stream) => stream.get_mut().set_nonblocking(false),
170            _ => unimplemented!("Other TLS streams are not supported"),
171        }
172        .map_err(|e| {
173            WebSocketError::ConnectionError(format!("Failed to set blocking mode: {e}"))
174        })?;
175        loop {
176            match socket.read() {
177                Ok(_) => {}
178                Err(tungstenite::Error::ConnectionClosed | tungstenite::Error::AlreadyClosed) => {
179                    tracing::debug!("WebSocket connection closed");
180                    break;
181                }
182                Err(e) => {
183                    tracing::error!("WebSocket read error while waiting until closed: {e}");
184                    return Err(WebSocketError::SendError(e.to_string()));
185                }
186            }
187        }
188        Ok(())
189    }
190
191    fn active_socket(&mut self) -> Result<&mut Socket, WebSocketError> {
192        if let Some(socket) = self.state.as_mut() {
193            Ok(&mut socket.socket)
194        } else {
195            Err(WebSocketError::NotConnected)
196        }
197    }
198}
199
200impl Drop for WebSocketClient {
201    fn drop(&mut self) {
202        _ = self.close();
203    }
204}