1use axum::{
6 extract::{
7 State,
8 ws::{Message, WebSocket, WebSocketUpgrade},
9 },
10 response::IntoResponse,
11};
12use serde_json::Value;
13use std::sync::Arc;
14use tracing::{debug, error, info, warn};
15
16fn trace_ws(message: &str) {
17 if std::env::var("SPIKARD_WS_TRACE").ok().as_deref() == Some("1") {
18 eprintln!("[spikard-ws] {message}");
19 }
20}
21
22pub trait WebSocketHandler: Send + Sync {
58 fn handle_message(&self, message: Value) -> impl std::future::Future<Output = Option<Value>> + Send;
70
71 fn on_connect(&self) -> impl std::future::Future<Output = ()> + Send {
76 async {}
77 }
78
79 fn on_disconnect(&self) -> impl std::future::Future<Output = ()> + Send {
84 async {}
85 }
86}
87
88#[derive(Debug)]
94pub struct WebSocketState<H: WebSocketHandler> {
95 handler: Arc<H>,
97 message_schema: Option<Arc<jsonschema::Validator>>,
99 response_schema: Option<Arc<jsonschema::Validator>>,
101}
102
103impl<H: WebSocketHandler> Clone for WebSocketState<H> {
104 fn clone(&self) -> Self {
105 Self {
106 handler: Arc::clone(&self.handler),
107 message_schema: self.message_schema.clone(),
108 response_schema: self.response_schema.clone(),
109 }
110 }
111}
112
113impl<H: WebSocketHandler + 'static> WebSocketState<H> {
114 pub fn new(handler: H) -> Self {
128 Self {
129 handler: Arc::new(handler),
130 message_schema: None,
131 response_schema: None,
132 }
133 }
134
135 pub fn with_schemas(
170 handler: H,
171 message_schema: Option<serde_json::Value>,
172 response_schema: Option<serde_json::Value>,
173 ) -> Result<Self, String> {
174 let message_validator = if let Some(schema) = message_schema {
175 Some(Arc::new(
176 jsonschema::validator_for(&schema).map_err(|e| format!("Invalid message schema: {}", e))?,
177 ))
178 } else {
179 None
180 };
181
182 let response_validator = if let Some(schema) = response_schema {
183 Some(Arc::new(
184 jsonschema::validator_for(&schema).map_err(|e| format!("Invalid response schema: {}", e))?,
185 ))
186 } else {
187 None
188 };
189
190 Ok(Self {
191 handler: Arc::new(handler),
192 message_schema: message_validator,
193 response_schema: response_validator,
194 })
195 }
196
197 pub async fn on_connect(&self) {
199 self.handler.on_connect().await;
200 }
201
202 pub async fn on_disconnect(&self) {
204 self.handler.on_disconnect().await;
205 }
206
207 pub async fn handle_message_validated(&self, message: Value) -> Result<Option<Value>, String> {
209 if let Some(validator) = &self.message_schema
210 && !validator.is_valid(&message)
211 {
212 return Err("Message validation failed".to_string());
213 }
214
215 let response = self.handler.handle_message(message).await;
216 if let Some(ref value) = response
217 && let Some(validator) = &self.response_schema
218 && !validator.is_valid(value)
219 {
220 return Ok(None);
221 }
222
223 Ok(response)
224 }
225}
226
227pub async fn websocket_handler<H: WebSocketHandler + 'static>(
250 ws: WebSocketUpgrade,
251 State(state): State<WebSocketState<H>>,
252) -> impl IntoResponse {
253 ws.on_upgrade(move |socket| handle_socket(socket, state))
254}
255
256async fn handle_socket<H: WebSocketHandler>(mut socket: WebSocket, state: WebSocketState<H>) {
258 info!("WebSocket client connected");
259 trace_ws("socket:connected");
260
261 state.handler.on_connect().await;
262 trace_ws("socket:on_connect:done");
263
264 while let Some(msg) = socket.recv().await {
265 match msg {
266 Ok(Message::Text(text)) => {
267 debug!("Received text message: {}", text);
268 trace_ws(&format!("recv:text len={}", text.len()));
269
270 match serde_json::from_str::<Value>(&text) {
271 Ok(json_msg) => {
272 trace_ws("recv:text:json-ok");
273 if let Some(validator) = &state.message_schema
274 && !validator.is_valid(&json_msg)
275 {
276 error!("Message validation failed");
277 trace_ws("recv:text:validation-failed");
278 let error_response = serde_json::json!({
279 "error": "Message validation failed"
280 });
281 if let Ok(error_text) = serde_json::to_string(&error_response) {
282 trace_ws(&format!("send:validation-error len={}", error_text.len()));
283 let _ = socket.send(Message::Text(error_text.into())).await;
284 }
285 continue;
286 }
287
288 if let Some(response) = state.handler.handle_message(json_msg).await {
289 trace_ws("handler:response:some");
290 if let Some(validator) = &state.response_schema
291 && !validator.is_valid(&response)
292 {
293 error!("Response validation failed");
294 trace_ws("send:response:validation-failed");
295 continue;
296 }
297
298 let response_text = serde_json::to_string(&response).unwrap_or_else(|_| "{}".to_string());
299 let response_len = response_text.len();
300
301 if let Err(e) = socket.send(Message::Text(response_text.into())).await {
302 error!("Failed to send response: {}", e);
303 trace_ws("send:response:error");
304 break;
305 }
306 trace_ws(&format!("send:response len={}", response_len));
307 } else {
308 trace_ws("handler:response:none");
309 }
310 }
311 Err(e) => {
312 warn!("Failed to parse JSON message: {}", e);
313 trace_ws("recv:text:json-error");
314 let error_msg = serde_json::json!({
315 "type": "error",
316 "message": "Invalid JSON"
317 });
318 let error_text = serde_json::to_string(&error_msg).unwrap_or_else(|_| "{}".to_string());
319 trace_ws(&format!("send:json-error len={}", error_text.len()));
320 let _ = socket.send(Message::Text(error_text.into())).await;
321 }
322 }
323 }
324 Ok(Message::Binary(data)) => {
325 debug!("Received binary message: {} bytes", data.len());
326 trace_ws(&format!("recv:binary len={}", data.len()));
327 if let Err(e) = socket.send(Message::Binary(data)).await {
328 error!("Failed to send binary response: {}", e);
329 trace_ws("send:binary:error");
330 break;
331 }
332 trace_ws("send:binary:ok");
333 }
334 Ok(Message::Ping(data)) => {
335 debug!("Received ping");
336 trace_ws(&format!("recv:ping len={}", data.len()));
337 if let Err(e) = socket.send(Message::Pong(data)).await {
338 error!("Failed to send pong: {}", e);
339 trace_ws("send:pong:error");
340 break;
341 }
342 trace_ws("send:pong:ok");
343 }
344 Ok(Message::Pong(_)) => {
345 debug!("Received pong");
346 trace_ws("recv:pong");
347 }
348 Ok(Message::Close(_)) => {
349 info!("Client closed connection");
350 trace_ws("recv:close");
351 break;
352 }
353 Err(e) => {
354 error!("WebSocket error: {}", e);
355 trace_ws(&format!("recv:error {}", e));
356 break;
357 }
358 }
359 }
360
361 state.handler.on_disconnect().await;
362 trace_ws("socket:on_disconnect:done");
363 info!("WebSocket client disconnected");
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use std::sync::Mutex;
370 use std::sync::atomic::{AtomicUsize, Ordering};
371
372 #[derive(Debug)]
373 struct EchoHandler;
374
375 impl WebSocketHandler for EchoHandler {
376 async fn handle_message(&self, message: Value) -> Option<Value> {
377 Some(message)
378 }
379 }
380
381 #[derive(Debug)]
382 struct TrackingHandler {
383 connect_count: Arc<AtomicUsize>,
384 disconnect_count: Arc<AtomicUsize>,
385 message_count: Arc<AtomicUsize>,
386 messages: Arc<Mutex<Vec<Value>>>,
387 }
388
389 impl TrackingHandler {
390 fn new() -> Self {
391 Self {
392 connect_count: Arc::new(AtomicUsize::new(0)),
393 disconnect_count: Arc::new(AtomicUsize::new(0)),
394 message_count: Arc::new(AtomicUsize::new(0)),
395 messages: Arc::new(Mutex::new(Vec::new())),
396 }
397 }
398 }
399
400 impl WebSocketHandler for TrackingHandler {
401 async fn handle_message(&self, message: Value) -> Option<Value> {
402 self.message_count.fetch_add(1, Ordering::SeqCst);
403 self.messages.lock().unwrap().push(message.clone());
404 Some(message)
405 }
406
407 async fn on_connect(&self) {
408 self.connect_count.fetch_add(1, Ordering::SeqCst);
409 }
410
411 async fn on_disconnect(&self) {
412 self.disconnect_count.fetch_add(1, Ordering::SeqCst);
413 }
414 }
415
416 #[derive(Debug)]
417 struct SelectiveHandler;
418
419 impl WebSocketHandler for SelectiveHandler {
420 async fn handle_message(&self, message: Value) -> Option<Value> {
421 if message.get("respond").is_some_and(|v| v.as_bool().unwrap_or(false)) {
422 Some(serde_json::json!({"response": "acknowledged"}))
423 } else {
424 None
425 }
426 }
427 }
428
429 #[derive(Debug)]
430 struct TransformHandler;
431
432 impl WebSocketHandler for TransformHandler {
433 async fn handle_message(&self, message: Value) -> Option<Value> {
434 message.as_object().map_or(None, |obj| {
435 let mut resp = obj.clone();
436 resp.insert("processed".to_string(), Value::Bool(true));
437 Some(Value::Object(resp))
438 })
439 }
440 }
441
442 #[test]
443 fn test_websocket_state_creation() {
444 let handler: EchoHandler = EchoHandler;
445 let state: WebSocketState<EchoHandler> = WebSocketState::new(handler);
446 let cloned: WebSocketState<EchoHandler> = state.clone();
447 assert!(Arc::ptr_eq(&state.handler, &cloned.handler));
448 }
449
450 #[test]
451 fn test_websocket_state_with_valid_schema() {
452 let handler: EchoHandler = EchoHandler;
453 let schema: serde_json::Value = serde_json::json!({
454 "type": "object",
455 "properties": {
456 "type": {"type": "string"}
457 }
458 });
459
460 let result: Result<WebSocketState<EchoHandler>, String> =
461 WebSocketState::with_schemas(handler, Some(schema), None);
462 assert!(result.is_ok());
463 }
464
465 #[test]
466 fn test_websocket_state_with_invalid_schema() {
467 let handler: EchoHandler = EchoHandler;
468 let invalid_schema: serde_json::Value = serde_json::json!({
469 "type": "not_a_real_type",
470 "invalid": "schema"
471 });
472
473 let result: Result<WebSocketState<EchoHandler>, String> =
474 WebSocketState::with_schemas(handler, Some(invalid_schema), None);
475 assert!(result.is_err());
476 if let Err(error_msg) = result {
477 assert!(error_msg.contains("Invalid message schema"));
478 }
479 }
480
481 #[test]
482 fn test_websocket_state_with_both_schemas() {
483 let handler: EchoHandler = EchoHandler;
484 let message_schema: serde_json::Value = serde_json::json!({
485 "type": "object",
486 "properties": {"action": {"type": "string"}}
487 });
488 let response_schema: serde_json::Value = serde_json::json!({
489 "type": "object",
490 "properties": {"result": {"type": "string"}}
491 });
492
493 let result: Result<WebSocketState<EchoHandler>, String> =
494 WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema));
495 assert!(result.is_ok());
496 let state: WebSocketState<EchoHandler> = result.unwrap();
497 assert!(state.message_schema.is_some());
498 assert!(state.response_schema.is_some());
499 }
500
501 #[test]
502 fn test_websocket_state_cloning_preserves_schemas() {
503 let handler: EchoHandler = EchoHandler;
504 let schema: serde_json::Value = serde_json::json!({
505 "type": "object",
506 "properties": {"id": {"type": "integer"}}
507 });
508
509 let state: WebSocketState<EchoHandler> = WebSocketState::with_schemas(handler, Some(schema), None).unwrap();
510 let cloned: WebSocketState<EchoHandler> = state.clone();
511
512 assert!(cloned.message_schema.is_some());
513 assert!(cloned.response_schema.is_none());
514 assert!(Arc::ptr_eq(&state.handler, &cloned.handler));
515 }
516
517 #[tokio::test]
518 async fn test_tracking_handler_lifecycle() {
519 let handler: TrackingHandler = TrackingHandler::new();
520 handler.on_connect().await;
521 assert_eq!(handler.connect_count.load(Ordering::SeqCst), 1);
522
523 let msg: Value = serde_json::json!({"test": "data"});
524 let _response: Option<Value> = handler.handle_message(msg).await;
525 assert_eq!(handler.message_count.load(Ordering::SeqCst), 1);
526
527 handler.on_disconnect().await;
528 assert_eq!(handler.disconnect_count.load(Ordering::SeqCst), 1);
529 }
530
531 #[tokio::test]
532 async fn test_selective_handler_responds_conditionally() {
533 let handler: SelectiveHandler = SelectiveHandler;
534
535 let respond_msg: Value = serde_json::json!({"respond": true});
536 let response1: Option<Value> = handler.handle_message(respond_msg).await;
537 assert!(response1.is_some());
538 assert_eq!(response1.unwrap(), serde_json::json!({"response": "acknowledged"}));
539
540 let no_respond_msg: Value = serde_json::json!({"respond": false});
541 let response2: Option<Value> = handler.handle_message(no_respond_msg).await;
542 assert!(response2.is_none());
543 }
544
545 #[tokio::test]
546 async fn test_transform_handler_modifies_message() {
547 let handler: TransformHandler = TransformHandler;
548 let original: Value = serde_json::json!({"name": "test"});
549 let transformed: Option<Value> = handler.handle_message(original).await;
550
551 assert!(transformed.is_some());
552 let resp: Value = transformed.unwrap();
553 assert_eq!(resp.get("name").unwrap(), "test");
554 assert_eq!(resp.get("processed").unwrap(), true);
555 }
556
557 #[tokio::test]
558 async fn test_echo_handler_preserves_json_types() {
559 let handler: EchoHandler = EchoHandler;
560
561 let messages: Vec<Value> = vec![
562 serde_json::json!({"string": "value"}),
563 serde_json::json!({"number": 42}),
564 serde_json::json!({"float": 3.14}),
565 serde_json::json!({"bool": true}),
566 serde_json::json!({"null": null}),
567 serde_json::json!({"array": [1, 2, 3]}),
568 ];
569
570 for msg in messages {
571 let response: Option<Value> = handler.handle_message(msg.clone()).await;
572 assert!(response.is_some());
573 assert_eq!(response.unwrap(), msg);
574 }
575 }
576
577 #[tokio::test]
578 async fn test_tracking_handler_accumulates_messages() {
579 let handler: TrackingHandler = TrackingHandler::new();
580
581 let messages: Vec<Value> = vec![
582 serde_json::json!({"id": 1}),
583 serde_json::json!({"id": 2}),
584 serde_json::json!({"id": 3}),
585 ];
586
587 for msg in messages {
588 let _: Option<Value> = handler.handle_message(msg).await;
589 }
590
591 assert_eq!(handler.message_count.load(Ordering::SeqCst), 3);
592 let stored: Vec<Value> = handler.messages.lock().unwrap().clone();
593 assert_eq!(stored.len(), 3);
594 assert_eq!(stored[0].get("id").unwrap(), 1);
595 assert_eq!(stored[1].get("id").unwrap(), 2);
596 assert_eq!(stored[2].get("id").unwrap(), 3);
597 }
598
599 #[tokio::test]
600 async fn test_echo_handler_with_nested_json() {
601 let handler: EchoHandler = EchoHandler;
602 let nested: Value = serde_json::json!({
603 "level1": {
604 "level2": {
605 "level3": {
606 "value": "deeply nested"
607 }
608 }
609 }
610 });
611
612 let response: Option<Value> = handler.handle_message(nested.clone()).await;
613 assert!(response.is_some());
614 assert_eq!(response.unwrap(), nested);
615 }
616
617 #[tokio::test]
618 async fn test_echo_handler_with_large_array() {
619 let handler: EchoHandler = EchoHandler;
620 let large_array: Value = serde_json::json!({
621 "items": (0..1000).collect::<Vec<i32>>()
622 });
623
624 let response: Option<Value> = handler.handle_message(large_array.clone()).await;
625 assert!(response.is_some());
626 assert_eq!(response.unwrap(), large_array);
627 }
628
629 #[tokio::test]
630 async fn test_echo_handler_with_unicode() {
631 let handler: EchoHandler = EchoHandler;
632 let unicode_msg: Value = serde_json::json!({
633 "emoji": "🚀",
634 "chinese": "你好",
635 "arabic": "مرحبا",
636 "mixed": "Hello 世界 🌍"
637 });
638
639 let response: Option<Value> = handler.handle_message(unicode_msg.clone()).await;
640 assert!(response.is_some());
641 assert_eq!(response.unwrap(), unicode_msg);
642 }
643
644 #[test]
645 fn test_websocket_state_schemas_are_independent() {
646 let handler: EchoHandler = EchoHandler;
647 let message_schema: serde_json::Value = serde_json::json!({"type": "object"});
648 let response_schema: serde_json::Value = serde_json::json!({"type": "array"});
649
650 let state: WebSocketState<EchoHandler> =
651 WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema)).unwrap();
652
653 let cloned: WebSocketState<EchoHandler> = state.clone();
654
655 assert!(state.message_schema.is_some());
656 assert!(state.response_schema.is_some());
657 assert!(cloned.message_schema.is_some());
658 assert!(cloned.response_schema.is_some());
659 }
660
661 #[test]
662 fn test_message_schema_validation_with_required_field() {
663 let handler: EchoHandler = EchoHandler;
664 let message_schema: serde_json::Value = serde_json::json!({
665 "type": "object",
666 "properties": {"type": {"type": "string"}},
667 "required": ["type"]
668 });
669
670 let state: WebSocketState<EchoHandler> =
671 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
672
673 assert!(state.message_schema.is_some());
674 assert!(state.response_schema.is_none());
675
676 let valid_msg: Value = serde_json::json!({"type": "test"});
677 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
678 assert!(validator.is_valid(&valid_msg));
679
680 let invalid_msg: Value = serde_json::json!({"other": "field"});
681 assert!(!validator.is_valid(&invalid_msg));
682 }
683
684 #[test]
685 fn test_response_schema_validation_with_required_field() {
686 let handler: EchoHandler = EchoHandler;
687 let response_schema: serde_json::Value = serde_json::json!({
688 "type": "object",
689 "properties": {"status": {"type": "string"}},
690 "required": ["status"]
691 });
692
693 let state: WebSocketState<EchoHandler> =
694 WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
695
696 assert!(state.message_schema.is_none());
697 assert!(state.response_schema.is_some());
698
699 let valid_response: Value = serde_json::json!({"status": "ok"});
700 let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
701 assert!(validator.is_valid(&valid_response));
702
703 let invalid_response: Value = serde_json::json!({"other": "field"});
704 assert!(!validator.is_valid(&invalid_response));
705 }
706
707 #[test]
708 fn test_invalid_message_schema_returns_error() {
709 let handler: EchoHandler = EchoHandler;
710 let invalid_schema: serde_json::Value = serde_json::json!({
711 "type": "invalid_type_value",
712 "properties": {}
713 });
714
715 let result: Result<WebSocketState<EchoHandler>, String> =
716 WebSocketState::with_schemas(handler, Some(invalid_schema), None);
717
718 assert!(result.is_err());
719 match result {
720 Err(error_msg) => assert!(error_msg.contains("Invalid message schema")),
721 Ok(_) => panic!("Expected error but got Ok"),
722 }
723 }
724
725 #[test]
726 fn test_invalid_response_schema_returns_error() {
727 let handler: EchoHandler = EchoHandler;
728 let invalid_schema: serde_json::Value = serde_json::json!({
729 "type": "definitely_not_valid"
730 });
731
732 let result: Result<WebSocketState<EchoHandler>, String> =
733 WebSocketState::with_schemas(handler, None, Some(invalid_schema));
734
735 assert!(result.is_err());
736 match result {
737 Err(error_msg) => assert!(error_msg.contains("Invalid response schema")),
738 Ok(_) => panic!("Expected error but got Ok"),
739 }
740 }
741
742 #[tokio::test]
743 async fn test_handler_returning_none_response() {
744 let handler: SelectiveHandler = SelectiveHandler;
745
746 let no_response_msg: Value = serde_json::json!({"respond": false});
747 let result: Option<Value> = handler.handle_message(no_response_msg).await;
748
749 assert!(result.is_none());
750 }
751
752 #[tokio::test]
753 async fn test_handler_with_complex_schema_validation() {
754 let handler: EchoHandler = EchoHandler;
755 let message_schema: serde_json::Value = serde_json::json!({
756 "type": "object",
757 "properties": {
758 "user": {
759 "type": "object",
760 "properties": {
761 "id": {"type": "integer"},
762 "name": {"type": "string"}
763 },
764 "required": ["id", "name"]
765 },
766 "action": {"type": "string"}
767 },
768 "required": ["user", "action"]
769 });
770
771 let state: WebSocketState<EchoHandler> =
772 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
773
774 let valid_msg: Value = serde_json::json!({
775 "user": {"id": 123, "name": "Alice"},
776 "action": "create"
777 });
778 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
779 assert!(validator.is_valid(&valid_msg));
780
781 let invalid_msg: Value = serde_json::json!({
782 "user": {"id": "not_an_int", "name": "Bob"},
783 "action": "create"
784 });
785 assert!(!validator.is_valid(&invalid_msg));
786 }
787
788 #[tokio::test]
789 async fn test_tracking_handler_with_multiple_message_types() {
790 let handler: TrackingHandler = TrackingHandler::new();
791
792 let messages: Vec<Value> = vec![
793 serde_json::json!({"type": "text", "content": "hello"}),
794 serde_json::json!({"type": "image", "url": "http://example.com/image.png"}),
795 serde_json::json!({"type": "video", "duration": 120}),
796 ];
797
798 for msg in messages {
799 let _: Option<Value> = handler.handle_message(msg).await;
800 }
801
802 assert_eq!(handler.message_count.load(Ordering::SeqCst), 3);
803 let stored: Vec<Value> = handler.messages.lock().unwrap().clone();
804 assert_eq!(stored.len(), 3);
805 assert_eq!(stored[0].get("type").unwrap(), "text");
806 assert_eq!(stored[1].get("type").unwrap(), "image");
807 assert_eq!(stored[2].get("type").unwrap(), "video");
808 }
809
810 #[tokio::test]
811 async fn test_selective_handler_with_explicit_false() {
812 let handler: SelectiveHandler = SelectiveHandler;
813
814 let msg: Value = serde_json::json!({"respond": false, "data": "test"});
815 let response: Option<Value> = handler.handle_message(msg).await;
816
817 assert!(response.is_none());
818 }
819
820 #[tokio::test]
821 async fn test_selective_handler_without_respond_field() {
822 let handler: SelectiveHandler = SelectiveHandler;
823
824 let msg: Value = serde_json::json!({"data": "test"});
825 let response: Option<Value> = handler.handle_message(msg).await;
826
827 assert!(response.is_none());
828 }
829
830 #[tokio::test]
831 async fn test_transform_handler_with_empty_object() {
832 let handler: TransformHandler = TransformHandler;
833 let original: Value = serde_json::json!({});
834 let transformed: Option<Value> = handler.handle_message(original).await;
835
836 assert!(transformed.is_some());
837 let resp: Value = transformed.unwrap();
838 assert_eq!(resp.get("processed").unwrap(), true);
839 assert_eq!(resp.as_object().unwrap().len(), 1);
840 }
841
842 #[tokio::test]
843 async fn test_transform_handler_preserves_all_fields() {
844 let handler: TransformHandler = TransformHandler;
845 let original: Value = serde_json::json!({
846 "field1": "value1",
847 "field2": 42,
848 "field3": true,
849 "nested": {"key": "value"}
850 });
851 let transformed: Option<Value> = handler.handle_message(original.clone()).await;
852
853 assert!(transformed.is_some());
854 let resp: Value = transformed.unwrap();
855 assert_eq!(resp.get("field1").unwrap(), "value1");
856 assert_eq!(resp.get("field2").unwrap(), 42);
857 assert_eq!(resp.get("field3").unwrap(), true);
858 assert_eq!(resp.get("nested").unwrap(), &serde_json::json!({"key": "value"}));
859 assert_eq!(resp.get("processed").unwrap(), true);
860 }
861
862 #[tokio::test]
863 async fn test_transform_handler_with_non_object_input() {
864 let handler: TransformHandler = TransformHandler;
865
866 let array: Value = serde_json::json!([1, 2, 3]);
867 let response1: Option<Value> = handler.handle_message(array).await;
868 assert!(response1.is_none());
869
870 let string: Value = serde_json::json!("not an object");
871 let response2: Option<Value> = handler.handle_message(string).await;
872 assert!(response2.is_none());
873
874 let number: Value = serde_json::json!(42);
875 let response3: Option<Value> = handler.handle_message(number).await;
876 assert!(response3.is_none());
877 }
878
879 #[test]
881 fn test_message_schema_rejects_wrong_type() {
882 let handler: EchoHandler = EchoHandler;
883 let message_schema: serde_json::Value = serde_json::json!({
884 "type": "object",
885 "properties": {"id": {"type": "integer"}},
886 "required": ["id"]
887 });
888
889 let state: WebSocketState<EchoHandler> =
890 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
891
892 let invalid_msg: Value = serde_json::json!({"id": "not_an_integer"});
893 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
894 assert!(!validator.is_valid(&invalid_msg));
895 }
896
897 #[test]
899 fn test_response_schema_rejects_invalid_type() {
900 let handler: EchoHandler = EchoHandler;
901 let response_schema: serde_json::Value = serde_json::json!({
902 "type": "object",
903 "properties": {"count": {"type": "integer"}},
904 "required": ["count"]
905 });
906
907 let state: WebSocketState<EchoHandler> =
908 WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
909
910 let invalid_response: Value = serde_json::json!([1, 2, 3]);
911 let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
912 assert!(!validator.is_valid(&invalid_response));
913 }
914
915 #[test]
917 fn test_message_missing_multiple_required_fields() {
918 let handler: EchoHandler = EchoHandler;
919 let message_schema: serde_json::Value = serde_json::json!({
920 "type": "object",
921 "properties": {
922 "user_id": {"type": "integer"},
923 "action": {"type": "string"},
924 "timestamp": {"type": "string"}
925 },
926 "required": ["user_id", "action", "timestamp"]
927 });
928
929 let state: WebSocketState<EchoHandler> =
930 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
931
932 let invalid_msg: Value = serde_json::json!({"other": "value"});
933 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
934 assert!(!validator.is_valid(&invalid_msg));
935
936 let partial_msg: Value = serde_json::json!({"user_id": 123});
937 assert!(!validator.is_valid(&partial_msg));
938 }
939
940 #[test]
942 fn test_deeply_nested_schema_validation_failure() {
943 let handler: EchoHandler = EchoHandler;
944 let message_schema: serde_json::Value = serde_json::json!({
945 "type": "object",
946 "properties": {
947 "metadata": {
948 "type": "object",
949 "properties": {
950 "request": {
951 "type": "object",
952 "properties": {
953 "id": {"type": "string"}
954 },
955 "required": ["id"]
956 }
957 },
958 "required": ["request"]
959 }
960 },
961 "required": ["metadata"]
962 });
963
964 let state: WebSocketState<EchoHandler> =
965 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
966
967 let invalid_msg: Value = serde_json::json!({
968 "metadata": {
969 "request": {}
970 }
971 });
972 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
973 assert!(!validator.is_valid(&invalid_msg));
974 }
975
976 #[test]
978 fn test_array_property_type_validation() {
979 let handler: EchoHandler = EchoHandler;
980 let message_schema: serde_json::Value = serde_json::json!({
981 "type": "object",
982 "properties": {
983 "ids": {
984 "type": "array",
985 "items": {"type": "integer"}
986 }
987 }
988 });
989
990 let state: WebSocketState<EchoHandler> =
991 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
992
993 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
994
995 let valid_msg: Value = serde_json::json!({"ids": [1, 2, 3]});
996 assert!(validator.is_valid(&valid_msg));
997
998 let invalid_msg: Value = serde_json::json!({"ids": [1, "two", 3]});
999 assert!(!validator.is_valid(&invalid_msg));
1000
1001 let invalid_msg2: Value = serde_json::json!({"ids": "not_an_array"});
1002 assert!(!validator.is_valid(&invalid_msg2));
1003 }
1004
1005 #[test]
1007 fn test_enum_property_validation() {
1008 let handler: EchoHandler = EchoHandler;
1009 let message_schema: serde_json::Value = serde_json::json!({
1010 "type": "object",
1011 "properties": {
1012 "status": {
1013 "type": "string",
1014 "enum": ["pending", "active", "completed"]
1015 }
1016 },
1017 "required": ["status"]
1018 });
1019
1020 let state: WebSocketState<EchoHandler> =
1021 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1022
1023 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1024
1025 let valid_msg: Value = serde_json::json!({"status": "active"});
1026 assert!(validator.is_valid(&valid_msg));
1027
1028 let invalid_msg: Value = serde_json::json!({"status": "unknown"});
1029 assert!(!validator.is_valid(&invalid_msg));
1030 }
1031
1032 #[test]
1034 fn test_number_range_validation() {
1035 let handler: EchoHandler = EchoHandler;
1036 let message_schema: serde_json::Value = serde_json::json!({
1037 "type": "object",
1038 "properties": {
1039 "age": {
1040 "type": "integer",
1041 "minimum": 0,
1042 "maximum": 150
1043 }
1044 },
1045 "required": ["age"]
1046 });
1047
1048 let state: WebSocketState<EchoHandler> =
1049 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1050
1051 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1052
1053 let valid_msg: Value = serde_json::json!({"age": 25});
1054 assert!(validator.is_valid(&valid_msg));
1055
1056 let invalid_msg: Value = serde_json::json!({"age": -1});
1057 assert!(!validator.is_valid(&invalid_msg));
1058
1059 let invalid_msg2: Value = serde_json::json!({"age": 200});
1060 assert!(!validator.is_valid(&invalid_msg2));
1061 }
1062
1063 #[test]
1065 fn test_string_length_validation() {
1066 let handler: EchoHandler = EchoHandler;
1067 let message_schema: serde_json::Value = serde_json::json!({
1068 "type": "object",
1069 "properties": {
1070 "username": {
1071 "type": "string",
1072 "minLength": 3,
1073 "maxLength": 20
1074 }
1075 },
1076 "required": ["username"]
1077 });
1078
1079 let state: WebSocketState<EchoHandler> =
1080 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1081
1082 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1083
1084 let valid_msg: Value = serde_json::json!({"username": "alice"});
1085 assert!(validator.is_valid(&valid_msg));
1086
1087 let invalid_msg: Value = serde_json::json!({"username": "ab"});
1088 assert!(!validator.is_valid(&invalid_msg));
1089
1090 let invalid_msg2: Value =
1091 serde_json::json!({"username": "this_is_a_very_long_username_over_twenty_characters"});
1092 assert!(!validator.is_valid(&invalid_msg2));
1093 }
1094
1095 #[test]
1097 fn test_pattern_validation() {
1098 let handler: EchoHandler = EchoHandler;
1099 let message_schema: serde_json::Value = serde_json::json!({
1100 "type": "object",
1101 "properties": {
1102 "email": {
1103 "type": "string",
1104 "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
1105 }
1106 },
1107 "required": ["email"]
1108 });
1109
1110 let state: WebSocketState<EchoHandler> =
1111 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1112
1113 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1114
1115 let valid_msg: Value = serde_json::json!({"email": "user@example.com"});
1116 assert!(validator.is_valid(&valid_msg));
1117
1118 let invalid_msg: Value = serde_json::json!({"email": "user@example"});
1119 assert!(!validator.is_valid(&invalid_msg));
1120
1121 let invalid_msg2: Value = serde_json::json!({"email": "userexample.com"});
1122 assert!(!validator.is_valid(&invalid_msg2));
1123 }
1124
1125 #[test]
1127 fn test_additional_properties_validation() {
1128 let handler: EchoHandler = EchoHandler;
1129 let message_schema: serde_json::Value = serde_json::json!({
1130 "type": "object",
1131 "properties": {
1132 "name": {"type": "string"}
1133 },
1134 "additionalProperties": false
1135 });
1136
1137 let state: WebSocketState<EchoHandler> =
1138 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1139
1140 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1141
1142 let valid_msg: Value = serde_json::json!({"name": "Alice"});
1143 assert!(validator.is_valid(&valid_msg));
1144
1145 let invalid_msg: Value = serde_json::json!({"name": "Bob", "age": 30});
1146 assert!(!validator.is_valid(&invalid_msg));
1147 }
1148
1149 #[test]
1151 fn test_one_of_constraint() {
1152 let handler: EchoHandler = EchoHandler;
1153 let message_schema: serde_json::Value = serde_json::json!({
1154 "type": "object",
1155 "oneOf": [
1156 {
1157 "properties": {"type": {"const": "text"}},
1158 "required": ["type"]
1159 },
1160 {
1161 "properties": {"type": {"const": "number"}},
1162 "required": ["type"]
1163 }
1164 ]
1165 });
1166
1167 let state: WebSocketState<EchoHandler> =
1168 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1169
1170 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1171
1172 let valid_msg: Value = serde_json::json!({"type": "text"});
1173 assert!(validator.is_valid(&valid_msg));
1174
1175 let invalid_msg: Value = serde_json::json!({"type": "unknown"});
1176 assert!(!validator.is_valid(&invalid_msg));
1177 }
1178
1179 #[test]
1181 fn test_any_of_constraint() {
1182 let handler: EchoHandler = EchoHandler;
1183 let message_schema: serde_json::Value = serde_json::json!({
1184 "type": "object",
1185 "properties": {
1186 "value": {"type": ["string", "integer"]}
1187 },
1188 "required": ["value"]
1189 });
1190
1191 let state: WebSocketState<EchoHandler> =
1192 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1193
1194 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1195
1196 let msg1: Value = serde_json::json!({"value": "text"});
1197 assert!(validator.is_valid(&msg1));
1198
1199 let msg2: Value = serde_json::json!({"value": 42});
1200 assert!(validator.is_valid(&msg2));
1201
1202 let invalid_msg: Value = serde_json::json!({"value": true});
1203 assert!(!validator.is_valid(&invalid_msg));
1204 }
1205
1206 #[test]
1208 fn test_response_schema_with_multiple_constraints() {
1209 let handler: EchoHandler = EchoHandler;
1210 let response_schema: serde_json::Value = serde_json::json!({
1211 "type": "object",
1212 "properties": {
1213 "success": {"type": "boolean"},
1214 "data": {
1215 "type": "object",
1216 "properties": {
1217 "items": {
1218 "type": "array",
1219 "items": {"type": "object"},
1220 "minItems": 1
1221 }
1222 },
1223 "required": ["items"]
1224 }
1225 },
1226 "required": ["success", "data"]
1227 });
1228
1229 let state: WebSocketState<EchoHandler> =
1230 WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
1231
1232 let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
1233
1234 let valid_response: Value = serde_json::json!({
1235 "success": true,
1236 "data": {
1237 "items": [{"id": 1}]
1238 }
1239 });
1240 assert!(validator.is_valid(&valid_response));
1241
1242 let invalid_response: Value = serde_json::json!({
1243 "success": true,
1244 "data": {
1245 "items": []
1246 }
1247 });
1248 assert!(!validator.is_valid(&invalid_response));
1249
1250 let invalid_response2: Value = serde_json::json!({
1251 "success": true
1252 });
1253 assert!(!validator.is_valid(&invalid_response2));
1254 }
1255
1256 #[test]
1258 fn test_null_value_validation() {
1259 let handler: EchoHandler = EchoHandler;
1260 let message_schema: serde_json::Value = serde_json::json!({
1261 "type": "object",
1262 "properties": {
1263 "optional_field": {"type": ["string", "null"]},
1264 "required_field": {"type": "string"}
1265 },
1266 "required": ["required_field"]
1267 });
1268
1269 let state: WebSocketState<EchoHandler> =
1270 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1271
1272 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1273
1274 let msg1: Value = serde_json::json!({
1275 "optional_field": null,
1276 "required_field": "value"
1277 });
1278 assert!(validator.is_valid(&msg1));
1279
1280 let msg2: Value = serde_json::json!({"required_field": "value"});
1281 assert!(validator.is_valid(&msg2));
1282
1283 let invalid_msg: Value = serde_json::json!({"required_field": null});
1284 assert!(!validator.is_valid(&invalid_msg));
1285 }
1286
1287 #[test]
1289 fn test_schema_with_defaults_still_validates() {
1290 let handler: EchoHandler = EchoHandler;
1291 let message_schema: serde_json::Value = serde_json::json!({
1292 "type": "object",
1293 "properties": {
1294 "status": {
1295 "type": "string",
1296 "default": "pending"
1297 }
1298 }
1299 });
1300
1301 let state: WebSocketState<EchoHandler> =
1302 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1303
1304 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1305
1306 let msg: Value = serde_json::json!({});
1307 assert!(validator.is_valid(&msg));
1308 }
1309
1310 #[test]
1312 fn test_both_schemas_validate_independently() {
1313 let handler: EchoHandler = EchoHandler;
1314 let message_schema: serde_json::Value = serde_json::json!({
1315 "type": "object",
1316 "properties": {"action": {"type": "string"}},
1317 "required": ["action"]
1318 });
1319 let response_schema: serde_json::Value = serde_json::json!({
1320 "type": "object",
1321 "properties": {"result": {"type": "string"}},
1322 "required": ["result"]
1323 });
1324
1325 let state: WebSocketState<EchoHandler> =
1326 WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema)).unwrap();
1327
1328 let msg_validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1329 let resp_validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
1330
1331 let valid_msg: Value = serde_json::json!({"action": "test"});
1332 let invalid_response: Value = serde_json::json!({"data": "oops"});
1333
1334 assert!(msg_validator.is_valid(&valid_msg));
1335 assert!(!resp_validator.is_valid(&invalid_response));
1336
1337 let invalid_msg: Value = serde_json::json!({"data": "oops"});
1338 let valid_response: Value = serde_json::json!({"result": "ok"});
1339
1340 assert!(!msg_validator.is_valid(&invalid_msg));
1341 assert!(resp_validator.is_valid(&valid_response));
1342 }
1343
1344 #[test]
1346 fn test_validation_with_large_payload() {
1347 let handler: EchoHandler = EchoHandler;
1348 let message_schema: serde_json::Value = serde_json::json!({
1349 "type": "object",
1350 "properties": {
1351 "items": {
1352 "type": "array",
1353 "items": {"type": "integer"}
1354 }
1355 },
1356 "required": ["items"]
1357 });
1358
1359 let state: WebSocketState<EchoHandler> =
1360 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1361
1362 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1363
1364 let mut items = Vec::new();
1365 for i in 0..10_000 {
1366 items.push(i);
1367 }
1368 let large_msg: Value = serde_json::json!({"items": items});
1369
1370 assert!(validator.is_valid(&large_msg));
1371 }
1372
1373 #[test]
1375 fn test_mutually_exclusive_schema_properties() {
1376 let handler: EchoHandler = EchoHandler;
1377
1378 let message_schema: serde_json::Value = serde_json::json!({
1379 "allOf": [
1380 {
1381 "type": "object",
1382 "properties": {"a": {"type": "string"}},
1383 "required": ["a"]
1384 },
1385 {
1386 "type": "object",
1387 "properties": {"b": {"type": "integer"}},
1388 "required": ["b"]
1389 }
1390 ]
1391 });
1392
1393 let state: WebSocketState<EchoHandler> =
1394 WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1395
1396 let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1397
1398 let valid_msg: Value = serde_json::json!({"a": "text", "b": 42});
1399 assert!(validator.is_valid(&valid_msg));
1400
1401 let invalid_msg: Value = serde_json::json!({"a": "text"});
1402 assert!(!validator.is_valid(&invalid_msg));
1403 }
1404}