mockforge_http/middleware/
drift_tracking.rs1use axum::{body::Body, extract::Request, http::Response, middleware::Next};
7use mockforge_core::{
8 ai_contract_diff::ContractDiffAnalyzer,
9 consumer_contracts::{ConsumerBreakingChangeDetector, UsageRecorder},
10 contract_drift::{DriftBudgetEngine, DriftResult},
11 incidents::{IncidentManager, IncidentSeverity, IncidentType},
12 openapi::OpenApiSpec,
13};
14use serde_json::Value;
15use std::collections::HashMap;
16use std::sync::Arc;
17use tracing::{debug, warn};
18
19const MAX_DRIFT_BODY_SIZE: usize = 1024 * 1024;
21
22#[derive(Clone)]
24pub struct DriftTrackingState {
25 pub diff_analyzer: Option<Arc<ContractDiffAnalyzer>>,
27 pub spec: Option<Arc<OpenApiSpec>>,
29 pub drift_engine: Arc<DriftBudgetEngine>,
31 pub incident_manager: Arc<IncidentManager>,
33 pub usage_recorder: Arc<UsageRecorder>,
35 pub consumer_detector: Arc<ConsumerBreakingChangeDetector>,
37 pub enabled: bool,
39}
40
41pub async fn drift_tracking_middleware_with_extensions(
46 req: Request<Body>,
47 next: Next,
48) -> Response<Body> {
49 let state = req.extensions().get::<DriftTrackingState>().cloned();
51
52 let state = if let Some(state) = state {
53 state
54 } else {
55 return next.run(req).await;
57 };
58
59 if !state.enabled {
60 return next.run(req).await;
61 }
62
63 let method = req.method().to_string();
64 let path = req.uri().path().to_string();
65
66 let consumer_id = extract_consumer_id(&req);
68
69 let captured_headers = extract_headers_for_capture(&req);
71
72 let (parts, body) = req.into_parts();
74 let body_bytes = match axum::body::to_bytes(body, MAX_DRIFT_BODY_SIZE).await {
75 Ok(b) => b,
76 Err(_) => {
77 let rebuilt = Request::from_parts(parts, Body::empty());
79 return next.run(rebuilt).await;
80 }
81 };
82
83 let captured_body = if !body_bytes.is_empty() {
85 serde_json::from_slice::<serde_json::Value>(&body_bytes).ok()
86 } else {
87 None
88 };
89
90 let req = Request::from_parts(parts, Body::from(body_bytes));
92
93 let response = next.run(req).await;
95
96 let response_body = extract_response_body(&response);
98
99 if let Some(ref consumer_id) = consumer_id {
101 if let Some(body) = &response_body {
102 state.usage_recorder.record_usage(consumer_id, &path, &method, Some(body)).await;
103 }
104 }
105
106 if let (Some(ref analyzer), Some(ref spec)) = (&state.diff_analyzer, &state.spec) {
108 let mut captured = mockforge_core::ai_contract_diff::CapturedRequest::new(
110 &method,
111 &path,
112 "drift_tracking",
113 )
114 .with_headers(captured_headers)
115 .with_response(response.status().as_u16(), response_body.clone());
116
117 if let Some(body_value) = captured_body {
118 captured = captured.with_body(body_value);
119 }
120
121 match analyzer.analyze(&captured, spec).await {
123 Ok(diff_result) => {
124 let drift_result = state.drift_engine.evaluate(&diff_result, &path, &method);
126
127 mockforge_core::pillar_tracking::record_contracts_usage(
129 None, None,
131 "drift_detection",
132 serde_json::json!({
133 "endpoint": path,
134 "method": method,
135 "breaking_changes": drift_result.breaking_changes,
136 "non_breaking_changes": drift_result.non_breaking_changes,
137 "incident": drift_result.should_create_incident
138 }),
139 )
140 .await;
141
142 if drift_result.should_create_incident {
144 let incident_type = if drift_result.breaking_changes > 0 {
145 IncidentType::BreakingChange
146 } else {
147 IncidentType::ThresholdExceeded
148 };
149
150 let severity = determine_severity(&drift_result);
151
152 let details = serde_json::json!({
153 "breaking_changes": drift_result.breaking_changes,
154 "non_breaking_changes": drift_result.non_breaking_changes,
155 "breaking_mismatches": drift_result.breaking_mismatches,
156 "non_breaking_mismatches": drift_result.non_breaking_mismatches,
157 "budget_exceeded": drift_result.budget_exceeded,
158 });
159
160 let before_sample = Some(serde_json::json!({
163 "contract_format": diff_result.metadata.contract_format,
164 "contract_version": diff_result.metadata.contract_version,
165 "endpoint": path,
166 "method": method,
167 }));
168
169 let after_sample = Some(serde_json::json!({
170 "mismatches": diff_result.mismatches,
171 "recommendations": diff_result.recommendations,
172 "corrections": diff_result.corrections,
173 }));
174
175 let _incident = state
176 .incident_manager
177 .create_incident_with_samples(
178 path.clone(),
179 method.clone(),
180 incident_type,
181 severity,
182 details,
183 None, None, None, None, before_sample,
188 after_sample,
189 Some(drift_result.fitness_test_results.clone()), drift_result.consumer_impact.clone(), Some(mockforge_core::protocol_abstraction::Protocol::Http), )
193 .await;
194
195 warn!(
196 "Drift incident created: {} {} - {} breaking changes, {} non-breaking changes",
197 method, path, drift_result.breaking_changes, drift_result.non_breaking_changes
198 );
199 }
200
201 if let Some(ref consumer_id) = consumer_id {
203 let violations = state
204 .consumer_detector
205 .detect_violations(consumer_id, &path, &method, &diff_result, None)
206 .await;
207
208 if !violations.is_empty() {
209 warn!(
210 "Consumer {} has {} violations on {} {}",
211 consumer_id,
212 violations.len(),
213 method,
214 path
215 );
216 }
217 }
218 }
219 Err(e) => {
220 debug!("Contract diff analysis failed: {}", e);
221 }
222 }
223 }
224
225 response
226}
227
228fn extract_consumer_id(req: &Request<Body>) -> Option<String> {
230 if let Some(consumer_id) = req.headers().get("x-consumer-id").and_then(|h| h.to_str().ok()) {
233 return Some(consumer_id.to_string());
234 }
235
236 if let Some(workspace_id) = req.headers().get("x-workspace-id").and_then(|h| h.to_str().ok()) {
238 return Some(format!("workspace:{}", workspace_id));
239 }
240
241 if let Some(api_key) = req
243 .headers()
244 .get("x-api-key")
245 .or_else(|| req.headers().get("authorization"))
246 .and_then(|h| h.to_str().ok())
247 {
248 use sha2::{Digest, Sha256};
250 let mut hasher = Sha256::new();
251 hasher.update(api_key.as_bytes());
252 let hash = format!("{:x}", hasher.finalize());
253 return Some(format!("api_key:{}", hash));
254 }
255
256 None
257}
258
259fn extract_headers_for_capture(req: &Request<Body>) -> HashMap<String, String> {
261 let safe_headers = [
262 "accept",
263 "accept-encoding",
264 "content-type",
265 "content-length",
266 "user-agent",
267 ];
268 let mut captured = HashMap::new();
269 for name in safe_headers {
270 if let Some(value) = req.headers().get(name).and_then(|v| v.to_str().ok()) {
271 captured.insert(name.to_string(), value.to_string());
272 }
273 }
274 captured
275}
276
277fn extract_response_body(response: &Response<Body>) -> Option<Value> {
279 if let Some(buffered) = crate::middleware::get_buffered_response(response) {
281 return buffered.json();
282 }
283
284 None
287}
288
289fn determine_severity(drift_result: &DriftResult) -> IncidentSeverity {
291 if drift_result.breaking_changes > 0 {
292 if drift_result
294 .breaking_mismatches
295 .iter()
296 .any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::Critical)
297 {
298 return IncidentSeverity::Critical;
299 }
300 if drift_result
302 .breaking_mismatches
303 .iter()
304 .any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::High)
305 {
306 return IncidentSeverity::High;
307 }
308 return IncidentSeverity::Medium;
309 }
310
311 if drift_result.non_breaking_changes > 5 {
313 IncidentSeverity::Medium
314 } else {
315 IncidentSeverity::Low
316 }
317}