1use axum::{
6 extract::{Path, Query, State},
7 http::StatusCode,
8 response::Json,
9};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14
15use mockforge_core::consumer_contracts::{
16 ConsumerBreakingChangeDetector, ConsumerIdentifier, ConsumerRegistry, ConsumerType,
17 ConsumerUsage, ConsumerViolation, UsageRecorder,
18};
19
20#[derive(Clone)]
22pub struct ConsumerContractsState {
23 pub registry: Arc<ConsumerRegistry>,
25 pub usage_recorder: Arc<UsageRecorder>,
27 pub detector: Arc<ConsumerBreakingChangeDetector>,
29 pub violations: Arc<RwLock<HashMap<String, Vec<ConsumerViolation>>>>,
31}
32
33#[derive(Debug, Deserialize, Serialize)]
35pub struct RegisterConsumerRequest {
36 pub name: String,
38 pub consumer_type: String,
40 pub identifier: String,
42 pub workspace_id: Option<String>,
44 pub metadata: Option<HashMap<String, serde_json::Value>>,
46}
47
48#[derive(Debug, Serialize)]
50pub struct ConsumerResponse {
51 pub id: String,
53 pub name: String,
55 pub consumer_type: String,
57 pub identifier: String,
59 pub workspace_id: Option<String>,
61 pub created_at: i64,
63}
64
65#[derive(Debug, Deserialize)]
67pub struct ListConsumersRequest {
68 pub workspace_id: Option<String>,
70 pub consumer_type: Option<String>,
72 pub limit: Option<usize>,
74 pub offset: Option<usize>,
76}
77
78#[derive(Debug, Serialize)]
80pub struct ListConsumersResponse {
81 pub consumers: Vec<ConsumerResponse>,
83 pub total: usize,
85}
86
87#[derive(Debug, Serialize)]
89pub struct ConsumerUsageResponse {
90 pub consumer_id: String,
92 pub usage: Vec<ConsumerUsage>,
94}
95
96#[derive(Debug, Serialize)]
98pub struct ConsumerViolationsResponse {
99 pub consumer_id: String,
101 pub violations: Vec<ConsumerViolation>,
103}
104
105pub async fn register_consumer(
109 State(state): State<ConsumerContractsState>,
110 Json(request): Json<RegisterConsumerRequest>,
111) -> Result<Json<ConsumerResponse>, StatusCode> {
112 let consumer_type = match request.consumer_type.as_str() {
113 "workspace" => ConsumerType::Workspace,
114 "custom" => ConsumerType::Custom,
115 "api_key" => ConsumerType::ApiKey,
116 "auth_token" => ConsumerType::AuthToken,
117 _ => return Err(StatusCode::BAD_REQUEST),
118 };
119
120 let identifier = match consumer_type {
121 ConsumerType::Workspace => ConsumerIdentifier::workspace(request.identifier),
122 ConsumerType::Custom => ConsumerIdentifier::custom(request.identifier),
123 ConsumerType::ApiKey => ConsumerIdentifier::api_key(request.identifier),
124 ConsumerType::AuthToken => ConsumerIdentifier::auth_token(request.identifier),
125 };
126
127 let consumer = state
128 .registry
129 .get_or_create(identifier, request.name.clone(), request.workspace_id.clone())
130 .await;
131
132 Ok(Json(ConsumerResponse {
133 id: consumer.id,
134 name: consumer.name,
135 consumer_type: format!("{:?}", consumer.identifier.consumer_type),
136 identifier: consumer.identifier.value,
137 workspace_id: consumer.workspace_id,
138 created_at: consumer.created_at,
139 }))
140}
141
142pub async fn list_consumers(
146 State(state): State<ConsumerContractsState>,
147 Query(params): Query<HashMap<String, String>>,
148) -> Result<Json<ListConsumersResponse>, StatusCode> {
149 let mut consumers = state.registry.list_all().await;
150
151 if let Some(workspace_id) = params.get("workspace_id") {
153 consumers.retain(|c| c.workspace_id.as_ref().map(|w| w == workspace_id).unwrap_or(false));
154 }
155
156 if let Some(consumer_type_str) = params.get("consumer_type") {
157 let consumer_type = match consumer_type_str.as_str() {
158 "workspace" => ConsumerType::Workspace,
159 "custom" => ConsumerType::Custom,
160 "api_key" => ConsumerType::ApiKey,
161 "auth_token" => ConsumerType::AuthToken,
162 _ => return Err(StatusCode::BAD_REQUEST),
163 };
164 consumers.retain(|c| c.identifier.consumer_type == consumer_type);
165 }
166
167 let total = consumers.len();
168
169 let offset = params.get("offset").and_then(|s| s.parse().ok()).unwrap_or(0);
171 let limit = params.get("limit").and_then(|s| s.parse().ok()).unwrap_or(100);
172
173 consumers = consumers.into_iter().skip(offset).take(limit).collect();
174
175 let consumer_responses: Vec<ConsumerResponse> = consumers
176 .into_iter()
177 .map(|c| ConsumerResponse {
178 id: c.id,
179 name: c.name,
180 consumer_type: format!("{:?}", c.identifier.consumer_type),
181 identifier: c.identifier.value,
182 workspace_id: c.workspace_id,
183 created_at: c.created_at,
184 })
185 .collect();
186
187 Ok(Json(ListConsumersResponse {
188 consumers: consumer_responses,
189 total,
190 }))
191}
192
193pub async fn get_consumer(
197 State(state): State<ConsumerContractsState>,
198 Path(id): Path<String>,
199) -> Result<Json<ConsumerResponse>, StatusCode> {
200 let consumer = state.registry.get_by_id(&id).await.ok_or(StatusCode::NOT_FOUND)?;
201
202 Ok(Json(ConsumerResponse {
203 id: consumer.id,
204 name: consumer.name,
205 consumer_type: format!("{:?}", consumer.identifier.consumer_type),
206 identifier: consumer.identifier.value,
207 workspace_id: consumer.workspace_id,
208 created_at: consumer.created_at,
209 }))
210}
211
212pub async fn get_consumer_usage(
216 State(state): State<ConsumerContractsState>,
217 Path(id): Path<String>,
218) -> Result<Json<ConsumerUsageResponse>, StatusCode> {
219 state.registry.get_by_id(&id).await.ok_or(StatusCode::NOT_FOUND)?;
221
222 let usage = state.usage_recorder.get_usage(&id).await;
223
224 Ok(Json(ConsumerUsageResponse {
225 consumer_id: id,
226 usage,
227 }))
228}
229
230pub async fn get_consumer_violations(
234 State(state): State<ConsumerContractsState>,
235 Path(id): Path<String>,
236) -> Result<Json<ConsumerViolationsResponse>, StatusCode> {
237 state.registry.get_by_id(&id).await.ok_or(StatusCode::NOT_FOUND)?;
239
240 let violations_store = state.violations.read().await;
241 let violations = violations_store.get(&id).cloned().unwrap_or_default();
242
243 Ok(Json(ConsumerViolationsResponse {
244 consumer_id: id,
245 violations,
246 }))
247}
248
249#[derive(Debug, Deserialize)]
251pub struct RecordViolationsRequest {
252 pub endpoint: String,
254 pub method: String,
256 pub diff_result: mockforge_core::ai_contract_diff::ContractDiffResult,
258 pub incident_id: Option<String>,
260}
261
262pub async fn record_consumer_violations(
266 State(state): State<ConsumerContractsState>,
267 Path(id): Path<String>,
268 Json(request): Json<RecordViolationsRequest>,
269) -> Result<Json<ConsumerViolationsResponse>, StatusCode> {
270 state.registry.get_by_id(&id).await.ok_or(StatusCode::NOT_FOUND)?;
272
273 let new_violations = state
275 .detector
276 .detect_violations(
277 &id,
278 &request.endpoint,
279 &request.method,
280 &request.diff_result,
281 request.incident_id,
282 )
283 .await;
284
285 let mut violations_store = state.violations.write().await;
287 let entry = violations_store.entry(id.clone()).or_default();
288 entry.extend(new_violations);
289 let all_violations = entry.clone();
290
291 Ok(Json(ConsumerViolationsResponse {
292 consumer_id: id,
293 violations: all_violations,
294 }))
295}
296
297pub fn consumer_contracts_router(state: ConsumerContractsState) -> axum::Router {
299 use axum::routing::{get, post};
300
301 axum::Router::new()
302 .route("/api/v1/consumers", post(register_consumer))
303 .route("/api/v1/consumers", get(list_consumers))
304 .route("/api/v1/consumers/{id}", get(get_consumer))
305 .route("/api/v1/consumers/{id}/usage", get(get_consumer_usage))
306 .route(
307 "/api/v1/consumers/{id}/violations",
308 get(get_consumer_violations).post(record_consumer_violations),
309 )
310 .with_state(state)
311}