mockforge_http/middleware/
drift_tracking.rs1#![allow(deprecated)]
8
9use axum::{body::Body, extract::Request, http::Response, middleware::Next};
10use mockforge_contracts::consumer_contracts::{ConsumerBreakingChangeDetector, UsageRecorder};
11use mockforge_core::{
12 ai_contract_diff::ContractDiffAnalyzer,
13 contract_drift::DriftBudgetEngine,
14 incidents::{IncidentManager, IncidentSeverity, IncidentType},
15 openapi::OpenApiSpec,
16};
17use mockforge_foundation::contract_drift_types::DriftResult;
18use serde_json::Value;
19use std::collections::HashMap;
20use std::sync::Arc;
21use tracing::{debug, warn};
22
23fn max_drift_body_size() -> usize {
30 const DEFAULT_MB: usize = 10;
31 std::env::var("MOCKFORGE_DRIFT_MAX_BODY_MB")
32 .ok()
33 .and_then(|v| v.parse::<usize>().ok())
34 .unwrap_or(DEFAULT_MB)
35 .saturating_mul(1024 * 1024)
36}
37
38#[derive(Clone)]
40pub struct DriftTrackingState {
41 pub diff_analyzer: Option<Arc<ContractDiffAnalyzer>>,
43 pub spec: Option<Arc<OpenApiSpec>>,
45 pub drift_engine: Arc<DriftBudgetEngine>,
47 pub incident_manager: Arc<IncidentManager>,
49 pub usage_recorder: Arc<UsageRecorder>,
51 pub consumer_detector: Arc<ConsumerBreakingChangeDetector>,
53 pub enabled: bool,
55}
56
57pub async fn drift_tracking_middleware_with_extensions(
62 req: Request<Body>,
63 next: Next,
64) -> Response<Body> {
65 let state = req.extensions().get::<DriftTrackingState>().cloned();
67
68 let state = if let Some(state) = state {
69 state
70 } else {
71 return next.run(req).await;
73 };
74
75 if !state.enabled {
76 return next.run(req).await;
77 }
78
79 let method = req.method().to_string();
80 let path = req.uri().path().to_string();
81 let max_body = max_drift_body_size();
82
83 let content_length = req
87 .headers()
88 .get(http::header::CONTENT_LENGTH)
89 .and_then(|v| v.to_str().ok())
90 .and_then(|s| s.parse::<usize>().ok());
91 if let Some(len) = content_length {
92 if len > max_body {
93 debug!(
94 "drift_tracking: skipping capture for {} {} — content-length {} > cap {}",
95 method, path, len, max_body
96 );
97 return next.run(req).await;
98 }
99 }
100
101 let consumer_id = extract_consumer_id(&req);
103
104 let captured_headers = extract_headers_for_capture(&req);
106
107 let (parts, body) = req.into_parts();
109 let body_bytes = match axum::body::to_bytes(body, max_body).await {
110 Ok(b) => b,
111 Err(_) => {
112 return Response::builder()
116 .status(http::StatusCode::PAYLOAD_TOO_LARGE)
117 .header(
118 http::header::CONTENT_TYPE,
119 "application/json",
120 )
121 .body(Body::from(format!(
122 r#"{{"error":"PAYLOAD_TOO_LARGE","message":"chunked request body exceeded drift_tracking capture cap (~{} MiB); raise MOCKFORGE_DRIFT_MAX_BODY_MB or send Content-Length"}}"#,
123 max_body / (1024 * 1024)
124 )))
125 .unwrap_or_else(|_| Response::new(Body::from("payload too large")));
126 }
127 };
128
129 let captured_body = if !body_bytes.is_empty() {
131 serde_json::from_slice::<Value>(&body_bytes).ok()
132 } else {
133 None
134 };
135
136 let req = Request::from_parts(parts, Body::from(body_bytes));
138
139 let response = next.run(req).await;
141
142 let response_body = extract_response_body(&response);
144
145 if let Some(ref consumer_id) = consumer_id {
147 if let Some(body) = &response_body {
148 state.usage_recorder.record_usage(consumer_id, &path, &method, Some(body)).await;
149 }
150 }
151
152 if let (Some(ref analyzer), Some(ref spec)) = (&state.diff_analyzer, &state.spec) {
154 let mut captured = mockforge_core::ai_contract_diff::CapturedRequest::new(
156 &method,
157 &path,
158 "drift_tracking",
159 )
160 .with_headers(captured_headers)
161 .with_response(response.status().as_u16(), response_body.clone());
162
163 if let Some(body_value) = captured_body {
164 captured = captured.with_body(body_value);
165 }
166
167 match analyzer.analyze(&captured, spec).await {
169 Ok(diff_result) => {
170 let drift_result = state.drift_engine.evaluate(&diff_result, &path, &method);
172
173 mockforge_observability::get_global_registry().record_drift_evaluation(
180 mockforge_observability::DriftEvaluationSample {
181 workspace_id: "",
182 endpoint: &path,
183 method: &method,
184 total: drift_result.metrics.total_changes,
185 breaking: drift_result.breaking_changes,
186 potentially_breaking: drift_result.potentially_breaking_changes,
187 budget_exceeded: drift_result.budget_exceeded,
188 },
189 );
190
191 let total = i64::from(drift_result.metrics.total_changes);
202 let drifting = i64::from(drift_result.breaking_changes)
203 + i64::from(drift_result.potentially_breaking_changes);
204 mockforge_analytics::record_drift_percentage_async(
205 String::new(), None,
207 total,
208 drifting,
209 );
210
211 mockforge_core::pillar_tracking::record_contracts_usage(
213 None, None,
215 "drift_detection",
216 serde_json::json!({
217 "endpoint": path,
218 "method": method,
219 "breaking_changes": drift_result.breaking_changes,
220 "non_breaking_changes": drift_result.non_breaking_changes,
221 "incident": drift_result.should_create_incident
222 }),
223 )
224 .await;
225
226 let endpoint_key = format!("{} {}", method, path);
229 let budget_config = state.drift_engine.config();
230 if budget_config.enabled
231 && (budget_config.per_endpoint_budgets.contains_key(&endpoint_key)
232 || budget_config.default_budget.is_some())
233 {
234 mockforge_core::pillar_tracking::record_contracts_usage(
235 None,
236 None,
237 "drift_budget_configured",
238 serde_json::json!({
239 "endpoint": endpoint_key,
240 }),
241 )
242 .await;
243 }
244
245 if drift_result.should_create_incident {
247 let incident_type = if drift_result.breaking_changes > 0 {
248 IncidentType::BreakingChange
249 } else {
250 IncidentType::ThresholdExceeded
251 };
252
253 let severity = determine_severity(&drift_result);
254
255 let details = serde_json::json!({
256 "breaking_changes": drift_result.breaking_changes,
257 "non_breaking_changes": drift_result.non_breaking_changes,
258 "breaking_mismatches": drift_result.breaking_mismatches,
259 "non_breaking_mismatches": drift_result.non_breaking_mismatches,
260 "budget_exceeded": drift_result.budget_exceeded,
261 });
262
263 let before_sample = Some(serde_json::json!({
266 "contract_format": diff_result.metadata.contract_format,
267 "contract_version": diff_result.metadata.contract_version,
268 "endpoint": path,
269 "method": method,
270 }));
271
272 let after_sample = Some(serde_json::json!({
273 "mismatches": diff_result.mismatches,
274 "recommendations": diff_result.recommendations,
275 "corrections": diff_result.corrections,
276 }));
277
278 let _incident = state
279 .incident_manager
280 .create_incident_with_samples(
281 path.clone(),
282 method.clone(),
283 incident_type,
284 severity,
285 details,
286 None, None, None, None, before_sample,
291 after_sample,
292 Some(drift_result.fitness_test_results.clone()), drift_result.consumer_impact.clone(), Some(mockforge_foundation::protocol::Protocol::Http), )
296 .await;
297
298 warn!(
299 "Drift incident created: {} {} - {} breaking changes, {} non-breaking changes",
300 method, path, drift_result.breaking_changes, drift_result.non_breaking_changes
301 );
302 }
303
304 if let Some(ref consumer_id) = consumer_id {
306 let violations = state
307 .consumer_detector
308 .detect_violations(consumer_id, &path, &method, &diff_result, None)
309 .await;
310
311 if !violations.is_empty() {
312 warn!(
313 "Consumer {} has {} violations on {} {}",
314 consumer_id,
315 violations.len(),
316 method,
317 path
318 );
319 }
320 }
321 }
322 Err(e) => {
323 debug!("Contract diff analysis failed: {}", e);
324 }
325 }
326 }
327
328 response
329}
330
331fn extract_consumer_id(req: &Request<Body>) -> Option<String> {
333 if let Some(consumer_id) = req.headers().get("x-consumer-id").and_then(|h| h.to_str().ok()) {
336 return Some(consumer_id.to_string());
337 }
338
339 if let Some(workspace_id) = req.headers().get("x-workspace-id").and_then(|h| h.to_str().ok()) {
341 return Some(format!("workspace:{}", workspace_id));
342 }
343
344 if let Some(api_key) = req
346 .headers()
347 .get("x-api-key")
348 .or_else(|| req.headers().get("authorization"))
349 .and_then(|h| h.to_str().ok())
350 {
351 use sha2::{Digest, Sha256};
353 let mut hasher = Sha256::new();
354 hasher.update(api_key.as_bytes());
355 let hash = format!("{:x}", hasher.finalize());
356 return Some(format!("api_key:{}", hash));
357 }
358
359 None
360}
361
362fn extract_headers_for_capture(req: &Request<Body>) -> HashMap<String, String> {
364 let safe_headers = [
365 "accept",
366 "accept-encoding",
367 "content-type",
368 "content-length",
369 "user-agent",
370 ];
371 let mut captured = HashMap::new();
372 for name in safe_headers {
373 if let Some(value) = req.headers().get(name).and_then(|v| v.to_str().ok()) {
374 captured.insert(name.to_string(), value.to_string());
375 }
376 }
377 captured
378}
379
380fn extract_response_body(response: &Response<Body>) -> Option<Value> {
382 if let Some(buffered) = crate::middleware::get_buffered_response(response) {
384 return buffered.json();
385 }
386
387 None
390}
391
392fn determine_severity(drift_result: &DriftResult) -> IncidentSeverity {
394 if drift_result.breaking_changes > 0 {
395 if drift_result
397 .breaking_mismatches
398 .iter()
399 .any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::Critical)
400 {
401 return IncidentSeverity::Critical;
402 }
403 if drift_result
405 .breaking_mismatches
406 .iter()
407 .any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::High)
408 {
409 return IncidentSeverity::High;
410 }
411 return IncidentSeverity::Medium;
412 }
413
414 if drift_result.non_breaking_changes > 5 {
416 IncidentSeverity::Medium
417 } else {
418 IncidentSeverity::Low
419 }
420}