Skip to main content

mockforge_bench/
spec_dependencies.rs

1//! Cross-spec dependency detection and configuration for multi-spec benchmarking
2//!
3//! This module provides:
4//! - Auto-detection of dependencies between specs based on schema references
5//! - Manual dependency configuration via YAML/JSON files
6//! - Topological sorting for correct execution order
7//! - Value extraction and injection between spec groups
8
9use crate::error::{BenchError, Result};
10use mockforge_core::openapi::spec::OpenApiSpec;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13use std::path::{Path, PathBuf};
14
15/// Cross-spec dependency configuration (optional override)
16#[derive(Debug, Clone, Serialize, Deserialize, Default)]
17pub struct SpecDependencyConfig {
18    /// Ordered list of spec groups to execute
19    #[serde(default)]
20    pub execution_order: Vec<SpecGroup>,
21    /// Disable auto-detection of dependencies
22    #[serde(default)]
23    pub disable_auto_detect: bool,
24}
25
26impl SpecDependencyConfig {
27    /// Load dependency configuration from a file (YAML or JSON)
28    pub fn from_file(path: &Path) -> Result<Self> {
29        let content = std::fs::read_to_string(path)
30            .map_err(|e| BenchError::Other(format!("Failed to read dependency config: {}", e)))?;
31
32        let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
33        match ext {
34            "yaml" | "yml" => serde_yaml::from_str(&content).map_err(|e| {
35                BenchError::Other(format!("Failed to parse YAML dependency config: {}", e))
36            }),
37            "json" => serde_json::from_str(&content).map_err(|e| {
38                BenchError::Other(format!("Failed to parse JSON dependency config: {}", e))
39            }),
40            _ => Err(BenchError::Other(format!(
41                "Unsupported dependency config format: {}. Use .yaml, .yml, or .json",
42                ext
43            ))),
44        }
45    }
46}
47
48/// A group of specs to execute together
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct SpecGroup {
51    /// Name for this group (e.g., "infrastructure", "services")
52    pub name: String,
53    /// Spec files in this group
54    pub specs: Vec<PathBuf>,
55    /// Fields to extract from responses (JSONPath-like syntax)
56    #[serde(default)]
57    pub extract: HashMap<String, String>,
58    /// Fields to inject into next group's requests
59    #[serde(default)]
60    pub inject: HashMap<String, String>,
61}
62
63/// Detected dependency between two specs
64#[derive(Debug, Clone)]
65pub struct SpecDependency {
66    /// The spec that depends on another
67    pub dependent_spec: PathBuf,
68    /// The spec that is depended upon
69    pub dependency_spec: PathBuf,
70    /// Field name that creates the dependency (e.g., "pool_ref")
71    pub field_name: String,
72    /// Schema name being referenced (e.g., "Pool")
73    pub referenced_schema: String,
74    /// Extraction path for the dependency value
75    pub extraction_path: String,
76}
77
78/// Dependency detector for analyzing specs
79pub struct DependencyDetector {
80    /// Schemas available in each spec (spec_path -> schema_names)
81    schema_registry: HashMap<PathBuf, HashSet<String>>,
82    /// Detected dependencies
83    dependencies: Vec<SpecDependency>,
84}
85
86impl DependencyDetector {
87    /// Create a new dependency detector
88    pub fn new() -> Self {
89        Self {
90            schema_registry: HashMap::new(),
91            dependencies: Vec::new(),
92        }
93    }
94
95    /// Detect dependencies between specs by analyzing schema references
96    pub fn detect_dependencies(&mut self, specs: &[(PathBuf, OpenApiSpec)]) -> Vec<SpecDependency> {
97        // Build schema registry - collect all schemas from each spec
98        for (path, spec) in specs {
99            let schemas = self.extract_schema_names(spec);
100            self.schema_registry.insert(path.clone(), schemas);
101        }
102
103        // Analyze each spec's request bodies for references to other specs' schemas
104        for (path, spec) in specs {
105            self.analyze_spec_references(path, spec, specs);
106        }
107
108        self.dependencies.clone()
109    }
110
111    /// Extract all schema names from a spec
112    fn extract_schema_names(&self, spec: &OpenApiSpec) -> HashSet<String> {
113        let mut schemas = HashSet::new();
114
115        if let Some(components) = &spec.spec.components {
116            for (name, _) in &components.schemas {
117                schemas.insert(name.clone());
118                // Also add common variations
119                schemas.insert(name.to_lowercase());
120                schemas.insert(to_snake_case(name));
121            }
122        }
123
124        schemas
125    }
126
127    /// Analyze a spec's references to detect dependencies
128    fn analyze_spec_references(
129        &mut self,
130        current_path: &PathBuf,
131        spec: &OpenApiSpec,
132        all_specs: &[(PathBuf, OpenApiSpec)],
133    ) {
134        // Analyze request body schemas for reference patterns
135        for (path, path_item) in &spec.spec.paths.paths {
136            if let openapiv3::ReferenceOr::Item(item) = path_item {
137                // Check POST operations (most common for creating resources with refs)
138                if let Some(op) = &item.post {
139                    self.analyze_operation_refs(current_path, op, all_specs, path);
140                }
141                if let Some(op) = &item.put {
142                    self.analyze_operation_refs(current_path, op, all_specs, path);
143                }
144                if let Some(op) = &item.patch {
145                    self.analyze_operation_refs(current_path, op, all_specs, path);
146                }
147            }
148        }
149    }
150
151    /// Analyze operation request body for reference fields
152    fn analyze_operation_refs(
153        &mut self,
154        current_path: &PathBuf,
155        operation: &openapiv3::Operation,
156        all_specs: &[(PathBuf, OpenApiSpec)],
157        _api_path: &str,
158    ) {
159        if let Some(openapiv3::ReferenceOr::Item(body)) = &operation.request_body {
160            // Check JSON content
161            if let Some(media_type) = body.content.get("application/json") {
162                if let Some(schema_ref) = &media_type.schema {
163                    self.analyze_schema_for_refs(current_path, schema_ref, all_specs, "");
164                }
165            }
166        }
167    }
168
169    /// Recursively analyze schema for reference patterns
170    fn analyze_schema_for_refs(
171        &mut self,
172        current_path: &PathBuf,
173        schema_ref: &openapiv3::ReferenceOr<openapiv3::Schema>,
174        all_specs: &[(PathBuf, OpenApiSpec)],
175        field_prefix: &str,
176    ) {
177        match schema_ref {
178            openapiv3::ReferenceOr::Item(schema) => {
179                self.analyze_schema(current_path, schema, all_specs, field_prefix);
180            }
181            openapiv3::ReferenceOr::Reference { reference } => {
182                // Could analyze $ref to other schemas here
183                let _ = reference; // Silence unused warning for now
184            }
185        }
186    }
187
188    /// Analyze schema for reference patterns (handles both Box<Schema> and Schema)
189    fn analyze_schema(
190        &mut self,
191        current_path: &PathBuf,
192        schema: &openapiv3::Schema,
193        all_specs: &[(PathBuf, OpenApiSpec)],
194        field_prefix: &str,
195    ) {
196        match &schema.schema_kind {
197            openapiv3::SchemaKind::Type(openapiv3::Type::Object(obj)) => {
198                for (prop_name, prop_schema) in &obj.properties {
199                    let full_path = if field_prefix.is_empty() {
200                        prop_name.clone()
201                    } else {
202                        format!("{}.{}", field_prefix, prop_name)
203                    };
204
205                    // Check for reference patterns in field names
206                    if let Some(dep) = self.detect_ref_field(current_path, prop_name, all_specs) {
207                        self.dependencies.push(SpecDependency {
208                            dependent_spec: current_path.clone(),
209                            dependency_spec: dep.0,
210                            field_name: prop_name.clone(),
211                            referenced_schema: dep.1,
212                            extraction_path: format!("$.{}", full_path),
213                        });
214                    }
215
216                    // Recursively check nested schemas
217                    self.analyze_boxed_schema_ref(current_path, prop_schema, all_specs, &full_path);
218                }
219            }
220            openapiv3::SchemaKind::AllOf { all_of } => {
221                for sub_schema in all_of {
222                    self.analyze_schema_for_refs(current_path, sub_schema, all_specs, field_prefix);
223                }
224            }
225            openapiv3::SchemaKind::OneOf { one_of } => {
226                for sub_schema in one_of {
227                    self.analyze_schema_for_refs(current_path, sub_schema, all_specs, field_prefix);
228                }
229            }
230            openapiv3::SchemaKind::AnyOf { any_of } => {
231                for sub_schema in any_of {
232                    self.analyze_schema_for_refs(current_path, sub_schema, all_specs, field_prefix);
233                }
234            }
235            _ => {}
236        }
237    }
238
239    /// Handle ReferenceOr<Box<Schema>> which is used in object properties
240    fn analyze_boxed_schema_ref(
241        &mut self,
242        current_path: &PathBuf,
243        schema_ref: &openapiv3::ReferenceOr<Box<openapiv3::Schema>>,
244        all_specs: &[(PathBuf, OpenApiSpec)],
245        field_prefix: &str,
246    ) {
247        match schema_ref {
248            openapiv3::ReferenceOr::Item(boxed_schema) => {
249                self.analyze_schema(current_path, boxed_schema.as_ref(), all_specs, field_prefix);
250            }
251            openapiv3::ReferenceOr::Reference { reference } => {
252                let _ = reference; // Could analyze $ref here
253            }
254        }
255    }
256
257    /// Detect if a field name references another spec's schema
258    fn detect_ref_field(
259        &self,
260        current_path: &PathBuf,
261        field_name: &str,
262        all_specs: &[(PathBuf, OpenApiSpec)],
263    ) -> Option<(PathBuf, String)> {
264        // Common patterns for reference fields
265        let ref_patterns = [
266            ("_ref", ""),       // pool_ref -> Pool
267            ("_id", ""),        // pool_id -> Pool
268            ("Id", ""),         // poolId -> pool
269            ("_uuid", ""),      // pool_uuid -> Pool
270            ("Uuid", ""),       // poolUuid -> pool
271            ("_reference", ""), // pool_reference -> Pool
272        ];
273
274        for (suffix, _) in ref_patterns.iter() {
275            if field_name.ends_with(suffix) {
276                // Extract the schema name from the field
277                let schema_base = field_name.trim_end_matches(suffix).trim_end_matches('_');
278
279                // Search for this schema in other specs
280                for (other_path, _) in all_specs {
281                    if other_path == current_path {
282                        continue;
283                    }
284
285                    if let Some(schemas) = self.schema_registry.get(other_path) {
286                        // Check various name formats
287                        let schema_pascal = to_pascal_case(schema_base);
288                        let schema_lower = schema_base.to_lowercase();
289
290                        for schema_name in schemas {
291                            if schema_name == &schema_pascal
292                                || schema_name == &schema_lower
293                                || schema_name.to_lowercase() == schema_lower
294                            {
295                                return Some((other_path.clone(), schema_name.clone()));
296                            }
297                        }
298                    }
299                }
300            }
301        }
302
303        None
304    }
305}
306
307impl Default for DependencyDetector {
308    fn default() -> Self {
309        Self::new()
310    }
311}
312
313/// Topologically sort specs based on dependencies
314pub fn topological_sort(
315    specs: &[(PathBuf, OpenApiSpec)],
316    dependencies: &[SpecDependency],
317) -> Result<Vec<PathBuf>> {
318    let spec_paths: Vec<PathBuf> = specs.iter().map(|(p, _)| p.clone()).collect();
319
320    // Build adjacency list (dependency -> dependent)
321    let mut adj: HashMap<PathBuf, Vec<PathBuf>> = HashMap::new();
322    let mut in_degree: HashMap<PathBuf, usize> = HashMap::new();
323
324    for path in &spec_paths {
325        adj.insert(path.clone(), Vec::new());
326        in_degree.insert(path.clone(), 0);
327    }
328
329    for dep in dependencies {
330        adj.entry(dep.dependency_spec.clone())
331            .or_default()
332            .push(dep.dependent_spec.clone());
333        *in_degree.entry(dep.dependent_spec.clone()).or_insert(0) += 1;
334    }
335
336    // Kahn's algorithm
337    let mut queue: Vec<PathBuf> = in_degree
338        .iter()
339        .filter(|(_, &deg)| deg == 0)
340        .map(|(path, _)| path.clone())
341        .collect();
342
343    let mut result = Vec::new();
344
345    while let Some(path) = queue.pop() {
346        result.push(path.clone());
347
348        if let Some(dependents) = adj.get(&path) {
349            for dependent in dependents {
350                if let Some(deg) = in_degree.get_mut(dependent) {
351                    *deg -= 1;
352                    if *deg == 0 {
353                        queue.push(dependent.clone());
354                    }
355                }
356            }
357        }
358    }
359
360    if result.len() != spec_paths.len() {
361        return Err(BenchError::Other("Circular dependency detected between specs".to_string()));
362    }
363
364    Ok(result)
365}
366
367/// Convert string to snake_case
368fn to_snake_case(s: &str) -> String {
369    let mut result = String::new();
370    for (i, c) in s.chars().enumerate() {
371        if c.is_uppercase() {
372            if i > 0 {
373                result.push('_');
374            }
375            result.push(c.to_lowercase().next().unwrap());
376        } else {
377            result.push(c);
378        }
379    }
380    result
381}
382
383/// Convert string to PascalCase
384fn to_pascal_case(s: &str) -> String {
385    let mut result = String::new();
386    let mut capitalize_next = true;
387
388    for c in s.chars() {
389        if c == '_' || c == '-' {
390            capitalize_next = true;
391        } else if capitalize_next {
392            result.push(c.to_uppercase().next().unwrap());
393            capitalize_next = false;
394        } else {
395            result.push(c);
396        }
397    }
398
399    result
400}
401
402/// Extracted values from spec execution for passing to dependent specs
403#[derive(Debug, Clone, Default)]
404pub struct ExtractedValues {
405    /// Values extracted by variable name
406    pub values: HashMap<String, serde_json::Value>,
407}
408
409impl ExtractedValues {
410    /// Create new empty extracted values
411    pub fn new() -> Self {
412        Self::default()
413    }
414
415    /// Set a value
416    pub fn set(&mut self, key: String, value: serde_json::Value) {
417        self.values.insert(key, value);
418    }
419
420    /// Get a value
421    pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
422        self.values.get(key)
423    }
424
425    /// Merge values from another ExtractedValues
426    pub fn merge(&mut self, other: &ExtractedValues) {
427        for (key, value) in &other.values {
428            self.values.insert(key.clone(), value.clone());
429        }
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn test_to_snake_case() {
439        assert_eq!(to_snake_case("PascalCase"), "pascal_case");
440        assert_eq!(to_snake_case("camelCase"), "camel_case");
441        assert_eq!(to_snake_case("Pool"), "pool");
442        assert_eq!(to_snake_case("VirtualService"), "virtual_service");
443    }
444
445    #[test]
446    fn test_to_pascal_case() {
447        assert_eq!(to_pascal_case("snake_case"), "SnakeCase");
448        assert_eq!(to_pascal_case("pool"), "Pool");
449        assert_eq!(to_pascal_case("virtual_service"), "VirtualService");
450    }
451
452    #[test]
453    fn test_extracted_values() {
454        let mut values = ExtractedValues::new();
455        values.set("pool_id".to_string(), serde_json::json!("abc123"));
456        values.set("name".to_string(), serde_json::json!("test-pool"));
457
458        assert_eq!(values.get("pool_id"), Some(&serde_json::json!("abc123")));
459        assert_eq!(values.get("missing"), None);
460    }
461
462    #[test]
463    fn test_spec_dependency_config_default() {
464        let config = SpecDependencyConfig::default();
465        assert!(config.execution_order.is_empty());
466        assert!(!config.disable_auto_detect);
467    }
468}