Skip to main content

mockforge_contracts/contract_drift/
websocket_contract.rs

1//! WebSocket contract implementation for protocol-agnostic contract drift detection
2//!
3//! This module provides a `WebSocketContract` struct that implements the `ProtocolContract` trait
4//! for WebSocket connections, enabling drift detection and analysis for WebSocket message schemas
5//! and topics.
6
7use crate::contract_drift::protocol_contracts::{
8    ContractError, ContractOperation, ContractRequest, OperationType, ProtocolContract,
9    ValidationError, ValidationResult,
10};
11use jsonschema::{self, Draft, Validator as JSONSchema};
12use mockforge_foundation::contract_diff_types::{
13    ContractDiffResult, Mismatch, MismatchSeverity, MismatchType,
14};
15use mockforge_foundation::protocol::Protocol;
16use serde_json::Value;
17use std::collections::HashMap;
18
19// WebSocketMessageType + MessageDirection re-exported from foundation.
20pub use mockforge_foundation::protocol_contract_types::{MessageDirection, WebSocketMessageType};
21
22/// WebSocket contract implementation
23///
24/// Defines message types and topics for a WebSocket connection, enabling
25/// schema validation and drift detection.
26pub struct WebSocketContract {
27    /// Unique identifier for this contract
28    contract_id: String,
29    /// Contract version
30    version: String,
31    /// Map of message type identifiers to message type definitions
32    message_types: HashMap<String, WebSocketMessageType>,
33    /// Map of topics to message types that can be sent on that topic
34    topics: HashMap<String, Vec<String>>,
35    /// Compiled JSON schemas for validation (cached)
36    schema_cache: HashMap<String, JSONSchema>,
37    /// Cached contract operations for quick lookup
38    operations_cache: HashMap<String, ContractOperation>,
39    /// Contract metadata
40    metadata: HashMap<String, String>,
41}
42
43impl WebSocketContract {
44    /// Create a new WebSocket contract
45    pub fn new(contract_id: String, version: String) -> Self {
46        Self {
47            contract_id,
48            version,
49            message_types: HashMap::new(),
50            topics: HashMap::new(),
51            schema_cache: HashMap::new(),
52            operations_cache: HashMap::new(),
53            metadata: HashMap::new(),
54        }
55    }
56
57    /// Add a message type to the contract
58    pub fn add_message_type(
59        &mut self,
60        message_type: WebSocketMessageType,
61    ) -> Result<(), ContractError> {
62        let message_type_id = message_type.message_type.clone();
63
64        // Compile and cache the JSON schema for validation
65        let schema = JSONSchema::options()
66            .with_draft(Draft::Draft7)
67            .build(&message_type.schema)
68            .map_err(|e| ContractError::SchemaValidation(format!("Invalid JSON schema: {}", e)))?;
69        self.schema_cache.insert(message_type_id.clone(), schema);
70
71        // Add to message types
72        self.message_types.insert(message_type_id.clone(), message_type.clone());
73
74        // Build operation ID (topic:message_type or just message_type)
75        let operation_id = if let Some(ref topic) = message_type.topic {
76            format!("{}:{}", topic, message_type_id)
77        } else {
78            message_type_id.clone()
79        };
80
81        // Cache the contract operation
82        let operation = ContractOperation {
83            id: operation_id.clone(),
84            name: message_type.message_type.clone(),
85            operation_type: OperationType::WebSocketMessage {
86                message_type: message_type.message_type.clone(),
87                topic: message_type.topic.clone(),
88            },
89            input_schema: Some(message_type.schema.clone()),
90            output_schema: Some(message_type.schema.clone()), // WebSocket messages can be bidirectional
91            metadata: {
92                let mut meta = HashMap::new();
93                meta.insert("direction".to_string(), format!("{:?}", message_type.direction));
94                if let Some(ref desc) = message_type.description {
95                    meta.insert("description".to_string(), desc.clone());
96                }
97                meta
98            },
99        };
100        self.operations_cache.insert(operation_id, operation);
101
102        // Index by topic if topic is specified
103        if let Some(topic) = &message_type.topic {
104            self.topics.entry(topic.clone()).or_default().push(message_type_id);
105        }
106
107        Ok(())
108    }
109
110    /// Remove a message type from the contract
111    pub fn remove_message_type(&mut self, message_type_id: &str) {
112        if let Some(message_type) = self.message_types.remove(message_type_id) {
113            self.schema_cache.remove(message_type_id);
114
115            // Remove from topic index
116            if let Some(topic) = &message_type.topic {
117                if let Some(message_types) = self.topics.get_mut(topic) {
118                    message_types.retain(|id| id != message_type_id);
119                    if message_types.is_empty() {
120                        self.topics.remove(topic);
121                    }
122                }
123            }
124
125            // Store topic before moving message_type
126            let topic = message_type.topic.clone();
127
128            // Remove from operations cache
129            let operation_id = if let Some(ref topic_name) = topic {
130                format!("{}:{}", topic_name, message_type_id)
131            } else {
132                message_type_id.to_string()
133            };
134            self.operations_cache.remove(&operation_id);
135        }
136    }
137
138    /// Get message types for a specific topic
139    pub fn get_message_types_for_topic(&self, topic: &str) -> Vec<&WebSocketMessageType> {
140        self.topics
141            .get(topic)
142            .map(|ids| ids.iter().filter_map(|id| self.message_types.get(id)).collect())
143            .unwrap_or_default()
144    }
145
146    /// Compare two WebSocket contracts and detect differences
147    fn diff_contracts(
148        &self,
149        other: &WebSocketContract,
150    ) -> Result<ContractDiffResult, ContractError> {
151        let mut mismatches = Vec::new();
152
153        // Collect all message type IDs
154        let all_message_types: std::collections::HashSet<String> =
155            self.message_types.keys().chain(other.message_types.keys()).cloned().collect();
156
157        // Check for removed message types (breaking change)
158        for message_type_id in &all_message_types {
159            if self.message_types.contains_key(message_type_id)
160                && !other.message_types.contains_key(message_type_id)
161            {
162                let mut context = HashMap::new();
163                context.insert("is_additive".to_string(), serde_json::json!(false));
164                context.insert("is_breaking".to_string(), serde_json::json!(true));
165                context.insert(
166                    "change_category".to_string(),
167                    serde_json::json!("message_type_removed"),
168                );
169                context.insert("message_type".to_string(), serde_json::json!(message_type_id));
170
171                mismatches.push(Mismatch {
172                    mismatch_type: MismatchType::EndpointNotFound,
173                    path: message_type_id.clone(),
174                    method: None,
175                    expected: Some(format!("Message type {} should exist", message_type_id)),
176                    actual: Some("Message type removed".to_string()),
177                    description: format!("Message type {} was removed", message_type_id),
178                    severity: MismatchSeverity::Critical,
179                    confidence: 1.0,
180                    context,
181                });
182            }
183        }
184
185        // Check for added message types (non-breaking, additive)
186        for message_type_id in &all_message_types {
187            if !self.message_types.contains_key(message_type_id)
188                && other.message_types.contains_key(message_type_id)
189            {
190                let mut context = HashMap::new();
191                context.insert("is_additive".to_string(), serde_json::json!(true));
192                context.insert("is_breaking".to_string(), serde_json::json!(false));
193                context
194                    .insert("change_category".to_string(), serde_json::json!("message_type_added"));
195                context.insert("message_type".to_string(), serde_json::json!(message_type_id));
196
197                mismatches.push(Mismatch {
198                    mismatch_type: MismatchType::UnexpectedField,
199                    path: message_type_id.clone(),
200                    method: None,
201                    expected: None,
202                    actual: Some(format!("New message type {}", message_type_id)),
203                    description: format!("New message type {} was added", message_type_id),
204                    severity: MismatchSeverity::Low,
205                    confidence: 1.0,
206                    context,
207                });
208            }
209        }
210
211        // Compare message type schemas for types that exist in both
212        for message_type_id in all_message_types.intersection(
213            &self.message_types.keys().cloned().collect::<std::collections::HashSet<_>>(),
214        ) {
215            if let (Some(old_type), Some(new_type)) = (
216                self.message_types.get(message_type_id),
217                other.message_types.get(message_type_id),
218            ) {
219                let schema_mismatches =
220                    Self::diff_message_type_schemas(message_type_id, old_type, new_type)?;
221                mismatches.extend(schema_mismatches);
222
223                // Check for topic changes (breaking change)
224                if old_type.topic != new_type.topic {
225                    let mut context = HashMap::new();
226                    context.insert("is_additive".to_string(), serde_json::json!(false));
227                    context.insert("is_breaking".to_string(), serde_json::json!(true));
228                    context
229                        .insert("change_category".to_string(), serde_json::json!("topic_changed"));
230                    context.insert("message_type".to_string(), serde_json::json!(message_type_id));
231                    context.insert("old_topic".to_string(), serde_json::json!(old_type.topic));
232                    context.insert("new_topic".to_string(), serde_json::json!(new_type.topic));
233
234                    mismatches.push(Mismatch {
235                        mismatch_type: MismatchType::SchemaMismatch,
236                        path: format!("{}.topic", message_type_id),
237                        method: None,
238                        expected: old_type.topic.clone().map(|t| format!("Topic: {}", t)),
239                        actual: new_type.topic.clone().map(|t| format!("Topic: {}", t)),
240                        description: format!(
241                            "Topic changed for message type {}: {:?} -> {:?}",
242                            message_type_id, old_type.topic, new_type.topic
243                        ),
244                        severity: MismatchSeverity::High,
245                        confidence: 1.0,
246                        context,
247                    });
248                }
249
250                // Check for direction changes (breaking change)
251                if old_type.direction != new_type.direction {
252                    let mut context = HashMap::new();
253                    context.insert("is_additive".to_string(), serde_json::json!(false));
254                    context.insert("is_breaking".to_string(), serde_json::json!(true));
255                    context.insert(
256                        "change_category".to_string(),
257                        serde_json::json!("direction_changed"),
258                    );
259                    context.insert("message_type".to_string(), serde_json::json!(message_type_id));
260                    context.insert(
261                        "old_direction".to_string(),
262                        serde_json::json!(format!("{:?}", old_type.direction)),
263                    );
264                    context.insert(
265                        "new_direction".to_string(),
266                        serde_json::json!(format!("{:?}", new_type.direction)),
267                    );
268
269                    mismatches.push(Mismatch {
270                        mismatch_type: MismatchType::SchemaMismatch,
271                        path: format!("{}.direction", message_type_id),
272                        method: None,
273                        expected: Some(format!("Direction: {:?}", old_type.direction)),
274                        actual: Some(format!("Direction: {:?}", new_type.direction)),
275                        description: format!(
276                            "Direction changed for message type {}: {:?} -> {:?}",
277                            message_type_id, old_type.direction, new_type.direction
278                        ),
279                        severity: MismatchSeverity::High,
280                        confidence: 1.0,
281                        context,
282                    });
283                }
284            }
285        }
286
287        // Compare topics
288        let all_topics: std::collections::HashSet<String> =
289            self.topics.keys().chain(other.topics.keys()).cloned().collect();
290
291        for topic in &all_topics {
292            let old_message_types = self.get_message_types_for_topic(topic);
293            let new_message_types = other.get_message_types_for_topic(topic);
294
295            let old_ids: std::collections::HashSet<String> =
296                old_message_types.iter().map(|mt| mt.message_type.clone()).collect();
297            let new_ids: std::collections::HashSet<String> =
298                new_message_types.iter().map(|mt| mt.message_type.clone()).collect();
299
300            // Check for removed message types from topic
301            for removed_id in old_ids.difference(&new_ids) {
302                mismatches.push(Mismatch {
303                    mismatch_type: MismatchType::SchemaMismatch,
304                    path: format!("topic:{}.{}", topic, removed_id),
305                    method: None,
306                    expected: Some(format!(
307                        "Message type {} should be available on topic {}",
308                        removed_id, topic
309                    )),
310                    actual: Some("Message type removed from topic".to_string()),
311                    description: format!(
312                        "Message type {} was removed from topic {}",
313                        removed_id, topic
314                    ),
315                    severity: MismatchSeverity::High,
316                    confidence: 1.0,
317                    context: HashMap::new(),
318                });
319            }
320        }
321
322        let matches = mismatches.is_empty();
323        let confidence = if matches { 1.0 } else { 0.8 };
324
325        Ok(ContractDiffResult {
326            matches,
327            confidence,
328            mismatches,
329            recommendations: Vec::new(),
330            corrections: Vec::new(),
331            metadata: mockforge_foundation::contract_diff_types::DiffMetadata {
332                analyzed_at: chrono::Utc::now(),
333                request_source: "websocket_contract_diff".to_string(),
334                contract_version: Some(self.version.clone()),
335                contract_format: "websocket_schema".to_string(),
336                endpoint_path: "".to_string(),
337                http_method: "".to_string(),
338                request_count: 1,
339                llm_provider: None,
340                llm_model: None,
341            },
342        })
343    }
344
345    /// Compare message type schemas
346    fn diff_message_type_schemas(
347        message_type_id: &str,
348        old_type: &WebSocketMessageType,
349        new_type: &WebSocketMessageType,
350    ) -> Result<Vec<Mismatch>, ContractError> {
351        let mut mismatches = Vec::new();
352
353        // Detect schema format
354        let old_format = Self::detect_schema_format(&old_type.schema);
355        let new_format = Self::detect_schema_format(&new_type.schema);
356
357        // Check for schema format changes (breaking change)
358        if old_format != new_format {
359            let mut context = HashMap::new();
360            context.insert("is_additive".to_string(), serde_json::json!(false));
361            context.insert("is_breaking".to_string(), serde_json::json!(true));
362            context
363                .insert("change_category".to_string(), serde_json::json!("schema_format_changed"));
364            context.insert("message_type".to_string(), serde_json::json!(message_type_id));
365            context.insert("old_format".to_string(), serde_json::json!(old_format));
366            context.insert("new_format".to_string(), serde_json::json!(new_format));
367
368            mismatches.push(Mismatch {
369                mismatch_type: MismatchType::SchemaMismatch,
370                path: format!("{}.schema_format", message_type_id),
371                method: None,
372                expected: Some(format!("Schema format: {}", old_format)),
373                actual: Some(format!("Schema format: {}", new_format)),
374                description: format!(
375                    "Schema format changed from {} to {} for message type {}",
376                    old_format, new_format, message_type_id
377                ),
378                severity: MismatchSeverity::High,
379                confidence: 1.0,
380                context,
381            });
382        }
383
384        // Compare schemas based on format
385        if old_type.schema != new_type.schema {
386            match (old_format.as_str(), new_format.as_str()) {
387                ("json_schema", "json_schema") => {
388                    let schema_diff = Self::compare_json_schemas(
389                        &old_type.schema,
390                        &new_type.schema,
391                        message_type_id,
392                    );
393                    mismatches.extend(schema_diff);
394                }
395                ("avro", "avro") => {
396                    let schema_diff = Self::compare_avro_schemas(
397                        &old_type.schema,
398                        &new_type.schema,
399                        message_type_id,
400                    )?;
401                    mismatches.extend(schema_diff);
402                }
403                ("json_shape", "json_shape") => {
404                    let schema_diff = Self::compare_json_shape_schemas(
405                        &old_type.schema,
406                        &new_type.schema,
407                        message_type_id,
408                    );
409                    mismatches.extend(schema_diff);
410                }
411                _ => {
412                    // Different formats - already handled above
413                }
414            }
415        }
416
417        Ok(mismatches)
418    }
419
420    /// Detect the schema format (JSON Schema, Avro, or JSON-shape)
421    fn detect_schema_format(schema: &Value) -> String {
422        // Check for Avro schema indicators
423        if schema.get("type").and_then(|v| v.as_str()) == Some("record")
424            || schema.get("fields").is_some()
425        {
426            return "avro".to_string();
427        }
428
429        // Check for JSON Schema indicators
430        if schema.get("$schema").is_some()
431            || (schema.get("type").is_some() && schema.get("properties").is_some())
432            || schema.get("required").is_some()
433        {
434            return "json_schema".to_string();
435        }
436
437        // Check for JSON-shape (simple object with type strings)
438        if let Some(obj) = schema.as_object() {
439            let all_strings = obj.values().all(|v| {
440                v.as_str().is_some()
441                    || (v.is_object() && v.get("type").and_then(|t| t.as_str()).is_some())
442            });
443            if all_strings && !obj.is_empty() {
444                return "json_shape".to_string();
445            }
446        }
447
448        // Default to JSON Schema if unclear
449        "json_schema".to_string()
450    }
451
452    /// Compare Avro schemas and identify differences
453    fn compare_avro_schemas(
454        old_schema: &Value,
455        new_schema: &Value,
456        path_prefix: &str,
457    ) -> Result<Vec<Mismatch>, ContractError> {
458        let mut mismatches = Vec::new();
459
460        // Extract fields from Avro schema
461        let old_fields = old_schema.get("fields").and_then(|v| v.as_array()).ok_or_else(|| {
462            ContractError::SchemaValidation("Invalid Avro schema: missing fields".to_string())
463        })?;
464        let new_fields = new_schema.get("fields").and_then(|v| v.as_array()).ok_or_else(|| {
465            ContractError::SchemaValidation("Invalid Avro schema: missing fields".to_string())
466        })?;
467
468        // Build field maps by name
469        let old_fields_map: HashMap<String, &Value> = old_fields
470            .iter()
471            .filter_map(|f| {
472                f.get("name").and_then(|n| n.as_str()).map(|name| (name.to_string(), f))
473            })
474            .collect();
475        let new_fields_map: HashMap<String, &Value> = new_fields
476            .iter()
477            .filter_map(|f| {
478                f.get("name").and_then(|n| n.as_str()).map(|name| (name.to_string(), f))
479            })
480            .collect();
481
482        // Check for removed fields (breaking change)
483        for field_name in old_fields_map.keys() {
484            if !new_fields_map.contains_key(field_name) {
485                let mut context = HashMap::new();
486                context.insert("is_additive".to_string(), serde_json::json!(false));
487                context.insert("is_breaking".to_string(), serde_json::json!(true));
488                context.insert("change_category".to_string(), serde_json::json!("field_removed"));
489                context.insert("field_name".to_string(), serde_json::json!(field_name));
490                context.insert("schema_format".to_string(), serde_json::json!("avro"));
491
492                mismatches.push(Mismatch {
493                    mismatch_type: MismatchType::EndpointNotFound,
494                    path: format!("{}.{}", path_prefix, field_name),
495                    method: None,
496                    expected: Some(format!("Field {} should exist", field_name)),
497                    actual: Some("Field removed".to_string()),
498                    description: format!("Avro field {} was removed", field_name),
499                    severity: MismatchSeverity::High,
500                    confidence: 1.0,
501                    context,
502                });
503            }
504        }
505
506        // Check for added fields
507        for (field_name, new_field) in &new_fields_map {
508            if !old_fields_map.contains_key(field_name) {
509                // In Avro, fields without defaults are required
510                let has_default = new_field.get("default").is_some();
511                let is_required = !has_default;
512
513                let mut context = HashMap::new();
514                context.insert("is_additive".to_string(), serde_json::json!(!is_required));
515                context.insert("is_breaking".to_string(), serde_json::json!(is_required));
516                context.insert(
517                    "change_category".to_string(),
518                    serde_json::json!(if is_required {
519                        "required_field_added"
520                    } else {
521                        "field_added"
522                    }),
523                );
524                context.insert("field_name".to_string(), serde_json::json!(field_name));
525                context.insert("schema_format".to_string(), serde_json::json!("avro"));
526                context.insert("has_default".to_string(), serde_json::json!(has_default));
527
528                mismatches.push(Mismatch {
529                    mismatch_type: if is_required {
530                        MismatchType::MissingRequiredField
531                    } else {
532                        MismatchType::UnexpectedField
533                    },
534                    path: format!("{}.{}", path_prefix, field_name),
535                    method: None,
536                    expected: None,
537                    actual: Some(format!(
538                        "New Avro field {} ({})",
539                        field_name,
540                        if is_required { "required" } else { "optional" }
541                    )),
542                    description: format!(
543                        "New Avro field {} was added ({})",
544                        field_name,
545                        if is_required {
546                            "required - breaking"
547                        } else {
548                            "optional - additive"
549                        }
550                    ),
551                    severity: if is_required {
552                        MismatchSeverity::High
553                    } else {
554                        MismatchSeverity::Low
555                    },
556                    confidence: 1.0,
557                    context,
558                });
559            } else {
560                // Check for type changes
561                let old_field = old_fields_map[field_name];
562                let old_type = old_field.get("type");
563                let new_type = new_field.get("type");
564
565                if old_type != new_type {
566                    let mut context = HashMap::new();
567                    context.insert("is_additive".to_string(), serde_json::json!(false));
568                    context.insert("is_breaking".to_string(), serde_json::json!(true));
569                    context.insert(
570                        "change_category".to_string(),
571                        serde_json::json!("field_type_changed"),
572                    );
573                    context.insert("field_name".to_string(), serde_json::json!(field_name));
574                    context.insert("schema_format".to_string(), serde_json::json!("avro"));
575                    context.insert("old_type".to_string(), serde_json::json!(old_type));
576                    context.insert("new_type".to_string(), serde_json::json!(new_type));
577
578                    mismatches.push(Mismatch {
579                        mismatch_type: MismatchType::TypeMismatch,
580                        path: format!("{}.{}", path_prefix, field_name),
581                        method: None,
582                        expected: Some(format!("Type: {:?}", old_type)),
583                        actual: Some(format!("Type: {:?}", new_type)),
584                        description: format!("Avro field {} type changed", field_name),
585                        severity: MismatchSeverity::High,
586                        confidence: 1.0,
587                        context,
588                    });
589                }
590            }
591        }
592
593        Ok(mismatches)
594    }
595
596    /// Compare JSON-shape schemas (simplified format)
597    fn compare_json_shape_schemas(
598        old_schema: &Value,
599        new_schema: &Value,
600        path_prefix: &str,
601    ) -> Vec<Mismatch> {
602        let mut mismatches = Vec::new();
603
604        if let (Some(old_obj), Some(new_obj)) = (old_schema.as_object(), new_schema.as_object()) {
605            // Check for removed properties (breaking)
606            for (prop_name, _) in old_obj {
607                if !new_obj.contains_key(prop_name) {
608                    let mut context = HashMap::new();
609                    context.insert("is_additive".to_string(), serde_json::json!(false));
610                    context.insert("is_breaking".to_string(), serde_json::json!(true));
611                    context.insert(
612                        "change_category".to_string(),
613                        serde_json::json!("property_removed"),
614                    );
615                    context.insert("field_name".to_string(), serde_json::json!(prop_name));
616                    context.insert("schema_format".to_string(), serde_json::json!("json_shape"));
617
618                    mismatches.push(Mismatch {
619                        mismatch_type: MismatchType::UnexpectedField,
620                        path: format!("{}.{}", path_prefix, prop_name),
621                        method: None,
622                        expected: Some(format!("Property {} should exist", prop_name)),
623                        actual: Some("Property removed".to_string()),
624                        description: format!("Property {} was removed", prop_name),
625                        severity: MismatchSeverity::High,
626                        confidence: 1.0,
627                        context,
628                    });
629                }
630            }
631
632            // Check for added properties (additive)
633            for (prop_name, _) in new_obj {
634                if !old_obj.contains_key(prop_name) {
635                    let mut context = HashMap::new();
636                    context.insert("is_additive".to_string(), serde_json::json!(true));
637                    context.insert("is_breaking".to_string(), serde_json::json!(false));
638                    context
639                        .insert("change_category".to_string(), serde_json::json!("property_added"));
640                    context.insert("field_name".to_string(), serde_json::json!(prop_name));
641                    context.insert("schema_format".to_string(), serde_json::json!("json_shape"));
642
643                    mismatches.push(Mismatch {
644                        mismatch_type: MismatchType::UnexpectedField,
645                        path: format!("{}.{}", path_prefix, prop_name),
646                        method: None,
647                        expected: None,
648                        actual: Some(format!("New property {}", prop_name)),
649                        description: format!("New property {} was added", prop_name),
650                        severity: MismatchSeverity::Low,
651                        confidence: 1.0,
652                        context,
653                    });
654                } else {
655                    // Check for type changes
656                    let old_type = old_obj[prop_name]
657                        .as_str()
658                        .or_else(|| old_obj[prop_name].get("type").and_then(|t| t.as_str()));
659                    let new_type = new_obj[prop_name]
660                        .as_str()
661                        .or_else(|| new_obj[prop_name].get("type").and_then(|t| t.as_str()));
662
663                    if old_type != new_type {
664                        let mut context = HashMap::new();
665                        context.insert("is_additive".to_string(), serde_json::json!(false));
666                        context.insert("is_breaking".to_string(), serde_json::json!(true));
667                        context.insert(
668                            "change_category".to_string(),
669                            serde_json::json!("property_type_changed"),
670                        );
671                        context.insert("field_name".to_string(), serde_json::json!(prop_name));
672                        context
673                            .insert("schema_format".to_string(), serde_json::json!("json_shape"));
674                        context.insert("old_type".to_string(), serde_json::json!(old_type));
675                        context.insert("new_type".to_string(), serde_json::json!(new_type));
676
677                        mismatches.push(Mismatch {
678                            mismatch_type: MismatchType::TypeMismatch,
679                            path: format!("{}.{}", path_prefix, prop_name),
680                            method: None,
681                            expected: old_type.map(|t| format!("Type: {}", t)),
682                            actual: new_type.map(|t| format!("Type: {}", t)),
683                            description: format!("Property {} type changed", prop_name),
684                            severity: MismatchSeverity::High,
685                            confidence: 1.0,
686                            context,
687                        });
688                    }
689                }
690            }
691        }
692
693        mismatches
694    }
695
696    /// Compare two JSON schemas and identify differences
697    fn compare_json_schemas(
698        old_schema: &Value,
699        new_schema: &Value,
700        path_prefix: &str,
701    ) -> Vec<Mismatch> {
702        let mut mismatches = Vec::new();
703
704        // Check for required fields changes
705        if let (Some(old_required), Some(new_required)) = (
706            old_schema.get("required").and_then(|v| v.as_array()),
707            new_schema.get("required").and_then(|v| v.as_array()),
708        ) {
709            let old_required_set: std::collections::HashSet<&str> =
710                old_required.iter().filter_map(|v| v.as_str()).collect();
711            let new_required_set: std::collections::HashSet<&str> =
712                new_required.iter().filter_map(|v| v.as_str()).collect();
713
714            // Check for newly required fields (breaking change)
715            for new_req in new_required_set.difference(&old_required_set) {
716                let mut context = HashMap::new();
717                context.insert("is_additive".to_string(), serde_json::json!(false));
718                context.insert("is_breaking".to_string(), serde_json::json!(true));
719                context.insert(
720                    "change_category".to_string(),
721                    serde_json::json!("required_field_added"),
722                );
723                context.insert("field_name".to_string(), serde_json::json!(new_req));
724
725                mismatches.push(Mismatch {
726                    mismatch_type: MismatchType::MissingRequiredField,
727                    path: format!("{}.{}", path_prefix, new_req),
728                    method: None,
729                    expected: Some(format!("Field {} should be optional", new_req)),
730                    actual: Some(format!("Field {} is now required", new_req)),
731                    description: format!("Field {} became required", new_req),
732                    severity: MismatchSeverity::Critical,
733                    confidence: 1.0,
734                    context,
735                });
736            }
737
738            // Check for removed required fields (additive - field is now optional)
739            for removed_req in old_required_set.difference(&new_required_set) {
740                let mut context = HashMap::new();
741                context.insert("is_additive".to_string(), serde_json::json!(true));
742                context.insert("is_breaking".to_string(), serde_json::json!(false));
743                context.insert(
744                    "change_category".to_string(),
745                    serde_json::json!("required_field_removed"),
746                );
747                context.insert("field_name".to_string(), serde_json::json!(removed_req));
748
749                mismatches.push(Mismatch {
750                    mismatch_type: MismatchType::UnexpectedField,
751                    path: format!("{}.{}", path_prefix, removed_req),
752                    method: None,
753                    expected: Some(format!("Field {} was required", removed_req)),
754                    actual: Some(format!("Field {} is now optional", removed_req)),
755                    description: format!("Field {} is no longer required", removed_req),
756                    severity: MismatchSeverity::Low,
757                    confidence: 1.0,
758                    context,
759                });
760            }
761        }
762
763        // Check for property type changes
764        if let (Some(old_props), Some(new_props)) = (
765            old_schema.get("properties").and_then(|v| v.as_object()),
766            new_schema.get("properties").and_then(|v| v.as_object()),
767        ) {
768            for (prop_name, new_prop_schema) in new_props {
769                if let Some(old_prop_schema) = old_props.get(prop_name) {
770                    if let (Some(old_type), Some(new_type)) = (
771                        old_prop_schema.get("type").and_then(|v| v.as_str()),
772                        new_prop_schema.get("type").and_then(|v| v.as_str()),
773                    ) {
774                        if old_type != new_type {
775                            let mut context = HashMap::new();
776                            context.insert("is_additive".to_string(), serde_json::json!(false));
777                            context.insert("is_breaking".to_string(), serde_json::json!(true));
778                            context.insert(
779                                "change_category".to_string(),
780                                serde_json::json!("property_type_changed"),
781                            );
782                            context.insert("field_name".to_string(), serde_json::json!(prop_name));
783                            context.insert("old_type".to_string(), serde_json::json!(old_type));
784                            context.insert("new_type".to_string(), serde_json::json!(new_type));
785
786                            mismatches.push(Mismatch {
787                                mismatch_type: MismatchType::TypeMismatch,
788                                path: format!("{}.{}", path_prefix, prop_name),
789                                method: None,
790                                expected: Some(format!("Type: {}", old_type)),
791                                actual: Some(format!("Type: {}", new_type)),
792                                description: format!(
793                                    "Property {} type changed from {} to {}",
794                                    prop_name, old_type, new_type
795                                ),
796                                severity: MismatchSeverity::High,
797                                confidence: 1.0,
798                                context,
799                            });
800                        }
801                    }
802                } else {
803                    // New property added (additive change)
804                    let mut context = HashMap::new();
805                    context.insert("is_additive".to_string(), serde_json::json!(true));
806                    context.insert("is_breaking".to_string(), serde_json::json!(false));
807                    context
808                        .insert("change_category".to_string(), serde_json::json!("property_added"));
809                    context.insert("field_name".to_string(), serde_json::json!(prop_name));
810
811                    mismatches.push(Mismatch {
812                        mismatch_type: MismatchType::UnexpectedField,
813                        path: format!("{}.{}", path_prefix, prop_name),
814                        method: None,
815                        expected: None,
816                        actual: Some(format!("New property {}", prop_name)),
817                        description: format!("New property {} was added", prop_name),
818                        severity: MismatchSeverity::Low,
819                        confidence: 1.0,
820                        context,
821                    });
822                }
823            }
824
825            // Check for removed properties (breaking change)
826            for prop_name in old_props.keys() {
827                if !new_props.contains_key(prop_name) {
828                    let mut context = HashMap::new();
829                    context.insert("is_additive".to_string(), serde_json::json!(false));
830                    context.insert("is_breaking".to_string(), serde_json::json!(true));
831                    context.insert(
832                        "change_category".to_string(),
833                        serde_json::json!("property_removed"),
834                    );
835                    context.insert("field_name".to_string(), serde_json::json!(prop_name));
836
837                    mismatches.push(Mismatch {
838                        mismatch_type: MismatchType::UnexpectedField,
839                        path: format!("{}.{}", path_prefix, prop_name),
840                        method: None,
841                        expected: Some(format!("Property {} should exist", prop_name)),
842                        actual: Some("Property removed".to_string()),
843                        description: format!("Property {} was removed", prop_name),
844                        severity: MismatchSeverity::High,
845                        confidence: 1.0,
846                        context,
847                    });
848                }
849            }
850        }
851
852        mismatches
853    }
854
855    /// Validate a message against a message type schema
856    fn validate_message_against_schema(
857        &self,
858        message_type_id: &str,
859        message: &Value,
860    ) -> Result<ValidationResult, ContractError> {
861        let schema = self
862            .schema_cache
863            .get(message_type_id)
864            .ok_or_else(|| ContractError::OperationNotFound(message_type_id.to_string()))?;
865
866        // Use iter_errors instead of validate which returns Result
867        let mut validation_errors = Vec::new();
868        for error in schema.iter_errors(message) {
869            validation_errors.push(ValidationError {
870                message: error.to_string(),
871                path: Some(error.instance_path.to_string()),
872                code: Some("SCHEMA_VALIDATION_ERROR".to_string()),
873            });
874        }
875
876        Ok(ValidationResult {
877            valid: validation_errors.is_empty(),
878            errors: validation_errors,
879            warnings: Vec::new(),
880        })
881    }
882}
883
884#[async_trait::async_trait]
885impl ProtocolContract for WebSocketContract {
886    fn protocol(&self) -> Protocol {
887        Protocol::WebSocket
888    }
889
890    fn contract_id(&self) -> &str {
891        &self.contract_id
892    }
893
894    fn version(&self) -> &str {
895        &self.version
896    }
897
898    fn operations(&self) -> Vec<ContractOperation> {
899        self.operations_cache.values().cloned().collect()
900    }
901
902    fn get_operation(&self, operation_id: &str) -> Option<&ContractOperation> {
903        // Try direct lookup first
904        if let Some(operation) = self.operations_cache.get(operation_id) {
905            return Some(operation);
906        }
907
908        // Try to find by message type only (if operation_id doesn't include topic)
909        if !operation_id.contains(':') {
910            // Search for operation with this message type
911            for operation in self.operations_cache.values() {
912                if let OperationType::WebSocketMessage { message_type, .. } =
913                    &operation.operation_type
914                {
915                    if message_type == operation_id {
916                        return Some(operation);
917                    }
918                }
919            }
920        }
921
922        None
923    }
924
925    async fn diff(
926        &self,
927        other: &dyn ProtocolContract,
928    ) -> Result<ContractDiffResult, ContractError> {
929        // Ensure the other contract is also a WebSocket contract
930        if other.protocol() != Protocol::WebSocket {
931            return Err(ContractError::UnsupportedProtocol(other.protocol()));
932        }
933
934        // Similar limitation as GrpcContract - we need type information to compare
935        Err(ContractError::Other(
936            "Direct comparison of WebSocketContract instances requires type information. \
937             Use WebSocketContract::diff_contracts() for comparing two WebSocketContract instances."
938                .to_string(),
939        ))
940    }
941
942    async fn validate(
943        &self,
944        operation_id: &str,
945        request: &ContractRequest,
946    ) -> Result<ValidationResult, ContractError> {
947        // Parse the message payload as JSON
948        let message: Value = serde_json::from_slice(&request.payload)
949            .map_err(|e| ContractError::SchemaValidation(format!("Invalid JSON: {}", e)))?;
950
951        // Extract message type from operation_id (could be "topic:message_type" or just "message_type")
952        let message_type_id = if let Some((_, message_type)) = operation_id.split_once(':') {
953            message_type
954        } else {
955            operation_id
956        };
957
958        // Validate against the schema
959        self.validate_message_against_schema(message_type_id, &message)
960    }
961
962    fn get_schema(&self, operation_id: &str) -> Option<Value> {
963        // Extract message type from operation_id
964        let message_type_id = if let Some((_, message_type)) = operation_id.split_once(':') {
965            message_type
966        } else {
967            operation_id
968        };
969
970        self.message_types.get(message_type_id).map(|mt| mt.schema.clone())
971    }
972
973    fn to_json(&self) -> Result<Value, ContractError> {
974        let message_types: Vec<Value> = self
975            .message_types
976            .values()
977            .map(|mt| {
978                serde_json::json!({
979                    "message_type": mt.message_type,
980                    "topic": mt.topic,
981                    "schema": mt.schema,
982                    "direction": mt.direction,
983                    "description": mt.description,
984                    "example": mt.example,
985                })
986            })
987            .collect();
988
989        Ok(serde_json::json!({
990            "contract_id": self.contract_id,
991            "version": self.version,
992            "protocol": "websocket",
993            "message_types": message_types,
994            "topics": self.topics.keys().collect::<Vec<_>>(),
995            "metadata": self.metadata,
996        }))
997    }
998}
999
1000/// Helper function to compare two WebSocketContract instances
1001pub fn diff_websocket_contracts(
1002    old_contract: &WebSocketContract,
1003    new_contract: &WebSocketContract,
1004) -> Result<ContractDiffResult, ContractError> {
1005    old_contract.diff_contracts(new_contract)
1006}
1007
1008#[cfg(test)]
1009mod tests {
1010    use super::*;
1011
1012    #[test]
1013    fn test_websocket_contract_creation() {
1014        let contract = WebSocketContract::new("test-contract".to_string(), "1.0.0".to_string());
1015        assert_eq!(contract.contract_id(), "test-contract");
1016        assert_eq!(contract.version(), "1.0.0");
1017    }
1018
1019    #[test]
1020    fn test_add_message_type() {
1021        let mut contract = WebSocketContract::new("test".to_string(), "1.0.0".to_string());
1022        let message_type = WebSocketMessageType {
1023            message_type: "chat_message".to_string(),
1024            topic: Some("chat".to_string()),
1025            schema: serde_json::json!({
1026                "type": "object",
1027                "properties": {
1028                    "text": {"type": "string"},
1029                    "user": {"type": "string"}
1030                },
1031                "required": ["text", "user"]
1032            }),
1033            direction: MessageDirection::Bidirectional,
1034            description: Some("Chat message".to_string()),
1035            example: None,
1036        };
1037
1038        assert!(contract.add_message_type(message_type).is_ok());
1039        assert_eq!(contract.message_types.len(), 1);
1040    }
1041}