mockforge_http/middleware/
drift_tracking.rs1use axum::{
7 body::Body,
8 extract::{Request, State},
9 http::Response,
10 middleware::Next,
11};
12use mockforge_core::{
13 ai_contract_diff::ContractDiffAnalyzer,
14 consumer_contracts::{ConsumerBreakingChangeDetector, UsageRecorder},
15 contract_drift::{DriftBudgetEngine, DriftResult},
16 incidents::{IncidentManager, IncidentSeverity, IncidentType},
17 openapi::OpenApiSpec,
18};
19use serde_json::Value;
20use std::sync::Arc;
21use tracing::{debug, warn};
22
23#[derive(Clone)]
25pub struct DriftTrackingState {
26 pub diff_analyzer: Option<Arc<ContractDiffAnalyzer>>,
28 pub spec: Option<Arc<OpenApiSpec>>,
30 pub drift_engine: Arc<DriftBudgetEngine>,
32 pub incident_manager: Arc<IncidentManager>,
34 pub usage_recorder: Arc<UsageRecorder>,
36 pub consumer_detector: Arc<ConsumerBreakingChangeDetector>,
38 pub enabled: bool,
40}
41
42pub async fn drift_tracking_middleware_with_extensions(
47 req: Request<Body>,
48 next: Next,
49) -> Response<Body> {
50 let state = req.extensions().get::<DriftTrackingState>().cloned();
52
53 let state = if let Some(state) = state {
54 state
55 } else {
56 return next.run(req).await;
58 };
59
60 if !state.enabled {
61 return next.run(req).await;
62 }
63
64 let method = req.method().to_string();
65 let path = req.uri().path().to_string();
66
67 let consumer_id = extract_consumer_id(&req);
69
70 let response = next.run(req).await;
72
73 let response_body = extract_response_body(&response);
75
76 if let Some(ref consumer_id) = consumer_id {
78 if let Some(body) = &response_body {
79 state.usage_recorder.record_usage(consumer_id, &path, &method, Some(body)).await;
80 }
81 }
82
83 if let (Some(ref analyzer), Some(ref spec)) = (&state.diff_analyzer, &state.spec) {
85 let captured = mockforge_core::ai_contract_diff::CapturedRequest::new(
89 &method,
90 &path,
91 "drift_tracking",
92 )
93 .with_response(response.status().as_u16(), response_body.clone());
94
95 match analyzer.analyze(&captured, spec).await {
97 Ok(diff_result) => {
98 let drift_result = state.drift_engine.evaluate(&diff_result, &path, &method);
100
101 if drift_result.should_create_incident {
103 let incident_type = if drift_result.breaking_changes > 0 {
104 IncidentType::BreakingChange
105 } else {
106 IncidentType::ThresholdExceeded
107 };
108
109 let severity = determine_severity(&drift_result);
110
111 let details = serde_json::json!({
112 "breaking_changes": drift_result.breaking_changes,
113 "non_breaking_changes": drift_result.non_breaking_changes,
114 "breaking_mismatches": drift_result.breaking_mismatches,
115 "non_breaking_mismatches": drift_result.non_breaking_mismatches,
116 "budget_exceeded": drift_result.budget_exceeded,
117 });
118
119 let before_sample = Some(serde_json::json!({
122 "contract_format": diff_result.metadata.contract_format,
123 "contract_version": diff_result.metadata.contract_version,
124 "endpoint": path,
125 "method": method,
126 }));
127
128 let after_sample = Some(serde_json::json!({
129 "mismatches": diff_result.mismatches,
130 "recommendations": diff_result.recommendations,
131 "corrections": diff_result.corrections,
132 }));
133
134 let _incident = state
135 .incident_manager
136 .create_incident_with_samples(
137 path.clone(),
138 method.clone(),
139 incident_type,
140 severity,
141 details,
142 None, None, None, None, before_sample,
147 after_sample,
148 )
149 .await;
150
151 warn!(
152 "Drift incident created: {} {} - {} breaking changes, {} non-breaking changes",
153 method, path, drift_result.breaking_changes, drift_result.non_breaking_changes
154 );
155 }
156
157 if let Some(ref consumer_id) = consumer_id {
159 let violations = state
160 .consumer_detector
161 .detect_violations(consumer_id, &path, &method, &diff_result, None)
162 .await;
163
164 if !violations.is_empty() {
165 warn!(
166 "Consumer {} has {} violations on {} {}",
167 consumer_id,
168 violations.len(),
169 method,
170 path
171 );
172 }
173 }
174 }
175 Err(e) => {
176 debug!("Contract diff analysis failed: {}", e);
177 }
178 }
179 }
180
181 response
182}
183
184fn extract_consumer_id(req: &Request<Body>) -> Option<String> {
186 if let Some(consumer_id) = req.headers().get("x-consumer-id").and_then(|h| h.to_str().ok()) {
189 return Some(consumer_id.to_string());
190 }
191
192 if let Some(workspace_id) = req.headers().get("x-workspace-id").and_then(|h| h.to_str().ok()) {
194 return Some(format!("workspace:{}", workspace_id));
195 }
196
197 if let Some(api_key) = req
199 .headers()
200 .get("x-api-key")
201 .or_else(|| req.headers().get("authorization"))
202 .and_then(|h| h.to_str().ok())
203 {
204 use sha2::{Digest, Sha256};
206 let mut hasher = Sha256::new();
207 hasher.update(api_key.as_bytes());
208 let hash = format!("{:x}", hasher.finalize());
209 return Some(format!("api_key:{}", hash));
210 }
211
212 None
213}
214
215fn extract_response_body(response: &Response<Body>) -> Option<Value> {
217 if let Some(buffered) = crate::middleware::get_buffered_response(response) {
219 return buffered.json();
220 }
221
222 None
225}
226
227fn determine_severity(drift_result: &DriftResult) -> IncidentSeverity {
229 if drift_result.breaking_changes > 0 {
230 if drift_result
232 .breaking_mismatches
233 .iter()
234 .any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::Critical)
235 {
236 return IncidentSeverity::Critical;
237 }
238 if drift_result
240 .breaking_mismatches
241 .iter()
242 .any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::High)
243 {
244 return IncidentSeverity::High;
245 }
246 return IncidentSeverity::Medium;
247 }
248
249 if drift_result.non_breaking_changes > 5 {
251 IncidentSeverity::Medium
252 } else {
253 IncidentSeverity::Low
254 }
255}