Skip to main content

deribit_websocket/connection/
ws_connection.rs

1//! WebSocket connection management
2
3use crate::error::WebSocketError;
4use futures_util::{SinkExt, StreamExt};
5use tokio::net::TcpStream;
6use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message};
7use url::Url;
8
9/// WebSocket connection wrapper
10#[derive(Debug)]
11pub struct WebSocketConnection {
12    url: Url,
13    stream: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
14}
15
16impl WebSocketConnection {
17    /// Create a new WebSocket connection
18    pub fn new(url: Url) -> Self {
19        Self { url, stream: None }
20    }
21
22    /// Connect to the WebSocket server
23    pub async fn connect(&mut self) -> Result<(), WebSocketError> {
24        match connect_async(self.url.as_str()).await {
25            Ok((stream, _response)) => {
26                self.stream = Some(stream);
27                Ok(())
28            }
29            Err(e) => Err(WebSocketError::ConnectionFailed(format!(
30                "Failed to connect: {}",
31                e
32            ))),
33        }
34    }
35
36    /// Disconnect from the WebSocket server
37    pub async fn disconnect(&mut self) -> Result<(), WebSocketError> {
38        self.stream = None;
39        Ok(())
40    }
41
42    /// Check if connected
43    pub fn is_connected(&self) -> bool {
44        self.stream.is_some()
45    }
46
47    /// Send a message
48    pub async fn send(&mut self, message: String) -> Result<(), WebSocketError> {
49        if let Some(stream) = &mut self.stream {
50            match stream.send(Message::Text(message.into())).await {
51                Ok(()) => Ok(()),
52                Err(e) => {
53                    self.stream = None;
54                    Err(WebSocketError::ConnectionFailed(format!(
55                        "Failed to send message: {}",
56                        e
57                    )))
58                }
59            }
60        } else {
61            Err(WebSocketError::ConnectionClosed)
62        }
63    }
64
65    /// Receive a message
66    pub async fn receive(&mut self) -> Result<String, WebSocketError> {
67        if let Some(stream) = &mut self.stream {
68            loop {
69                match stream.next().await {
70                    Some(Ok(Message::Text(text))) => return Ok(text.to_string()),
71                    Some(Ok(
72                        Message::Binary(_)
73                        | Message::Ping(_)
74                        | Message::Pong(_)
75                        | Message::Frame(_),
76                    )) => {
77                        // Skip non-text frames and continue draining the stream.
78                        continue;
79                    }
80                    Some(Ok(Message::Close(_))) => {
81                        self.stream = None;
82                        return Err(WebSocketError::ConnectionClosed);
83                    }
84                    Some(Err(e)) => {
85                        self.stream = None;
86                        return Err(WebSocketError::ConnectionFailed(format!(
87                            "Failed to receive message: {}",
88                            e
89                        )));
90                    }
91                    None => {
92                        self.stream = None;
93                        return Err(WebSocketError::ConnectionClosed);
94                    }
95                }
96            }
97        } else {
98            Err(WebSocketError::ConnectionClosed)
99        }
100    }
101
102    /// Get the connection URL
103    pub fn url(&self) -> &Url {
104        &self.url
105    }
106}
107
108#[cfg(test)]
109#[allow(clippy::unwrap_used, clippy::expect_used)]
110mod tests {
111    use super::*;
112    use futures_util::{SinkExt, StreamExt};
113    use std::net::SocketAddr;
114    use tokio::net::TcpListener;
115    use tokio::task::JoinHandle;
116    use tokio_tungstenite::accept_async;
117    use tokio_tungstenite::tungstenite::Message;
118
119    /// Spawn a local WebSocket server that accepts a single connection and
120    /// runs `send_frames` over the server sink. A concurrent read-drain task
121    /// keeps the peer side from blocking on auto-pong back-pressure. Returns
122    /// the bound address and a `JoinHandle` for the acceptor task.
123    async fn spawn_mock_server<F, Fut>(send_frames: F) -> (SocketAddr, JoinHandle<()>)
124    where
125        F: FnOnce(
126                futures_util::stream::SplitSink<WebSocketStream<tokio::net::TcpStream>, Message>,
127            ) -> Fut
128            + Send
129            + 'static,
130        Fut: std::future::Future<Output = ()> + Send,
131    {
132        let listener = TcpListener::bind("127.0.0.1:0")
133            .await
134            .expect("bind localhost ephemeral port");
135        let addr = listener
136            .local_addr()
137            .expect("read local addr of bound listener");
138        let handle = tokio::spawn(async move {
139            let (socket, _peer) = match listener.accept().await {
140                Ok(pair) => pair,
141                Err(_) => return,
142            };
143            let ws = match accept_async(socket).await {
144                Ok(ws) => ws,
145                Err(_) => return,
146            };
147            let (sink, mut stream) = ws.split();
148            // Drain anything the client sends (including auto-pongs from
149            // tungstenite) so the client's write side never blocks on a
150            // full socket buffer. The drain task exits when the client
151            // disconnects at end of test.
152            let drain = tokio::spawn(async move {
153                while let Some(msg) = stream.next().await {
154                    if msg.is_err() {
155                        break;
156                    }
157                }
158            });
159            send_frames(sink).await;
160            let _ = drain.await;
161        });
162        (addr, handle)
163    }
164
165    fn ws_url(addr: SocketAddr) -> Url {
166        Url::parse(&format!("ws://{}/", addr)).expect("valid ws url")
167    }
168
169    async fn connect_client(addr: SocketAddr) -> WebSocketConnection {
170        let mut client = WebSocketConnection::new(ws_url(addr));
171        client
172            .connect()
173            .await
174            .expect("client connects to mock server");
175        client
176    }
177
178    #[tokio::test]
179    async fn test_receive_skips_ping_frames_then_returns_text() {
180        let (addr, server) = spawn_mock_server(|mut sink| async move {
181            for _ in 0..10_000 {
182                if sink.send(Message::Ping(Vec::new().into())).await.is_err() {
183                    return;
184                }
185            }
186            let _ = sink.send(Message::Text("payload".into())).await;
187        })
188        .await;
189
190        let mut client = connect_client(addr).await;
191        let received = client.receive().await.expect("receive returns the text");
192        assert_eq!(received, "payload");
193        drop(client);
194        server.await.expect("server task did not panic");
195    }
196
197    #[tokio::test]
198    async fn test_receive_skips_binary_frames_then_returns_text() {
199        let (addr, server) = spawn_mock_server(|mut sink| async move {
200            for _ in 0..100 {
201                if sink
202                    .send(Message::Binary(vec![1, 2, 3].into()))
203                    .await
204                    .is_err()
205                {
206                    return;
207                }
208            }
209            let _ = sink.send(Message::Text("payload".into())).await;
210        })
211        .await;
212
213        let mut client = connect_client(addr).await;
214        let received = client.receive().await.expect("receive returns the text");
215        assert_eq!(received, "payload");
216        drop(client);
217        server.await.expect("server task did not panic");
218    }
219
220    #[tokio::test]
221    async fn test_receive_skips_pong_frames_then_returns_text() {
222        let (addr, server) = spawn_mock_server(|mut sink| async move {
223            for _ in 0..100 {
224                if sink.send(Message::Pong(Vec::new().into())).await.is_err() {
225                    return;
226                }
227            }
228            let _ = sink.send(Message::Text("payload".into())).await;
229        })
230        .await;
231
232        let mut client = connect_client(addr).await;
233        let received = client.receive().await.expect("receive returns the text");
234        assert_eq!(received, "payload");
235        drop(client);
236        server.await.expect("server task did not panic");
237    }
238
239    #[tokio::test]
240    async fn test_receive_returns_closed_on_close_frame() {
241        let (addr, server) = spawn_mock_server(|mut sink| async move {
242            let _ = sink.send(Message::Close(None)).await;
243            let _ = sink.close().await;
244        })
245        .await;
246
247        let mut client = connect_client(addr).await;
248        let result = client.receive().await;
249        assert!(
250            matches!(result, Err(WebSocketError::ConnectionClosed)),
251            "expected ConnectionClosed, got {:?}",
252            result
253        );
254        assert!(
255            !client.is_connected(),
256            "stream should be cleared after close frame"
257        );
258        drop(client);
259        server.await.expect("server task did not panic");
260    }
261
262    #[tokio::test]
263    async fn test_receive_skips_mixed_non_text_frames() {
264        let (addr, server) = spawn_mock_server(|mut sink| async move {
265            for _ in 0..200 {
266                if sink.send(Message::Ping(Vec::new().into())).await.is_err() {
267                    return;
268                }
269                if sink
270                    .send(Message::Binary(vec![9, 9, 9].into()))
271                    .await
272                    .is_err()
273                {
274                    return;
275                }
276                if sink.send(Message::Pong(Vec::new().into())).await.is_err() {
277                    return;
278                }
279            }
280            let _ = sink.send(Message::Text("payload".into())).await;
281        })
282        .await;
283
284        let mut client = connect_client(addr).await;
285        let received = client.receive().await.expect("receive returns the text");
286        assert_eq!(received, "payload");
287        drop(client);
288        server.await.expect("server task did not panic");
289    }
290}