Skip to main content

reinhardt_test/
websocket.rs

1//! WebSocket test client and utilities for integration testing
2//!
3//! Provides WebSocket test client for end-to-end WebSocket testing with
4//! support for authentication, connection management, and message assertions.
5//!
6//! ## Usage Examples
7//!
8//! ### Basic WebSocket Connection
9//!
10//! ```rust,no_run
11//! use reinhardt_test::websocket::WebSocketTestClient;
12//! use rstest::*;
13//!
14//! #[rstest]
15//! #[tokio::test]
16//! async fn test_websocket_connection() {
17//!     let client = WebSocketTestClient::connect("ws://localhost:8080/ws").await.unwrap();
18//!     client.send_text("Hello").await.unwrap();
19//!     let response = client.receive_text().await.unwrap();
20//!     assert_eq!(response, "Hello");
21//! }
22//! ```
23//!
24//! ### WebSocket with Authentication
25//!
26//! ```rust,no_run
27//! use reinhardt_test::websocket::WebSocketTestClient;
28//!
29//! #[tokio::test]
30//! async fn test_websocket_auth() {
31//!     let client = WebSocketTestClient::connect_with_token(
32//!         "ws://localhost:8080/ws",
33//!         "my-auth-token"
34//!     ).await.unwrap();
35//!     // ...
36//! }
37//! ```
38
39use futures::{SinkExt, StreamExt};
40use std::io::{Error as IoError, ErrorKind};
41use tokio::time::{Duration, timeout};
42use tokio_tungstenite::{
43	MaybeTlsStream, WebSocketStream, connect_async,
44	tungstenite::{Error as WsError, Message},
45};
46
47/// WebSocket test client for integration testing
48///
49/// Provides high-level API for WebSocket connection management, message sending/receiving,
50/// and authentication.
51pub struct WebSocketTestClient {
52	/// WebSocket connection stream
53	stream: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
54	/// WebSocket URL
55	url: String,
56}
57
58impl WebSocketTestClient {
59	/// Connect to WebSocket server
60	///
61	/// # Example
62	/// ```rust,no_run
63	/// use reinhardt_test::websocket::WebSocketTestClient;
64	///
65	/// #[tokio::test]
66	/// async fn test_connect() {
67	///     let client = WebSocketTestClient::connect("ws://localhost:8080/ws")
68	///         .await
69	///         .unwrap();
70	/// }
71	/// ```
72	pub async fn connect(url: &str) -> Result<Self, WsError> {
73		let (stream, _response) = connect_async(url).await?;
74		Ok(Self {
75			stream,
76			url: url.to_string(),
77		})
78	}
79
80	/// Connect to WebSocket server with Bearer token authentication
81	///
82	/// Adds `Authorization: Bearer <token>` header to the WebSocket handshake request.
83	///
84	/// # Example
85	/// ```rust,no_run
86	/// use reinhardt_test::websocket::WebSocketTestClient;
87	///
88	/// #[tokio::test]
89	/// async fn test_auth() {
90	///     let client = WebSocketTestClient::connect_with_token(
91	///         "ws://localhost:8080/ws",
92	///         "my-secret-token"
93	///     )
94	///     .await
95	///     .unwrap();
96	/// }
97	/// ```
98	pub async fn connect_with_token(url: &str, token: &str) -> Result<Self, WsError> {
99		use tokio_tungstenite::tungstenite::http::Request;
100
101		let request = Request::builder()
102			.uri(url)
103			.header("Authorization", format!("Bearer {}", token))
104			.body(())
105			.expect("Failed to build WebSocket request");
106
107		let (stream, _response) = connect_async(request).await?;
108		Ok(Self {
109			stream,
110			url: url.to_string(),
111		})
112	}
113
114	/// Connect to WebSocket server with query parameter authentication
115	///
116	/// Appends `?token=<token>` to the URL.
117	///
118	/// # Example
119	/// ```rust,no_run
120	/// use reinhardt_test::websocket::WebSocketTestClient;
121	///
122	/// #[tokio::test]
123	/// async fn test_query_auth() {
124	///     let client = WebSocketTestClient::connect_with_query_token(
125	///         "ws://localhost:8080/ws",
126	///         "my-token"
127	///     )
128	///     .await
129	///     .unwrap();
130	/// }
131	/// ```
132	// Fixes #880: URL-encode token to prevent injection via query parameter
133	pub async fn connect_with_query_token(url: &str, token: &str) -> Result<Self, WsError> {
134		let url_with_token = format!("{}?token={}", url, urlencoding::encode(token));
135		Self::connect(&url_with_token).await
136	}
137
138	/// Connect to WebSocket server with cookie authentication
139	///
140	/// Adds `Cookie: <cookie_name>=<cookie_value>` header to the WebSocket handshake request.
141	///
142	/// # Example
143	/// ```rust,no_run
144	/// use reinhardt_test::websocket::WebSocketTestClient;
145	///
146	/// #[tokio::test]
147	/// async fn test_cookie_auth() {
148	///     let client = WebSocketTestClient::connect_with_cookie(
149	///         "ws://localhost:8080/ws",
150	///         "session_id",
151	///         "abc123"
152	///     )
153	///     .await
154	///     .unwrap();
155	/// }
156	/// ```
157	pub async fn connect_with_cookie(
158		url: &str,
159		cookie_name: &str,
160		cookie_value: &str,
161	) -> Result<Self, WsError> {
162		use tokio_tungstenite::tungstenite::http::Request;
163
164		let request = Request::builder()
165			.uri(url)
166			.header("Cookie", format!("{}={}", cookie_name, cookie_value))
167			.body(())
168			.expect("Failed to build WebSocket request");
169
170		let (stream, _response) = connect_async(request).await?;
171		Ok(Self {
172			stream,
173			url: url.to_string(),
174		})
175	}
176
177	/// Send text message to WebSocket server
178	///
179	/// # Example
180	/// ```rust,no_run
181	/// use reinhardt_test::websocket::WebSocketTestClient;
182	///
183	/// #[tokio::test]
184	/// async fn test_send() {
185	///     let mut client = WebSocketTestClient::connect("ws://localhost:8080/ws")
186	///         .await
187	///         .unwrap();
188	///     client.send_text("Hello").await.unwrap();
189	/// }
190	/// ```
191	pub async fn send_text(&mut self, text: &str) -> Result<(), WsError> {
192		self.stream.send(Message::Text(text.to_string())).await
193	}
194
195	/// Send binary message to WebSocket server
196	pub async fn send_binary(&mut self, data: &[u8]) -> Result<(), WsError> {
197		self.stream.send(Message::Binary(data.to_vec())).await
198	}
199
200	/// Send ping message to WebSocket server
201	pub async fn send_ping(&mut self, payload: &[u8]) -> Result<(), WsError> {
202		self.stream.send(Message::Ping(payload.to_vec())).await
203	}
204
205	/// Send pong message to WebSocket server
206	pub async fn send_pong(&mut self, payload: &[u8]) -> Result<(), WsError> {
207		self.stream.send(Message::Pong(payload.to_vec())).await
208	}
209
210	/// Receive next message from WebSocket server
211	///
212	/// Returns `None` if connection is closed.
213	pub async fn receive(&mut self) -> Option<Result<Message, WsError>> {
214		self.stream.next().await
215	}
216
217	/// Receive text message from WebSocket server with timeout
218	///
219	/// # Example
220	/// ```rust,no_run
221	/// use reinhardt_test::websocket::WebSocketTestClient;
222	///
223	/// #[tokio::test]
224	/// async fn test_receive() {
225	///     let mut client = WebSocketTestClient::connect("ws://localhost:8080/ws")
226	///         .await
227	///         .unwrap();
228	///     let text = client.receive_text().await.unwrap();
229	///     assert_eq!(text, "Welcome");
230	/// }
231	/// ```
232	pub async fn receive_text(&mut self) -> Result<String, WsError> {
233		self.receive_text_with_timeout(Duration::from_secs(5)).await
234	}
235
236	/// Receive text message with custom timeout
237	pub async fn receive_text_with_timeout(
238		&mut self,
239		duration: Duration,
240	) -> Result<String, WsError> {
241		match timeout(duration, self.stream.next()).await {
242			Ok(Some(Ok(Message::Text(text)))) => Ok(text),
243			Ok(Some(Ok(msg))) => Err(WsError::Io(IoError::new(
244				ErrorKind::InvalidData,
245				format!("Expected text message, got {:?}", msg),
246			))),
247			Ok(Some(Err(e))) => Err(e),
248			Ok(None) => Err(WsError::ConnectionClosed),
249			Err(_) => Err(WsError::Io(IoError::new(
250				ErrorKind::TimedOut,
251				"Receive timeout",
252			))),
253		}
254	}
255
256	/// Receive binary message from WebSocket server with timeout
257	pub async fn receive_binary(&mut self) -> Result<Vec<u8>, WsError> {
258		self.receive_binary_with_timeout(Duration::from_secs(5))
259			.await
260	}
261
262	/// Receive binary message with custom timeout
263	pub async fn receive_binary_with_timeout(
264		&mut self,
265		duration: Duration,
266	) -> Result<Vec<u8>, WsError> {
267		match timeout(duration, self.stream.next()).await {
268			Ok(Some(Ok(Message::Binary(data)))) => Ok(data),
269			Ok(Some(Ok(msg))) => Err(WsError::Io(IoError::new(
270				ErrorKind::InvalidData,
271				format!("Expected binary message, got {:?}", msg),
272			))),
273			Ok(Some(Err(e))) => Err(e),
274			Ok(None) => Err(WsError::ConnectionClosed),
275			Err(_) => Err(WsError::Io(IoError::new(
276				ErrorKind::TimedOut,
277				"Receive timeout",
278			))),
279		}
280	}
281
282	/// Close WebSocket connection
283	pub async fn close(mut self) -> Result<(), WsError> {
284		self.stream.close(None).await
285	}
286
287	/// Get WebSocket URL
288	pub fn url(&self) -> &str {
289		&self.url
290	}
291}
292
293/// WebSocket message assertion utilities
294pub mod assertions {
295	use tokio_tungstenite::tungstenite::Message;
296
297	/// Assert that WebSocket message is text with expected content
298	///
299	/// # Example
300	/// ```rust,no_run
301	/// use reinhardt_test::websocket::assertions::assert_message_text;
302	/// use tokio_tungstenite::tungstenite::Message;
303	///
304	/// let msg = Message::Text("Hello".to_string());
305	/// assert_message_text(&msg, "Hello");
306	/// ```
307	pub fn assert_message_text(msg: &Message, expected: &str) {
308		match msg {
309			Message::Text(text) => assert_eq!(text, expected),
310			_ => panic!("Expected text message, got {:?}", msg),
311		}
312	}
313
314	/// Assert that WebSocket message is text containing substring
315	pub fn assert_message_contains(msg: &Message, substring: &str) {
316		match msg {
317			Message::Text(text) => assert!(
318				text.contains(substring),
319				"Message '{}' does not contain '{}'",
320				text,
321				substring
322			),
323			_ => panic!("Expected text message, got {:?}", msg),
324		}
325	}
326
327	/// Assert that WebSocket message is binary with expected data
328	pub fn assert_message_binary(msg: &Message, expected: &[u8]) {
329		match msg {
330			Message::Binary(data) => assert_eq!(data, expected),
331			_ => panic!("Expected binary message, got {:?}", msg),
332		}
333	}
334
335	/// Assert that WebSocket message is ping
336	pub fn assert_message_ping(msg: &Message) {
337		match msg {
338			Message::Ping(_) => {}
339			_ => panic!("Expected ping message, got {:?}", msg),
340		}
341	}
342
343	/// Assert that WebSocket message is pong
344	pub fn assert_message_pong(msg: &Message) {
345		match msg {
346			Message::Pong(_) => {}
347			_ => panic!("Expected pong message, got {:?}", msg),
348		}
349	}
350}
351
352#[cfg(test)]
353mod tests {
354	use super::*;
355
356	#[test]
357	fn test_url_with_query_token() {
358		let url = "ws://localhost:8080/ws";
359		let token = "my-token";
360		let expected = "ws://localhost:8080/ws?token=my-token";
361
362		let url_with_token = format!("{}?token={}", url, urlencoding::encode(token));
363		assert_eq!(url_with_token, expected);
364	}
365
366	#[test]
367	fn test_url_with_query_token_special_chars() {
368		let url = "ws://localhost:8080/ws";
369		let token = "token with spaces&special=chars";
370		let url_with_token = format!("{}?token={}", url, urlencoding::encode(token));
371		assert_eq!(
372			url_with_token,
373			"ws://localhost:8080/ws?token=token%20with%20spaces%26special%3Dchars"
374		);
375	}
376
377	#[test]
378	fn test_message_assertions() {
379		use assertions::*;
380
381		let text_msg = Message::Text("Hello".to_string());
382		assert_message_text(&text_msg, "Hello");
383		assert_message_contains(&text_msg, "ell");
384
385		let binary_msg = Message::Binary(vec![1, 2, 3]);
386		assert_message_binary(&binary_msg, &[1, 2, 3]);
387
388		let ping_msg = Message::Ping(vec![]);
389		assert_message_ping(&ping_msg);
390
391		let pong_msg = Message::Pong(vec![]);
392		assert_message_pong(&pong_msg);
393	}
394}