use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::Json,
};
use mockforge_core::contract_drift::protocol_contracts::{
compare_contracts, ProtocolContractRegistry,
};
use mockforge_core::contract_drift::{
GrpcContract, KafkaContract, KafkaTopicSchema, MqttContract, MqttTopicSchema, SchemaFormat,
TopicSchema, WebSocketContract, WebSocketMessageType,
};
use mockforge_core::protocol_abstraction::Protocol;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use base64::{engine::general_purpose, Engine as _};
#[derive(Clone)]
pub struct ProtocolContractState {
pub registry: Arc<RwLock<ProtocolContractRegistry>>,
pub drift_engine: Option<Arc<mockforge_core::contract_drift::DriftBudgetEngine>>,
pub incident_manager: Option<Arc<mockforge_core::incidents::IncidentManager>>,
pub fitness_registry:
Option<Arc<RwLock<mockforge_core::contract_drift::FitnessFunctionRegistry>>>,
pub consumer_analyzer:
Option<Arc<RwLock<mockforge_core::contract_drift::ConsumerImpactAnalyzer>>>,
}
#[derive(Debug, Deserialize)]
pub struct CreateGrpcContractRequest {
pub contract_id: String,
pub version: String,
pub descriptor_set: String,
}
#[derive(Debug, Deserialize)]
pub struct CreateWebSocketContractRequest {
pub contract_id: String,
pub version: String,
pub message_types: Vec<WebSocketMessageTypeRequest>,
}
#[derive(Debug, Deserialize)]
pub struct WebSocketMessageTypeRequest {
pub message_type: String,
pub topic: Option<String>,
pub schema: serde_json::Value,
pub direction: String,
pub description: Option<String>,
pub example: Option<serde_json::Value>,
}
#[derive(Debug, Deserialize)]
pub struct CreateMqttContractRequest {
pub contract_id: String,
pub version: String,
pub topics: Vec<MqttTopicSchemaRequest>,
}
#[derive(Debug, Deserialize)]
pub struct MqttTopicSchemaRequest {
pub topic: String,
pub qos: Option<u8>,
pub schema: serde_json::Value,
pub retained: Option<bool>,
pub description: Option<String>,
pub example: Option<serde_json::Value>,
}
#[derive(Debug, Deserialize)]
pub struct CreateKafkaContractRequest {
pub contract_id: String,
pub version: String,
pub topics: Vec<KafkaTopicSchemaRequest>,
}
#[derive(Debug, Deserialize)]
pub struct KafkaTopicSchemaRequest {
pub topic: String,
pub key_schema: Option<TopicSchemaRequest>,
pub value_schema: TopicSchemaRequest,
pub partitions: Option<u32>,
pub replication_factor: Option<u16>,
pub description: Option<String>,
pub evolution_rules: Option<EvolutionRulesRequest>,
}
#[derive(Debug, Deserialize)]
pub struct TopicSchemaRequest {
pub format: String,
pub schema: serde_json::Value,
pub schema_id: Option<String>,
pub version: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct EvolutionRulesRequest {
pub allow_backward_compatible: bool,
pub allow_forward_compatible: bool,
pub require_version_bump: bool,
}
#[derive(Debug, Serialize)]
pub struct ProtocolContractResponse {
pub contract_id: String,
pub version: String,
pub protocol: String,
pub contract: serde_json::Value,
}
#[derive(Debug, Serialize)]
pub struct ListContractsResponse {
pub contracts: Vec<ProtocolContractResponse>,
pub total: usize,
}
#[derive(Debug, Deserialize)]
pub struct CompareContractsRequest {
pub old_contract_id: String,
pub new_contract_id: String,
}
#[derive(Debug, Deserialize)]
pub struct ValidateMessageRequest {
pub operation_id: String,
pub payload: serde_json::Value,
pub content_type: Option<String>,
pub metadata: Option<HashMap<String, String>>,
}
pub async fn list_contracts(
State(state): State<ProtocolContractState>,
Query(params): Query<HashMap<String, String>>,
) -> Result<Json<ListContractsResponse>, (StatusCode, Json<serde_json::Value>)> {
let registry = state.registry.read().await;
let protocol_filter = params.get("protocol").and_then(|p| match p.as_str() {
"grpc" => Some(Protocol::Grpc),
"websocket" => Some(Protocol::WebSocket),
"mqtt" => Some(Protocol::Mqtt),
"kafka" => Some(Protocol::Kafka),
_ => None,
});
let contracts: Vec<ProtocolContractResponse> = if let Some(protocol) = protocol_filter {
registry
.list_by_protocol(protocol)
.iter()
.map(|contract| {
let contract_json = contract.to_json().unwrap_or_else(|_| serde_json::json!({}));
ProtocolContractResponse {
contract_id: contract.contract_id().to_string(),
version: contract.version().to_string(),
protocol: format!("{:?}", contract.protocol()).to_lowercase(),
contract: contract_json,
}
})
.collect()
} else {
registry
.list()
.iter()
.map(|contract| {
let contract_json = contract.to_json().unwrap_or_else(|_| serde_json::json!({}));
ProtocolContractResponse {
contract_id: contract.contract_id().to_string(),
version: contract.version().to_string(),
protocol: format!("{:?}", contract.protocol()).to_lowercase(),
contract: contract_json,
}
})
.collect()
};
Ok(Json(ListContractsResponse {
total: contracts.len(),
contracts,
}))
}
pub async fn get_contract(
State(state): State<ProtocolContractState>,
Path(contract_id): Path<String>,
) -> Result<Json<ProtocolContractResponse>, (StatusCode, Json<serde_json::Value>)> {
let registry = state.registry.read().await;
let contract = registry.get(&contract_id).ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": "Contract not found",
"contract_id": contract_id
})),
)
})?;
let contract_json = contract.to_json().map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": "Failed to serialize contract",
"message": e.to_string()
})),
)
})?;
Ok(Json(ProtocolContractResponse {
contract_id: contract.contract_id().to_string(),
version: contract.version().to_string(),
protocol: format!("{:?}", contract.protocol()).to_lowercase(),
contract: contract_json,
}))
}
pub async fn create_grpc_contract(
State(state): State<ProtocolContractState>,
Json(request): Json<CreateGrpcContractRequest>,
) -> Result<Json<ProtocolContractResponse>, (StatusCode, Json<serde_json::Value>)> {
let descriptor_bytes =
general_purpose::STANDARD.decode(&request.descriptor_set).map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "Invalid base64 descriptor set",
"message": e.to_string()
})),
)
})?;
let contract = GrpcContract::from_descriptor_set(
request.contract_id.clone(),
request.version.clone(),
&descriptor_bytes,
)
.map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "Failed to create gRPC contract",
"message": e.to_string()
})),
)
})?;
let mut registry = state.registry.write().await;
registry.register(Box::new(contract));
let contract = registry.get(&request.contract_id).ok_or_else(|| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": "Failed to retrieve registered contract",
"contract_id": request.contract_id
})),
)
})?;
let contract_json = contract.to_json().map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": "Failed to serialize contract",
"message": e.to_string()
})),
)
})?;
Ok(Json(ProtocolContractResponse {
contract_id: request.contract_id,
version: request.version,
protocol: "grpc".to_string(),
contract: contract_json,
}))
}
pub async fn create_websocket_contract(
State(state): State<ProtocolContractState>,
Json(request): Json<CreateWebSocketContractRequest>,
) -> Result<Json<ProtocolContractResponse>, (StatusCode, Json<serde_json::Value>)> {
let mut contract = WebSocketContract::new(request.contract_id.clone(), request.version.clone());
for msg_type_req in request.message_types {
let direction = match msg_type_req.direction.as_str() {
"inbound" => mockforge_core::contract_drift::MessageDirection::Inbound,
"outbound" => mockforge_core::contract_drift::MessageDirection::Outbound,
"bidirectional" => mockforge_core::contract_drift::MessageDirection::Bidirectional,
_ => {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "Invalid direction",
"message": "Direction must be 'inbound', 'outbound', or 'bidirectional'"
})),
));
}
};
let message_type = WebSocketMessageType {
message_type: msg_type_req.message_type,
topic: msg_type_req.topic,
schema: msg_type_req.schema,
direction,
description: msg_type_req.description,
example: msg_type_req.example,
};
contract.add_message_type(message_type).map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "Failed to add message type",
"message": e.to_string()
})),
)
})?;
}
let mut registry = state.registry.write().await;
registry.register(Box::new(contract));
let contract = registry.get(&request.contract_id).ok_or_else(|| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": "Failed to retrieve registered contract",
"contract_id": request.contract_id
})),
)
})?;
let contract_json = contract.to_json().map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": "Failed to serialize contract",
"message": e.to_string()
})),
)
})?;
Ok(Json(ProtocolContractResponse {
contract_id: request.contract_id,
version: request.version,
protocol: "websocket".to_string(),
contract: contract_json,
}))
}
pub async fn create_mqtt_contract(
State(state): State<ProtocolContractState>,
Json(request): Json<CreateMqttContractRequest>,
) -> Result<Json<ProtocolContractResponse>, (StatusCode, Json<serde_json::Value>)> {
let mut contract = MqttContract::new(request.contract_id.clone(), request.version.clone());
for topic_req in request.topics {
let topic_schema = MqttTopicSchema {
topic: topic_req.topic,
qos: topic_req.qos,
schema: topic_req.schema,
retained: topic_req.retained,
description: topic_req.description,
example: topic_req.example,
};
contract.add_topic(topic_schema).map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "Failed to add topic",
"message": e.to_string()
})),
)
})?;
}
let mut registry = state.registry.write().await;
registry.register(Box::new(contract));
let contract = registry.get(&request.contract_id).ok_or_else(|| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": "Failed to retrieve registered contract",
"contract_id": request.contract_id
})),
)
})?;
let contract_json = contract.to_json().map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": "Failed to serialize contract",
"message": e.to_string()
})),
)
})?;
Ok(Json(ProtocolContractResponse {
contract_id: request.contract_id,
version: request.version,
protocol: "mqtt".to_string(),
contract: contract_json,
}))
}
pub async fn create_kafka_contract(
State(state): State<ProtocolContractState>,
Json(request): Json<CreateKafkaContractRequest>,
) -> Result<Json<ProtocolContractResponse>, (StatusCode, Json<serde_json::Value>)> {
let mut contract = KafkaContract::new(request.contract_id.clone(), request.version.clone());
for topic_req in request.topics {
let format = match topic_req.value_schema.format.as_str() {
"json" => SchemaFormat::Json,
"avro" => SchemaFormat::Avro,
"protobuf" => SchemaFormat::Protobuf,
_ => {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "Invalid schema format",
"message": "Format must be 'json', 'avro', or 'protobuf'"
})),
));
}
};
let value_schema = TopicSchema {
format,
schema: topic_req.value_schema.schema,
schema_id: topic_req.value_schema.schema_id,
version: topic_req.value_schema.version,
};
let key_schema = topic_req.key_schema.map(|ks_req| {
let format = match ks_req.format.as_str() {
"json" => SchemaFormat::Json,
"avro" => SchemaFormat::Avro,
"protobuf" => SchemaFormat::Protobuf,
_ => SchemaFormat::Json, };
TopicSchema {
format,
schema: ks_req.schema,
schema_id: ks_req.schema_id,
version: ks_req.version,
}
});
let evolution_rules = topic_req.evolution_rules.map(|er_req| {
mockforge_core::contract_drift::EvolutionRules {
allow_backward_compatible: er_req.allow_backward_compatible,
allow_forward_compatible: er_req.allow_forward_compatible,
require_version_bump: er_req.require_version_bump,
}
});
let topic_schema = KafkaTopicSchema {
topic: topic_req.topic,
key_schema,
value_schema,
partitions: topic_req.partitions,
replication_factor: topic_req.replication_factor,
description: topic_req.description,
evolution_rules,
};
contract.add_topic(topic_schema).map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "Failed to add topic",
"message": e.to_string()
})),
)
})?;
}
let mut registry = state.registry.write().await;
registry.register(Box::new(contract));
let contract = registry.get(&request.contract_id).ok_or_else(|| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": "Failed to retrieve registered contract",
"contract_id": request.contract_id
})),
)
})?;
let contract_json = contract.to_json().map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": "Failed to serialize contract",
"message": e.to_string()
})),
)
})?;
Ok(Json(ProtocolContractResponse {
contract_id: request.contract_id,
version: request.version,
protocol: "kafka".to_string(),
contract: contract_json,
}))
}
pub async fn delete_contract(
State(state): State<ProtocolContractState>,
Path(contract_id): Path<String>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
let mut registry = state.registry.write().await;
registry.remove(&contract_id).ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": "Contract not found",
"contract_id": contract_id
})),
)
})?;
Ok(Json(serde_json::json!({
"message": "Contract deleted",
"contract_id": contract_id
})))
}
pub async fn compare_contracts_handler(
State(state): State<ProtocolContractState>,
Json(request): Json<CompareContractsRequest>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
let registry = state.registry.read().await;
let old_contract = registry.get(&request.old_contract_id).ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": "Old contract not found",
"contract_id": request.old_contract_id
})),
)
})?;
let new_contract = registry.get(&request.new_contract_id).ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": "New contract not found",
"contract_id": request.new_contract_id
})),
)
})?;
let diff_result = compare_contracts(old_contract, new_contract).await.map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "Failed to compare contracts",
"message": e.to_string()
})),
)
})?;
let mut drift_evaluation = None;
if let (Some(ref drift_engine), Some(ref incident_manager)) =
(&state.drift_engine, &state.incident_manager)
{
let protocol = new_contract.protocol();
let operations = new_contract.operations();
for operation in operations {
let operation_id = &operation.id;
let (endpoint, method) = match &operation.operation_type {
mockforge_core::contract_drift::protocol_contracts::OperationType::HttpEndpoint { path, method } => {
(path.clone(), method.clone())
}
mockforge_core::contract_drift::protocol_contracts::OperationType::GrpcMethod { service, method } => {
(format!("{}.{}", service, method), "grpc".to_string())
}
mockforge_core::contract_drift::protocol_contracts::OperationType::WebSocketMessage { message_type, .. } => {
(message_type.clone(), "websocket".to_string())
}
mockforge_core::contract_drift::protocol_contracts::OperationType::MqttTopic { topic, qos: _ } => {
(topic.clone(), "mqtt".to_string())
}
mockforge_core::contract_drift::protocol_contracts::OperationType::KafkaTopic { topic, key_schema: _, value_schema: _ } => {
(topic.clone(), "kafka".to_string())
}
};
let drift_result = drift_engine.evaluate(&diff_result, &endpoint, &method);
let mut drift_result_with_fitness = drift_result.clone();
if let Some(ref fitness_registry) = state.fitness_registry {
let guard = fitness_registry.read().await;
if let Ok(results) = guard.evaluate_all_protocol(
Some(old_contract),
new_contract,
&diff_result,
operation_id,
None, None, ) {
drift_result_with_fitness.fitness_test_results = results;
if drift_result_with_fitness.fitness_test_results.iter().any(|r| !r.passed) {
drift_result_with_fitness.should_create_incident = true;
}
}
}
if let Some(ref consumer_analyzer) = state.consumer_analyzer {
let guard = consumer_analyzer.read().await;
let impact =
guard.analyze_impact_with_operation_id(&endpoint, &method, Some(operation_id));
if let Some(impact) = impact {
drift_result_with_fitness.consumer_impact = Some(impact);
}
}
if drift_result_with_fitness.should_create_incident {
let incident_type = if drift_result_with_fitness.breaking_changes > 0 {
mockforge_core::incidents::types::IncidentType::BreakingChange
} else {
mockforge_core::incidents::types::IncidentType::ThresholdExceeded
};
let severity = if drift_result_with_fitness.breaking_changes > 0 {
mockforge_core::incidents::types::IncidentSeverity::High
} else if drift_result_with_fitness.potentially_breaking_changes > 0 {
mockforge_core::incidents::types::IncidentSeverity::Medium
} else {
mockforge_core::incidents::types::IncidentSeverity::Low
};
let details = serde_json::json!({
"breaking_changes": drift_result_with_fitness.breaking_changes,
"potentially_breaking_changes": drift_result_with_fitness.potentially_breaking_changes,
"non_breaking_changes": drift_result_with_fitness.non_breaking_changes,
"budget_exceeded": drift_result_with_fitness.budget_exceeded,
"operation_id": operation_id,
"operation_type": format!("{:?}", operation.operation_type),
});
let before_sample = Some(serde_json::json!({
"contract_id": old_contract.contract_id(),
"version": old_contract.version(),
"protocol": format!("{:?}", old_contract.protocol()),
"operation_id": operation_id,
}));
let after_sample = Some(serde_json::json!({
"contract_id": new_contract.contract_id(),
"version": new_contract.version(),
"protocol": format!("{:?}", new_contract.protocol()),
"operation_id": operation_id,
"mismatches": diff_result.mismatches,
}));
let _incident = incident_manager
.create_incident_with_samples(
endpoint.clone(),
method.clone(),
incident_type,
severity,
details,
None, None, None, None, before_sample,
after_sample,
Some(drift_result_with_fitness.fitness_test_results.clone()),
drift_result_with_fitness.consumer_impact.clone(),
Some(protocol),
)
.await;
}
drift_evaluation = Some(serde_json::json!({
"operation_id": operation_id,
"endpoint": endpoint,
"method": method,
"budget_exceeded": drift_result_with_fitness.budget_exceeded,
"breaking_changes": drift_result_with_fitness.breaking_changes,
"fitness_test_results": drift_result_with_fitness.fitness_test_results,
"consumer_impact": drift_result_with_fitness.consumer_impact,
}));
}
}
Ok(Json(serde_json::json!({
"matches": diff_result.matches,
"confidence": diff_result.confidence,
"mismatches": diff_result.mismatches,
"recommendations": diff_result.recommendations,
"corrections": diff_result.corrections,
"drift_evaluation": drift_evaluation,
})))
}
pub async fn validate_message(
State(state): State<ProtocolContractState>,
Path(contract_id): Path<String>,
Json(request): Json<ValidateMessageRequest>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
let registry = state.registry.read().await;
let contract = registry.get(&contract_id).ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": "Contract not found",
"contract_id": contract_id
})),
)
})?;
let payload_bytes = match request.payload {
serde_json::Value::String(s) => {
general_purpose::STANDARD.decode(&s).unwrap_or_else(|_| s.into_bytes())
}
_ => serde_json::to_vec(&request.payload).map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "Failed to serialize payload",
"message": e.to_string()
})),
)
})?,
};
let contract_request = mockforge_core::contract_drift::protocol_contracts::ContractRequest {
protocol: contract.protocol(),
operation_id: request.operation_id.clone(),
payload: payload_bytes,
content_type: request.content_type,
metadata: request.metadata.unwrap_or_default(),
};
let validation_result =
contract.validate(&request.operation_id, &contract_request).await.map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "Validation failed",
"message": e.to_string()
})),
)
})?;
Ok(Json(serde_json::json!({
"valid": validation_result.valid,
"errors": validation_result.errors,
"warnings": validation_result.warnings,
})))
}
pub fn protocol_contracts_router(state: ProtocolContractState) -> axum::Router {
use axum::routing::{delete, get, post};
axum::Router::new()
.route("/", get(list_contracts))
.route("/{contract_id}", get(get_contract))
.route("/{contract_id}", delete(delete_contract))
.route("/grpc", post(create_grpc_contract))
.route("/websocket", post(create_websocket_contract))
.route("/mqtt", post(create_mqtt_contract))
.route("/kafka", post(create_kafka_contract))
.route("/compare", post(compare_contracts_handler))
.route("/{contract_id}/validate", post(validate_message))
.with_state(state)
}