use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::Json,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use mockforge_core::consumer_contracts::{
ConsumerBreakingChangeDetector, ConsumerIdentifier, ConsumerRegistry, ConsumerType,
ConsumerUsage, ConsumerViolation, UsageRecorder,
};
#[derive(Clone)]
pub struct ConsumerContractsState {
pub registry: Arc<ConsumerRegistry>,
pub usage_recorder: Arc<UsageRecorder>,
pub detector: Arc<ConsumerBreakingChangeDetector>,
pub violations: Arc<RwLock<HashMap<String, Vec<ConsumerViolation>>>>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct RegisterConsumerRequest {
pub name: String,
pub consumer_type: String,
pub identifier: String,
pub workspace_id: Option<String>,
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize)]
pub struct ConsumerResponse {
pub id: String,
pub name: String,
pub consumer_type: String,
pub identifier: String,
pub workspace_id: Option<String>,
pub created_at: i64,
}
#[derive(Debug, Deserialize)]
pub struct ListConsumersRequest {
pub workspace_id: Option<String>,
pub consumer_type: Option<String>,
pub limit: Option<usize>,
pub offset: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct ListConsumersResponse {
pub consumers: Vec<ConsumerResponse>,
pub total: usize,
}
#[derive(Debug, Serialize)]
pub struct ConsumerUsageResponse {
pub consumer_id: String,
pub usage: Vec<ConsumerUsage>,
}
#[derive(Debug, Serialize)]
pub struct ConsumerViolationsResponse {
pub consumer_id: String,
pub violations: Vec<ConsumerViolation>,
}
pub async fn register_consumer(
State(state): State<ConsumerContractsState>,
Json(request): Json<RegisterConsumerRequest>,
) -> Result<Json<ConsumerResponse>, StatusCode> {
let consumer_type = match request.consumer_type.as_str() {
"workspace" => ConsumerType::Workspace,
"custom" => ConsumerType::Custom,
"api_key" => ConsumerType::ApiKey,
"auth_token" => ConsumerType::AuthToken,
_ => return Err(StatusCode::BAD_REQUEST),
};
let identifier = match consumer_type {
ConsumerType::Workspace => ConsumerIdentifier::workspace(request.identifier),
ConsumerType::Custom => ConsumerIdentifier::custom(request.identifier),
ConsumerType::ApiKey => ConsumerIdentifier::api_key(request.identifier),
ConsumerType::AuthToken => ConsumerIdentifier::auth_token(request.identifier),
};
let consumer = state
.registry
.get_or_create(identifier, request.name.clone(), request.workspace_id.clone())
.await;
Ok(Json(ConsumerResponse {
id: consumer.id,
name: consumer.name,
consumer_type: format!("{:?}", consumer.identifier.consumer_type),
identifier: consumer.identifier.value,
workspace_id: consumer.workspace_id,
created_at: consumer.created_at,
}))
}
pub async fn list_consumers(
State(state): State<ConsumerContractsState>,
Query(params): Query<HashMap<String, String>>,
) -> Result<Json<ListConsumersResponse>, StatusCode> {
let mut consumers = state.registry.list_all().await;
if let Some(workspace_id) = params.get("workspace_id") {
consumers.retain(|c| c.workspace_id.as_ref().map(|w| w == workspace_id).unwrap_or(false));
}
if let Some(consumer_type_str) = params.get("consumer_type") {
let consumer_type = match consumer_type_str.as_str() {
"workspace" => ConsumerType::Workspace,
"custom" => ConsumerType::Custom,
"api_key" => ConsumerType::ApiKey,
"auth_token" => ConsumerType::AuthToken,
_ => return Err(StatusCode::BAD_REQUEST),
};
consumers.retain(|c| c.identifier.consumer_type == consumer_type);
}
let total = consumers.len();
let offset = params.get("offset").and_then(|s| s.parse().ok()).unwrap_or(0);
let limit = params.get("limit").and_then(|s| s.parse().ok()).unwrap_or(100);
consumers = consumers.into_iter().skip(offset).take(limit).collect();
let consumer_responses: Vec<ConsumerResponse> = consumers
.into_iter()
.map(|c| ConsumerResponse {
id: c.id,
name: c.name,
consumer_type: format!("{:?}", c.identifier.consumer_type),
identifier: c.identifier.value,
workspace_id: c.workspace_id,
created_at: c.created_at,
})
.collect();
Ok(Json(ListConsumersResponse {
consumers: consumer_responses,
total,
}))
}
pub async fn get_consumer(
State(state): State<ConsumerContractsState>,
Path(id): Path<String>,
) -> Result<Json<ConsumerResponse>, StatusCode> {
let consumer = state.registry.get_by_id(&id).await.ok_or(StatusCode::NOT_FOUND)?;
Ok(Json(ConsumerResponse {
id: consumer.id,
name: consumer.name,
consumer_type: format!("{:?}", consumer.identifier.consumer_type),
identifier: consumer.identifier.value,
workspace_id: consumer.workspace_id,
created_at: consumer.created_at,
}))
}
pub async fn get_consumer_usage(
State(state): State<ConsumerContractsState>,
Path(id): Path<String>,
) -> Result<Json<ConsumerUsageResponse>, StatusCode> {
state.registry.get_by_id(&id).await.ok_or(StatusCode::NOT_FOUND)?;
let usage = state.usage_recorder.get_usage(&id).await;
Ok(Json(ConsumerUsageResponse {
consumer_id: id,
usage,
}))
}
pub async fn get_consumer_violations(
State(state): State<ConsumerContractsState>,
Path(id): Path<String>,
) -> Result<Json<ConsumerViolationsResponse>, StatusCode> {
state.registry.get_by_id(&id).await.ok_or(StatusCode::NOT_FOUND)?;
let violations_store = state.violations.read().await;
let violations = violations_store.get(&id).cloned().unwrap_or_default();
Ok(Json(ConsumerViolationsResponse {
consumer_id: id,
violations,
}))
}
#[derive(Debug, Deserialize)]
pub struct RecordViolationsRequest {
pub endpoint: String,
pub method: String,
pub diff_result: mockforge_core::ai_contract_diff::ContractDiffResult,
pub incident_id: Option<String>,
}
pub async fn record_consumer_violations(
State(state): State<ConsumerContractsState>,
Path(id): Path<String>,
Json(request): Json<RecordViolationsRequest>,
) -> Result<Json<ConsumerViolationsResponse>, StatusCode> {
state.registry.get_by_id(&id).await.ok_or(StatusCode::NOT_FOUND)?;
let new_violations = state
.detector
.detect_violations(
&id,
&request.endpoint,
&request.method,
&request.diff_result,
request.incident_id,
)
.await;
let mut violations_store = state.violations.write().await;
let entry = violations_store.entry(id.clone()).or_default();
entry.extend(new_violations);
let all_violations = entry.clone();
Ok(Json(ConsumerViolationsResponse {
consumer_id: id,
violations: all_violations,
}))
}
pub fn consumer_contracts_router(state: ConsumerContractsState) -> axum::Router {
use axum::routing::{get, post};
axum::Router::new()
.route("/api/v1/consumers", post(register_consumer))
.route("/api/v1/consumers", get(list_consumers))
.route("/api/v1/consumers/{id}", get(get_consumer))
.route("/api/v1/consumers/{id}/usage", get(get_consumer_usage))
.route(
"/api/v1/consumers/{id}/violations",
get(get_consumer_violations).post(record_consumer_violations),
)
.with_state(state)
}