mockforge_contracts/contract_drift/
protocol_contracts.rs1use mockforge_foundation::contract_diff_types::{ContractDiffResult, Mismatch};
12use mockforge_foundation::protocol::Protocol;
13pub use mockforge_foundation::protocol_contract_types::{
14 classify_change, extract_breaking_changes, ChangeClassification, ContractError,
15 ContractMetadata, ContractOperation, ContractRequest, OperationType, ProtocolContract,
16 ValidationError, ValidationResult,
17};
18use std::collections::HashMap;
19
20pub struct ProtocolContractRegistry {
22 contracts: HashMap<String, Box<dyn ProtocolContract>>,
23}
24
25impl ProtocolContractRegistry {
26 pub fn new() -> Self {
28 Self {
29 contracts: HashMap::new(),
30 }
31 }
32
33 pub fn register(&mut self, contract: Box<dyn ProtocolContract>) {
35 let id = contract.contract_id().to_string();
36 self.contracts.insert(id, contract);
37 }
38
39 pub fn get(&self, contract_id: &str) -> Option<&dyn ProtocolContract> {
41 self.contracts.get(contract_id).map(|c| c.as_ref())
42 }
43
44 pub fn list(&self) -> Vec<&dyn ProtocolContract> {
46 self.contracts.values().map(|c| c.as_ref()).collect()
47 }
48
49 pub fn list_by_protocol(&self, protocol: Protocol) -> Vec<&dyn ProtocolContract> {
51 self.contracts
52 .values()
53 .filter(|c| c.protocol() == protocol)
54 .map(|c| c.as_ref())
55 .collect()
56 }
57
58 pub fn remove(&mut self, contract_id: &str) -> Option<Box<dyn ProtocolContract>> {
60 self.contracts.remove(contract_id)
61 }
62}
63
64impl Default for ProtocolContractRegistry {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70pub async fn compare_contracts(
72 old_contract: &dyn ProtocolContract,
73 new_contract: &dyn ProtocolContract,
74) -> Result<ContractDiffResult, ContractError> {
75 if old_contract.protocol() != new_contract.protocol() {
77 return Err(ContractError::Other(format!(
78 "Cannot compare contracts of different protocols: {:?} vs {:?}",
79 old_contract.protocol(),
80 new_contract.protocol()
81 )));
82 }
83
84 old_contract.diff(new_contract).await
86}
87
88pub fn generate_grpc_drift_report(diff: &ContractDiffResult) -> serde_json::Value {
92 use std::collections::HashMap;
93
94 let mut service_reports: HashMap<String, HashMap<String, Vec<&Mismatch>>> = HashMap::new();
96
97 for mismatch in &diff.mismatches {
98 let service = mismatch
100 .context
101 .get("service")
102 .and_then(|v| v.as_str())
103 .map(|s| s.to_string())
104 .or_else(|| {
105 mismatch.path.split('.').next().map(|s| s.to_string())
107 })
108 .unwrap_or_else(|| "unknown".to_string());
109
110 let method = mismatch
111 .context
112 .get("method")
113 .and_then(|v| v.as_str())
114 .map(|s| s.to_string())
115 .or_else(|| mismatch.method.clone())
116 .or_else(|| {
117 mismatch.path.split('.').nth(1).map(|s| s.to_string())
119 })
120 .unwrap_or_else(|| "unknown".to_string());
121
122 service_reports
123 .entry(service)
124 .or_default()
125 .entry(method)
126 .or_default()
127 .push(mismatch);
128 }
129
130 let mut report = serde_json::Map::new();
132 let mut services_json = serde_json::Map::new();
133
134 for (service_name, methods) in service_reports {
135 let mut service_json = serde_json::Map::new();
136 let mut methods_json = serde_json::Map::new();
137 let mut service_additive = 0;
138 let mut service_breaking = 0;
139
140 for (method_name, mismatches) in methods {
141 let mut method_json = serde_json::Map::new();
142 let mut method_additive = 0;
143 let mut method_breaking = 0;
144 let mut changes = Vec::new();
145
146 let total_changes = mismatches.len();
148
149 for mismatch in mismatches {
150 let classification = classify_change(mismatch);
151 if classification.is_additive {
152 method_additive += 1;
153 }
154 if classification.is_breaking {
155 method_breaking += 1;
156 }
157
158 changes.push(serde_json::json!({
159 "description": mismatch.description,
160 "path": mismatch.path,
161 "severity": format!("{:?}", mismatch.severity),
162 "is_additive": classification.is_additive,
163 "is_breaking": classification.is_breaking,
164 "change_category": classification.change_category,
165 }));
166 }
167
168 method_json.insert("additive_changes".to_string(), serde_json::json!(method_additive));
169 method_json.insert("breaking_changes".to_string(), serde_json::json!(method_breaking));
170 method_json.insert("total_changes".to_string(), serde_json::json!(total_changes));
171 method_json.insert("changes".to_string(), serde_json::json!(changes));
172
173 service_additive += method_additive;
174 service_breaking += method_breaking;
175
176 methods_json.insert(method_name, serde_json::Value::Object(method_json));
177 }
178
179 service_json.insert("additive_changes".to_string(), serde_json::json!(service_additive));
180 service_json.insert("breaking_changes".to_string(), serde_json::json!(service_breaking));
181 service_json.insert("methods".to_string(), serde_json::Value::Object(methods_json));
182
183 services_json.insert(service_name, serde_json::Value::Object(service_json));
184 }
185
186 report.insert("services".to_string(), serde_json::Value::Object(services_json));
187 report.insert("total_mismatches".to_string(), serde_json::json!(diff.mismatches.len()));
188
189 serde_json::Value::Object(report)
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn test_operation_type_serialization() {
198 let op_type = OperationType::HttpEndpoint {
199 method: "GET".to_string(),
200 path: "/api/users".to_string(),
201 };
202 let json = serde_json::to_string(&op_type).unwrap();
203 assert!(json.contains("http_endpoint"));
204 assert!(json.contains("GET"));
205 assert!(json.contains("/api/users"));
206 }
207
208 #[test]
209 fn test_contract_registry() {
210 let registry = ProtocolContractRegistry::new();
213 assert_eq!(registry.list().len(), 0);
214 }
215}