use crate::ai_contract_diff::ContractDiffResult;
use crate::contract_drift::consumer_mapping::ConsumerImpactAnalyzer;
use crate::contract_drift::field_tracking::FieldCountTracker;
use crate::contract_drift::fitness::FitnessFunctionRegistry;
use crate::contract_drift::types::{DriftBudget, DriftBudgetConfig, DriftResult};
use crate::openapi::OpenApiSpec;
use std::sync::Arc;
#[derive(Debug)]
pub struct DriftBudgetEngine {
config: DriftBudgetConfig,
field_tracker: Option<Arc<tokio::sync::RwLock<FieldCountTracker>>>,
fitness_registry: Option<Arc<tokio::sync::RwLock<FitnessFunctionRegistry>>>,
consumer_analyzer: Option<Arc<tokio::sync::RwLock<ConsumerImpactAnalyzer>>>,
}
impl DriftBudgetEngine {
pub fn new(config: DriftBudgetConfig) -> Self {
Self {
config,
field_tracker: None,
fitness_registry: None,
consumer_analyzer: None,
}
}
pub fn new_with_tracker(
config: DriftBudgetConfig,
field_tracker: Arc<tokio::sync::RwLock<FieldCountTracker>>,
) -> Self {
Self {
config,
field_tracker: Some(field_tracker),
fitness_registry: None,
consumer_analyzer: None,
}
}
pub fn set_fitness_registry(
&mut self,
fitness_registry: Arc<tokio::sync::RwLock<FitnessFunctionRegistry>>,
) {
self.fitness_registry = Some(fitness_registry);
}
pub fn set_consumer_analyzer(
&mut self,
consumer_analyzer: Arc<tokio::sync::RwLock<ConsumerImpactAnalyzer>>,
) {
self.consumer_analyzer = Some(consumer_analyzer);
}
pub fn set_field_tracker(
&mut self,
field_tracker: Arc<tokio::sync::RwLock<FieldCountTracker>>,
) {
self.field_tracker = Some(field_tracker);
}
pub fn evaluate(
&self,
diff_result: &ContractDiffResult,
endpoint: &str,
method: &str,
) -> DriftResult {
self.evaluate_with_context(diff_result, endpoint, method, None, None, None)
}
pub fn evaluate_with_context(
&self,
diff_result: &ContractDiffResult,
endpoint: &str,
method: &str,
workspace_id: Option<&str>,
service_name: Option<&str>,
tags: Option<&[String]>,
) -> DriftResult {
if !self.config.enabled {
return DriftResult {
budget_exceeded: false,
breaking_changes: 0,
potentially_breaking_changes: 0,
non_breaking_changes: 0,
breaking_mismatches: vec![],
potentially_breaking_mismatches: vec![],
non_breaking_mismatches: diff_result.mismatches.clone(),
metrics: crate::contract_drift::types::DriftMetrics {
endpoint: endpoint.to_string(),
method: method.to_string(),
breaking_changes: 0,
non_breaking_changes: diff_result.mismatches.len() as u32,
total_changes: diff_result.mismatches.len() as u32,
budget_exceeded: false,
last_updated: chrono::Utc::now().timestamp(),
},
should_create_incident: false,
fitness_test_results: Vec::new(),
consumer_impact: None,
};
}
let budget =
self.get_budget_for_endpoint(endpoint, method, workspace_id, service_name, tags);
if !budget.enabled {
return DriftResult {
budget_exceeded: false,
breaking_changes: 0,
potentially_breaking_changes: 0,
non_breaking_changes: 0,
breaking_mismatches: vec![],
potentially_breaking_mismatches: vec![],
non_breaking_mismatches: diff_result.mismatches.clone(),
metrics: crate::contract_drift::types::DriftMetrics {
endpoint: endpoint.to_string(),
method: method.to_string(),
breaking_changes: 0,
non_breaking_changes: diff_result.mismatches.len() as u32,
total_changes: diff_result.mismatches.len() as u32,
budget_exceeded: false,
last_updated: chrono::Utc::now().timestamp(),
},
should_create_incident: false,
fitness_test_results: Vec::new(),
consumer_impact: None,
};
}
let baseline_field_count = if budget.max_field_churn_percent.is_some() {
if let Some(ref tracker) = self.field_tracker {
let rt = tokio::runtime::Handle::try_current();
if let Ok(handle) = rt {
if let Some(time_window) = budget.time_window_days {
handle.block_on(async {
let guard = tracker.read().await;
guard.get_average_count(
None, endpoint,
method,
time_window,
)
})
} else {
None
}
} else {
None
}
} else {
None
}
} else {
None
};
let mut result = DriftResult::from_diff_result(
diff_result,
endpoint.to_string(),
method.to_string(),
&budget,
&self.config.breaking_change_rules,
baseline_field_count,
);
if let Some(ref analyzer) = self.consumer_analyzer {
let rt = tokio::runtime::Handle::try_current();
if let Ok(handle) = rt {
let impact = handle.block_on(async {
let guard = analyzer.read().await;
guard.analyze_impact(endpoint, method)
});
if let Some(impact) = impact {
result.consumer_impact = Some(impact);
}
}
}
result
}
#[allow(clippy::too_many_arguments)]
pub fn evaluate_with_specs(
&self,
old_spec: Option<&OpenApiSpec>,
new_spec: &OpenApiSpec,
diff_result: &ContractDiffResult,
endpoint: &str,
method: &str,
workspace_id: Option<&str>,
service_name: Option<&str>,
tags: Option<&[String]>,
) -> DriftResult {
if !self.config.enabled {
return DriftResult {
budget_exceeded: false,
breaking_changes: 0,
potentially_breaking_changes: 0,
non_breaking_changes: 0,
breaking_mismatches: vec![],
potentially_breaking_mismatches: vec![],
non_breaking_mismatches: diff_result.mismatches.clone(),
metrics: crate::contract_drift::types::DriftMetrics {
endpoint: endpoint.to_string(),
method: method.to_string(),
breaking_changes: 0,
non_breaking_changes: diff_result.mismatches.len() as u32,
total_changes: diff_result.mismatches.len() as u32,
budget_exceeded: false,
last_updated: chrono::Utc::now().timestamp(),
},
should_create_incident: false,
fitness_test_results: Vec::new(),
consumer_impact: None,
};
}
let budget =
self.get_budget_for_endpoint(endpoint, method, workspace_id, service_name, tags);
if !budget.enabled {
return DriftResult {
budget_exceeded: false,
breaking_changes: 0,
potentially_breaking_changes: 0,
non_breaking_changes: 0,
breaking_mismatches: vec![],
potentially_breaking_mismatches: vec![],
non_breaking_mismatches: diff_result.mismatches.clone(),
metrics: crate::contract_drift::types::DriftMetrics {
endpoint: endpoint.to_string(),
method: method.to_string(),
breaking_changes: 0,
non_breaking_changes: diff_result.mismatches.len() as u32,
total_changes: diff_result.mismatches.len() as u32,
budget_exceeded: false,
last_updated: chrono::Utc::now().timestamp(),
},
should_create_incident: false,
fitness_test_results: Vec::new(),
consumer_impact: None,
};
}
let baseline_field_count = if budget.max_field_churn_percent.is_some() {
if let Some(ref tracker) = self.field_tracker {
let rt = tokio::runtime::Handle::try_current();
if let Ok(handle) = rt {
if let Some(time_window) = budget.time_window_days {
handle.block_on(async {
let guard = tracker.read().await;
guard.get_average_count(
None, endpoint,
method,
time_window,
)
})
} else {
None
}
} else {
None
}
} else {
None
}
} else {
None
};
let mut result = DriftResult::from_diff_result(
diff_result,
endpoint.to_string(),
method.to_string(),
&budget,
&self.config.breaking_change_rules,
baseline_field_count,
);
if let Some(ref registry) = self.fitness_registry {
let rt = tokio::runtime::Handle::try_current();
if let Ok(handle) = rt {
let fitness_results = handle.block_on(async {
let guard = registry.read().await;
guard.evaluate_all(
old_spec,
new_spec,
diff_result,
endpoint,
method,
workspace_id,
service_name,
)
});
if let Ok(results) = fitness_results {
result.fitness_test_results = results;
if result.fitness_test_results.iter().any(|r| !r.passed) {
result.should_create_incident = true;
}
}
}
}
if let Some(ref analyzer) = self.consumer_analyzer {
let rt = tokio::runtime::Handle::try_current();
if let Ok(handle) = rt {
let impact = handle.block_on(async {
let guard = analyzer.read().await;
guard.analyze_impact(endpoint, method)
});
if let Some(impact) = impact {
result.consumer_impact = Some(impact);
}
}
}
result
}
pub fn get_budget_for_endpoint(
&self,
endpoint: &str,
method: &str,
workspace_id: Option<&str>,
service_name: Option<&str>,
tags: Option<&[String]>,
) -> DriftBudget {
if let Some(workspace_id) = workspace_id {
if let Some(budget) = self.config.per_workspace_budgets.get(workspace_id) {
return budget.clone();
}
}
if let Some(service_name) = service_name {
if let Some(budget) = self.config.per_service_budgets.get(service_name) {
return budget.clone();
}
}
if let Some(tags) = tags {
for tag in tags {
if let Some(budget) = self.config.per_tag_budgets.get(tag) {
return budget.clone();
}
if let Some(budget) = self.config.per_service_budgets.get(tag) {
return budget.clone();
}
}
}
let key = format!("{} {}", method, endpoint);
if let Some(budget) = self.config.per_endpoint_budgets.get(&key) {
return budget.clone();
}
self.config.default_budget.clone().unwrap_or_default()
}
pub fn config(&self) -> &DriftBudgetConfig {
&self.config
}
pub fn update_config(&mut self, config: DriftBudgetConfig) {
self.config = config;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ai_contract_diff::{
ContractDiffResult, DiffMetadata, Mismatch, MismatchSeverity, MismatchType,
};
fn create_test_diff_result(mismatches: Vec<Mismatch>) -> ContractDiffResult {
ContractDiffResult {
matches: mismatches.is_empty(),
confidence: 1.0,
mismatches,
recommendations: vec![],
corrections: vec![],
metadata: DiffMetadata {
analyzed_at: chrono::Utc::now(),
request_source: "budget_engine".to_string(),
contract_version: Some("3.0.0".to_string()),
contract_format: "openapi".to_string(),
endpoint_path: String::new(),
http_method: String::new(),
request_count: 0,
llm_provider: None,
llm_model: None,
},
}
}
#[test]
fn test_budget_evaluation_no_mismatches() {
let config = DriftBudgetConfig::default();
let engine = DriftBudgetEngine::new(config);
let diff_result = create_test_diff_result(vec![]);
let result = engine.evaluate(&diff_result, "/api/users", "GET");
assert!(!result.budget_exceeded);
assert_eq!(result.breaking_changes, 0);
assert_eq!(result.non_breaking_changes, 0);
assert!(!result.should_create_incident);
}
#[test]
fn test_budget_evaluation_breaking_change() {
let config = DriftBudgetConfig::default();
let engine = DriftBudgetEngine::new(config);
let mismatch = Mismatch {
mismatch_type: MismatchType::MissingRequiredField,
path: "body.email".to_string(),
method: Some("POST".to_string()),
expected: Some("string".to_string()),
actual: None,
description: "Missing required field: email".to_string(),
severity: MismatchSeverity::Critical,
confidence: 1.0,
context: std::collections::HashMap::new(),
};
let diff_result = create_test_diff_result(vec![mismatch]);
let result = engine.evaluate(&diff_result, "/api/users", "POST");
assert!(result.breaking_changes > 0);
assert!(result.should_create_incident);
}
#[test]
fn test_budget_evaluation_non_breaking_change() {
let config = DriftBudgetConfig::default();
let engine = DriftBudgetEngine::new(config);
let mismatch = Mismatch {
mismatch_type: MismatchType::UnexpectedField,
path: "body.extra_field".to_string(),
method: Some("POST".to_string()),
expected: None,
actual: Some("value".to_string()),
description: "Unexpected field: extra_field".to_string(),
severity: MismatchSeverity::Low,
confidence: 1.0,
context: std::collections::HashMap::new(),
};
let diff_result = create_test_diff_result(vec![mismatch]);
let result = engine.evaluate(&diff_result, "/api/users", "POST");
assert_eq!(result.breaking_changes, 0);
assert!(result.non_breaking_changes > 0);
assert!(!result.should_create_incident);
}
}