Skip to main content

mockforge_contracts/contract_drift/
protocol_contracts.rs

1//! Protocol-agnostic contract abstractions for multi-protocol drift detection
2//!
3//! This module provides a unified interface for contract definitions across different
4//! protocols (HTTP/OpenAPI, gRPC, WebSocket, MQTT, Kafka), enabling consistent drift
5//! detection and analysis regardless of the transport layer.
6//!
7//! The trait and data types are re-exported from
8//! `mockforge-foundation::protocol_contract_types` so consumers can implement
9//! and use them without depending on deprecated core modules.
10
11use mockforge_foundation::contract_diff_types::{ContractDiffResult, Mismatch};
12use mockforge_foundation::protocol::Protocol;
13pub use mockforge_foundation::protocol_contract_types::{
14    classify_change, extract_breaking_changes, ChangeClassification, ContractError,
15    ContractMetadata, ContractOperation, ContractRequest, OperationType, ProtocolContract,
16    ValidationError, ValidationResult,
17};
18use std::collections::HashMap;
19
20/// Registry for managing protocol contracts
21pub struct ProtocolContractRegistry {
22    contracts: HashMap<String, Box<dyn ProtocolContract>>,
23}
24
25impl ProtocolContractRegistry {
26    /// Create a new contract registry
27    pub fn new() -> Self {
28        Self {
29            contracts: HashMap::new(),
30        }
31    }
32
33    /// Register a contract
34    pub fn register(&mut self, contract: Box<dyn ProtocolContract>) {
35        let id = contract.contract_id().to_string();
36        self.contracts.insert(id, contract);
37    }
38
39    /// Get a contract by ID
40    pub fn get(&self, contract_id: &str) -> Option<&dyn ProtocolContract> {
41        self.contracts.get(contract_id).map(|c| c.as_ref())
42    }
43
44    /// List all contracts
45    pub fn list(&self) -> Vec<&dyn ProtocolContract> {
46        self.contracts.values().map(|c| c.as_ref()).collect()
47    }
48
49    /// List contracts by protocol
50    pub fn list_by_protocol(&self, protocol: Protocol) -> Vec<&dyn ProtocolContract> {
51        self.contracts
52            .values()
53            .filter(|c| c.protocol() == protocol)
54            .map(|c| c.as_ref())
55            .collect()
56    }
57
58    /// Remove a contract
59    pub fn remove(&mut self, contract_id: &str) -> Option<Box<dyn ProtocolContract>> {
60        self.contracts.remove(contract_id)
61    }
62}
63
64impl Default for ProtocolContractRegistry {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70/// Helper function to compare two contracts and generate drift analysis
71pub async fn compare_contracts(
72    old_contract: &dyn ProtocolContract,
73    new_contract: &dyn ProtocolContract,
74) -> Result<ContractDiffResult, ContractError> {
75    // Ensure protocols match
76    if old_contract.protocol() != new_contract.protocol() {
77        return Err(ContractError::Other(format!(
78            "Cannot compare contracts of different protocols: {:?} vs {:?}",
79            old_contract.protocol(),
80            new_contract.protocol()
81        )));
82    }
83
84    // Use the contract's diff method
85    old_contract.diff(new_contract).await
86}
87
88/// Generate a per-service+method drift report for gRPC contracts
89///
90/// Groups mismatches by service and method, showing additive vs breaking changes
91pub fn generate_grpc_drift_report(diff: &ContractDiffResult) -> serde_json::Value {
92    use std::collections::HashMap;
93
94    // Group mismatches by service and method
95    let mut service_reports: HashMap<String, HashMap<String, Vec<&Mismatch>>> = HashMap::new();
96
97    for mismatch in &diff.mismatches {
98        // Extract service and method from context or path
99        let service = mismatch
100            .context
101            .get("service")
102            .and_then(|v| v.as_str())
103            .map(|s| s.to_string())
104            .or_else(|| {
105                // Fallback: try to extract from path (format: "service.method" or "service")
106                mismatch.path.split('.').next().map(|s| s.to_string())
107            })
108            .unwrap_or_else(|| "unknown".to_string());
109
110        let method = mismatch
111            .context
112            .get("method")
113            .and_then(|v| v.as_str())
114            .map(|s| s.to_string())
115            .or_else(|| mismatch.method.clone())
116            .or_else(|| {
117                // Fallback: try to extract from path
118                mismatch.path.split('.').nth(1).map(|s| s.to_string())
119            })
120            .unwrap_or_else(|| "unknown".to_string());
121
122        service_reports
123            .entry(service)
124            .or_default()
125            .entry(method)
126            .or_default()
127            .push(mismatch);
128    }
129
130    // Build report structure
131    let mut report = serde_json::Map::new();
132    let mut services_json = serde_json::Map::new();
133
134    for (service_name, methods) in service_reports {
135        let mut service_json = serde_json::Map::new();
136        let mut methods_json = serde_json::Map::new();
137        let mut service_additive = 0;
138        let mut service_breaking = 0;
139
140        for (method_name, mismatches) in methods {
141            let mut method_json = serde_json::Map::new();
142            let mut method_additive = 0;
143            let mut method_breaking = 0;
144            let mut changes = Vec::new();
145
146            // Save length before consuming mismatches in the loop
147            let total_changes = mismatches.len();
148
149            for mismatch in mismatches {
150                let classification = classify_change(mismatch);
151                if classification.is_additive {
152                    method_additive += 1;
153                }
154                if classification.is_breaking {
155                    method_breaking += 1;
156                }
157
158                changes.push(serde_json::json!({
159                    "description": mismatch.description,
160                    "path": mismatch.path,
161                    "severity": format!("{:?}", mismatch.severity),
162                    "is_additive": classification.is_additive,
163                    "is_breaking": classification.is_breaking,
164                    "change_category": classification.change_category,
165                }));
166            }
167
168            method_json.insert("additive_changes".to_string(), serde_json::json!(method_additive));
169            method_json.insert("breaking_changes".to_string(), serde_json::json!(method_breaking));
170            method_json.insert("total_changes".to_string(), serde_json::json!(total_changes));
171            method_json.insert("changes".to_string(), serde_json::json!(changes));
172
173            service_additive += method_additive;
174            service_breaking += method_breaking;
175
176            methods_json.insert(method_name, serde_json::Value::Object(method_json));
177        }
178
179        service_json.insert("additive_changes".to_string(), serde_json::json!(service_additive));
180        service_json.insert("breaking_changes".to_string(), serde_json::json!(service_breaking));
181        service_json.insert("methods".to_string(), serde_json::Value::Object(methods_json));
182
183        services_json.insert(service_name, serde_json::Value::Object(service_json));
184    }
185
186    report.insert("services".to_string(), serde_json::Value::Object(services_json));
187    report.insert("total_mismatches".to_string(), serde_json::json!(diff.mismatches.len()));
188
189    serde_json::Value::Object(report)
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn test_operation_type_serialization() {
198        let op_type = OperationType::HttpEndpoint {
199            method: "GET".to_string(),
200            path: "/api/users".to_string(),
201        };
202        let json = serde_json::to_string(&op_type).unwrap();
203        assert!(json.contains("http_endpoint"));
204        assert!(json.contains("GET"));
205        assert!(json.contains("/api/users"));
206    }
207
208    #[test]
209    fn test_contract_registry() {
210        // This test would require a mock implementation of ProtocolContract
211        // For now, just test the registry structure
212        let registry = ProtocolContractRegistry::new();
213        assert_eq!(registry.list().len(), 0);
214    }
215}