burn_central_client/
websocket.rs

1use std::{thread, time::Duration};
2
3use reqwest::header::COOKIE;
4use serde::Serialize;
5
6use thiserror::Error;
7
8use tungstenite::{
9    Message, Utf8Bytes, WebSocket, client::IntoClientRequest, connect, stream::MaybeTlsStream,
10};
11
12pub use crate::experiment::websocket::*;
13
14#[derive(Error, Debug)]
15#[allow(clippy::enum_variant_names)]
16pub enum WebSocketError {
17    #[error("Failed to connect WebSocket: {0}")]
18    ConnectionError(String),
19    #[error("WebSocket send error: {0}")]
20    SendError(String),
21    #[error("WebSocket is not connected")]
22    NotConnected,
23    #[error("WebSocket cannot reconnect: {0}")]
24    CannotReconnect(String),
25}
26
27const DEFAULT_RECONNECT_DELAY: Duration = Duration::from_millis(1000);
28
29type Socket = WebSocket<MaybeTlsStream<std::net::TcpStream>>;
30struct ConnectedSocket {
31    socket: Socket,
32    url: String,
33    cookie: String,
34}
35
36#[derive(Default)]
37pub struct WebSocketClient {
38    state: Option<ConnectedSocket>,
39}
40
41impl WebSocketClient {
42    pub fn new() -> Self {
43        Self::default()
44    }
45
46    #[allow(dead_code)]
47    pub fn is_connected(&self) -> bool {
48        self.state.is_some()
49    }
50
51    pub fn connect(&mut self, url: String, session_cookie: &str) -> Result<(), WebSocketError> {
52        let mut req = url
53            .clone()
54            .into_client_request()
55            .expect("Should be able to create a client request from the URL");
56
57        req.headers_mut().append(
58            COOKIE,
59            session_cookie
60                .parse()
61                .expect("Should be able to parse cookie header"),
62        );
63
64        let (socket, _) =
65            connect(req).map_err(|e| WebSocketError::ConnectionError(e.to_string()))?;
66
67        self.state = Some(ConnectedSocket {
68            socket,
69            url,
70            cookie: session_cookie.to_string(),
71        });
72        Ok(())
73    }
74
75    pub fn reconnect(&mut self) -> Result<(), WebSocketError> {
76        if let Some(socket) = self.state.take() {
77            self.connect(socket.url, &socket.cookie)
78        } else {
79            Err(WebSocketError::CannotReconnect(
80                "The websocket was never opened so it cannot be reconnected".to_string(),
81            ))
82        }
83    }
84
85    pub fn send<I: Serialize>(&mut self, message: I) -> Result<(), WebSocketError> {
86        let socket = self.active_socket()?;
87
88        let json = serde_json::to_string(&message)
89            .map_err(|e| WebSocketError::SendError(e.to_string()))?;
90
91        match Self::attempt_send(socket, &json) {
92            Ok(_) => Ok(()),
93            Err(_) => {
94                thread::sleep(DEFAULT_RECONNECT_DELAY);
95                self.reconnect()?;
96
97                let socket = self.active_socket()?;
98                Self::attempt_send(socket, &json)
99            }
100        }
101    }
102
103    fn attempt_send(socket: &mut Socket, payload: &str) -> Result<(), WebSocketError> {
104        socket
105            .send(Message::Text(Utf8Bytes::from(payload)))
106            .map_err(|e| WebSocketError::SendError(e.to_string()))
107    }
108
109    pub fn close(&mut self) -> Result<(), WebSocketError> {
110        let socket = self.active_socket()?;
111        socket
112            .close(None)
113            .map_err(|e| WebSocketError::SendError(e.to_string()))
114    }
115
116    fn active_socket(&mut self) -> Result<&mut Socket, WebSocketError> {
117        if let Some(socket) = self.state.as_mut() {
118            Ok(&mut socket.socket)
119        } else {
120            Err(WebSocketError::NotConnected)
121        }
122    }
123}
124
125impl Drop for WebSocketClient {
126    fn drop(&mut self) {
127        _ = self.close();
128    }
129}