Skip to main content

mockforge_http/middleware/
drift_tracking.rs

1//! Drift tracking middleware
2//!
3//! This middleware integrates drift budget evaluation and consumer usage tracking
4//! with contract diff analysis.
5
6// Uses DriftBudgetEngine + ContractDiffAnalyzer which stay in core.
7#![allow(deprecated)]
8
9use axum::{body::Body, extract::Request, http::Response, middleware::Next};
10use mockforge_contracts::consumer_contracts::{ConsumerBreakingChangeDetector, UsageRecorder};
11use mockforge_core::{
12    ai_contract_diff::ContractDiffAnalyzer,
13    contract_drift::DriftBudgetEngine,
14    incidents::{IncidentManager, IncidentSeverity, IncidentType},
15    openapi::OpenApiSpec,
16};
17use mockforge_foundation::contract_drift_types::DriftResult;
18use serde_json::Value;
19use std::collections::HashMap;
20use std::sync::Arc;
21use tracing::{debug, warn};
22
23/// Maximum request body size to buffer for drift tracking (1 MB).
24const MAX_DRIFT_BODY_SIZE: usize = 1024 * 1024;
25
26/// State for drift tracking middleware
27#[derive(Clone)]
28pub struct DriftTrackingState {
29    /// Contract diff analyzer
30    pub diff_analyzer: Option<Arc<ContractDiffAnalyzer>>,
31    /// OpenAPI spec (if available)
32    pub spec: Option<Arc<OpenApiSpec>>,
33    /// Drift budget engine
34    pub drift_engine: Arc<DriftBudgetEngine>,
35    /// Incident manager
36    pub incident_manager: Arc<IncidentManager>,
37    /// Usage recorder for consumer contracts
38    pub usage_recorder: Arc<UsageRecorder>,
39    /// Consumer breaking change detector
40    pub consumer_detector: Arc<ConsumerBreakingChangeDetector>,
41    /// Whether drift tracking is enabled
42    pub enabled: bool,
43}
44
45/// Middleware to track drift and consumer usage (with state from extensions)
46///
47/// This middleware requires response body buffering middleware to be applied first.
48/// The response body is extracted from buffered response extensions.
49pub async fn drift_tracking_middleware_with_extensions(
50    req: Request<Body>,
51    next: Next,
52) -> Response<Body> {
53    // Extract state from request extensions
54    let state = req.extensions().get::<DriftTrackingState>().cloned();
55
56    let state = if let Some(state) = state {
57        state
58    } else {
59        // No state available, skip drift tracking
60        return next.run(req).await;
61    };
62
63    if !state.enabled {
64        return next.run(req).await;
65    }
66
67    let method = req.method().to_string();
68    let path = req.uri().path().to_string();
69
70    // Extract consumer identifier and headers from request
71    let consumer_id = extract_consumer_id(&req);
72
73    // Extract headers for capture
74    let captured_headers = extract_headers_for_capture(&req);
75
76    // Buffer the request body so we can capture it and still forward it
77    let (parts, body) = req.into_parts();
78    let body_bytes = match axum::body::to_bytes(body, MAX_DRIFT_BODY_SIZE).await {
79        Ok(b) => b,
80        Err(_) => {
81            // Body too large or read error — forward without capturing body
82            let rebuilt = Request::from_parts(parts, Body::empty());
83            return next.run(rebuilt).await;
84        }
85    };
86
87    // Try to parse body as JSON for structured capture
88    let captured_body = if !body_bytes.is_empty() {
89        serde_json::from_slice::<Value>(&body_bytes).ok()
90    } else {
91        None
92    };
93
94    // Reconstruct the request with the buffered body
95    let req = Request::from_parts(parts, Body::from(body_bytes));
96
97    // Process request and get response
98    let response = next.run(req).await;
99
100    // Extract response body for consumer usage tracking
101    let response_body = extract_response_body(&response);
102
103    // Record consumer usage if consumer is identified
104    if let Some(ref consumer_id) = consumer_id {
105        if let Some(body) = &response_body {
106            state.usage_recorder.record_usage(consumer_id, &path, &method, Some(body)).await;
107        }
108    }
109
110    // Perform contract diff analysis if analyzer and spec are available
111    if let (Some(ref analyzer), Some(ref spec)) = (&state.diff_analyzer, &state.spec) {
112        // Create captured request with body and headers
113        let mut captured = mockforge_core::ai_contract_diff::CapturedRequest::new(
114            &method,
115            &path,
116            "drift_tracking",
117        )
118        .with_headers(captured_headers)
119        .with_response(response.status().as_u16(), response_body.clone());
120
121        if let Some(body_value) = captured_body {
122            captured = captured.with_body(body_value);
123        }
124
125        // Analyze request against contract
126        match analyzer.analyze(&captured, spec).await {
127            Ok(diff_result) => {
128                // Evaluate drift budget
129                let drift_result = state.drift_engine.evaluate(&diff_result, &path, &method);
130
131                // Record contracts pillar usage for drift detection
132                mockforge_core::pillar_tracking::record_contracts_usage(
133                    None, // workspace_id could be extracted from request if available
134                    None,
135                    "drift_detection",
136                    serde_json::json!({
137                        "endpoint": path,
138                        "method": method,
139                        "breaking_changes": drift_result.breaking_changes,
140                        "non_breaking_changes": drift_result.non_breaking_changes,
141                        "incident": drift_result.should_create_incident
142                    }),
143                )
144                .await;
145
146                // If this endpoint has a drift budget configured, record that so
147                // the Contracts pillar dashboard can count distinct budgeted endpoints.
148                let endpoint_key = format!("{} {}", method, path);
149                let budget_config = state.drift_engine.config();
150                if budget_config.enabled
151                    && (budget_config.per_endpoint_budgets.contains_key(&endpoint_key)
152                        || budget_config.default_budget.is_some())
153                {
154                    mockforge_core::pillar_tracking::record_contracts_usage(
155                        None,
156                        None,
157                        "drift_budget_configured",
158                        serde_json::json!({
159                            "endpoint": endpoint_key,
160                        }),
161                    )
162                    .await;
163                }
164
165                // Create incident if budget is exceeded or breaking changes detected
166                if drift_result.should_create_incident {
167                    let incident_type = if drift_result.breaking_changes > 0 {
168                        IncidentType::BreakingChange
169                    } else {
170                        IncidentType::ThresholdExceeded
171                    };
172
173                    let severity = determine_severity(&drift_result);
174
175                    let details = serde_json::json!({
176                        "breaking_changes": drift_result.breaking_changes,
177                        "non_breaking_changes": drift_result.non_breaking_changes,
178                        "breaking_mismatches": drift_result.breaking_mismatches,
179                        "non_breaking_mismatches": drift_result.non_breaking_mismatches,
180                        "budget_exceeded": drift_result.budget_exceeded,
181                    });
182
183                    // Create incident with before/after samples
184                    // Extract before/after samples from diff result if available
185                    let before_sample = Some(serde_json::json!({
186                        "contract_format": diff_result.metadata.contract_format,
187                        "contract_version": diff_result.metadata.contract_version,
188                        "endpoint": path,
189                        "method": method,
190                    }));
191
192                    let after_sample = Some(serde_json::json!({
193                        "mismatches": diff_result.mismatches,
194                        "recommendations": diff_result.recommendations,
195                        "corrections": diff_result.corrections,
196                    }));
197
198                    let _incident = state
199                        .incident_manager
200                        .create_incident_with_samples(
201                            path.clone(),
202                            method.clone(),
203                            incident_type,
204                            severity,
205                            details,
206                            None, // budget_id
207                            None, // workspace_id
208                            None, // sync_cycle_id
209                            None, // contract_diff_id (could be generated from diff_result)
210                            before_sample,
211                            after_sample,
212                            Some(drift_result.fitness_test_results.clone()), // fitness_test_results
213                            drift_result.consumer_impact.clone(),            // affected_consumers
214                            Some(mockforge_foundation::protocol::Protocol::Http), // protocol
215                        )
216                        .await;
217
218                    warn!(
219                        "Drift incident created: {} {} - {} breaking changes, {} non-breaking changes",
220                        method, path, drift_result.breaking_changes, drift_result.non_breaking_changes
221                    );
222                }
223
224                // Check for consumer-specific violations
225                if let Some(ref consumer_id) = consumer_id {
226                    let violations = state
227                        .consumer_detector
228                        .detect_violations(consumer_id, &path, &method, &diff_result, None)
229                        .await;
230
231                    if !violations.is_empty() {
232                        warn!(
233                            "Consumer {} has {} violations on {} {}",
234                            consumer_id,
235                            violations.len(),
236                            method,
237                            path
238                        );
239                    }
240                }
241            }
242            Err(e) => {
243                debug!("Contract diff analysis failed: {}", e);
244            }
245        }
246    }
247
248    response
249}
250
251/// Extract consumer identifier from request
252fn extract_consumer_id(req: &Request<Body>) -> Option<String> {
253    // Try to extract from various sources:
254    // 1. X-Consumer-ID header
255    if let Some(consumer_id) = req.headers().get("x-consumer-id").and_then(|h| h.to_str().ok()) {
256        return Some(consumer_id.to_string());
257    }
258
259    // 2. X-Workspace-ID header (for workspace-based consumers)
260    if let Some(workspace_id) = req.headers().get("x-workspace-id").and_then(|h| h.to_str().ok()) {
261        return Some(format!("workspace:{}", workspace_id));
262    }
263
264    // 3. API key from header
265    if let Some(api_key) = req
266        .headers()
267        .get("x-api-key")
268        .or_else(|| req.headers().get("authorization"))
269        .and_then(|h| h.to_str().ok())
270    {
271        // Hash the API key for privacy
272        use sha2::{Digest, Sha256};
273        let mut hasher = Sha256::new();
274        hasher.update(api_key.as_bytes());
275        let hash = format!("{:x}", hasher.finalize());
276        return Some(format!("api_key:{}", hash));
277    }
278
279    None
280}
281
282/// Extract safe headers for drift capture
283fn extract_headers_for_capture(req: &Request<Body>) -> HashMap<String, String> {
284    let safe_headers = [
285        "accept",
286        "accept-encoding",
287        "content-type",
288        "content-length",
289        "user-agent",
290    ];
291    let mut captured = HashMap::new();
292    for name in safe_headers {
293        if let Some(value) = req.headers().get(name).and_then(|v| v.to_str().ok()) {
294            captured.insert(name.to_string(), value.to_string());
295        }
296    }
297    captured
298}
299
300/// Extract response body as JSON value
301fn extract_response_body(response: &Response<Body>) -> Option<Value> {
302    // Try to get buffered response from extensions
303    if let Some(buffered) = crate::middleware::get_buffered_response(response) {
304        return buffered.json();
305    }
306
307    // If not buffered, try to parse from response body
308    // Note: This requires the response body to be buffered by upstream middleware
309    None
310}
311
312/// Determine incident severity from drift result
313fn determine_severity(drift_result: &DriftResult) -> IncidentSeverity {
314    if drift_result.breaking_changes > 0 {
315        // Check if any breaking mismatch is critical
316        if drift_result
317            .breaking_mismatches
318            .iter()
319            .any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::Critical)
320        {
321            return IncidentSeverity::Critical;
322        }
323        // Check if any breaking mismatch is high
324        if drift_result
325            .breaking_mismatches
326            .iter()
327            .any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::High)
328        {
329            return IncidentSeverity::High;
330        }
331        return IncidentSeverity::Medium;
332    }
333
334    // Non-breaking changes are lower severity
335    if drift_result.non_breaking_changes > 5 {
336        IncidentSeverity::Medium
337    } else {
338        IncidentSeverity::Low
339    }
340}