burn_central_client/
websocket.rs1use std::{thread, time::Duration};
2
3use reqwest::header::COOKIE;
4use serde::Serialize;
5
6use thiserror::Error;
7
8use tungstenite::{
9 Message, Utf8Bytes, WebSocket, client::IntoClientRequest, connect, stream::MaybeTlsStream,
10};
11
12pub use crate::experiment::websocket::*;
13
14#[derive(Error, Debug)]
15#[allow(clippy::enum_variant_names)]
16pub enum WebSocketError {
17 #[error("Failed to connect WebSocket: {0}")]
18 ConnectionError(String),
19 #[error("WebSocket send error: {0}")]
20 SendError(String),
21 #[error("WebSocket is not connected")]
22 NotConnected,
23 #[error("WebSocket cannot reconnect: {0}")]
24 CannotReconnect(String),
25}
26
27const DEFAULT_RECONNECT_DELAY: Duration = Duration::from_millis(1000);
28
29type Socket = WebSocket<MaybeTlsStream<std::net::TcpStream>>;
30struct ConnectedSocket {
31 socket: Socket,
32 url: String,
33 cookie: String,
34}
35
36#[derive(Default)]
37pub struct WebSocketClient {
38 state: Option<ConnectedSocket>,
39}
40
41impl WebSocketClient {
42 pub fn new() -> Self {
43 Self::default()
44 }
45
46 #[allow(dead_code)]
47 pub fn is_connected(&self) -> bool {
48 self.state.is_some()
49 }
50
51 pub fn connect(&mut self, url: String, session_cookie: &str) -> Result<(), WebSocketError> {
52 let mut req = url
53 .clone()
54 .into_client_request()
55 .expect("Should be able to create a client request from the URL");
56
57 req.headers_mut().append(
58 COOKIE,
59 session_cookie
60 .parse()
61 .expect("Should be able to parse cookie header"),
62 );
63
64 let (socket, _) =
65 connect(req).map_err(|e| WebSocketError::ConnectionError(e.to_string()))?;
66
67 self.state = Some(ConnectedSocket {
68 socket,
69 url,
70 cookie: session_cookie.to_string(),
71 });
72 Ok(())
73 }
74
75 pub fn reconnect(&mut self) -> Result<(), WebSocketError> {
76 if let Some(socket) = self.state.take() {
77 self.connect(socket.url, &socket.cookie)
78 } else {
79 Err(WebSocketError::CannotReconnect(
80 "The websocket was never opened so it cannot be reconnected".to_string(),
81 ))
82 }
83 }
84
85 pub fn send<I: Serialize>(&mut self, message: I) -> Result<(), WebSocketError> {
86 let socket = self.active_socket()?;
87
88 let json = serde_json::to_string(&message)
89 .map_err(|e| WebSocketError::SendError(e.to_string()))?;
90
91 match Self::attempt_send(socket, &json) {
92 Ok(_) => Ok(()),
93 Err(_) => {
94 tracing::debug!("WebSocket send failed, attempting to reconnect...");
95 thread::sleep(DEFAULT_RECONNECT_DELAY);
96 self.reconnect()?;
97
98 let socket = self.active_socket()?;
99 Self::attempt_send(socket, &json)
100 }
101 }
102 }
103
104 fn attempt_send(socket: &mut Socket, payload: &str) -> Result<(), WebSocketError> {
105 socket
106 .send(Message::Text(Utf8Bytes::from(payload)))
107 .map_err(|e| WebSocketError::SendError(e.to_string()))
108 }
109
110 pub fn close(&mut self) -> Result<(), WebSocketError> {
111 let socket = self.active_socket()?;
112 socket
113 .close(None)
114 .map_err(|e| WebSocketError::SendError(e.to_string()))
115 }
116
117 pub fn wait_until_closed(&mut self) -> Result<(), WebSocketError> {
118 let socket = self.active_socket()?;
119 loop {
120 match socket.read() {
121 Ok(_) => {}
122 Err(tungstenite::Error::ConnectionClosed | tungstenite::Error::AlreadyClosed) => {
123 tracing::debug!("WebSocket connection closed");
124 break;
125 }
126 Err(e) => {
127 tracing::error!("WebSocket read error while waiting until closed: {e}");
128 return Err(WebSocketError::SendError(e.to_string()));
129 }
130 }
131 }
132 Ok(())
133 }
134
135 fn active_socket(&mut self) -> Result<&mut Socket, WebSocketError> {
136 if let Some(socket) = self.state.as_mut() {
137 Ok(&mut socket.socket)
138 } else {
139 Err(WebSocketError::NotConnected)
140 }
141 }
142}
143
144impl Drop for WebSocketClient {
145 fn drop(&mut self) {
146 _ = self.close();
147 }
148}