use axum::{body::Body, extract::Request, http::Response, middleware::Next};
use mockforge_core::{
ai_contract_diff::ContractDiffAnalyzer,
consumer_contracts::{ConsumerBreakingChangeDetector, UsageRecorder},
contract_drift::{DriftBudgetEngine, DriftResult},
incidents::{IncidentManager, IncidentSeverity, IncidentType},
openapi::OpenApiSpec,
};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, warn};
const MAX_DRIFT_BODY_SIZE: usize = 1024 * 1024;
#[derive(Clone)]
pub struct DriftTrackingState {
pub diff_analyzer: Option<Arc<ContractDiffAnalyzer>>,
pub spec: Option<Arc<OpenApiSpec>>,
pub drift_engine: Arc<DriftBudgetEngine>,
pub incident_manager: Arc<IncidentManager>,
pub usage_recorder: Arc<UsageRecorder>,
pub consumer_detector: Arc<ConsumerBreakingChangeDetector>,
pub enabled: bool,
}
pub async fn drift_tracking_middleware_with_extensions(
req: Request<Body>,
next: Next,
) -> Response<Body> {
let state = req.extensions().get::<DriftTrackingState>().cloned();
let state = if let Some(state) = state {
state
} else {
return next.run(req).await;
};
if !state.enabled {
return next.run(req).await;
}
let method = req.method().to_string();
let path = req.uri().path().to_string();
let consumer_id = extract_consumer_id(&req);
let captured_headers = extract_headers_for_capture(&req);
let (parts, body) = req.into_parts();
let body_bytes = match axum::body::to_bytes(body, MAX_DRIFT_BODY_SIZE).await {
Ok(b) => b,
Err(_) => {
let rebuilt = Request::from_parts(parts, Body::empty());
return next.run(rebuilt).await;
}
};
let captured_body = if !body_bytes.is_empty() {
serde_json::from_slice::<serde_json::Value>(&body_bytes).ok()
} else {
None
};
let req = Request::from_parts(parts, Body::from(body_bytes));
let response = next.run(req).await;
let response_body = extract_response_body(&response);
if let Some(ref consumer_id) = consumer_id {
if let Some(body) = &response_body {
state.usage_recorder.record_usage(consumer_id, &path, &method, Some(body)).await;
}
}
if let (Some(ref analyzer), Some(ref spec)) = (&state.diff_analyzer, &state.spec) {
let mut captured = mockforge_core::ai_contract_diff::CapturedRequest::new(
&method,
&path,
"drift_tracking",
)
.with_headers(captured_headers)
.with_response(response.status().as_u16(), response_body.clone());
if let Some(body_value) = captured_body {
captured = captured.with_body(body_value);
}
match analyzer.analyze(&captured, spec).await {
Ok(diff_result) => {
let drift_result = state.drift_engine.evaluate(&diff_result, &path, &method);
mockforge_core::pillar_tracking::record_contracts_usage(
None, None,
"drift_detection",
serde_json::json!({
"endpoint": path,
"method": method,
"breaking_changes": drift_result.breaking_changes,
"non_breaking_changes": drift_result.non_breaking_changes,
"incident": drift_result.should_create_incident
}),
)
.await;
if drift_result.should_create_incident {
let incident_type = if drift_result.breaking_changes > 0 {
IncidentType::BreakingChange
} else {
IncidentType::ThresholdExceeded
};
let severity = determine_severity(&drift_result);
let details = serde_json::json!({
"breaking_changes": drift_result.breaking_changes,
"non_breaking_changes": drift_result.non_breaking_changes,
"breaking_mismatches": drift_result.breaking_mismatches,
"non_breaking_mismatches": drift_result.non_breaking_mismatches,
"budget_exceeded": drift_result.budget_exceeded,
});
let before_sample = Some(serde_json::json!({
"contract_format": diff_result.metadata.contract_format,
"contract_version": diff_result.metadata.contract_version,
"endpoint": path,
"method": method,
}));
let after_sample = Some(serde_json::json!({
"mismatches": diff_result.mismatches,
"recommendations": diff_result.recommendations,
"corrections": diff_result.corrections,
}));
let _incident = state
.incident_manager
.create_incident_with_samples(
path.clone(),
method.clone(),
incident_type,
severity,
details,
None, None, None, None, before_sample,
after_sample,
Some(drift_result.fitness_test_results.clone()), drift_result.consumer_impact.clone(), Some(mockforge_core::protocol_abstraction::Protocol::Http), )
.await;
warn!(
"Drift incident created: {} {} - {} breaking changes, {} non-breaking changes",
method, path, drift_result.breaking_changes, drift_result.non_breaking_changes
);
}
if let Some(ref consumer_id) = consumer_id {
let violations = state
.consumer_detector
.detect_violations(consumer_id, &path, &method, &diff_result, None)
.await;
if !violations.is_empty() {
warn!(
"Consumer {} has {} violations on {} {}",
consumer_id,
violations.len(),
method,
path
);
}
}
}
Err(e) => {
debug!("Contract diff analysis failed: {}", e);
}
}
}
response
}
fn extract_consumer_id(req: &Request<Body>) -> Option<String> {
if let Some(consumer_id) = req.headers().get("x-consumer-id").and_then(|h| h.to_str().ok()) {
return Some(consumer_id.to_string());
}
if let Some(workspace_id) = req.headers().get("x-workspace-id").and_then(|h| h.to_str().ok()) {
return Some(format!("workspace:{}", workspace_id));
}
if let Some(api_key) = req
.headers()
.get("x-api-key")
.or_else(|| req.headers().get("authorization"))
.and_then(|h| h.to_str().ok())
{
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(api_key.as_bytes());
let hash = format!("{:x}", hasher.finalize());
return Some(format!("api_key:{}", hash));
}
None
}
fn extract_headers_for_capture(req: &Request<Body>) -> HashMap<String, String> {
let safe_headers = [
"accept",
"accept-encoding",
"content-type",
"content-length",
"user-agent",
];
let mut captured = HashMap::new();
for name in safe_headers {
if let Some(value) = req.headers().get(name).and_then(|v| v.to_str().ok()) {
captured.insert(name.to_string(), value.to_string());
}
}
captured
}
fn extract_response_body(response: &Response<Body>) -> Option<Value> {
if let Some(buffered) = crate::middleware::get_buffered_response(response) {
return buffered.json();
}
None
}
fn determine_severity(drift_result: &DriftResult) -> IncidentSeverity {
if drift_result.breaking_changes > 0 {
if drift_result
.breaking_mismatches
.iter()
.any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::Critical)
{
return IncidentSeverity::Critical;
}
if drift_result
.breaking_mismatches
.iter()
.any(|m| m.severity == mockforge_core::ai_contract_diff::MismatchSeverity::High)
{
return IncidentSeverity::High;
}
return IncidentSeverity::Medium;
}
if drift_result.non_breaking_changes > 5 {
IncidentSeverity::Medium
} else {
IncidentSeverity::Low
}
}