deribit_websocket/connection/
ws_connection.rs1use crate::error::WebSocketError;
4use futures_util::{SinkExt, StreamExt};
5use tokio::net::TcpStream;
6use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message};
7use url::Url;
8
9#[derive(Debug)]
11pub struct WebSocketConnection {
12 url: Url,
13 stream: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
14}
15
16impl WebSocketConnection {
17 pub fn new(url: Url) -> Self {
19 Self { url, stream: None }
20 }
21
22 pub async fn connect(&mut self) -> Result<(), WebSocketError> {
24 match connect_async(self.url.as_str()).await {
25 Ok((stream, _response)) => {
26 self.stream = Some(stream);
27 Ok(())
28 }
29 Err(e) => Err(WebSocketError::ConnectionFailed(format!(
30 "Failed to connect: {}",
31 e
32 ))),
33 }
34 }
35
36 pub async fn disconnect(&mut self) -> Result<(), WebSocketError> {
38 self.stream = None;
39 Ok(())
40 }
41
42 pub fn is_connected(&self) -> bool {
44 self.stream.is_some()
45 }
46
47 pub async fn send(&mut self, message: String) -> Result<(), WebSocketError> {
49 if let Some(stream) = &mut self.stream {
50 match stream.send(Message::Text(message.into())).await {
51 Ok(()) => Ok(()),
52 Err(e) => {
53 self.stream = None;
54 Err(WebSocketError::ConnectionFailed(format!(
55 "Failed to send message: {}",
56 e
57 )))
58 }
59 }
60 } else {
61 Err(WebSocketError::ConnectionClosed)
62 }
63 }
64
65 pub async fn receive(&mut self) -> Result<String, WebSocketError> {
67 if let Some(stream) = &mut self.stream {
68 loop {
69 match stream.next().await {
70 Some(Ok(Message::Text(text))) => return Ok(text.to_string()),
71 Some(Ok(
72 Message::Binary(_)
73 | Message::Ping(_)
74 | Message::Pong(_)
75 | Message::Frame(_),
76 )) => {
77 continue;
79 }
80 Some(Ok(Message::Close(_))) => {
81 self.stream = None;
82 return Err(WebSocketError::ConnectionClosed);
83 }
84 Some(Err(e)) => {
85 self.stream = None;
86 return Err(WebSocketError::ConnectionFailed(format!(
87 "Failed to receive message: {}",
88 e
89 )));
90 }
91 None => {
92 self.stream = None;
93 return Err(WebSocketError::ConnectionClosed);
94 }
95 }
96 }
97 } else {
98 Err(WebSocketError::ConnectionClosed)
99 }
100 }
101
102 pub fn url(&self) -> &Url {
104 &self.url
105 }
106}
107
108#[cfg(test)]
109#[allow(clippy::unwrap_used, clippy::expect_used)]
110mod tests {
111 use super::*;
112 use futures_util::{SinkExt, StreamExt};
113 use std::net::SocketAddr;
114 use tokio::net::TcpListener;
115 use tokio::task::JoinHandle;
116 use tokio_tungstenite::accept_async;
117 use tokio_tungstenite::tungstenite::Message;
118
119 async fn spawn_mock_server<F, Fut>(send_frames: F) -> (SocketAddr, JoinHandle<()>)
124 where
125 F: FnOnce(
126 futures_util::stream::SplitSink<WebSocketStream<tokio::net::TcpStream>, Message>,
127 ) -> Fut
128 + Send
129 + 'static,
130 Fut: std::future::Future<Output = ()> + Send,
131 {
132 let listener = TcpListener::bind("127.0.0.1:0")
133 .await
134 .expect("bind localhost ephemeral port");
135 let addr = listener
136 .local_addr()
137 .expect("read local addr of bound listener");
138 let handle = tokio::spawn(async move {
139 let (socket, _peer) = match listener.accept().await {
140 Ok(pair) => pair,
141 Err(_) => return,
142 };
143 let ws = match accept_async(socket).await {
144 Ok(ws) => ws,
145 Err(_) => return,
146 };
147 let (sink, mut stream) = ws.split();
148 let drain = tokio::spawn(async move {
153 while let Some(msg) = stream.next().await {
154 if msg.is_err() {
155 break;
156 }
157 }
158 });
159 send_frames(sink).await;
160 let _ = drain.await;
161 });
162 (addr, handle)
163 }
164
165 fn ws_url(addr: SocketAddr) -> Url {
166 Url::parse(&format!("ws://{}/", addr)).expect("valid ws url")
167 }
168
169 async fn connect_client(addr: SocketAddr) -> WebSocketConnection {
170 let mut client = WebSocketConnection::new(ws_url(addr));
171 client
172 .connect()
173 .await
174 .expect("client connects to mock server");
175 client
176 }
177
178 #[tokio::test]
179 async fn test_receive_skips_ping_frames_then_returns_text() {
180 let (addr, server) = spawn_mock_server(|mut sink| async move {
181 for _ in 0..10_000 {
182 if sink.send(Message::Ping(Vec::new().into())).await.is_err() {
183 return;
184 }
185 }
186 let _ = sink.send(Message::Text("payload".into())).await;
187 })
188 .await;
189
190 let mut client = connect_client(addr).await;
191 let received = client.receive().await.expect("receive returns the text");
192 assert_eq!(received, "payload");
193 drop(client);
194 server.await.expect("server task did not panic");
195 }
196
197 #[tokio::test]
198 async fn test_receive_skips_binary_frames_then_returns_text() {
199 let (addr, server) = spawn_mock_server(|mut sink| async move {
200 for _ in 0..100 {
201 if sink
202 .send(Message::Binary(vec![1, 2, 3].into()))
203 .await
204 .is_err()
205 {
206 return;
207 }
208 }
209 let _ = sink.send(Message::Text("payload".into())).await;
210 })
211 .await;
212
213 let mut client = connect_client(addr).await;
214 let received = client.receive().await.expect("receive returns the text");
215 assert_eq!(received, "payload");
216 drop(client);
217 server.await.expect("server task did not panic");
218 }
219
220 #[tokio::test]
221 async fn test_receive_skips_pong_frames_then_returns_text() {
222 let (addr, server) = spawn_mock_server(|mut sink| async move {
223 for _ in 0..100 {
224 if sink.send(Message::Pong(Vec::new().into())).await.is_err() {
225 return;
226 }
227 }
228 let _ = sink.send(Message::Text("payload".into())).await;
229 })
230 .await;
231
232 let mut client = connect_client(addr).await;
233 let received = client.receive().await.expect("receive returns the text");
234 assert_eq!(received, "payload");
235 drop(client);
236 server.await.expect("server task did not panic");
237 }
238
239 #[tokio::test]
240 async fn test_receive_returns_closed_on_close_frame() {
241 let (addr, server) = spawn_mock_server(|mut sink| async move {
242 let _ = sink.send(Message::Close(None)).await;
243 let _ = sink.close().await;
244 })
245 .await;
246
247 let mut client = connect_client(addr).await;
248 let result = client.receive().await;
249 assert!(
250 matches!(result, Err(WebSocketError::ConnectionClosed)),
251 "expected ConnectionClosed, got {:?}",
252 result
253 );
254 assert!(
255 !client.is_connected(),
256 "stream should be cleared after close frame"
257 );
258 drop(client);
259 server.await.expect("server task did not panic");
260 }
261
262 #[tokio::test]
263 async fn test_receive_skips_mixed_non_text_frames() {
264 let (addr, server) = spawn_mock_server(|mut sink| async move {
265 for _ in 0..200 {
266 if sink.send(Message::Ping(Vec::new().into())).await.is_err() {
267 return;
268 }
269 if sink
270 .send(Message::Binary(vec![9, 9, 9].into()))
271 .await
272 .is_err()
273 {
274 return;
275 }
276 if sink.send(Message::Pong(Vec::new().into())).await.is_err() {
277 return;
278 }
279 }
280 let _ = sink.send(Message::Text("payload".into())).await;
281 })
282 .await;
283
284 let mut client = connect_client(addr).await;
285 let received = client.receive().await.expect("receive returns the text");
286 assert_eq!(received, "payload");
287 drop(client);
288 server.await.expect("server task did not panic");
289 }
290}