mockforge_http/middleware/
drift_tracking.rs1#![allow(deprecated)]
8
9use axum::{body::Body, extract::Request, http::Response, middleware::Next};
10use mockforge_contracts::consumer_contracts::{ConsumerBreakingChangeDetector, UsageRecorder};
11use mockforge_core::{
12 ai_contract_diff::ContractDiffAnalyzer,
13 contract_drift::DriftBudgetEngine,
14 incidents::{IncidentManager, IncidentSeverity, IncidentType},
15 openapi::OpenApiSpec,
16};
17use mockforge_foundation::contract_drift_types::DriftResult;
18use serde_json::Value;
19use std::collections::HashMap;
20use std::sync::Arc;
21use tracing::{debug, warn};
22
23const MAX_DRIFT_BODY_SIZE: usize = 1024 * 1024;
25
26#[derive(Clone)]
28pub struct DriftTrackingState {
29 pub diff_analyzer: Option<Arc<ContractDiffAnalyzer>>,
31 pub spec: Option<Arc<OpenApiSpec>>,
33 pub drift_engine: Arc<DriftBudgetEngine>,
35 pub incident_manager: Arc<IncidentManager>,
37 pub usage_recorder: Arc<UsageRecorder>,
39 pub consumer_detector: Arc<ConsumerBreakingChangeDetector>,
41 pub enabled: bool,
43}
44
45pub async fn drift_tracking_middleware_with_extensions(
50 req: Request<Body>,
51 next: Next,
52) -> Response<Body> {
53 let state = req.extensions().get::<DriftTrackingState>().cloned();
55
56 let state = if let Some(state) = state {
57 state
58 } else {
59 return next.run(req).await;
61 };
62
63 if !state.enabled {
64 return next.run(req).await;
65 }
66
67 let method = req.method().to_string();
68 let path = req.uri().path().to_string();
69
70 let consumer_id = extract_consumer_id(&req);
72
73 let captured_headers = extract_headers_for_capture(&req);
75
76 let (parts, body) = req.into_parts();
78 let body_bytes = match axum::body::to_bytes(body, MAX_DRIFT_BODY_SIZE).await {
79 Ok(b) => b,
80 Err(_) => {
81 let rebuilt = Request::from_parts(parts, Body::empty());
83 return next.run(rebuilt).await;
84 }
85 };
86
87 let captured_body = if !body_bytes.is_empty() {
89 serde_json::from_slice::<Value>(&body_bytes).ok()
90 } else {
91 None
92 };
93
94 let req = Request::from_parts(parts, Body::from(body_bytes));
96
97 let response = next.run(req).await;
99
100 let response_body = extract_response_body(&response);
102
103 if let Some(ref consumer_id) = consumer_id {
105 if let Some(body) = &response_body {
106 state.usage_recorder.record_usage(consumer_id, &path, &method, Some(body)).await;
107 }
108 }
109
110 if let (Some(ref analyzer), Some(ref spec)) = (&state.diff_analyzer, &state.spec) {
112 let mut captured = mockforge_core::ai_contract_diff::CapturedRequest::new(
114 &method,
115 &path,
116 "drift_tracking",
117 )
118 .with_headers(captured_headers)
119 .with_response(response.status().as_u16(), response_body.clone());
120
121 if let Some(body_value) = captured_body {
122 captured = captured.with_body(body_value);
123 }
124
125 match analyzer.analyze(&captured, spec).await {
127 Ok(diff_result) => {
128 let drift_result = state.drift_engine.evaluate(&diff_result, &path, &method);
130
131 mockforge_core::pillar_tracking::record_contracts_usage(
133 None, None,
135 "drift_detection",
136 serde_json::json!({
137 "endpoint": path,
138 "method": method,
139 "breaking_changes": drift_result.breaking_changes,
140 "non_breaking_changes": drift_result.non_breaking_changes,
141 "incident": drift_result.should_create_incident
142 }),
143 )
144 .await;
145
146 let endpoint_key = format!("{} {}", method, path);
149 let budget_config = state.drift_engine.config();
150 if budget_config.enabled
151 && (budget_config.per_endpoint_budgets.contains_key(&endpoint_key)
152 || budget_config.default_budget.is_some())
153 {
154 mockforge_core::pillar_tracking::record_contracts_usage(
155 None,
156 None,
157 "drift_budget_configured",
158 serde_json::json!({
159 "endpoint": endpoint_key,
160 }),
161 )
162 .await;
163 }
164
165 if drift_result.should_create_incident {
167 let incident_type = if drift_result.breaking_changes > 0 {
168 IncidentType::BreakingChange
169 } else {
170 IncidentType::ThresholdExceeded
171 };
172
173 let severity = determine_severity(&drift_result);
174
175 let details = serde_json::json!({
176 "breaking_changes": drift_result.breaking_changes,
177 "non_breaking_changes": drift_result.non_breaking_changes,
178 "breaking_mismatches": drift_result.breaking_mismatches,
179 "non_breaking_mismatches": drift_result.non_breaking_mismatches,
180 "budget_exceeded": drift_result.budget_exceeded,
181 });
182
183 let before_sample = Some(serde_json::json!({
186 "contract_format": diff_result.metadata.contract_format,
187 "contract_version": diff_result.metadata.contract_version,
188 "endpoint": path,
189 "method": method,
190 }));
191
192 let after_sample = Some(serde_json::json!({
193 "mismatches": diff_result.mismatches,
194 "recommendations": diff_result.recommendations,
195 "corrections": diff_result.corrections,
196 }));
197
198 let _incident = state
199 .incident_manager
200 .create_incident_with_samples(
201 path.clone(),
202 method.clone(),
203 incident_type,
204 severity,
205 details,
206 None, None, None, None, before_sample,
211 after_sample,
212 Some(drift_result.fitness_test_results.clone()), drift_result.consumer_impact.clone(), Some(mockforge_foundation::protocol::Protocol::Http), )
216 .await;
217
218 warn!(
219 "Drift incident created: {} {} - {} breaking changes, {} non-breaking changes",
220 method, path, drift_result.breaking_changes, drift_result.non_breaking_changes
221 );
222 }
223
224 if let Some(ref consumer_id) = consumer_id {
226 let violations = state
227 .consumer_detector
228 .detect_violations(consumer_id, &path, &method, &diff_result, None)
229 .await;
230
231 if !violations.is_empty() {
232 warn!(
233 "Consumer {} has {} violations on {} {}",
234 consumer_id,
235 violations.len(),
236 method,
237 path
238 );
239 }
240 }
241 }
242 Err(e) => {
243 debug!("Contract diff analysis failed: {}", e);
244 }
245 }
246 }
247
248 response
249}
250
251fn extract_consumer_id(req: &Request<Body>) -> Option<String> {
253 if let Some(consumer_id) = req.headers().get("x-consumer-id").and_then(|h| h.to_str().ok()) {
256 return Some(consumer_id.to_string());
257 }
258
259 if let Some(workspace_id) = req.headers().get("x-workspace-id").and_then(|h| h.to_str().ok()) {
261 return Some(format!("workspace:{}", workspace_id));
262 }
263
264 if let Some(api_key) = req
266 .headers()
267 .get("x-api-key")
268 .or_else(|| req.headers().get("authorization"))
269 .and_then(|h| h.to_str().ok())
270 {
271 use sha2::{Digest, Sha256};
273 let mut hasher = Sha256::new();
274 hasher.update(api_key.as_bytes());
275 let hash = format!("{:x}", hasher.finalize());
276 return Some(format!("api_key:{}", hash));
277 }
278
279 None
280}
281
282fn extract_headers_for_capture(req: &Request<Body>) -> HashMap<String, String> {
284 let safe_headers = [
285 "accept",
286 "accept-encoding",
287 "content-type",
288 "content-length",
289 "user-agent",
290 ];
291 let mut captured = HashMap::new();
292 for name in safe_headers {
293 if let Some(value) = req.headers().get(name).and_then(|v| v.to_str().ok()) {
294 captured.insert(name.to_string(), value.to_string());
295 }
296 }
297 captured
298}
299
300fn extract_response_body(response: &Response<Body>) -> Option<Value> {
302 if let Some(buffered) = crate::middleware::get_buffered_response(response) {
304 return buffered.json();
305 }
306
307 None
310}
311
312fn determine_severity(drift_result: &DriftResult) -> IncidentSeverity {
314 if drift_result.breaking_changes > 0 {
315 if drift_result
317 .breaking_mismatches
318 .iter()
319 .any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::Critical)
320 {
321 return IncidentSeverity::Critical;
322 }
323 if drift_result
325 .breaking_mismatches
326 .iter()
327 .any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::High)
328 {
329 return IncidentSeverity::High;
330 }
331 return IncidentSeverity::Medium;
332 }
333
334 if drift_result.non_breaking_changes > 5 {
336 IncidentSeverity::Medium
337 } else {
338 IncidentSeverity::Low
339 }
340}