Skip to main content

mockforge_http/middleware/
drift_tracking.rs

1//! Drift tracking middleware
2//!
3//! This middleware integrates drift budget evaluation and consumer usage tracking
4//! with contract diff analysis.
5
6use axum::{body::Body, extract::Request, http::Response, middleware::Next};
7use mockforge_core::{
8    ai_contract_diff::ContractDiffAnalyzer,
9    consumer_contracts::{ConsumerBreakingChangeDetector, UsageRecorder},
10    contract_drift::{DriftBudgetEngine, DriftResult},
11    incidents::{IncidentManager, IncidentSeverity, IncidentType},
12    openapi::OpenApiSpec,
13};
14use serde_json::Value;
15use std::collections::HashMap;
16use std::sync::Arc;
17use tracing::{debug, warn};
18
19/// Maximum request body size to buffer for drift tracking (1 MB).
20const MAX_DRIFT_BODY_SIZE: usize = 1024 * 1024;
21
22/// State for drift tracking middleware
23#[derive(Clone)]
24pub struct DriftTrackingState {
25    /// Contract diff analyzer
26    pub diff_analyzer: Option<Arc<ContractDiffAnalyzer>>,
27    /// OpenAPI spec (if available)
28    pub spec: Option<Arc<OpenApiSpec>>,
29    /// Drift budget engine
30    pub drift_engine: Arc<DriftBudgetEngine>,
31    /// Incident manager
32    pub incident_manager: Arc<IncidentManager>,
33    /// Usage recorder for consumer contracts
34    pub usage_recorder: Arc<UsageRecorder>,
35    /// Consumer breaking change detector
36    pub consumer_detector: Arc<ConsumerBreakingChangeDetector>,
37    /// Whether drift tracking is enabled
38    pub enabled: bool,
39}
40
41/// Middleware to track drift and consumer usage (with state from extensions)
42///
43/// This middleware requires response body buffering middleware to be applied first.
44/// The response body is extracted from buffered response extensions.
45pub async fn drift_tracking_middleware_with_extensions(
46    req: Request<Body>,
47    next: Next,
48) -> Response<Body> {
49    // Extract state from request extensions
50    let state = req.extensions().get::<DriftTrackingState>().cloned();
51
52    let state = if let Some(state) = state {
53        state
54    } else {
55        // No state available, skip drift tracking
56        return next.run(req).await;
57    };
58
59    if !state.enabled {
60        return next.run(req).await;
61    }
62
63    let method = req.method().to_string();
64    let path = req.uri().path().to_string();
65
66    // Extract consumer identifier and headers from request
67    let consumer_id = extract_consumer_id(&req);
68
69    // Extract headers for capture
70    let captured_headers = extract_headers_for_capture(&req);
71
72    // Buffer the request body so we can capture it and still forward it
73    let (parts, body) = req.into_parts();
74    let body_bytes = match axum::body::to_bytes(body, MAX_DRIFT_BODY_SIZE).await {
75        Ok(b) => b,
76        Err(_) => {
77            // Body too large or read error — forward without capturing body
78            let rebuilt = Request::from_parts(parts, Body::empty());
79            return next.run(rebuilt).await;
80        }
81    };
82
83    // Try to parse body as JSON for structured capture
84    let captured_body = if !body_bytes.is_empty() {
85        serde_json::from_slice::<serde_json::Value>(&body_bytes).ok()
86    } else {
87        None
88    };
89
90    // Reconstruct the request with the buffered body
91    let req = Request::from_parts(parts, Body::from(body_bytes));
92
93    // Process request and get response
94    let response = next.run(req).await;
95
96    // Extract response body for consumer usage tracking
97    let response_body = extract_response_body(&response);
98
99    // Record consumer usage if consumer is identified
100    if let Some(ref consumer_id) = consumer_id {
101        if let Some(body) = &response_body {
102            state.usage_recorder.record_usage(consumer_id, &path, &method, Some(body)).await;
103        }
104    }
105
106    // Perform contract diff analysis if analyzer and spec are available
107    if let (Some(ref analyzer), Some(ref spec)) = (&state.diff_analyzer, &state.spec) {
108        // Create captured request with body and headers
109        let mut captured = mockforge_core::ai_contract_diff::CapturedRequest::new(
110            &method,
111            &path,
112            "drift_tracking",
113        )
114        .with_headers(captured_headers)
115        .with_response(response.status().as_u16(), response_body.clone());
116
117        if let Some(body_value) = captured_body {
118            captured = captured.with_body(body_value);
119        }
120
121        // Analyze request against contract
122        match analyzer.analyze(&captured, spec).await {
123            Ok(diff_result) => {
124                // Evaluate drift budget
125                let drift_result = state.drift_engine.evaluate(&diff_result, &path, &method);
126
127                // Record contracts pillar usage for drift detection
128                mockforge_core::pillar_tracking::record_contracts_usage(
129                    None, // workspace_id could be extracted from request if available
130                    None,
131                    "drift_detection",
132                    serde_json::json!({
133                        "endpoint": path,
134                        "method": method,
135                        "breaking_changes": drift_result.breaking_changes,
136                        "non_breaking_changes": drift_result.non_breaking_changes,
137                        "incident": drift_result.should_create_incident
138                    }),
139                )
140                .await;
141
142                // Create incident if budget is exceeded or breaking changes detected
143                if drift_result.should_create_incident {
144                    let incident_type = if drift_result.breaking_changes > 0 {
145                        IncidentType::BreakingChange
146                    } else {
147                        IncidentType::ThresholdExceeded
148                    };
149
150                    let severity = determine_severity(&drift_result);
151
152                    let details = serde_json::json!({
153                        "breaking_changes": drift_result.breaking_changes,
154                        "non_breaking_changes": drift_result.non_breaking_changes,
155                        "breaking_mismatches": drift_result.breaking_mismatches,
156                        "non_breaking_mismatches": drift_result.non_breaking_mismatches,
157                        "budget_exceeded": drift_result.budget_exceeded,
158                    });
159
160                    // Create incident with before/after samples
161                    // Extract before/after samples from diff result if available
162                    let before_sample = Some(serde_json::json!({
163                        "contract_format": diff_result.metadata.contract_format,
164                        "contract_version": diff_result.metadata.contract_version,
165                        "endpoint": path,
166                        "method": method,
167                    }));
168
169                    let after_sample = Some(serde_json::json!({
170                        "mismatches": diff_result.mismatches,
171                        "recommendations": diff_result.recommendations,
172                        "corrections": diff_result.corrections,
173                    }));
174
175                    let _incident = state
176                        .incident_manager
177                        .create_incident_with_samples(
178                            path.clone(),
179                            method.clone(),
180                            incident_type,
181                            severity,
182                            details,
183                            None, // budget_id
184                            None, // workspace_id
185                            None, // sync_cycle_id
186                            None, // contract_diff_id (could be generated from diff_result)
187                            before_sample,
188                            after_sample,
189                            Some(drift_result.fitness_test_results.clone()), // fitness_test_results
190                            drift_result.consumer_impact.clone(),            // affected_consumers
191                            Some(mockforge_core::protocol_abstraction::Protocol::Http), // protocol
192                        )
193                        .await;
194
195                    warn!(
196                        "Drift incident created: {} {} - {} breaking changes, {} non-breaking changes",
197                        method, path, drift_result.breaking_changes, drift_result.non_breaking_changes
198                    );
199                }
200
201                // Check for consumer-specific violations
202                if let Some(ref consumer_id) = consumer_id {
203                    let violations = state
204                        .consumer_detector
205                        .detect_violations(consumer_id, &path, &method, &diff_result, None)
206                        .await;
207
208                    if !violations.is_empty() {
209                        warn!(
210                            "Consumer {} has {} violations on {} {}",
211                            consumer_id,
212                            violations.len(),
213                            method,
214                            path
215                        );
216                    }
217                }
218            }
219            Err(e) => {
220                debug!("Contract diff analysis failed: {}", e);
221            }
222        }
223    }
224
225    response
226}
227
228/// Extract consumer identifier from request
229fn extract_consumer_id(req: &Request<Body>) -> Option<String> {
230    // Try to extract from various sources:
231    // 1. X-Consumer-ID header
232    if let Some(consumer_id) = req.headers().get("x-consumer-id").and_then(|h| h.to_str().ok()) {
233        return Some(consumer_id.to_string());
234    }
235
236    // 2. X-Workspace-ID header (for workspace-based consumers)
237    if let Some(workspace_id) = req.headers().get("x-workspace-id").and_then(|h| h.to_str().ok()) {
238        return Some(format!("workspace:{}", workspace_id));
239    }
240
241    // 3. API key from header
242    if let Some(api_key) = req
243        .headers()
244        .get("x-api-key")
245        .or_else(|| req.headers().get("authorization"))
246        .and_then(|h| h.to_str().ok())
247    {
248        // Hash the API key for privacy
249        use sha2::{Digest, Sha256};
250        let mut hasher = Sha256::new();
251        hasher.update(api_key.as_bytes());
252        let hash = format!("{:x}", hasher.finalize());
253        return Some(format!("api_key:{}", hash));
254    }
255
256    None
257}
258
259/// Extract safe headers for drift capture
260fn extract_headers_for_capture(req: &Request<Body>) -> HashMap<String, String> {
261    let safe_headers = [
262        "accept",
263        "accept-encoding",
264        "content-type",
265        "content-length",
266        "user-agent",
267    ];
268    let mut captured = HashMap::new();
269    for name in safe_headers {
270        if let Some(value) = req.headers().get(name).and_then(|v| v.to_str().ok()) {
271            captured.insert(name.to_string(), value.to_string());
272        }
273    }
274    captured
275}
276
277/// Extract response body as JSON value
278fn extract_response_body(response: &Response<Body>) -> Option<Value> {
279    // Try to get buffered response from extensions
280    if let Some(buffered) = crate::middleware::get_buffered_response(response) {
281        return buffered.json();
282    }
283
284    // If not buffered, try to parse from response body
285    // Note: This requires the response body to be buffered by upstream middleware
286    None
287}
288
289/// Determine incident severity from drift result
290fn determine_severity(drift_result: &DriftResult) -> IncidentSeverity {
291    if drift_result.breaking_changes > 0 {
292        // Check if any breaking mismatch is critical
293        if drift_result
294            .breaking_mismatches
295            .iter()
296            .any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::Critical)
297        {
298            return IncidentSeverity::Critical;
299        }
300        // Check if any breaking mismatch is high
301        if drift_result
302            .breaking_mismatches
303            .iter()
304            .any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::High)
305        {
306            return IncidentSeverity::High;
307        }
308        return IncidentSeverity::Medium;
309    }
310
311    // Non-breaking changes are lower severity
312    if drift_result.non_breaking_changes > 5 {
313        IncidentSeverity::Medium
314    } else {
315        IncidentSeverity::Low
316    }
317}