1use futures::{SinkExt, StreamExt};
40use std::io::{Error as IoError, ErrorKind};
41use tokio::time::{Duration, timeout};
42use tokio_tungstenite::{
43 MaybeTlsStream, WebSocketStream, connect_async,
44 tungstenite::{Error as WsError, Message},
45};
46
47pub struct WebSocketTestClient {
52 stream: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
54 url: String,
56}
57
58impl WebSocketTestClient {
59 pub async fn connect(url: &str) -> Result<Self, WsError> {
73 let (stream, _response) = connect_async(url).await?;
74 Ok(Self {
75 stream,
76 url: url.to_string(),
77 })
78 }
79
80 pub async fn connect_with_token(url: &str, token: &str) -> Result<Self, WsError> {
99 use tokio_tungstenite::tungstenite::http::Request;
100
101 let request = Request::builder()
102 .uri(url)
103 .header("Authorization", format!("Bearer {}", token))
104 .body(())
105 .expect("Failed to build WebSocket request");
106
107 let (stream, _response) = connect_async(request).await?;
108 Ok(Self {
109 stream,
110 url: url.to_string(),
111 })
112 }
113
114 pub async fn connect_with_query_token(url: &str, token: &str) -> Result<Self, WsError> {
134 let url_with_token = format!("{}?token={}", url, urlencoding::encode(token));
135 Self::connect(&url_with_token).await
136 }
137
138 pub async fn connect_with_cookie(
158 url: &str,
159 cookie_name: &str,
160 cookie_value: &str,
161 ) -> Result<Self, WsError> {
162 use tokio_tungstenite::tungstenite::http::Request;
163
164 let request = Request::builder()
165 .uri(url)
166 .header("Cookie", format!("{}={}", cookie_name, cookie_value))
167 .body(())
168 .expect("Failed to build WebSocket request");
169
170 let (stream, _response) = connect_async(request).await?;
171 Ok(Self {
172 stream,
173 url: url.to_string(),
174 })
175 }
176
177 pub async fn send_text(&mut self, text: &str) -> Result<(), WsError> {
192 self.stream.send(Message::Text(text.to_string())).await
193 }
194
195 pub async fn send_binary(&mut self, data: &[u8]) -> Result<(), WsError> {
197 self.stream.send(Message::Binary(data.to_vec())).await
198 }
199
200 pub async fn send_ping(&mut self, payload: &[u8]) -> Result<(), WsError> {
202 self.stream.send(Message::Ping(payload.to_vec())).await
203 }
204
205 pub async fn send_pong(&mut self, payload: &[u8]) -> Result<(), WsError> {
207 self.stream.send(Message::Pong(payload.to_vec())).await
208 }
209
210 pub async fn receive(&mut self) -> Option<Result<Message, WsError>> {
214 self.stream.next().await
215 }
216
217 pub async fn receive_text(&mut self) -> Result<String, WsError> {
233 self.receive_text_with_timeout(Duration::from_secs(5)).await
234 }
235
236 pub async fn receive_text_with_timeout(
238 &mut self,
239 duration: Duration,
240 ) -> Result<String, WsError> {
241 match timeout(duration, self.stream.next()).await {
242 Ok(Some(Ok(Message::Text(text)))) => Ok(text),
243 Ok(Some(Ok(msg))) => Err(WsError::Io(IoError::new(
244 ErrorKind::InvalidData,
245 format!("Expected text message, got {:?}", msg),
246 ))),
247 Ok(Some(Err(e))) => Err(e),
248 Ok(None) => Err(WsError::ConnectionClosed),
249 Err(_) => Err(WsError::Io(IoError::new(
250 ErrorKind::TimedOut,
251 "Receive timeout",
252 ))),
253 }
254 }
255
256 pub async fn receive_binary(&mut self) -> Result<Vec<u8>, WsError> {
258 self.receive_binary_with_timeout(Duration::from_secs(5))
259 .await
260 }
261
262 pub async fn receive_binary_with_timeout(
264 &mut self,
265 duration: Duration,
266 ) -> Result<Vec<u8>, WsError> {
267 match timeout(duration, self.stream.next()).await {
268 Ok(Some(Ok(Message::Binary(data)))) => Ok(data),
269 Ok(Some(Ok(msg))) => Err(WsError::Io(IoError::new(
270 ErrorKind::InvalidData,
271 format!("Expected binary message, got {:?}", msg),
272 ))),
273 Ok(Some(Err(e))) => Err(e),
274 Ok(None) => Err(WsError::ConnectionClosed),
275 Err(_) => Err(WsError::Io(IoError::new(
276 ErrorKind::TimedOut,
277 "Receive timeout",
278 ))),
279 }
280 }
281
282 pub async fn close(mut self) -> Result<(), WsError> {
284 self.stream.close(None).await
285 }
286
287 pub fn url(&self) -> &str {
289 &self.url
290 }
291}
292
293pub mod assertions {
295 use tokio_tungstenite::tungstenite::Message;
296
297 pub fn assert_message_text(msg: &Message, expected: &str) {
308 match msg {
309 Message::Text(text) => assert_eq!(text, expected),
310 _ => panic!("Expected text message, got {:?}", msg),
311 }
312 }
313
314 pub fn assert_message_contains(msg: &Message, substring: &str) {
316 match msg {
317 Message::Text(text) => assert!(
318 text.contains(substring),
319 "Message '{}' does not contain '{}'",
320 text,
321 substring
322 ),
323 _ => panic!("Expected text message, got {:?}", msg),
324 }
325 }
326
327 pub fn assert_message_binary(msg: &Message, expected: &[u8]) {
329 match msg {
330 Message::Binary(data) => assert_eq!(data, expected),
331 _ => panic!("Expected binary message, got {:?}", msg),
332 }
333 }
334
335 pub fn assert_message_ping(msg: &Message) {
337 match msg {
338 Message::Ping(_) => {}
339 _ => panic!("Expected ping message, got {:?}", msg),
340 }
341 }
342
343 pub fn assert_message_pong(msg: &Message) {
345 match msg {
346 Message::Pong(_) => {}
347 _ => panic!("Expected pong message, got {:?}", msg),
348 }
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn test_url_with_query_token() {
358 let url = "ws://localhost:8080/ws";
359 let token = "my-token";
360 let expected = "ws://localhost:8080/ws?token=my-token";
361
362 let url_with_token = format!("{}?token={}", url, urlencoding::encode(token));
363 assert_eq!(url_with_token, expected);
364 }
365
366 #[test]
367 fn test_url_with_query_token_special_chars() {
368 let url = "ws://localhost:8080/ws";
369 let token = "token with spaces&special=chars";
370 let url_with_token = format!("{}?token={}", url, urlencoding::encode(token));
371 assert_eq!(
372 url_with_token,
373 "ws://localhost:8080/ws?token=token%20with%20spaces%26special%3Dchars"
374 );
375 }
376
377 #[test]
378 fn test_message_assertions() {
379 use assertions::*;
380
381 let text_msg = Message::Text("Hello".to_string());
382 assert_message_text(&text_msg, "Hello");
383 assert_message_contains(&text_msg, "ell");
384
385 let binary_msg = Message::Binary(vec![1, 2, 3]);
386 assert_message_binary(&binary_msg, &[1, 2, 3]);
387
388 let ping_msg = Message::Ping(vec![]);
389 assert_message_ping(&ping_msg);
390
391 let pong_msg = Message::Pong(vec![]);
392 assert_message_pong(&pong_msg);
393 }
394}