1use axum::{
6 extract::{
7 State,
8 ws::{Message, WebSocket, WebSocketUpgrade},
9 },
10 response::IntoResponse,
11};
12use serde_json::Value;
13use std::sync::Arc;
14use tracing::{debug, error, info, warn};
15
16pub trait WebSocketHandler: Send + Sync {
52 fn handle_message(&self, message: Value) -> impl std::future::Future<Output = Option<Value>> + Send;
64
65 fn on_connect(&self) -> impl std::future::Future<Output = ()> + Send {
70 async {}
71 }
72
73 fn on_disconnect(&self) -> impl std::future::Future<Output = ()> + Send {
78 async {}
79 }
80}
81
82#[derive(Debug)]
88pub struct WebSocketState<H: WebSocketHandler> {
89 handler: Arc<H>,
91 message_schema: Option<Arc<jsonschema::Validator>>,
93 response_schema: Option<Arc<jsonschema::Validator>>,
95}
96
97impl<H: WebSocketHandler> Clone for WebSocketState<H> {
98 fn clone(&self) -> Self {
99 Self {
100 handler: Arc::clone(&self.handler),
101 message_schema: self.message_schema.clone(),
102 response_schema: self.response_schema.clone(),
103 }
104 }
105}
106
107impl<H: WebSocketHandler + 'static> WebSocketState<H> {
108 pub fn new(handler: H) -> Self {
122 Self {
123 handler: Arc::new(handler),
124 message_schema: None,
125 response_schema: None,
126 }
127 }
128
129 pub fn with_schemas(
164 handler: H,
165 message_schema: Option<serde_json::Value>,
166 response_schema: Option<serde_json::Value>,
167 ) -> Result<Self, String> {
168 let message_validator = if let Some(schema) = message_schema {
169 Some(Arc::new(
170 jsonschema::validator_for(&schema).map_err(|e| format!("Invalid message schema: {}", e))?,
171 ))
172 } else {
173 None
174 };
175
176 let response_validator = if let Some(schema) = response_schema {
177 Some(Arc::new(
178 jsonschema::validator_for(&schema).map_err(|e| format!("Invalid response schema: {}", e))?,
179 ))
180 } else {
181 None
182 };
183
184 Ok(Self {
185 handler: Arc::new(handler),
186 message_schema: message_validator,
187 response_schema: response_validator,
188 })
189 }
190}
191
192pub async fn websocket_handler<H: WebSocketHandler + 'static>(
215 ws: WebSocketUpgrade,
216 State(state): State<WebSocketState<H>>,
217) -> impl IntoResponse {
218 ws.on_upgrade(move |socket| handle_socket(socket, state))
219}
220
221async fn handle_socket<H: WebSocketHandler>(mut socket: WebSocket, state: WebSocketState<H>) {
223 info!("WebSocket client connected");
224
225 state.handler.on_connect().await;
226
227 while let Some(msg) = socket.recv().await {
228 match msg {
229 Ok(Message::Text(text)) => {
230 debug!("Received text message: {}", text);
231
232 match serde_json::from_str::<Value>(&text) {
233 Ok(json_msg) => {
234 if let Some(validator) = &state.message_schema
235 && !validator.is_valid(&json_msg)
236 {
237 error!("Message validation failed");
238 let error_response = serde_json::json!({
239 "error": "Message validation failed"
240 });
241 if let Ok(error_text) = serde_json::to_string(&error_response) {
242 let _ = socket.send(Message::Text(error_text.into())).await;
243 }
244 continue;
245 }
246
247 if let Some(response) = state.handler.handle_message(json_msg).await {
248 if let Some(validator) = &state.response_schema
249 && !validator.is_valid(&response)
250 {
251 error!("Response validation failed");
252 continue;
253 }
254
255 let response_text = serde_json::to_string(&response).unwrap_or_else(|_| "{}".to_string());
256
257 if let Err(e) = socket.send(Message::Text(response_text.into())).await {
258 error!("Failed to send response: {}", e);
259 break;
260 }
261 }
262 }
263 Err(e) => {
264 warn!("Failed to parse JSON message: {}", e);
265 let error_msg = serde_json::json!({
266 "type": "error",
267 "message": "Invalid JSON"
268 });
269 let error_text = serde_json::to_string(&error_msg).unwrap_or_else(|_| "{}".to_string());
270 let _ = socket.send(Message::Text(error_text.into())).await;
271 }
272 }
273 }
274 Ok(Message::Binary(data)) => {
275 debug!("Received binary message: {} bytes", data.len());
276 if let Err(e) = socket.send(Message::Binary(data)).await {
277 error!("Failed to send binary response: {}", e);
278 break;
279 }
280 }
281 Ok(Message::Ping(data)) => {
282 debug!("Received ping");
283 if let Err(e) = socket.send(Message::Pong(data)).await {
284 error!("Failed to send pong: {}", e);
285 break;
286 }
287 }
288 Ok(Message::Pong(_)) => {
289 debug!("Received pong");
290 }
291 Ok(Message::Close(_)) => {
292 info!("Client closed connection");
293 break;
294 }
295 Err(e) => {
296 error!("WebSocket error: {}", e);
297 break;
298 }
299 }
300 }
301
302 state.handler.on_disconnect().await;
303 info!("WebSocket client disconnected");
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use std::sync::Mutex;
310 use std::sync::atomic::{AtomicUsize, Ordering};
311
312 #[derive(Debug)]
313 struct EchoHandler;
314
315 impl WebSocketHandler for EchoHandler {
316 async fn handle_message(&self, message: Value) -> Option<Value> {
317 Some(message)
318 }
319 }
320
321 #[derive(Debug)]
322 struct TrackingHandler {
323 connect_count: Arc<AtomicUsize>,
324 disconnect_count: Arc<AtomicUsize>,
325 message_count: Arc<AtomicUsize>,
326 messages: Arc<Mutex<Vec<Value>>>,
327 }
328
329 impl TrackingHandler {
330 fn new() -> Self {
331 Self {
332 connect_count: Arc::new(AtomicUsize::new(0)),
333 disconnect_count: Arc::new(AtomicUsize::new(0)),
334 message_count: Arc::new(AtomicUsize::new(0)),
335 messages: Arc::new(Mutex::new(Vec::new())),
336 }
337 }
338 }
339
340 impl WebSocketHandler for TrackingHandler {
341 async fn handle_message(&self, message: Value) -> Option<Value> {
342 self.message_count.fetch_add(1, Ordering::SeqCst);
343 self.messages.lock().unwrap().push(message.clone());
344 Some(message)
345 }
346
347 async fn on_connect(&self) {
348 self.connect_count.fetch_add(1, Ordering::SeqCst);
349 }
350
351 async fn on_disconnect(&self) {
352 self.disconnect_count.fetch_add(1, Ordering::SeqCst);
353 }
354 }
355
356 #[derive(Debug)]
357 struct SelectiveHandler;
358
359 impl WebSocketHandler for SelectiveHandler {
360 async fn handle_message(&self, message: Value) -> Option<Value> {
361 if message.get("respond").is_some_and(|v| v.as_bool().unwrap_or(false)) {
362 Some(serde_json::json!({"response": "acknowledged"}))
363 } else {
364 None
365 }
366 }
367 }
368
369 #[derive(Debug)]
370 struct TransformHandler;
371
372 impl WebSocketHandler for TransformHandler {
373 async fn handle_message(&self, message: Value) -> Option<Value> {
374 message.as_object().map_or(None, |obj| {
375 let mut resp = obj.clone();
376 resp.insert("processed".to_string(), Value::Bool(true));
377 Some(Value::Object(resp))
378 })
379 }
380 }
381
382 #[test]
383 fn test_websocket_state_creation() {
384 let handler: EchoHandler = EchoHandler;
385 let state: WebSocketState<EchoHandler> = WebSocketState::new(handler);
386 let cloned: WebSocketState<EchoHandler> = state.clone();
387 assert!(Arc::ptr_eq(&state.handler, &cloned.handler));
388 }
389
390 #[test]
391 fn test_websocket_state_with_valid_schema() {
392 let handler: EchoHandler = EchoHandler;
393 let schema: serde_json::Value = serde_json::json!({
394 "type": "object",
395 "properties": {
396 "type": {"type": "string"}
397 }
398 });
399
400 let result: Result<WebSocketState<EchoHandler>, String> =
401 WebSocketState::with_schemas(handler, Some(schema), None);
402 assert!(result.is_ok());
403 }
404
405 #[test]
406 fn test_websocket_state_with_invalid_schema() {
407 let handler: EchoHandler = EchoHandler;
408 let invalid_schema: serde_json::Value = serde_json::json!({
409 "type": "not_a_real_type",
410 "invalid": "schema"
411 });
412
413 let result: Result<WebSocketState<EchoHandler>, String> =
414 WebSocketState::with_schemas(handler, Some(invalid_schema), None);
415 assert!(result.is_err());
416 if let Err(error_msg) = result {
417 assert!(error_msg.contains("Invalid message schema"));
418 }
419 }
420
421 #[test]
422 fn test_websocket_state_with_both_schemas() {
423 let handler: EchoHandler = EchoHandler;
424 let message_schema: serde_json::Value = serde_json::json!({
425 "type": "object",
426 "properties": {"action": {"type": "string"}}
427 });
428 let response_schema: serde_json::Value = serde_json::json!({
429 "type": "object",
430 "properties": {"result": {"type": "string"}}
431 });
432
433 let result: Result<WebSocketState<EchoHandler>, String> =
434 WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema));
435 assert!(result.is_ok());
436 let state: WebSocketState<EchoHandler> = result.unwrap();
437 assert!(state.message_schema.is_some());
438 assert!(state.response_schema.is_some());
439 }
440
441 #[test]
442 fn test_websocket_state_cloning_preserves_schemas() {
443 let handler: EchoHandler = EchoHandler;
444 let schema: serde_json::Value = serde_json::json!({
445 "type": "object",
446 "properties": {"id": {"type": "integer"}}
447 });
448
449 let state: WebSocketState<EchoHandler> = WebSocketState::with_schemas(handler, Some(schema), None).unwrap();
450 let cloned: WebSocketState<EchoHandler> = state.clone();
451
452 assert!(cloned.message_schema.is_some());
453 assert!(cloned.response_schema.is_none());
454 assert!(Arc::ptr_eq(&state.handler, &cloned.handler));
455 }
456
457 #[tokio::test]
458 async fn test_tracking_handler_lifecycle() {
459 let handler: TrackingHandler = TrackingHandler::new();
460 handler.on_connect().await;
461 assert_eq!(handler.connect_count.load(Ordering::SeqCst), 1);
462
463 let msg: Value = serde_json::json!({"test": "data"});
464 let _response: Option<Value> = handler.handle_message(msg).await;
465 assert_eq!(handler.message_count.load(Ordering::SeqCst), 1);
466
467 handler.on_disconnect().await;
468 assert_eq!(handler.disconnect_count.load(Ordering::SeqCst), 1);
469 }
470
471 #[tokio::test]
472 async fn test_selective_handler_responds_conditionally() {
473 let handler: SelectiveHandler = SelectiveHandler;
474
475 let respond_msg: Value = serde_json::json!({"respond": true});
476 let response1: Option<Value> = handler.handle_message(respond_msg).await;
477 assert!(response1.is_some());
478 assert_eq!(response1.unwrap(), serde_json::json!({"response": "acknowledged"}));
479
480 let no_respond_msg: Value = serde_json::json!({"respond": false});
481 let response2: Option<Value> = handler.handle_message(no_respond_msg).await;
482 assert!(response2.is_none());
483 }
484
485 #[tokio::test]
486 async fn test_transform_handler_modifies_message() {
487 let handler: TransformHandler = TransformHandler;
488 let original: Value = serde_json::json!({"name": "test"});
489 let transformed: Option<Value> = handler.handle_message(original).await;
490
491 assert!(transformed.is_some());
492 let resp: Value = transformed.unwrap();
493 assert_eq!(resp.get("name").unwrap(), "test");
494 assert_eq!(resp.get("processed").unwrap(), true);
495 }
496
497 #[tokio::test]
498 async fn test_echo_handler_preserves_json_types() {
499 let handler: EchoHandler = EchoHandler;
500
501 let messages: Vec<Value> = vec![
502 serde_json::json!({"string": "value"}),
503 serde_json::json!({"number": 42}),
504 serde_json::json!({"float": 3.14}),
505 serde_json::json!({"bool": true}),
506 serde_json::json!({"null": null}),
507 serde_json::json!({"array": [1, 2, 3]}),
508 ];
509
510 for msg in messages {
511 let response: Option<Value> = handler.handle_message(msg.clone()).await;
512 assert!(response.is_some());
513 assert_eq!(response.unwrap(), msg);
514 }
515 }
516
517 #[tokio::test]
518 async fn test_tracking_handler_accumulates_messages() {
519 let handler: TrackingHandler = TrackingHandler::new();
520
521 let messages: Vec<Value> = vec![
522 serde_json::json!({"id": 1}),
523 serde_json::json!({"id": 2}),
524 serde_json::json!({"id": 3}),
525 ];
526
527 for msg in messages {
528 let _: Option<Value> = handler.handle_message(msg).await;
529 }
530
531 assert_eq!(handler.message_count.load(Ordering::SeqCst), 3);
532 let stored: Vec<Value> = handler.messages.lock().unwrap().clone();
533 assert_eq!(stored.len(), 3);
534 assert_eq!(stored[0].get("id").unwrap(), 1);
535 assert_eq!(stored[1].get("id").unwrap(), 2);
536 assert_eq!(stored[2].get("id").unwrap(), 3);
537 }
538
539 #[tokio::test]
540 async fn test_echo_handler_with_nested_json() {
541 let handler: EchoHandler = EchoHandler;
542 let nested: Value = serde_json::json!({
543 "level1": {
544 "level2": {
545 "level3": {
546 "value": "deeply nested"
547 }
548 }
549 }
550 });
551
552 let response: Option<Value> = handler.handle_message(nested.clone()).await;
553 assert!(response.is_some());
554 assert_eq!(response.unwrap(), nested);
555 }
556
557 #[tokio::test]
558 async fn test_echo_handler_with_large_array() {
559 let handler: EchoHandler = EchoHandler;
560 let large_array: Value = serde_json::json!({
561 "items": (0..1000).collect::<Vec<i32>>()
562 });
563
564 let response: Option<Value> = handler.handle_message(large_array.clone()).await;
565 assert!(response.is_some());
566 assert_eq!(response.unwrap(), large_array);
567 }
568
569 #[tokio::test]
570 async fn test_echo_handler_with_unicode() {
571 let handler: EchoHandler = EchoHandler;
572 let unicode_msg: Value = serde_json::json!({
573 "emoji": "🚀",
574 "chinese": "你好",
575 "arabic": "مرحبا",
576 "mixed": "Hello 世界 🌍"
577 });
578
579 let response: Option<Value> = handler.handle_message(unicode_msg.clone()).await;
580 assert!(response.is_some());
581 assert_eq!(response.unwrap(), unicode_msg);
582 }
583
584 #[test]
585 fn test_websocket_state_schemas_are_independent() {
586 let handler: EchoHandler = EchoHandler;
587 let message_schema: serde_json::Value = serde_json::json!({"type": "object"});
588 let response_schema: serde_json::Value = serde_json::json!({"type": "array"});
589
590 let state: WebSocketState<EchoHandler> =
591 WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema)).unwrap();
592
593 let cloned: WebSocketState<EchoHandler> = state.clone();
594
595 assert!(state.message_schema.is_some());
596 assert!(state.response_schema.is_some());
597 assert!(cloned.message_schema.is_some());
598 assert!(cloned.response_schema.is_some());
599 }
600
601 #[test]
602 fn test_message_schema_validation_with_required_field() {
603 let handler: EchoHandler = EchoHandler;
604 let message_schema: serde_json::Value = serde_json::json!({
605 "type": "object",
606 "properties": {"type": {"type": "string"}},
607 "required": ["type"]
608 });
609
610 let state: WebSocketState<EchoHandler> =
611 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
612
613 assert!(state.message_schema.is_some());
614 assert!(state.response_schema.is_none());
615
616 let valid_msg: Value = serde_json::json!({"type": "test"});
617 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
618 assert!(validator.is_valid(&valid_msg));
619
620 let invalid_msg: Value = serde_json::json!({"other": "field"});
621 assert!(!validator.is_valid(&invalid_msg));
622 }
623
624 #[test]
625 fn test_response_schema_validation_with_required_field() {
626 let handler: EchoHandler = EchoHandler;
627 let response_schema: serde_json::Value = serde_json::json!({
628 "type": "object",
629 "properties": {"status": {"type": "string"}},
630 "required": ["status"]
631 });
632
633 let state: WebSocketState<EchoHandler> =
634 WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
635
636 assert!(state.message_schema.is_none());
637 assert!(state.response_schema.is_some());
638
639 let valid_response: Value = serde_json::json!({"status": "ok"});
640 let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
641 assert!(validator.is_valid(&valid_response));
642
643 let invalid_response: Value = serde_json::json!({"other": "field"});
644 assert!(!validator.is_valid(&invalid_response));
645 }
646
647 #[test]
648 fn test_invalid_message_schema_returns_error() {
649 let handler: EchoHandler = EchoHandler;
650 let invalid_schema: serde_json::Value = serde_json::json!({
651 "type": "invalid_type_value",
652 "properties": {}
653 });
654
655 let result: Result<WebSocketState<EchoHandler>, String> =
656 WebSocketState::with_schemas(handler, Some(invalid_schema), None);
657
658 assert!(result.is_err());
659 match result {
660 Err(error_msg) => assert!(error_msg.contains("Invalid message schema")),
661 Ok(_) => panic!("Expected error but got Ok"),
662 }
663 }
664
665 #[test]
666 fn test_invalid_response_schema_returns_error() {
667 let handler: EchoHandler = EchoHandler;
668 let invalid_schema: serde_json::Value = serde_json::json!({
669 "type": "definitely_not_valid"
670 });
671
672 let result: Result<WebSocketState<EchoHandler>, String> =
673 WebSocketState::with_schemas(handler, None, Some(invalid_schema));
674
675 assert!(result.is_err());
676 match result {
677 Err(error_msg) => assert!(error_msg.contains("Invalid response schema")),
678 Ok(_) => panic!("Expected error but got Ok"),
679 }
680 }
681
682 #[tokio::test]
683 async fn test_handler_returning_none_response() {
684 let handler: SelectiveHandler = SelectiveHandler;
685
686 let no_response_msg: Value = serde_json::json!({"respond": false});
687 let result: Option<Value> = handler.handle_message(no_response_msg).await;
688
689 assert!(result.is_none());
690 }
691
692 #[tokio::test]
693 async fn test_handler_with_complex_schema_validation() {
694 let handler: EchoHandler = EchoHandler;
695 let message_schema: serde_json::Value = serde_json::json!({
696 "type": "object",
697 "properties": {
698 "user": {
699 "type": "object",
700 "properties": {
701 "id": {"type": "integer"},
702 "name": {"type": "string"}
703 },
704 "required": ["id", "name"]
705 },
706 "action": {"type": "string"}
707 },
708 "required": ["user", "action"]
709 });
710
711 let state: WebSocketState<EchoHandler> =
712 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
713
714 let valid_msg: Value = serde_json::json!({
715 "user": {"id": 123, "name": "Alice"},
716 "action": "create"
717 });
718 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
719 assert!(validator.is_valid(&valid_msg));
720
721 let invalid_msg: Value = serde_json::json!({
722 "user": {"id": "not_an_int", "name": "Bob"},
723 "action": "create"
724 });
725 assert!(!validator.is_valid(&invalid_msg));
726 }
727
728 #[tokio::test]
729 async fn test_tracking_handler_with_multiple_message_types() {
730 let handler: TrackingHandler = TrackingHandler::new();
731
732 let messages: Vec<Value> = vec![
733 serde_json::json!({"type": "text", "content": "hello"}),
734 serde_json::json!({"type": "image", "url": "http://example.com/image.png"}),
735 serde_json::json!({"type": "video", "duration": 120}),
736 ];
737
738 for msg in messages {
739 let _: Option<Value> = handler.handle_message(msg).await;
740 }
741
742 assert_eq!(handler.message_count.load(Ordering::SeqCst), 3);
743 let stored: Vec<Value> = handler.messages.lock().unwrap().clone();
744 assert_eq!(stored.len(), 3);
745 assert_eq!(stored[0].get("type").unwrap(), "text");
746 assert_eq!(stored[1].get("type").unwrap(), "image");
747 assert_eq!(stored[2].get("type").unwrap(), "video");
748 }
749
750 #[tokio::test]
751 async fn test_selective_handler_with_explicit_false() {
752 let handler: SelectiveHandler = SelectiveHandler;
753
754 let msg: Value = serde_json::json!({"respond": false, "data": "test"});
755 let response: Option<Value> = handler.handle_message(msg).await;
756
757 assert!(response.is_none());
758 }
759
760 #[tokio::test]
761 async fn test_selective_handler_without_respond_field() {
762 let handler: SelectiveHandler = SelectiveHandler;
763
764 let msg: Value = serde_json::json!({"data": "test"});
765 let response: Option<Value> = handler.handle_message(msg).await;
766
767 assert!(response.is_none());
768 }
769
770 #[tokio::test]
771 async fn test_transform_handler_with_empty_object() {
772 let handler: TransformHandler = TransformHandler;
773 let original: Value = serde_json::json!({});
774 let transformed: Option<Value> = handler.handle_message(original).await;
775
776 assert!(transformed.is_some());
777 let resp: Value = transformed.unwrap();
778 assert_eq!(resp.get("processed").unwrap(), true);
779 assert_eq!(resp.as_object().unwrap().len(), 1);
780 }
781
782 #[tokio::test]
783 async fn test_transform_handler_preserves_all_fields() {
784 let handler: TransformHandler = TransformHandler;
785 let original: Value = serde_json::json!({
786 "field1": "value1",
787 "field2": 42,
788 "field3": true,
789 "nested": {"key": "value"}
790 });
791 let transformed: Option<Value> = handler.handle_message(original.clone()).await;
792
793 assert!(transformed.is_some());
794 let resp: Value = transformed.unwrap();
795 assert_eq!(resp.get("field1").unwrap(), "value1");
796 assert_eq!(resp.get("field2").unwrap(), 42);
797 assert_eq!(resp.get("field3").unwrap(), true);
798 assert_eq!(resp.get("nested").unwrap(), &serde_json::json!({"key": "value"}));
799 assert_eq!(resp.get("processed").unwrap(), true);
800 }
801
802 #[tokio::test]
803 async fn test_transform_handler_with_non_object_input() {
804 let handler: TransformHandler = TransformHandler;
805
806 let array: Value = serde_json::json!([1, 2, 3]);
807 let response1: Option<Value> = handler.handle_message(array).await;
808 assert!(response1.is_none());
809
810 let string: Value = serde_json::json!("not an object");
811 let response2: Option<Value> = handler.handle_message(string).await;
812 assert!(response2.is_none());
813
814 let number: Value = serde_json::json!(42);
815 let response3: Option<Value> = handler.handle_message(number).await;
816 assert!(response3.is_none());
817 }
818
819 #[test]
821 fn test_message_schema_rejects_wrong_type() {
822 let handler: EchoHandler = EchoHandler;
823 let message_schema: serde_json::Value = serde_json::json!({
824 "type": "object",
825 "properties": {"id": {"type": "integer"}},
826 "required": ["id"]
827 });
828
829 let state: WebSocketState<EchoHandler> =
830 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
831
832 let invalid_msg: Value = serde_json::json!({"id": "not_an_integer"});
833 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
834 assert!(!validator.is_valid(&invalid_msg));
835 }
836
837 #[test]
839 fn test_response_schema_rejects_invalid_type() {
840 let handler: EchoHandler = EchoHandler;
841 let response_schema: serde_json::Value = serde_json::json!({
842 "type": "object",
843 "properties": {"count": {"type": "integer"}},
844 "required": ["count"]
845 });
846
847 let state: WebSocketState<EchoHandler> =
848 WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
849
850 let invalid_response: Value = serde_json::json!([1, 2, 3]);
851 let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
852 assert!(!validator.is_valid(&invalid_response));
853 }
854
855 #[test]
857 fn test_message_missing_multiple_required_fields() {
858 let handler: EchoHandler = EchoHandler;
859 let message_schema: serde_json::Value = serde_json::json!({
860 "type": "object",
861 "properties": {
862 "user_id": {"type": "integer"},
863 "action": {"type": "string"},
864 "timestamp": {"type": "string"}
865 },
866 "required": ["user_id", "action", "timestamp"]
867 });
868
869 let state: WebSocketState<EchoHandler> =
870 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
871
872 let invalid_msg: Value = serde_json::json!({"other": "value"});
873 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
874 assert!(!validator.is_valid(&invalid_msg));
875
876 let partial_msg: Value = serde_json::json!({"user_id": 123});
877 assert!(!validator.is_valid(&partial_msg));
878 }
879
880 #[test]
882 fn test_deeply_nested_schema_validation_failure() {
883 let handler: EchoHandler = EchoHandler;
884 let message_schema: serde_json::Value = serde_json::json!({
885 "type": "object",
886 "properties": {
887 "metadata": {
888 "type": "object",
889 "properties": {
890 "request": {
891 "type": "object",
892 "properties": {
893 "id": {"type": "string"}
894 },
895 "required": ["id"]
896 }
897 },
898 "required": ["request"]
899 }
900 },
901 "required": ["metadata"]
902 });
903
904 let state: WebSocketState<EchoHandler> =
905 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
906
907 let invalid_msg: Value = serde_json::json!({
908 "metadata": {
909 "request": {}
910 }
911 });
912 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
913 assert!(!validator.is_valid(&invalid_msg));
914 }
915
916 #[test]
918 fn test_array_property_type_validation() {
919 let handler: EchoHandler = EchoHandler;
920 let message_schema: serde_json::Value = serde_json::json!({
921 "type": "object",
922 "properties": {
923 "ids": {
924 "type": "array",
925 "items": {"type": "integer"}
926 }
927 }
928 });
929
930 let state: WebSocketState<EchoHandler> =
931 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
932
933 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
934
935 let valid_msg: Value = serde_json::json!({"ids": [1, 2, 3]});
936 assert!(validator.is_valid(&valid_msg));
937
938 let invalid_msg: Value = serde_json::json!({"ids": [1, "two", 3]});
939 assert!(!validator.is_valid(&invalid_msg));
940
941 let invalid_msg2: Value = serde_json::json!({"ids": "not_an_array"});
942 assert!(!validator.is_valid(&invalid_msg2));
943 }
944
945 #[test]
947 fn test_enum_property_validation() {
948 let handler: EchoHandler = EchoHandler;
949 let message_schema: serde_json::Value = serde_json::json!({
950 "type": "object",
951 "properties": {
952 "status": {
953 "type": "string",
954 "enum": ["pending", "active", "completed"]
955 }
956 },
957 "required": ["status"]
958 });
959
960 let state: WebSocketState<EchoHandler> =
961 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
962
963 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
964
965 let valid_msg: Value = serde_json::json!({"status": "active"});
966 assert!(validator.is_valid(&valid_msg));
967
968 let invalid_msg: Value = serde_json::json!({"status": "unknown"});
969 assert!(!validator.is_valid(&invalid_msg));
970 }
971
972 #[test]
974 fn test_number_range_validation() {
975 let handler: EchoHandler = EchoHandler;
976 let message_schema: serde_json::Value = serde_json::json!({
977 "type": "object",
978 "properties": {
979 "age": {
980 "type": "integer",
981 "minimum": 0,
982 "maximum": 150
983 }
984 },
985 "required": ["age"]
986 });
987
988 let state: WebSocketState<EchoHandler> =
989 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
990
991 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
992
993 let valid_msg: Value = serde_json::json!({"age": 25});
994 assert!(validator.is_valid(&valid_msg));
995
996 let invalid_msg: Value = serde_json::json!({"age": -1});
997 assert!(!validator.is_valid(&invalid_msg));
998
999 let invalid_msg2: Value = serde_json::json!({"age": 200});
1000 assert!(!validator.is_valid(&invalid_msg2));
1001 }
1002
1003 #[test]
1005 fn test_string_length_validation() {
1006 let handler: EchoHandler = EchoHandler;
1007 let message_schema: serde_json::Value = serde_json::json!({
1008 "type": "object",
1009 "properties": {
1010 "username": {
1011 "type": "string",
1012 "minLength": 3,
1013 "maxLength": 20
1014 }
1015 },
1016 "required": ["username"]
1017 });
1018
1019 let state: WebSocketState<EchoHandler> =
1020 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1021
1022 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1023
1024 let valid_msg: Value = serde_json::json!({"username": "alice"});
1025 assert!(validator.is_valid(&valid_msg));
1026
1027 let invalid_msg: Value = serde_json::json!({"username": "ab"});
1028 assert!(!validator.is_valid(&invalid_msg));
1029
1030 let invalid_msg2: Value =
1031 serde_json::json!({"username": "this_is_a_very_long_username_over_twenty_characters"});
1032 assert!(!validator.is_valid(&invalid_msg2));
1033 }
1034
1035 #[test]
1037 fn test_pattern_validation() {
1038 let handler: EchoHandler = EchoHandler;
1039 let message_schema: serde_json::Value = serde_json::json!({
1040 "type": "object",
1041 "properties": {
1042 "email": {
1043 "type": "string",
1044 "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
1045 }
1046 },
1047 "required": ["email"]
1048 });
1049
1050 let state: WebSocketState<EchoHandler> =
1051 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1052
1053 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1054
1055 let valid_msg: Value = serde_json::json!({"email": "user@example.com"});
1056 assert!(validator.is_valid(&valid_msg));
1057
1058 let invalid_msg: Value = serde_json::json!({"email": "user@example"});
1059 assert!(!validator.is_valid(&invalid_msg));
1060
1061 let invalid_msg2: Value = serde_json::json!({"email": "userexample.com"});
1062 assert!(!validator.is_valid(&invalid_msg2));
1063 }
1064
1065 #[test]
1067 fn test_additional_properties_validation() {
1068 let handler: EchoHandler = EchoHandler;
1069 let message_schema: serde_json::Value = serde_json::json!({
1070 "type": "object",
1071 "properties": {
1072 "name": {"type": "string"}
1073 },
1074 "additionalProperties": false
1075 });
1076
1077 let state: WebSocketState<EchoHandler> =
1078 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1079
1080 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1081
1082 let valid_msg: Value = serde_json::json!({"name": "Alice"});
1083 assert!(validator.is_valid(&valid_msg));
1084
1085 let invalid_msg: Value = serde_json::json!({"name": "Bob", "age": 30});
1086 assert!(!validator.is_valid(&invalid_msg));
1087 }
1088
1089 #[test]
1091 fn test_one_of_constraint() {
1092 let handler: EchoHandler = EchoHandler;
1093 let message_schema: serde_json::Value = serde_json::json!({
1094 "type": "object",
1095 "oneOf": [
1096 {
1097 "properties": {"type": {"const": "text"}},
1098 "required": ["type"]
1099 },
1100 {
1101 "properties": {"type": {"const": "number"}},
1102 "required": ["type"]
1103 }
1104 ]
1105 });
1106
1107 let state: WebSocketState<EchoHandler> =
1108 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1109
1110 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1111
1112 let valid_msg: Value = serde_json::json!({"type": "text"});
1113 assert!(validator.is_valid(&valid_msg));
1114
1115 let invalid_msg: Value = serde_json::json!({"type": "unknown"});
1116 assert!(!validator.is_valid(&invalid_msg));
1117 }
1118
1119 #[test]
1121 fn test_any_of_constraint() {
1122 let handler: EchoHandler = EchoHandler;
1123 let message_schema: serde_json::Value = serde_json::json!({
1124 "type": "object",
1125 "properties": {
1126 "value": {"type": ["string", "integer"]}
1127 },
1128 "required": ["value"]
1129 });
1130
1131 let state: WebSocketState<EchoHandler> =
1132 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1133
1134 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1135
1136 let msg1: Value = serde_json::json!({"value": "text"});
1137 assert!(validator.is_valid(&msg1));
1138
1139 let msg2: Value = serde_json::json!({"value": 42});
1140 assert!(validator.is_valid(&msg2));
1141
1142 let invalid_msg: Value = serde_json::json!({"value": true});
1143 assert!(!validator.is_valid(&invalid_msg));
1144 }
1145
1146 #[test]
1148 fn test_response_schema_with_multiple_constraints() {
1149 let handler: EchoHandler = EchoHandler;
1150 let response_schema: serde_json::Value = serde_json::json!({
1151 "type": "object",
1152 "properties": {
1153 "success": {"type": "boolean"},
1154 "data": {
1155 "type": "object",
1156 "properties": {
1157 "items": {
1158 "type": "array",
1159 "items": {"type": "object"},
1160 "minItems": 1
1161 }
1162 },
1163 "required": ["items"]
1164 }
1165 },
1166 "required": ["success", "data"]
1167 });
1168
1169 let state: WebSocketState<EchoHandler> =
1170 WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
1171
1172 let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
1173
1174 let valid_response: Value = serde_json::json!({
1175 "success": true,
1176 "data": {
1177 "items": [{"id": 1}]
1178 }
1179 });
1180 assert!(validator.is_valid(&valid_response));
1181
1182 let invalid_response: Value = serde_json::json!({
1183 "success": true,
1184 "data": {
1185 "items": []
1186 }
1187 });
1188 assert!(!validator.is_valid(&invalid_response));
1189
1190 let invalid_response2: Value = serde_json::json!({
1191 "success": true
1192 });
1193 assert!(!validator.is_valid(&invalid_response2));
1194 }
1195
1196 #[test]
1198 fn test_null_value_validation() {
1199 let handler: EchoHandler = EchoHandler;
1200 let message_schema: serde_json::Value = serde_json::json!({
1201 "type": "object",
1202 "properties": {
1203 "optional_field": {"type": ["string", "null"]},
1204 "required_field": {"type": "string"}
1205 },
1206 "required": ["required_field"]
1207 });
1208
1209 let state: WebSocketState<EchoHandler> =
1210 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1211
1212 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1213
1214 let msg1: Value = serde_json::json!({
1215 "optional_field": null,
1216 "required_field": "value"
1217 });
1218 assert!(validator.is_valid(&msg1));
1219
1220 let msg2: Value = serde_json::json!({"required_field": "value"});
1221 assert!(validator.is_valid(&msg2));
1222
1223 let invalid_msg: Value = serde_json::json!({"required_field": null});
1224 assert!(!validator.is_valid(&invalid_msg));
1225 }
1226
1227 #[test]
1229 fn test_schema_with_defaults_still_validates() {
1230 let handler: EchoHandler = EchoHandler;
1231 let message_schema: serde_json::Value = serde_json::json!({
1232 "type": "object",
1233 "properties": {
1234 "status": {
1235 "type": "string",
1236 "default": "pending"
1237 }
1238 }
1239 });
1240
1241 let state: WebSocketState<EchoHandler> =
1242 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1243
1244 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1245
1246 let msg: Value = serde_json::json!({});
1247 assert!(validator.is_valid(&msg));
1248 }
1249
1250 #[test]
1252 fn test_both_schemas_validate_independently() {
1253 let handler: EchoHandler = EchoHandler;
1254 let message_schema: serde_json::Value = serde_json::json!({
1255 "type": "object",
1256 "properties": {"action": {"type": "string"}},
1257 "required": ["action"]
1258 });
1259 let response_schema: serde_json::Value = serde_json::json!({
1260 "type": "object",
1261 "properties": {"result": {"type": "string"}},
1262 "required": ["result"]
1263 });
1264
1265 let state: WebSocketState<EchoHandler> =
1266 WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema)).unwrap();
1267
1268 let msg_validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1269 let resp_validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
1270
1271 let valid_msg: Value = serde_json::json!({"action": "test"});
1272 let invalid_response: Value = serde_json::json!({"data": "oops"});
1273
1274 assert!(msg_validator.is_valid(&valid_msg));
1275 assert!(!resp_validator.is_valid(&invalid_response));
1276
1277 let invalid_msg: Value = serde_json::json!({"data": "oops"});
1278 let valid_response: Value = serde_json::json!({"result": "ok"});
1279
1280 assert!(!msg_validator.is_valid(&invalid_msg));
1281 assert!(resp_validator.is_valid(&valid_response));
1282 }
1283
1284 #[test]
1286 fn test_validation_with_large_payload() {
1287 let handler: EchoHandler = EchoHandler;
1288 let message_schema: serde_json::Value = serde_json::json!({
1289 "type": "object",
1290 "properties": {
1291 "items": {
1292 "type": "array",
1293 "items": {"type": "integer"}
1294 }
1295 },
1296 "required": ["items"]
1297 });
1298
1299 let state: WebSocketState<EchoHandler> =
1300 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1301
1302 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1303
1304 let mut items = Vec::new();
1305 for i in 0..10_000 {
1306 items.push(i);
1307 }
1308 let large_msg: Value = serde_json::json!({"items": items});
1309
1310 assert!(validator.is_valid(&large_msg));
1311 }
1312
1313 #[test]
1315 fn test_mutually_exclusive_schema_properties() {
1316 let handler: EchoHandler = EchoHandler;
1317
1318 let message_schema: serde_json::Value = serde_json::json!({
1319 "allOf": [
1320 {
1321 "type": "object",
1322 "properties": {"a": {"type": "string"}},
1323 "required": ["a"]
1324 },
1325 {
1326 "type": "object",
1327 "properties": {"b": {"type": "integer"}},
1328 "required": ["b"]
1329 }
1330 ]
1331 });
1332
1333 let state: WebSocketState<EchoHandler> =
1334 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1335
1336 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1337
1338 let valid_msg: Value = serde_json::json!({"a": "text", "b": 42});
1339 assert!(validator.is_valid(&valid_msg));
1340
1341 let invalid_msg: Value = serde_json::json!({"a": "text"});
1342 assert!(!validator.is_valid(&invalid_msg));
1343 }
1344}