Skip to main content

reinhardt_websockets/
consumers.rs

1//! WebSocket consumers for advanced message handling patterns
2//!
3//! This module provides consumer classes inspired by Django Channels,
4//! enabling structured WebSocket message handling with lifecycle hooks.
5//!
6//! # Dependency Injection Support
7//!
8//! When the `di` feature is enabled, `ConsumerContext` supports dependency injection:
9//!
10//! ```rust,no_run
11//! use reinhardt_websockets::consumers::{ConsumerContext, WebSocketConsumer};
12//! use reinhardt_websockets::{Message, WebSocketResult};
13//! use async_trait::async_trait;
14//! use std::sync::Arc;
15//!
16//! # type DatabaseConnection = ();
17//! # type CacheService = ();
18//! # struct MyConsumer;
19//! #
20//! # #[async_trait]
21//! # impl WebSocketConsumer for MyConsumer {
22//! #     async fn on_connect(&self, _ctx: &mut ConsumerContext) -> WebSocketResult<()> {
23//! #         Ok(())
24//! #     }
25//! #
26//! async fn on_message(&self, ctx: &mut ConsumerContext, msg: Message) -> WebSocketResult<()> {
27//!     // Resolve dependencies from DI context
28//!     // let db: Arc<DatabaseConnection> = ctx.resolve().await?;
29//!     // let cache: CacheService = ctx.resolve_uncached().await?;
30//!
31//!     // Use the dependencies...
32//!     Ok(())
33//! }
34//! #
35//! #     async fn on_disconnect(&self, _ctx: &mut ConsumerContext) -> WebSocketResult<()> {
36//! #         Ok(())
37//! #     }
38//! # }
39//! ```
40// WebSocketError is used in #[cfg(feature = "di")] code
41#[allow(unused_imports)]
42use crate::connection::{Message, WebSocketConnection, WebSocketError, WebSocketResult};
43use async_trait::async_trait;
44use std::sync::Arc;
45
46#[cfg(feature = "di")]
47use reinhardt_di::{Injectable, Injected, InjectionContext};
48
49/// Consumer context containing connection and message information
50///
51/// This context is passed to WebSocket consumer methods and provides access to:
52/// - The WebSocket connection for sending messages
53/// - HTTP handshake headers (e.g., Cookie, Origin)
54/// - Metadata for storing request-scoped data
55/// - Dependency injection (when the `di` feature is enabled)
56///
57/// # HTTP Headers
58///
59/// Headers from the WebSocket HTTP handshake can be stored and retrieved:
60///
61/// ```
62/// # use reinhardt_websockets::consumers::ConsumerContext;
63/// # use reinhardt_websockets::WebSocketConnection;
64/// # use tokio::sync::mpsc;
65/// # use std::sync::Arc;
66/// #
67/// # let (tx, _rx) = mpsc::unbounded_channel();
68/// # let conn = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx));
69/// let context = ConsumerContext::new(conn)
70///     .with_header("cookie".to_string(), "sessionid=abc123".to_string());
71///
72/// assert_eq!(context.cookie_header(), Some("sessionid=abc123"));
73/// ```
74pub struct ConsumerContext {
75	/// The WebSocket connection
76	pub connection: Arc<WebSocketConnection>,
77	/// HTTP handshake headers (e.g., Cookie, Origin)
78	pub headers: std::collections::HashMap<String, String>,
79	/// Additional metadata
80	pub metadata: std::collections::HashMap<String, String>,
81	/// DI context for dependency injection (when `di` feature is enabled)
82	#[cfg(feature = "di")]
83	di_context: Option<Arc<InjectionContext>>,
84}
85
86impl ConsumerContext {
87	/// Create a new consumer context
88	///
89	/// # Examples
90	///
91	/// ```
92	/// use reinhardt_websockets::consumers::ConsumerContext;
93	/// use reinhardt_websockets::WebSocketConnection;
94	/// use tokio::sync::mpsc;
95	/// use std::sync::Arc;
96	///
97	/// let (tx, _rx) = mpsc::unbounded_channel();
98	/// let conn = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx));
99	/// let context = ConsumerContext::new(conn);
100	/// ```
101	pub fn new(connection: Arc<WebSocketConnection>) -> Self {
102		Self {
103			connection,
104			headers: std::collections::HashMap::new(),
105			metadata: std::collections::HashMap::new(),
106			#[cfg(feature = "di")]
107			di_context: None,
108		}
109	}
110
111	/// Create a new consumer context with DI context
112	///
113	/// This constructor is used when dependency injection is needed in WebSocket handlers.
114	///
115	/// # Examples
116	///
117	/// ```ignore
118	/// use reinhardt_websockets::consumers::ConsumerContext;
119	/// use reinhardt_di::{InjectionContext, SingletonScope};
120	/// use std::sync::Arc;
121	///
122	/// let singleton = Arc::new(SingletonScope::new());
123	/// let di_ctx = Arc::new(InjectionContext::builder(singleton).build());
124	/// let context = ConsumerContext::with_di_context(connection, di_ctx);
125	/// ```
126	#[cfg(feature = "di")]
127	pub fn with_di_context(
128		connection: Arc<WebSocketConnection>,
129		di_context: Arc<InjectionContext>,
130	) -> Self {
131		Self {
132			connection,
133			headers: std::collections::HashMap::new(),
134			metadata: std::collections::HashMap::new(),
135			di_context: Some(di_context),
136		}
137	}
138
139	/// Add an HTTP handshake header to the context
140	///
141	/// Header names are stored as-is. For case-insensitive lookup,
142	/// callers should normalize keys to lowercase before insertion.
143	pub fn with_header(mut self, key: String, value: String) -> Self {
144		self.headers.insert(key, value);
145		self
146	}
147
148	/// Get an HTTP handshake header value
149	pub fn get_header(&self, key: &str) -> Option<&String> {
150		self.headers.get(key)
151	}
152
153	/// Get the Cookie header from the HTTP handshake
154	///
155	/// Convenience method equivalent to `get_header("cookie")`.
156	pub fn cookie_header(&self) -> Option<&str> {
157		self.headers.get("cookie").map(|s| s.as_str())
158	}
159
160	/// Add metadata to the context
161	pub fn with_metadata(mut self, key: String, value: String) -> Self {
162		self.metadata.insert(key, value);
163		self
164	}
165
166	/// Get metadata value
167	pub fn get_metadata(&self, key: &str) -> Option<&String> {
168		self.metadata.get(key)
169	}
170
171	/// Get the DI context if available
172	#[cfg(feature = "di")]
173	pub fn di_context(&self) -> Option<&Arc<InjectionContext>> {
174		self.di_context.as_ref()
175	}
176
177	/// Set the DI context
178	#[cfg(feature = "di")]
179	pub fn set_di_context(&mut self, ctx: Arc<InjectionContext>) {
180		self.di_context = Some(ctx);
181	}
182
183	/// Resolve a dependency with caching
184	///
185	/// This method extracts the dependency from the DI context. The resolved
186	/// dependency is cached for the duration of the connection.
187	///
188	/// # Errors
189	///
190	/// Returns an error if:
191	/// - The DI context is not set
192	/// - The dependency cannot be resolved
193	///
194	/// # Examples
195	///
196	/// ```ignore
197	/// let db: Arc<DatabaseConnection> = ctx.resolve().await?;
198	/// ```
199	#[cfg(feature = "di")]
200	pub async fn resolve<T>(&self) -> WebSocketResult<T>
201	where
202		T: Injectable + Clone + Send + Sync + 'static,
203	{
204		let ctx = self
205			.di_context
206			.as_ref()
207			.ok_or_else(|| WebSocketError::Internal("DI context not available".to_string()))?;
208
209		Injected::<T>::resolve(ctx)
210			.await
211			.map(|injected| injected.into_inner())
212			.map_err(|_| WebSocketError::Internal("dependency resolution failed".to_string()))
213	}
214
215	/// Resolve a dependency without caching
216	///
217	/// This method is similar to `resolve()` but creates a fresh instance
218	/// of the dependency each time.
219	///
220	/// # Errors
221	///
222	/// Returns an error if:
223	/// - The DI context is not set
224	/// - The dependency cannot be resolved
225	///
226	/// # Examples
227	///
228	/// ```ignore
229	/// let fresh_service: MyService = ctx.resolve_uncached().await?;
230	/// ```
231	#[cfg(feature = "di")]
232	pub async fn resolve_uncached<T>(&self) -> WebSocketResult<T>
233	where
234		T: Injectable + Clone + Send + Sync + 'static,
235	{
236		let ctx = self
237			.di_context
238			.as_ref()
239			.ok_or_else(|| WebSocketError::Internal("DI context not available".to_string()))?;
240
241		Injected::<T>::resolve_uncached(ctx)
242			.await
243			.map(|injected| injected.into_inner())
244			.map_err(|_| WebSocketError::Internal("dependency resolution failed".to_string()))
245	}
246
247	/// Try to resolve a dependency, returning None if DI context is not available
248	///
249	/// This is useful for optional dependencies or when you want to gracefully
250	/// handle the case where DI is not configured.
251	///
252	/// # Examples
253	///
254	/// ```ignore
255	/// if let Some(cache) = ctx.try_resolve::<CacheService>().await {
256	///     // Use cache
257	/// } else {
258	///     // Fallback without cache
259	/// }
260	/// ```
261	#[cfg(feature = "di")]
262	pub async fn try_resolve<T>(&self) -> Option<T>
263	where
264		T: Injectable + Clone + Send + Sync + 'static,
265	{
266		let ctx = self.di_context.as_ref()?;
267
268		Injected::<T>::resolve(ctx)
269			.await
270			.ok()
271			.map(|injected| injected.into_inner())
272	}
273}
274
275/// WebSocket consumer trait
276///
277/// Consumers handle the lifecycle of WebSocket connections and messages.
278#[async_trait]
279pub trait WebSocketConsumer: Send + Sync {
280	/// Called when a WebSocket connection is established
281	async fn on_connect(&self, context: &mut ConsumerContext) -> WebSocketResult<()>;
282
283	/// Called when a message is received
284	async fn on_message(
285		&self,
286		context: &mut ConsumerContext,
287		message: Message,
288	) -> WebSocketResult<()>;
289
290	/// Called when a WebSocket connection is closed
291	async fn on_disconnect(&self, context: &mut ConsumerContext) -> WebSocketResult<()>;
292}
293
294/// Echo consumer that echoes all messages back to the sender
295///
296/// # Examples
297///
298/// ```
299/// use reinhardt_websockets::consumers::{EchoConsumer, WebSocketConsumer, ConsumerContext};
300/// use reinhardt_websockets::{Message, WebSocketConnection};
301/// use tokio::sync::mpsc;
302/// use std::sync::Arc;
303///
304/// # tokio_test::block_on(async {
305/// let consumer = EchoConsumer::new();
306/// let (tx, mut rx) = mpsc::unbounded_channel();
307/// let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
308/// let mut context = ConsumerContext::new(conn);
309///
310/// let msg = Message::text("Hello".to_string());
311/// consumer.on_message(&mut context, msg).await.unwrap();
312///
313/// let received = rx.recv().await.unwrap();
314/// match received {
315///     Message::Text { data } => assert_eq!(data, "Echo: Hello"),
316///     _ => panic!("Expected text message"),
317/// }
318/// # });
319/// ```
320pub struct EchoConsumer {
321	prefix: String,
322}
323
324impl EchoConsumer {
325	/// Create a new echo consumer
326	pub fn new() -> Self {
327		Self {
328			prefix: "Echo".to_string(),
329		}
330	}
331
332	/// Create a new echo consumer with custom prefix
333	pub fn with_prefix(prefix: String) -> Self {
334		Self { prefix }
335	}
336}
337
338impl Default for EchoConsumer {
339	fn default() -> Self {
340		Self::new()
341	}
342}
343
344#[async_trait]
345impl WebSocketConsumer for EchoConsumer {
346	async fn on_connect(&self, context: &mut ConsumerContext) -> WebSocketResult<()> {
347		context
348			.connection
349			.send_text(format!("{}: Connection established", self.prefix))
350			.await
351	}
352
353	async fn on_message(
354		&self,
355		context: &mut ConsumerContext,
356		message: Message,
357	) -> WebSocketResult<()> {
358		match message {
359			Message::Text { data } => {
360				context
361					.connection
362					.send_text(format!("{}: {}", self.prefix, data))
363					.await
364			}
365			Message::Binary { data } => {
366				// Validate binary payload and attempt UTF-8 conversion
367				match String::from_utf8(data.clone()) {
368					Ok(text) => {
369						context
370							.connection
371							.send_text(format!("{}: {}", self.prefix, text))
372							.await
373					}
374					Err(_) => {
375						// Non-UTF-8 binary: echo back a summary with byte count
376						context
377							.connection
378							.send_text(format!("{}: binary({} bytes)", self.prefix, data.len()))
379							.await
380					}
381				}
382			}
383			Message::Close { code, reason } => {
384				// Acknowledge close and ensure cleanup
385				context
386					.connection
387					.close_with_reason(code, reason)
388					.await
389					.ok();
390				Ok(())
391			}
392			_ => Ok(()),
393		}
394	}
395
396	async fn on_disconnect(&self, _context: &mut ConsumerContext) -> WebSocketResult<()> {
397		Ok(())
398	}
399}
400
401/// Broadcast consumer that broadcasts messages to all connections in a group
402///
403/// # Examples
404///
405/// ```
406/// use reinhardt_websockets::consumers::{BroadcastConsumer, WebSocketConsumer, ConsumerContext};
407/// use reinhardt_websockets::{Message, WebSocketConnection};
408/// use reinhardt_websockets::room::Room;
409/// use tokio::sync::mpsc;
410/// use std::sync::Arc;
411///
412/// # tokio_test::block_on(async {
413/// let room = Arc::new(Room::new("chat".to_string()));
414/// let consumer = BroadcastConsumer::new(room.clone());
415///
416/// let (tx1, mut rx1) = mpsc::unbounded_channel();
417/// let (tx2, mut rx2) = mpsc::unbounded_channel();
418///
419/// let conn1 = Arc::new(WebSocketConnection::new("user1".to_string(), tx1));
420/// let conn2 = Arc::new(WebSocketConnection::new("user2".to_string(), tx2));
421///
422/// room.join("user1".to_string(), conn1.clone()).await.unwrap();
423/// room.join("user2".to_string(), conn2.clone()).await.unwrap();
424///
425/// let mut context = ConsumerContext::new(conn1);
426/// let msg = Message::text("Hello everyone".to_string());
427///
428/// consumer.on_message(&mut context, msg).await.unwrap();
429///
430/// // Both connections should receive the broadcast
431/// assert!(rx1.try_recv().is_ok());
432/// assert!(rx2.try_recv().is_ok());
433/// # });
434/// ```
435pub struct BroadcastConsumer {
436	room: Arc<crate::room::Room>,
437}
438
439impl BroadcastConsumer {
440	/// Create a new broadcast consumer
441	pub fn new(room: Arc<crate::room::Room>) -> Self {
442		Self { room }
443	}
444}
445
446#[async_trait]
447impl WebSocketConsumer for BroadcastConsumer {
448	async fn on_connect(&self, context: &mut ConsumerContext) -> WebSocketResult<()> {
449		let client_id = context.connection.id().to_string();
450		self.room
451			.join(client_id.clone(), context.connection.clone())
452			.await
453			.map_err(|e| crate::connection::WebSocketError::Connection(e.to_string()))?;
454
455		context
456			.connection
457			.send_text("Joined broadcast room".to_string())
458			.await
459	}
460
461	async fn on_message(
462		&self,
463		_context: &mut ConsumerContext,
464		message: Message,
465	) -> WebSocketResult<()> {
466		let result = self.room.broadcast(message).await;
467		if result.is_complete_failure() {
468			return Err(crate::connection::WebSocketError::Send(
469				"broadcast failed for all clients".to_string(),
470			));
471		}
472		Ok(())
473	}
474
475	async fn on_disconnect(&self, context: &mut ConsumerContext) -> WebSocketResult<()> {
476		let client_id = context.connection.id();
477		// Best-effort leave: the client may already have been removed by
478		// broadcast failure cleanup, so ignore ClientNotFound errors.
479		let _ = self.room.leave(client_id).await;
480
481		// Ensure the connection is marked as closed even on abnormal disconnect
482		context.connection.force_close().await;
483
484		Ok(())
485	}
486}
487
488/// JSON consumer that parses and serializes JSON messages
489///
490/// # Examples
491///
492/// ```
493/// use reinhardt_websockets::consumers::{JsonConsumer, WebSocketConsumer, ConsumerContext};
494/// use reinhardt_websockets::{Message, WebSocketConnection};
495/// use tokio::sync::mpsc;
496/// use std::sync::Arc;
497/// use serde::{Serialize, Deserialize};
498///
499/// #[derive(Serialize, Deserialize, Debug, PartialEq)]
500/// struct ChatMessage {
501///     user: String,
502///     text: String,
503/// }
504///
505/// # tokio_test::block_on(async {
506/// let consumer = JsonConsumer::new();
507/// let (tx, mut rx) = mpsc::unbounded_channel();
508/// let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
509/// let mut context = ConsumerContext::new(conn);
510///
511/// let msg = ChatMessage {
512///     user: "Alice".to_string(),
513///     text: "Hello".to_string(),
514/// };
515///
516/// let json_msg = Message::json(&msg).unwrap();
517/// consumer.on_message(&mut context, json_msg).await.unwrap();
518/// # });
519/// ```
520pub struct JsonConsumer;
521
522impl JsonConsumer {
523	/// Create a new JSON consumer
524	pub fn new() -> Self {
525		Self
526	}
527}
528
529impl Default for JsonConsumer {
530	fn default() -> Self {
531		Self::new()
532	}
533}
534
535#[async_trait]
536impl WebSocketConsumer for JsonConsumer {
537	async fn on_connect(&self, context: &mut ConsumerContext) -> WebSocketResult<()> {
538		context
539			.connection
540			.send_json(&serde_json::json!({
541				"type": "connection",
542				"status": "connected"
543			}))
544			.await
545	}
546
547	async fn on_message(
548		&self,
549		context: &mut ConsumerContext,
550		message: Message,
551	) -> WebSocketResult<()> {
552		match message {
553			Message::Text { data } => {
554				// Try to parse as JSON
555				let json: serde_json::Value = serde_json::from_str(&data)
556					.map_err(|e| crate::connection::WebSocketError::Protocol(e.to_string()))?;
557
558				// Echo back with metadata
559				let response = serde_json::json!({
560					"type": "echo",
561					"data": json,
562					"timestamp": chrono::Utc::now().to_rfc3339()
563				});
564
565				context.connection.send_json(&response).await
566			}
567			Message::Binary { data } => {
568				// Validate that binary data is valid UTF-8 JSON
569				let text = String::from_utf8(data).map_err(|e| {
570					crate::connection::WebSocketError::BinaryPayload(format!(
571						"binary payload is not valid UTF-8: {}",
572						e
573					))
574				})?;
575
576				let json: serde_json::Value = serde_json::from_str(&text).map_err(|e| {
577					crate::connection::WebSocketError::BinaryPayload(format!(
578						"binary payload is not valid JSON: {}",
579						e
580					))
581				})?;
582
583				let response = serde_json::json!({
584					"type": "echo",
585					"data": json,
586					"source": "binary",
587					"timestamp": chrono::Utc::now().to_rfc3339()
588				});
589
590				context.connection.send_json(&response).await
591			}
592			_ => Ok(()),
593		}
594	}
595
596	async fn on_disconnect(&self, _context: &mut ConsumerContext) -> WebSocketResult<()> {
597		Ok(())
598	}
599}
600
601/// Consumer chain for composing multiple consumers
602///
603/// # Examples
604///
605/// ```
606/// use reinhardt_websockets::consumers::{ConsumerChain, EchoConsumer, ConsumerContext, WebSocketConsumer};
607/// use reinhardt_websockets::WebSocketConnection;
608/// use tokio::sync::mpsc;
609/// use std::sync::Arc;
610///
611/// # tokio_test::block_on(async {
612/// let mut chain = ConsumerChain::new();
613/// chain.add_consumer(Box::new(EchoConsumer::new()));
614///
615/// let (tx, _rx) = mpsc::unbounded_channel();
616/// let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
617/// let mut context = ConsumerContext::new(conn);
618///
619/// assert!(chain.on_connect(&mut context).await.is_ok());
620/// # });
621/// ```
622pub struct ConsumerChain {
623	consumers: Vec<Box<dyn WebSocketConsumer>>,
624}
625
626impl ConsumerChain {
627	/// Create a new consumer chain
628	pub fn new() -> Self {
629		Self {
630			consumers: Vec::new(),
631		}
632	}
633
634	/// Add a consumer to the chain
635	pub fn add_consumer(&mut self, consumer: Box<dyn WebSocketConsumer>) {
636		self.consumers.push(consumer);
637	}
638}
639
640impl Default for ConsumerChain {
641	fn default() -> Self {
642		Self::new()
643	}
644}
645
646#[async_trait]
647impl WebSocketConsumer for ConsumerChain {
648	async fn on_connect(&self, context: &mut ConsumerContext) -> WebSocketResult<()> {
649		for consumer in &self.consumers {
650			consumer.on_connect(context).await?;
651		}
652		Ok(())
653	}
654
655	async fn on_message(
656		&self,
657		context: &mut ConsumerContext,
658		message: Message,
659	) -> WebSocketResult<()> {
660		for consumer in &self.consumers {
661			consumer.on_message(context, message.clone()).await?;
662		}
663		Ok(())
664	}
665
666	async fn on_disconnect(&self, context: &mut ConsumerContext) -> WebSocketResult<()> {
667		for consumer in &self.consumers {
668			consumer.on_disconnect(context).await?;
669		}
670		Ok(())
671	}
672}
673
674#[cfg(test)]
675mod tests {
676	use super::*;
677	use rstest::rstest;
678	use tokio::sync::mpsc;
679
680	#[rstest]
681	#[tokio::test]
682	async fn test_consumer_context_creation() {
683		// Arrange
684		let (tx, _rx) = mpsc::unbounded_channel();
685		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
686
687		// Act
688		let context = ConsumerContext::new(conn);
689
690		// Assert
691		assert_eq!(context.connection.id(), "test");
692	}
693
694	#[rstest]
695	#[tokio::test]
696	async fn test_consumer_context_metadata() {
697		// Arrange
698		let (tx, _rx) = mpsc::unbounded_channel();
699		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
700
701		// Act
702		let context =
703			ConsumerContext::new(conn).with_metadata("user_id".to_string(), "123".to_string());
704
705		// Assert
706		assert_eq!(context.get_metadata("user_id").unwrap(), "123");
707	}
708
709	#[rstest]
710	#[tokio::test]
711	async fn test_echo_consumer_connect() {
712		// Arrange
713		let consumer = EchoConsumer::new();
714		let (tx, mut rx) = mpsc::unbounded_channel();
715		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
716		let mut context = ConsumerContext::new(conn);
717
718		// Act
719		consumer.on_connect(&mut context).await.unwrap();
720
721		// Assert
722		let msg = rx.recv().await.unwrap();
723		match msg {
724			Message::Text { data } => assert!(data.contains("Connection established")),
725			_ => panic!("Expected text message"),
726		}
727	}
728
729	#[rstest]
730	#[tokio::test]
731	async fn test_echo_consumer_message() {
732		// Arrange
733		let consumer = EchoConsumer::new();
734		let (tx, mut rx) = mpsc::unbounded_channel();
735		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
736		let mut context = ConsumerContext::new(conn);
737
738		// Act
739		let msg = Message::text("Hello".to_string());
740		consumer.on_message(&mut context, msg).await.unwrap();
741
742		// Assert
743		let received = rx.recv().await.unwrap();
744		match received {
745			Message::Text { data } => assert_eq!(data, "Echo: Hello"),
746			_ => panic!("Expected text message"),
747		}
748	}
749
750	#[rstest]
751	#[tokio::test]
752	async fn test_echo_consumer_binary_utf8_message() {
753		// Arrange
754		let consumer = EchoConsumer::new();
755		let (tx, mut rx) = mpsc::unbounded_channel();
756		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
757		let mut context = ConsumerContext::new(conn);
758
759		// Act - send a valid UTF-8 binary message
760		let msg = Message::binary(b"Hello binary".to_vec());
761		consumer.on_message(&mut context, msg).await.unwrap();
762
763		// Assert
764		let received = rx.recv().await.unwrap();
765		match received {
766			Message::Text { data } => assert_eq!(data, "Echo: Hello binary"),
767			_ => panic!("Expected text message"),
768		}
769	}
770
771	#[rstest]
772	#[tokio::test]
773	async fn test_echo_consumer_binary_non_utf8_message() {
774		// Arrange
775		let consumer = EchoConsumer::new();
776		let (tx, mut rx) = mpsc::unbounded_channel();
777		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
778		let mut context = ConsumerContext::new(conn);
779
780		// Act - send a non-UTF-8 binary message
781		let msg = Message::binary(vec![0xFF, 0xFE, 0xFD]);
782		consumer.on_message(&mut context, msg).await.unwrap();
783
784		// Assert
785		let received = rx.recv().await.unwrap();
786		match received {
787			Message::Text { data } => assert_eq!(data, "Echo: binary(3 bytes)"),
788			_ => panic!("Expected text message"),
789		}
790	}
791
792	#[rstest]
793	#[tokio::test]
794	async fn test_echo_consumer_handles_close_message() {
795		// Arrange
796		let consumer = EchoConsumer::new();
797		let (tx, mut rx) = mpsc::unbounded_channel();
798		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
799		let mut context = ConsumerContext::new(conn.clone());
800
801		// Act
802		let msg = Message::Close {
803			code: 1000,
804			reason: "Normal closure".to_string(),
805		};
806		consumer.on_message(&mut context, msg).await.unwrap();
807
808		// Assert - connection should be closed
809		assert!(conn.is_closed().await);
810
811		// The close frame should have been sent
812		let received = rx.recv().await.unwrap();
813		assert!(matches!(received, Message::Close { code: 1000, .. }));
814	}
815
816	#[rstest]
817	#[tokio::test]
818	async fn test_json_consumer_connect() {
819		// Arrange
820		let consumer = JsonConsumer::new();
821		let (tx, mut rx) = mpsc::unbounded_channel();
822		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
823		let mut context = ConsumerContext::new(conn);
824
825		// Act
826		consumer.on_connect(&mut context).await.unwrap();
827
828		// Assert
829		let msg = rx.recv().await.unwrap();
830		match msg {
831			Message::Text { data } => {
832				let json: serde_json::Value = serde_json::from_str(&data).unwrap();
833				assert_eq!(json["status"], "connected");
834			}
835			_ => panic!("Expected text message"),
836		}
837	}
838
839	#[rstest]
840	#[tokio::test]
841	async fn test_json_consumer_binary_valid_json() {
842		// Arrange
843		let consumer = JsonConsumer::new();
844		let (tx, mut rx) = mpsc::unbounded_channel();
845		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
846		let mut context = ConsumerContext::new(conn);
847
848		// Act - send valid JSON as binary
849		let msg = Message::binary(br#"{"key":"value"}"#.to_vec());
850		consumer.on_message(&mut context, msg).await.unwrap();
851
852		// Assert
853		let received = rx.recv().await.unwrap();
854		match received {
855			Message::Text { data } => {
856				let json: serde_json::Value = serde_json::from_str(&data).unwrap();
857				assert_eq!(json["source"], "binary");
858				assert_eq!(json["data"]["key"], "value");
859			}
860			_ => panic!("Expected text message"),
861		}
862	}
863
864	#[rstest]
865	#[tokio::test]
866	async fn test_json_consumer_binary_invalid_utf8_returns_error() {
867		// Arrange
868		let consumer = JsonConsumer::new();
869		let (tx, _rx) = mpsc::unbounded_channel();
870		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
871		let mut context = ConsumerContext::new(conn);
872
873		// Act - send non-UTF-8 binary
874		let msg = Message::binary(vec![0xFF, 0xFE]);
875		let result = consumer.on_message(&mut context, msg).await;
876
877		// Assert
878		assert!(result.is_err());
879		let err = result.unwrap_err();
880		assert!(matches!(err, WebSocketError::BinaryPayload(_)));
881		assert!(err.to_string().contains("not valid UTF-8"));
882	}
883
884	#[rstest]
885	#[tokio::test]
886	async fn test_json_consumer_binary_invalid_json_returns_error() {
887		// Arrange
888		let consumer = JsonConsumer::new();
889		let (tx, _rx) = mpsc::unbounded_channel();
890		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
891		let mut context = ConsumerContext::new(conn);
892
893		// Act - send valid UTF-8 but invalid JSON as binary
894		let msg = Message::binary(b"not json at all".to_vec());
895		let result = consumer.on_message(&mut context, msg).await;
896
897		// Assert
898		assert!(result.is_err());
899		let err = result.unwrap_err();
900		assert!(matches!(err, WebSocketError::BinaryPayload(_)));
901		assert!(err.to_string().contains("not valid JSON"));
902	}
903
904	#[rstest]
905	#[tokio::test]
906	async fn test_broadcast_consumer_disconnect_cleanup() {
907		// Arrange
908		let room = Arc::new(crate::room::Room::new("cleanup".to_string()));
909		let consumer = BroadcastConsumer::new(room.clone());
910		let (tx, _rx) = mpsc::unbounded_channel();
911		let conn = Arc::new(WebSocketConnection::new("user1".to_string(), tx));
912		room.join("user1".to_string(), conn.clone()).await.unwrap();
913		let mut context = ConsumerContext::new(conn.clone());
914
915		// Act
916		consumer.on_disconnect(&mut context).await.unwrap();
917
918		// Assert - connection is force-closed and removed from room
919		assert!(conn.is_closed().await);
920		assert!(!room.has_client("user1").await);
921	}
922
923	#[rstest]
924	#[tokio::test]
925	async fn test_broadcast_consumer_disconnect_tolerates_already_removed() {
926		// Arrange - client not in room (e.g., already removed by broadcast cleanup)
927		let room = Arc::new(crate::room::Room::new("tolerant".to_string()));
928		let consumer = BroadcastConsumer::new(room.clone());
929		let (tx, _rx) = mpsc::unbounded_channel();
930		let conn = Arc::new(WebSocketConnection::new("ghost".to_string(), tx));
931		let mut context = ConsumerContext::new(conn.clone());
932
933		// Act - should not error even though client is not in the room
934		let result = consumer.on_disconnect(&mut context).await;
935
936		// Assert
937		assert!(result.is_ok());
938		assert!(conn.is_closed().await);
939	}
940
941	#[rstest]
942	#[tokio::test]
943	async fn test_consumer_context_headers() {
944		// Arrange
945		let (tx, _rx) = mpsc::unbounded_channel();
946		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
947
948		// Act
949		let context = ConsumerContext::new(conn)
950			.with_header("cookie".to_string(), "sessionid=abc123".to_string())
951			.with_header("origin".to_string(), "https://example.com".to_string());
952
953		// Assert
954		assert_eq!(context.get_header("cookie").unwrap(), "sessionid=abc123");
955		assert_eq!(context.get_header("origin").unwrap(), "https://example.com");
956	}
957
958	#[rstest]
959	#[tokio::test]
960	async fn test_consumer_context_cookie_header() {
961		// Arrange
962		let (tx, _rx) = mpsc::unbounded_channel();
963		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
964
965		// Act
966		let context = ConsumerContext::new(conn).with_header(
967			"cookie".to_string(),
968			"sessionid=abc123; csrftoken=xyz".to_string(),
969		);
970
971		// Assert
972		assert_eq!(
973			context.cookie_header(),
974			Some("sessionid=abc123; csrftoken=xyz")
975		);
976	}
977
978	#[rstest]
979	#[tokio::test]
980	async fn test_consumer_context_cookie_header_missing() {
981		// Arrange
982		let (tx, _rx) = mpsc::unbounded_channel();
983		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
984
985		// Act
986		let context = ConsumerContext::new(conn);
987
988		// Assert
989		assert_eq!(context.cookie_header(), None);
990	}
991
992	#[rstest]
993	#[tokio::test]
994	async fn test_consumer_context_headers_default_empty() {
995		// Arrange
996		let (tx, _rx) = mpsc::unbounded_channel();
997		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
998
999		// Act
1000		let context = ConsumerContext::new(conn);
1001
1002		// Assert
1003		assert!(context.headers.is_empty());
1004	}
1005
1006	#[rstest]
1007	#[tokio::test]
1008	async fn test_consumer_chain() {
1009		// Arrange
1010		let mut chain = ConsumerChain::new();
1011		chain.add_consumer(Box::new(EchoConsumer::with_prefix("Consumer1".to_string())));
1012
1013		let (tx, _rx) = mpsc::unbounded_channel();
1014		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
1015		let mut context = ConsumerContext::new(conn);
1016
1017		// Act & Assert
1018		assert!(chain.on_connect(&mut context).await.is_ok());
1019	}
1020}