spikard_http/
websocket.rs

1//! WebSocket support for Spikard
2//!
3//! Provides WebSocket connection handling with message validation and routing.
4
5use axum::{
6    extract::{
7        State,
8        ws::{Message, WebSocket, WebSocketUpgrade},
9    },
10    response::IntoResponse,
11};
12use serde_json::Value;
13use std::sync::Arc;
14use tracing::{debug, error, info, warn};
15
16/// WebSocket message handler trait
17///
18/// Implement this trait to create custom WebSocket message handlers for your application.
19/// The handler processes JSON messages received from WebSocket clients and can optionally
20/// send responses back.
21///
22/// # Implementing the Trait
23///
24/// You must implement the `handle_message` method. The `on_connect` and `on_disconnect`
25/// methods are optional and provide lifecycle hooks.
26///
27/// # Example
28///
29/// ```ignore
30/// use spikard_http::websocket::WebSocketHandler;
31/// use serde_json::{json, Value};
32///
33/// struct EchoHandler;
34///
35/// #[async_trait]
36/// impl WebSocketHandler for EchoHandler {
37///     async fn handle_message(&self, message: Value) -> Option<Value> {
38///         // Echo the message back to the client
39///         Some(message)
40///     }
41///
42///     async fn on_connect(&self) {
43///         println!("Client connected");
44///     }
45///
46///     async fn on_disconnect(&self) {
47///         println!("Client disconnected");
48///     }
49/// }
50/// ```
51pub trait WebSocketHandler: Send + Sync {
52    /// Handle incoming WebSocket message
53    ///
54    /// Called whenever a text message is received from a WebSocket client.
55    /// Messages are automatically parsed as JSON.
56    ///
57    /// # Arguments
58    /// * `message` - JSON value received from the client
59    ///
60    /// # Returns
61    /// * `Some(value)` - JSON value to send back to the client
62    /// * `None` - No response to send
63    fn handle_message(&self, message: Value) -> impl std::future::Future<Output = Option<Value>> + Send;
64
65    /// Called when a client connects to the WebSocket
66    ///
67    /// Optional lifecycle hook invoked when a new WebSocket connection is established.
68    /// Default implementation does nothing.
69    fn on_connect(&self) -> impl std::future::Future<Output = ()> + Send {
70        async {}
71    }
72
73    /// Called when a client disconnects from the WebSocket
74    ///
75    /// Optional lifecycle hook invoked when a WebSocket connection is closed
76    /// (either by the client or due to an error). Default implementation does nothing.
77    fn on_disconnect(&self) -> impl std::future::Future<Output = ()> + Send {
78        async {}
79    }
80}
81
82/// WebSocket state shared across connections
83///
84/// Contains the message handler and optional JSON schemas for validating
85/// incoming and outgoing messages. This state is shared among all connections
86/// to the same WebSocket endpoint.
87#[derive(Debug)]
88pub struct WebSocketState<H: WebSocketHandler> {
89    /// The message handler implementation
90    handler: Arc<H>,
91    /// Optional JSON Schema for validating incoming messages
92    message_schema: Option<Arc<jsonschema::Validator>>,
93    /// Optional JSON Schema for validating outgoing responses
94    response_schema: Option<Arc<jsonschema::Validator>>,
95}
96
97impl<H: WebSocketHandler> Clone for WebSocketState<H> {
98    fn clone(&self) -> Self {
99        Self {
100            handler: Arc::clone(&self.handler),
101            message_schema: self.message_schema.clone(),
102            response_schema: self.response_schema.clone(),
103        }
104    }
105}
106
107impl<H: WebSocketHandler + 'static> WebSocketState<H> {
108    /// Create new WebSocket state with a handler
109    ///
110    /// Creates a new state without message or response validation schemas.
111    /// Messages and responses are not validated.
112    ///
113    /// # Arguments
114    /// * `handler` - The message handler implementation
115    ///
116    /// # Example
117    ///
118    /// ```ignore
119    /// let state = WebSocketState::new(MyHandler);
120    /// ```
121    pub fn new(handler: H) -> Self {
122        Self {
123            handler: Arc::new(handler),
124            message_schema: None,
125            response_schema: None,
126        }
127    }
128
129    /// Create new WebSocket state with a handler and optional validation schemas
130    ///
131    /// Creates a new state with optional JSON schemas for validating incoming messages
132    /// and outgoing responses. If a schema is provided and validation fails, the message
133    /// or response is rejected.
134    ///
135    /// # Arguments
136    /// * `handler` - The message handler implementation
137    /// * `message_schema` - Optional JSON schema for validating client messages
138    /// * `response_schema` - Optional JSON schema for validating handler responses
139    ///
140    /// # Returns
141    /// * `Ok(state)` - Successfully created state
142    /// * `Err(msg)` - Invalid schema provided
143    ///
144    /// # Example
145    ///
146    /// ```ignore
147    /// use serde_json::json;
148    ///
149    /// let message_schema = json!({
150    ///     "type": "object",
151    ///     "properties": {
152    ///         "type": {"type": "string"},
153    ///         "data": {"type": "string"}
154    ///     }
155    /// });
156    ///
157    /// let state = WebSocketState::with_schemas(
158    ///     MyHandler,
159    ///     Some(message_schema),
160    ///     None,
161    /// )?;
162    /// ```
163    pub fn with_schemas(
164        handler: H,
165        message_schema: Option<serde_json::Value>,
166        response_schema: Option<serde_json::Value>,
167    ) -> Result<Self, String> {
168        let message_validator = if let Some(schema) = message_schema {
169            Some(Arc::new(
170                jsonschema::validator_for(&schema).map_err(|e| format!("Invalid message schema: {}", e))?,
171            ))
172        } else {
173            None
174        };
175
176        let response_validator = if let Some(schema) = response_schema {
177            Some(Arc::new(
178                jsonschema::validator_for(&schema).map_err(|e| format!("Invalid response schema: {}", e))?,
179            ))
180        } else {
181            None
182        };
183
184        Ok(Self {
185            handler: Arc::new(handler),
186            message_schema: message_validator,
187            response_schema: response_validator,
188        })
189    }
190}
191
192/// WebSocket upgrade handler
193///
194/// This is the main entry point for WebSocket connections. Use this as an Axum route
195/// handler by passing it to an Axum router's `.route()` method with `get()`.
196///
197/// # Arguments
198/// * `ws` - WebSocket upgrade from Axum
199/// * `State(state)` - Application state containing the handler and optional schemas
200///
201/// # Returns
202/// An Axum response that upgrades the connection to WebSocket
203///
204/// # Example
205///
206/// ```ignore
207/// use axum::{Router, routing::get, extract::State};
208///
209/// let state = WebSocketState::new(MyHandler);
210/// let router = Router::new()
211///     .route("/ws", get(websocket_handler::<MyHandler>))
212///     .with_state(state);
213/// ```
214pub async fn websocket_handler<H: WebSocketHandler + 'static>(
215    ws: WebSocketUpgrade,
216    State(state): State<WebSocketState<H>>,
217) -> impl IntoResponse {
218    ws.on_upgrade(move |socket| handle_socket(socket, state))
219}
220
221/// Handle an individual WebSocket connection
222async fn handle_socket<H: WebSocketHandler>(mut socket: WebSocket, state: WebSocketState<H>) {
223    info!("WebSocket client connected");
224
225    state.handler.on_connect().await;
226
227    while let Some(msg) = socket.recv().await {
228        match msg {
229            Ok(Message::Text(text)) => {
230                debug!("Received text message: {}", text);
231
232                match serde_json::from_str::<Value>(&text) {
233                    Ok(json_msg) => {
234                        if let Some(validator) = &state.message_schema
235                            && !validator.is_valid(&json_msg)
236                        {
237                            error!("Message validation failed");
238                            let error_response = serde_json::json!({
239                                "error": "Message validation failed"
240                            });
241                            if let Ok(error_text) = serde_json::to_string(&error_response) {
242                                let _ = socket.send(Message::Text(error_text.into())).await;
243                            }
244                            continue;
245                        }
246
247                        if let Some(response) = state.handler.handle_message(json_msg).await {
248                            if let Some(validator) = &state.response_schema
249                                && !validator.is_valid(&response)
250                            {
251                                error!("Response validation failed");
252                                continue;
253                            }
254
255                            let response_text = serde_json::to_string(&response).unwrap_or_else(|_| "{}".to_string());
256
257                            if let Err(e) = socket.send(Message::Text(response_text.into())).await {
258                                error!("Failed to send response: {}", e);
259                                break;
260                            }
261                        }
262                    }
263                    Err(e) => {
264                        warn!("Failed to parse JSON message: {}", e);
265                        let error_msg = serde_json::json!({
266                            "type": "error",
267                            "message": "Invalid JSON"
268                        });
269                        let error_text = serde_json::to_string(&error_msg).unwrap_or_else(|_| "{}".to_string());
270                        let _ = socket.send(Message::Text(error_text.into())).await;
271                    }
272                }
273            }
274            Ok(Message::Binary(data)) => {
275                debug!("Received binary message: {} bytes", data.len());
276                if let Err(e) = socket.send(Message::Binary(data)).await {
277                    error!("Failed to send binary response: {}", e);
278                    break;
279                }
280            }
281            Ok(Message::Ping(data)) => {
282                debug!("Received ping");
283                if let Err(e) = socket.send(Message::Pong(data)).await {
284                    error!("Failed to send pong: {}", e);
285                    break;
286                }
287            }
288            Ok(Message::Pong(_)) => {
289                debug!("Received pong");
290            }
291            Ok(Message::Close(_)) => {
292                info!("Client closed connection");
293                break;
294            }
295            Err(e) => {
296                error!("WebSocket error: {}", e);
297                break;
298            }
299        }
300    }
301
302    state.handler.on_disconnect().await;
303    info!("WebSocket client disconnected");
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use std::sync::Mutex;
310    use std::sync::atomic::{AtomicUsize, Ordering};
311
312    #[derive(Debug)]
313    struct EchoHandler;
314
315    impl WebSocketHandler for EchoHandler {
316        async fn handle_message(&self, message: Value) -> Option<Value> {
317            Some(message)
318        }
319    }
320
321    #[derive(Debug)]
322    struct TrackingHandler {
323        connect_count: Arc<AtomicUsize>,
324        disconnect_count: Arc<AtomicUsize>,
325        message_count: Arc<AtomicUsize>,
326        messages: Arc<Mutex<Vec<Value>>>,
327    }
328
329    impl TrackingHandler {
330        fn new() -> Self {
331            Self {
332                connect_count: Arc::new(AtomicUsize::new(0)),
333                disconnect_count: Arc::new(AtomicUsize::new(0)),
334                message_count: Arc::new(AtomicUsize::new(0)),
335                messages: Arc::new(Mutex::new(Vec::new())),
336            }
337        }
338    }
339
340    impl WebSocketHandler for TrackingHandler {
341        async fn handle_message(&self, message: Value) -> Option<Value> {
342            self.message_count.fetch_add(1, Ordering::SeqCst);
343            self.messages.lock().unwrap().push(message.clone());
344            Some(message)
345        }
346
347        async fn on_connect(&self) {
348            self.connect_count.fetch_add(1, Ordering::SeqCst);
349        }
350
351        async fn on_disconnect(&self) {
352            self.disconnect_count.fetch_add(1, Ordering::SeqCst);
353        }
354    }
355
356    #[derive(Debug)]
357    struct SelectiveHandler;
358
359    impl WebSocketHandler for SelectiveHandler {
360        async fn handle_message(&self, message: Value) -> Option<Value> {
361            if message.get("respond").is_some_and(|v| v.as_bool().unwrap_or(false)) {
362                Some(serde_json::json!({"response": "acknowledged"}))
363            } else {
364                None
365            }
366        }
367    }
368
369    #[derive(Debug)]
370    struct TransformHandler;
371
372    impl WebSocketHandler for TransformHandler {
373        async fn handle_message(&self, message: Value) -> Option<Value> {
374            message.as_object().map_or(None, |obj| {
375                let mut resp = obj.clone();
376                resp.insert("processed".to_string(), Value::Bool(true));
377                Some(Value::Object(resp))
378            })
379        }
380    }
381
382    #[test]
383    fn test_websocket_state_creation() {
384        let handler: EchoHandler = EchoHandler;
385        let state: WebSocketState<EchoHandler> = WebSocketState::new(handler);
386        let cloned: WebSocketState<EchoHandler> = state.clone();
387        assert!(Arc::ptr_eq(&state.handler, &cloned.handler));
388    }
389
390    #[test]
391    fn test_websocket_state_with_valid_schema() {
392        let handler: EchoHandler = EchoHandler;
393        let schema: serde_json::Value = serde_json::json!({
394            "type": "object",
395            "properties": {
396                "type": {"type": "string"}
397            }
398        });
399
400        let result: Result<WebSocketState<EchoHandler>, String> =
401            WebSocketState::with_schemas(handler, Some(schema), None);
402        assert!(result.is_ok());
403    }
404
405    #[test]
406    fn test_websocket_state_with_invalid_schema() {
407        let handler: EchoHandler = EchoHandler;
408        let invalid_schema: serde_json::Value = serde_json::json!({
409            "type": "not_a_real_type",
410            "invalid": "schema"
411        });
412
413        let result: Result<WebSocketState<EchoHandler>, String> =
414            WebSocketState::with_schemas(handler, Some(invalid_schema), None);
415        assert!(result.is_err());
416        if let Err(error_msg) = result {
417            assert!(error_msg.contains("Invalid message schema"));
418        }
419    }
420
421    #[test]
422    fn test_websocket_state_with_both_schemas() {
423        let handler: EchoHandler = EchoHandler;
424        let message_schema: serde_json::Value = serde_json::json!({
425            "type": "object",
426            "properties": {"action": {"type": "string"}}
427        });
428        let response_schema: serde_json::Value = serde_json::json!({
429            "type": "object",
430            "properties": {"result": {"type": "string"}}
431        });
432
433        let result: Result<WebSocketState<EchoHandler>, String> =
434            WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema));
435        assert!(result.is_ok());
436        let state: WebSocketState<EchoHandler> = result.unwrap();
437        assert!(state.message_schema.is_some());
438        assert!(state.response_schema.is_some());
439    }
440
441    #[test]
442    fn test_websocket_state_cloning_preserves_schemas() {
443        let handler: EchoHandler = EchoHandler;
444        let schema: serde_json::Value = serde_json::json!({
445            "type": "object",
446            "properties": {"id": {"type": "integer"}}
447        });
448
449        let state: WebSocketState<EchoHandler> = WebSocketState::with_schemas(handler, Some(schema), None).unwrap();
450        let cloned: WebSocketState<EchoHandler> = state.clone();
451
452        assert!(cloned.message_schema.is_some());
453        assert!(cloned.response_schema.is_none());
454        assert!(Arc::ptr_eq(&state.handler, &cloned.handler));
455    }
456
457    #[tokio::test]
458    async fn test_tracking_handler_lifecycle() {
459        let handler: TrackingHandler = TrackingHandler::new();
460        handler.on_connect().await;
461        assert_eq!(handler.connect_count.load(Ordering::SeqCst), 1);
462
463        let msg: Value = serde_json::json!({"test": "data"});
464        let _response: Option<Value> = handler.handle_message(msg).await;
465        assert_eq!(handler.message_count.load(Ordering::SeqCst), 1);
466
467        handler.on_disconnect().await;
468        assert_eq!(handler.disconnect_count.load(Ordering::SeqCst), 1);
469    }
470
471    #[tokio::test]
472    async fn test_selective_handler_responds_conditionally() {
473        let handler: SelectiveHandler = SelectiveHandler;
474
475        let respond_msg: Value = serde_json::json!({"respond": true});
476        let response1: Option<Value> = handler.handle_message(respond_msg).await;
477        assert!(response1.is_some());
478        assert_eq!(response1.unwrap(), serde_json::json!({"response": "acknowledged"}));
479
480        let no_respond_msg: Value = serde_json::json!({"respond": false});
481        let response2: Option<Value> = handler.handle_message(no_respond_msg).await;
482        assert!(response2.is_none());
483    }
484
485    #[tokio::test]
486    async fn test_transform_handler_modifies_message() {
487        let handler: TransformHandler = TransformHandler;
488        let original: Value = serde_json::json!({"name": "test"});
489        let transformed: Option<Value> = handler.handle_message(original).await;
490
491        assert!(transformed.is_some());
492        let resp: Value = transformed.unwrap();
493        assert_eq!(resp.get("name").unwrap(), "test");
494        assert_eq!(resp.get("processed").unwrap(), true);
495    }
496
497    #[tokio::test]
498    async fn test_echo_handler_preserves_json_types() {
499        let handler: EchoHandler = EchoHandler;
500
501        let messages: Vec<Value> = vec![
502            serde_json::json!({"string": "value"}),
503            serde_json::json!({"number": 42}),
504            serde_json::json!({"float": 3.14}),
505            serde_json::json!({"bool": true}),
506            serde_json::json!({"null": null}),
507            serde_json::json!({"array": [1, 2, 3]}),
508        ];
509
510        for msg in messages {
511            let response: Option<Value> = handler.handle_message(msg.clone()).await;
512            assert!(response.is_some());
513            assert_eq!(response.unwrap(), msg);
514        }
515    }
516
517    #[tokio::test]
518    async fn test_tracking_handler_accumulates_messages() {
519        let handler: TrackingHandler = TrackingHandler::new();
520
521        let messages: Vec<Value> = vec![
522            serde_json::json!({"id": 1}),
523            serde_json::json!({"id": 2}),
524            serde_json::json!({"id": 3}),
525        ];
526
527        for msg in messages {
528            let _: Option<Value> = handler.handle_message(msg).await;
529        }
530
531        assert_eq!(handler.message_count.load(Ordering::SeqCst), 3);
532        let stored: Vec<Value> = handler.messages.lock().unwrap().clone();
533        assert_eq!(stored.len(), 3);
534        assert_eq!(stored[0].get("id").unwrap(), 1);
535        assert_eq!(stored[1].get("id").unwrap(), 2);
536        assert_eq!(stored[2].get("id").unwrap(), 3);
537    }
538
539    #[tokio::test]
540    async fn test_echo_handler_with_nested_json() {
541        let handler: EchoHandler = EchoHandler;
542        let nested: Value = serde_json::json!({
543            "level1": {
544                "level2": {
545                    "level3": {
546                        "value": "deeply nested"
547                    }
548                }
549            }
550        });
551
552        let response: Option<Value> = handler.handle_message(nested.clone()).await;
553        assert!(response.is_some());
554        assert_eq!(response.unwrap(), nested);
555    }
556
557    #[tokio::test]
558    async fn test_echo_handler_with_large_array() {
559        let handler: EchoHandler = EchoHandler;
560        let large_array: Value = serde_json::json!({
561            "items": (0..1000).collect::<Vec<i32>>()
562        });
563
564        let response: Option<Value> = handler.handle_message(large_array.clone()).await;
565        assert!(response.is_some());
566        assert_eq!(response.unwrap(), large_array);
567    }
568
569    #[tokio::test]
570    async fn test_echo_handler_with_unicode() {
571        let handler: EchoHandler = EchoHandler;
572        let unicode_msg: Value = serde_json::json!({
573            "emoji": "🚀",
574            "chinese": "你好",
575            "arabic": "مرحبا",
576            "mixed": "Hello 世界 🌍"
577        });
578
579        let response: Option<Value> = handler.handle_message(unicode_msg.clone()).await;
580        assert!(response.is_some());
581        assert_eq!(response.unwrap(), unicode_msg);
582    }
583
584    #[test]
585    fn test_websocket_state_schemas_are_independent() {
586        let handler: EchoHandler = EchoHandler;
587        let message_schema: serde_json::Value = serde_json::json!({"type": "object"});
588        let response_schema: serde_json::Value = serde_json::json!({"type": "array"});
589
590        let state: WebSocketState<EchoHandler> =
591            WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema)).unwrap();
592
593        let cloned: WebSocketState<EchoHandler> = state.clone();
594
595        assert!(state.message_schema.is_some());
596        assert!(state.response_schema.is_some());
597        assert!(cloned.message_schema.is_some());
598        assert!(cloned.response_schema.is_some());
599    }
600
601    #[test]
602    fn test_message_schema_validation_with_required_field() {
603        let handler: EchoHandler = EchoHandler;
604        let message_schema: serde_json::Value = serde_json::json!({
605            "type": "object",
606            "properties": {"type": {"type": "string"}},
607            "required": ["type"]
608        });
609
610        let state: WebSocketState<EchoHandler> =
611            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
612
613        assert!(state.message_schema.is_some());
614        assert!(state.response_schema.is_none());
615
616        let valid_msg: Value = serde_json::json!({"type": "test"});
617        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
618        assert!(validator.is_valid(&valid_msg));
619
620        let invalid_msg: Value = serde_json::json!({"other": "field"});
621        assert!(!validator.is_valid(&invalid_msg));
622    }
623
624    #[test]
625    fn test_response_schema_validation_with_required_field() {
626        let handler: EchoHandler = EchoHandler;
627        let response_schema: serde_json::Value = serde_json::json!({
628            "type": "object",
629            "properties": {"status": {"type": "string"}},
630            "required": ["status"]
631        });
632
633        let state: WebSocketState<EchoHandler> =
634            WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
635
636        assert!(state.message_schema.is_none());
637        assert!(state.response_schema.is_some());
638
639        let valid_response: Value = serde_json::json!({"status": "ok"});
640        let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
641        assert!(validator.is_valid(&valid_response));
642
643        let invalid_response: Value = serde_json::json!({"other": "field"});
644        assert!(!validator.is_valid(&invalid_response));
645    }
646
647    #[test]
648    fn test_invalid_message_schema_returns_error() {
649        let handler: EchoHandler = EchoHandler;
650        let invalid_schema: serde_json::Value = serde_json::json!({
651            "type": "invalid_type_value",
652            "properties": {}
653        });
654
655        let result: Result<WebSocketState<EchoHandler>, String> =
656            WebSocketState::with_schemas(handler, Some(invalid_schema), None);
657
658        assert!(result.is_err());
659        match result {
660            Err(error_msg) => assert!(error_msg.contains("Invalid message schema")),
661            Ok(_) => panic!("Expected error but got Ok"),
662        }
663    }
664
665    #[test]
666    fn test_invalid_response_schema_returns_error() {
667        let handler: EchoHandler = EchoHandler;
668        let invalid_schema: serde_json::Value = serde_json::json!({
669            "type": "definitely_not_valid"
670        });
671
672        let result: Result<WebSocketState<EchoHandler>, String> =
673            WebSocketState::with_schemas(handler, None, Some(invalid_schema));
674
675        assert!(result.is_err());
676        match result {
677            Err(error_msg) => assert!(error_msg.contains("Invalid response schema")),
678            Ok(_) => panic!("Expected error but got Ok"),
679        }
680    }
681
682    #[tokio::test]
683    async fn test_handler_returning_none_response() {
684        let handler: SelectiveHandler = SelectiveHandler;
685
686        let no_response_msg: Value = serde_json::json!({"respond": false});
687        let result: Option<Value> = handler.handle_message(no_response_msg).await;
688
689        assert!(result.is_none());
690    }
691
692    #[tokio::test]
693    async fn test_handler_with_complex_schema_validation() {
694        let handler: EchoHandler = EchoHandler;
695        let message_schema: serde_json::Value = serde_json::json!({
696            "type": "object",
697            "properties": {
698                "user": {
699                    "type": "object",
700                    "properties": {
701                        "id": {"type": "integer"},
702                        "name": {"type": "string"}
703                    },
704                    "required": ["id", "name"]
705                },
706                "action": {"type": "string"}
707            },
708            "required": ["user", "action"]
709        });
710
711        let state: WebSocketState<EchoHandler> =
712            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
713
714        let valid_msg: Value = serde_json::json!({
715            "user": {"id": 123, "name": "Alice"},
716            "action": "create"
717        });
718        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
719        assert!(validator.is_valid(&valid_msg));
720
721        let invalid_msg: Value = serde_json::json!({
722            "user": {"id": "not_an_int", "name": "Bob"},
723            "action": "create"
724        });
725        assert!(!validator.is_valid(&invalid_msg));
726    }
727
728    #[tokio::test]
729    async fn test_tracking_handler_with_multiple_message_types() {
730        let handler: TrackingHandler = TrackingHandler::new();
731
732        let messages: Vec<Value> = vec![
733            serde_json::json!({"type": "text", "content": "hello"}),
734            serde_json::json!({"type": "image", "url": "http://example.com/image.png"}),
735            serde_json::json!({"type": "video", "duration": 120}),
736        ];
737
738        for msg in messages {
739            let _: Option<Value> = handler.handle_message(msg).await;
740        }
741
742        assert_eq!(handler.message_count.load(Ordering::SeqCst), 3);
743        let stored: Vec<Value> = handler.messages.lock().unwrap().clone();
744        assert_eq!(stored.len(), 3);
745        assert_eq!(stored[0].get("type").unwrap(), "text");
746        assert_eq!(stored[1].get("type").unwrap(), "image");
747        assert_eq!(stored[2].get("type").unwrap(), "video");
748    }
749
750    #[tokio::test]
751    async fn test_selective_handler_with_explicit_false() {
752        let handler: SelectiveHandler = SelectiveHandler;
753
754        let msg: Value = serde_json::json!({"respond": false, "data": "test"});
755        let response: Option<Value> = handler.handle_message(msg).await;
756
757        assert!(response.is_none());
758    }
759
760    #[tokio::test]
761    async fn test_selective_handler_without_respond_field() {
762        let handler: SelectiveHandler = SelectiveHandler;
763
764        let msg: Value = serde_json::json!({"data": "test"});
765        let response: Option<Value> = handler.handle_message(msg).await;
766
767        assert!(response.is_none());
768    }
769
770    #[tokio::test]
771    async fn test_transform_handler_with_empty_object() {
772        let handler: TransformHandler = TransformHandler;
773        let original: Value = serde_json::json!({});
774        let transformed: Option<Value> = handler.handle_message(original).await;
775
776        assert!(transformed.is_some());
777        let resp: Value = transformed.unwrap();
778        assert_eq!(resp.get("processed").unwrap(), true);
779        assert_eq!(resp.as_object().unwrap().len(), 1);
780    }
781
782    #[tokio::test]
783    async fn test_transform_handler_preserves_all_fields() {
784        let handler: TransformHandler = TransformHandler;
785        let original: Value = serde_json::json!({
786            "field1": "value1",
787            "field2": 42,
788            "field3": true,
789            "nested": {"key": "value"}
790        });
791        let transformed: Option<Value> = handler.handle_message(original.clone()).await;
792
793        assert!(transformed.is_some());
794        let resp: Value = transformed.unwrap();
795        assert_eq!(resp.get("field1").unwrap(), "value1");
796        assert_eq!(resp.get("field2").unwrap(), 42);
797        assert_eq!(resp.get("field3").unwrap(), true);
798        assert_eq!(resp.get("nested").unwrap(), &serde_json::json!({"key": "value"}));
799        assert_eq!(resp.get("processed").unwrap(), true);
800    }
801
802    #[tokio::test]
803    async fn test_transform_handler_with_non_object_input() {
804        let handler: TransformHandler = TransformHandler;
805
806        let array: Value = serde_json::json!([1, 2, 3]);
807        let response1: Option<Value> = handler.handle_message(array).await;
808        assert!(response1.is_none());
809
810        let string: Value = serde_json::json!("not an object");
811        let response2: Option<Value> = handler.handle_message(string).await;
812        assert!(response2.is_none());
813
814        let number: Value = serde_json::json!(42);
815        let response3: Option<Value> = handler.handle_message(number).await;
816        assert!(response3.is_none());
817    }
818
819    /// Test message validation failure with schema constraint
820    #[test]
821    fn test_message_schema_rejects_wrong_type() {
822        let handler: EchoHandler = EchoHandler;
823        let message_schema: serde_json::Value = serde_json::json!({
824            "type": "object",
825            "properties": {"id": {"type": "integer"}},
826            "required": ["id"]
827        });
828
829        let state: WebSocketState<EchoHandler> =
830            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
831
832        let invalid_msg: Value = serde_json::json!({"id": "not_an_integer"});
833        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
834        assert!(!validator.is_valid(&invalid_msg));
835    }
836
837    /// Test response schema validation failure
838    #[test]
839    fn test_response_schema_rejects_invalid_type() {
840        let handler: EchoHandler = EchoHandler;
841        let response_schema: serde_json::Value = serde_json::json!({
842            "type": "object",
843            "properties": {"count": {"type": "integer"}},
844            "required": ["count"]
845        });
846
847        let state: WebSocketState<EchoHandler> =
848            WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
849
850        let invalid_response: Value = serde_json::json!([1, 2, 3]);
851        let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
852        assert!(!validator.is_valid(&invalid_response));
853    }
854
855    /// Test message with multiple required fields missing
856    #[test]
857    fn test_message_missing_multiple_required_fields() {
858        let handler: EchoHandler = EchoHandler;
859        let message_schema: serde_json::Value = serde_json::json!({
860            "type": "object",
861            "properties": {
862                "user_id": {"type": "integer"},
863                "action": {"type": "string"},
864                "timestamp": {"type": "string"}
865            },
866            "required": ["user_id", "action", "timestamp"]
867        });
868
869        let state: WebSocketState<EchoHandler> =
870            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
871
872        let invalid_msg: Value = serde_json::json!({"other": "value"});
873        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
874        assert!(!validator.is_valid(&invalid_msg));
875
876        let partial_msg: Value = serde_json::json!({"user_id": 123});
877        assert!(!validator.is_valid(&partial_msg));
878    }
879
880    /// Test deeply nested schema validation with required nested properties
881    #[test]
882    fn test_deeply_nested_schema_validation_failure() {
883        let handler: EchoHandler = EchoHandler;
884        let message_schema: serde_json::Value = serde_json::json!({
885            "type": "object",
886            "properties": {
887                "metadata": {
888                    "type": "object",
889                    "properties": {
890                        "request": {
891                            "type": "object",
892                            "properties": {
893                                "id": {"type": "string"}
894                            },
895                            "required": ["id"]
896                        }
897                    },
898                    "required": ["request"]
899                }
900            },
901            "required": ["metadata"]
902        });
903
904        let state: WebSocketState<EchoHandler> =
905            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
906
907        let invalid_msg: Value = serde_json::json!({
908            "metadata": {
909                "request": {}
910            }
911        });
912        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
913        assert!(!validator.is_valid(&invalid_msg));
914    }
915
916    /// Test array property validation with items constraint
917    #[test]
918    fn test_array_property_type_validation() {
919        let handler: EchoHandler = EchoHandler;
920        let message_schema: serde_json::Value = serde_json::json!({
921            "type": "object",
922            "properties": {
923                "ids": {
924                    "type": "array",
925                    "items": {"type": "integer"}
926                }
927            }
928        });
929
930        let state: WebSocketState<EchoHandler> =
931            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
932
933        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
934
935        let valid_msg: Value = serde_json::json!({"ids": [1, 2, 3]});
936        assert!(validator.is_valid(&valid_msg));
937
938        let invalid_msg: Value = serde_json::json!({"ids": [1, "two", 3]});
939        assert!(!validator.is_valid(&invalid_msg));
940
941        let invalid_msg2: Value = serde_json::json!({"ids": "not_an_array"});
942        assert!(!validator.is_valid(&invalid_msg2));
943    }
944
945    /// Test enum/const property validation
946    #[test]
947    fn test_enum_property_validation() {
948        let handler: EchoHandler = EchoHandler;
949        let message_schema: serde_json::Value = serde_json::json!({
950            "type": "object",
951            "properties": {
952                "status": {
953                    "type": "string",
954                    "enum": ["pending", "active", "completed"]
955                }
956            },
957            "required": ["status"]
958        });
959
960        let state: WebSocketState<EchoHandler> =
961            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
962
963        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
964
965        let valid_msg: Value = serde_json::json!({"status": "active"});
966        assert!(validator.is_valid(&valid_msg));
967
968        let invalid_msg: Value = serde_json::json!({"status": "unknown"});
969        assert!(!validator.is_valid(&invalid_msg));
970    }
971
972    /// Test minimum/maximum constraints on numbers
973    #[test]
974    fn test_number_range_validation() {
975        let handler: EchoHandler = EchoHandler;
976        let message_schema: serde_json::Value = serde_json::json!({
977            "type": "object",
978            "properties": {
979                "age": {
980                    "type": "integer",
981                    "minimum": 0,
982                    "maximum": 150
983                }
984            },
985            "required": ["age"]
986        });
987
988        let state: WebSocketState<EchoHandler> =
989            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
990
991        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
992
993        let valid_msg: Value = serde_json::json!({"age": 25});
994        assert!(validator.is_valid(&valid_msg));
995
996        let invalid_msg: Value = serde_json::json!({"age": -1});
997        assert!(!validator.is_valid(&invalid_msg));
998
999        let invalid_msg2: Value = serde_json::json!({"age": 200});
1000        assert!(!validator.is_valid(&invalid_msg2));
1001    }
1002
1003    /// Test string length constraints
1004    #[test]
1005    fn test_string_length_validation() {
1006        let handler: EchoHandler = EchoHandler;
1007        let message_schema: serde_json::Value = serde_json::json!({
1008            "type": "object",
1009            "properties": {
1010                "username": {
1011                    "type": "string",
1012                    "minLength": 3,
1013                    "maxLength": 20
1014                }
1015            },
1016            "required": ["username"]
1017        });
1018
1019        let state: WebSocketState<EchoHandler> =
1020            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1021
1022        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1023
1024        let valid_msg: Value = serde_json::json!({"username": "alice"});
1025        assert!(validator.is_valid(&valid_msg));
1026
1027        let invalid_msg: Value = serde_json::json!({"username": "ab"});
1028        assert!(!validator.is_valid(&invalid_msg));
1029
1030        let invalid_msg2: Value =
1031            serde_json::json!({"username": "this_is_a_very_long_username_over_twenty_characters"});
1032        assert!(!validator.is_valid(&invalid_msg2));
1033    }
1034
1035    /// Test pattern (regex) validation
1036    #[test]
1037    fn test_pattern_validation() {
1038        let handler: EchoHandler = EchoHandler;
1039        let message_schema: serde_json::Value = serde_json::json!({
1040            "type": "object",
1041            "properties": {
1042                "email": {
1043                    "type": "string",
1044                    "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
1045                }
1046            },
1047            "required": ["email"]
1048        });
1049
1050        let state: WebSocketState<EchoHandler> =
1051            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1052
1053        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1054
1055        let valid_msg: Value = serde_json::json!({"email": "user@example.com"});
1056        assert!(validator.is_valid(&valid_msg));
1057
1058        let invalid_msg: Value = serde_json::json!({"email": "user@example"});
1059        assert!(!validator.is_valid(&invalid_msg));
1060
1061        let invalid_msg2: Value = serde_json::json!({"email": "userexample.com"});
1062        assert!(!validator.is_valid(&invalid_msg2));
1063    }
1064
1065    /// Test additionalProperties constraint
1066    #[test]
1067    fn test_additional_properties_validation() {
1068        let handler: EchoHandler = EchoHandler;
1069        let message_schema: serde_json::Value = serde_json::json!({
1070            "type": "object",
1071            "properties": {
1072                "name": {"type": "string"}
1073            },
1074            "additionalProperties": false
1075        });
1076
1077        let state: WebSocketState<EchoHandler> =
1078            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1079
1080        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1081
1082        let valid_msg: Value = serde_json::json!({"name": "Alice"});
1083        assert!(validator.is_valid(&valid_msg));
1084
1085        let invalid_msg: Value = serde_json::json!({"name": "Bob", "age": 30});
1086        assert!(!validator.is_valid(&invalid_msg));
1087    }
1088
1089    /// Test oneOf constraint (mutually exclusive properties)
1090    #[test]
1091    fn test_one_of_constraint() {
1092        let handler: EchoHandler = EchoHandler;
1093        let message_schema: serde_json::Value = serde_json::json!({
1094            "type": "object",
1095            "oneOf": [
1096                {
1097                    "properties": {"type": {"const": "text"}},
1098                    "required": ["type"]
1099                },
1100                {
1101                    "properties": {"type": {"const": "number"}},
1102                    "required": ["type"]
1103                }
1104            ]
1105        });
1106
1107        let state: WebSocketState<EchoHandler> =
1108            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1109
1110        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1111
1112        let valid_msg: Value = serde_json::json!({"type": "text"});
1113        assert!(validator.is_valid(&valid_msg));
1114
1115        let invalid_msg: Value = serde_json::json!({"type": "unknown"});
1116        assert!(!validator.is_valid(&invalid_msg));
1117    }
1118
1119    /// Test anyOf constraint (at least one match)
1120    #[test]
1121    fn test_any_of_constraint() {
1122        let handler: EchoHandler = EchoHandler;
1123        let message_schema: serde_json::Value = serde_json::json!({
1124            "type": "object",
1125            "properties": {
1126                "value": {"type": ["string", "integer"]}
1127            },
1128            "required": ["value"]
1129        });
1130
1131        let state: WebSocketState<EchoHandler> =
1132            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1133
1134        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1135
1136        let msg1: Value = serde_json::json!({"value": "text"});
1137        assert!(validator.is_valid(&msg1));
1138
1139        let msg2: Value = serde_json::json!({"value": 42});
1140        assert!(validator.is_valid(&msg2));
1141
1142        let invalid_msg: Value = serde_json::json!({"value": true});
1143        assert!(!validator.is_valid(&invalid_msg));
1144    }
1145
1146    /// Test response validation with complex constraints
1147    #[test]
1148    fn test_response_schema_with_multiple_constraints() {
1149        let handler: EchoHandler = EchoHandler;
1150        let response_schema: serde_json::Value = serde_json::json!({
1151            "type": "object",
1152            "properties": {
1153                "success": {"type": "boolean"},
1154                "data": {
1155                    "type": "object",
1156                    "properties": {
1157                        "items": {
1158                            "type": "array",
1159                            "items": {"type": "object"},
1160                            "minItems": 1
1161                        }
1162                    },
1163                    "required": ["items"]
1164                }
1165            },
1166            "required": ["success", "data"]
1167        });
1168
1169        let state: WebSocketState<EchoHandler> =
1170            WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
1171
1172        let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
1173
1174        let valid_response: Value = serde_json::json!({
1175            "success": true,
1176            "data": {
1177                "items": [{"id": 1}]
1178            }
1179        });
1180        assert!(validator.is_valid(&valid_response));
1181
1182        let invalid_response: Value = serde_json::json!({
1183            "success": true,
1184            "data": {
1185                "items": []
1186            }
1187        });
1188        assert!(!validator.is_valid(&invalid_response));
1189
1190        let invalid_response2: Value = serde_json::json!({
1191            "success": true
1192        });
1193        assert!(!validator.is_valid(&invalid_response2));
1194    }
1195
1196    /// Test null type validation
1197    #[test]
1198    fn test_null_value_validation() {
1199        let handler: EchoHandler = EchoHandler;
1200        let message_schema: serde_json::Value = serde_json::json!({
1201            "type": "object",
1202            "properties": {
1203                "optional_field": {"type": ["string", "null"]},
1204                "required_field": {"type": "string"}
1205            },
1206            "required": ["required_field"]
1207        });
1208
1209        let state: WebSocketState<EchoHandler> =
1210            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1211
1212        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1213
1214        let msg1: Value = serde_json::json!({
1215            "optional_field": null,
1216            "required_field": "value"
1217        });
1218        assert!(validator.is_valid(&msg1));
1219
1220        let msg2: Value = serde_json::json!({"required_field": "value"});
1221        assert!(validator.is_valid(&msg2));
1222
1223        let invalid_msg: Value = serde_json::json!({"required_field": null});
1224        assert!(!validator.is_valid(&invalid_msg));
1225    }
1226
1227    /// Test schema with default values (they don't change validation)
1228    #[test]
1229    fn test_schema_with_defaults_still_validates() {
1230        let handler: EchoHandler = EchoHandler;
1231        let message_schema: serde_json::Value = serde_json::json!({
1232            "type": "object",
1233            "properties": {
1234                "status": {
1235                    "type": "string",
1236                    "default": "pending"
1237                }
1238            }
1239        });
1240
1241        let state: WebSocketState<EchoHandler> =
1242            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1243
1244        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1245
1246        let msg: Value = serde_json::json!({});
1247        assert!(validator.is_valid(&msg));
1248    }
1249
1250    /// Test both message and response schema validation together
1251    #[test]
1252    fn test_both_schemas_validate_independently() {
1253        let handler: EchoHandler = EchoHandler;
1254        let message_schema: serde_json::Value = serde_json::json!({
1255            "type": "object",
1256            "properties": {"action": {"type": "string"}},
1257            "required": ["action"]
1258        });
1259        let response_schema: serde_json::Value = serde_json::json!({
1260            "type": "object",
1261            "properties": {"result": {"type": "string"}},
1262            "required": ["result"]
1263        });
1264
1265        let state: WebSocketState<EchoHandler> =
1266            WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema)).unwrap();
1267
1268        let msg_validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1269        let resp_validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
1270
1271        let valid_msg: Value = serde_json::json!({"action": "test"});
1272        let invalid_response: Value = serde_json::json!({"data": "oops"});
1273
1274        assert!(msg_validator.is_valid(&valid_msg));
1275        assert!(!resp_validator.is_valid(&invalid_response));
1276
1277        let invalid_msg: Value = serde_json::json!({"data": "oops"});
1278        let valid_response: Value = serde_json::json!({"result": "ok"});
1279
1280        assert!(!msg_validator.is_valid(&invalid_msg));
1281        assert!(resp_validator.is_valid(&valid_response));
1282    }
1283
1284    /// Test validation with very long/large payload
1285    #[test]
1286    fn test_validation_with_large_payload() {
1287        let handler: EchoHandler = EchoHandler;
1288        let message_schema: serde_json::Value = serde_json::json!({
1289            "type": "object",
1290            "properties": {
1291                "items": {
1292                    "type": "array",
1293                    "items": {"type": "integer"}
1294                }
1295            },
1296            "required": ["items"]
1297        });
1298
1299        let state: WebSocketState<EchoHandler> =
1300            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1301
1302        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1303
1304        let mut items = Vec::new();
1305        for i in 0..10_000 {
1306            items.push(i);
1307        }
1308        let large_msg: Value = serde_json::json!({"items": items});
1309
1310        assert!(validator.is_valid(&large_msg));
1311    }
1312
1313    /// Test validation error doesn't panic with invalid schema combinations
1314    #[test]
1315    fn test_mutually_exclusive_schema_properties() {
1316        let handler: EchoHandler = EchoHandler;
1317
1318        let message_schema: serde_json::Value = serde_json::json!({
1319            "allOf": [
1320                {
1321                    "type": "object",
1322                    "properties": {"a": {"type": "string"}},
1323                    "required": ["a"]
1324                },
1325                {
1326                    "type": "object",
1327                    "properties": {"b": {"type": "integer"}},
1328                    "required": ["b"]
1329                }
1330            ]
1331        });
1332
1333        let state: WebSocketState<EchoHandler> =
1334            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1335
1336        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1337
1338        let valid_msg: Value = serde_json::json!({"a": "text", "b": 42});
1339        assert!(validator.is_valid(&valid_msg));
1340
1341        let invalid_msg: Value = serde_json::json!({"a": "text"});
1342        assert!(!validator.is_valid(&invalid_msg));
1343    }
1344}