use crate::ai_contract_diff::{ContractDiffResult, Mismatch};
use mockforge_foundation::protocol::Protocol;
pub use mockforge_foundation::protocol_contract_types::{
classify_change, extract_breaking_changes, ChangeClassification, ContractError,
ContractMetadata, ContractOperation, ContractRequest, OperationType, ProtocolContract,
ValidationError, ValidationResult,
};
use std::collections::HashMap;
pub struct ProtocolContractRegistry {
contracts: HashMap<String, Box<dyn ProtocolContract>>,
}
impl ProtocolContractRegistry {
pub fn new() -> Self {
Self {
contracts: HashMap::new(),
}
}
pub fn register(&mut self, contract: Box<dyn ProtocolContract>) {
let id = contract.contract_id().to_string();
self.contracts.insert(id, contract);
}
pub fn get(&self, contract_id: &str) -> Option<&dyn ProtocolContract> {
self.contracts.get(contract_id).map(|c| c.as_ref())
}
pub fn list(&self) -> Vec<&dyn ProtocolContract> {
self.contracts.values().map(|c| c.as_ref()).collect()
}
pub fn list_by_protocol(&self, protocol: Protocol) -> Vec<&dyn ProtocolContract> {
self.contracts
.values()
.filter(|c| c.protocol() == protocol)
.map(|c| c.as_ref())
.collect()
}
pub fn remove(&mut self, contract_id: &str) -> Option<Box<dyn ProtocolContract>> {
self.contracts.remove(contract_id)
}
}
impl Default for ProtocolContractRegistry {
fn default() -> Self {
Self::new()
}
}
pub async fn compare_contracts(
old_contract: &dyn ProtocolContract,
new_contract: &dyn ProtocolContract,
) -> Result<ContractDiffResult, ContractError> {
if old_contract.protocol() != new_contract.protocol() {
return Err(ContractError::Other(format!(
"Cannot compare contracts of different protocols: {:?} vs {:?}",
old_contract.protocol(),
new_contract.protocol()
)));
}
old_contract.diff(new_contract).await
}
pub fn generate_grpc_drift_report(diff: &ContractDiffResult) -> serde_json::Value {
use std::collections::HashMap;
let mut service_reports: HashMap<String, HashMap<String, Vec<&Mismatch>>> = HashMap::new();
for mismatch in &diff.mismatches {
let service = mismatch
.context
.get("service")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| {
mismatch.path.split('.').next().map(|s| s.to_string())
})
.unwrap_or_else(|| "unknown".to_string());
let method = mismatch
.context
.get("method")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| mismatch.method.clone())
.or_else(|| {
mismatch.path.split('.').nth(1).map(|s| s.to_string())
})
.unwrap_or_else(|| "unknown".to_string());
service_reports
.entry(service)
.or_default()
.entry(method)
.or_default()
.push(mismatch);
}
let mut report = serde_json::Map::new();
let mut services_json = serde_json::Map::new();
for (service_name, methods) in service_reports {
let mut service_json = serde_json::Map::new();
let mut methods_json = serde_json::Map::new();
let mut service_additive = 0;
let mut service_breaking = 0;
for (method_name, mismatches) in methods {
let mut method_json = serde_json::Map::new();
let mut method_additive = 0;
let mut method_breaking = 0;
let mut changes = Vec::new();
let total_changes = mismatches.len();
for mismatch in mismatches {
let classification = classify_change(mismatch);
if classification.is_additive {
method_additive += 1;
}
if classification.is_breaking {
method_breaking += 1;
}
changes.push(serde_json::json!({
"description": mismatch.description,
"path": mismatch.path,
"severity": format!("{:?}", mismatch.severity),
"is_additive": classification.is_additive,
"is_breaking": classification.is_breaking,
"change_category": classification.change_category,
}));
}
method_json.insert("additive_changes".to_string(), serde_json::json!(method_additive));
method_json.insert("breaking_changes".to_string(), serde_json::json!(method_breaking));
method_json.insert("total_changes".to_string(), serde_json::json!(total_changes));
method_json.insert("changes".to_string(), serde_json::json!(changes));
service_additive += method_additive;
service_breaking += method_breaking;
methods_json.insert(method_name, serde_json::Value::Object(method_json));
}
service_json.insert("additive_changes".to_string(), serde_json::json!(service_additive));
service_json.insert("breaking_changes".to_string(), serde_json::json!(service_breaking));
service_json.insert("methods".to_string(), serde_json::Value::Object(methods_json));
services_json.insert(service_name, serde_json::Value::Object(service_json));
}
report.insert("services".to_string(), serde_json::Value::Object(services_json));
report.insert("total_mismatches".to_string(), serde_json::json!(diff.mismatches.len()));
serde_json::Value::Object(report)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_operation_type_serialization() {
let op_type = OperationType::HttpEndpoint {
method: "GET".to_string(),
path: "/api/users".to_string(),
};
let json = serde_json::to_string(&op_type).unwrap();
assert!(json.contains("http_endpoint"));
assert!(json.contains("GET"));
assert!(json.contains("/api/users"));
}
#[test]
fn test_contract_registry() {
let registry = ProtocolContractRegistry::new();
assert_eq!(registry.list().len(), 0);
}
}