1use axum::{
6 extract::{
7 State,
8 ws::{Message, WebSocket, WebSocketUpgrade},
9 },
10 response::IntoResponse,
11};
12use serde_json::Value;
13use std::sync::Arc;
14use tracing::{debug, error, info, warn};
15
16fn trace_ws(message: &str) {
17 if std::env::var("SPIKARD_WS_TRACE").ok().as_deref() == Some("1") {
18 eprintln!("[spikard-ws] {message}");
19 }
20}
21
22pub trait WebSocketHandler: Send + Sync {
58 fn handle_message(&self, message: Value) -> impl std::future::Future<Output = Option<Value>> + Send;
70
71 fn on_connect(&self) -> impl std::future::Future<Output = ()> + Send {
76 async {}
77 }
78
79 fn on_disconnect(&self) -> impl std::future::Future<Output = ()> + Send {
84 async {}
85 }
86}
87
88pub struct WebSocketState<H: WebSocketHandler> {
94 handler: Arc<H>,
96 handler_factory: Arc<dyn Fn() -> Result<Arc<H>, String> + Send + Sync>,
98 message_schema: Option<Arc<jsonschema::Validator>>,
100 response_schema: Option<Arc<jsonschema::Validator>>,
102}
103
104impl<H: WebSocketHandler> std::fmt::Debug for WebSocketState<H> {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 f.debug_struct("WebSocketState")
107 .field("message_schema", &self.message_schema.is_some())
108 .field("response_schema", &self.response_schema.is_some())
109 .finish()
110 }
111}
112
113impl<H: WebSocketHandler> Clone for WebSocketState<H> {
114 fn clone(&self) -> Self {
115 Self {
116 handler: Arc::clone(&self.handler),
117 handler_factory: Arc::clone(&self.handler_factory),
118 message_schema: self.message_schema.clone(),
119 response_schema: self.response_schema.clone(),
120 }
121 }
122}
123
124impl<H: WebSocketHandler + 'static> WebSocketState<H> {
125 pub fn new(handler: H) -> Self {
139 let handler = Arc::new(handler);
140 Self {
141 handler_factory: Arc::new({
142 let handler = Arc::clone(&handler);
143 move || Ok(Arc::clone(&handler))
144 }),
145 handler,
146 message_schema: None,
147 response_schema: None,
148 }
149 }
150
151 pub fn with_schemas(
186 handler: H,
187 message_schema: Option<serde_json::Value>,
188 response_schema: Option<serde_json::Value>,
189 ) -> Result<Self, String> {
190 let message_validator = if let Some(schema) = message_schema {
191 Some(Arc::new(
192 jsonschema::validator_for(&schema).map_err(|e| format!("Invalid message schema: {}", e))?,
193 ))
194 } else {
195 None
196 };
197
198 let response_validator = if let Some(schema) = response_schema {
199 Some(Arc::new(
200 jsonschema::validator_for(&schema).map_err(|e| format!("Invalid response schema: {}", e))?,
201 ))
202 } else {
203 None
204 };
205
206 let handler = Arc::new(handler);
207 Ok(Self {
208 handler_factory: Arc::new({
209 let handler = Arc::clone(&handler);
210 move || Ok(Arc::clone(&handler))
211 }),
212 handler,
213 message_schema: message_validator,
214 response_schema: response_validator,
215 })
216 }
217
218 pub fn with_factory<F>(
222 factory: F,
223 message_schema: Option<serde_json::Value>,
224 response_schema: Option<serde_json::Value>,
225 ) -> Result<Self, String>
226 where
227 F: Fn() -> Result<H, String> + Send + Sync + 'static,
228 {
229 let message_validator = if let Some(schema) = message_schema {
230 Some(Arc::new(
231 jsonschema::validator_for(&schema).map_err(|e| format!("Invalid message schema: {}", e))?,
232 ))
233 } else {
234 None
235 };
236
237 let response_validator = if let Some(schema) = response_schema {
238 Some(Arc::new(
239 jsonschema::validator_for(&schema).map_err(|e| format!("Invalid response schema: {}", e))?,
240 ))
241 } else {
242 None
243 };
244
245 let factory = Arc::new(factory);
246 let handler = factory()
247 .map(Arc::new)
248 .map_err(|e| format!("Failed to build WebSocket handler: {}", e))?;
249
250 Ok(Self {
251 handler_factory: Arc::new({
252 let factory = Arc::clone(&factory);
253 move || factory().map(Arc::new)
254 }),
255 handler,
256 message_schema: message_validator,
257 response_schema: response_validator,
258 })
259 }
260
261 pub async fn on_connect(&self) {
263 self.handler.on_connect().await;
264 }
265
266 pub async fn on_disconnect(&self) {
268 self.handler.on_disconnect().await;
269 }
270
271 pub async fn handle_message_validated(&self, message: Value) -> Result<Option<Value>, String> {
273 if let Some(validator) = &self.message_schema
274 && !validator.is_valid(&message)
275 {
276 return Err("Message validation failed".to_string());
277 }
278
279 let response = self.handler.handle_message(message).await;
280 if let Some(ref value) = response
281 && let Some(validator) = &self.response_schema
282 && !validator.is_valid(value)
283 {
284 return Ok(None);
285 }
286
287 Ok(response)
288 }
289}
290
291pub async fn websocket_handler<H: WebSocketHandler + 'static>(
314 ws: WebSocketUpgrade,
315 State(state): State<WebSocketState<H>>,
316) -> impl IntoResponse {
317 ws.on_upgrade(move |socket| handle_socket(socket, state))
318}
319
320async fn handle_socket<H: WebSocketHandler>(mut socket: WebSocket, state: WebSocketState<H>) {
322 info!("WebSocket client connected");
323 trace_ws("socket:connected");
324
325 let handler = match (state.handler_factory)() {
326 Ok(handler) => handler,
327 Err(err) => {
328 error!("Failed to create WebSocket handler: {}", err);
329 trace_ws("socket:handler-factory:error");
330 return;
331 }
332 };
333
334 handler.on_connect().await;
335 trace_ws("socket:on_connect:done");
336
337 while let Some(msg) = socket.recv().await {
338 match msg {
339 Ok(Message::Text(text)) => {
340 debug!("Received text message: {}", text);
341 trace_ws(&format!("recv:text len={}", text.len()));
342
343 match serde_json::from_str::<Value>(&text) {
344 Ok(json_msg) => {
345 trace_ws("recv:text:json-ok");
346 if let Some(validator) = &state.message_schema
347 && !validator.is_valid(&json_msg)
348 {
349 error!("Message validation failed");
350 trace_ws("recv:text:validation-failed");
351 let error_response = serde_json::json!({
352 "error": "Message validation failed"
353 });
354 if let Ok(error_text) = serde_json::to_string(&error_response) {
355 trace_ws(&format!("send:validation-error len={}", error_text.len()));
356 let _ = socket.send(Message::Text(error_text.into())).await;
357 }
358 continue;
359 }
360
361 if let Some(response) = handler.handle_message(json_msg).await {
362 trace_ws("handler:response:some");
363 if let Some(validator) = &state.response_schema
364 && !validator.is_valid(&response)
365 {
366 error!("Response validation failed");
367 trace_ws("send:response:validation-failed");
368 continue;
369 }
370
371 let response_text = serde_json::to_string(&response).unwrap_or_else(|_| "{}".to_string());
372 let response_len = response_text.len();
373
374 if let Err(e) = socket.send(Message::Text(response_text.into())).await {
375 error!("Failed to send response: {}", e);
376 trace_ws("send:response:error");
377 break;
378 }
379 trace_ws(&format!("send:response len={}", response_len));
380 } else {
381 trace_ws("handler:response:none");
382 }
383 }
384 Err(e) => {
385 warn!("Failed to parse JSON message: {}", e);
386 trace_ws("recv:text:json-error");
387 let error_msg = serde_json::json!({
388 "type": "error",
389 "message": "Invalid JSON"
390 });
391 let error_text = serde_json::to_string(&error_msg).unwrap_or_else(|_| "{}".to_string());
392 trace_ws(&format!("send:json-error len={}", error_text.len()));
393 let _ = socket.send(Message::Text(error_text.into())).await;
394 }
395 }
396 }
397 Ok(Message::Binary(data)) => {
398 debug!("Received binary message: {} bytes", data.len());
399 trace_ws(&format!("recv:binary len={}", data.len()));
400 if let Err(e) = socket.send(Message::Binary(data)).await {
401 error!("Failed to send binary response: {}", e);
402 trace_ws("send:binary:error");
403 break;
404 }
405 trace_ws("send:binary:ok");
406 }
407 Ok(Message::Ping(data)) => {
408 debug!("Received ping");
409 trace_ws(&format!("recv:ping len={}", data.len()));
410 if let Err(e) = socket.send(Message::Pong(data)).await {
411 error!("Failed to send pong: {}", e);
412 trace_ws("send:pong:error");
413 break;
414 }
415 trace_ws("send:pong:ok");
416 }
417 Ok(Message::Pong(_)) => {
418 debug!("Received pong");
419 trace_ws("recv:pong");
420 }
421 Ok(Message::Close(_)) => {
422 info!("Client closed connection");
423 trace_ws("recv:close");
424 break;
425 }
426 Err(e) => {
427 error!("WebSocket error: {}", e);
428 trace_ws(&format!("recv:error {}", e));
429 break;
430 }
431 }
432 }
433
434 handler.on_disconnect().await;
435 trace_ws("socket:on_disconnect:done");
436 info!("WebSocket client disconnected");
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use std::sync::Mutex;
443 use std::sync::atomic::{AtomicUsize, Ordering};
444
445 #[derive(Debug)]
446 struct EchoHandler;
447
448 impl WebSocketHandler for EchoHandler {
449 async fn handle_message(&self, message: Value) -> Option<Value> {
450 Some(message)
451 }
452 }
453
454 #[derive(Debug)]
455 struct TrackingHandler {
456 connect_count: Arc<AtomicUsize>,
457 disconnect_count: Arc<AtomicUsize>,
458 message_count: Arc<AtomicUsize>,
459 messages: Arc<Mutex<Vec<Value>>>,
460 }
461
462 impl TrackingHandler {
463 fn new() -> Self {
464 Self {
465 connect_count: Arc::new(AtomicUsize::new(0)),
466 disconnect_count: Arc::new(AtomicUsize::new(0)),
467 message_count: Arc::new(AtomicUsize::new(0)),
468 messages: Arc::new(Mutex::new(Vec::new())),
469 }
470 }
471 }
472
473 impl WebSocketHandler for TrackingHandler {
474 async fn handle_message(&self, message: Value) -> Option<Value> {
475 self.message_count.fetch_add(1, Ordering::SeqCst);
476 self.messages.lock().unwrap().push(message.clone());
477 Some(message)
478 }
479
480 async fn on_connect(&self) {
481 self.connect_count.fetch_add(1, Ordering::SeqCst);
482 }
483
484 async fn on_disconnect(&self) {
485 self.disconnect_count.fetch_add(1, Ordering::SeqCst);
486 }
487 }
488
489 #[derive(Debug)]
490 struct SelectiveHandler;
491
492 impl WebSocketHandler for SelectiveHandler {
493 async fn handle_message(&self, message: Value) -> Option<Value> {
494 if message.get("respond").is_some_and(|v| v.as_bool().unwrap_or(false)) {
495 Some(serde_json::json!({"response": "acknowledged"}))
496 } else {
497 None
498 }
499 }
500 }
501
502 #[derive(Debug)]
503 struct TransformHandler;
504
505 impl WebSocketHandler for TransformHandler {
506 async fn handle_message(&self, message: Value) -> Option<Value> {
507 message.as_object().map_or(None, |obj| {
508 let mut resp = obj.clone();
509 resp.insert("processed".to_string(), Value::Bool(true));
510 Some(Value::Object(resp))
511 })
512 }
513 }
514
515 #[test]
516 fn test_websocket_state_creation() {
517 let handler: EchoHandler = EchoHandler;
518 let state: WebSocketState<EchoHandler> = WebSocketState::new(handler);
519 let cloned: WebSocketState<EchoHandler> = state.clone();
520 assert!(Arc::ptr_eq(&state.handler, &cloned.handler));
521 }
522
523 #[test]
524 fn test_websocket_state_with_valid_schema() {
525 let handler: EchoHandler = EchoHandler;
526 let schema: serde_json::Value = serde_json::json!({
527 "type": "object",
528 "properties": {
529 "type": {"type": "string"}
530 }
531 });
532
533 let result: Result<WebSocketState<EchoHandler>, String> =
534 WebSocketState::with_schemas(handler, Some(schema), None);
535 assert!(result.is_ok());
536 }
537
538 #[test]
539 fn test_websocket_state_with_invalid_schema() {
540 let handler: EchoHandler = EchoHandler;
541 let invalid_schema: serde_json::Value = serde_json::json!({
542 "type": "not_a_real_type",
543 "invalid": "schema"
544 });
545
546 let result: Result<WebSocketState<EchoHandler>, String> =
547 WebSocketState::with_schemas(handler, Some(invalid_schema), None);
548 assert!(result.is_err());
549 if let Err(error_msg) = result {
550 assert!(error_msg.contains("Invalid message schema"));
551 }
552 }
553
554 #[test]
555 fn test_websocket_state_with_both_schemas() {
556 let handler: EchoHandler = EchoHandler;
557 let message_schema: serde_json::Value = serde_json::json!({
558 "type": "object",
559 "properties": {"action": {"type": "string"}}
560 });
561 let response_schema: serde_json::Value = serde_json::json!({
562 "type": "object",
563 "properties": {"result": {"type": "string"}}
564 });
565
566 let result: Result<WebSocketState<EchoHandler>, String> =
567 WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema));
568 assert!(result.is_ok());
569 let state: WebSocketState<EchoHandler> = result.unwrap();
570 assert!(state.message_schema.is_some());
571 assert!(state.response_schema.is_some());
572 }
573
574 #[test]
575 fn test_websocket_state_cloning_preserves_schemas() {
576 let handler: EchoHandler = EchoHandler;
577 let schema: serde_json::Value = serde_json::json!({
578 "type": "object",
579 "properties": {"id": {"type": "integer"}}
580 });
581
582 let state: WebSocketState<EchoHandler> = WebSocketState::with_schemas(handler, Some(schema), None).unwrap();
583 let cloned: WebSocketState<EchoHandler> = state.clone();
584
585 assert!(cloned.message_schema.is_some());
586 assert!(cloned.response_schema.is_none());
587 assert!(Arc::ptr_eq(&state.handler, &cloned.handler));
588 }
589
590 #[tokio::test]
591 async fn test_tracking_handler_lifecycle() {
592 let handler: TrackingHandler = TrackingHandler::new();
593 handler.on_connect().await;
594 assert_eq!(handler.connect_count.load(Ordering::SeqCst), 1);
595
596 let msg: Value = serde_json::json!({"test": "data"});
597 let _response: Option<Value> = handler.handle_message(msg).await;
598 assert_eq!(handler.message_count.load(Ordering::SeqCst), 1);
599
600 handler.on_disconnect().await;
601 assert_eq!(handler.disconnect_count.load(Ordering::SeqCst), 1);
602 }
603
604 #[tokio::test]
605 async fn test_selective_handler_responds_conditionally() {
606 let handler: SelectiveHandler = SelectiveHandler;
607
608 let respond_msg: Value = serde_json::json!({"respond": true});
609 let response1: Option<Value> = handler.handle_message(respond_msg).await;
610 assert!(response1.is_some());
611 assert_eq!(response1.unwrap(), serde_json::json!({"response": "acknowledged"}));
612
613 let no_respond_msg: Value = serde_json::json!({"respond": false});
614 let response2: Option<Value> = handler.handle_message(no_respond_msg).await;
615 assert!(response2.is_none());
616 }
617
618 #[tokio::test]
619 async fn test_transform_handler_modifies_message() {
620 let handler: TransformHandler = TransformHandler;
621 let original: Value = serde_json::json!({"name": "test"});
622 let transformed: Option<Value> = handler.handle_message(original).await;
623
624 assert!(transformed.is_some());
625 let resp: Value = transformed.unwrap();
626 assert_eq!(resp.get("name").unwrap(), "test");
627 assert_eq!(resp.get("processed").unwrap(), true);
628 }
629
630 #[tokio::test]
631 async fn test_echo_handler_preserves_json_types() {
632 let handler: EchoHandler = EchoHandler;
633
634 let messages: Vec<Value> = vec![
635 serde_json::json!({"string": "value"}),
636 serde_json::json!({"number": 42}),
637 serde_json::json!({"float": 3.14}),
638 serde_json::json!({"bool": true}),
639 serde_json::json!({"null": null}),
640 serde_json::json!({"array": [1, 2, 3]}),
641 ];
642
643 for msg in messages {
644 let response: Option<Value> = handler.handle_message(msg.clone()).await;
645 assert!(response.is_some());
646 assert_eq!(response.unwrap(), msg);
647 }
648 }
649
650 #[tokio::test]
651 async fn test_tracking_handler_accumulates_messages() {
652 let handler: TrackingHandler = TrackingHandler::new();
653
654 let messages: Vec<Value> = vec![
655 serde_json::json!({"id": 1}),
656 serde_json::json!({"id": 2}),
657 serde_json::json!({"id": 3}),
658 ];
659
660 for msg in messages {
661 let _: Option<Value> = handler.handle_message(msg).await;
662 }
663
664 assert_eq!(handler.message_count.load(Ordering::SeqCst), 3);
665 let stored: Vec<Value> = handler.messages.lock().unwrap().clone();
666 assert_eq!(stored.len(), 3);
667 assert_eq!(stored[0].get("id").unwrap(), 1);
668 assert_eq!(stored[1].get("id").unwrap(), 2);
669 assert_eq!(stored[2].get("id").unwrap(), 3);
670 }
671
672 #[tokio::test]
673 async fn test_echo_handler_with_nested_json() {
674 let handler: EchoHandler = EchoHandler;
675 let nested: Value = serde_json::json!({
676 "level1": {
677 "level2": {
678 "level3": {
679 "value": "deeply nested"
680 }
681 }
682 }
683 });
684
685 let response: Option<Value> = handler.handle_message(nested.clone()).await;
686 assert!(response.is_some());
687 assert_eq!(response.unwrap(), nested);
688 }
689
690 #[tokio::test]
691 async fn test_echo_handler_with_large_array() {
692 let handler: EchoHandler = EchoHandler;
693 let large_array: Value = serde_json::json!({
694 "items": (0..1000).collect::<Vec<i32>>()
695 });
696
697 let response: Option<Value> = handler.handle_message(large_array.clone()).await;
698 assert!(response.is_some());
699 assert_eq!(response.unwrap(), large_array);
700 }
701
702 #[tokio::test]
703 async fn test_echo_handler_with_unicode() {
704 let handler: EchoHandler = EchoHandler;
705 let unicode_msg: Value = serde_json::json!({
706 "emoji": "🚀",
707 "chinese": "你好",
708 "arabic": "مرحبا",
709 "mixed": "Hello 世界 🌍"
710 });
711
712 let response: Option<Value> = handler.handle_message(unicode_msg.clone()).await;
713 assert!(response.is_some());
714 assert_eq!(response.unwrap(), unicode_msg);
715 }
716
717 #[test]
718 fn test_websocket_state_schemas_are_independent() {
719 let handler: EchoHandler = EchoHandler;
720 let message_schema: serde_json::Value = serde_json::json!({"type": "object"});
721 let response_schema: serde_json::Value = serde_json::json!({"type": "array"});
722
723 let state: WebSocketState<EchoHandler> =
724 WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema)).unwrap();
725
726 let cloned: WebSocketState<EchoHandler> = state.clone();
727
728 assert!(state.message_schema.is_some());
729 assert!(state.response_schema.is_some());
730 assert!(cloned.message_schema.is_some());
731 assert!(cloned.response_schema.is_some());
732 }
733
734 #[test]
735 fn test_message_schema_validation_with_required_field() {
736 let handler: EchoHandler = EchoHandler;
737 let message_schema: serde_json::Value = serde_json::json!({
738 "type": "object",
739 "properties": {"type": {"type": "string"}},
740 "required": ["type"]
741 });
742
743 let state: WebSocketState<EchoHandler> =
744 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
745
746 assert!(state.message_schema.is_some());
747 assert!(state.response_schema.is_none());
748
749 let valid_msg: Value = serde_json::json!({"type": "test"});
750 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
751 assert!(validator.is_valid(&valid_msg));
752
753 let invalid_msg: Value = serde_json::json!({"other": "field"});
754 assert!(!validator.is_valid(&invalid_msg));
755 }
756
757 #[test]
758 fn test_response_schema_validation_with_required_field() {
759 let handler: EchoHandler = EchoHandler;
760 let response_schema: serde_json::Value = serde_json::json!({
761 "type": "object",
762 "properties": {"status": {"type": "string"}},
763 "required": ["status"]
764 });
765
766 let state: WebSocketState<EchoHandler> =
767 WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
768
769 assert!(state.message_schema.is_none());
770 assert!(state.response_schema.is_some());
771
772 let valid_response: Value = serde_json::json!({"status": "ok"});
773 let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
774 assert!(validator.is_valid(&valid_response));
775
776 let invalid_response: Value = serde_json::json!({"other": "field"});
777 assert!(!validator.is_valid(&invalid_response));
778 }
779
780 #[test]
781 fn test_invalid_message_schema_returns_error() {
782 let handler: EchoHandler = EchoHandler;
783 let invalid_schema: serde_json::Value = serde_json::json!({
784 "type": "invalid_type_value",
785 "properties": {}
786 });
787
788 let result: Result<WebSocketState<EchoHandler>, String> =
789 WebSocketState::with_schemas(handler, Some(invalid_schema), None);
790
791 assert!(result.is_err());
792 match result {
793 Err(error_msg) => assert!(error_msg.contains("Invalid message schema")),
794 Ok(_) => panic!("Expected error but got Ok"),
795 }
796 }
797
798 #[test]
799 fn test_invalid_response_schema_returns_error() {
800 let handler: EchoHandler = EchoHandler;
801 let invalid_schema: serde_json::Value = serde_json::json!({
802 "type": "definitely_not_valid"
803 });
804
805 let result: Result<WebSocketState<EchoHandler>, String> =
806 WebSocketState::with_schemas(handler, None, Some(invalid_schema));
807
808 assert!(result.is_err());
809 match result {
810 Err(error_msg) => assert!(error_msg.contains("Invalid response schema")),
811 Ok(_) => panic!("Expected error but got Ok"),
812 }
813 }
814
815 #[tokio::test]
816 async fn test_handler_returning_none_response() {
817 let handler: SelectiveHandler = SelectiveHandler;
818
819 let no_response_msg: Value = serde_json::json!({"respond": false});
820 let result: Option<Value> = handler.handle_message(no_response_msg).await;
821
822 assert!(result.is_none());
823 }
824
825 #[tokio::test]
826 async fn test_handler_with_complex_schema_validation() {
827 let handler: EchoHandler = EchoHandler;
828 let message_schema: serde_json::Value = serde_json::json!({
829 "type": "object",
830 "properties": {
831 "user": {
832 "type": "object",
833 "properties": {
834 "id": {"type": "integer"},
835 "name": {"type": "string"}
836 },
837 "required": ["id", "name"]
838 },
839 "action": {"type": "string"}
840 },
841 "required": ["user", "action"]
842 });
843
844 let state: WebSocketState<EchoHandler> =
845 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
846
847 let valid_msg: Value = serde_json::json!({
848 "user": {"id": 123, "name": "Alice"},
849 "action": "create"
850 });
851 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
852 assert!(validator.is_valid(&valid_msg));
853
854 let invalid_msg: Value = serde_json::json!({
855 "user": {"id": "not_an_int", "name": "Bob"},
856 "action": "create"
857 });
858 assert!(!validator.is_valid(&invalid_msg));
859 }
860
861 #[tokio::test]
862 async fn test_tracking_handler_with_multiple_message_types() {
863 let handler: TrackingHandler = TrackingHandler::new();
864
865 let messages: Vec<Value> = vec![
866 serde_json::json!({"type": "text", "content": "hello"}),
867 serde_json::json!({"type": "image", "url": "http://example.com/image.png"}),
868 serde_json::json!({"type": "video", "duration": 120}),
869 ];
870
871 for msg in messages {
872 let _: Option<Value> = handler.handle_message(msg).await;
873 }
874
875 assert_eq!(handler.message_count.load(Ordering::SeqCst), 3);
876 let stored: Vec<Value> = handler.messages.lock().unwrap().clone();
877 assert_eq!(stored.len(), 3);
878 assert_eq!(stored[0].get("type").unwrap(), "text");
879 assert_eq!(stored[1].get("type").unwrap(), "image");
880 assert_eq!(stored[2].get("type").unwrap(), "video");
881 }
882
883 #[tokio::test]
884 async fn test_selective_handler_with_explicit_false() {
885 let handler: SelectiveHandler = SelectiveHandler;
886
887 let msg: Value = serde_json::json!({"respond": false, "data": "test"});
888 let response: Option<Value> = handler.handle_message(msg).await;
889
890 assert!(response.is_none());
891 }
892
893 #[tokio::test]
894 async fn test_selective_handler_without_respond_field() {
895 let handler: SelectiveHandler = SelectiveHandler;
896
897 let msg: Value = serde_json::json!({"data": "test"});
898 let response: Option<Value> = handler.handle_message(msg).await;
899
900 assert!(response.is_none());
901 }
902
903 #[tokio::test]
904 async fn test_transform_handler_with_empty_object() {
905 let handler: TransformHandler = TransformHandler;
906 let original: Value = serde_json::json!({});
907 let transformed: Option<Value> = handler.handle_message(original).await;
908
909 assert!(transformed.is_some());
910 let resp: Value = transformed.unwrap();
911 assert_eq!(resp.get("processed").unwrap(), true);
912 assert_eq!(resp.as_object().unwrap().len(), 1);
913 }
914
915 #[tokio::test]
916 async fn test_transform_handler_preserves_all_fields() {
917 let handler: TransformHandler = TransformHandler;
918 let original: Value = serde_json::json!({
919 "field1": "value1",
920 "field2": 42,
921 "field3": true,
922 "nested": {"key": "value"}
923 });
924 let transformed: Option<Value> = handler.handle_message(original.clone()).await;
925
926 assert!(transformed.is_some());
927 let resp: Value = transformed.unwrap();
928 assert_eq!(resp.get("field1").unwrap(), "value1");
929 assert_eq!(resp.get("field2").unwrap(), 42);
930 assert_eq!(resp.get("field3").unwrap(), true);
931 assert_eq!(resp.get("nested").unwrap(), &serde_json::json!({"key": "value"}));
932 assert_eq!(resp.get("processed").unwrap(), true);
933 }
934
935 #[tokio::test]
936 async fn test_transform_handler_with_non_object_input() {
937 let handler: TransformHandler = TransformHandler;
938
939 let array: Value = serde_json::json!([1, 2, 3]);
940 let response1: Option<Value> = handler.handle_message(array).await;
941 assert!(response1.is_none());
942
943 let string: Value = serde_json::json!("not an object");
944 let response2: Option<Value> = handler.handle_message(string).await;
945 assert!(response2.is_none());
946
947 let number: Value = serde_json::json!(42);
948 let response3: Option<Value> = handler.handle_message(number).await;
949 assert!(response3.is_none());
950 }
951
952 #[test]
954 fn test_message_schema_rejects_wrong_type() {
955 let handler: EchoHandler = EchoHandler;
956 let message_schema: serde_json::Value = serde_json::json!({
957 "type": "object",
958 "properties": {"id": {"type": "integer"}},
959 "required": ["id"]
960 });
961
962 let state: WebSocketState<EchoHandler> =
963 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
964
965 let invalid_msg: Value = serde_json::json!({"id": "not_an_integer"});
966 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
967 assert!(!validator.is_valid(&invalid_msg));
968 }
969
970 #[test]
972 fn test_response_schema_rejects_invalid_type() {
973 let handler: EchoHandler = EchoHandler;
974 let response_schema: serde_json::Value = serde_json::json!({
975 "type": "object",
976 "properties": {"count": {"type": "integer"}},
977 "required": ["count"]
978 });
979
980 let state: WebSocketState<EchoHandler> =
981 WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
982
983 let invalid_response: Value = serde_json::json!([1, 2, 3]);
984 let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
985 assert!(!validator.is_valid(&invalid_response));
986 }
987
988 #[test]
990 fn test_message_missing_multiple_required_fields() {
991 let handler: EchoHandler = EchoHandler;
992 let message_schema: serde_json::Value = serde_json::json!({
993 "type": "object",
994 "properties": {
995 "user_id": {"type": "integer"},
996 "action": {"type": "string"},
997 "timestamp": {"type": "string"}
998 },
999 "required": ["user_id", "action", "timestamp"]
1000 });
1001
1002 let state: WebSocketState<EchoHandler> =
1003 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1004
1005 let invalid_msg: Value = serde_json::json!({"other": "value"});
1006 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1007 assert!(!validator.is_valid(&invalid_msg));
1008
1009 let partial_msg: Value = serde_json::json!({"user_id": 123});
1010 assert!(!validator.is_valid(&partial_msg));
1011 }
1012
1013 #[test]
1015 fn test_deeply_nested_schema_validation_failure() {
1016 let handler: EchoHandler = EchoHandler;
1017 let message_schema: serde_json::Value = serde_json::json!({
1018 "type": "object",
1019 "properties": {
1020 "metadata": {
1021 "type": "object",
1022 "properties": {
1023 "request": {
1024 "type": "object",
1025 "properties": {
1026 "id": {"type": "string"}
1027 },
1028 "required": ["id"]
1029 }
1030 },
1031 "required": ["request"]
1032 }
1033 },
1034 "required": ["metadata"]
1035 });
1036
1037 let state: WebSocketState<EchoHandler> =
1038 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1039
1040 let invalid_msg: Value = serde_json::json!({
1041 "metadata": {
1042 "request": {}
1043 }
1044 });
1045 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1046 assert!(!validator.is_valid(&invalid_msg));
1047 }
1048
1049 #[test]
1051 fn test_array_property_type_validation() {
1052 let handler: EchoHandler = EchoHandler;
1053 let message_schema: serde_json::Value = serde_json::json!({
1054 "type": "object",
1055 "properties": {
1056 "ids": {
1057 "type": "array",
1058 "items": {"type": "integer"}
1059 }
1060 }
1061 });
1062
1063 let state: WebSocketState<EchoHandler> =
1064 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1065
1066 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1067
1068 let valid_msg: Value = serde_json::json!({"ids": [1, 2, 3]});
1069 assert!(validator.is_valid(&valid_msg));
1070
1071 let invalid_msg: Value = serde_json::json!({"ids": [1, "two", 3]});
1072 assert!(!validator.is_valid(&invalid_msg));
1073
1074 let invalid_msg2: Value = serde_json::json!({"ids": "not_an_array"});
1075 assert!(!validator.is_valid(&invalid_msg2));
1076 }
1077
1078 #[test]
1080 fn test_enum_property_validation() {
1081 let handler: EchoHandler = EchoHandler;
1082 let message_schema: serde_json::Value = serde_json::json!({
1083 "type": "object",
1084 "properties": {
1085 "status": {
1086 "type": "string",
1087 "enum": ["pending", "active", "completed"]
1088 }
1089 },
1090 "required": ["status"]
1091 });
1092
1093 let state: WebSocketState<EchoHandler> =
1094 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1095
1096 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1097
1098 let valid_msg: Value = serde_json::json!({"status": "active"});
1099 assert!(validator.is_valid(&valid_msg));
1100
1101 let invalid_msg: Value = serde_json::json!({"status": "unknown"});
1102 assert!(!validator.is_valid(&invalid_msg));
1103 }
1104
1105 #[test]
1107 fn test_number_range_validation() {
1108 let handler: EchoHandler = EchoHandler;
1109 let message_schema: serde_json::Value = serde_json::json!({
1110 "type": "object",
1111 "properties": {
1112 "age": {
1113 "type": "integer",
1114 "minimum": 0,
1115 "maximum": 150
1116 }
1117 },
1118 "required": ["age"]
1119 });
1120
1121 let state: WebSocketState<EchoHandler> =
1122 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1123
1124 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1125
1126 let valid_msg: Value = serde_json::json!({"age": 25});
1127 assert!(validator.is_valid(&valid_msg));
1128
1129 let invalid_msg: Value = serde_json::json!({"age": -1});
1130 assert!(!validator.is_valid(&invalid_msg));
1131
1132 let invalid_msg2: Value = serde_json::json!({"age": 200});
1133 assert!(!validator.is_valid(&invalid_msg2));
1134 }
1135
1136 #[test]
1138 fn test_string_length_validation() {
1139 let handler: EchoHandler = EchoHandler;
1140 let message_schema: serde_json::Value = serde_json::json!({
1141 "type": "object",
1142 "properties": {
1143 "username": {
1144 "type": "string",
1145 "minLength": 3,
1146 "maxLength": 20
1147 }
1148 },
1149 "required": ["username"]
1150 });
1151
1152 let state: WebSocketState<EchoHandler> =
1153 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1154
1155 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1156
1157 let valid_msg: Value = serde_json::json!({"username": "alice"});
1158 assert!(validator.is_valid(&valid_msg));
1159
1160 let invalid_msg: Value = serde_json::json!({"username": "ab"});
1161 assert!(!validator.is_valid(&invalid_msg));
1162
1163 let invalid_msg2: Value =
1164 serde_json::json!({"username": "this_is_a_very_long_username_over_twenty_characters"});
1165 assert!(!validator.is_valid(&invalid_msg2));
1166 }
1167
1168 #[test]
1170 fn test_pattern_validation() {
1171 let handler: EchoHandler = EchoHandler;
1172 let message_schema: serde_json::Value = serde_json::json!({
1173 "type": "object",
1174 "properties": {
1175 "email": {
1176 "type": "string",
1177 "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
1178 }
1179 },
1180 "required": ["email"]
1181 });
1182
1183 let state: WebSocketState<EchoHandler> =
1184 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1185
1186 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1187
1188 let valid_msg: Value = serde_json::json!({"email": "user@example.com"});
1189 assert!(validator.is_valid(&valid_msg));
1190
1191 let invalid_msg: Value = serde_json::json!({"email": "user@example"});
1192 assert!(!validator.is_valid(&invalid_msg));
1193
1194 let invalid_msg2: Value = serde_json::json!({"email": "userexample.com"});
1195 assert!(!validator.is_valid(&invalid_msg2));
1196 }
1197
1198 #[test]
1200 fn test_additional_properties_validation() {
1201 let handler: EchoHandler = EchoHandler;
1202 let message_schema: serde_json::Value = serde_json::json!({
1203 "type": "object",
1204 "properties": {
1205 "name": {"type": "string"}
1206 },
1207 "additionalProperties": false
1208 });
1209
1210 let state: WebSocketState<EchoHandler> =
1211 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1212
1213 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1214
1215 let valid_msg: Value = serde_json::json!({"name": "Alice"});
1216 assert!(validator.is_valid(&valid_msg));
1217
1218 let invalid_msg: Value = serde_json::json!({"name": "Bob", "age": 30});
1219 assert!(!validator.is_valid(&invalid_msg));
1220 }
1221
1222 #[test]
1224 fn test_one_of_constraint() {
1225 let handler: EchoHandler = EchoHandler;
1226 let message_schema: serde_json::Value = serde_json::json!({
1227 "type": "object",
1228 "oneOf": [
1229 {
1230 "properties": {"type": {"const": "text"}},
1231 "required": ["type"]
1232 },
1233 {
1234 "properties": {"type": {"const": "number"}},
1235 "required": ["type"]
1236 }
1237 ]
1238 });
1239
1240 let state: WebSocketState<EchoHandler> =
1241 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1242
1243 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1244
1245 let valid_msg: Value = serde_json::json!({"type": "text"});
1246 assert!(validator.is_valid(&valid_msg));
1247
1248 let invalid_msg: Value = serde_json::json!({"type": "unknown"});
1249 assert!(!validator.is_valid(&invalid_msg));
1250 }
1251
1252 #[test]
1254 fn test_any_of_constraint() {
1255 let handler: EchoHandler = EchoHandler;
1256 let message_schema: serde_json::Value = serde_json::json!({
1257 "type": "object",
1258 "properties": {
1259 "value": {"type": ["string", "integer"]}
1260 },
1261 "required": ["value"]
1262 });
1263
1264 let state: WebSocketState<EchoHandler> =
1265 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1266
1267 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1268
1269 let msg1: Value = serde_json::json!({"value": "text"});
1270 assert!(validator.is_valid(&msg1));
1271
1272 let msg2: Value = serde_json::json!({"value": 42});
1273 assert!(validator.is_valid(&msg2));
1274
1275 let invalid_msg: Value = serde_json::json!({"value": true});
1276 assert!(!validator.is_valid(&invalid_msg));
1277 }
1278
1279 #[test]
1281 fn test_response_schema_with_multiple_constraints() {
1282 let handler: EchoHandler = EchoHandler;
1283 let response_schema: serde_json::Value = serde_json::json!({
1284 "type": "object",
1285 "properties": {
1286 "success": {"type": "boolean"},
1287 "data": {
1288 "type": "object",
1289 "properties": {
1290 "items": {
1291 "type": "array",
1292 "items": {"type": "object"},
1293 "minItems": 1
1294 }
1295 },
1296 "required": ["items"]
1297 }
1298 },
1299 "required": ["success", "data"]
1300 });
1301
1302 let state: WebSocketState<EchoHandler> =
1303 WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
1304
1305 let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
1306
1307 let valid_response: Value = serde_json::json!({
1308 "success": true,
1309 "data": {
1310 "items": [{"id": 1}]
1311 }
1312 });
1313 assert!(validator.is_valid(&valid_response));
1314
1315 let invalid_response: Value = serde_json::json!({
1316 "success": true,
1317 "data": {
1318 "items": []
1319 }
1320 });
1321 assert!(!validator.is_valid(&invalid_response));
1322
1323 let invalid_response2: Value = serde_json::json!({
1324 "success": true
1325 });
1326 assert!(!validator.is_valid(&invalid_response2));
1327 }
1328
1329 #[test]
1331 fn test_null_value_validation() {
1332 let handler: EchoHandler = EchoHandler;
1333 let message_schema: serde_json::Value = serde_json::json!({
1334 "type": "object",
1335 "properties": {
1336 "optional_field": {"type": ["string", "null"]},
1337 "required_field": {"type": "string"}
1338 },
1339 "required": ["required_field"]
1340 });
1341
1342 let state: WebSocketState<EchoHandler> =
1343 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1344
1345 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1346
1347 let msg1: Value = serde_json::json!({
1348 "optional_field": null,
1349 "required_field": "value"
1350 });
1351 assert!(validator.is_valid(&msg1));
1352
1353 let msg2: Value = serde_json::json!({"required_field": "value"});
1354 assert!(validator.is_valid(&msg2));
1355
1356 let invalid_msg: Value = serde_json::json!({"required_field": null});
1357 assert!(!validator.is_valid(&invalid_msg));
1358 }
1359
1360 #[test]
1362 fn test_schema_with_defaults_still_validates() {
1363 let handler: EchoHandler = EchoHandler;
1364 let message_schema: serde_json::Value = serde_json::json!({
1365 "type": "object",
1366 "properties": {
1367 "status": {
1368 "type": "string",
1369 "default": "pending"
1370 }
1371 }
1372 });
1373
1374 let state: WebSocketState<EchoHandler> =
1375 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1376
1377 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1378
1379 let msg: Value = serde_json::json!({});
1380 assert!(validator.is_valid(&msg));
1381 }
1382
1383 #[test]
1385 fn test_both_schemas_validate_independently() {
1386 let handler: EchoHandler = EchoHandler;
1387 let message_schema: serde_json::Value = serde_json::json!({
1388 "type": "object",
1389 "properties": {"action": {"type": "string"}},
1390 "required": ["action"]
1391 });
1392 let response_schema: serde_json::Value = serde_json::json!({
1393 "type": "object",
1394 "properties": {"result": {"type": "string"}},
1395 "required": ["result"]
1396 });
1397
1398 let state: WebSocketState<EchoHandler> =
1399 WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema)).unwrap();
1400
1401 let msg_validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1402 let resp_validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
1403
1404 let valid_msg: Value = serde_json::json!({"action": "test"});
1405 let invalid_response: Value = serde_json::json!({"data": "oops"});
1406
1407 assert!(msg_validator.is_valid(&valid_msg));
1408 assert!(!resp_validator.is_valid(&invalid_response));
1409
1410 let invalid_msg: Value = serde_json::json!({"data": "oops"});
1411 let valid_response: Value = serde_json::json!({"result": "ok"});
1412
1413 assert!(!msg_validator.is_valid(&invalid_msg));
1414 assert!(resp_validator.is_valid(&valid_response));
1415 }
1416
1417 #[test]
1419 fn test_validation_with_large_payload() {
1420 let handler: EchoHandler = EchoHandler;
1421 let message_schema: serde_json::Value = serde_json::json!({
1422 "type": "object",
1423 "properties": {
1424 "items": {
1425 "type": "array",
1426 "items": {"type": "integer"}
1427 }
1428 },
1429 "required": ["items"]
1430 });
1431
1432 let state: WebSocketState<EchoHandler> =
1433 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1434
1435 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1436
1437 let mut items = Vec::new();
1438 for i in 0..10_000 {
1439 items.push(i);
1440 }
1441 let large_msg: Value = serde_json::json!({"items": items});
1442
1443 assert!(validator.is_valid(&large_msg));
1444 }
1445
1446 #[test]
1448 fn test_mutually_exclusive_schema_properties() {
1449 let handler: EchoHandler = EchoHandler;
1450
1451 let message_schema: serde_json::Value = serde_json::json!({
1452 "allOf": [
1453 {
1454 "type": "object",
1455 "properties": {"a": {"type": "string"}},
1456 "required": ["a"]
1457 },
1458 {
1459 "type": "object",
1460 "properties": {"b": {"type": "integer"}},
1461 "required": ["b"]
1462 }
1463 ]
1464 });
1465
1466 let state: WebSocketState<EchoHandler> =
1467 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1468
1469 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1470
1471 let valid_msg: Value = serde_json::json!({"a": "text", "b": 42});
1472 assert!(validator.is_valid(&valid_msg));
1473
1474 let invalid_msg: Value = serde_json::json!({"a": "text"});
1475 assert!(!validator.is_valid(&invalid_msg));
1476 }
1477}