spikard_http/
websocket.rs

1//! WebSocket support for Spikard
2//!
3//! Provides WebSocket connection handling with message validation and routing.
4
5use axum::{
6    extract::{
7        State,
8        ws::{Message, WebSocket, WebSocketUpgrade},
9    },
10    response::IntoResponse,
11};
12use serde_json::Value;
13use std::sync::Arc;
14use tracing::{debug, error, info, warn};
15
16fn trace_ws(message: &str) {
17    if std::env::var("SPIKARD_WS_TRACE").ok().as_deref() == Some("1") {
18        eprintln!("[spikard-ws] {message}");
19    }
20}
21
22/// WebSocket message handler trait
23///
24/// Implement this trait to create custom WebSocket message handlers for your application.
25/// The handler processes JSON messages received from WebSocket clients and can optionally
26/// send responses back.
27///
28/// # Implementing the Trait
29///
30/// You must implement the `handle_message` method. The `on_connect` and `on_disconnect`
31/// methods are optional and provide lifecycle hooks.
32///
33/// # Example
34///
35/// ```ignore
36/// use spikard_http::websocket::WebSocketHandler;
37/// use serde_json::{json, Value};
38///
39/// struct EchoHandler;
40///
41/// #[async_trait]
42/// impl WebSocketHandler for EchoHandler {
43///     async fn handle_message(&self, message: Value) -> Option<Value> {
44///         // Echo the message back to the client
45///         Some(message)
46///     }
47///
48///     async fn on_connect(&self) {
49///         println!("Client connected");
50///     }
51///
52///     async fn on_disconnect(&self) {
53///         println!("Client disconnected");
54///     }
55/// }
56/// ```
57pub trait WebSocketHandler: Send + Sync {
58    /// Handle incoming WebSocket message
59    ///
60    /// Called whenever a text message is received from a WebSocket client.
61    /// Messages are automatically parsed as JSON.
62    ///
63    /// # Arguments
64    /// * `message` - JSON value received from the client
65    ///
66    /// # Returns
67    /// * `Some(value)` - JSON value to send back to the client
68    /// * `None` - No response to send
69    fn handle_message(&self, message: Value) -> impl std::future::Future<Output = Option<Value>> + Send;
70
71    /// Called when a client connects to the WebSocket
72    ///
73    /// Optional lifecycle hook invoked when a new WebSocket connection is established.
74    /// Default implementation does nothing.
75    fn on_connect(&self) -> impl std::future::Future<Output = ()> + Send {
76        async {}
77    }
78
79    /// Called when a client disconnects from the WebSocket
80    ///
81    /// Optional lifecycle hook invoked when a WebSocket connection is closed
82    /// (either by the client or due to an error). Default implementation does nothing.
83    fn on_disconnect(&self) -> impl std::future::Future<Output = ()> + Send {
84        async {}
85    }
86}
87
88/// WebSocket state shared across connections
89///
90/// Contains the message handler and optional JSON schemas for validating
91/// incoming and outgoing messages. This state is shared among all connections
92/// to the same WebSocket endpoint.
93#[derive(Debug)]
94pub struct WebSocketState<H: WebSocketHandler> {
95    /// The message handler implementation
96    handler: Arc<H>,
97    /// Optional JSON Schema for validating incoming messages
98    message_schema: Option<Arc<jsonschema::Validator>>,
99    /// Optional JSON Schema for validating outgoing responses
100    response_schema: Option<Arc<jsonschema::Validator>>,
101}
102
103impl<H: WebSocketHandler> Clone for WebSocketState<H> {
104    fn clone(&self) -> Self {
105        Self {
106            handler: Arc::clone(&self.handler),
107            message_schema: self.message_schema.clone(),
108            response_schema: self.response_schema.clone(),
109        }
110    }
111}
112
113impl<H: WebSocketHandler + 'static> WebSocketState<H> {
114    /// Create new WebSocket state with a handler
115    ///
116    /// Creates a new state without message or response validation schemas.
117    /// Messages and responses are not validated.
118    ///
119    /// # Arguments
120    /// * `handler` - The message handler implementation
121    ///
122    /// # Example
123    ///
124    /// ```ignore
125    /// let state = WebSocketState::new(MyHandler);
126    /// ```
127    pub fn new(handler: H) -> Self {
128        Self {
129            handler: Arc::new(handler),
130            message_schema: None,
131            response_schema: None,
132        }
133    }
134
135    /// Create new WebSocket state with a handler and optional validation schemas
136    ///
137    /// Creates a new state with optional JSON schemas for validating incoming messages
138    /// and outgoing responses. If a schema is provided and validation fails, the message
139    /// or response is rejected.
140    ///
141    /// # Arguments
142    /// * `handler` - The message handler implementation
143    /// * `message_schema` - Optional JSON schema for validating client messages
144    /// * `response_schema` - Optional JSON schema for validating handler responses
145    ///
146    /// # Returns
147    /// * `Ok(state)` - Successfully created state
148    /// * `Err(msg)` - Invalid schema provided
149    ///
150    /// # Example
151    ///
152    /// ```ignore
153    /// use serde_json::json;
154    ///
155    /// let message_schema = json!({
156    ///     "type": "object",
157    ///     "properties": {
158    ///         "type": {"type": "string"},
159    ///         "data": {"type": "string"}
160    ///     }
161    /// });
162    ///
163    /// let state = WebSocketState::with_schemas(
164    ///     MyHandler,
165    ///     Some(message_schema),
166    ///     None,
167    /// )?;
168    /// ```
169    pub fn with_schemas(
170        handler: H,
171        message_schema: Option<serde_json::Value>,
172        response_schema: Option<serde_json::Value>,
173    ) -> Result<Self, String> {
174        let message_validator = if let Some(schema) = message_schema {
175            Some(Arc::new(
176                jsonschema::validator_for(&schema).map_err(|e| format!("Invalid message schema: {}", e))?,
177            ))
178        } else {
179            None
180        };
181
182        let response_validator = if let Some(schema) = response_schema {
183            Some(Arc::new(
184                jsonschema::validator_for(&schema).map_err(|e| format!("Invalid response schema: {}", e))?,
185            ))
186        } else {
187            None
188        };
189
190        Ok(Self {
191            handler: Arc::new(handler),
192            message_schema: message_validator,
193            response_schema: response_validator,
194        })
195    }
196
197    /// Invoke the connection hook for testing.
198    pub async fn on_connect(&self) {
199        self.handler.on_connect().await;
200    }
201
202    /// Invoke the disconnect hook for testing.
203    pub async fn on_disconnect(&self) {
204        self.handler.on_disconnect().await;
205    }
206
207    /// Validate and handle an incoming message without a socket transport.
208    pub async fn handle_message_validated(&self, message: Value) -> Result<Option<Value>, String> {
209        if let Some(validator) = &self.message_schema
210            && !validator.is_valid(&message)
211        {
212            return Err("Message validation failed".to_string());
213        }
214
215        let response = self.handler.handle_message(message).await;
216        if let Some(ref value) = response
217            && let Some(validator) = &self.response_schema
218            && !validator.is_valid(value)
219        {
220            return Ok(None);
221        }
222
223        Ok(response)
224    }
225}
226
227/// WebSocket upgrade handler
228///
229/// This is the main entry point for WebSocket connections. Use this as an Axum route
230/// handler by passing it to an Axum router's `.route()` method with `get()`.
231///
232/// # Arguments
233/// * `ws` - WebSocket upgrade from Axum
234/// * `State(state)` - Application state containing the handler and optional schemas
235///
236/// # Returns
237/// An Axum response that upgrades the connection to WebSocket
238///
239/// # Example
240///
241/// ```ignore
242/// use axum::{Router, routing::get, extract::State};
243///
244/// let state = WebSocketState::new(MyHandler);
245/// let router = Router::new()
246///     .route("/ws", get(websocket_handler::<MyHandler>))
247///     .with_state(state);
248/// ```
249pub async fn websocket_handler<H: WebSocketHandler + 'static>(
250    ws: WebSocketUpgrade,
251    State(state): State<WebSocketState<H>>,
252) -> impl IntoResponse {
253    ws.on_upgrade(move |socket| handle_socket(socket, state))
254}
255
256/// Handle an individual WebSocket connection
257async fn handle_socket<H: WebSocketHandler>(mut socket: WebSocket, state: WebSocketState<H>) {
258    info!("WebSocket client connected");
259    trace_ws("socket:connected");
260
261    state.handler.on_connect().await;
262    trace_ws("socket:on_connect:done");
263
264    while let Some(msg) = socket.recv().await {
265        match msg {
266            Ok(Message::Text(text)) => {
267                debug!("Received text message: {}", text);
268                trace_ws(&format!("recv:text len={}", text.len()));
269
270                match serde_json::from_str::<Value>(&text) {
271                    Ok(json_msg) => {
272                        trace_ws("recv:text:json-ok");
273                        if let Some(validator) = &state.message_schema
274                            && !validator.is_valid(&json_msg)
275                        {
276                            error!("Message validation failed");
277                            trace_ws("recv:text:validation-failed");
278                            let error_response = serde_json::json!({
279                                "error": "Message validation failed"
280                            });
281                            if let Ok(error_text) = serde_json::to_string(&error_response) {
282                                trace_ws(&format!("send:validation-error len={}", error_text.len()));
283                                let _ = socket.send(Message::Text(error_text.into())).await;
284                            }
285                            continue;
286                        }
287
288                        if let Some(response) = state.handler.handle_message(json_msg).await {
289                            trace_ws("handler:response:some");
290                            if let Some(validator) = &state.response_schema
291                                && !validator.is_valid(&response)
292                            {
293                                error!("Response validation failed");
294                                trace_ws("send:response:validation-failed");
295                                continue;
296                            }
297
298                            let response_text = serde_json::to_string(&response).unwrap_or_else(|_| "{}".to_string());
299                            let response_len = response_text.len();
300
301                            if let Err(e) = socket.send(Message::Text(response_text.into())).await {
302                                error!("Failed to send response: {}", e);
303                                trace_ws("send:response:error");
304                                break;
305                            }
306                            trace_ws(&format!("send:response len={}", response_len));
307                        } else {
308                            trace_ws("handler:response:none");
309                        }
310                    }
311                    Err(e) => {
312                        warn!("Failed to parse JSON message: {}", e);
313                        trace_ws("recv:text:json-error");
314                        let error_msg = serde_json::json!({
315                            "type": "error",
316                            "message": "Invalid JSON"
317                        });
318                        let error_text = serde_json::to_string(&error_msg).unwrap_or_else(|_| "{}".to_string());
319                        trace_ws(&format!("send:json-error len={}", error_text.len()));
320                        let _ = socket.send(Message::Text(error_text.into())).await;
321                    }
322                }
323            }
324            Ok(Message::Binary(data)) => {
325                debug!("Received binary message: {} bytes", data.len());
326                trace_ws(&format!("recv:binary len={}", data.len()));
327                if let Err(e) = socket.send(Message::Binary(data)).await {
328                    error!("Failed to send binary response: {}", e);
329                    trace_ws("send:binary:error");
330                    break;
331                }
332                trace_ws("send:binary:ok");
333            }
334            Ok(Message::Ping(data)) => {
335                debug!("Received ping");
336                trace_ws(&format!("recv:ping len={}", data.len()));
337                if let Err(e) = socket.send(Message::Pong(data)).await {
338                    error!("Failed to send pong: {}", e);
339                    trace_ws("send:pong:error");
340                    break;
341                }
342                trace_ws("send:pong:ok");
343            }
344            Ok(Message::Pong(_)) => {
345                debug!("Received pong");
346                trace_ws("recv:pong");
347            }
348            Ok(Message::Close(_)) => {
349                info!("Client closed connection");
350                trace_ws("recv:close");
351                break;
352            }
353            Err(e) => {
354                error!("WebSocket error: {}", e);
355                trace_ws(&format!("recv:error {}", e));
356                break;
357            }
358        }
359    }
360
361    state.handler.on_disconnect().await;
362    trace_ws("socket:on_disconnect:done");
363    info!("WebSocket client disconnected");
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use std::sync::Mutex;
370    use std::sync::atomic::{AtomicUsize, Ordering};
371
372    #[derive(Debug)]
373    struct EchoHandler;
374
375    impl WebSocketHandler for EchoHandler {
376        async fn handle_message(&self, message: Value) -> Option<Value> {
377            Some(message)
378        }
379    }
380
381    #[derive(Debug)]
382    struct TrackingHandler {
383        connect_count: Arc<AtomicUsize>,
384        disconnect_count: Arc<AtomicUsize>,
385        message_count: Arc<AtomicUsize>,
386        messages: Arc<Mutex<Vec<Value>>>,
387    }
388
389    impl TrackingHandler {
390        fn new() -> Self {
391            Self {
392                connect_count: Arc::new(AtomicUsize::new(0)),
393                disconnect_count: Arc::new(AtomicUsize::new(0)),
394                message_count: Arc::new(AtomicUsize::new(0)),
395                messages: Arc::new(Mutex::new(Vec::new())),
396            }
397        }
398    }
399
400    impl WebSocketHandler for TrackingHandler {
401        async fn handle_message(&self, message: Value) -> Option<Value> {
402            self.message_count.fetch_add(1, Ordering::SeqCst);
403            self.messages.lock().unwrap().push(message.clone());
404            Some(message)
405        }
406
407        async fn on_connect(&self) {
408            self.connect_count.fetch_add(1, Ordering::SeqCst);
409        }
410
411        async fn on_disconnect(&self) {
412            self.disconnect_count.fetch_add(1, Ordering::SeqCst);
413        }
414    }
415
416    #[derive(Debug)]
417    struct SelectiveHandler;
418
419    impl WebSocketHandler for SelectiveHandler {
420        async fn handle_message(&self, message: Value) -> Option<Value> {
421            if message.get("respond").is_some_and(|v| v.as_bool().unwrap_or(false)) {
422                Some(serde_json::json!({"response": "acknowledged"}))
423            } else {
424                None
425            }
426        }
427    }
428
429    #[derive(Debug)]
430    struct TransformHandler;
431
432    impl WebSocketHandler for TransformHandler {
433        async fn handle_message(&self, message: Value) -> Option<Value> {
434            message.as_object().map_or(None, |obj| {
435                let mut resp = obj.clone();
436                resp.insert("processed".to_string(), Value::Bool(true));
437                Some(Value::Object(resp))
438            })
439        }
440    }
441
442    #[test]
443    fn test_websocket_state_creation() {
444        let handler: EchoHandler = EchoHandler;
445        let state: WebSocketState<EchoHandler> = WebSocketState::new(handler);
446        let cloned: WebSocketState<EchoHandler> = state.clone();
447        assert!(Arc::ptr_eq(&state.handler, &cloned.handler));
448    }
449
450    #[test]
451    fn test_websocket_state_with_valid_schema() {
452        let handler: EchoHandler = EchoHandler;
453        let schema: serde_json::Value = serde_json::json!({
454            "type": "object",
455            "properties": {
456                "type": {"type": "string"}
457            }
458        });
459
460        let result: Result<WebSocketState<EchoHandler>, String> =
461            WebSocketState::with_schemas(handler, Some(schema), None);
462        assert!(result.is_ok());
463    }
464
465    #[test]
466    fn test_websocket_state_with_invalid_schema() {
467        let handler: EchoHandler = EchoHandler;
468        let invalid_schema: serde_json::Value = serde_json::json!({
469            "type": "not_a_real_type",
470            "invalid": "schema"
471        });
472
473        let result: Result<WebSocketState<EchoHandler>, String> =
474            WebSocketState::with_schemas(handler, Some(invalid_schema), None);
475        assert!(result.is_err());
476        if let Err(error_msg) = result {
477            assert!(error_msg.contains("Invalid message schema"));
478        }
479    }
480
481    #[test]
482    fn test_websocket_state_with_both_schemas() {
483        let handler: EchoHandler = EchoHandler;
484        let message_schema: serde_json::Value = serde_json::json!({
485            "type": "object",
486            "properties": {"action": {"type": "string"}}
487        });
488        let response_schema: serde_json::Value = serde_json::json!({
489            "type": "object",
490            "properties": {"result": {"type": "string"}}
491        });
492
493        let result: Result<WebSocketState<EchoHandler>, String> =
494            WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema));
495        assert!(result.is_ok());
496        let state: WebSocketState<EchoHandler> = result.unwrap();
497        assert!(state.message_schema.is_some());
498        assert!(state.response_schema.is_some());
499    }
500
501    #[test]
502    fn test_websocket_state_cloning_preserves_schemas() {
503        let handler: EchoHandler = EchoHandler;
504        let schema: serde_json::Value = serde_json::json!({
505            "type": "object",
506            "properties": {"id": {"type": "integer"}}
507        });
508
509        let state: WebSocketState<EchoHandler> = WebSocketState::with_schemas(handler, Some(schema), None).unwrap();
510        let cloned: WebSocketState<EchoHandler> = state.clone();
511
512        assert!(cloned.message_schema.is_some());
513        assert!(cloned.response_schema.is_none());
514        assert!(Arc::ptr_eq(&state.handler, &cloned.handler));
515    }
516
517    #[tokio::test]
518    async fn test_tracking_handler_lifecycle() {
519        let handler: TrackingHandler = TrackingHandler::new();
520        handler.on_connect().await;
521        assert_eq!(handler.connect_count.load(Ordering::SeqCst), 1);
522
523        let msg: Value = serde_json::json!({"test": "data"});
524        let _response: Option<Value> = handler.handle_message(msg).await;
525        assert_eq!(handler.message_count.load(Ordering::SeqCst), 1);
526
527        handler.on_disconnect().await;
528        assert_eq!(handler.disconnect_count.load(Ordering::SeqCst), 1);
529    }
530
531    #[tokio::test]
532    async fn test_selective_handler_responds_conditionally() {
533        let handler: SelectiveHandler = SelectiveHandler;
534
535        let respond_msg: Value = serde_json::json!({"respond": true});
536        let response1: Option<Value> = handler.handle_message(respond_msg).await;
537        assert!(response1.is_some());
538        assert_eq!(response1.unwrap(), serde_json::json!({"response": "acknowledged"}));
539
540        let no_respond_msg: Value = serde_json::json!({"respond": false});
541        let response2: Option<Value> = handler.handle_message(no_respond_msg).await;
542        assert!(response2.is_none());
543    }
544
545    #[tokio::test]
546    async fn test_transform_handler_modifies_message() {
547        let handler: TransformHandler = TransformHandler;
548        let original: Value = serde_json::json!({"name": "test"});
549        let transformed: Option<Value> = handler.handle_message(original).await;
550
551        assert!(transformed.is_some());
552        let resp: Value = transformed.unwrap();
553        assert_eq!(resp.get("name").unwrap(), "test");
554        assert_eq!(resp.get("processed").unwrap(), true);
555    }
556
557    #[tokio::test]
558    async fn test_echo_handler_preserves_json_types() {
559        let handler: EchoHandler = EchoHandler;
560
561        let messages: Vec<Value> = vec![
562            serde_json::json!({"string": "value"}),
563            serde_json::json!({"number": 42}),
564            serde_json::json!({"float": 3.14}),
565            serde_json::json!({"bool": true}),
566            serde_json::json!({"null": null}),
567            serde_json::json!({"array": [1, 2, 3]}),
568        ];
569
570        for msg in messages {
571            let response: Option<Value> = handler.handle_message(msg.clone()).await;
572            assert!(response.is_some());
573            assert_eq!(response.unwrap(), msg);
574        }
575    }
576
577    #[tokio::test]
578    async fn test_tracking_handler_accumulates_messages() {
579        let handler: TrackingHandler = TrackingHandler::new();
580
581        let messages: Vec<Value> = vec![
582            serde_json::json!({"id": 1}),
583            serde_json::json!({"id": 2}),
584            serde_json::json!({"id": 3}),
585        ];
586
587        for msg in messages {
588            let _: Option<Value> = handler.handle_message(msg).await;
589        }
590
591        assert_eq!(handler.message_count.load(Ordering::SeqCst), 3);
592        let stored: Vec<Value> = handler.messages.lock().unwrap().clone();
593        assert_eq!(stored.len(), 3);
594        assert_eq!(stored[0].get("id").unwrap(), 1);
595        assert_eq!(stored[1].get("id").unwrap(), 2);
596        assert_eq!(stored[2].get("id").unwrap(), 3);
597    }
598
599    #[tokio::test]
600    async fn test_echo_handler_with_nested_json() {
601        let handler: EchoHandler = EchoHandler;
602        let nested: Value = serde_json::json!({
603            "level1": {
604                "level2": {
605                    "level3": {
606                        "value": "deeply nested"
607                    }
608                }
609            }
610        });
611
612        let response: Option<Value> = handler.handle_message(nested.clone()).await;
613        assert!(response.is_some());
614        assert_eq!(response.unwrap(), nested);
615    }
616
617    #[tokio::test]
618    async fn test_echo_handler_with_large_array() {
619        let handler: EchoHandler = EchoHandler;
620        let large_array: Value = serde_json::json!({
621            "items": (0..1000).collect::<Vec<i32>>()
622        });
623
624        let response: Option<Value> = handler.handle_message(large_array.clone()).await;
625        assert!(response.is_some());
626        assert_eq!(response.unwrap(), large_array);
627    }
628
629    #[tokio::test]
630    async fn test_echo_handler_with_unicode() {
631        let handler: EchoHandler = EchoHandler;
632        let unicode_msg: Value = serde_json::json!({
633            "emoji": "🚀",
634            "chinese": "你好",
635            "arabic": "مرحبا",
636            "mixed": "Hello 世界 🌍"
637        });
638
639        let response: Option<Value> = handler.handle_message(unicode_msg.clone()).await;
640        assert!(response.is_some());
641        assert_eq!(response.unwrap(), unicode_msg);
642    }
643
644    #[test]
645    fn test_websocket_state_schemas_are_independent() {
646        let handler: EchoHandler = EchoHandler;
647        let message_schema: serde_json::Value = serde_json::json!({"type": "object"});
648        let response_schema: serde_json::Value = serde_json::json!({"type": "array"});
649
650        let state: WebSocketState<EchoHandler> =
651            WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema)).unwrap();
652
653        let cloned: WebSocketState<EchoHandler> = state.clone();
654
655        assert!(state.message_schema.is_some());
656        assert!(state.response_schema.is_some());
657        assert!(cloned.message_schema.is_some());
658        assert!(cloned.response_schema.is_some());
659    }
660
661    #[test]
662    fn test_message_schema_validation_with_required_field() {
663        let handler: EchoHandler = EchoHandler;
664        let message_schema: serde_json::Value = serde_json::json!({
665            "type": "object",
666            "properties": {"type": {"type": "string"}},
667            "required": ["type"]
668        });
669
670        let state: WebSocketState<EchoHandler> =
671            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
672
673        assert!(state.message_schema.is_some());
674        assert!(state.response_schema.is_none());
675
676        let valid_msg: Value = serde_json::json!({"type": "test"});
677        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
678        assert!(validator.is_valid(&valid_msg));
679
680        let invalid_msg: Value = serde_json::json!({"other": "field"});
681        assert!(!validator.is_valid(&invalid_msg));
682    }
683
684    #[test]
685    fn test_response_schema_validation_with_required_field() {
686        let handler: EchoHandler = EchoHandler;
687        let response_schema: serde_json::Value = serde_json::json!({
688            "type": "object",
689            "properties": {"status": {"type": "string"}},
690            "required": ["status"]
691        });
692
693        let state: WebSocketState<EchoHandler> =
694            WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
695
696        assert!(state.message_schema.is_none());
697        assert!(state.response_schema.is_some());
698
699        let valid_response: Value = serde_json::json!({"status": "ok"});
700        let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
701        assert!(validator.is_valid(&valid_response));
702
703        let invalid_response: Value = serde_json::json!({"other": "field"});
704        assert!(!validator.is_valid(&invalid_response));
705    }
706
707    #[test]
708    fn test_invalid_message_schema_returns_error() {
709        let handler: EchoHandler = EchoHandler;
710        let invalid_schema: serde_json::Value = serde_json::json!({
711            "type": "invalid_type_value",
712            "properties": {}
713        });
714
715        let result: Result<WebSocketState<EchoHandler>, String> =
716            WebSocketState::with_schemas(handler, Some(invalid_schema), None);
717
718        assert!(result.is_err());
719        match result {
720            Err(error_msg) => assert!(error_msg.contains("Invalid message schema")),
721            Ok(_) => panic!("Expected error but got Ok"),
722        }
723    }
724
725    #[test]
726    fn test_invalid_response_schema_returns_error() {
727        let handler: EchoHandler = EchoHandler;
728        let invalid_schema: serde_json::Value = serde_json::json!({
729            "type": "definitely_not_valid"
730        });
731
732        let result: Result<WebSocketState<EchoHandler>, String> =
733            WebSocketState::with_schemas(handler, None, Some(invalid_schema));
734
735        assert!(result.is_err());
736        match result {
737            Err(error_msg) => assert!(error_msg.contains("Invalid response schema")),
738            Ok(_) => panic!("Expected error but got Ok"),
739        }
740    }
741
742    #[tokio::test]
743    async fn test_handler_returning_none_response() {
744        let handler: SelectiveHandler = SelectiveHandler;
745
746        let no_response_msg: Value = serde_json::json!({"respond": false});
747        let result: Option<Value> = handler.handle_message(no_response_msg).await;
748
749        assert!(result.is_none());
750    }
751
752    #[tokio::test]
753    async fn test_handler_with_complex_schema_validation() {
754        let handler: EchoHandler = EchoHandler;
755        let message_schema: serde_json::Value = serde_json::json!({
756            "type": "object",
757            "properties": {
758                "user": {
759                    "type": "object",
760                    "properties": {
761                        "id": {"type": "integer"},
762                        "name": {"type": "string"}
763                    },
764                    "required": ["id", "name"]
765                },
766                "action": {"type": "string"}
767            },
768            "required": ["user", "action"]
769        });
770
771        let state: WebSocketState<EchoHandler> =
772            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
773
774        let valid_msg: Value = serde_json::json!({
775            "user": {"id": 123, "name": "Alice"},
776            "action": "create"
777        });
778        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
779        assert!(validator.is_valid(&valid_msg));
780
781        let invalid_msg: Value = serde_json::json!({
782            "user": {"id": "not_an_int", "name": "Bob"},
783            "action": "create"
784        });
785        assert!(!validator.is_valid(&invalid_msg));
786    }
787
788    #[tokio::test]
789    async fn test_tracking_handler_with_multiple_message_types() {
790        let handler: TrackingHandler = TrackingHandler::new();
791
792        let messages: Vec<Value> = vec![
793            serde_json::json!({"type": "text", "content": "hello"}),
794            serde_json::json!({"type": "image", "url": "http://example.com/image.png"}),
795            serde_json::json!({"type": "video", "duration": 120}),
796        ];
797
798        for msg in messages {
799            let _: Option<Value> = handler.handle_message(msg).await;
800        }
801
802        assert_eq!(handler.message_count.load(Ordering::SeqCst), 3);
803        let stored: Vec<Value> = handler.messages.lock().unwrap().clone();
804        assert_eq!(stored.len(), 3);
805        assert_eq!(stored[0].get("type").unwrap(), "text");
806        assert_eq!(stored[1].get("type").unwrap(), "image");
807        assert_eq!(stored[2].get("type").unwrap(), "video");
808    }
809
810    #[tokio::test]
811    async fn test_selective_handler_with_explicit_false() {
812        let handler: SelectiveHandler = SelectiveHandler;
813
814        let msg: Value = serde_json::json!({"respond": false, "data": "test"});
815        let response: Option<Value> = handler.handle_message(msg).await;
816
817        assert!(response.is_none());
818    }
819
820    #[tokio::test]
821    async fn test_selective_handler_without_respond_field() {
822        let handler: SelectiveHandler = SelectiveHandler;
823
824        let msg: Value = serde_json::json!({"data": "test"});
825        let response: Option<Value> = handler.handle_message(msg).await;
826
827        assert!(response.is_none());
828    }
829
830    #[tokio::test]
831    async fn test_transform_handler_with_empty_object() {
832        let handler: TransformHandler = TransformHandler;
833        let original: Value = serde_json::json!({});
834        let transformed: Option<Value> = handler.handle_message(original).await;
835
836        assert!(transformed.is_some());
837        let resp: Value = transformed.unwrap();
838        assert_eq!(resp.get("processed").unwrap(), true);
839        assert_eq!(resp.as_object().unwrap().len(), 1);
840    }
841
842    #[tokio::test]
843    async fn test_transform_handler_preserves_all_fields() {
844        let handler: TransformHandler = TransformHandler;
845        let original: Value = serde_json::json!({
846            "field1": "value1",
847            "field2": 42,
848            "field3": true,
849            "nested": {"key": "value"}
850        });
851        let transformed: Option<Value> = handler.handle_message(original.clone()).await;
852
853        assert!(transformed.is_some());
854        let resp: Value = transformed.unwrap();
855        assert_eq!(resp.get("field1").unwrap(), "value1");
856        assert_eq!(resp.get("field2").unwrap(), 42);
857        assert_eq!(resp.get("field3").unwrap(), true);
858        assert_eq!(resp.get("nested").unwrap(), &serde_json::json!({"key": "value"}));
859        assert_eq!(resp.get("processed").unwrap(), true);
860    }
861
862    #[tokio::test]
863    async fn test_transform_handler_with_non_object_input() {
864        let handler: TransformHandler = TransformHandler;
865
866        let array: Value = serde_json::json!([1, 2, 3]);
867        let response1: Option<Value> = handler.handle_message(array).await;
868        assert!(response1.is_none());
869
870        let string: Value = serde_json::json!("not an object");
871        let response2: Option<Value> = handler.handle_message(string).await;
872        assert!(response2.is_none());
873
874        let number: Value = serde_json::json!(42);
875        let response3: Option<Value> = handler.handle_message(number).await;
876        assert!(response3.is_none());
877    }
878
879    /// Test message validation failure with schema constraint
880    #[test]
881    fn test_message_schema_rejects_wrong_type() {
882        let handler: EchoHandler = EchoHandler;
883        let message_schema: serde_json::Value = serde_json::json!({
884            "type": "object",
885            "properties": {"id": {"type": "integer"}},
886            "required": ["id"]
887        });
888
889        let state: WebSocketState<EchoHandler> =
890            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
891
892        let invalid_msg: Value = serde_json::json!({"id": "not_an_integer"});
893        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
894        assert!(!validator.is_valid(&invalid_msg));
895    }
896
897    /// Test response schema validation failure
898    #[test]
899    fn test_response_schema_rejects_invalid_type() {
900        let handler: EchoHandler = EchoHandler;
901        let response_schema: serde_json::Value = serde_json::json!({
902            "type": "object",
903            "properties": {"count": {"type": "integer"}},
904            "required": ["count"]
905        });
906
907        let state: WebSocketState<EchoHandler> =
908            WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
909
910        let invalid_response: Value = serde_json::json!([1, 2, 3]);
911        let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
912        assert!(!validator.is_valid(&invalid_response));
913    }
914
915    /// Test message with multiple required fields missing
916    #[test]
917    fn test_message_missing_multiple_required_fields() {
918        let handler: EchoHandler = EchoHandler;
919        let message_schema: serde_json::Value = serde_json::json!({
920            "type": "object",
921            "properties": {
922                "user_id": {"type": "integer"},
923                "action": {"type": "string"},
924                "timestamp": {"type": "string"}
925            },
926            "required": ["user_id", "action", "timestamp"]
927        });
928
929        let state: WebSocketState<EchoHandler> =
930            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
931
932        let invalid_msg: Value = serde_json::json!({"other": "value"});
933        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
934        assert!(!validator.is_valid(&invalid_msg));
935
936        let partial_msg: Value = serde_json::json!({"user_id": 123});
937        assert!(!validator.is_valid(&partial_msg));
938    }
939
940    /// Test deeply nested schema validation with required nested properties
941    #[test]
942    fn test_deeply_nested_schema_validation_failure() {
943        let handler: EchoHandler = EchoHandler;
944        let message_schema: serde_json::Value = serde_json::json!({
945            "type": "object",
946            "properties": {
947                "metadata": {
948                    "type": "object",
949                    "properties": {
950                        "request": {
951                            "type": "object",
952                            "properties": {
953                                "id": {"type": "string"}
954                            },
955                            "required": ["id"]
956                        }
957                    },
958                    "required": ["request"]
959                }
960            },
961            "required": ["metadata"]
962        });
963
964        let state: WebSocketState<EchoHandler> =
965            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
966
967        let invalid_msg: Value = serde_json::json!({
968            "metadata": {
969                "request": {}
970            }
971        });
972        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
973        assert!(!validator.is_valid(&invalid_msg));
974    }
975
976    /// Test array property validation with items constraint
977    #[test]
978    fn test_array_property_type_validation() {
979        let handler: EchoHandler = EchoHandler;
980        let message_schema: serde_json::Value = serde_json::json!({
981            "type": "object",
982            "properties": {
983                "ids": {
984                    "type": "array",
985                    "items": {"type": "integer"}
986                }
987            }
988        });
989
990        let state: WebSocketState<EchoHandler> =
991            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
992
993        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
994
995        let valid_msg: Value = serde_json::json!({"ids": [1, 2, 3]});
996        assert!(validator.is_valid(&valid_msg));
997
998        let invalid_msg: Value = serde_json::json!({"ids": [1, "two", 3]});
999        assert!(!validator.is_valid(&invalid_msg));
1000
1001        let invalid_msg2: Value = serde_json::json!({"ids": "not_an_array"});
1002        assert!(!validator.is_valid(&invalid_msg2));
1003    }
1004
1005    /// Test enum/const property validation
1006    #[test]
1007    fn test_enum_property_validation() {
1008        let handler: EchoHandler = EchoHandler;
1009        let message_schema: serde_json::Value = serde_json::json!({
1010            "type": "object",
1011            "properties": {
1012                "status": {
1013                    "type": "string",
1014                    "enum": ["pending", "active", "completed"]
1015                }
1016            },
1017            "required": ["status"]
1018        });
1019
1020        let state: WebSocketState<EchoHandler> =
1021            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1022
1023        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1024
1025        let valid_msg: Value = serde_json::json!({"status": "active"});
1026        assert!(validator.is_valid(&valid_msg));
1027
1028        let invalid_msg: Value = serde_json::json!({"status": "unknown"});
1029        assert!(!validator.is_valid(&invalid_msg));
1030    }
1031
1032    /// Test minimum/maximum constraints on numbers
1033    #[test]
1034    fn test_number_range_validation() {
1035        let handler: EchoHandler = EchoHandler;
1036        let message_schema: serde_json::Value = serde_json::json!({
1037            "type": "object",
1038            "properties": {
1039                "age": {
1040                    "type": "integer",
1041                    "minimum": 0,
1042                    "maximum": 150
1043                }
1044            },
1045            "required": ["age"]
1046        });
1047
1048        let state: WebSocketState<EchoHandler> =
1049            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1050
1051        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1052
1053        let valid_msg: Value = serde_json::json!({"age": 25});
1054        assert!(validator.is_valid(&valid_msg));
1055
1056        let invalid_msg: Value = serde_json::json!({"age": -1});
1057        assert!(!validator.is_valid(&invalid_msg));
1058
1059        let invalid_msg2: Value = serde_json::json!({"age": 200});
1060        assert!(!validator.is_valid(&invalid_msg2));
1061    }
1062
1063    /// Test string length constraints
1064    #[test]
1065    fn test_string_length_validation() {
1066        let handler: EchoHandler = EchoHandler;
1067        let message_schema: serde_json::Value = serde_json::json!({
1068            "type": "object",
1069            "properties": {
1070                "username": {
1071                    "type": "string",
1072                    "minLength": 3,
1073                    "maxLength": 20
1074                }
1075            },
1076            "required": ["username"]
1077        });
1078
1079        let state: WebSocketState<EchoHandler> =
1080            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1081
1082        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1083
1084        let valid_msg: Value = serde_json::json!({"username": "alice"});
1085        assert!(validator.is_valid(&valid_msg));
1086
1087        let invalid_msg: Value = serde_json::json!({"username": "ab"});
1088        assert!(!validator.is_valid(&invalid_msg));
1089
1090        let invalid_msg2: Value =
1091            serde_json::json!({"username": "this_is_a_very_long_username_over_twenty_characters"});
1092        assert!(!validator.is_valid(&invalid_msg2));
1093    }
1094
1095    /// Test pattern (regex) validation
1096    #[test]
1097    fn test_pattern_validation() {
1098        let handler: EchoHandler = EchoHandler;
1099        let message_schema: serde_json::Value = serde_json::json!({
1100            "type": "object",
1101            "properties": {
1102                "email": {
1103                    "type": "string",
1104                    "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
1105                }
1106            },
1107            "required": ["email"]
1108        });
1109
1110        let state: WebSocketState<EchoHandler> =
1111            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1112
1113        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1114
1115        let valid_msg: Value = serde_json::json!({"email": "user@example.com"});
1116        assert!(validator.is_valid(&valid_msg));
1117
1118        let invalid_msg: Value = serde_json::json!({"email": "user@example"});
1119        assert!(!validator.is_valid(&invalid_msg));
1120
1121        let invalid_msg2: Value = serde_json::json!({"email": "userexample.com"});
1122        assert!(!validator.is_valid(&invalid_msg2));
1123    }
1124
1125    /// Test additionalProperties constraint
1126    #[test]
1127    fn test_additional_properties_validation() {
1128        let handler: EchoHandler = EchoHandler;
1129        let message_schema: serde_json::Value = serde_json::json!({
1130            "type": "object",
1131            "properties": {
1132                "name": {"type": "string"}
1133            },
1134            "additionalProperties": false
1135        });
1136
1137        let state: WebSocketState<EchoHandler> =
1138            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1139
1140        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1141
1142        let valid_msg: Value = serde_json::json!({"name": "Alice"});
1143        assert!(validator.is_valid(&valid_msg));
1144
1145        let invalid_msg: Value = serde_json::json!({"name": "Bob", "age": 30});
1146        assert!(!validator.is_valid(&invalid_msg));
1147    }
1148
1149    /// Test oneOf constraint (mutually exclusive properties)
1150    #[test]
1151    fn test_one_of_constraint() {
1152        let handler: EchoHandler = EchoHandler;
1153        let message_schema: serde_json::Value = serde_json::json!({
1154            "type": "object",
1155            "oneOf": [
1156                {
1157                    "properties": {"type": {"const": "text"}},
1158                    "required": ["type"]
1159                },
1160                {
1161                    "properties": {"type": {"const": "number"}},
1162                    "required": ["type"]
1163                }
1164            ]
1165        });
1166
1167        let state: WebSocketState<EchoHandler> =
1168            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1169
1170        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1171
1172        let valid_msg: Value = serde_json::json!({"type": "text"});
1173        assert!(validator.is_valid(&valid_msg));
1174
1175        let invalid_msg: Value = serde_json::json!({"type": "unknown"});
1176        assert!(!validator.is_valid(&invalid_msg));
1177    }
1178
1179    /// Test anyOf constraint (at least one match)
1180    #[test]
1181    fn test_any_of_constraint() {
1182        let handler: EchoHandler = EchoHandler;
1183        let message_schema: serde_json::Value = serde_json::json!({
1184            "type": "object",
1185            "properties": {
1186                "value": {"type": ["string", "integer"]}
1187            },
1188            "required": ["value"]
1189        });
1190
1191        let state: WebSocketState<EchoHandler> =
1192            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1193
1194        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1195
1196        let msg1: Value = serde_json::json!({"value": "text"});
1197        assert!(validator.is_valid(&msg1));
1198
1199        let msg2: Value = serde_json::json!({"value": 42});
1200        assert!(validator.is_valid(&msg2));
1201
1202        let invalid_msg: Value = serde_json::json!({"value": true});
1203        assert!(!validator.is_valid(&invalid_msg));
1204    }
1205
1206    /// Test response validation with complex constraints
1207    #[test]
1208    fn test_response_schema_with_multiple_constraints() {
1209        let handler: EchoHandler = EchoHandler;
1210        let response_schema: serde_json::Value = serde_json::json!({
1211            "type": "object",
1212            "properties": {
1213                "success": {"type": "boolean"},
1214                "data": {
1215                    "type": "object",
1216                    "properties": {
1217                        "items": {
1218                            "type": "array",
1219                            "items": {"type": "object"},
1220                            "minItems": 1
1221                        }
1222                    },
1223                    "required": ["items"]
1224                }
1225            },
1226            "required": ["success", "data"]
1227        });
1228
1229        let state: WebSocketState<EchoHandler> =
1230            WebSocketState::with_schemas(handler, None, Some(response_schema)).unwrap();
1231
1232        let validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
1233
1234        let valid_response: Value = serde_json::json!({
1235            "success": true,
1236            "data": {
1237                "items": [{"id": 1}]
1238            }
1239        });
1240        assert!(validator.is_valid(&valid_response));
1241
1242        let invalid_response: Value = serde_json::json!({
1243            "success": true,
1244            "data": {
1245                "items": []
1246            }
1247        });
1248        assert!(!validator.is_valid(&invalid_response));
1249
1250        let invalid_response2: Value = serde_json::json!({
1251            "success": true
1252        });
1253        assert!(!validator.is_valid(&invalid_response2));
1254    }
1255
1256    /// Test null type validation
1257    #[test]
1258    fn test_null_value_validation() {
1259        let handler: EchoHandler = EchoHandler;
1260        let message_schema: serde_json::Value = serde_json::json!({
1261            "type": "object",
1262            "properties": {
1263                "optional_field": {"type": ["string", "null"]},
1264                "required_field": {"type": "string"}
1265            },
1266            "required": ["required_field"]
1267        });
1268
1269        let state: WebSocketState<EchoHandler> =
1270            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1271
1272        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1273
1274        let msg1: Value = serde_json::json!({
1275            "optional_field": null,
1276            "required_field": "value"
1277        });
1278        assert!(validator.is_valid(&msg1));
1279
1280        let msg2: Value = serde_json::json!({"required_field": "value"});
1281        assert!(validator.is_valid(&msg2));
1282
1283        let invalid_msg: Value = serde_json::json!({"required_field": null});
1284        assert!(!validator.is_valid(&invalid_msg));
1285    }
1286
1287    /// Test schema with default values (they don't change validation)
1288    #[test]
1289    fn test_schema_with_defaults_still_validates() {
1290        let handler: EchoHandler = EchoHandler;
1291        let message_schema: serde_json::Value = serde_json::json!({
1292            "type": "object",
1293            "properties": {
1294                "status": {
1295                    "type": "string",
1296                    "default": "pending"
1297                }
1298            }
1299        });
1300
1301        let state: WebSocketState<EchoHandler> =
1302            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1303
1304        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1305
1306        let msg: Value = serde_json::json!({});
1307        assert!(validator.is_valid(&msg));
1308    }
1309
1310    /// Test both message and response schema validation together
1311    #[test]
1312    fn test_both_schemas_validate_independently() {
1313        let handler: EchoHandler = EchoHandler;
1314        let message_schema: serde_json::Value = serde_json::json!({
1315            "type": "object",
1316            "properties": {"action": {"type": "string"}},
1317            "required": ["action"]
1318        });
1319        let response_schema: serde_json::Value = serde_json::json!({
1320            "type": "object",
1321            "properties": {"result": {"type": "string"}},
1322            "required": ["result"]
1323        });
1324
1325        let state: WebSocketState<EchoHandler> =
1326            WebSocketState::with_schemas(handler, Some(message_schema), Some(response_schema)).unwrap();
1327
1328        let msg_validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1329        let resp_validator: &jsonschema::Validator = state.response_schema.as_ref().unwrap();
1330
1331        let valid_msg: Value = serde_json::json!({"action": "test"});
1332        let invalid_response: Value = serde_json::json!({"data": "oops"});
1333
1334        assert!(msg_validator.is_valid(&valid_msg));
1335        assert!(!resp_validator.is_valid(&invalid_response));
1336
1337        let invalid_msg: Value = serde_json::json!({"data": "oops"});
1338        let valid_response: Value = serde_json::json!({"result": "ok"});
1339
1340        assert!(!msg_validator.is_valid(&invalid_msg));
1341        assert!(resp_validator.is_valid(&valid_response));
1342    }
1343
1344    /// Test validation with very long/large payload
1345    #[test]
1346    fn test_validation_with_large_payload() {
1347        let handler: EchoHandler = EchoHandler;
1348        let message_schema: serde_json::Value = serde_json::json!({
1349            "type": "object",
1350            "properties": {
1351                "items": {
1352                    "type": "array",
1353                    "items": {"type": "integer"}
1354                }
1355            },
1356            "required": ["items"]
1357        });
1358
1359        let state: WebSocketState<EchoHandler> =
1360            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1361
1362        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1363
1364        let mut items = Vec::new();
1365        for i in 0..10_000 {
1366            items.push(i);
1367        }
1368        let large_msg: Value = serde_json::json!({"items": items});
1369
1370        assert!(validator.is_valid(&large_msg));
1371    }
1372
1373    /// Test validation error doesn't panic with invalid schema combinations
1374    #[test]
1375    fn test_mutually_exclusive_schema_properties() {
1376        let handler: EchoHandler = EchoHandler;
1377
1378        let message_schema: serde_json::Value = serde_json::json!({
1379            "allOf": [
1380                {
1381                    "type": "object",
1382                    "properties": {"a": {"type": "string"}},
1383                    "required": ["a"]
1384                },
1385                {
1386                    "type": "object",
1387                    "properties": {"b": {"type": "integer"}},
1388                    "required": ["b"]
1389                }
1390            ]
1391        });
1392
1393        let state: WebSocketState<EchoHandler> =
1394            WebSocketState::with_schemas(handler, Some(message_schema), None).unwrap();
1395
1396        let validator: &jsonschema::Validator = state.message_schema.as_ref().unwrap();
1397
1398        let valid_msg: Value = serde_json::json!({"a": "text", "b": 42});
1399        assert!(validator.is_valid(&valid_msg));
1400
1401        let invalid_msg: Value = serde_json::json!({"a": "text"});
1402        assert!(!validator.is_valid(&invalid_msg));
1403    }
1404}