mockforge_http/middleware/
drift_tracking.rs

1//! Drift tracking middleware
2//!
3//! This middleware integrates drift budget evaluation and consumer usage tracking
4//! with contract diff analysis.
5
6use axum::{
7    body::Body,
8    extract::{Request, State},
9    http::Response,
10    middleware::Next,
11};
12use mockforge_core::{
13    ai_contract_diff::ContractDiffAnalyzer,
14    consumer_contracts::{ConsumerBreakingChangeDetector, UsageRecorder},
15    contract_drift::{DriftBudgetEngine, DriftResult},
16    incidents::{IncidentManager, IncidentSeverity, IncidentType},
17    openapi::OpenApiSpec,
18};
19use serde_json::Value;
20use std::sync::Arc;
21use tracing::{debug, warn};
22
23/// State for drift tracking middleware
24#[derive(Clone)]
25pub struct DriftTrackingState {
26    /// Contract diff analyzer
27    pub diff_analyzer: Option<Arc<ContractDiffAnalyzer>>,
28    /// OpenAPI spec (if available)
29    pub spec: Option<Arc<OpenApiSpec>>,
30    /// Drift budget engine
31    pub drift_engine: Arc<DriftBudgetEngine>,
32    /// Incident manager
33    pub incident_manager: Arc<IncidentManager>,
34    /// Usage recorder for consumer contracts
35    pub usage_recorder: Arc<UsageRecorder>,
36    /// Consumer breaking change detector
37    pub consumer_detector: Arc<ConsumerBreakingChangeDetector>,
38    /// Whether drift tracking is enabled
39    pub enabled: bool,
40}
41
42/// Middleware to track drift and consumer usage (with state from extensions)
43///
44/// This middleware requires response body buffering middleware to be applied first.
45/// The response body is extracted from buffered response extensions.
46pub async fn drift_tracking_middleware_with_extensions(
47    req: Request<Body>,
48    next: Next,
49) -> Response<Body> {
50    // Extract state from request extensions
51    let state = req.extensions().get::<DriftTrackingState>().cloned();
52
53    let state = if let Some(state) = state {
54        state
55    } else {
56        // No state available, skip drift tracking
57        return next.run(req).await;
58    };
59
60    if !state.enabled {
61        return next.run(req).await;
62    }
63
64    let method = req.method().to_string();
65    let path = req.uri().path().to_string();
66
67    // Extract consumer identifier from request
68    let consumer_id = extract_consumer_id(&req);
69
70    // Process request and get response
71    let response = next.run(req).await;
72
73    // Extract response body for consumer usage tracking
74    let response_body = extract_response_body(&response);
75
76    // Record consumer usage if consumer is identified
77    if let Some(ref consumer_id) = consumer_id {
78        if let Some(body) = &response_body {
79            state.usage_recorder.record_usage(consumer_id, &path, &method, Some(body)).await;
80        }
81    }
82
83    // Perform contract diff analysis if analyzer and spec are available
84    if let (Some(ref analyzer), Some(ref spec)) = (&state.diff_analyzer, &state.spec) {
85        // Create captured request from the actual request
86        // Note: In a full implementation, we'd need to capture the request body
87        // For now, we'll analyze based on path and method
88        let captured = mockforge_core::ai_contract_diff::CapturedRequest::new(
89            &method,
90            &path,
91            "drift_tracking",
92        )
93        .with_response(response.status().as_u16(), response_body.clone());
94
95        // Analyze request against contract
96        match analyzer.analyze(&captured, spec).await {
97            Ok(diff_result) => {
98                // Evaluate drift budget
99                let drift_result = state.drift_engine.evaluate(&diff_result, &path, &method);
100
101                // Create incident if budget is exceeded or breaking changes detected
102                if drift_result.should_create_incident {
103                    let incident_type = if drift_result.breaking_changes > 0 {
104                        IncidentType::BreakingChange
105                    } else {
106                        IncidentType::ThresholdExceeded
107                    };
108
109                    let severity = determine_severity(&drift_result);
110
111                    let details = serde_json::json!({
112                        "breaking_changes": drift_result.breaking_changes,
113                        "non_breaking_changes": drift_result.non_breaking_changes,
114                        "breaking_mismatches": drift_result.breaking_mismatches,
115                        "non_breaking_mismatches": drift_result.non_breaking_mismatches,
116                        "budget_exceeded": drift_result.budget_exceeded,
117                    });
118
119                    // Create incident with before/after samples
120                    // Extract before/after samples from diff result if available
121                    let before_sample = Some(serde_json::json!({
122                        "contract_format": diff_result.metadata.contract_format,
123                        "contract_version": diff_result.metadata.contract_version,
124                        "endpoint": path,
125                        "method": method,
126                    }));
127
128                    let after_sample = Some(serde_json::json!({
129                        "mismatches": diff_result.mismatches,
130                        "recommendations": diff_result.recommendations,
131                        "corrections": diff_result.corrections,
132                    }));
133
134                    let _incident = state
135                        .incident_manager
136                        .create_incident_with_samples(
137                            path.clone(),
138                            method.clone(),
139                            incident_type,
140                            severity,
141                            details,
142                            None, // budget_id
143                            None, // workspace_id
144                            None, // sync_cycle_id
145                            None, // contract_diff_id (could be generated from diff_result)
146                            before_sample,
147                            after_sample,
148                        )
149                        .await;
150
151                    warn!(
152                        "Drift incident created: {} {} - {} breaking changes, {} non-breaking changes",
153                        method, path, drift_result.breaking_changes, drift_result.non_breaking_changes
154                    );
155                }
156
157                // Check for consumer-specific violations
158                if let Some(ref consumer_id) = consumer_id {
159                    let violations = state
160                        .consumer_detector
161                        .detect_violations(consumer_id, &path, &method, &diff_result, None)
162                        .await;
163
164                    if !violations.is_empty() {
165                        warn!(
166                            "Consumer {} has {} violations on {} {}",
167                            consumer_id,
168                            violations.len(),
169                            method,
170                            path
171                        );
172                    }
173                }
174            }
175            Err(e) => {
176                debug!("Contract diff analysis failed: {}", e);
177            }
178        }
179    }
180
181    response
182}
183
184/// Extract consumer identifier from request
185fn extract_consumer_id(req: &Request<Body>) -> Option<String> {
186    // Try to extract from various sources:
187    // 1. X-Consumer-ID header
188    if let Some(consumer_id) = req.headers().get("x-consumer-id").and_then(|h| h.to_str().ok()) {
189        return Some(consumer_id.to_string());
190    }
191
192    // 2. X-Workspace-ID header (for workspace-based consumers)
193    if let Some(workspace_id) = req.headers().get("x-workspace-id").and_then(|h| h.to_str().ok()) {
194        return Some(format!("workspace:{}", workspace_id));
195    }
196
197    // 3. API key from header
198    if let Some(api_key) = req
199        .headers()
200        .get("x-api-key")
201        .or_else(|| req.headers().get("authorization"))
202        .and_then(|h| h.to_str().ok())
203    {
204        // Hash the API key for privacy
205        use sha2::{Digest, Sha256};
206        let mut hasher = Sha256::new();
207        hasher.update(api_key.as_bytes());
208        let hash = format!("{:x}", hasher.finalize());
209        return Some(format!("api_key:{}", hash));
210    }
211
212    None
213}
214
215/// Extract response body as JSON value
216fn extract_response_body(response: &Response<Body>) -> Option<Value> {
217    // Try to get buffered response from extensions
218    if let Some(buffered) = crate::middleware::get_buffered_response(response) {
219        return buffered.json();
220    }
221
222    // If not buffered, try to parse from response body
223    // Note: This requires the response body to be buffered by upstream middleware
224    None
225}
226
227/// Determine incident severity from drift result
228fn determine_severity(drift_result: &DriftResult) -> IncidentSeverity {
229    if drift_result.breaking_changes > 0 {
230        // Check if any breaking mismatch is critical
231        if drift_result
232            .breaking_mismatches
233            .iter()
234            .any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::Critical)
235        {
236            return IncidentSeverity::Critical;
237        }
238        // Check if any breaking mismatch is high
239        if drift_result
240            .breaking_mismatches
241            .iter()
242            .any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::High)
243        {
244            return IncidentSeverity::High;
245        }
246        return IncidentSeverity::Medium;
247    }
248
249    // Non-breaking changes are lower severity
250    if drift_result.non_breaking_changes > 5 {
251        IncidentSeverity::Medium
252    } else {
253        IncidentSeverity::Low
254    }
255}