1use crate::contract_drift::protocol_contracts::{
8 ContractError, ContractOperation, ContractRequest, OperationType, ProtocolContract,
9 ValidationError, ValidationResult,
10};
11use jsonschema::{self, Draft, Validator as JSONSchema};
12use mockforge_foundation::contract_diff_types::{
13 ContractDiffResult, Mismatch, MismatchSeverity, MismatchType,
14};
15use mockforge_foundation::protocol::Protocol;
16use serde_json::Value;
17use std::collections::HashMap;
18
19pub use mockforge_foundation::protocol_contract_types::{MessageDirection, WebSocketMessageType};
21
22pub struct WebSocketContract {
27 contract_id: String,
29 version: String,
31 message_types: HashMap<String, WebSocketMessageType>,
33 topics: HashMap<String, Vec<String>>,
35 schema_cache: HashMap<String, JSONSchema>,
37 operations_cache: HashMap<String, ContractOperation>,
39 metadata: HashMap<String, String>,
41}
42
43impl WebSocketContract {
44 pub fn new(contract_id: String, version: String) -> Self {
46 Self {
47 contract_id,
48 version,
49 message_types: HashMap::new(),
50 topics: HashMap::new(),
51 schema_cache: HashMap::new(),
52 operations_cache: HashMap::new(),
53 metadata: HashMap::new(),
54 }
55 }
56
57 pub fn add_message_type(
59 &mut self,
60 message_type: WebSocketMessageType,
61 ) -> Result<(), ContractError> {
62 let message_type_id = message_type.message_type.clone();
63
64 let schema = JSONSchema::options()
66 .with_draft(Draft::Draft7)
67 .build(&message_type.schema)
68 .map_err(|e| ContractError::SchemaValidation(format!("Invalid JSON schema: {}", e)))?;
69 self.schema_cache.insert(message_type_id.clone(), schema);
70
71 self.message_types.insert(message_type_id.clone(), message_type.clone());
73
74 let operation_id = if let Some(ref topic) = message_type.topic {
76 format!("{}:{}", topic, message_type_id)
77 } else {
78 message_type_id.clone()
79 };
80
81 let operation = ContractOperation {
83 id: operation_id.clone(),
84 name: message_type.message_type.clone(),
85 operation_type: OperationType::WebSocketMessage {
86 message_type: message_type.message_type.clone(),
87 topic: message_type.topic.clone(),
88 },
89 input_schema: Some(message_type.schema.clone()),
90 output_schema: Some(message_type.schema.clone()), metadata: {
92 let mut meta = HashMap::new();
93 meta.insert("direction".to_string(), format!("{:?}", message_type.direction));
94 if let Some(ref desc) = message_type.description {
95 meta.insert("description".to_string(), desc.clone());
96 }
97 meta
98 },
99 };
100 self.operations_cache.insert(operation_id, operation);
101
102 if let Some(topic) = &message_type.topic {
104 self.topics.entry(topic.clone()).or_default().push(message_type_id);
105 }
106
107 Ok(())
108 }
109
110 pub fn remove_message_type(&mut self, message_type_id: &str) {
112 if let Some(message_type) = self.message_types.remove(message_type_id) {
113 self.schema_cache.remove(message_type_id);
114
115 if let Some(topic) = &message_type.topic {
117 if let Some(message_types) = self.topics.get_mut(topic) {
118 message_types.retain(|id| id != message_type_id);
119 if message_types.is_empty() {
120 self.topics.remove(topic);
121 }
122 }
123 }
124
125 let topic = message_type.topic.clone();
127
128 let operation_id = if let Some(ref topic_name) = topic {
130 format!("{}:{}", topic_name, message_type_id)
131 } else {
132 message_type_id.to_string()
133 };
134 self.operations_cache.remove(&operation_id);
135 }
136 }
137
138 pub fn get_message_types_for_topic(&self, topic: &str) -> Vec<&WebSocketMessageType> {
140 self.topics
141 .get(topic)
142 .map(|ids| ids.iter().filter_map(|id| self.message_types.get(id)).collect())
143 .unwrap_or_default()
144 }
145
146 fn diff_contracts(
148 &self,
149 other: &WebSocketContract,
150 ) -> Result<ContractDiffResult, ContractError> {
151 let mut mismatches = Vec::new();
152
153 let all_message_types: std::collections::HashSet<String> =
155 self.message_types.keys().chain(other.message_types.keys()).cloned().collect();
156
157 for message_type_id in &all_message_types {
159 if self.message_types.contains_key(message_type_id)
160 && !other.message_types.contains_key(message_type_id)
161 {
162 let mut context = HashMap::new();
163 context.insert("is_additive".to_string(), serde_json::json!(false));
164 context.insert("is_breaking".to_string(), serde_json::json!(true));
165 context.insert(
166 "change_category".to_string(),
167 serde_json::json!("message_type_removed"),
168 );
169 context.insert("message_type".to_string(), serde_json::json!(message_type_id));
170
171 mismatches.push(Mismatch {
172 mismatch_type: MismatchType::EndpointNotFound,
173 path: message_type_id.clone(),
174 method: None,
175 expected: Some(format!("Message type {} should exist", message_type_id)),
176 actual: Some("Message type removed".to_string()),
177 description: format!("Message type {} was removed", message_type_id),
178 severity: MismatchSeverity::Critical,
179 confidence: 1.0,
180 context,
181 });
182 }
183 }
184
185 for message_type_id in &all_message_types {
187 if !self.message_types.contains_key(message_type_id)
188 && other.message_types.contains_key(message_type_id)
189 {
190 let mut context = HashMap::new();
191 context.insert("is_additive".to_string(), serde_json::json!(true));
192 context.insert("is_breaking".to_string(), serde_json::json!(false));
193 context
194 .insert("change_category".to_string(), serde_json::json!("message_type_added"));
195 context.insert("message_type".to_string(), serde_json::json!(message_type_id));
196
197 mismatches.push(Mismatch {
198 mismatch_type: MismatchType::UnexpectedField,
199 path: message_type_id.clone(),
200 method: None,
201 expected: None,
202 actual: Some(format!("New message type {}", message_type_id)),
203 description: format!("New message type {} was added", message_type_id),
204 severity: MismatchSeverity::Low,
205 confidence: 1.0,
206 context,
207 });
208 }
209 }
210
211 for message_type_id in all_message_types.intersection(
213 &self.message_types.keys().cloned().collect::<std::collections::HashSet<_>>(),
214 ) {
215 if let (Some(old_type), Some(new_type)) = (
216 self.message_types.get(message_type_id),
217 other.message_types.get(message_type_id),
218 ) {
219 let schema_mismatches =
220 Self::diff_message_type_schemas(message_type_id, old_type, new_type)?;
221 mismatches.extend(schema_mismatches);
222
223 if old_type.topic != new_type.topic {
225 let mut context = HashMap::new();
226 context.insert("is_additive".to_string(), serde_json::json!(false));
227 context.insert("is_breaking".to_string(), serde_json::json!(true));
228 context
229 .insert("change_category".to_string(), serde_json::json!("topic_changed"));
230 context.insert("message_type".to_string(), serde_json::json!(message_type_id));
231 context.insert("old_topic".to_string(), serde_json::json!(old_type.topic));
232 context.insert("new_topic".to_string(), serde_json::json!(new_type.topic));
233
234 mismatches.push(Mismatch {
235 mismatch_type: MismatchType::SchemaMismatch,
236 path: format!("{}.topic", message_type_id),
237 method: None,
238 expected: old_type.topic.clone().map(|t| format!("Topic: {}", t)),
239 actual: new_type.topic.clone().map(|t| format!("Topic: {}", t)),
240 description: format!(
241 "Topic changed for message type {}: {:?} -> {:?}",
242 message_type_id, old_type.topic, new_type.topic
243 ),
244 severity: MismatchSeverity::High,
245 confidence: 1.0,
246 context,
247 });
248 }
249
250 if old_type.direction != new_type.direction {
252 let mut context = HashMap::new();
253 context.insert("is_additive".to_string(), serde_json::json!(false));
254 context.insert("is_breaking".to_string(), serde_json::json!(true));
255 context.insert(
256 "change_category".to_string(),
257 serde_json::json!("direction_changed"),
258 );
259 context.insert("message_type".to_string(), serde_json::json!(message_type_id));
260 context.insert(
261 "old_direction".to_string(),
262 serde_json::json!(format!("{:?}", old_type.direction)),
263 );
264 context.insert(
265 "new_direction".to_string(),
266 serde_json::json!(format!("{:?}", new_type.direction)),
267 );
268
269 mismatches.push(Mismatch {
270 mismatch_type: MismatchType::SchemaMismatch,
271 path: format!("{}.direction", message_type_id),
272 method: None,
273 expected: Some(format!("Direction: {:?}", old_type.direction)),
274 actual: Some(format!("Direction: {:?}", new_type.direction)),
275 description: format!(
276 "Direction changed for message type {}: {:?} -> {:?}",
277 message_type_id, old_type.direction, new_type.direction
278 ),
279 severity: MismatchSeverity::High,
280 confidence: 1.0,
281 context,
282 });
283 }
284 }
285 }
286
287 let all_topics: std::collections::HashSet<String> =
289 self.topics.keys().chain(other.topics.keys()).cloned().collect();
290
291 for topic in &all_topics {
292 let old_message_types = self.get_message_types_for_topic(topic);
293 let new_message_types = other.get_message_types_for_topic(topic);
294
295 let old_ids: std::collections::HashSet<String> =
296 old_message_types.iter().map(|mt| mt.message_type.clone()).collect();
297 let new_ids: std::collections::HashSet<String> =
298 new_message_types.iter().map(|mt| mt.message_type.clone()).collect();
299
300 for removed_id in old_ids.difference(&new_ids) {
302 mismatches.push(Mismatch {
303 mismatch_type: MismatchType::SchemaMismatch,
304 path: format!("topic:{}.{}", topic, removed_id),
305 method: None,
306 expected: Some(format!(
307 "Message type {} should be available on topic {}",
308 removed_id, topic
309 )),
310 actual: Some("Message type removed from topic".to_string()),
311 description: format!(
312 "Message type {} was removed from topic {}",
313 removed_id, topic
314 ),
315 severity: MismatchSeverity::High,
316 confidence: 1.0,
317 context: HashMap::new(),
318 });
319 }
320 }
321
322 let matches = mismatches.is_empty();
323 let confidence = if matches { 1.0 } else { 0.8 };
324
325 Ok(ContractDiffResult {
326 matches,
327 confidence,
328 mismatches,
329 recommendations: Vec::new(),
330 corrections: Vec::new(),
331 metadata: mockforge_foundation::contract_diff_types::DiffMetadata {
332 analyzed_at: chrono::Utc::now(),
333 request_source: "websocket_contract_diff".to_string(),
334 contract_version: Some(self.version.clone()),
335 contract_format: "websocket_schema".to_string(),
336 endpoint_path: "".to_string(),
337 http_method: "".to_string(),
338 request_count: 1,
339 llm_provider: None,
340 llm_model: None,
341 },
342 })
343 }
344
345 fn diff_message_type_schemas(
347 message_type_id: &str,
348 old_type: &WebSocketMessageType,
349 new_type: &WebSocketMessageType,
350 ) -> Result<Vec<Mismatch>, ContractError> {
351 let mut mismatches = Vec::new();
352
353 let old_format = Self::detect_schema_format(&old_type.schema);
355 let new_format = Self::detect_schema_format(&new_type.schema);
356
357 if old_format != new_format {
359 let mut context = HashMap::new();
360 context.insert("is_additive".to_string(), serde_json::json!(false));
361 context.insert("is_breaking".to_string(), serde_json::json!(true));
362 context
363 .insert("change_category".to_string(), serde_json::json!("schema_format_changed"));
364 context.insert("message_type".to_string(), serde_json::json!(message_type_id));
365 context.insert("old_format".to_string(), serde_json::json!(old_format));
366 context.insert("new_format".to_string(), serde_json::json!(new_format));
367
368 mismatches.push(Mismatch {
369 mismatch_type: MismatchType::SchemaMismatch,
370 path: format!("{}.schema_format", message_type_id),
371 method: None,
372 expected: Some(format!("Schema format: {}", old_format)),
373 actual: Some(format!("Schema format: {}", new_format)),
374 description: format!(
375 "Schema format changed from {} to {} for message type {}",
376 old_format, new_format, message_type_id
377 ),
378 severity: MismatchSeverity::High,
379 confidence: 1.0,
380 context,
381 });
382 }
383
384 if old_type.schema != new_type.schema {
386 match (old_format.as_str(), new_format.as_str()) {
387 ("json_schema", "json_schema") => {
388 let schema_diff = Self::compare_json_schemas(
389 &old_type.schema,
390 &new_type.schema,
391 message_type_id,
392 );
393 mismatches.extend(schema_diff);
394 }
395 ("avro", "avro") => {
396 let schema_diff = Self::compare_avro_schemas(
397 &old_type.schema,
398 &new_type.schema,
399 message_type_id,
400 )?;
401 mismatches.extend(schema_diff);
402 }
403 ("json_shape", "json_shape") => {
404 let schema_diff = Self::compare_json_shape_schemas(
405 &old_type.schema,
406 &new_type.schema,
407 message_type_id,
408 );
409 mismatches.extend(schema_diff);
410 }
411 _ => {
412 }
414 }
415 }
416
417 Ok(mismatches)
418 }
419
420 fn detect_schema_format(schema: &Value) -> String {
422 if schema.get("type").and_then(|v| v.as_str()) == Some("record")
424 || schema.get("fields").is_some()
425 {
426 return "avro".to_string();
427 }
428
429 if schema.get("$schema").is_some()
431 || (schema.get("type").is_some() && schema.get("properties").is_some())
432 || schema.get("required").is_some()
433 {
434 return "json_schema".to_string();
435 }
436
437 if let Some(obj) = schema.as_object() {
439 let all_strings = obj.values().all(|v| {
440 v.as_str().is_some()
441 || (v.is_object() && v.get("type").and_then(|t| t.as_str()).is_some())
442 });
443 if all_strings && !obj.is_empty() {
444 return "json_shape".to_string();
445 }
446 }
447
448 "json_schema".to_string()
450 }
451
452 fn compare_avro_schemas(
454 old_schema: &Value,
455 new_schema: &Value,
456 path_prefix: &str,
457 ) -> Result<Vec<Mismatch>, ContractError> {
458 let mut mismatches = Vec::new();
459
460 let old_fields = old_schema.get("fields").and_then(|v| v.as_array()).ok_or_else(|| {
462 ContractError::SchemaValidation("Invalid Avro schema: missing fields".to_string())
463 })?;
464 let new_fields = new_schema.get("fields").and_then(|v| v.as_array()).ok_or_else(|| {
465 ContractError::SchemaValidation("Invalid Avro schema: missing fields".to_string())
466 })?;
467
468 let old_fields_map: HashMap<String, &Value> = old_fields
470 .iter()
471 .filter_map(|f| {
472 f.get("name").and_then(|n| n.as_str()).map(|name| (name.to_string(), f))
473 })
474 .collect();
475 let new_fields_map: HashMap<String, &Value> = new_fields
476 .iter()
477 .filter_map(|f| {
478 f.get("name").and_then(|n| n.as_str()).map(|name| (name.to_string(), f))
479 })
480 .collect();
481
482 for field_name in old_fields_map.keys() {
484 if !new_fields_map.contains_key(field_name) {
485 let mut context = HashMap::new();
486 context.insert("is_additive".to_string(), serde_json::json!(false));
487 context.insert("is_breaking".to_string(), serde_json::json!(true));
488 context.insert("change_category".to_string(), serde_json::json!("field_removed"));
489 context.insert("field_name".to_string(), serde_json::json!(field_name));
490 context.insert("schema_format".to_string(), serde_json::json!("avro"));
491
492 mismatches.push(Mismatch {
493 mismatch_type: MismatchType::EndpointNotFound,
494 path: format!("{}.{}", path_prefix, field_name),
495 method: None,
496 expected: Some(format!("Field {} should exist", field_name)),
497 actual: Some("Field removed".to_string()),
498 description: format!("Avro field {} was removed", field_name),
499 severity: MismatchSeverity::High,
500 confidence: 1.0,
501 context,
502 });
503 }
504 }
505
506 for (field_name, new_field) in &new_fields_map {
508 if !old_fields_map.contains_key(field_name) {
509 let has_default = new_field.get("default").is_some();
511 let is_required = !has_default;
512
513 let mut context = HashMap::new();
514 context.insert("is_additive".to_string(), serde_json::json!(!is_required));
515 context.insert("is_breaking".to_string(), serde_json::json!(is_required));
516 context.insert(
517 "change_category".to_string(),
518 serde_json::json!(if is_required {
519 "required_field_added"
520 } else {
521 "field_added"
522 }),
523 );
524 context.insert("field_name".to_string(), serde_json::json!(field_name));
525 context.insert("schema_format".to_string(), serde_json::json!("avro"));
526 context.insert("has_default".to_string(), serde_json::json!(has_default));
527
528 mismatches.push(Mismatch {
529 mismatch_type: if is_required {
530 MismatchType::MissingRequiredField
531 } else {
532 MismatchType::UnexpectedField
533 },
534 path: format!("{}.{}", path_prefix, field_name),
535 method: None,
536 expected: None,
537 actual: Some(format!(
538 "New Avro field {} ({})",
539 field_name,
540 if is_required { "required" } else { "optional" }
541 )),
542 description: format!(
543 "New Avro field {} was added ({})",
544 field_name,
545 if is_required {
546 "required - breaking"
547 } else {
548 "optional - additive"
549 }
550 ),
551 severity: if is_required {
552 MismatchSeverity::High
553 } else {
554 MismatchSeverity::Low
555 },
556 confidence: 1.0,
557 context,
558 });
559 } else {
560 let old_field = old_fields_map[field_name];
562 let old_type = old_field.get("type");
563 let new_type = new_field.get("type");
564
565 if old_type != new_type {
566 let mut context = HashMap::new();
567 context.insert("is_additive".to_string(), serde_json::json!(false));
568 context.insert("is_breaking".to_string(), serde_json::json!(true));
569 context.insert(
570 "change_category".to_string(),
571 serde_json::json!("field_type_changed"),
572 );
573 context.insert("field_name".to_string(), serde_json::json!(field_name));
574 context.insert("schema_format".to_string(), serde_json::json!("avro"));
575 context.insert("old_type".to_string(), serde_json::json!(old_type));
576 context.insert("new_type".to_string(), serde_json::json!(new_type));
577
578 mismatches.push(Mismatch {
579 mismatch_type: MismatchType::TypeMismatch,
580 path: format!("{}.{}", path_prefix, field_name),
581 method: None,
582 expected: Some(format!("Type: {:?}", old_type)),
583 actual: Some(format!("Type: {:?}", new_type)),
584 description: format!("Avro field {} type changed", field_name),
585 severity: MismatchSeverity::High,
586 confidence: 1.0,
587 context,
588 });
589 }
590 }
591 }
592
593 Ok(mismatches)
594 }
595
596 fn compare_json_shape_schemas(
598 old_schema: &Value,
599 new_schema: &Value,
600 path_prefix: &str,
601 ) -> Vec<Mismatch> {
602 let mut mismatches = Vec::new();
603
604 if let (Some(old_obj), Some(new_obj)) = (old_schema.as_object(), new_schema.as_object()) {
605 for (prop_name, _) in old_obj {
607 if !new_obj.contains_key(prop_name) {
608 let mut context = HashMap::new();
609 context.insert("is_additive".to_string(), serde_json::json!(false));
610 context.insert("is_breaking".to_string(), serde_json::json!(true));
611 context.insert(
612 "change_category".to_string(),
613 serde_json::json!("property_removed"),
614 );
615 context.insert("field_name".to_string(), serde_json::json!(prop_name));
616 context.insert("schema_format".to_string(), serde_json::json!("json_shape"));
617
618 mismatches.push(Mismatch {
619 mismatch_type: MismatchType::UnexpectedField,
620 path: format!("{}.{}", path_prefix, prop_name),
621 method: None,
622 expected: Some(format!("Property {} should exist", prop_name)),
623 actual: Some("Property removed".to_string()),
624 description: format!("Property {} was removed", prop_name),
625 severity: MismatchSeverity::High,
626 confidence: 1.0,
627 context,
628 });
629 }
630 }
631
632 for (prop_name, _) in new_obj {
634 if !old_obj.contains_key(prop_name) {
635 let mut context = HashMap::new();
636 context.insert("is_additive".to_string(), serde_json::json!(true));
637 context.insert("is_breaking".to_string(), serde_json::json!(false));
638 context
639 .insert("change_category".to_string(), serde_json::json!("property_added"));
640 context.insert("field_name".to_string(), serde_json::json!(prop_name));
641 context.insert("schema_format".to_string(), serde_json::json!("json_shape"));
642
643 mismatches.push(Mismatch {
644 mismatch_type: MismatchType::UnexpectedField,
645 path: format!("{}.{}", path_prefix, prop_name),
646 method: None,
647 expected: None,
648 actual: Some(format!("New property {}", prop_name)),
649 description: format!("New property {} was added", prop_name),
650 severity: MismatchSeverity::Low,
651 confidence: 1.0,
652 context,
653 });
654 } else {
655 let old_type = old_obj[prop_name]
657 .as_str()
658 .or_else(|| old_obj[prop_name].get("type").and_then(|t| t.as_str()));
659 let new_type = new_obj[prop_name]
660 .as_str()
661 .or_else(|| new_obj[prop_name].get("type").and_then(|t| t.as_str()));
662
663 if old_type != new_type {
664 let mut context = HashMap::new();
665 context.insert("is_additive".to_string(), serde_json::json!(false));
666 context.insert("is_breaking".to_string(), serde_json::json!(true));
667 context.insert(
668 "change_category".to_string(),
669 serde_json::json!("property_type_changed"),
670 );
671 context.insert("field_name".to_string(), serde_json::json!(prop_name));
672 context
673 .insert("schema_format".to_string(), serde_json::json!("json_shape"));
674 context.insert("old_type".to_string(), serde_json::json!(old_type));
675 context.insert("new_type".to_string(), serde_json::json!(new_type));
676
677 mismatches.push(Mismatch {
678 mismatch_type: MismatchType::TypeMismatch,
679 path: format!("{}.{}", path_prefix, prop_name),
680 method: None,
681 expected: old_type.map(|t| format!("Type: {}", t)),
682 actual: new_type.map(|t| format!("Type: {}", t)),
683 description: format!("Property {} type changed", prop_name),
684 severity: MismatchSeverity::High,
685 confidence: 1.0,
686 context,
687 });
688 }
689 }
690 }
691 }
692
693 mismatches
694 }
695
696 fn compare_json_schemas(
698 old_schema: &Value,
699 new_schema: &Value,
700 path_prefix: &str,
701 ) -> Vec<Mismatch> {
702 let mut mismatches = Vec::new();
703
704 if let (Some(old_required), Some(new_required)) = (
706 old_schema.get("required").and_then(|v| v.as_array()),
707 new_schema.get("required").and_then(|v| v.as_array()),
708 ) {
709 let old_required_set: std::collections::HashSet<&str> =
710 old_required.iter().filter_map(|v| v.as_str()).collect();
711 let new_required_set: std::collections::HashSet<&str> =
712 new_required.iter().filter_map(|v| v.as_str()).collect();
713
714 for new_req in new_required_set.difference(&old_required_set) {
716 let mut context = HashMap::new();
717 context.insert("is_additive".to_string(), serde_json::json!(false));
718 context.insert("is_breaking".to_string(), serde_json::json!(true));
719 context.insert(
720 "change_category".to_string(),
721 serde_json::json!("required_field_added"),
722 );
723 context.insert("field_name".to_string(), serde_json::json!(new_req));
724
725 mismatches.push(Mismatch {
726 mismatch_type: MismatchType::MissingRequiredField,
727 path: format!("{}.{}", path_prefix, new_req),
728 method: None,
729 expected: Some(format!("Field {} should be optional", new_req)),
730 actual: Some(format!("Field {} is now required", new_req)),
731 description: format!("Field {} became required", new_req),
732 severity: MismatchSeverity::Critical,
733 confidence: 1.0,
734 context,
735 });
736 }
737
738 for removed_req in old_required_set.difference(&new_required_set) {
740 let mut context = HashMap::new();
741 context.insert("is_additive".to_string(), serde_json::json!(true));
742 context.insert("is_breaking".to_string(), serde_json::json!(false));
743 context.insert(
744 "change_category".to_string(),
745 serde_json::json!("required_field_removed"),
746 );
747 context.insert("field_name".to_string(), serde_json::json!(removed_req));
748
749 mismatches.push(Mismatch {
750 mismatch_type: MismatchType::UnexpectedField,
751 path: format!("{}.{}", path_prefix, removed_req),
752 method: None,
753 expected: Some(format!("Field {} was required", removed_req)),
754 actual: Some(format!("Field {} is now optional", removed_req)),
755 description: format!("Field {} is no longer required", removed_req),
756 severity: MismatchSeverity::Low,
757 confidence: 1.0,
758 context,
759 });
760 }
761 }
762
763 if let (Some(old_props), Some(new_props)) = (
765 old_schema.get("properties").and_then(|v| v.as_object()),
766 new_schema.get("properties").and_then(|v| v.as_object()),
767 ) {
768 for (prop_name, new_prop_schema) in new_props {
769 if let Some(old_prop_schema) = old_props.get(prop_name) {
770 if let (Some(old_type), Some(new_type)) = (
771 old_prop_schema.get("type").and_then(|v| v.as_str()),
772 new_prop_schema.get("type").and_then(|v| v.as_str()),
773 ) {
774 if old_type != new_type {
775 let mut context = HashMap::new();
776 context.insert("is_additive".to_string(), serde_json::json!(false));
777 context.insert("is_breaking".to_string(), serde_json::json!(true));
778 context.insert(
779 "change_category".to_string(),
780 serde_json::json!("property_type_changed"),
781 );
782 context.insert("field_name".to_string(), serde_json::json!(prop_name));
783 context.insert("old_type".to_string(), serde_json::json!(old_type));
784 context.insert("new_type".to_string(), serde_json::json!(new_type));
785
786 mismatches.push(Mismatch {
787 mismatch_type: MismatchType::TypeMismatch,
788 path: format!("{}.{}", path_prefix, prop_name),
789 method: None,
790 expected: Some(format!("Type: {}", old_type)),
791 actual: Some(format!("Type: {}", new_type)),
792 description: format!(
793 "Property {} type changed from {} to {}",
794 prop_name, old_type, new_type
795 ),
796 severity: MismatchSeverity::High,
797 confidence: 1.0,
798 context,
799 });
800 }
801 }
802 } else {
803 let mut context = HashMap::new();
805 context.insert("is_additive".to_string(), serde_json::json!(true));
806 context.insert("is_breaking".to_string(), serde_json::json!(false));
807 context
808 .insert("change_category".to_string(), serde_json::json!("property_added"));
809 context.insert("field_name".to_string(), serde_json::json!(prop_name));
810
811 mismatches.push(Mismatch {
812 mismatch_type: MismatchType::UnexpectedField,
813 path: format!("{}.{}", path_prefix, prop_name),
814 method: None,
815 expected: None,
816 actual: Some(format!("New property {}", prop_name)),
817 description: format!("New property {} was added", prop_name),
818 severity: MismatchSeverity::Low,
819 confidence: 1.0,
820 context,
821 });
822 }
823 }
824
825 for prop_name in old_props.keys() {
827 if !new_props.contains_key(prop_name) {
828 let mut context = HashMap::new();
829 context.insert("is_additive".to_string(), serde_json::json!(false));
830 context.insert("is_breaking".to_string(), serde_json::json!(true));
831 context.insert(
832 "change_category".to_string(),
833 serde_json::json!("property_removed"),
834 );
835 context.insert("field_name".to_string(), serde_json::json!(prop_name));
836
837 mismatches.push(Mismatch {
838 mismatch_type: MismatchType::UnexpectedField,
839 path: format!("{}.{}", path_prefix, prop_name),
840 method: None,
841 expected: Some(format!("Property {} should exist", prop_name)),
842 actual: Some("Property removed".to_string()),
843 description: format!("Property {} was removed", prop_name),
844 severity: MismatchSeverity::High,
845 confidence: 1.0,
846 context,
847 });
848 }
849 }
850 }
851
852 mismatches
853 }
854
855 fn validate_message_against_schema(
857 &self,
858 message_type_id: &str,
859 message: &Value,
860 ) -> Result<ValidationResult, ContractError> {
861 let schema = self
862 .schema_cache
863 .get(message_type_id)
864 .ok_or_else(|| ContractError::OperationNotFound(message_type_id.to_string()))?;
865
866 let mut validation_errors = Vec::new();
868 for error in schema.iter_errors(message) {
869 validation_errors.push(ValidationError {
870 message: error.to_string(),
871 path: Some(error.instance_path.to_string()),
872 code: Some("SCHEMA_VALIDATION_ERROR".to_string()),
873 });
874 }
875
876 Ok(ValidationResult {
877 valid: validation_errors.is_empty(),
878 errors: validation_errors,
879 warnings: Vec::new(),
880 })
881 }
882}
883
884#[async_trait::async_trait]
885impl ProtocolContract for WebSocketContract {
886 fn protocol(&self) -> Protocol {
887 Protocol::WebSocket
888 }
889
890 fn contract_id(&self) -> &str {
891 &self.contract_id
892 }
893
894 fn version(&self) -> &str {
895 &self.version
896 }
897
898 fn operations(&self) -> Vec<ContractOperation> {
899 self.operations_cache.values().cloned().collect()
900 }
901
902 fn get_operation(&self, operation_id: &str) -> Option<&ContractOperation> {
903 if let Some(operation) = self.operations_cache.get(operation_id) {
905 return Some(operation);
906 }
907
908 if !operation_id.contains(':') {
910 for operation in self.operations_cache.values() {
912 if let OperationType::WebSocketMessage { message_type, .. } =
913 &operation.operation_type
914 {
915 if message_type == operation_id {
916 return Some(operation);
917 }
918 }
919 }
920 }
921
922 None
923 }
924
925 async fn diff(
926 &self,
927 other: &dyn ProtocolContract,
928 ) -> Result<ContractDiffResult, ContractError> {
929 if other.protocol() != Protocol::WebSocket {
931 return Err(ContractError::UnsupportedProtocol(other.protocol()));
932 }
933
934 Err(ContractError::Other(
936 "Direct comparison of WebSocketContract instances requires type information. \
937 Use WebSocketContract::diff_contracts() for comparing two WebSocketContract instances."
938 .to_string(),
939 ))
940 }
941
942 async fn validate(
943 &self,
944 operation_id: &str,
945 request: &ContractRequest,
946 ) -> Result<ValidationResult, ContractError> {
947 let message: Value = serde_json::from_slice(&request.payload)
949 .map_err(|e| ContractError::SchemaValidation(format!("Invalid JSON: {}", e)))?;
950
951 let message_type_id = if let Some((_, message_type)) = operation_id.split_once(':') {
953 message_type
954 } else {
955 operation_id
956 };
957
958 self.validate_message_against_schema(message_type_id, &message)
960 }
961
962 fn get_schema(&self, operation_id: &str) -> Option<Value> {
963 let message_type_id = if let Some((_, message_type)) = operation_id.split_once(':') {
965 message_type
966 } else {
967 operation_id
968 };
969
970 self.message_types.get(message_type_id).map(|mt| mt.schema.clone())
971 }
972
973 fn to_json(&self) -> Result<Value, ContractError> {
974 let message_types: Vec<Value> = self
975 .message_types
976 .values()
977 .map(|mt| {
978 serde_json::json!({
979 "message_type": mt.message_type,
980 "topic": mt.topic,
981 "schema": mt.schema,
982 "direction": mt.direction,
983 "description": mt.description,
984 "example": mt.example,
985 })
986 })
987 .collect();
988
989 Ok(serde_json::json!({
990 "contract_id": self.contract_id,
991 "version": self.version,
992 "protocol": "websocket",
993 "message_types": message_types,
994 "topics": self.topics.keys().collect::<Vec<_>>(),
995 "metadata": self.metadata,
996 }))
997 }
998}
999
1000pub fn diff_websocket_contracts(
1002 old_contract: &WebSocketContract,
1003 new_contract: &WebSocketContract,
1004) -> Result<ContractDiffResult, ContractError> {
1005 old_contract.diff_contracts(new_contract)
1006}
1007
1008#[cfg(test)]
1009mod tests {
1010 use super::*;
1011
1012 #[test]
1013 fn test_websocket_contract_creation() {
1014 let contract = WebSocketContract::new("test-contract".to_string(), "1.0.0".to_string());
1015 assert_eq!(contract.contract_id(), "test-contract");
1016 assert_eq!(contract.version(), "1.0.0");
1017 }
1018
1019 #[test]
1020 fn test_add_message_type() {
1021 let mut contract = WebSocketContract::new("test".to_string(), "1.0.0".to_string());
1022 let message_type = WebSocketMessageType {
1023 message_type: "chat_message".to_string(),
1024 topic: Some("chat".to_string()),
1025 schema: serde_json::json!({
1026 "type": "object",
1027 "properties": {
1028 "text": {"type": "string"},
1029 "user": {"type": "string"}
1030 },
1031 "required": ["text", "user"]
1032 }),
1033 direction: MessageDirection::Bidirectional,
1034 description: Some("Chat message".to_string()),
1035 example: None,
1036 };
1037
1038 assert!(contract.add_message_type(message_type).is_ok());
1039 assert_eq!(contract.message_types.len(), 1);
1040 }
1041}