Skip to main content

mockforge_http/middleware/
drift_tracking.rs

1//! Drift tracking middleware
2//!
3//! This middleware integrates drift budget evaluation and consumer usage tracking
4//! with contract diff analysis.
5
6// Uses DriftBudgetEngine + ContractDiffAnalyzer which stay in core.
7#![allow(deprecated)]
8
9use axum::{body::Body, extract::Request, http::Response, middleware::Next};
10use mockforge_contracts::consumer_contracts::{ConsumerBreakingChangeDetector, UsageRecorder};
11use mockforge_core::{
12    ai_contract_diff::ContractDiffAnalyzer,
13    contract_drift::DriftBudgetEngine,
14    incidents::{IncidentManager, IncidentSeverity, IncidentType},
15    openapi::OpenApiSpec,
16};
17use mockforge_foundation::contract_drift_types::DriftResult;
18use serde_json::Value;
19use std::collections::HashMap;
20use std::sync::Arc;
21use tracing::{debug, warn};
22
23/// Maximum request body size to buffer for drift tracking.
24///
25/// Issue #79 — see `contract_diff_middleware`. Same root cause: the
26/// old hard-coded 1 MiB cap, paired with the over-cap branch
27/// substituting `Body::empty()`, broke downstream handlers on large
28/// chunked uploads. Configurable via `MOCKFORGE_DRIFT_MAX_BODY_MB`.
29fn max_drift_body_size() -> usize {
30    const DEFAULT_MB: usize = 10;
31    std::env::var("MOCKFORGE_DRIFT_MAX_BODY_MB")
32        .ok()
33        .and_then(|v| v.parse::<usize>().ok())
34        .unwrap_or(DEFAULT_MB)
35        .saturating_mul(1024 * 1024)
36}
37
38/// State for drift tracking middleware
39#[derive(Clone)]
40pub struct DriftTrackingState {
41    /// Contract diff analyzer
42    pub diff_analyzer: Option<Arc<ContractDiffAnalyzer>>,
43    /// OpenAPI spec (if available)
44    pub spec: Option<Arc<OpenApiSpec>>,
45    /// Drift budget engine
46    pub drift_engine: Arc<DriftBudgetEngine>,
47    /// Incident manager
48    pub incident_manager: Arc<IncidentManager>,
49    /// Usage recorder for consumer contracts
50    pub usage_recorder: Arc<UsageRecorder>,
51    /// Consumer breaking change detector
52    pub consumer_detector: Arc<ConsumerBreakingChangeDetector>,
53    /// Whether drift tracking is enabled
54    pub enabled: bool,
55}
56
57/// Middleware to track drift and consumer usage (with state from extensions)
58///
59/// This middleware requires response body buffering middleware to be applied first.
60/// The response body is extracted from buffered response extensions.
61pub async fn drift_tracking_middleware_with_extensions(
62    req: Request<Body>,
63    next: Next,
64) -> Response<Body> {
65    // Extract state from request extensions
66    let state = req.extensions().get::<DriftTrackingState>().cloned();
67
68    let state = if let Some(state) = state {
69        state
70    } else {
71        // No state available, skip drift tracking
72        return next.run(req).await;
73    };
74
75    if !state.enabled {
76        return next.run(req).await;
77    }
78
79    let method = req.method().to_string();
80    let path = req.uri().path().to_string();
81    let max_body = max_drift_body_size();
82
83    // Issue #79 — pre-check Content-Length so we skip capture cleanly
84    // for over-cap requests instead of consuming the body and then
85    // substituting `Body::empty()` (which broke downstream handlers).
86    let content_length = req
87        .headers()
88        .get(http::header::CONTENT_LENGTH)
89        .and_then(|v| v.to_str().ok())
90        .and_then(|s| s.parse::<usize>().ok());
91    if let Some(len) = content_length {
92        if len > max_body {
93            debug!(
94                "drift_tracking: skipping capture for {} {} — content-length {} > cap {}",
95                method, path, len, max_body
96            );
97            return next.run(req).await;
98        }
99    }
100
101    // Extract consumer identifier and headers from request
102    let consumer_id = extract_consumer_id(&req);
103
104    // Extract headers for capture
105    let captured_headers = extract_headers_for_capture(&req);
106
107    // Buffer the request body so we can capture it and still forward it
108    let (parts, body) = req.into_parts();
109    let body_bytes = match axum::body::to_bytes(body, max_body).await {
110        Ok(b) => b,
111        Err(_) => {
112            // Chunked over-cap body — body partially consumed and we
113            // cannot rebuild. 413 PayloadTooLarge instead of the old
114            // silent `Body::empty()` substitution. Issue #79.
115            return Response::builder()
116                .status(http::StatusCode::PAYLOAD_TOO_LARGE)
117                .header(
118                    http::header::CONTENT_TYPE,
119                    "application/json",
120                )
121                .body(Body::from(format!(
122                    r#"{{"error":"PAYLOAD_TOO_LARGE","message":"chunked request body exceeded drift_tracking capture cap (~{} MiB); raise MOCKFORGE_DRIFT_MAX_BODY_MB or send Content-Length"}}"#,
123                    max_body / (1024 * 1024)
124                )))
125                .unwrap_or_else(|_| Response::new(Body::from("payload too large")));
126        }
127    };
128
129    // Try to parse body as JSON for structured capture
130    let captured_body = if !body_bytes.is_empty() {
131        serde_json::from_slice::<Value>(&body_bytes).ok()
132    } else {
133        None
134    };
135
136    // Reconstruct the request with the buffered body
137    let req = Request::from_parts(parts, Body::from(body_bytes));
138
139    // Process request and get response
140    let response = next.run(req).await;
141
142    // Extract response body for consumer usage tracking
143    let response_body = extract_response_body(&response);
144
145    // Record consumer usage if consumer is identified
146    if let Some(ref consumer_id) = consumer_id {
147        if let Some(body) = &response_body {
148            state.usage_recorder.record_usage(consumer_id, &path, &method, Some(body)).await;
149        }
150    }
151
152    // Perform contract diff analysis if analyzer and spec are available
153    if let (Some(ref analyzer), Some(ref spec)) = (&state.diff_analyzer, &state.spec) {
154        // Create captured request with body and headers
155        let mut captured = mockforge_core::ai_contract_diff::CapturedRequest::new(
156            &method,
157            &path,
158            "drift_tracking",
159        )
160        .with_headers(captured_headers)
161        .with_response(response.status().as_u16(), response_body.clone());
162
163        if let Some(body_value) = captured_body {
164            captured = captured.with_body(body_value);
165        }
166
167        // Analyze request against contract
168        match analyzer.analyze(&captured, spec).await {
169            Ok(diff_result) => {
170                // Evaluate drift budget
171                let drift_result = state.drift_engine.evaluate(&diff_result, &path, &method);
172
173                // Emit Prometheus gauges for the drift dashboard (#678). The
174                // workspace_id is not yet plumbed through the middleware —
175                // pass an empty label so a single global series records,
176                // matching the pattern used for the non-workspace
177                // `pillar_tracking` calls below. When workspace identification
178                // lands we can switch this label to the resolved id.
179                mockforge_observability::get_global_registry().record_drift_evaluation(
180                    mockforge_observability::DriftEvaluationSample {
181                        workspace_id: "",
182                        endpoint: &path,
183                        method: &method,
184                        total: drift_result.metrics.total_changes,
185                        breaking: drift_result.breaking_changes,
186                        potentially_breaking: drift_result.potentially_breaking_changes,
187                        budget_exceeded: drift_result.budget_exceeded,
188                    },
189                );
190
191                // #677 — when analytics is installed, fire a drift_percentage
192                // sample at the analytics sqlite. Stays a no-op when the
193                // global DB hasn't been initialised, so OSS quick-start
194                // doesn't accidentally create a sqlite file.
195                //
196                // Heuristic for total/drifting: total_changes counts the
197                // observed mismatches (the "denominator" for how many things
198                // we looked at this request), and breaking + potentially-
199                // breaking is what counts as drift. Recorded per-request so
200                // the dashboard's time-series view shows the live trend.
201                let total = i64::from(drift_result.metrics.total_changes);
202                let drifting = i64::from(drift_result.breaking_changes)
203                    + i64::from(drift_result.potentially_breaking_changes);
204                mockforge_analytics::record_drift_percentage_async(
205                    String::new(), // workspace_id — see Prometheus emission note
206                    None,
207                    total,
208                    drifting,
209                );
210
211                // Record contracts pillar usage for drift detection
212                mockforge_core::pillar_tracking::record_contracts_usage(
213                    None, // workspace_id could be extracted from request if available
214                    None,
215                    "drift_detection",
216                    serde_json::json!({
217                        "endpoint": path,
218                        "method": method,
219                        "breaking_changes": drift_result.breaking_changes,
220                        "non_breaking_changes": drift_result.non_breaking_changes,
221                        "incident": drift_result.should_create_incident
222                    }),
223                )
224                .await;
225
226                // If this endpoint has a drift budget configured, record that so
227                // the Contracts pillar dashboard can count distinct budgeted endpoints.
228                let endpoint_key = format!("{} {}", method, path);
229                let budget_config = state.drift_engine.config();
230                if budget_config.enabled
231                    && (budget_config.per_endpoint_budgets.contains_key(&endpoint_key)
232                        || budget_config.default_budget.is_some())
233                {
234                    mockforge_core::pillar_tracking::record_contracts_usage(
235                        None,
236                        None,
237                        "drift_budget_configured",
238                        serde_json::json!({
239                            "endpoint": endpoint_key,
240                        }),
241                    )
242                    .await;
243                }
244
245                // Create incident if budget is exceeded or breaking changes detected
246                if drift_result.should_create_incident {
247                    let incident_type = if drift_result.breaking_changes > 0 {
248                        IncidentType::BreakingChange
249                    } else {
250                        IncidentType::ThresholdExceeded
251                    };
252
253                    let severity = determine_severity(&drift_result);
254
255                    let details = serde_json::json!({
256                        "breaking_changes": drift_result.breaking_changes,
257                        "non_breaking_changes": drift_result.non_breaking_changes,
258                        "breaking_mismatches": drift_result.breaking_mismatches,
259                        "non_breaking_mismatches": drift_result.non_breaking_mismatches,
260                        "budget_exceeded": drift_result.budget_exceeded,
261                    });
262
263                    // Create incident with before/after samples
264                    // Extract before/after samples from diff result if available
265                    let before_sample = Some(serde_json::json!({
266                        "contract_format": diff_result.metadata.contract_format,
267                        "contract_version": diff_result.metadata.contract_version,
268                        "endpoint": path,
269                        "method": method,
270                    }));
271
272                    let after_sample = Some(serde_json::json!({
273                        "mismatches": diff_result.mismatches,
274                        "recommendations": diff_result.recommendations,
275                        "corrections": diff_result.corrections,
276                    }));
277
278                    let _incident = state
279                        .incident_manager
280                        .create_incident_with_samples(
281                            path.clone(),
282                            method.clone(),
283                            incident_type,
284                            severity,
285                            details,
286                            None, // budget_id
287                            None, // workspace_id
288                            None, // sync_cycle_id
289                            None, // contract_diff_id (could be generated from diff_result)
290                            before_sample,
291                            after_sample,
292                            Some(drift_result.fitness_test_results.clone()), // fitness_test_results
293                            drift_result.consumer_impact.clone(),            // affected_consumers
294                            Some(mockforge_foundation::protocol::Protocol::Http), // protocol
295                        )
296                        .await;
297
298                    warn!(
299                        "Drift incident created: {} {} - {} breaking changes, {} non-breaking changes",
300                        method, path, drift_result.breaking_changes, drift_result.non_breaking_changes
301                    );
302                }
303
304                // Check for consumer-specific violations
305                if let Some(ref consumer_id) = consumer_id {
306                    let violations = state
307                        .consumer_detector
308                        .detect_violations(consumer_id, &path, &method, &diff_result, None)
309                        .await;
310
311                    if !violations.is_empty() {
312                        warn!(
313                            "Consumer {} has {} violations on {} {}",
314                            consumer_id,
315                            violations.len(),
316                            method,
317                            path
318                        );
319                    }
320                }
321            }
322            Err(e) => {
323                debug!("Contract diff analysis failed: {}", e);
324            }
325        }
326    }
327
328    response
329}
330
331/// Extract consumer identifier from request
332fn extract_consumer_id(req: &Request<Body>) -> Option<String> {
333    // Try to extract from various sources:
334    // 1. X-Consumer-ID header
335    if let Some(consumer_id) = req.headers().get("x-consumer-id").and_then(|h| h.to_str().ok()) {
336        return Some(consumer_id.to_string());
337    }
338
339    // 2. X-Workspace-ID header (for workspace-based consumers)
340    if let Some(workspace_id) = req.headers().get("x-workspace-id").and_then(|h| h.to_str().ok()) {
341        return Some(format!("workspace:{}", workspace_id));
342    }
343
344    // 3. API key from header
345    if let Some(api_key) = req
346        .headers()
347        .get("x-api-key")
348        .or_else(|| req.headers().get("authorization"))
349        .and_then(|h| h.to_str().ok())
350    {
351        // Hash the API key for privacy
352        use sha2::{Digest, Sha256};
353        let mut hasher = Sha256::new();
354        hasher.update(api_key.as_bytes());
355        let hash = format!("{:x}", hasher.finalize());
356        return Some(format!("api_key:{}", hash));
357    }
358
359    None
360}
361
362/// Extract safe headers for drift capture
363fn extract_headers_for_capture(req: &Request<Body>) -> HashMap<String, String> {
364    let safe_headers = [
365        "accept",
366        "accept-encoding",
367        "content-type",
368        "content-length",
369        "user-agent",
370    ];
371    let mut captured = HashMap::new();
372    for name in safe_headers {
373        if let Some(value) = req.headers().get(name).and_then(|v| v.to_str().ok()) {
374            captured.insert(name.to_string(), value.to_string());
375        }
376    }
377    captured
378}
379
380/// Extract response body as JSON value
381fn extract_response_body(response: &Response<Body>) -> Option<Value> {
382    // Try to get buffered response from extensions
383    if let Some(buffered) = crate::middleware::get_buffered_response(response) {
384        return buffered.json();
385    }
386
387    // If not buffered, try to parse from response body
388    // Note: This requires the response body to be buffered by upstream middleware
389    None
390}
391
392/// Determine incident severity from drift result
393fn determine_severity(drift_result: &DriftResult) -> IncidentSeverity {
394    if drift_result.breaking_changes > 0 {
395        // Check if any breaking mismatch is critical
396        if drift_result
397            .breaking_mismatches
398            .iter()
399            .any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::Critical)
400        {
401            return IncidentSeverity::Critical;
402        }
403        // Check if any breaking mismatch is high
404        if drift_result
405            .breaking_mismatches
406            .iter()
407            .any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::High)
408        {
409            return IncidentSeverity::High;
410        }
411        return IncidentSeverity::Medium;
412    }
413
414    // Non-breaking changes are lower severity
415    if drift_result.non_breaking_changes > 5 {
416        IncidentSeverity::Medium
417    } else {
418        IncidentSeverity::Low
419    }
420}