1use futures::{SinkExt, StreamExt};
40use std::io::{Error as IoError, ErrorKind};
41use tokio::time::{Duration, timeout};
42use tokio_tungstenite::{
43 MaybeTlsStream, WebSocketStream, connect_async,
44 tungstenite::{Error as WsError, Message},
45};
46
47pub struct WebSocketTestClient {
52 stream: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
54 url: String,
56}
57
58impl WebSocketTestClient {
59 pub async fn connect(url: &str) -> Result<Self, WsError> {
73 let (stream, _response) = connect_async(url).await?;
74 Ok(Self {
75 stream,
76 url: url.to_string(),
77 })
78 }
79
80 pub async fn connect_with_token(url: &str, token: &str) -> Result<Self, WsError> {
99 use tokio_tungstenite::tungstenite::http::Request;
100
101 let request = Request::builder()
102 .uri(url)
103 .header("Authorization", format!("Bearer {}", token))
104 .body(())
105 .expect("Failed to build WebSocket request");
106
107 let (stream, _response) = connect_async(request).await?;
108 Ok(Self {
109 stream,
110 url: url.to_string(),
111 })
112 }
113
114 pub async fn connect_with_query_token(url: &str, token: &str) -> Result<Self, WsError> {
134 let url_with_token = format!("{}?token={}", url, urlencoding::encode(token));
135 Self::connect(&url_with_token).await
136 }
137
138 pub async fn connect_with_cookie(
158 url: &str,
159 cookie_name: &str,
160 cookie_value: &str,
161 ) -> Result<Self, WsError> {
162 use tokio_tungstenite::tungstenite::http::Request;
163
164 let request = Request::builder()
165 .uri(url)
166 .header("Cookie", format!("{}={}", cookie_name, cookie_value))
167 .body(())
168 .expect("Failed to build WebSocket request");
169
170 let (stream, _response) = connect_async(request).await?;
171 Ok(Self {
172 stream,
173 url: url.to_string(),
174 })
175 }
176
177 pub async fn send_text(&mut self, text: &str) -> Result<(), WsError> {
192 self.stream.send(Message::text(text)).await
193 }
194
195 pub async fn send_binary(&mut self, data: &[u8]) -> Result<(), WsError> {
197 self.stream.send(Message::binary(data.to_vec())).await
198 }
199
200 pub async fn send_ping(&mut self, payload: &[u8]) -> Result<(), WsError> {
202 self.stream
203 .send(Message::Ping(payload.to_vec().into()))
204 .await
205 }
206
207 pub async fn send_pong(&mut self, payload: &[u8]) -> Result<(), WsError> {
209 self.stream
210 .send(Message::Pong(payload.to_vec().into()))
211 .await
212 }
213
214 pub async fn receive(&mut self) -> Option<Result<Message, WsError>> {
218 self.stream.next().await
219 }
220
221 pub async fn receive_text(&mut self) -> Result<String, WsError> {
237 self.receive_text_with_timeout(Duration::from_secs(5)).await
238 }
239
240 pub async fn receive_text_with_timeout(
242 &mut self,
243 duration: Duration,
244 ) -> Result<String, WsError> {
245 match timeout(duration, self.stream.next()).await {
246 Ok(Some(Ok(Message::Text(text)))) => Ok(text.to_string()),
247 Ok(Some(Ok(msg))) => Err(WsError::Io(IoError::new(
248 ErrorKind::InvalidData,
249 format!("Expected text message, got {:?}", msg),
250 ))),
251 Ok(Some(Err(e))) => Err(e),
252 Ok(None) => Err(WsError::ConnectionClosed),
253 Err(_) => Err(WsError::Io(IoError::new(
254 ErrorKind::TimedOut,
255 "Receive timeout",
256 ))),
257 }
258 }
259
260 pub async fn receive_binary(&mut self) -> Result<Vec<u8>, WsError> {
262 self.receive_binary_with_timeout(Duration::from_secs(5))
263 .await
264 }
265
266 pub async fn receive_binary_with_timeout(
268 &mut self,
269 duration: Duration,
270 ) -> Result<Vec<u8>, WsError> {
271 match timeout(duration, self.stream.next()).await {
272 Ok(Some(Ok(Message::Binary(data)))) => Ok(data.to_vec()),
273 Ok(Some(Ok(msg))) => Err(WsError::Io(IoError::new(
274 ErrorKind::InvalidData,
275 format!("Expected binary message, got {:?}", msg),
276 ))),
277 Ok(Some(Err(e))) => Err(e),
278 Ok(None) => Err(WsError::ConnectionClosed),
279 Err(_) => Err(WsError::Io(IoError::new(
280 ErrorKind::TimedOut,
281 "Receive timeout",
282 ))),
283 }
284 }
285
286 pub async fn close(mut self) -> Result<(), WsError> {
288 self.stream.close(None).await
289 }
290
291 pub fn url(&self) -> &str {
293 &self.url
294 }
295}
296
297pub mod assertions {
299 use tokio_tungstenite::tungstenite::Message;
300
301 pub fn assert_message_text(msg: &Message, expected: &str) {
312 match msg {
313 Message::Text(text) => assert_eq!(text.as_str(), expected),
314 _ => panic!("Expected text message, got {:?}", msg),
315 }
316 }
317
318 pub fn assert_message_contains(msg: &Message, substring: &str) {
320 match msg {
321 Message::Text(text) => assert!(
322 text.contains(substring),
323 "Message '{}' does not contain '{}'",
324 text,
325 substring
326 ),
327 _ => panic!("Expected text message, got {:?}", msg),
328 }
329 }
330
331 pub fn assert_message_binary(msg: &Message, expected: &[u8]) {
333 match msg {
334 Message::Binary(data) => assert_eq!(data.as_ref(), expected),
335 _ => panic!("Expected binary message, got {:?}", msg),
336 }
337 }
338
339 pub fn assert_message_ping(msg: &Message) {
341 match msg {
342 Message::Ping(_) => {}
343 _ => panic!("Expected ping message, got {:?}", msg),
344 }
345 }
346
347 pub fn assert_message_pong(msg: &Message) {
349 match msg {
350 Message::Pong(_) => {}
351 _ => panic!("Expected pong message, got {:?}", msg),
352 }
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_url_with_query_token() {
362 let url = "ws://localhost:8080/ws";
363 let token = "my-token";
364 let expected = "ws://localhost:8080/ws?token=my-token";
365
366 let url_with_token = format!("{}?token={}", url, urlencoding::encode(token));
367 assert_eq!(url_with_token, expected);
368 }
369
370 #[test]
371 fn test_url_with_query_token_special_chars() {
372 let url = "ws://localhost:8080/ws";
373 let token = "token with spaces&special=chars";
374 let url_with_token = format!("{}?token={}", url, urlencoding::encode(token));
375 assert_eq!(
376 url_with_token,
377 "ws://localhost:8080/ws?token=token%20with%20spaces%26special%3Dchars"
378 );
379 }
380
381 #[test]
382 fn test_message_assertions() {
383 use assertions::*;
384
385 let text_msg = Message::text("Hello");
386 assert_message_text(&text_msg, "Hello");
387 assert_message_contains(&text_msg, "ell");
388
389 let binary_msg = Message::Binary(vec![1, 2, 3].into());
390 assert_message_binary(&binary_msg, &[1, 2, 3]);
391
392 let ping_msg = Message::Ping(vec![].into());
393 assert_message_ping(&ping_msg);
394
395 let pong_msg = Message::Pong(vec![].into());
396 assert_message_pong(&pong_msg);
397 }
398}