burn_central_client/
websocket.rs1use std::{thread, time::Duration};
2
3use reqwest::header::COOKIE;
4use serde::{Serialize, de::DeserializeOwned};
5
6use thiserror::Error;
7
8use tungstenite::{
9 Message, Utf8Bytes, WebSocket, client::IntoClientRequest, connect, stream::MaybeTlsStream,
10};
11
12pub use crate::experiment::websocket::*;
13
14#[derive(Error, Debug)]
15#[allow(clippy::enum_variant_names)]
16pub enum WebSocketError {
17 #[error("Failed to connect WebSocket: {0}")]
18 ConnectionError(String),
19 #[error("WebSocket send error: {0}")]
20 SendError(String),
21 #[error("WebSocket receive error: {0}")]
22 ReceiveError(String),
23 #[error("WebSocket is not connected")]
24 NotConnected,
25 #[error("WebSocket cannot reconnect: {0}")]
26 CannotReconnect(String),
27 #[error("Serialization error: {0}")]
28 SerializationError(String),
29}
30
31const DEFAULT_RECONNECT_DELAY: Duration = Duration::from_millis(1000);
32
33type Socket = WebSocket<MaybeTlsStream<std::net::TcpStream>>;
34struct ConnectedSocket {
35 socket: Socket,
36 url: String,
37 cookie: String,
38}
39
40#[derive(Default)]
41pub struct WebSocketClient {
42 state: Option<ConnectedSocket>,
43}
44
45impl WebSocketClient {
46 pub fn new() -> Self {
47 Self::default()
48 }
49
50 #[allow(dead_code)]
51 pub fn is_connected(&self) -> bool {
52 self.state.is_some()
53 }
54
55 pub fn connect(&mut self, url: String, session_cookie: &str) -> Result<(), WebSocketError> {
56 let mut req = url
57 .clone()
58 .into_client_request()
59 .expect("Should be able to create a client request from the URL");
60
61 req.headers_mut().append(
62 COOKIE,
63 session_cookie
64 .parse()
65 .expect("Should be able to parse cookie header"),
66 );
67
68 let (mut socket, _) =
69 connect(req).map_err(|e| WebSocketError::ConnectionError(e.to_string()))?;
70
71 match socket.get_mut() {
72 MaybeTlsStream::Plain(stream) => stream.set_nonblocking(true),
73 MaybeTlsStream::NativeTls(stream) => stream.get_mut().set_nonblocking(true),
74 _ => unimplemented!("Other TLS streams are not supported"),
75 }
76 .map_err(|e| {
77 WebSocketError::ConnectionError(format!("Failed to set non-blocking mode: {e}"))
78 })?;
79
80 self.state = Some(ConnectedSocket {
81 socket,
82 url,
83 cookie: session_cookie.to_string(),
84 });
85 Ok(())
86 }
87
88 pub fn reconnect(&mut self) -> Result<(), WebSocketError> {
89 if let Some(socket) = self.state.take() {
90 self.connect(socket.url, &socket.cookie)
91 } else {
92 Err(WebSocketError::CannotReconnect(
93 "The websocket was never opened so it cannot be reconnected".to_string(),
94 ))
95 }
96 }
97
98 pub fn send<I: Serialize>(&mut self, message: I) -> Result<(), WebSocketError> {
102 let socket = self.active_socket()?;
103
104 let json = serde_json::to_string(&message)
105 .map_err(|e| WebSocketError::SerializationError(e.to_string()))?;
106
107 match Self::attempt_send(socket, &json) {
108 Ok(_) => Ok(()),
109 Err(_) => {
110 tracing::debug!("WebSocket send failed, attempting to reconnect...");
111 thread::sleep(DEFAULT_RECONNECT_DELAY);
112 self.reconnect()?;
113
114 let socket = self.active_socket()?;
115 Self::attempt_send(socket, &json)
116 }
117 }
118 }
119
120 pub fn receive<T: DeserializeOwned>(&mut self) -> Result<Option<T>, WebSocketError> {
123 let socket = self.active_socket()?;
124
125 match socket.read() {
126 Ok(msg) => match msg {
127 Message::Text(text) => {
128 let deserialized: T = serde_json::from_str(&text)
129 .map_err(|e| WebSocketError::SerializationError(e.to_string()))?;
130 Ok(Some(deserialized))
131 }
132 Message::Binary(_) => {
133 tracing::warn!("Received unexpected binary message");
134 Ok(None)
135 }
136 Message::Ping(_) | Message::Pong(_) | Message::Close(_) => Ok(None),
137 Message::Frame(frame) => {
138 tracing::warn!("Received unexpected frame message: {:?}", frame);
139 Ok(None)
140 }
141 },
142 Err(tungstenite::Error::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
143 Ok(None)
145 }
146 Err(e) => Err(WebSocketError::ReceiveError(e.to_string())),
147 }
148 }
149
150 fn attempt_send(socket: &mut Socket, payload: &str) -> Result<(), WebSocketError> {
151 socket
152 .send(Message::Text(Utf8Bytes::from(payload)))
153 .map_err(|e| WebSocketError::SendError(e.to_string()))
154 }
155
156 pub fn close(&mut self) -> Result<(), WebSocketError> {
158 let socket = self.active_socket()?;
159 socket
160 .close(None)
161 .map_err(|e| WebSocketError::SendError(e.to_string()))
162 }
163
164 pub fn wait_until_closed(&mut self) -> Result<(), WebSocketError> {
166 let socket = self.active_socket()?;
167 match socket.get_mut() {
168 MaybeTlsStream::Plain(stream) => stream.set_nonblocking(false),
169 MaybeTlsStream::NativeTls(stream) => stream.get_mut().set_nonblocking(false),
170 _ => unimplemented!("Other TLS streams are not supported"),
171 }
172 .map_err(|e| {
173 WebSocketError::ConnectionError(format!("Failed to set blocking mode: {e}"))
174 })?;
175 loop {
176 match socket.read() {
177 Ok(_) => {}
178 Err(tungstenite::Error::ConnectionClosed | tungstenite::Error::AlreadyClosed) => {
179 tracing::debug!("WebSocket connection closed");
180 break;
181 }
182 Err(e) => {
183 tracing::error!("WebSocket read error while waiting until closed: {e}");
184 return Err(WebSocketError::SendError(e.to_string()));
185 }
186 }
187 }
188 Ok(())
189 }
190
191 fn active_socket(&mut self) -> Result<&mut Socket, WebSocketError> {
192 if let Some(socket) = self.state.as_mut() {
193 Ok(&mut socket.socket)
194 } else {
195 Err(WebSocketError::NotConnected)
196 }
197 }
198}
199
200impl Drop for WebSocketClient {
201 fn drop(&mut self) {
202 _ = self.close();
203 }
204}