1#[allow(unused_imports)]
42use crate::connection::{Message, WebSocketConnection, WebSocketError, WebSocketResult};
43use async_trait::async_trait;
44use std::sync::Arc;
45
46#[cfg(feature = "di")]
47use reinhardt_di::{Injectable, Injected, InjectionContext};
48
49pub struct ConsumerContext {
75 pub connection: Arc<WebSocketConnection>,
77 pub headers: std::collections::HashMap<String, String>,
79 pub metadata: std::collections::HashMap<String, String>,
81 #[cfg(feature = "di")]
83 di_context: Option<Arc<InjectionContext>>,
84}
85
86impl ConsumerContext {
87 pub fn new(connection: Arc<WebSocketConnection>) -> Self {
102 Self {
103 connection,
104 headers: std::collections::HashMap::new(),
105 metadata: std::collections::HashMap::new(),
106 #[cfg(feature = "di")]
107 di_context: None,
108 }
109 }
110
111 #[cfg(feature = "di")]
127 pub fn with_di_context(
128 connection: Arc<WebSocketConnection>,
129 di_context: Arc<InjectionContext>,
130 ) -> Self {
131 Self {
132 connection,
133 headers: std::collections::HashMap::new(),
134 metadata: std::collections::HashMap::new(),
135 di_context: Some(di_context),
136 }
137 }
138
139 pub fn with_header(mut self, key: String, value: String) -> Self {
144 self.headers.insert(key, value);
145 self
146 }
147
148 pub fn get_header(&self, key: &str) -> Option<&String> {
150 self.headers.get(key)
151 }
152
153 pub fn cookie_header(&self) -> Option<&str> {
157 self.headers.get("cookie").map(|s| s.as_str())
158 }
159
160 pub fn with_metadata(mut self, key: String, value: String) -> Self {
162 self.metadata.insert(key, value);
163 self
164 }
165
166 pub fn get_metadata(&self, key: &str) -> Option<&String> {
168 self.metadata.get(key)
169 }
170
171 #[cfg(feature = "di")]
173 pub fn di_context(&self) -> Option<&Arc<InjectionContext>> {
174 self.di_context.as_ref()
175 }
176
177 #[cfg(feature = "di")]
179 pub fn set_di_context(&mut self, ctx: Arc<InjectionContext>) {
180 self.di_context = Some(ctx);
181 }
182
183 #[cfg(feature = "di")]
200 pub async fn resolve<T>(&self) -> WebSocketResult<T>
201 where
202 T: Injectable + Clone + Send + Sync + 'static,
203 {
204 let ctx = self
205 .di_context
206 .as_ref()
207 .ok_or_else(|| WebSocketError::Internal("DI context not available".to_string()))?;
208
209 Injected::<T>::resolve(ctx)
210 .await
211 .map(|injected| injected.into_inner())
212 .map_err(|_| WebSocketError::Internal("dependency resolution failed".to_string()))
213 }
214
215 #[cfg(feature = "di")]
232 pub async fn resolve_uncached<T>(&self) -> WebSocketResult<T>
233 where
234 T: Injectable + Clone + Send + Sync + 'static,
235 {
236 let ctx = self
237 .di_context
238 .as_ref()
239 .ok_or_else(|| WebSocketError::Internal("DI context not available".to_string()))?;
240
241 Injected::<T>::resolve_uncached(ctx)
242 .await
243 .map(|injected| injected.into_inner())
244 .map_err(|_| WebSocketError::Internal("dependency resolution failed".to_string()))
245 }
246
247 #[cfg(feature = "di")]
262 pub async fn try_resolve<T>(&self) -> Option<T>
263 where
264 T: Injectable + Clone + Send + Sync + 'static,
265 {
266 let ctx = self.di_context.as_ref()?;
267
268 Injected::<T>::resolve(ctx)
269 .await
270 .ok()
271 .map(|injected| injected.into_inner())
272 }
273}
274
275#[async_trait]
279pub trait WebSocketConsumer: Send + Sync {
280 async fn on_connect(&self, context: &mut ConsumerContext) -> WebSocketResult<()>;
282
283 async fn on_message(
285 &self,
286 context: &mut ConsumerContext,
287 message: Message,
288 ) -> WebSocketResult<()>;
289
290 async fn on_disconnect(&self, context: &mut ConsumerContext) -> WebSocketResult<()>;
292}
293
294pub struct EchoConsumer {
321 prefix: String,
322}
323
324impl EchoConsumer {
325 pub fn new() -> Self {
327 Self {
328 prefix: "Echo".to_string(),
329 }
330 }
331
332 pub fn with_prefix(prefix: String) -> Self {
334 Self { prefix }
335 }
336}
337
338impl Default for EchoConsumer {
339 fn default() -> Self {
340 Self::new()
341 }
342}
343
344#[async_trait]
345impl WebSocketConsumer for EchoConsumer {
346 async fn on_connect(&self, context: &mut ConsumerContext) -> WebSocketResult<()> {
347 context
348 .connection
349 .send_text(format!("{}: Connection established", self.prefix))
350 .await
351 }
352
353 async fn on_message(
354 &self,
355 context: &mut ConsumerContext,
356 message: Message,
357 ) -> WebSocketResult<()> {
358 match message {
359 Message::Text { data } => {
360 context
361 .connection
362 .send_text(format!("{}: {}", self.prefix, data))
363 .await
364 }
365 Message::Binary { data } => {
366 match String::from_utf8(data.clone()) {
368 Ok(text) => {
369 context
370 .connection
371 .send_text(format!("{}: {}", self.prefix, text))
372 .await
373 }
374 Err(_) => {
375 context
377 .connection
378 .send_text(format!("{}: binary({} bytes)", self.prefix, data.len()))
379 .await
380 }
381 }
382 }
383 Message::Close { code, reason } => {
384 context
386 .connection
387 .close_with_reason(code, reason)
388 .await
389 .ok();
390 Ok(())
391 }
392 _ => Ok(()),
393 }
394 }
395
396 async fn on_disconnect(&self, _context: &mut ConsumerContext) -> WebSocketResult<()> {
397 Ok(())
398 }
399}
400
401pub struct BroadcastConsumer {
436 room: Arc<crate::room::Room>,
437}
438
439impl BroadcastConsumer {
440 pub fn new(room: Arc<crate::room::Room>) -> Self {
442 Self { room }
443 }
444}
445
446#[async_trait]
447impl WebSocketConsumer for BroadcastConsumer {
448 async fn on_connect(&self, context: &mut ConsumerContext) -> WebSocketResult<()> {
449 let client_id = context.connection.id().to_string();
450 self.room
451 .join(client_id.clone(), context.connection.clone())
452 .await
453 .map_err(|e| crate::connection::WebSocketError::Connection(e.to_string()))?;
454
455 context
456 .connection
457 .send_text("Joined broadcast room".to_string())
458 .await
459 }
460
461 async fn on_message(
462 &self,
463 _context: &mut ConsumerContext,
464 message: Message,
465 ) -> WebSocketResult<()> {
466 let result = self.room.broadcast(message).await;
467 if result.is_complete_failure() {
468 return Err(crate::connection::WebSocketError::Send(
469 "broadcast failed for all clients".to_string(),
470 ));
471 }
472 Ok(())
473 }
474
475 async fn on_disconnect(&self, context: &mut ConsumerContext) -> WebSocketResult<()> {
476 let client_id = context.connection.id();
477 let _ = self.room.leave(client_id).await;
480
481 context.connection.force_close().await;
483
484 Ok(())
485 }
486}
487
488pub struct JsonConsumer;
521
522impl JsonConsumer {
523 pub fn new() -> Self {
525 Self
526 }
527}
528
529impl Default for JsonConsumer {
530 fn default() -> Self {
531 Self::new()
532 }
533}
534
535#[async_trait]
536impl WebSocketConsumer for JsonConsumer {
537 async fn on_connect(&self, context: &mut ConsumerContext) -> WebSocketResult<()> {
538 context
539 .connection
540 .send_json(&serde_json::json!({
541 "type": "connection",
542 "status": "connected"
543 }))
544 .await
545 }
546
547 async fn on_message(
548 &self,
549 context: &mut ConsumerContext,
550 message: Message,
551 ) -> WebSocketResult<()> {
552 match message {
553 Message::Text { data } => {
554 let json: serde_json::Value = serde_json::from_str(&data)
556 .map_err(|e| crate::connection::WebSocketError::Protocol(e.to_string()))?;
557
558 let response = serde_json::json!({
560 "type": "echo",
561 "data": json,
562 "timestamp": chrono::Utc::now().to_rfc3339()
563 });
564
565 context.connection.send_json(&response).await
566 }
567 Message::Binary { data } => {
568 let text = String::from_utf8(data).map_err(|e| {
570 crate::connection::WebSocketError::BinaryPayload(format!(
571 "binary payload is not valid UTF-8: {}",
572 e
573 ))
574 })?;
575
576 let json: serde_json::Value = serde_json::from_str(&text).map_err(|e| {
577 crate::connection::WebSocketError::BinaryPayload(format!(
578 "binary payload is not valid JSON: {}",
579 e
580 ))
581 })?;
582
583 let response = serde_json::json!({
584 "type": "echo",
585 "data": json,
586 "source": "binary",
587 "timestamp": chrono::Utc::now().to_rfc3339()
588 });
589
590 context.connection.send_json(&response).await
591 }
592 _ => Ok(()),
593 }
594 }
595
596 async fn on_disconnect(&self, _context: &mut ConsumerContext) -> WebSocketResult<()> {
597 Ok(())
598 }
599}
600
601pub struct ConsumerChain {
623 consumers: Vec<Box<dyn WebSocketConsumer>>,
624}
625
626impl ConsumerChain {
627 pub fn new() -> Self {
629 Self {
630 consumers: Vec::new(),
631 }
632 }
633
634 pub fn add_consumer(&mut self, consumer: Box<dyn WebSocketConsumer>) {
636 self.consumers.push(consumer);
637 }
638}
639
640impl Default for ConsumerChain {
641 fn default() -> Self {
642 Self::new()
643 }
644}
645
646#[async_trait]
647impl WebSocketConsumer for ConsumerChain {
648 async fn on_connect(&self, context: &mut ConsumerContext) -> WebSocketResult<()> {
649 for consumer in &self.consumers {
650 consumer.on_connect(context).await?;
651 }
652 Ok(())
653 }
654
655 async fn on_message(
656 &self,
657 context: &mut ConsumerContext,
658 message: Message,
659 ) -> WebSocketResult<()> {
660 for consumer in &self.consumers {
661 consumer.on_message(context, message.clone()).await?;
662 }
663 Ok(())
664 }
665
666 async fn on_disconnect(&self, context: &mut ConsumerContext) -> WebSocketResult<()> {
667 for consumer in &self.consumers {
668 consumer.on_disconnect(context).await?;
669 }
670 Ok(())
671 }
672}
673
674#[cfg(test)]
675mod tests {
676 use super::*;
677 use rstest::rstest;
678 use tokio::sync::mpsc;
679
680 #[rstest]
681 #[tokio::test]
682 async fn test_consumer_context_creation() {
683 let (tx, _rx) = mpsc::unbounded_channel();
685 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
686
687 let context = ConsumerContext::new(conn);
689
690 assert_eq!(context.connection.id(), "test");
692 }
693
694 #[rstest]
695 #[tokio::test]
696 async fn test_consumer_context_metadata() {
697 let (tx, _rx) = mpsc::unbounded_channel();
699 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
700
701 let context =
703 ConsumerContext::new(conn).with_metadata("user_id".to_string(), "123".to_string());
704
705 assert_eq!(context.get_metadata("user_id").unwrap(), "123");
707 }
708
709 #[rstest]
710 #[tokio::test]
711 async fn test_echo_consumer_connect() {
712 let consumer = EchoConsumer::new();
714 let (tx, mut rx) = mpsc::unbounded_channel();
715 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
716 let mut context = ConsumerContext::new(conn);
717
718 consumer.on_connect(&mut context).await.unwrap();
720
721 let msg = rx.recv().await.unwrap();
723 match msg {
724 Message::Text { data } => assert!(data.contains("Connection established")),
725 _ => panic!("Expected text message"),
726 }
727 }
728
729 #[rstest]
730 #[tokio::test]
731 async fn test_echo_consumer_message() {
732 let consumer = EchoConsumer::new();
734 let (tx, mut rx) = mpsc::unbounded_channel();
735 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
736 let mut context = ConsumerContext::new(conn);
737
738 let msg = Message::text("Hello".to_string());
740 consumer.on_message(&mut context, msg).await.unwrap();
741
742 let received = rx.recv().await.unwrap();
744 match received {
745 Message::Text { data } => assert_eq!(data, "Echo: Hello"),
746 _ => panic!("Expected text message"),
747 }
748 }
749
750 #[rstest]
751 #[tokio::test]
752 async fn test_echo_consumer_binary_utf8_message() {
753 let consumer = EchoConsumer::new();
755 let (tx, mut rx) = mpsc::unbounded_channel();
756 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
757 let mut context = ConsumerContext::new(conn);
758
759 let msg = Message::binary(b"Hello binary".to_vec());
761 consumer.on_message(&mut context, msg).await.unwrap();
762
763 let received = rx.recv().await.unwrap();
765 match received {
766 Message::Text { data } => assert_eq!(data, "Echo: Hello binary"),
767 _ => panic!("Expected text message"),
768 }
769 }
770
771 #[rstest]
772 #[tokio::test]
773 async fn test_echo_consumer_binary_non_utf8_message() {
774 let consumer = EchoConsumer::new();
776 let (tx, mut rx) = mpsc::unbounded_channel();
777 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
778 let mut context = ConsumerContext::new(conn);
779
780 let msg = Message::binary(vec![0xFF, 0xFE, 0xFD]);
782 consumer.on_message(&mut context, msg).await.unwrap();
783
784 let received = rx.recv().await.unwrap();
786 match received {
787 Message::Text { data } => assert_eq!(data, "Echo: binary(3 bytes)"),
788 _ => panic!("Expected text message"),
789 }
790 }
791
792 #[rstest]
793 #[tokio::test]
794 async fn test_echo_consumer_handles_close_message() {
795 let consumer = EchoConsumer::new();
797 let (tx, mut rx) = mpsc::unbounded_channel();
798 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
799 let mut context = ConsumerContext::new(conn.clone());
800
801 let msg = Message::Close {
803 code: 1000,
804 reason: "Normal closure".to_string(),
805 };
806 consumer.on_message(&mut context, msg).await.unwrap();
807
808 assert!(conn.is_closed().await);
810
811 let received = rx.recv().await.unwrap();
813 assert!(matches!(received, Message::Close { code: 1000, .. }));
814 }
815
816 #[rstest]
817 #[tokio::test]
818 async fn test_json_consumer_connect() {
819 let consumer = JsonConsumer::new();
821 let (tx, mut rx) = mpsc::unbounded_channel();
822 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
823 let mut context = ConsumerContext::new(conn);
824
825 consumer.on_connect(&mut context).await.unwrap();
827
828 let msg = rx.recv().await.unwrap();
830 match msg {
831 Message::Text { data } => {
832 let json: serde_json::Value = serde_json::from_str(&data).unwrap();
833 assert_eq!(json["status"], "connected");
834 }
835 _ => panic!("Expected text message"),
836 }
837 }
838
839 #[rstest]
840 #[tokio::test]
841 async fn test_json_consumer_binary_valid_json() {
842 let consumer = JsonConsumer::new();
844 let (tx, mut rx) = mpsc::unbounded_channel();
845 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
846 let mut context = ConsumerContext::new(conn);
847
848 let msg = Message::binary(br#"{"key":"value"}"#.to_vec());
850 consumer.on_message(&mut context, msg).await.unwrap();
851
852 let received = rx.recv().await.unwrap();
854 match received {
855 Message::Text { data } => {
856 let json: serde_json::Value = serde_json::from_str(&data).unwrap();
857 assert_eq!(json["source"], "binary");
858 assert_eq!(json["data"]["key"], "value");
859 }
860 _ => panic!("Expected text message"),
861 }
862 }
863
864 #[rstest]
865 #[tokio::test]
866 async fn test_json_consumer_binary_invalid_utf8_returns_error() {
867 let consumer = JsonConsumer::new();
869 let (tx, _rx) = mpsc::unbounded_channel();
870 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
871 let mut context = ConsumerContext::new(conn);
872
873 let msg = Message::binary(vec![0xFF, 0xFE]);
875 let result = consumer.on_message(&mut context, msg).await;
876
877 assert!(result.is_err());
879 let err = result.unwrap_err();
880 assert!(matches!(err, WebSocketError::BinaryPayload(_)));
881 assert!(err.to_string().contains("not valid UTF-8"));
882 }
883
884 #[rstest]
885 #[tokio::test]
886 async fn test_json_consumer_binary_invalid_json_returns_error() {
887 let consumer = JsonConsumer::new();
889 let (tx, _rx) = mpsc::unbounded_channel();
890 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
891 let mut context = ConsumerContext::new(conn);
892
893 let msg = Message::binary(b"not json at all".to_vec());
895 let result = consumer.on_message(&mut context, msg).await;
896
897 assert!(result.is_err());
899 let err = result.unwrap_err();
900 assert!(matches!(err, WebSocketError::BinaryPayload(_)));
901 assert!(err.to_string().contains("not valid JSON"));
902 }
903
904 #[rstest]
905 #[tokio::test]
906 async fn test_broadcast_consumer_disconnect_cleanup() {
907 let room = Arc::new(crate::room::Room::new("cleanup".to_string()));
909 let consumer = BroadcastConsumer::new(room.clone());
910 let (tx, _rx) = mpsc::unbounded_channel();
911 let conn = Arc::new(WebSocketConnection::new("user1".to_string(), tx));
912 room.join("user1".to_string(), conn.clone()).await.unwrap();
913 let mut context = ConsumerContext::new(conn.clone());
914
915 consumer.on_disconnect(&mut context).await.unwrap();
917
918 assert!(conn.is_closed().await);
920 assert!(!room.has_client("user1").await);
921 }
922
923 #[rstest]
924 #[tokio::test]
925 async fn test_broadcast_consumer_disconnect_tolerates_already_removed() {
926 let room = Arc::new(crate::room::Room::new("tolerant".to_string()));
928 let consumer = BroadcastConsumer::new(room.clone());
929 let (tx, _rx) = mpsc::unbounded_channel();
930 let conn = Arc::new(WebSocketConnection::new("ghost".to_string(), tx));
931 let mut context = ConsumerContext::new(conn.clone());
932
933 let result = consumer.on_disconnect(&mut context).await;
935
936 assert!(result.is_ok());
938 assert!(conn.is_closed().await);
939 }
940
941 #[rstest]
942 #[tokio::test]
943 async fn test_consumer_context_headers() {
944 let (tx, _rx) = mpsc::unbounded_channel();
946 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
947
948 let context = ConsumerContext::new(conn)
950 .with_header("cookie".to_string(), "sessionid=abc123".to_string())
951 .with_header("origin".to_string(), "https://example.com".to_string());
952
953 assert_eq!(context.get_header("cookie").unwrap(), "sessionid=abc123");
955 assert_eq!(context.get_header("origin").unwrap(), "https://example.com");
956 }
957
958 #[rstest]
959 #[tokio::test]
960 async fn test_consumer_context_cookie_header() {
961 let (tx, _rx) = mpsc::unbounded_channel();
963 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
964
965 let context = ConsumerContext::new(conn).with_header(
967 "cookie".to_string(),
968 "sessionid=abc123; csrftoken=xyz".to_string(),
969 );
970
971 assert_eq!(
973 context.cookie_header(),
974 Some("sessionid=abc123; csrftoken=xyz")
975 );
976 }
977
978 #[rstest]
979 #[tokio::test]
980 async fn test_consumer_context_cookie_header_missing() {
981 let (tx, _rx) = mpsc::unbounded_channel();
983 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
984
985 let context = ConsumerContext::new(conn);
987
988 assert_eq!(context.cookie_header(), None);
990 }
991
992 #[rstest]
993 #[tokio::test]
994 async fn test_consumer_context_headers_default_empty() {
995 let (tx, _rx) = mpsc::unbounded_channel();
997 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
998
999 let context = ConsumerContext::new(conn);
1001
1002 assert!(context.headers.is_empty());
1004 }
1005
1006 #[rstest]
1007 #[tokio::test]
1008 async fn test_consumer_chain() {
1009 let mut chain = ConsumerChain::new();
1011 chain.add_consumer(Box::new(EchoConsumer::with_prefix("Consumer1".to_string())));
1012
1013 let (tx, _rx) = mpsc::unbounded_channel();
1014 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
1015 let mut context = ConsumerContext::new(conn);
1016
1017 assert!(chain.on_connect(&mut context).await.is_ok());
1019 }
1020}