use crate::ai_contract_diff::{ContractDiffResult, MismatchType};
use crate::openapi::OpenApiSpec;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FitnessFunction {
pub id: String,
pub name: String,
pub description: String,
pub function_type: FitnessFunctionType,
pub config: serde_json::Value,
pub scope: FitnessScope,
pub enabled: bool,
#[serde(default)]
pub created_at: i64,
#[serde(default)]
pub updated_at: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum FitnessScope {
Global,
Workspace {
workspace_id: String,
},
Service {
service_name: String,
},
Endpoint {
pattern: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum FitnessFunctionType {
ResponseSize {
max_increase_percent: f64,
},
RequiredField {
path_pattern: String,
allow_new_required: bool,
},
FieldCount {
max_fields: u32,
},
SchemaComplexity {
max_depth: u32,
},
Custom {
evaluator: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FitnessTestResult {
pub function_id: String,
pub function_name: String,
pub passed: bool,
pub message: String,
pub metrics: HashMap<String, f64>,
}
pub trait FitnessEvaluator: Send + Sync {
fn evaluate(
&self,
old_spec: Option<&OpenApiSpec>,
new_spec: &OpenApiSpec,
diff_result: &ContractDiffResult,
endpoint: &str,
method: &str,
config: &serde_json::Value,
) -> crate::Result<FitnessTestResult>;
fn evaluate_protocol(
&self,
old_contract: Option<&dyn crate::contract_drift::protocol_contracts::ProtocolContract>,
new_contract: &dyn crate::contract_drift::protocol_contracts::ProtocolContract,
diff_result: &ContractDiffResult,
operation_id: &str,
config: &serde_json::Value,
) -> crate::Result<FitnessTestResult> {
let max_increase_percent =
config.get("max_increase_percent").and_then(|v| v.as_f64()).unwrap_or(50.0);
let old_size = old_contract
.map(|old| estimate_protocol_schema_size(old, operation_id))
.unwrap_or_else(|| estimate_size_from_diff(diff_result));
let new_size = estimate_protocol_schema_size(new_contract, operation_id);
let increase_percent = if old_size > 0.0 {
((new_size - old_size) / old_size) * 100.0
} else if new_size > 0.0 {
100.0
} else {
0.0
};
let passed = increase_percent <= max_increase_percent;
let mut metrics = HashMap::new();
metrics.insert("old_schema_size".to_string(), old_size);
metrics.insert("new_schema_size".to_string(), new_size);
metrics.insert("increase_percent".to_string(), increase_percent);
metrics.insert("max_increase_percent".to_string(), max_increase_percent);
metrics.insert("mismatch_count".to_string(), diff_result.mismatches.len() as f64);
Ok(FitnessTestResult {
function_id: String::new(),
function_name: "Protocol Contract Evaluation".to_string(),
passed,
message: if passed {
format!(
"Protocol schema change ({:.1}%) is within allowed limit ({:.1}%)",
increase_percent, max_increase_percent
)
} else {
format!(
"Protocol schema change ({:.1}%) exceeds allowed limit ({:.1}%)",
increase_percent, max_increase_percent
)
},
metrics,
})
}
}
pub struct ResponseSizeFitnessEvaluator;
impl FitnessEvaluator for ResponseSizeFitnessEvaluator {
fn evaluate(
&self,
old_spec: Option<&OpenApiSpec>,
_new_spec: &OpenApiSpec,
diff_result: &ContractDiffResult,
endpoint: &str,
method: &str,
config: &serde_json::Value,
) -> crate::Result<FitnessTestResult> {
let max_increase_percent =
config.get("max_increase_percent").and_then(|v| v.as_f64()).unwrap_or(25.0);
let old_field_count = if let Some(old) = old_spec {
estimate_response_field_count(old, endpoint, method)
} else {
diff_result.mismatches.len() as f64
};
let new_field_count =
estimate_response_field_count_from_diff(diff_result, endpoint, method);
let increase_percent = if old_field_count > 0.0 {
((new_field_count - old_field_count) / old_field_count) * 100.0
} else if new_field_count > 0.0 {
100.0 } else {
0.0 };
let passed = increase_percent <= max_increase_percent;
let message = if passed {
format!(
"Response size increase ({:.1}%) is within allowed limit ({:.1}%)",
increase_percent, max_increase_percent
)
} else {
format!(
"Response size increase ({:.1}%) exceeds allowed limit ({:.1}%)",
increase_percent, max_increase_percent
)
};
let mut metrics = HashMap::new();
metrics.insert("old_field_count".to_string(), old_field_count);
metrics.insert("new_field_count".to_string(), new_field_count);
metrics.insert("increase_percent".to_string(), increase_percent);
metrics.insert("max_increase_percent".to_string(), max_increase_percent);
Ok(FitnessTestResult {
function_id: String::new(), function_name: "Response Size".to_string(),
passed,
message,
metrics,
})
}
fn evaluate_protocol(
&self,
old_contract: Option<&dyn crate::contract_drift::protocol_contracts::ProtocolContract>,
new_contract: &dyn crate::contract_drift::protocol_contracts::ProtocolContract,
diff_result: &ContractDiffResult,
operation_id: &str,
config: &serde_json::Value,
) -> crate::Result<FitnessTestResult> {
let max_increase_percent =
config.get("max_increase_percent").and_then(|v| v.as_f64()).unwrap_or(25.0);
let old_size = if let Some(old) = old_contract {
estimate_protocol_schema_size(old, operation_id)
} else {
estimate_size_from_diff(diff_result)
};
let new_size = estimate_protocol_schema_size(new_contract, operation_id);
let increase_percent = if old_size > 0.0 {
((new_size - old_size) / old_size) * 100.0
} else if new_size > 0.0 {
100.0 } else {
0.0 };
let passed = increase_percent <= max_increase_percent;
let message = if passed {
format!(
"Protocol contract response size increase ({:.1}%) is within allowed limit ({:.1}%)",
increase_percent, max_increase_percent
)
} else {
format!(
"Protocol contract response size increase ({:.1}%) exceeds allowed limit ({:.1}%)",
increase_percent, max_increase_percent
)
};
let mut metrics = HashMap::new();
metrics.insert("old_size".to_string(), old_size);
metrics.insert("new_size".to_string(), new_size);
metrics.insert("increase_percent".to_string(), increase_percent);
metrics.insert("max_increase_percent".to_string(), max_increase_percent);
Ok(FitnessTestResult {
function_id: String::new(),
function_name: "Response Size".to_string(),
passed,
message,
metrics,
})
}
}
pub struct RequiredFieldFitnessEvaluator;
impl FitnessEvaluator for RequiredFieldFitnessEvaluator {
fn evaluate(
&self,
_old_spec: Option<&OpenApiSpec>,
_new_spec: &OpenApiSpec,
diff_result: &ContractDiffResult,
endpoint: &str,
method: &str,
config: &serde_json::Value,
) -> crate::Result<FitnessTestResult> {
let path_pattern = config.get("path_pattern").and_then(|v| v.as_str()).unwrap_or("*");
let allow_new_required =
config.get("allow_new_required").and_then(|v| v.as_bool()).unwrap_or(false);
let matches_pattern = matches_pattern(endpoint, path_pattern);
if !matches_pattern {
return Ok(FitnessTestResult {
function_id: String::new(),
function_name: "Required Field".to_string(),
passed: true,
message: format!("Endpoint {} does not match pattern {}", endpoint, path_pattern),
metrics: HashMap::new(),
});
}
let new_required_fields = diff_result
.mismatches
.iter()
.filter(|m| {
m.mismatch_type == MismatchType::MissingRequiredField
&& m.method.as_deref() == Some(method)
})
.count();
let passed = allow_new_required || new_required_fields == 0;
let message = if passed {
if allow_new_required {
format!("Found {} new required fields, which is allowed", new_required_fields)
} else {
"No new required fields detected".to_string()
}
} else {
format!(
"Found {} new required fields, which violates the fitness function",
new_required_fields
)
};
let mut metrics = HashMap::new();
metrics.insert("new_required_fields".to_string(), new_required_fields as f64);
metrics
.insert("allow_new_required".to_string(), if allow_new_required { 1.0 } else { 0.0 });
Ok(FitnessTestResult {
function_id: String::new(),
function_name: "Required Field".to_string(),
passed,
message,
metrics,
})
}
fn evaluate_protocol(
&self,
_old_contract: Option<&dyn crate::contract_drift::protocol_contracts::ProtocolContract>,
new_contract: &dyn crate::contract_drift::protocol_contracts::ProtocolContract,
diff_result: &ContractDiffResult,
operation_id: &str,
config: &serde_json::Value,
) -> crate::Result<FitnessTestResult> {
let path_pattern = config.get("path_pattern").and_then(|v| v.as_str()).unwrap_or("*");
let allow_new_required =
config.get("allow_new_required").and_then(|v| v.as_bool()).unwrap_or(false);
let matches = matches_pattern(operation_id, path_pattern) || path_pattern == "*";
if !matches {
return Ok(FitnessTestResult {
function_id: String::new(),
function_name: "Required Field".to_string(),
passed: true,
message: format!(
"Operation {} does not match pattern {}",
operation_id, path_pattern
),
metrics: HashMap::new(),
});
}
let new_required_fields = diff_result
.mismatches
.iter()
.filter(|m| m.mismatch_type == MismatchType::MissingRequiredField)
.count();
let schema_required_fields = if let Some(schema) = new_contract.get_schema(operation_id) {
count_required_fields_in_schema(&schema)
} else {
0
};
let total_new_required = new_required_fields + schema_required_fields;
let passed = allow_new_required || total_new_required == 0;
let message = if passed {
if allow_new_required {
format!("Found {} new required fields, which is allowed", total_new_required)
} else {
"No new required fields detected in protocol contract".to_string()
}
} else {
format!(
"Found {} new required fields in protocol contract, which violates the fitness function",
total_new_required
)
};
let mut metrics = HashMap::new();
metrics.insert("new_required_fields".to_string(), total_new_required as f64);
metrics
.insert("allow_new_required".to_string(), if allow_new_required { 1.0 } else { 0.0 });
Ok(FitnessTestResult {
function_id: String::new(),
function_name: "Required Field".to_string(),
passed,
message,
metrics,
})
}
}
pub struct FieldCountFitnessEvaluator;
impl FitnessEvaluator for FieldCountFitnessEvaluator {
fn evaluate(
&self,
_old_spec: Option<&OpenApiSpec>,
_new_spec: &OpenApiSpec,
diff_result: &ContractDiffResult,
endpoint: &str,
method: &str,
config: &serde_json::Value,
) -> crate::Result<FitnessTestResult> {
let max_fields = config
.get("max_fields")
.and_then(|v| v.as_u64())
.map(|v| v as u32)
.unwrap_or(100);
let field_count = estimate_field_count_from_diff(diff_result, endpoint, method);
let passed = field_count <= max_fields as f64;
let message = if passed {
format!("Field count ({}) is within allowed limit ({})", field_count as u32, max_fields)
} else {
format!("Field count ({}) exceeds allowed limit ({})", field_count as u32, max_fields)
};
let mut metrics = HashMap::new();
metrics.insert("field_count".to_string(), field_count);
metrics.insert("max_fields".to_string(), max_fields as f64);
Ok(FitnessTestResult {
function_id: String::new(),
function_name: "Field Count".to_string(),
passed,
message,
metrics,
})
}
fn evaluate_protocol(
&self,
_old_contract: Option<&dyn crate::contract_drift::protocol_contracts::ProtocolContract>,
new_contract: &dyn crate::contract_drift::protocol_contracts::ProtocolContract,
_diff_result: &ContractDiffResult,
operation_id: &str,
config: &serde_json::Value,
) -> crate::Result<FitnessTestResult> {
let max_fields = config
.get("max_fields")
.and_then(|v| v.as_u64())
.map(|v| v as u32)
.unwrap_or(100);
let field_count = if let Some(schema) = new_contract.get_schema(operation_id) {
count_fields_in_schema(&schema)
} else {
0.0
};
let passed = field_count <= max_fields as f64;
let message = if passed {
format!(
"Protocol contract field count ({}) is within allowed limit ({})",
field_count as u32, max_fields
)
} else {
format!(
"Protocol contract field count ({}) exceeds allowed limit ({})",
field_count as u32, max_fields
)
};
let mut metrics = HashMap::new();
metrics.insert("field_count".to_string(), field_count);
metrics.insert("max_fields".to_string(), max_fields as f64);
Ok(FitnessTestResult {
function_id: String::new(),
function_name: "Field Count".to_string(),
passed,
message,
metrics,
})
}
}
pub struct SchemaComplexityFitnessEvaluator;
impl FitnessEvaluator for SchemaComplexityFitnessEvaluator {
fn evaluate(
&self,
_old_spec: Option<&OpenApiSpec>,
new_spec: &OpenApiSpec,
_diff_result: &ContractDiffResult,
endpoint: &str,
method: &str,
config: &serde_json::Value,
) -> crate::Result<FitnessTestResult> {
let max_depth =
config.get("max_depth").and_then(|v| v.as_u64()).map(|v| v as u32).unwrap_or(10);
let depth = calculate_schema_depth(new_spec, endpoint, method);
let passed = depth <= max_depth;
let message = if passed {
format!("Schema depth ({}) is within allowed limit ({})", depth, max_depth)
} else {
format!("Schema depth ({}) exceeds allowed limit ({})", depth, max_depth)
};
let mut metrics = HashMap::new();
metrics.insert("schema_depth".to_string(), depth as f64);
metrics.insert("max_depth".to_string(), max_depth as f64);
Ok(FitnessTestResult {
function_id: String::new(),
function_name: "Schema Complexity".to_string(),
passed,
message,
metrics,
})
}
fn evaluate_protocol(
&self,
_old_contract: Option<&dyn crate::contract_drift::protocol_contracts::ProtocolContract>,
new_contract: &dyn crate::contract_drift::protocol_contracts::ProtocolContract,
_diff_result: &ContractDiffResult,
operation_id: &str,
config: &serde_json::Value,
) -> crate::Result<FitnessTestResult> {
let max_depth =
config.get("max_depth").and_then(|v| v.as_u64()).map(|v| v as u32).unwrap_or(10);
let depth = if let Some(schema) = new_contract.get_schema(operation_id) {
calculate_protocol_schema_depth(&schema)
} else {
0
};
let passed = depth <= max_depth;
let message = if passed {
format!(
"Protocol contract schema depth ({}) is within allowed limit ({})",
depth, max_depth
)
} else {
format!(
"Protocol contract schema depth ({}) exceeds allowed limit ({})",
depth, max_depth
)
};
let mut metrics = HashMap::new();
metrics.insert("schema_depth".to_string(), depth as f64);
metrics.insert("max_depth".to_string(), max_depth as f64);
Ok(FitnessTestResult {
function_id: String::new(),
function_name: "Schema Complexity".to_string(),
passed,
message,
metrics,
})
}
}
pub struct FitnessFunctionRegistry {
functions: HashMap<String, FitnessFunction>,
evaluators: HashMap<String, Arc<dyn FitnessEvaluator>>,
}
impl std::fmt::Debug for FitnessFunctionRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FitnessFunctionRegistry")
.field("functions", &self.functions)
.field("evaluators_count", &self.evaluators.len())
.finish()
}
}
impl Default for FitnessFunctionRegistry {
fn default() -> Self {
Self::new()
}
}
impl FitnessFunctionRegistry {
pub fn new() -> Self {
let mut registry = Self {
functions: HashMap::new(),
evaluators: HashMap::new(),
};
registry.register_evaluator(
"response_size",
Arc::new(ResponseSizeFitnessEvaluator) as Arc<dyn FitnessEvaluator>,
);
registry.register_evaluator(
"required_field",
Arc::new(RequiredFieldFitnessEvaluator) as Arc<dyn FitnessEvaluator>,
);
registry.register_evaluator(
"field_count",
Arc::new(FieldCountFitnessEvaluator) as Arc<dyn FitnessEvaluator>,
);
registry.register_evaluator(
"schema_complexity",
Arc::new(SchemaComplexityFitnessEvaluator) as Arc<dyn FitnessEvaluator>,
);
registry
}
pub fn register_evaluator(&mut self, name: &str, evaluator: Arc<dyn FitnessEvaluator>) {
self.evaluators.insert(name.to_string(), evaluator);
}
pub fn add_function(&mut self, function: FitnessFunction) {
self.functions.insert(function.id.clone(), function);
}
pub fn get_function(&self, id: &str) -> Option<&FitnessFunction> {
self.functions.get(id)
}
pub fn list_functions(&self) -> Vec<&FitnessFunction> {
self.functions.values().collect()
}
pub fn get_functions_for_scope(
&self,
endpoint: &str,
method: &str,
workspace_id: Option<&str>,
service_name: Option<&str>,
) -> Vec<&FitnessFunction> {
self.functions
.values()
.filter(|f| {
f.enabled && self.matches_scope(f, endpoint, method, workspace_id, service_name)
})
.collect()
}
#[allow(clippy::too_many_arguments)]
pub fn evaluate_all(
&self,
old_spec: Option<&OpenApiSpec>,
new_spec: &OpenApiSpec,
diff_result: &ContractDiffResult,
endpoint: &str,
method: &str,
workspace_id: Option<&str>,
service_name: Option<&str>,
) -> crate::Result<Vec<FitnessTestResult>> {
let functions = self.get_functions_for_scope(endpoint, method, workspace_id, service_name);
let mut results = Vec::new();
for function in functions {
let evaluator_name = match &function.function_type {
FitnessFunctionType::ResponseSize { .. } => "response_size",
FitnessFunctionType::RequiredField { .. } => "required_field",
FitnessFunctionType::FieldCount { .. } => "field_count",
FitnessFunctionType::SchemaComplexity { .. } => "schema_complexity",
FitnessFunctionType::Custom { evaluator } => evaluator.as_str(),
};
if let Some(evaluator) = self.evaluators.get(evaluator_name) {
let mut result = evaluator.evaluate(
old_spec,
new_spec,
diff_result,
endpoint,
method,
&function.config,
)?;
result.function_id = function.id.clone();
result.function_name = function.name.clone();
results.push(result);
}
}
Ok(results)
}
pub fn evaluate_all_protocol(
&self,
old_contract: Option<&dyn crate::contract_drift::protocol_contracts::ProtocolContract>,
new_contract: &dyn crate::contract_drift::protocol_contracts::ProtocolContract,
diff_result: &ContractDiffResult,
operation_id: &str,
workspace_id: Option<&str>,
service_name: Option<&str>,
) -> crate::Result<Vec<FitnessTestResult>> {
let operation = new_contract.get_operation(operation_id);
let (endpoint, method) = if let Some(op) = operation {
match &op.operation_type {
crate::contract_drift::protocol_contracts::OperationType::HttpEndpoint {
path,
method,
} => (path.clone(), method.clone()),
crate::contract_drift::protocol_contracts::OperationType::GrpcMethod {
service,
method,
} => {
(format!("{}.{}", service, method), "grpc".to_string())
}
crate::contract_drift::protocol_contracts::OperationType::WebSocketMessage {
message_type,
..
} => (message_type.clone(), "websocket".to_string()),
crate::contract_drift::protocol_contracts::OperationType::MqttTopic {
topic,
qos: _,
} => (topic.clone(), "mqtt".to_string()),
crate::contract_drift::protocol_contracts::OperationType::KafkaTopic {
topic,
key_schema: _,
value_schema: _,
} => (topic.clone(), "kafka".to_string()),
}
} else {
(operation_id.to_string(), "unknown".to_string())
};
let functions =
self.get_functions_for_scope(&endpoint, &method, workspace_id, service_name);
let mut results = Vec::new();
for function in functions {
let evaluator_name = match &function.function_type {
FitnessFunctionType::ResponseSize { .. } => "response_size",
FitnessFunctionType::RequiredField { .. } => "required_field",
FitnessFunctionType::FieldCount { .. } => "field_count",
FitnessFunctionType::SchemaComplexity { .. } => "schema_complexity",
FitnessFunctionType::Custom { evaluator } => evaluator.as_str(),
};
if let Some(evaluator) = self.evaluators.get(evaluator_name) {
let mut result = evaluator.evaluate_protocol(
old_contract,
new_contract,
diff_result,
operation_id,
&function.config,
)?;
result.function_id = function.id.clone();
result.function_name = function.name.clone();
results.push(result);
}
}
Ok(results)
}
fn matches_scope(
&self,
function: &FitnessFunction,
endpoint: &str,
_method: &str,
workspace_id: Option<&str>,
service_name: Option<&str>,
) -> bool {
match &function.scope {
FitnessScope::Global => true,
FitnessScope::Workspace {
workspace_id: ws_id,
} => workspace_id.map(|id| id == ws_id).unwrap_or(false),
FitnessScope::Service {
service_name: svc_name,
} => service_name.map(|name| name == svc_name).unwrap_or(false),
FitnessScope::Endpoint { pattern } => matches_pattern(endpoint, pattern),
}
}
pub fn remove_function(&mut self, id: &str) -> Option<FitnessFunction> {
self.functions.remove(id)
}
pub fn update_function(&mut self, function: FitnessFunction) {
self.functions.insert(function.id.clone(), function);
}
pub fn load_from_config(
&mut self,
config_rules: &[crate::config::FitnessRuleConfig],
) -> crate::Result<()> {
use crate::config::FitnessRuleType;
for (idx, rule_config) in config_rules.iter().enumerate() {
let id = format!("config-rule-{}", idx);
let scope = parse_scope(&rule_config.scope)?;
let function_type = match rule_config.rule_type {
FitnessRuleType::ResponseSizeDelta => {
let max_increase = rule_config
.max_percent_increase
.ok_or_else(|| {
crate::Error::validation(format!(
"Fitness rule '{}' (type: response_size_delta) requires 'max_percent_increase' field. \
Example: max_percent_increase: 25.0",
rule_config.name
))
})?;
if max_increase < 0.0 {
return Err(crate::Error::validation(format!(
"Fitness rule '{}' (type: response_size_delta): 'max_percent_increase' must be >= 0, got {}",
rule_config.name, max_increase
)));
}
if rule_config.max_fields.is_some() {
tracing::warn!(
"Fitness rule '{}' (type: response_size_delta): 'max_fields' is not used for this rule type",
rule_config.name
);
}
if rule_config.max_depth.is_some() {
tracing::warn!(
"Fitness rule '{}' (type: response_size_delta): 'max_depth' is not used for this rule type",
rule_config.name
);
}
FitnessFunctionType::ResponseSize {
max_increase_percent: max_increase,
}
}
FitnessRuleType::NoNewRequiredFields => {
let path_pattern = match &scope {
FitnessScope::Endpoint { pattern } => pattern.clone(),
_ => "*".to_string(), };
if rule_config.max_percent_increase.is_some() {
tracing::warn!(
"Fitness rule '{}' (type: no_new_required_fields): 'max_percent_increase' is not used for this rule type",
rule_config.name
);
}
if rule_config.max_fields.is_some() {
tracing::warn!(
"Fitness rule '{}' (type: no_new_required_fields): 'max_fields' is not used for this rule type",
rule_config.name
);
}
if rule_config.max_depth.is_some() {
tracing::warn!(
"Fitness rule '{}' (type: no_new_required_fields): 'max_depth' is not used for this rule type",
rule_config.name
);
}
FitnessFunctionType::RequiredField {
path_pattern,
allow_new_required: false,
}
}
FitnessRuleType::FieldCount => {
let max_fields = rule_config.max_fields.ok_or_else(|| {
crate::Error::validation(format!(
"Fitness rule '{}' (type: field_count) requires 'max_fields' field. \
Example: max_fields: 50",
rule_config.name
))
})?;
if max_fields == 0 {
return Err(crate::Error::validation(format!(
"Fitness rule '{}' (type: field_count): 'max_fields' must be > 0, got {}",
rule_config.name, max_fields
)));
}
if rule_config.max_percent_increase.is_some() {
tracing::warn!(
"Fitness rule '{}' (type: field_count): 'max_percent_increase' is not used for this rule type",
rule_config.name
);
}
if rule_config.max_depth.is_some() {
tracing::warn!(
"Fitness rule '{}' (type: field_count): 'max_depth' is not used for this rule type",
rule_config.name
);
}
FitnessFunctionType::FieldCount { max_fields }
}
FitnessRuleType::SchemaComplexity => {
let max_depth = rule_config.max_depth.ok_or_else(|| {
crate::Error::validation(format!(
"Fitness rule '{}' (type: schema_complexity) requires 'max_depth' field. \
Example: max_depth: 5",
rule_config.name
))
})?;
if max_depth == 0 {
return Err(crate::Error::validation(format!(
"Fitness rule '{}' (type: schema_complexity): 'max_depth' must be > 0, got {}",
rule_config.name, max_depth
)));
}
if rule_config.max_percent_increase.is_some() {
tracing::warn!(
"Fitness rule '{}' (type: schema_complexity): 'max_percent_increase' is not used for this rule type",
rule_config.name
);
}
if rule_config.max_fields.is_some() {
tracing::warn!(
"Fitness rule '{}' (type: schema_complexity): 'max_fields' is not used for this rule type",
rule_config.name
);
}
FitnessFunctionType::SchemaComplexity { max_depth }
}
};
let config_json = match &function_type {
FitnessFunctionType::ResponseSize {
max_increase_percent,
} => {
serde_json::json!({
"max_increase_percent": max_increase_percent
})
}
FitnessFunctionType::RequiredField {
path_pattern,
allow_new_required,
} => {
serde_json::json!({
"path_pattern": path_pattern,
"allow_new_required": allow_new_required
})
}
FitnessFunctionType::FieldCount { max_fields } => {
serde_json::json!({
"max_fields": max_fields
})
}
FitnessFunctionType::SchemaComplexity { max_depth } => {
serde_json::json!({
"max_depth": max_depth
})
}
FitnessFunctionType::Custom { .. } => {
serde_json::json!({})
}
};
let function = FitnessFunction {
id,
name: rule_config.name.clone(),
description: format!("Fitness rule: {}", rule_config.name),
function_type,
config: config_json,
scope,
enabled: true,
created_at: chrono::Utc::now().timestamp(),
updated_at: chrono::Utc::now().timestamp(),
};
self.add_function(function);
}
Ok(())
}
}
fn parse_scope(scope_str: &str) -> crate::Result<FitnessScope> {
let scope_str = scope_str.trim();
if scope_str == "global" {
return Ok(FitnessScope::Global);
}
if let Some(workspace_id) = scope_str.strip_prefix("workspace:") {
return Ok(FitnessScope::Workspace {
workspace_id: workspace_id.to_string(),
});
}
if let Some(service_name) = scope_str.strip_prefix("service:") {
return Ok(FitnessScope::Service {
service_name: service_name.to_string(),
});
}
Ok(FitnessScope::Endpoint {
pattern: scope_str.to_string(),
})
}
fn matches_pattern(endpoint: &str, pattern: &str) -> bool {
if pattern == "*" {
return true;
}
let pattern_parts: Vec<&str> = pattern.split('*').collect();
if pattern_parts.len() == 1 {
return endpoint == pattern;
}
if let (Some(first), Some(last)) = (pattern_parts.first(), pattern_parts.last()) {
endpoint.starts_with(first) && endpoint.ends_with(last)
} else {
false
}
}
fn estimate_response_field_count(spec: &OpenApiSpec, endpoint: &str, method: &str) -> f64 {
let mut visited = std::collections::HashSet::new();
extract_response_schema(spec, endpoint, method)
.map(|schema| count_fields_in_schema_resolved(schema, spec, &mut visited))
.filter(|count| *count > 0.0)
.unwrap_or(10.0)
}
fn estimate_response_field_count_from_diff(
diff_result: &ContractDiffResult,
_endpoint: &str,
_method: &str,
) -> f64 {
let base_count = 10.0;
let mismatch_count = diff_result.mismatches.len() as f64;
base_count + mismatch_count
}
fn estimate_field_count_from_diff(
diff_result: &ContractDiffResult,
_endpoint: &str,
_method: &str,
) -> f64 {
let unique_paths: std::collections::HashSet<String> = diff_result
.mismatches
.iter()
.map(|m| {
m.path.split('.').next().unwrap_or("").to_string()
})
.collect();
unique_paths.len() as f64 + 10.0 }
fn calculate_schema_depth(spec: &OpenApiSpec, endpoint: &str, method: &str) -> u32 {
let mut visited = std::collections::HashSet::new();
extract_response_schema(spec, endpoint, method)
.map(|schema| schema_depth_resolved(schema, spec, &mut visited))
.filter(|depth| *depth > 0)
.unwrap_or(5)
}
fn extract_response_schema<'a>(
spec: &'a OpenApiSpec,
endpoint: &str,
method: &str,
) -> Option<&'a serde_json::Value> {
let raw = spec.raw_document.as_ref()?;
let paths = raw.get("paths")?.as_object()?;
let path_item = paths.get(endpoint)?;
let operation = path_item.get(method.to_lowercase())?;
let responses = operation.get("responses")?.as_object()?;
let response = responses
.get("200")
.or_else(|| responses.get("201"))
.or_else(|| responses.get("default"))
.or_else(|| responses.values().next())?;
let content = response.get("content")?.as_object()?;
let media_type = content.get("application/json").or_else(|| content.values().next())?;
media_type.get("schema")
}
fn resolve_local_ref<'a>(spec: &'a OpenApiSpec, reference: &str) -> Option<&'a serde_json::Value> {
let pointer = reference.strip_prefix('#')?;
spec.raw_document.as_ref()?.pointer(pointer)
}
fn count_fields_in_schema_resolved(
schema: &serde_json::Value,
spec: &OpenApiSpec,
visited: &mut std::collections::HashSet<String>,
) -> f64 {
if let Some(reference) = schema.get("$ref").and_then(|r| r.as_str()) {
if !visited.insert(reference.to_string()) {
return 0.0;
}
return resolve_local_ref(spec, reference)
.map(|resolved| count_fields_in_schema_resolved(resolved, spec, visited))
.unwrap_or(0.0);
}
let mut total = 0.0;
if let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) {
total += properties.len() as f64;
for prop in properties.values() {
total += count_fields_in_schema_resolved(prop, spec, visited);
}
}
if let Some(fields) = schema.get("fields").and_then(|f| f.as_array()) {
total += fields.len() as f64;
for field in fields {
total += count_fields_in_schema_resolved(field, spec, visited);
}
}
if let Some(items) = schema.get("items") {
total += count_fields_in_schema_resolved(items, spec, visited);
}
for key in ["oneOf", "anyOf", "allOf"] {
if let Some(variants) = schema.get(key).and_then(|v| v.as_array()) {
for variant in variants {
total += count_fields_in_schema_resolved(variant, spec, visited);
}
}
}
total
}
fn schema_depth_resolved(
schema: &serde_json::Value,
spec: &OpenApiSpec,
visited: &mut std::collections::HashSet<String>,
) -> u32 {
if let Some(reference) = schema.get("$ref").and_then(|r| r.as_str()) {
if !visited.insert(reference.to_string()) {
return 0;
}
return resolve_local_ref(spec, reference)
.map(|resolved| schema_depth_resolved(resolved, spec, visited))
.unwrap_or(0);
}
let mut max_child = 0u32;
if let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) {
for prop in properties.values() {
max_child = max_child.max(schema_depth_resolved(prop, spec, visited));
}
}
if let Some(items) = schema.get("items") {
max_child = max_child.max(schema_depth_resolved(items, spec, visited));
}
for key in ["oneOf", "anyOf", "allOf"] {
if let Some(variants) = schema.get(key).and_then(|v| v.as_array()) {
for variant in variants {
max_child = max_child.max(schema_depth_resolved(variant, spec, visited));
}
}
}
let is_object_like = schema.get("properties").is_some()
|| schema
.get("type")
.and_then(|t| t.as_str())
.map(|t| t == "object" || t == "array")
.unwrap_or(false);
if is_object_like {
1 + max_child
} else {
max_child
}
}
fn estimate_protocol_schema_size(
contract: &dyn crate::contract_drift::protocol_contracts::ProtocolContract,
operation_id: &str,
) -> f64 {
if let Some(schema) = contract.get_schema(operation_id) {
if let Some(output_schema) = schema.get("output_schema") {
count_fields_in_schema(output_schema)
} else if let Some(input_schema) = schema.get("input_schema") {
count_fields_in_schema(input_schema)
} else {
10.0
}
} else {
10.0
}
}
fn count_fields_in_schema(schema: &serde_json::Value) -> f64 {
match schema {
serde_json::Value::Object(map) => {
let mut count = 0.0;
if let Some(properties) = map.get("properties") {
if let Some(props) = properties.as_object() {
count += props.len() as f64;
for prop_value in props.values() {
count += count_fields_in_schema(prop_value);
}
}
}
if let Some(fields) = map.get("fields") {
if let Some(fields_array) = fields.as_array() {
count += fields_array.len() as f64;
for field in fields_array {
if let Some(field_obj) = field.as_object() {
if let Some(field_type) = field_obj.get("type") {
count += count_fields_in_schema(field_type);
}
}
}
}
}
if let Some(item_type) = map.get("items") {
count += count_fields_in_schema(item_type);
}
count
}
_ => 0.0,
}
}
fn estimate_size_from_diff(diff_result: &ContractDiffResult) -> f64 {
let base_size = 10.0;
let mismatch_count = diff_result.mismatches.len() as f64;
base_size + (mismatch_count * 2.0) }
fn count_required_fields_in_schema(schema: &serde_json::Value) -> usize {
match schema {
serde_json::Value::Object(map) => {
let mut count = 0;
if let Some(required) = map.get("required") {
if let Some(required_array) = required.as_array() {
count += required_array.len();
}
}
if let Some(properties) = map.get("properties") {
if let Some(props) = properties.as_object() {
for prop_value in props.values() {
count += count_required_fields_in_schema(prop_value);
}
}
}
if let Some(fields) = map.get("fields") {
if let Some(fields_array) = fields.as_array() {
for field in fields_array {
if let Some(field_obj) = field.as_object() {
if !field_obj.contains_key("default") {
count += 1;
}
if let Some(field_type) = field_obj.get("type") {
count += count_required_fields_in_schema(field_type);
}
}
}
}
}
count
}
_ => 0,
}
}
fn calculate_protocol_schema_depth(schema: &serde_json::Value) -> u32 {
match schema {
serde_json::Value::Object(map) => {
let mut max_depth = 0;
if let Some(properties) = map.get("properties") {
if let Some(props) = properties.as_object() {
for prop_value in props.values() {
let depth = calculate_protocol_schema_depth(prop_value);
max_depth = max_depth.max(depth + 1);
}
}
}
if let Some(fields) = map.get("fields") {
if let Some(fields_array) = fields.as_array() {
for field in fields_array {
if let Some(field_obj) = field.as_object() {
if let Some(field_type) = field_obj.get("type") {
let depth = calculate_protocol_schema_depth(field_type);
max_depth = max_depth.max(depth + 1);
}
}
}
}
}
if let Some(items) = map.get("items") {
let depth = calculate_protocol_schema_depth(items);
max_depth = max_depth.max(depth + 1);
}
max_depth
}
_ => 0,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_matches_pattern() {
assert!(matches_pattern("/api/users", "*"));
assert!(matches_pattern("/api/users", "/api/users"));
assert!(matches_pattern("/api/users/123", "/api/users/*"));
assert!(matches_pattern("/v1/mobile/users", "/v1/mobile/*"));
assert!(!matches_pattern("/api/users", "/api/orders"));
}
#[test]
fn test_fitness_function_registry() {
let mut registry = FitnessFunctionRegistry::new();
let function = FitnessFunction {
id: "test-1".to_string(),
name: "Test Function".to_string(),
description: "Test".to_string(),
function_type: FitnessFunctionType::ResponseSize {
max_increase_percent: 25.0,
},
config: serde_json::json!({"max_increase_percent": 25.0}),
scope: FitnessScope::Global,
enabled: true,
created_at: 0,
updated_at: 0,
};
registry.add_function(function);
assert_eq!(registry.list_functions().len(), 1);
}
}