mockforge_http/middleware/
drift_tracking.rs

1//! Drift tracking middleware
2//!
3//! This middleware integrates drift budget evaluation and consumer usage tracking
4//! with contract diff analysis.
5
6use axum::{body::Body, extract::Request, http::Response, middleware::Next};
7use mockforge_core::{
8    ai_contract_diff::ContractDiffAnalyzer,
9    consumer_contracts::{ConsumerBreakingChangeDetector, UsageRecorder},
10    contract_drift::{DriftBudgetEngine, DriftResult},
11    incidents::{IncidentManager, IncidentSeverity, IncidentType},
12    openapi::OpenApiSpec,
13};
14use serde_json::Value;
15use std::sync::Arc;
16use tracing::{debug, warn};
17
18/// State for drift tracking middleware
19#[derive(Clone)]
20pub struct DriftTrackingState {
21    /// Contract diff analyzer
22    pub diff_analyzer: Option<Arc<ContractDiffAnalyzer>>,
23    /// OpenAPI spec (if available)
24    pub spec: Option<Arc<OpenApiSpec>>,
25    /// Drift budget engine
26    pub drift_engine: Arc<DriftBudgetEngine>,
27    /// Incident manager
28    pub incident_manager: Arc<IncidentManager>,
29    /// Usage recorder for consumer contracts
30    pub usage_recorder: Arc<UsageRecorder>,
31    /// Consumer breaking change detector
32    pub consumer_detector: Arc<ConsumerBreakingChangeDetector>,
33    /// Whether drift tracking is enabled
34    pub enabled: bool,
35}
36
37/// Middleware to track drift and consumer usage (with state from extensions)
38///
39/// This middleware requires response body buffering middleware to be applied first.
40/// The response body is extracted from buffered response extensions.
41pub async fn drift_tracking_middleware_with_extensions(
42    req: Request<Body>,
43    next: Next,
44) -> Response<Body> {
45    // Extract state from request extensions
46    let state = req.extensions().get::<DriftTrackingState>().cloned();
47
48    let state = if let Some(state) = state {
49        state
50    } else {
51        // No state available, skip drift tracking
52        return next.run(req).await;
53    };
54
55    if !state.enabled {
56        return next.run(req).await;
57    }
58
59    let method = req.method().to_string();
60    let path = req.uri().path().to_string();
61
62    // Extract consumer identifier from request
63    let consumer_id = extract_consumer_id(&req);
64
65    // Process request and get response
66    let response = next.run(req).await;
67
68    // Extract response body for consumer usage tracking
69    let response_body = extract_response_body(&response);
70
71    // Record consumer usage if consumer is identified
72    if let Some(ref consumer_id) = consumer_id {
73        if let Some(body) = &response_body {
74            state.usage_recorder.record_usage(consumer_id, &path, &method, Some(body)).await;
75        }
76    }
77
78    // Perform contract diff analysis if analyzer and spec are available
79    if let (Some(ref analyzer), Some(ref spec)) = (&state.diff_analyzer, &state.spec) {
80        // Create captured request from the actual request
81        // Note: In a full implementation, we'd need to capture the request body
82        // For now, we'll analyze based on path and method
83        let captured = mockforge_core::ai_contract_diff::CapturedRequest::new(
84            &method,
85            &path,
86            "drift_tracking",
87        )
88        .with_response(response.status().as_u16(), response_body.clone());
89
90        // Analyze request against contract
91        match analyzer.analyze(&captured, spec).await {
92            Ok(diff_result) => {
93                // Evaluate drift budget
94                let drift_result = state.drift_engine.evaluate(&diff_result, &path, &method);
95
96                // Record contracts pillar usage for drift detection
97                mockforge_core::pillar_tracking::record_contracts_usage(
98                    None, // workspace_id could be extracted from request if available
99                    None,
100                    "drift_detection",
101                    serde_json::json!({
102                        "endpoint": path,
103                        "method": method,
104                        "breaking_changes": drift_result.breaking_changes,
105                        "non_breaking_changes": drift_result.non_breaking_changes,
106                        "incident": drift_result.should_create_incident
107                    }),
108                )
109                .await;
110
111                // Create incident if budget is exceeded or breaking changes detected
112                if drift_result.should_create_incident {
113                    let incident_type = if drift_result.breaking_changes > 0 {
114                        IncidentType::BreakingChange
115                    } else {
116                        IncidentType::ThresholdExceeded
117                    };
118
119                    let severity = determine_severity(&drift_result);
120
121                    let details = serde_json::json!({
122                        "breaking_changes": drift_result.breaking_changes,
123                        "non_breaking_changes": drift_result.non_breaking_changes,
124                        "breaking_mismatches": drift_result.breaking_mismatches,
125                        "non_breaking_mismatches": drift_result.non_breaking_mismatches,
126                        "budget_exceeded": drift_result.budget_exceeded,
127                    });
128
129                    // Create incident with before/after samples
130                    // Extract before/after samples from diff result if available
131                    let before_sample = Some(serde_json::json!({
132                        "contract_format": diff_result.metadata.contract_format,
133                        "contract_version": diff_result.metadata.contract_version,
134                        "endpoint": path,
135                        "method": method,
136                    }));
137
138                    let after_sample = Some(serde_json::json!({
139                        "mismatches": diff_result.mismatches,
140                        "recommendations": diff_result.recommendations,
141                        "corrections": diff_result.corrections,
142                    }));
143
144                    let _incident = state
145                        .incident_manager
146                        .create_incident_with_samples(
147                            path.clone(),
148                            method.clone(),
149                            incident_type,
150                            severity,
151                            details,
152                            None, // budget_id
153                            None, // workspace_id
154                            None, // sync_cycle_id
155                            None, // contract_diff_id (could be generated from diff_result)
156                            before_sample,
157                            after_sample,
158                            Some(drift_result.fitness_test_results.clone()), // fitness_test_results
159                            drift_result.consumer_impact.clone(),            // affected_consumers
160                            Some(mockforge_core::protocol_abstraction::Protocol::Http), // protocol
161                        )
162                        .await;
163
164                    warn!(
165                        "Drift incident created: {} {} - {} breaking changes, {} non-breaking changes",
166                        method, path, drift_result.breaking_changes, drift_result.non_breaking_changes
167                    );
168                }
169
170                // Check for consumer-specific violations
171                if let Some(ref consumer_id) = consumer_id {
172                    let violations = state
173                        .consumer_detector
174                        .detect_violations(consumer_id, &path, &method, &diff_result, None)
175                        .await;
176
177                    if !violations.is_empty() {
178                        warn!(
179                            "Consumer {} has {} violations on {} {}",
180                            consumer_id,
181                            violations.len(),
182                            method,
183                            path
184                        );
185                    }
186                }
187            }
188            Err(e) => {
189                debug!("Contract diff analysis failed: {}", e);
190            }
191        }
192    }
193
194    response
195}
196
197/// Extract consumer identifier from request
198fn extract_consumer_id(req: &Request<Body>) -> Option<String> {
199    // Try to extract from various sources:
200    // 1. X-Consumer-ID header
201    if let Some(consumer_id) = req.headers().get("x-consumer-id").and_then(|h| h.to_str().ok()) {
202        return Some(consumer_id.to_string());
203    }
204
205    // 2. X-Workspace-ID header (for workspace-based consumers)
206    if let Some(workspace_id) = req.headers().get("x-workspace-id").and_then(|h| h.to_str().ok()) {
207        return Some(format!("workspace:{}", workspace_id));
208    }
209
210    // 3. API key from header
211    if let Some(api_key) = req
212        .headers()
213        .get("x-api-key")
214        .or_else(|| req.headers().get("authorization"))
215        .and_then(|h| h.to_str().ok())
216    {
217        // Hash the API key for privacy
218        use sha2::{Digest, Sha256};
219        let mut hasher = Sha256::new();
220        hasher.update(api_key.as_bytes());
221        let hash = format!("{:x}", hasher.finalize());
222        return Some(format!("api_key:{}", hash));
223    }
224
225    None
226}
227
228/// Extract response body as JSON value
229fn extract_response_body(response: &Response<Body>) -> Option<Value> {
230    // Try to get buffered response from extensions
231    if let Some(buffered) = crate::middleware::get_buffered_response(response) {
232        return buffered.json();
233    }
234
235    // If not buffered, try to parse from response body
236    // Note: This requires the response body to be buffered by upstream middleware
237    None
238}
239
240/// Determine incident severity from drift result
241fn determine_severity(drift_result: &DriftResult) -> IncidentSeverity {
242    if drift_result.breaking_changes > 0 {
243        // Check if any breaking mismatch is critical
244        if drift_result
245            .breaking_mismatches
246            .iter()
247            .any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::Critical)
248        {
249            return IncidentSeverity::Critical;
250        }
251        // Check if any breaking mismatch is high
252        if drift_result
253            .breaking_mismatches
254            .iter()
255            .any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::High)
256        {
257            return IncidentSeverity::High;
258        }
259        return IncidentSeverity::Medium;
260    }
261
262    // Non-breaking changes are lower severity
263    if drift_result.non_breaking_changes > 5 {
264        IncidentSeverity::Medium
265    } else {
266        IncidentSeverity::Low
267    }
268}