Skip to main content

reinhardt_testkit/
websocket.rs

1//! WebSocket test client and utilities for integration testing
2//!
3//! Provides WebSocket test client for end-to-end WebSocket testing with
4//! support for authentication, connection management, and message assertions.
5//!
6//! ## Usage Examples
7//!
8//! ### Basic WebSocket Connection
9//!
10//! ```rust,no_run
11//! use reinhardt_testkit::websocket::WebSocketTestClient;
12//! use rstest::*;
13//!
14//! #[rstest]
15//! #[tokio::test]
16//! async fn test_websocket_connection() {
17//!     let client = WebSocketTestClient::connect("ws://localhost:8080/ws").await.unwrap();
18//!     client.send_text("Hello").await.unwrap();
19//!     let response = client.receive_text().await.unwrap();
20//!     assert_eq!(response, "Hello");
21//! }
22//! ```
23//!
24//! ### WebSocket with Authentication
25//!
26//! ```rust,no_run
27//! use reinhardt_testkit::websocket::WebSocketTestClient;
28//!
29//! #[tokio::test]
30//! async fn test_websocket_auth() {
31//!     let client = WebSocketTestClient::connect_with_token(
32//!         "ws://localhost:8080/ws",
33//!         "my-auth-token"
34//!     ).await.unwrap();
35//!     // ...
36//! }
37//! ```
38
39use futures::{SinkExt, StreamExt};
40use std::io::{Error as IoError, ErrorKind};
41use tokio::time::{Duration, timeout};
42use tokio_tungstenite::{
43	MaybeTlsStream, WebSocketStream, connect_async,
44	tungstenite::{Error as WsError, Message},
45};
46
47/// WebSocket test client for integration testing
48///
49/// Provides high-level API for WebSocket connection management, message sending/receiving,
50/// and authentication.
51pub struct WebSocketTestClient {
52	/// WebSocket connection stream
53	stream: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
54	/// WebSocket URL
55	url: String,
56}
57
58impl WebSocketTestClient {
59	/// Connect to WebSocket server
60	///
61	/// # Example
62	/// ```rust,no_run
63	/// use reinhardt_testkit::websocket::WebSocketTestClient;
64	///
65	/// #[tokio::test]
66	/// async fn test_connect() {
67	///     let client = WebSocketTestClient::connect("ws://localhost:8080/ws")
68	///         .await
69	///         .unwrap();
70	/// }
71	/// ```
72	pub async fn connect(url: &str) -> Result<Self, WsError> {
73		let (stream, _response) = connect_async(url).await?;
74		Ok(Self {
75			stream,
76			url: url.to_string(),
77		})
78	}
79
80	/// Connect to WebSocket server with Bearer token authentication
81	///
82	/// Adds `Authorization: Bearer <token>` header to the WebSocket handshake request.
83	///
84	/// # Example
85	/// ```rust,no_run
86	/// use reinhardt_testkit::websocket::WebSocketTestClient;
87	///
88	/// #[tokio::test]
89	/// async fn test_auth() {
90	///     let client = WebSocketTestClient::connect_with_token(
91	///         "ws://localhost:8080/ws",
92	///         "my-secret-token"
93	///     )
94	///     .await
95	///     .unwrap();
96	/// }
97	/// ```
98	pub async fn connect_with_token(url: &str, token: &str) -> Result<Self, WsError> {
99		use tokio_tungstenite::tungstenite::http::Request;
100
101		let request = Request::builder()
102			.uri(url)
103			.header("Authorization", format!("Bearer {}", token))
104			.body(())
105			.expect("Failed to build WebSocket request");
106
107		let (stream, _response) = connect_async(request).await?;
108		Ok(Self {
109			stream,
110			url: url.to_string(),
111		})
112	}
113
114	/// Connect to WebSocket server with query parameter authentication
115	///
116	/// Appends `?token=<token>` to the URL.
117	///
118	/// # Example
119	/// ```rust,no_run
120	/// use reinhardt_testkit::websocket::WebSocketTestClient;
121	///
122	/// #[tokio::test]
123	/// async fn test_query_auth() {
124	///     let client = WebSocketTestClient::connect_with_query_token(
125	///         "ws://localhost:8080/ws",
126	///         "my-token"
127	///     )
128	///     .await
129	///     .unwrap();
130	/// }
131	/// ```
132	// Fixes #880: URL-encode token to prevent injection via query parameter
133	pub async fn connect_with_query_token(url: &str, token: &str) -> Result<Self, WsError> {
134		let url_with_token = format!("{}?token={}", url, urlencoding::encode(token));
135		Self::connect(&url_with_token).await
136	}
137
138	/// Connect to WebSocket server with cookie authentication
139	///
140	/// Adds `Cookie: <cookie_name>=<cookie_value>` header to the WebSocket handshake request.
141	///
142	/// # Example
143	/// ```rust,no_run
144	/// use reinhardt_testkit::websocket::WebSocketTestClient;
145	///
146	/// #[tokio::test]
147	/// async fn test_cookie_auth() {
148	///     let client = WebSocketTestClient::connect_with_cookie(
149	///         "ws://localhost:8080/ws",
150	///         "session_id",
151	///         "abc123"
152	///     )
153	///     .await
154	///     .unwrap();
155	/// }
156	/// ```
157	pub async fn connect_with_cookie(
158		url: &str,
159		cookie_name: &str,
160		cookie_value: &str,
161	) -> Result<Self, WsError> {
162		use tokio_tungstenite::tungstenite::http::Request;
163
164		let request = Request::builder()
165			.uri(url)
166			.header("Cookie", format!("{}={}", cookie_name, cookie_value))
167			.body(())
168			.expect("Failed to build WebSocket request");
169
170		let (stream, _response) = connect_async(request).await?;
171		Ok(Self {
172			stream,
173			url: url.to_string(),
174		})
175	}
176
177	/// Send text message to WebSocket server
178	///
179	/// # Example
180	/// ```rust,no_run
181	/// use reinhardt_testkit::websocket::WebSocketTestClient;
182	///
183	/// #[tokio::test]
184	/// async fn test_send() {
185	///     let mut client = WebSocketTestClient::connect("ws://localhost:8080/ws")
186	///         .await
187	///         .unwrap();
188	///     client.send_text("Hello").await.unwrap();
189	/// }
190	/// ```
191	pub async fn send_text(&mut self, text: &str) -> Result<(), WsError> {
192		self.stream.send(Message::text(text)).await
193	}
194
195	/// Send binary message to WebSocket server
196	pub async fn send_binary(&mut self, data: &[u8]) -> Result<(), WsError> {
197		self.stream.send(Message::binary(data.to_vec())).await
198	}
199
200	/// Send ping message to WebSocket server
201	pub async fn send_ping(&mut self, payload: &[u8]) -> Result<(), WsError> {
202		self.stream
203			.send(Message::Ping(payload.to_vec().into()))
204			.await
205	}
206
207	/// Send pong message to WebSocket server
208	pub async fn send_pong(&mut self, payload: &[u8]) -> Result<(), WsError> {
209		self.stream
210			.send(Message::Pong(payload.to_vec().into()))
211			.await
212	}
213
214	/// Receive next message from WebSocket server
215	///
216	/// Returns `None` if connection is closed.
217	pub async fn receive(&mut self) -> Option<Result<Message, WsError>> {
218		self.stream.next().await
219	}
220
221	/// Receive text message from WebSocket server with timeout
222	///
223	/// # Example
224	/// ```rust,no_run
225	/// use reinhardt_testkit::websocket::WebSocketTestClient;
226	///
227	/// #[tokio::test]
228	/// async fn test_receive() {
229	///     let mut client = WebSocketTestClient::connect("ws://localhost:8080/ws")
230	///         .await
231	///         .unwrap();
232	///     let text = client.receive_text().await.unwrap();
233	///     assert_eq!(text, "Welcome");
234	/// }
235	/// ```
236	pub async fn receive_text(&mut self) -> Result<String, WsError> {
237		self.receive_text_with_timeout(Duration::from_secs(5)).await
238	}
239
240	/// Receive text message with custom timeout
241	pub async fn receive_text_with_timeout(
242		&mut self,
243		duration: Duration,
244	) -> Result<String, WsError> {
245		match timeout(duration, self.stream.next()).await {
246			Ok(Some(Ok(Message::Text(text)))) => Ok(text.to_string()),
247			Ok(Some(Ok(msg))) => Err(WsError::Io(IoError::new(
248				ErrorKind::InvalidData,
249				format!("Expected text message, got {:?}", msg),
250			))),
251			Ok(Some(Err(e))) => Err(e),
252			Ok(None) => Err(WsError::ConnectionClosed),
253			Err(_) => Err(WsError::Io(IoError::new(
254				ErrorKind::TimedOut,
255				"Receive timeout",
256			))),
257		}
258	}
259
260	/// Receive binary message from WebSocket server with timeout
261	pub async fn receive_binary(&mut self) -> Result<Vec<u8>, WsError> {
262		self.receive_binary_with_timeout(Duration::from_secs(5))
263			.await
264	}
265
266	/// Receive binary message with custom timeout
267	pub async fn receive_binary_with_timeout(
268		&mut self,
269		duration: Duration,
270	) -> Result<Vec<u8>, WsError> {
271		match timeout(duration, self.stream.next()).await {
272			Ok(Some(Ok(Message::Binary(data)))) => Ok(data.to_vec()),
273			Ok(Some(Ok(msg))) => Err(WsError::Io(IoError::new(
274				ErrorKind::InvalidData,
275				format!("Expected binary message, got {:?}", msg),
276			))),
277			Ok(Some(Err(e))) => Err(e),
278			Ok(None) => Err(WsError::ConnectionClosed),
279			Err(_) => Err(WsError::Io(IoError::new(
280				ErrorKind::TimedOut,
281				"Receive timeout",
282			))),
283		}
284	}
285
286	/// Close WebSocket connection
287	pub async fn close(mut self) -> Result<(), WsError> {
288		self.stream.close(None).await
289	}
290
291	/// Get WebSocket URL
292	pub fn url(&self) -> &str {
293		&self.url
294	}
295}
296
297/// WebSocket message assertion utilities
298pub mod assertions {
299	use tokio_tungstenite::tungstenite::Message;
300
301	/// Assert that WebSocket message is text with expected content
302	///
303	/// # Example
304	/// ```rust,no_run
305	/// use reinhardt_testkit::websocket::assertions::assert_message_text;
306	/// use tokio_tungstenite::tungstenite::Message;
307	///
308	/// let msg = Message::text("Hello");
309	/// assert_message_text(&msg, "Hello");
310	/// ```
311	pub fn assert_message_text(msg: &Message, expected: &str) {
312		match msg {
313			Message::Text(text) => assert_eq!(text.as_str(), expected),
314			_ => panic!("Expected text message, got {:?}", msg),
315		}
316	}
317
318	/// Assert that WebSocket message is text containing substring
319	pub fn assert_message_contains(msg: &Message, substring: &str) {
320		match msg {
321			Message::Text(text) => assert!(
322				text.contains(substring),
323				"Message '{}' does not contain '{}'",
324				text,
325				substring
326			),
327			_ => panic!("Expected text message, got {:?}", msg),
328		}
329	}
330
331	/// Assert that WebSocket message is binary with expected data
332	pub fn assert_message_binary(msg: &Message, expected: &[u8]) {
333		match msg {
334			Message::Binary(data) => assert_eq!(data.as_ref(), expected),
335			_ => panic!("Expected binary message, got {:?}", msg),
336		}
337	}
338
339	/// Assert that WebSocket message is ping
340	pub fn assert_message_ping(msg: &Message) {
341		match msg {
342			Message::Ping(_) => {}
343			_ => panic!("Expected ping message, got {:?}", msg),
344		}
345	}
346
347	/// Assert that WebSocket message is pong
348	pub fn assert_message_pong(msg: &Message) {
349		match msg {
350			Message::Pong(_) => {}
351			_ => panic!("Expected pong message, got {:?}", msg),
352		}
353	}
354}
355
356#[cfg(test)]
357mod tests {
358	use super::*;
359
360	#[test]
361	fn test_url_with_query_token() {
362		let url = "ws://localhost:8080/ws";
363		let token = "my-token";
364		let expected = "ws://localhost:8080/ws?token=my-token";
365
366		let url_with_token = format!("{}?token={}", url, urlencoding::encode(token));
367		assert_eq!(url_with_token, expected);
368	}
369
370	#[test]
371	fn test_url_with_query_token_special_chars() {
372		let url = "ws://localhost:8080/ws";
373		let token = "token with spaces&special=chars";
374		let url_with_token = format!("{}?token={}", url, urlencoding::encode(token));
375		assert_eq!(
376			url_with_token,
377			"ws://localhost:8080/ws?token=token%20with%20spaces%26special%3Dchars"
378		);
379	}
380
381	#[test]
382	fn test_message_assertions() {
383		use assertions::*;
384
385		let text_msg = Message::text("Hello");
386		assert_message_text(&text_msg, "Hello");
387		assert_message_contains(&text_msg, "ell");
388
389		let binary_msg = Message::Binary(vec![1, 2, 3].into());
390		assert_message_binary(&binary_msg, &[1, 2, 3]);
391
392		let ping_msg = Message::Ping(vec![].into());
393		assert_message_ping(&ping_msg);
394
395		let pong_msg = Message::Pong(vec![].into());
396		assert_message_pong(&pong_msg);
397	}
398}