use crate::ai_contract_diff::{ContractDiffResult, Mismatch, MismatchSeverity};
use crate::protocol_abstraction::Protocol;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[async_trait::async_trait]
pub trait ProtocolContract: Send + Sync {
fn protocol(&self) -> Protocol;
fn contract_id(&self) -> &str;
fn version(&self) -> &str;
fn operations(&self) -> Vec<ContractOperation>;
fn get_operation(&self, operation_id: &str) -> Option<&ContractOperation>;
async fn diff(&self, other: &dyn ProtocolContract)
-> Result<ContractDiffResult, ContractError>;
async fn validate(
&self,
operation_id: &str,
request: &ContractRequest,
) -> Result<ValidationResult, ContractError>;
fn get_schema(&self, operation_id: &str) -> Option<serde_json::Value>;
fn to_json(&self) -> Result<serde_json::Value, ContractError>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContractOperation {
pub id: String,
pub name: String,
pub operation_type: OperationType,
pub input_schema: Option<serde_json::Value>,
pub output_schema: Option<serde_json::Value>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum OperationType {
HttpEndpoint {
method: String,
path: String,
},
GrpcMethod {
service: String,
method: String,
},
WebSocketMessage {
message_type: String,
topic: Option<String>,
},
MqttTopic {
topic: String,
qos: Option<u8>,
},
KafkaTopic {
topic: String,
key_schema: Option<String>,
value_schema: Option<String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContractRequest {
pub protocol: Protocol,
pub operation_id: String,
pub payload: Vec<u8>,
pub content_type: Option<String>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationResult {
pub valid: bool,
pub errors: Vec<ValidationError>,
pub warnings: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationError {
pub message: String,
pub path: Option<String>,
pub code: Option<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum ContractError {
#[error("Contract not found: {0}")]
NotFound(String),
#[error("Invalid contract format: {0}")]
InvalidFormat(String),
#[error("Unsupported protocol: {0:?}")]
UnsupportedProtocol(Protocol),
#[error("Operation not found: {0}")]
OperationNotFound(String),
#[error("Schema validation error: {0}")]
SchemaValidation(String),
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Other error: {0}")]
Other(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContractMetadata {
pub name: String,
pub version: String,
pub protocol: Protocol,
pub description: Option<String>,
pub tags: Vec<String>,
pub created_at: Option<i64>,
pub updated_at: Option<i64>,
}
pub struct ProtocolContractRegistry {
contracts: HashMap<String, Box<dyn ProtocolContract>>,
}
impl ProtocolContractRegistry {
pub fn new() -> Self {
Self {
contracts: HashMap::new(),
}
}
pub fn register(&mut self, contract: Box<dyn ProtocolContract>) {
let id = contract.contract_id().to_string();
self.contracts.insert(id, contract);
}
pub fn get(&self, contract_id: &str) -> Option<&dyn ProtocolContract> {
self.contracts.get(contract_id).map(|c| c.as_ref())
}
pub fn list(&self) -> Vec<&dyn ProtocolContract> {
self.contracts.values().map(|c| c.as_ref()).collect()
}
pub fn list_by_protocol(&self, protocol: Protocol) -> Vec<&dyn ProtocolContract> {
self.contracts
.values()
.filter(|c| c.protocol() == protocol)
.map(|c| c.as_ref())
.collect()
}
pub fn remove(&mut self, contract_id: &str) -> Option<Box<dyn ProtocolContract>> {
self.contracts.remove(contract_id)
}
}
impl Default for ProtocolContractRegistry {
fn default() -> Self {
Self::new()
}
}
pub async fn compare_contracts(
old_contract: &dyn ProtocolContract,
new_contract: &dyn ProtocolContract,
) -> Result<ContractDiffResult, ContractError> {
if old_contract.protocol() != new_contract.protocol() {
return Err(ContractError::Other(format!(
"Cannot compare contracts of different protocols: {:?} vs {:?}",
old_contract.protocol(),
new_contract.protocol()
)));
}
old_contract.diff(new_contract).await
}
pub fn extract_breaking_changes(diff: &ContractDiffResult) -> Vec<&Mismatch> {
diff.mismatches
.iter()
.filter(|m| {
matches!(m.severity, MismatchSeverity::Critical | MismatchSeverity::High)
&& matches!(
m.mismatch_type,
crate::ai_contract_diff::MismatchType::MissingRequiredField
| crate::ai_contract_diff::MismatchType::TypeMismatch
| crate::ai_contract_diff::MismatchType::EndpointNotFound
| crate::ai_contract_diff::MismatchType::MethodNotAllowed
| crate::ai_contract_diff::MismatchType::SchemaMismatch
)
})
.collect()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChangeClassification {
pub is_additive: bool,
pub is_breaking: bool,
pub change_category: Option<String>,
}
pub fn classify_change(mismatch: &Mismatch) -> ChangeClassification {
let is_additive =
mismatch.context.get("is_additive").and_then(|v| v.as_bool()).unwrap_or(false);
let is_breaking = mismatch.context.get("is_breaking").and_then(|v| v.as_bool()).unwrap_or({
matches!(mismatch.severity, MismatchSeverity::Critical | MismatchSeverity::High)
&& matches!(
mismatch.mismatch_type,
crate::ai_contract_diff::MismatchType::MissingRequiredField
| crate::ai_contract_diff::MismatchType::TypeMismatch
| crate::ai_contract_diff::MismatchType::EndpointNotFound
| crate::ai_contract_diff::MismatchType::MethodNotAllowed
| crate::ai_contract_diff::MismatchType::SchemaMismatch
)
});
let change_category = mismatch
.context
.get("change_category")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
ChangeClassification {
is_additive,
is_breaking,
change_category,
}
}
pub fn generate_grpc_drift_report(diff: &ContractDiffResult) -> serde_json::Value {
use std::collections::HashMap;
let mut service_reports: HashMap<String, HashMap<String, Vec<&Mismatch>>> = HashMap::new();
for mismatch in &diff.mismatches {
let service = mismatch
.context
.get("service")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| {
mismatch.path.split('.').next().map(|s| s.to_string())
})
.unwrap_or_else(|| "unknown".to_string());
let method = mismatch
.context
.get("method")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| mismatch.method.clone())
.or_else(|| {
mismatch.path.split('.').nth(1).map(|s| s.to_string())
})
.unwrap_or_else(|| "unknown".to_string());
service_reports
.entry(service)
.or_default()
.entry(method)
.or_default()
.push(mismatch);
}
let mut report = serde_json::Map::new();
let mut services_json = serde_json::Map::new();
for (service_name, methods) in service_reports {
let mut service_json = serde_json::Map::new();
let mut methods_json = serde_json::Map::new();
let mut service_additive = 0;
let mut service_breaking = 0;
for (method_name, mismatches) in methods {
let mut method_json = serde_json::Map::new();
let mut method_additive = 0;
let mut method_breaking = 0;
let mut changes = Vec::new();
let total_changes = mismatches.len();
for mismatch in mismatches {
let classification = classify_change(mismatch);
if classification.is_additive {
method_additive += 1;
}
if classification.is_breaking {
method_breaking += 1;
}
changes.push(serde_json::json!({
"description": mismatch.description,
"path": mismatch.path,
"severity": format!("{:?}", mismatch.severity),
"is_additive": classification.is_additive,
"is_breaking": classification.is_breaking,
"change_category": classification.change_category,
}));
}
method_json.insert("additive_changes".to_string(), serde_json::json!(method_additive));
method_json.insert("breaking_changes".to_string(), serde_json::json!(method_breaking));
method_json.insert("total_changes".to_string(), serde_json::json!(total_changes));
method_json.insert("changes".to_string(), serde_json::json!(changes));
service_additive += method_additive;
service_breaking += method_breaking;
methods_json.insert(method_name, serde_json::Value::Object(method_json));
}
service_json.insert("additive_changes".to_string(), serde_json::json!(service_additive));
service_json.insert("breaking_changes".to_string(), serde_json::json!(service_breaking));
service_json.insert("methods".to_string(), serde_json::Value::Object(methods_json));
services_json.insert(service_name, serde_json::Value::Object(service_json));
}
report.insert("services".to_string(), serde_json::Value::Object(services_json));
report.insert("total_mismatches".to_string(), serde_json::json!(diff.mismatches.len()));
serde_json::Value::Object(report)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_operation_type_serialization() {
let op_type = OperationType::HttpEndpoint {
method: "GET".to_string(),
path: "/api/users".to_string(),
};
let json = serde_json::to_string(&op_type).unwrap();
assert!(json.contains("http_endpoint"));
assert!(json.contains("GET"));
assert!(json.contains("/api/users"));
}
#[test]
fn test_contract_registry() {
let registry = ProtocolContractRegistry::new();
assert_eq!(registry.list().len(), 0);
}
}