burn_central_client/
websocket.rs

1use std::{thread, time::Duration};
2
3use reqwest::header::COOKIE;
4use serde::Serialize;
5
6use thiserror::Error;
7
8use tungstenite::{
9    Message, Utf8Bytes, WebSocket, client::IntoClientRequest, connect, stream::MaybeTlsStream,
10};
11
12pub use crate::experiment::websocket::*;
13
14#[derive(Error, Debug)]
15#[allow(clippy::enum_variant_names)]
16pub enum WebSocketError {
17    #[error("Failed to connect WebSocket: {0}")]
18    ConnectionError(String),
19    #[error("WebSocket send error: {0}")]
20    SendError(String),
21    #[error("WebSocket is not connected")]
22    NotConnected,
23    #[error("WebSocket cannot reconnect: {0}")]
24    CannotReconnect(String),
25}
26
27const DEFAULT_RECONNECT_DELAY: Duration = Duration::from_millis(1000);
28
29type Socket = WebSocket<MaybeTlsStream<std::net::TcpStream>>;
30struct ConnectedSocket {
31    socket: Socket,
32    url: String,
33    cookie: String,
34}
35
36#[derive(Default)]
37pub struct WebSocketClient {
38    state: Option<ConnectedSocket>,
39}
40
41impl WebSocketClient {
42    pub fn new() -> Self {
43        Self::default()
44    }
45
46    #[allow(dead_code)]
47    pub fn is_connected(&self) -> bool {
48        self.state.is_some()
49    }
50
51    pub fn connect(&mut self, url: String, session_cookie: &str) -> Result<(), WebSocketError> {
52        let mut req = url
53            .clone()
54            .into_client_request()
55            .expect("Should be able to create a client request from the URL");
56
57        req.headers_mut().append(
58            COOKIE,
59            session_cookie
60                .parse()
61                .expect("Should be able to parse cookie header"),
62        );
63
64        let (socket, _) =
65            connect(req).map_err(|e| WebSocketError::ConnectionError(e.to_string()))?;
66
67        self.state = Some(ConnectedSocket {
68            socket,
69            url,
70            cookie: session_cookie.to_string(),
71        });
72        Ok(())
73    }
74
75    pub fn reconnect(&mut self) -> Result<(), WebSocketError> {
76        if let Some(socket) = self.state.take() {
77            self.connect(socket.url, &socket.cookie)
78        } else {
79            Err(WebSocketError::CannotReconnect(
80                "The websocket was never opened so it cannot be reconnected".to_string(),
81            ))
82        }
83    }
84
85    pub fn send<I: Serialize>(&mut self, message: I) -> Result<(), WebSocketError> {
86        let socket = self.active_socket()?;
87
88        let json = serde_json::to_string(&message)
89            .map_err(|e| WebSocketError::SendError(e.to_string()))?;
90
91        match Self::attempt_send(socket, &json) {
92            Ok(_) => Ok(()),
93            Err(_) => {
94                tracing::debug!("WebSocket send failed, attempting to reconnect...");
95                thread::sleep(DEFAULT_RECONNECT_DELAY);
96                self.reconnect()?;
97
98                let socket = self.active_socket()?;
99                Self::attempt_send(socket, &json)
100            }
101        }
102    }
103
104    fn attempt_send(socket: &mut Socket, payload: &str) -> Result<(), WebSocketError> {
105        socket
106            .send(Message::Text(Utf8Bytes::from(payload)))
107            .map_err(|e| WebSocketError::SendError(e.to_string()))
108    }
109
110    pub fn close(&mut self) -> Result<(), WebSocketError> {
111        let socket = self.active_socket()?;
112        socket
113            .close(None)
114            .map_err(|e| WebSocketError::SendError(e.to_string()))
115    }
116
117    pub fn wait_until_closed(&mut self) -> Result<(), WebSocketError> {
118        let socket = self.active_socket()?;
119        loop {
120            match socket.read() {
121                Ok(_) => {}
122                Err(tungstenite::Error::ConnectionClosed | tungstenite::Error::AlreadyClosed) => {
123                    tracing::debug!("WebSocket connection closed");
124                    break;
125                }
126                Err(e) => {
127                    tracing::error!("WebSocket read error while waiting until closed: {e}");
128                    return Err(WebSocketError::SendError(e.to_string()));
129                }
130            }
131        }
132        Ok(())
133    }
134
135    fn active_socket(&mut self) -> Result<&mut Socket, WebSocketError> {
136        if let Some(socket) = self.state.as_mut() {
137            Ok(&mut socket.socket)
138        } else {
139            Err(WebSocketError::NotConnected)
140        }
141    }
142}
143
144impl Drop for WebSocketClient {
145    fn drop(&mut self) {
146        _ = self.close();
147    }
148}