1use axum::{
6 extract::{
7 State,
8 ws::{Message, WebSocket, WebSocketUpgrade},
9 },
10 response::IntoResponse,
11};
12use serde_json::Value;
13use std::sync::Arc;
14use tracing::{debug, error, info, warn};
15
16fn trace_ws(message: &str) {
17 if std::env::var("SPIKARD_WS_TRACE").ok().as_deref() == Some("1") {
18 eprintln!("[spikard-ws] {message}");
19 }
20}
21
22pub trait WebSocketHandler: Send + Sync {
58 fn handle_message(&self, message: Value) -> impl std::future::Future<Output = Option<Value>> + Send;
70
71 fn on_connect(&self) -> impl std::future::Future<Output = ()> + Send {
76 async {}
77 }
78
79 fn on_disconnect(&self) -> impl std::future::Future<Output = ()> + Send {
84 async {}
85 }
86}
87
88#[derive(Debug)]
94pub struct WebSocketState<H: WebSocketHandler> {
95 handler: Arc<H>,
97 message_schema: Option<Arc<jsonschema::Validator>>,
99 response_schema: Option<Arc<jsonschema::Validator>>,
101}
102
103impl<H: WebSocketHandler> Clone for WebSocketState<H> {
104 fn clone(&self) -> Self {
105 Self {
106 handler: Arc::clone(&self.handler),
107 message_schema: self.message_schema.clone(),
108 response_schema: self.response_schema.clone(),
109 }
110 }
111}
112
113impl<H: WebSocketHandler + 'static> WebSocketState<H> {
114 pub fn new(handler: H) -> Self {
128 Self {
129 handler: Arc::new(handler),
130 message_schema: None,
131 response_schema: None,
132 }
133 }
134
135 pub fn with_schemas(
170 handler: H,
171 message_schema: Option<serde_json::Value>,
172 response_schema: Option<serde_json::Value>,
173 ) -> Result<Self, String> {
174 let message_validator = if let Some(schema) = message_schema {
175 Some(Arc::new(
176 jsonschema::validator_for(&schema).map_err(|e| format!("Invalid message schema: {}", e))?,
177 ))
178 } else {
179 None
180 };
181
182 let response_validator = if let Some(schema) = response_schema {
183 Some(Arc::new(
184 jsonschema::validator_for(&schema).map_err(|e| format!("Invalid response schema: {}", e))?,
185 ))
186 } else {
187 None
188 };
189
190 Ok(Self {
191 handler: Arc::new(handler),
192 message_schema: message_validator,
193 response_schema: response_validator,
194 })
195 }
196}
197
198pub async fn websocket_handler<H: WebSocketHandler + 'static>(
221 ws: WebSocketUpgrade,
222 State(state): State<WebSocketState<H>>,
223) -> impl IntoResponse {
224 ws.on_upgrade(move |socket| handle_socket(socket, state))
225}
226
227async fn handle_socket<H: WebSocketHandler>(mut socket: WebSocket, state: WebSocketState<H>) {
229 info!("WebSocket client connected");
230 trace_ws("socket:connected");
231
232 state.handler.on_connect().await;
233 trace_ws("socket:on_connect:done");
234
235 while let Some(msg) = socket.recv().await {
236 match msg {
237 Ok(Message::Text(text)) => {
238 debug!("Received text message: {}", text);
239 trace_ws(&format!("recv:text len={}", text.len()));
240
241 match serde_json::from_str::<Value>(&text) {
242 Ok(json_msg) => {
243 trace_ws("recv:text:json-ok");
244 if let Some(validator) = &state.message_schema
245 && !validator.is_valid(&json_msg)
246 {
247 error!("Message validation failed");
248 trace_ws("recv:text:validation-failed");
249 let error_response = serde_json::json!({
250 "error": "Message validation failed"
251 });
252 if let Ok(error_text) = serde_json::to_string(&error_response) {
253 trace_ws(&format!("send:validation-error len={}", error_text.len()));
254 let _ = socket.send(Message::Text(error_text.into())).await;
255 }
256 continue;
257 }
258
259 if let Some(response) = state.handler.handle_message(json_msg).await {
260 trace_ws("handler:response:some");
261 if let Some(validator) = &state.response_schema
262 && !validator.is_valid(&response)
263 {
264 error!("Response validation failed");
265 trace_ws("send:response:validation-failed");
266 continue;
267 }
268
269 let response_text = serde_json::to_string(&response).unwrap_or_else(|_| "{}".to_string());
270 let response_len = response_text.len();
271
272 if let Err(e) = socket.send(Message::Text(response_text.into())).await {
273 error!("Failed to send response: {}", e);
274 trace_ws("send:response:error");
275 break;
276 }
277 trace_ws(&format!("send:response len={}", response_len));
278 } else {
279 trace_ws("handler:response:none");
280 }
281 }
282 Err(e) => {
283 warn!("Failed to parse JSON message: {}", e);
284 trace_ws("recv:text:json-error");
285 let error_msg = serde_json::json!({
286 "type": "error",
287 "message": "Invalid JSON"
288 });
289 let error_text = serde_json::to_string(&error_msg).unwrap_or_else(|_| "{}".to_string());
290 trace_ws(&format!("send:json-error len={}", error_text.len()));
291 let _ = socket.send(Message::Text(error_text.into())).await;
292 }
293 }
294 }
295 Ok(Message::Binary(data)) => {
296 debug!("Received binary message: {} bytes", data.len());
297 trace_ws(&format!("recv:binary len={}", data.len()));
298 if let Err(e) = socket.send(Message::Binary(data)).await {
299 error!("Failed to send binary response: {}", e);
300 trace_ws("send:binary:error");
301 break;
302 }
303 trace_ws("send:binary:ok");
304 }
305 Ok(Message::Ping(data)) => {
306 debug!("Received ping");
307 trace_ws(&format!("recv:ping len={}", data.len()));
308 if let Err(e) = socket.send(Message::Pong(data)).await {
309 error!("Failed to send pong: {}", e);
310 trace_ws("send:pong:error");
311 break;
312 }
313 trace_ws("send:pong:ok");
314 }
315 Ok(Message::Pong(_)) => {
316 debug!("Received pong");
317 trace_ws("recv:pong");
318 }
319 Ok(Message::Close(_)) => {
320 info!("Client closed connection");
321 trace_ws("recv:close");
322 break;
323 }
324 Err(e) => {
325 error!("WebSocket error: {}", e);
326 trace_ws(&format!("recv:error {}", e));
327 break;
328 }
329 }
330 }
331
332 state.handler.on_disconnect().await;
333 trace_ws("socket:on_disconnect:done");
334 info!("WebSocket client disconnected");
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use std::sync::Mutex;
341 use std::sync::atomic::{AtomicUsize, Ordering};
342
343 #[derive(Debug)]
344 struct EchoHandler;
345
346 impl WebSocketHandler for EchoHandler {
347 async fn handle_message(&self, message: Value) -> Option<Value> {
348 Some(message)
349 }
350 }
351
352 #[derive(Debug)]
353 struct TrackingHandler {
354 connect_count: Arc<AtomicUsize>,
355 disconnect_count: Arc<AtomicUsize>,
356 message_count: Arc<AtomicUsize>,
357 messages: Arc<Mutex<Vec<Value>>>,
358 }
359
360 impl TrackingHandler {
361 fn new() -> Self {
362 Self {
363 connect_count: Arc::new(AtomicUsize::new(0)),
364 disconnect_count: Arc::new(AtomicUsize::new(0)),
365 message_count: Arc::new(AtomicUsize::new(0)),
366 messages: Arc::new(Mutex::new(Vec::new())),
367 }
368 }
369 }
370
371 impl WebSocketHandler for TrackingHandler {
372 async fn handle_message(&self, message: Value) -> Option<Value> {
373 self.message_count.fetch_add(1, Ordering::SeqCst);
374 self.messages.lock().unwrap().push(message.clone());
375 Some(message)
376 }
377
378 async fn on_connect(&self) {
379 self.connect_count.fetch_add(1, Ordering::SeqCst);
380 }
381
382 async fn on_disconnect(&self) {
383 self.disconnect_count.fetch_add(1, Ordering::SeqCst);
384 }
385 }
386
387 #[derive(Debug)]
388 struct SelectiveHandler;
389
390 impl WebSocketHandler for SelectiveHandler {
391 async fn handle_message(&self, message: Value) -> Option<Value> {
392 if message.get("respond").is_some_and(|v| v.as_bool().unwrap_or(false)) {
393 Some(serde_json::json!({"response": "acknowledged"}))
394 } else {
395 None
396 }
397 }
398 }
399
400 #[derive(Debug)]
401 struct TransformHandler;
402
403 impl WebSocketHandler for TransformHandler {
404 async fn handle_message(&self, message: Value) -> Option<Value> {
405 message.as_object().map_or(None, |obj| {
406 let mut resp = obj.clone();
407 resp.insert("processed".to_string(), Value::Bool(true));
408 Some(Value::Object(resp))
409 })
410 }
411 }
412
413 #[test]
414 fn test_websocket_state_creation() {
415 let handler: EchoHandler = EchoHandler;
416 let state: WebSocketState<EchoHandler> = WebSocketState::new(handler);
417 let cloned: WebSocketState<EchoHandler> = state.clone();
418 assert!(Arc::ptr_eq(&state.handler, &cloned.handler));
419 }
420
421 #[test]
422 fn test_websocket_state_with_valid_schema() {
423 let handler: EchoHandler = EchoHandler;
424 let schema: serde_json::Value = serde_json::json!({
425 "type": "object",
426 "properties": {
427 "type": {"type": "string"}
428 }
429 });
430
431 let result: Result<WebSocketState<EchoHandler>, String> =
432 WebSocketState::with_schemas(handler, Some(schema), None);
433 assert!(result.is_ok());
434 }
435
436 #[test]
437 fn test_websocket_state_with_invalid_schema() {
438 let handler: EchoHandler = EchoHandler;
439 let invalid_schema: serde_json::Value = serde_json::json!({
440 "type": "not_a_real_type",
441 "invalid": "schema"
442 });
443
444 let result: Result<WebSocketState<EchoHandler>, String> =
445 WebSocketState::with_schemas(handler, Some(invalid_schema), None);
446 assert!(result.is_err());
447 if let Err(error_msg) = result {
448 assert!(error_msg.contains("Invalid message schema"));
449 }
450 }
451
452 #[test]
453 fn test_websocket_state_with_both_schemas() {
454 let handler: EchoHandler = EchoHandler;
455 let message_schema: serde_json::Value = serde_json::json!({
456 "type": "object",
457 "properties": {"action": {"type": "string"}}
458 });
459 let response_schema: serde_json::Value = serde_json::json!({
460 "type": "object",
461 "properties": {"result": {"type": "string"}}
462 });
463
464 let result: Result<WebSocketState<EchoHandler>, String> =
465 WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema));
466 assert!(result.is_ok());
467 let state: WebSocketState<EchoHandler> = result.unwrap();
468 assert!(state.message_schema.is_some());
469 assert!(state.response_schema.is_some());
470 }
471
472 #[test]
473 fn test_websocket_state_cloning_preserves_schemas() {
474 let handler: EchoHandler = EchoHandler;
475 let schema: serde_json::Value = serde_json::json!({
476 "type": "object",
477 "properties": {"id": {"type": "integer"}}
478 });
479
480 let state: WebSocketState<EchoHandler> = WebSocketState::with_schemas(handler, Some(schema), None).unwrap();
481 let cloned: WebSocketState<EchoHandler> = state.clone();
482
483 assert!(cloned.message_schema.is_some());
484 assert!(cloned.response_schema.is_none());
485 assert!(Arc::ptr_eq(&state.handler, &cloned.handler));
486 }
487
488 #[tokio::test]
489 async fn test_tracking_handler_lifecycle() {
490 let handler: TrackingHandler = TrackingHandler::new();
491 handler.on_connect().await;
492 assert_eq!(handler.connect_count.load(Ordering::SeqCst), 1);
493
494 let msg: Value = serde_json::json!({"test": "data"});
495 let _response: Option<Value> = handler.handle_message(msg).await;
496 assert_eq!(handler.message_count.load(Ordering::SeqCst), 1);
497
498 handler.on_disconnect().await;
499 assert_eq!(handler.disconnect_count.load(Ordering::SeqCst), 1);
500 }
501
502 #[tokio::test]
503 async fn test_selective_handler_responds_conditionally() {
504 let handler: SelectiveHandler = SelectiveHandler;
505
506 let respond_msg: Value = serde_json::json!({"respond": true});
507 let response1: Option<Value> = handler.handle_message(respond_msg).await;
508 assert!(response1.is_some());
509 assert_eq!(response1.unwrap(), serde_json::json!({"response": "acknowledged"}));
510
511 let no_respond_msg: Value = serde_json::json!({"respond": false});
512 let response2: Option<Value> = handler.handle_message(no_respond_msg).await;
513 assert!(response2.is_none());
514 }
515
516 #[tokio::test]
517 async fn test_transform_handler_modifies_message() {
518 let handler: TransformHandler = TransformHandler;
519 let original: Value = serde_json::json!({"name": "test"});
520 let transformed: Option<Value> = handler.handle_message(original).await;
521
522 assert!(transformed.is_some());
523 let resp: Value = transformed.unwrap();
524 assert_eq!(resp.get("name").unwrap(), "test");
525 assert_eq!(resp.get("processed").unwrap(), true);
526 }
527
528 #[tokio::test]
529 async fn test_echo_handler_preserves_json_types() {
530 let handler: EchoHandler = EchoHandler;
531
532 let messages: Vec<Value> = vec![
533 serde_json::json!({"string": "value"}),
534 serde_json::json!({"number": 42}),
535 serde_json::json!({"float": 3.14}),
536 serde_json::json!({"bool": true}),
537 serde_json::json!({"null": null}),
538 serde_json::json!({"array": [1, 2, 3]}),
539 ];
540
541 for msg in messages {
542 let response: Option<Value> = handler.handle_message(msg.clone()).await;
543 assert!(response.is_some());
544 assert_eq!(response.unwrap(), msg);
545 }
546 }
547
548 #[tokio::test]
549 async fn test_tracking_handler_accumulates_messages() {
550 let handler: TrackingHandler = TrackingHandler::new();
551
552 let messages: Vec<Value> = vec![
553 serde_json::json!({"id": 1}),
554 serde_json::json!({"id": 2}),
555 serde_json::json!({"id": 3}),
556 ];
557
558 for msg in messages {
559 let _: Option<Value> = handler.handle_message(msg).await;
560 }
561
562 assert_eq!(handler.message_count.load(Ordering::SeqCst), 3);
563 let stored: Vec<Value> = handler.messages.lock().unwrap().clone();
564 assert_eq!(stored.len(), 3);
565 assert_eq!(stored[0].get("id").unwrap(), 1);
566 assert_eq!(stored[1].get("id").unwrap(), 2);
567 assert_eq!(stored[2].get("id").unwrap(), 3);
568 }
569
570 #[tokio::test]
571 async fn test_echo_handler_with_nested_json() {
572 let handler: EchoHandler = EchoHandler;
573 let nested: Value = serde_json::json!({
574 "level1": {
575 "level2": {
576 "level3": {
577 "value": "deeply nested"
578 }
579 }
580 }
581 });
582
583 let response: Option<Value> = handler.handle_message(nested.clone()).await;
584 assert!(response.is_some());
585 assert_eq!(response.unwrap(), nested);
586 }
587
588 #[tokio::test]
589 async fn test_echo_handler_with_large_array() {
590 let handler: EchoHandler = EchoHandler;
591 let large_array: Value = serde_json::json!({
592 "items": (0..1000).collect::<Vec<i32>>()
593 });
594
595 let response: Option<Value> = handler.handle_message(large_array.clone()).await;
596 assert!(response.is_some());
597 assert_eq!(response.unwrap(), large_array);
598 }
599
600 #[tokio::test]
601 async fn test_echo_handler_with_unicode() {
602 let handler: EchoHandler = EchoHandler;
603 let unicode_msg: Value = serde_json::json!({
604 "emoji": "🚀",
605 "chinese": "你好",
606 "arabic": "مرحبا",
607 "mixed": "Hello 世界 🌍"
608 });
609
610 let response: Option<Value> = handler.handle_message(unicode_msg.clone()).await;
611 assert!(response.is_some());
612 assert_eq!(response.unwrap(), unicode_msg);
613 }
614
615 #[test]
616 fn test_websocket_state_schemas_are_independent() {
617 let handler: EchoHandler = EchoHandler;
618 let message_schema: serde_json::Value = serde_json::json!({"type": "object"});
619 let response_schema: serde_json::Value = serde_json::json!({"type": "array"});
620
621 let state: WebSocketState<EchoHandler> =
622 WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema)).unwrap();
623
624 let cloned: WebSocketState<EchoHandler> = state.clone();
625
626 assert!(state.message_schema.is_some());
627 assert!(state.response_schema.is_some());
628 assert!(cloned.message_schema.is_some());
629 assert!(cloned.response_schema.is_some());
630 }
631
632 #[test]
633 fn test_message_schema_validation_with_required_field() {
634 let handler: EchoHandler = EchoHandler;
635 let message_schema: serde_json::Value = serde_json::json!({
636 "type": "object",
637 "properties": {"type": {"type": "string"}},
638 "required": ["type"]
639 });
640
641 let state: WebSocketState<EchoHandler> =
642 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
643
644 assert!(state.message_schema.is_some());
645 assert!(state.response_schema.is_none());
646
647 let valid_msg: Value = serde_json::json!({"type": "test"});
648 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
649 assert!(validator.is_valid(&valid_msg));
650
651 let invalid_msg: Value = serde_json::json!({"other": "field"});
652 assert!(!validator.is_valid(&invalid_msg));
653 }
654
655 #[test]
656 fn test_response_schema_validation_with_required_field() {
657 let handler: EchoHandler = EchoHandler;
658 let response_schema: serde_json::Value = serde_json::json!({
659 "type": "object",
660 "properties": {"status": {"type": "string"}},
661 "required": ["status"]
662 });
663
664 let state: WebSocketState<EchoHandler> =
665 WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
666
667 assert!(state.message_schema.is_none());
668 assert!(state.response_schema.is_some());
669
670 let valid_response: Value = serde_json::json!({"status": "ok"});
671 let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
672 assert!(validator.is_valid(&valid_response));
673
674 let invalid_response: Value = serde_json::json!({"other": "field"});
675 assert!(!validator.is_valid(&invalid_response));
676 }
677
678 #[test]
679 fn test_invalid_message_schema_returns_error() {
680 let handler: EchoHandler = EchoHandler;
681 let invalid_schema: serde_json::Value = serde_json::json!({
682 "type": "invalid_type_value",
683 "properties": {}
684 });
685
686 let result: Result<WebSocketState<EchoHandler>, String> =
687 WebSocketState::with_schemas(handler, Some(invalid_schema), None);
688
689 assert!(result.is_err());
690 match result {
691 Err(error_msg) => assert!(error_msg.contains("Invalid message schema")),
692 Ok(_) => panic!("Expected error but got Ok"),
693 }
694 }
695
696 #[test]
697 fn test_invalid_response_schema_returns_error() {
698 let handler: EchoHandler = EchoHandler;
699 let invalid_schema: serde_json::Value = serde_json::json!({
700 "type": "definitely_not_valid"
701 });
702
703 let result: Result<WebSocketState<EchoHandler>, String> =
704 WebSocketState::with_schemas(handler, None, Some(invalid_schema));
705
706 assert!(result.is_err());
707 match result {
708 Err(error_msg) => assert!(error_msg.contains("Invalid response schema")),
709 Ok(_) => panic!("Expected error but got Ok"),
710 }
711 }
712
713 #[tokio::test]
714 async fn test_handler_returning_none_response() {
715 let handler: SelectiveHandler = SelectiveHandler;
716
717 let no_response_msg: Value = serde_json::json!({"respond": false});
718 let result: Option<Value> = handler.handle_message(no_response_msg).await;
719
720 assert!(result.is_none());
721 }
722
723 #[tokio::test]
724 async fn test_handler_with_complex_schema_validation() {
725 let handler: EchoHandler = EchoHandler;
726 let message_schema: serde_json::Value = serde_json::json!({
727 "type": "object",
728 "properties": {
729 "user": {
730 "type": "object",
731 "properties": {
732 "id": {"type": "integer"},
733 "name": {"type": "string"}
734 },
735 "required": ["id", "name"]
736 },
737 "action": {"type": "string"}
738 },
739 "required": ["user", "action"]
740 });
741
742 let state: WebSocketState<EchoHandler> =
743 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
744
745 let valid_msg: Value = serde_json::json!({
746 "user": {"id": 123, "name": "Alice"},
747 "action": "create"
748 });
749 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
750 assert!(validator.is_valid(&valid_msg));
751
752 let invalid_msg: Value = serde_json::json!({
753 "user": {"id": "not_an_int", "name": "Bob"},
754 "action": "create"
755 });
756 assert!(!validator.is_valid(&invalid_msg));
757 }
758
759 #[tokio::test]
760 async fn test_tracking_handler_with_multiple_message_types() {
761 let handler: TrackingHandler = TrackingHandler::new();
762
763 let messages: Vec<Value> = vec![
764 serde_json::json!({"type": "text", "content": "hello"}),
765 serde_json::json!({"type": "image", "url": "http://example.com/image.png"}),
766 serde_json::json!({"type": "video", "duration": 120}),
767 ];
768
769 for msg in messages {
770 let _: Option<Value> = handler.handle_message(msg).await;
771 }
772
773 assert_eq!(handler.message_count.load(Ordering::SeqCst), 3);
774 let stored: Vec<Value> = handler.messages.lock().unwrap().clone();
775 assert_eq!(stored.len(), 3);
776 assert_eq!(stored[0].get("type").unwrap(), "text");
777 assert_eq!(stored[1].get("type").unwrap(), "image");
778 assert_eq!(stored[2].get("type").unwrap(), "video");
779 }
780
781 #[tokio::test]
782 async fn test_selective_handler_with_explicit_false() {
783 let handler: SelectiveHandler = SelectiveHandler;
784
785 let msg: Value = serde_json::json!({"respond": false, "data": "test"});
786 let response: Option<Value> = handler.handle_message(msg).await;
787
788 assert!(response.is_none());
789 }
790
791 #[tokio::test]
792 async fn test_selective_handler_without_respond_field() {
793 let handler: SelectiveHandler = SelectiveHandler;
794
795 let msg: Value = serde_json::json!({"data": "test"});
796 let response: Option<Value> = handler.handle_message(msg).await;
797
798 assert!(response.is_none());
799 }
800
801 #[tokio::test]
802 async fn test_transform_handler_with_empty_object() {
803 let handler: TransformHandler = TransformHandler;
804 let original: Value = serde_json::json!({});
805 let transformed: Option<Value> = handler.handle_message(original).await;
806
807 assert!(transformed.is_some());
808 let resp: Value = transformed.unwrap();
809 assert_eq!(resp.get("processed").unwrap(), true);
810 assert_eq!(resp.as_object().unwrap().len(), 1);
811 }
812
813 #[tokio::test]
814 async fn test_transform_handler_preserves_all_fields() {
815 let handler: TransformHandler = TransformHandler;
816 let original: Value = serde_json::json!({
817 "field1": "value1",
818 "field2": 42,
819 "field3": true,
820 "nested": {"key": "value"}
821 });
822 let transformed: Option<Value> = handler.handle_message(original.clone()).await;
823
824 assert!(transformed.is_some());
825 let resp: Value = transformed.unwrap();
826 assert_eq!(resp.get("field1").unwrap(), "value1");
827 assert_eq!(resp.get("field2").unwrap(), 42);
828 assert_eq!(resp.get("field3").unwrap(), true);
829 assert_eq!(resp.get("nested").unwrap(), &serde_json::json!({"key": "value"}));
830 assert_eq!(resp.get("processed").unwrap(), true);
831 }
832
833 #[tokio::test]
834 async fn test_transform_handler_with_non_object_input() {
835 let handler: TransformHandler = TransformHandler;
836
837 let array: Value = serde_json::json!([1, 2, 3]);
838 let response1: Option<Value> = handler.handle_message(array).await;
839 assert!(response1.is_none());
840
841 let string: Value = serde_json::json!("not an object");
842 let response2: Option<Value> = handler.handle_message(string).await;
843 assert!(response2.is_none());
844
845 let number: Value = serde_json::json!(42);
846 let response3: Option<Value> = handler.handle_message(number).await;
847 assert!(response3.is_none());
848 }
849
850 #[test]
852 fn test_message_schema_rejects_wrong_type() {
853 let handler: EchoHandler = EchoHandler;
854 let message_schema: serde_json::Value = serde_json::json!({
855 "type": "object",
856 "properties": {"id": {"type": "integer"}},
857 "required": ["id"]
858 });
859
860 let state: WebSocketState<EchoHandler> =
861 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
862
863 let invalid_msg: Value = serde_json::json!({"id": "not_an_integer"});
864 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
865 assert!(!validator.is_valid(&invalid_msg));
866 }
867
868 #[test]
870 fn test_response_schema_rejects_invalid_type() {
871 let handler: EchoHandler = EchoHandler;
872 let response_schema: serde_json::Value = serde_json::json!({
873 "type": "object",
874 "properties": {"count": {"type": "integer"}},
875 "required": ["count"]
876 });
877
878 let state: WebSocketState<EchoHandler> =
879 WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
880
881 let invalid_response: Value = serde_json::json!([1, 2, 3]);
882 let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
883 assert!(!validator.is_valid(&invalid_response));
884 }
885
886 #[test]
888 fn test_message_missing_multiple_required_fields() {
889 let handler: EchoHandler = EchoHandler;
890 let message_schema: serde_json::Value = serde_json::json!({
891 "type": "object",
892 "properties": {
893 "user_id": {"type": "integer"},
894 "action": {"type": "string"},
895 "timestamp": {"type": "string"}
896 },
897 "required": ["user_id", "action", "timestamp"]
898 });
899
900 let state: WebSocketState<EchoHandler> =
901 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
902
903 let invalid_msg: Value = serde_json::json!({"other": "value"});
904 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
905 assert!(!validator.is_valid(&invalid_msg));
906
907 let partial_msg: Value = serde_json::json!({"user_id": 123});
908 assert!(!validator.is_valid(&partial_msg));
909 }
910
911 #[test]
913 fn test_deeply_nested_schema_validation_failure() {
914 let handler: EchoHandler = EchoHandler;
915 let message_schema: serde_json::Value = serde_json::json!({
916 "type": "object",
917 "properties": {
918 "metadata": {
919 "type": "object",
920 "properties": {
921 "request": {
922 "type": "object",
923 "properties": {
924 "id": {"type": "string"}
925 },
926 "required": ["id"]
927 }
928 },
929 "required": ["request"]
930 }
931 },
932 "required": ["metadata"]
933 });
934
935 let state: WebSocketState<EchoHandler> =
936 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
937
938 let invalid_msg: Value = serde_json::json!({
939 "metadata": {
940 "request": {}
941 }
942 });
943 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
944 assert!(!validator.is_valid(&invalid_msg));
945 }
946
947 #[test]
949 fn test_array_property_type_validation() {
950 let handler: EchoHandler = EchoHandler;
951 let message_schema: serde_json::Value = serde_json::json!({
952 "type": "object",
953 "properties": {
954 "ids": {
955 "type": "array",
956 "items": {"type": "integer"}
957 }
958 }
959 });
960
961 let state: WebSocketState<EchoHandler> =
962 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
963
964 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
965
966 let valid_msg: Value = serde_json::json!({"ids": [1, 2, 3]});
967 assert!(validator.is_valid(&valid_msg));
968
969 let invalid_msg: Value = serde_json::json!({"ids": [1, "two", 3]});
970 assert!(!validator.is_valid(&invalid_msg));
971
972 let invalid_msg2: Value = serde_json::json!({"ids": "not_an_array"});
973 assert!(!validator.is_valid(&invalid_msg2));
974 }
975
976 #[test]
978 fn test_enum_property_validation() {
979 let handler: EchoHandler = EchoHandler;
980 let message_schema: serde_json::Value = serde_json::json!({
981 "type": "object",
982 "properties": {
983 "status": {
984 "type": "string",
985 "enum": ["pending", "active", "completed"]
986 }
987 },
988 "required": ["status"]
989 });
990
991 let state: WebSocketState<EchoHandler> =
992 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
993
994 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
995
996 let valid_msg: Value = serde_json::json!({"status": "active"});
997 assert!(validator.is_valid(&valid_msg));
998
999 let invalid_msg: Value = serde_json::json!({"status": "unknown"});
1000 assert!(!validator.is_valid(&invalid_msg));
1001 }
1002
1003 #[test]
1005 fn test_number_range_validation() {
1006 let handler: EchoHandler = EchoHandler;
1007 let message_schema: serde_json::Value = serde_json::json!({
1008 "type": "object",
1009 "properties": {
1010 "age": {
1011 "type": "integer",
1012 "minimum": 0,
1013 "maximum": 150
1014 }
1015 },
1016 "required": ["age"]
1017 });
1018
1019 let state: WebSocketState<EchoHandler> =
1020 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1021
1022 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1023
1024 let valid_msg: Value = serde_json::json!({"age": 25});
1025 assert!(validator.is_valid(&valid_msg));
1026
1027 let invalid_msg: Value = serde_json::json!({"age": -1});
1028 assert!(!validator.is_valid(&invalid_msg));
1029
1030 let invalid_msg2: Value = serde_json::json!({"age": 200});
1031 assert!(!validator.is_valid(&invalid_msg2));
1032 }
1033
1034 #[test]
1036 fn test_string_length_validation() {
1037 let handler: EchoHandler = EchoHandler;
1038 let message_schema: serde_json::Value = serde_json::json!({
1039 "type": "object",
1040 "properties": {
1041 "username": {
1042 "type": "string",
1043 "minLength": 3,
1044 "maxLength": 20
1045 }
1046 },
1047 "required": ["username"]
1048 });
1049
1050 let state: WebSocketState<EchoHandler> =
1051 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1052
1053 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1054
1055 let valid_msg: Value = serde_json::json!({"username": "alice"});
1056 assert!(validator.is_valid(&valid_msg));
1057
1058 let invalid_msg: Value = serde_json::json!({"username": "ab"});
1059 assert!(!validator.is_valid(&invalid_msg));
1060
1061 let invalid_msg2: Value =
1062 serde_json::json!({"username": "this_is_a_very_long_username_over_twenty_characters"});
1063 assert!(!validator.is_valid(&invalid_msg2));
1064 }
1065
1066 #[test]
1068 fn test_pattern_validation() {
1069 let handler: EchoHandler = EchoHandler;
1070 let message_schema: serde_json::Value = serde_json::json!({
1071 "type": "object",
1072 "properties": {
1073 "email": {
1074 "type": "string",
1075 "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
1076 }
1077 },
1078 "required": ["email"]
1079 });
1080
1081 let state: WebSocketState<EchoHandler> =
1082 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1083
1084 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1085
1086 let valid_msg: Value = serde_json::json!({"email": "user@example.com"});
1087 assert!(validator.is_valid(&valid_msg));
1088
1089 let invalid_msg: Value = serde_json::json!({"email": "user@example"});
1090 assert!(!validator.is_valid(&invalid_msg));
1091
1092 let invalid_msg2: Value = serde_json::json!({"email": "userexample.com"});
1093 assert!(!validator.is_valid(&invalid_msg2));
1094 }
1095
1096 #[test]
1098 fn test_additional_properties_validation() {
1099 let handler: EchoHandler = EchoHandler;
1100 let message_schema: serde_json::Value = serde_json::json!({
1101 "type": "object",
1102 "properties": {
1103 "name": {"type": "string"}
1104 },
1105 "additionalProperties": false
1106 });
1107
1108 let state: WebSocketState<EchoHandler> =
1109 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1110
1111 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1112
1113 let valid_msg: Value = serde_json::json!({"name": "Alice"});
1114 assert!(validator.is_valid(&valid_msg));
1115
1116 let invalid_msg: Value = serde_json::json!({"name": "Bob", "age": 30});
1117 assert!(!validator.is_valid(&invalid_msg));
1118 }
1119
1120 #[test]
1122 fn test_one_of_constraint() {
1123 let handler: EchoHandler = EchoHandler;
1124 let message_schema: serde_json::Value = serde_json::json!({
1125 "type": "object",
1126 "oneOf": [
1127 {
1128 "properties": {"type": {"const": "text"}},
1129 "required": ["type"]
1130 },
1131 {
1132 "properties": {"type": {"const": "number"}},
1133 "required": ["type"]
1134 }
1135 ]
1136 });
1137
1138 let state: WebSocketState<EchoHandler> =
1139 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1140
1141 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1142
1143 let valid_msg: Value = serde_json::json!({"type": "text"});
1144 assert!(validator.is_valid(&valid_msg));
1145
1146 let invalid_msg: Value = serde_json::json!({"type": "unknown"});
1147 assert!(!validator.is_valid(&invalid_msg));
1148 }
1149
1150 #[test]
1152 fn test_any_of_constraint() {
1153 let handler: EchoHandler = EchoHandler;
1154 let message_schema: serde_json::Value = serde_json::json!({
1155 "type": "object",
1156 "properties": {
1157 "value": {"type": ["string", "integer"]}
1158 },
1159 "required": ["value"]
1160 });
1161
1162 let state: WebSocketState<EchoHandler> =
1163 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1164
1165 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1166
1167 let msg1: Value = serde_json::json!({"value": "text"});
1168 assert!(validator.is_valid(&msg1));
1169
1170 let msg2: Value = serde_json::json!({"value": 42});
1171 assert!(validator.is_valid(&msg2));
1172
1173 let invalid_msg: Value = serde_json::json!({"value": true});
1174 assert!(!validator.is_valid(&invalid_msg));
1175 }
1176
1177 #[test]
1179 fn test_response_schema_with_multiple_constraints() {
1180 let handler: EchoHandler = EchoHandler;
1181 let response_schema: serde_json::Value = serde_json::json!({
1182 "type": "object",
1183 "properties": {
1184 "success": {"type": "boolean"},
1185 "data": {
1186 "type": "object",
1187 "properties": {
1188 "items": {
1189 "type": "array",
1190 "items": {"type": "object"},
1191 "minItems": 1
1192 }
1193 },
1194 "required": ["items"]
1195 }
1196 },
1197 "required": ["success", "data"]
1198 });
1199
1200 let state: WebSocketState<EchoHandler> =
1201 WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
1202
1203 let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
1204
1205 let valid_response: Value = serde_json::json!({
1206 "success": true,
1207 "data": {
1208 "items": [{"id": 1}]
1209 }
1210 });
1211 assert!(validator.is_valid(&valid_response));
1212
1213 let invalid_response: Value = serde_json::json!({
1214 "success": true,
1215 "data": {
1216 "items": []
1217 }
1218 });
1219 assert!(!validator.is_valid(&invalid_response));
1220
1221 let invalid_response2: Value = serde_json::json!({
1222 "success": true
1223 });
1224 assert!(!validator.is_valid(&invalid_response2));
1225 }
1226
1227 #[test]
1229 fn test_null_value_validation() {
1230 let handler: EchoHandler = EchoHandler;
1231 let message_schema: serde_json::Value = serde_json::json!({
1232 "type": "object",
1233 "properties": {
1234 "optional_field": {"type": ["string", "null"]},
1235 "required_field": {"type": "string"}
1236 },
1237 "required": ["required_field"]
1238 });
1239
1240 let state: WebSocketState<EchoHandler> =
1241 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1242
1243 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1244
1245 let msg1: Value = serde_json::json!({
1246 "optional_field": null,
1247 "required_field": "value"
1248 });
1249 assert!(validator.is_valid(&msg1));
1250
1251 let msg2: Value = serde_json::json!({"required_field": "value"});
1252 assert!(validator.is_valid(&msg2));
1253
1254 let invalid_msg: Value = serde_json::json!({"required_field": null});
1255 assert!(!validator.is_valid(&invalid_msg));
1256 }
1257
1258 #[test]
1260 fn test_schema_with_defaults_still_validates() {
1261 let handler: EchoHandler = EchoHandler;
1262 let message_schema: serde_json::Value = serde_json::json!({
1263 "type": "object",
1264 "properties": {
1265 "status": {
1266 "type": "string",
1267 "default": "pending"
1268 }
1269 }
1270 });
1271
1272 let state: WebSocketState<EchoHandler> =
1273 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1274
1275 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1276
1277 let msg: Value = serde_json::json!({});
1278 assert!(validator.is_valid(&msg));
1279 }
1280
1281 #[test]
1283 fn test_both_schemas_validate_independently() {
1284 let handler: EchoHandler = EchoHandler;
1285 let message_schema: serde_json::Value = serde_json::json!({
1286 "type": "object",
1287 "properties": {"action": {"type": "string"}},
1288 "required": ["action"]
1289 });
1290 let response_schema: serde_json::Value = serde_json::json!({
1291 "type": "object",
1292 "properties": {"result": {"type": "string"}},
1293 "required": ["result"]
1294 });
1295
1296 let state: WebSocketState<EchoHandler> =
1297 WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema)).unwrap();
1298
1299 let msg_validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1300 let resp_validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
1301
1302 let valid_msg: Value = serde_json::json!({"action": "test"});
1303 let invalid_response: Value = serde_json::json!({"data": "oops"});
1304
1305 assert!(msg_validator.is_valid(&valid_msg));
1306 assert!(!resp_validator.is_valid(&invalid_response));
1307
1308 let invalid_msg: Value = serde_json::json!({"data": "oops"});
1309 let valid_response: Value = serde_json::json!({"result": "ok"});
1310
1311 assert!(!msg_validator.is_valid(&invalid_msg));
1312 assert!(resp_validator.is_valid(&valid_response));
1313 }
1314
1315 #[test]
1317 fn test_validation_with_large_payload() {
1318 let handler: EchoHandler = EchoHandler;
1319 let message_schema: serde_json::Value = serde_json::json!({
1320 "type": "object",
1321 "properties": {
1322 "items": {
1323 "type": "array",
1324 "items": {"type": "integer"}
1325 }
1326 },
1327 "required": ["items"]
1328 });
1329
1330 let state: WebSocketState<EchoHandler> =
1331 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1332
1333 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1334
1335 let mut items = Vec::new();
1336 for i in 0..10_000 {
1337 items.push(i);
1338 }
1339 let large_msg: Value = serde_json::json!({"items": items});
1340
1341 assert!(validator.is_valid(&large_msg));
1342 }
1343
1344 #[test]
1346 fn test_mutually_exclusive_schema_properties() {
1347 let handler: EchoHandler = EchoHandler;
1348
1349 let message_schema: serde_json::Value = serde_json::json!({
1350 "allOf": [
1351 {
1352 "type": "object",
1353 "properties": {"a": {"type": "string"}},
1354 "required": ["a"]
1355 },
1356 {
1357 "type": "object",
1358 "properties": {"b": {"type": "integer"}},
1359 "required": ["b"]
1360 }
1361 ]
1362 });
1363
1364 let state: WebSocketState<EchoHandler> =
1365 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1366
1367 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1368
1369 let valid_msg: Value = serde_json::json!({"a": "text", "b": 42});
1370 assert!(validator.is_valid(&valid_msg));
1371
1372 let invalid_msg: Value = serde_json::json!({"a": "text"});
1373 assert!(!validator.is_valid(&invalid_msg));
1374 }
1375}