1use crate::ai_contract_diff::{ContractDiffResult, MismatchType};
8use crate::openapi::OpenApiSpec;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct FitnessFunction {
16 pub id: String,
18 pub name: String,
20 pub description: String,
22 pub function_type: FitnessFunctionType,
24 pub config: serde_json::Value,
26 pub scope: FitnessScope,
28 pub enabled: bool,
30 #[serde(default)]
32 pub created_at: i64,
33 #[serde(default)]
35 pub updated_at: i64,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
40#[serde(rename_all = "snake_case")]
41pub enum FitnessScope {
42 Global,
44 Workspace {
46 workspace_id: String,
48 },
49 Service {
51 service_name: String,
53 },
54 Endpoint {
56 pattern: String,
58 },
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
63#[serde(tag = "type", rename_all = "snake_case")]
64pub enum FitnessFunctionType {
65 ResponseSize {
67 max_increase_percent: f64,
69 },
70 RequiredField {
72 path_pattern: String,
74 allow_new_required: bool,
76 },
77 FieldCount {
79 max_fields: u32,
81 },
82 SchemaComplexity {
84 max_depth: u32,
86 },
87 Custom {
89 evaluator: String,
91 },
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct FitnessTestResult {
97 pub function_id: String,
99 pub function_name: String,
101 pub passed: bool,
103 pub message: String,
105 pub metrics: HashMap<String, f64>,
107}
108
109pub trait FitnessEvaluator: Send + Sync {
111 fn evaluate(
126 &self,
127 old_spec: Option<&OpenApiSpec>,
128 new_spec: &OpenApiSpec,
129 diff_result: &ContractDiffResult,
130 endpoint: &str,
131 method: &str,
132 config: &serde_json::Value,
133 ) -> crate::Result<FitnessTestResult>;
134}
135
136pub struct ResponseSizeFitnessEvaluator;
138
139impl FitnessEvaluator for ResponseSizeFitnessEvaluator {
140 fn evaluate(
141 &self,
142 old_spec: Option<&OpenApiSpec>,
143 _new_spec: &OpenApiSpec,
144 diff_result: &ContractDiffResult,
145 endpoint: &str,
146 method: &str,
147 config: &serde_json::Value,
148 ) -> crate::Result<FitnessTestResult> {
149 let max_increase_percent =
151 config.get("max_increase_percent").and_then(|v| v.as_f64()).unwrap_or(25.0);
152
153 let old_field_count = if let Some(old) = old_spec {
156 estimate_response_field_count(old, endpoint, method)
158 } else {
159 diff_result.mismatches.len() as f64
161 };
162
163 let new_field_count =
164 estimate_response_field_count_from_diff(diff_result, endpoint, method);
165
166 let increase_percent = if old_field_count > 0.0 {
167 ((new_field_count - old_field_count) / old_field_count) * 100.0
168 } else if new_field_count > 0.0 {
169 100.0 } else {
171 0.0 };
173
174 let passed = increase_percent <= max_increase_percent;
175 let message = if passed {
176 format!(
177 "Response size increase ({:.1}%) is within allowed limit ({:.1}%)",
178 increase_percent, max_increase_percent
179 )
180 } else {
181 format!(
182 "Response size increase ({:.1}%) exceeds allowed limit ({:.1}%)",
183 increase_percent, max_increase_percent
184 )
185 };
186
187 let mut metrics = HashMap::new();
188 metrics.insert("old_field_count".to_string(), old_field_count);
189 metrics.insert("new_field_count".to_string(), new_field_count);
190 metrics.insert("increase_percent".to_string(), increase_percent);
191 metrics.insert("max_increase_percent".to_string(), max_increase_percent);
192
193 Ok(FitnessTestResult {
194 function_id: String::new(), function_name: "Response Size".to_string(),
196 passed,
197 message,
198 metrics,
199 })
200 }
201}
202
203pub struct RequiredFieldFitnessEvaluator;
205
206impl FitnessEvaluator for RequiredFieldFitnessEvaluator {
207 fn evaluate(
208 &self,
209 _old_spec: Option<&OpenApiSpec>,
210 _new_spec: &OpenApiSpec,
211 diff_result: &ContractDiffResult,
212 endpoint: &str,
213 method: &str,
214 config: &serde_json::Value,
215 ) -> crate::Result<FitnessTestResult> {
216 let path_pattern = config.get("path_pattern").and_then(|v| v.as_str()).unwrap_or("*");
218 let allow_new_required =
219 config.get("allow_new_required").and_then(|v| v.as_bool()).unwrap_or(false);
220
221 let matches_pattern = matches_pattern(endpoint, path_pattern);
223
224 if !matches_pattern {
225 return Ok(FitnessTestResult {
227 function_id: String::new(),
228 function_name: "Required Field".to_string(),
229 passed: true,
230 message: format!("Endpoint {} does not match pattern {}", endpoint, path_pattern),
231 metrics: HashMap::new(),
232 });
233 }
234
235 let new_required_fields = diff_result
237 .mismatches
238 .iter()
239 .filter(|m| {
240 m.mismatch_type == MismatchType::MissingRequiredField
241 && m.method.as_ref().map(|m| m.as_str()) == Some(method)
242 })
243 .count();
244
245 let passed = allow_new_required || new_required_fields == 0;
246 let message = if passed {
247 if allow_new_required {
248 format!("Found {} new required fields, which is allowed", new_required_fields)
249 } else {
250 "No new required fields detected".to_string()
251 }
252 } else {
253 format!(
254 "Found {} new required fields, which violates the fitness function",
255 new_required_fields
256 )
257 };
258
259 let mut metrics = HashMap::new();
260 metrics.insert("new_required_fields".to_string(), new_required_fields as f64);
261 metrics
262 .insert("allow_new_required".to_string(), if allow_new_required { 1.0 } else { 0.0 });
263
264 Ok(FitnessTestResult {
265 function_id: String::new(),
266 function_name: "Required Field".to_string(),
267 passed,
268 message,
269 metrics,
270 })
271 }
272}
273
274pub struct FieldCountFitnessEvaluator;
276
277impl FitnessEvaluator for FieldCountFitnessEvaluator {
278 fn evaluate(
279 &self,
280 _old_spec: Option<&OpenApiSpec>,
281 _new_spec: &OpenApiSpec,
282 diff_result: &ContractDiffResult,
283 endpoint: &str,
284 method: &str,
285 config: &serde_json::Value,
286 ) -> crate::Result<FitnessTestResult> {
287 let max_fields = config
289 .get("max_fields")
290 .and_then(|v| v.as_u64())
291 .map(|v| v as u32)
292 .unwrap_or(100);
293
294 let field_count = estimate_field_count_from_diff(diff_result, endpoint, method);
296
297 let passed = field_count <= max_fields as f64;
298 let message = if passed {
299 format!("Field count ({}) is within allowed limit ({})", field_count as u32, max_fields)
300 } else {
301 format!("Field count ({}) exceeds allowed limit ({})", field_count as u32, max_fields)
302 };
303
304 let mut metrics = HashMap::new();
305 metrics.insert("field_count".to_string(), field_count);
306 metrics.insert("max_fields".to_string(), max_fields as f64);
307
308 Ok(FitnessTestResult {
309 function_id: String::new(),
310 function_name: "Field Count".to_string(),
311 passed,
312 message,
313 metrics,
314 })
315 }
316}
317
318pub struct SchemaComplexityFitnessEvaluator;
320
321impl FitnessEvaluator for SchemaComplexityFitnessEvaluator {
322 fn evaluate(
323 &self,
324 _old_spec: Option<&OpenApiSpec>,
325 new_spec: &OpenApiSpec,
326 _diff_result: &ContractDiffResult,
327 endpoint: &str,
328 method: &str,
329 config: &serde_json::Value,
330 ) -> crate::Result<FitnessTestResult> {
331 let max_depth =
333 config.get("max_depth").and_then(|v| v.as_u64()).map(|v| v as u32).unwrap_or(10);
334
335 let depth = calculate_schema_depth(new_spec, endpoint, method);
337
338 let passed = depth <= max_depth;
339 let message = if passed {
340 format!("Schema depth ({}) is within allowed limit ({})", depth, max_depth)
341 } else {
342 format!("Schema depth ({}) exceeds allowed limit ({})", depth, max_depth)
343 };
344
345 let mut metrics = HashMap::new();
346 metrics.insert("schema_depth".to_string(), depth as f64);
347 metrics.insert("max_depth".to_string(), max_depth as f64);
348
349 Ok(FitnessTestResult {
350 function_id: String::new(),
351 function_name: "Schema Complexity".to_string(),
352 passed,
353 message,
354 metrics,
355 })
356 }
357}
358
359pub struct FitnessFunctionRegistry {
361 functions: HashMap<String, FitnessFunction>,
363 evaluators: HashMap<String, Arc<dyn FitnessEvaluator>>,
365}
366
367impl std::fmt::Debug for FitnessFunctionRegistry {
368 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
369 f.debug_struct("FitnessFunctionRegistry")
370 .field("functions", &self.functions)
371 .field("evaluators_count", &self.evaluators.len())
372 .finish()
373 }
374}
375
376impl FitnessFunctionRegistry {
377 pub fn new() -> Self {
379 let mut registry = Self {
380 functions: HashMap::new(),
381 evaluators: HashMap::new(),
382 };
383
384 registry.register_evaluator(
386 "response_size",
387 Arc::new(ResponseSizeFitnessEvaluator) as Arc<dyn FitnessEvaluator>,
388 );
389 registry.register_evaluator(
390 "required_field",
391 Arc::new(RequiredFieldFitnessEvaluator) as Arc<dyn FitnessEvaluator>,
392 );
393 registry.register_evaluator(
394 "field_count",
395 Arc::new(FieldCountFitnessEvaluator) as Arc<dyn FitnessEvaluator>,
396 );
397 registry.register_evaluator(
398 "schema_complexity",
399 Arc::new(SchemaComplexityFitnessEvaluator) as Arc<dyn FitnessEvaluator>,
400 );
401
402 registry
403 }
404
405 pub fn register_evaluator(&mut self, name: &str, evaluator: Arc<dyn FitnessEvaluator>) {
407 self.evaluators.insert(name.to_string(), evaluator);
408 }
409
410 pub fn add_function(&mut self, function: FitnessFunction) {
412 self.functions.insert(function.id.clone(), function);
413 }
414
415 pub fn get_function(&self, id: &str) -> Option<&FitnessFunction> {
417 self.functions.get(id)
418 }
419
420 pub fn list_functions(&self) -> Vec<&FitnessFunction> {
422 self.functions.values().collect()
423 }
424
425 pub fn get_functions_for_scope(
427 &self,
428 endpoint: &str,
429 method: &str,
430 workspace_id: Option<&str>,
431 service_name: Option<&str>,
432 ) -> Vec<&FitnessFunction> {
433 self.functions
434 .values()
435 .filter(|f| {
436 f.enabled && self.matches_scope(f, endpoint, method, workspace_id, service_name)
437 })
438 .collect()
439 }
440
441 pub fn evaluate_all(
443 &self,
444 old_spec: Option<&OpenApiSpec>,
445 new_spec: &OpenApiSpec,
446 diff_result: &ContractDiffResult,
447 endpoint: &str,
448 method: &str,
449 workspace_id: Option<&str>,
450 service_name: Option<&str>,
451 ) -> crate::Result<Vec<FitnessTestResult>> {
452 let functions = self.get_functions_for_scope(endpoint, method, workspace_id, service_name);
453 let mut results = Vec::new();
454
455 for function in functions {
456 let evaluator_name = match &function.function_type {
457 FitnessFunctionType::ResponseSize { .. } => "response_size",
458 FitnessFunctionType::RequiredField { .. } => "required_field",
459 FitnessFunctionType::FieldCount { .. } => "field_count",
460 FitnessFunctionType::SchemaComplexity { .. } => "schema_complexity",
461 FitnessFunctionType::Custom { evaluator } => evaluator.as_str(),
462 };
463
464 if let Some(evaluator) = self.evaluators.get(evaluator_name) {
465 let mut result = evaluator.evaluate(
466 old_spec,
467 new_spec,
468 diff_result,
469 endpoint,
470 method,
471 &function.config,
472 )?;
473 result.function_id = function.id.clone();
474 result.function_name = function.name.clone();
475 results.push(result);
476 }
477 }
478
479 Ok(results)
480 }
481
482 fn matches_scope(
484 &self,
485 function: &FitnessFunction,
486 endpoint: &str,
487 method: &str,
488 workspace_id: Option<&str>,
489 service_name: Option<&str>,
490 ) -> bool {
491 match &function.scope {
492 FitnessScope::Global => true,
493 FitnessScope::Workspace {
494 workspace_id: ws_id,
495 } => workspace_id.map(|id| id == ws_id).unwrap_or(false),
496 FitnessScope::Service {
497 service_name: svc_name,
498 } => service_name.map(|name| name == svc_name).unwrap_or(false),
499 FitnessScope::Endpoint { pattern } => matches_pattern(endpoint, pattern),
500 }
501 }
502
503 pub fn remove_function(&mut self, id: &str) -> Option<FitnessFunction> {
505 self.functions.remove(id)
506 }
507
508 pub fn update_function(&mut self, function: FitnessFunction) {
510 self.functions.insert(function.id.clone(), function);
511 }
512}
513
514impl Default for FitnessFunctionRegistry {
515 fn default() -> Self {
516 Self::new()
517 }
518}
519
520fn matches_pattern(endpoint: &str, pattern: &str) -> bool {
524 if pattern == "*" {
525 return true;
526 }
527
528 let pattern_parts: Vec<&str> = pattern.split('*').collect();
530 if pattern_parts.len() == 1 {
531 return endpoint == pattern;
533 }
534
535 if let (Some(first), Some(last)) = (pattern_parts.first(), pattern_parts.last()) {
537 endpoint.starts_with(first) && endpoint.ends_with(last)
538 } else {
539 false
540 }
541}
542
543fn estimate_response_field_count(_spec: &OpenApiSpec, _endpoint: &str, _method: &str) -> f64 {
545 10.0
549}
550
551fn estimate_response_field_count_from_diff(
553 diff_result: &ContractDiffResult,
554 _endpoint: &str,
555 _method: &str,
556) -> f64 {
557 let base_count = 10.0;
560 let mismatch_count = diff_result.mismatches.len() as f64;
561 base_count + mismatch_count
562}
563
564fn estimate_field_count_from_diff(
566 diff_result: &ContractDiffResult,
567 _endpoint: &str,
568 _method: &str,
569) -> f64 {
570 let unique_paths: std::collections::HashSet<String> = diff_result
572 .mismatches
573 .iter()
574 .map(|m| {
575 m.path.split('.').next().unwrap_or("").to_string()
577 })
578 .collect();
579
580 unique_paths.len() as f64 + 10.0 }
582
583fn calculate_schema_depth(_spec: &OpenApiSpec, _endpoint: &str, _method: &str) -> u32 {
585 5
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594
595 #[test]
596 fn test_matches_pattern() {
597 assert!(matches_pattern("/api/users", "*"));
598 assert!(matches_pattern("/api/users", "/api/users"));
599 assert!(matches_pattern("/api/users/123", "/api/users/*"));
600 assert!(matches_pattern("/v1/mobile/users", "/v1/mobile/*"));
601 assert!(!matches_pattern("/api/users", "/api/orders"));
602 }
603
604 #[test]
605 fn test_fitness_function_registry() {
606 let mut registry = FitnessFunctionRegistry::new();
607
608 let function = FitnessFunction {
609 id: "test-1".to_string(),
610 name: "Test Function".to_string(),
611 description: "Test".to_string(),
612 function_type: FitnessFunctionType::ResponseSize {
613 max_increase_percent: 25.0,
614 },
615 config: serde_json::json!({"max_increase_percent": 25.0}),
616 scope: FitnessScope::Global,
617 enabled: true,
618 created_at: 0,
619 updated_at: 0,
620 };
621
622 registry.add_function(function);
623 assert_eq!(registry.list_functions().len(), 1);
624 }
625}