mockforge_bench/
command.rs

1//! Bench command implementation
2
3use crate::error::{BenchError, Result};
4use crate::executor::K6Executor;
5use crate::k6_gen::{K6Config, K6ScriptGenerator};
6use crate::parallel_executor::{AggregatedResults, ParallelExecutor};
7use crate::param_overrides::ParameterOverrides;
8use crate::reporter::TerminalReporter;
9use crate::request_gen::RequestGenerator;
10use crate::scenarios::LoadScenario;
11use crate::spec_parser::SpecParser;
12use crate::target_parser::parse_targets_file;
13use std::collections::HashMap;
14use std::path::PathBuf;
15use std::str::FromStr;
16
17/// Bench command configuration
18pub struct BenchCommand {
19    pub spec: PathBuf,
20    pub target: String,
21    pub duration: String,
22    pub vus: u32,
23    pub scenario: String,
24    pub operations: Option<String>,
25    /// Exclude operations from testing (comma-separated)
26    ///
27    /// Supports "METHOD /path" or just "METHOD" to exclude all operations of that type.
28    pub exclude_operations: Option<String>,
29    pub auth: Option<String>,
30    pub headers: Option<String>,
31    pub output: PathBuf,
32    pub generate_only: bool,
33    pub script_output: Option<PathBuf>,
34    pub threshold_percentile: String,
35    pub threshold_ms: u64,
36    pub max_error_rate: f64,
37    pub verbose: bool,
38    pub skip_tls_verify: bool,
39    /// Optional file containing multiple targets
40    pub targets_file: Option<PathBuf>,
41    /// Maximum number of parallel executions (for multi-target mode)
42    pub max_concurrency: Option<u32>,
43    /// Results format: "per-target", "aggregated", or "both"
44    pub results_format: String,
45    /// Optional file containing parameter value overrides (JSON or YAML)
46    ///
47    /// Allows users to provide custom values for path parameters, query parameters,
48    /// headers, and request bodies instead of auto-generated placeholder values.
49    pub params_file: Option<PathBuf>,
50}
51
52impl BenchCommand {
53    /// Execute the bench command
54    pub async fn execute(&self) -> Result<()> {
55        // Check if we're in multi-target mode
56        if let Some(targets_file) = &self.targets_file {
57            return self.execute_multi_target(targets_file).await;
58        }
59
60        // Single target mode (existing behavior)
61        // Print header
62        TerminalReporter::print_header(
63            self.spec.to_str().unwrap(),
64            &self.target,
65            0, // Will be updated later
66            &self.scenario,
67            Self::parse_duration(&self.duration)?,
68        );
69
70        // Validate k6 installation
71        if !K6Executor::is_k6_installed() {
72            TerminalReporter::print_error("k6 is not installed");
73            TerminalReporter::print_warning(
74                "Install k6 from: https://k6.io/docs/get-started/installation/",
75            );
76            return Err(BenchError::K6NotFound);
77        }
78
79        // Load and parse spec
80        TerminalReporter::print_progress("Loading OpenAPI specification...");
81        let parser = SpecParser::from_file(&self.spec).await?;
82        TerminalReporter::print_success("Specification loaded");
83
84        // Get operations
85        TerminalReporter::print_progress("Extracting API operations...");
86        let mut operations = if let Some(filter) = &self.operations {
87            parser.filter_operations(filter)?
88        } else {
89            parser.get_operations()
90        };
91
92        // Apply exclusions if provided
93        if let Some(exclude) = &self.exclude_operations {
94            let before_count = operations.len();
95            operations = parser.exclude_operations(operations, exclude)?;
96            let excluded_count = before_count - operations.len();
97            if excluded_count > 0 {
98                TerminalReporter::print_progress(&format!(
99                    "Excluded {} operations matching '{}'",
100                    excluded_count, exclude
101                ));
102            }
103        }
104
105        if operations.is_empty() {
106            return Err(BenchError::Other("No operations found in spec".to_string()));
107        }
108
109        TerminalReporter::print_success(&format!("Found {} operations", operations.len()));
110
111        // Load parameter overrides if provided
112        let param_overrides = if let Some(params_file) = &self.params_file {
113            TerminalReporter::print_progress("Loading parameter overrides...");
114            let overrides = ParameterOverrides::from_file(params_file)?;
115            TerminalReporter::print_success(&format!(
116                "Loaded parameter overrides ({} operation-specific, {} defaults)",
117                overrides.operations.len(),
118                if overrides.defaults.is_empty() { 0 } else { 1 }
119            ));
120            Some(overrides)
121        } else {
122            None
123        };
124
125        // Generate request templates
126        TerminalReporter::print_progress("Generating request templates...");
127        let templates: Vec<_> = operations
128            .iter()
129            .map(|op| {
130                let op_overrides = param_overrides.as_ref().map(|po| {
131                    po.get_for_operation(op.operation_id.as_deref(), &op.method, &op.path)
132                });
133                RequestGenerator::generate_template_with_overrides(op, op_overrides.as_ref())
134            })
135            .collect::<Result<Vec<_>>>()?;
136        TerminalReporter::print_success("Request templates generated");
137
138        // Parse headers
139        let custom_headers = self.parse_headers()?;
140
141        // Generate k6 script
142        TerminalReporter::print_progress("Generating k6 load test script...");
143        let scenario =
144            LoadScenario::from_str(&self.scenario).map_err(BenchError::InvalidScenario)?;
145
146        let k6_config = K6Config {
147            target_url: self.target.clone(),
148            scenario,
149            duration_secs: Self::parse_duration(&self.duration)?,
150            max_vus: self.vus,
151            threshold_percentile: self.threshold_percentile.clone(),
152            threshold_ms: self.threshold_ms,
153            max_error_rate: self.max_error_rate,
154            auth_header: self.auth.clone(),
155            custom_headers,
156            skip_tls_verify: self.skip_tls_verify,
157        };
158
159        let generator = K6ScriptGenerator::new(k6_config, templates);
160        let script = generator.generate()?;
161        TerminalReporter::print_success("k6 script generated");
162
163        // Validate the generated script
164        TerminalReporter::print_progress("Validating k6 script...");
165        let validation_errors = K6ScriptGenerator::validate_script(&script);
166        if !validation_errors.is_empty() {
167            TerminalReporter::print_error("Script validation failed");
168            for error in &validation_errors {
169                eprintln!("  {}", error);
170            }
171            return Err(BenchError::Other(format!(
172                "Generated k6 script has {} validation error(s). Please check the output above.",
173                validation_errors.len()
174            )));
175        }
176        TerminalReporter::print_success("Script validation passed");
177
178        // Write script to file
179        let script_path = if let Some(output) = &self.script_output {
180            output.clone()
181        } else {
182            self.output.join("k6-script.js")
183        };
184
185        std::fs::create_dir_all(script_path.parent().unwrap())?;
186        std::fs::write(&script_path, script)?;
187        TerminalReporter::print_success(&format!("Script written to: {}", script_path.display()));
188
189        // If generate-only mode, exit here
190        if self.generate_only {
191            println!("\nScript generated successfully. Run it with:");
192            println!("  k6 run {}", script_path.display());
193            return Ok(());
194        }
195
196        // Execute k6
197        TerminalReporter::print_progress("Executing load test...");
198        let executor = K6Executor::new()?;
199
200        std::fs::create_dir_all(&self.output)?;
201
202        let results = executor.execute(&script_path, Some(&self.output), self.verbose).await?;
203
204        // Print results
205        let duration_secs = Self::parse_duration(&self.duration)?;
206        TerminalReporter::print_summary(&results, duration_secs);
207
208        println!("\nResults saved to: {}", self.output.display());
209
210        Ok(())
211    }
212
213    /// Execute multi-target bench testing
214    async fn execute_multi_target(&self, targets_file: &PathBuf) -> Result<()> {
215        TerminalReporter::print_progress("Parsing targets file...");
216        let targets = parse_targets_file(targets_file)?;
217        let num_targets = targets.len();
218        TerminalReporter::print_success(&format!("Loaded {} targets", num_targets));
219
220        if targets.is_empty() {
221            return Err(BenchError::Other("No targets found in file".to_string()));
222        }
223
224        // Determine max concurrency
225        let max_concurrency = self.max_concurrency.unwrap_or(10) as usize;
226        let max_concurrency = max_concurrency.min(num_targets); // Don't exceed number of targets
227
228        // Print header for multi-target mode
229        TerminalReporter::print_header(
230            self.spec.to_str().unwrap(),
231            &format!("{} targets", num_targets),
232            0,
233            &self.scenario,
234            Self::parse_duration(&self.duration)?,
235        );
236
237        // Create parallel executor
238        let executor = ParallelExecutor::new(
239            BenchCommand {
240                // Clone all fields except targets_file (we don't need it in the executor)
241                spec: self.spec.clone(),
242                target: self.target.clone(), // Not used in multi-target mode, but kept for compatibility
243                duration: self.duration.clone(),
244                vus: self.vus,
245                scenario: self.scenario.clone(),
246                operations: self.operations.clone(),
247                exclude_operations: self.exclude_operations.clone(),
248                auth: self.auth.clone(),
249                headers: self.headers.clone(),
250                output: self.output.clone(),
251                generate_only: self.generate_only,
252                script_output: self.script_output.clone(),
253                threshold_percentile: self.threshold_percentile.clone(),
254                threshold_ms: self.threshold_ms,
255                max_error_rate: self.max_error_rate,
256                verbose: self.verbose,
257                skip_tls_verify: self.skip_tls_verify,
258                targets_file: None,
259                max_concurrency: None,
260                results_format: self.results_format.clone(),
261                params_file: self.params_file.clone(),
262            },
263            targets,
264            max_concurrency,
265        );
266
267        // Execute all targets
268        let aggregated_results = executor.execute_all().await?;
269
270        // Organize and report results
271        self.report_multi_target_results(&aggregated_results)?;
272
273        Ok(())
274    }
275
276    /// Report results for multi-target execution
277    fn report_multi_target_results(&self, results: &AggregatedResults) -> Result<()> {
278        // Print summary
279        TerminalReporter::print_multi_target_summary(results);
280
281        // Save aggregated summary if requested
282        if self.results_format == "aggregated" || self.results_format == "both" {
283            let summary_path = self.output.join("aggregated_summary.json");
284            let summary_json = serde_json::json!({
285                "total_targets": results.total_targets,
286                "successful_targets": results.successful_targets,
287                "failed_targets": results.failed_targets,
288                "aggregated_metrics": {
289                    "total_requests": results.aggregated_metrics.total_requests,
290                    "total_failed_requests": results.aggregated_metrics.total_failed_requests,
291                    "avg_duration_ms": results.aggregated_metrics.avg_duration_ms,
292                    "p95_duration_ms": results.aggregated_metrics.p95_duration_ms,
293                    "p99_duration_ms": results.aggregated_metrics.p99_duration_ms,
294                    "error_rate": results.aggregated_metrics.error_rate,
295                },
296                "target_results": results.target_results.iter().map(|r| {
297                    serde_json::json!({
298                        "target_url": r.target_url,
299                        "target_index": r.target_index,
300                        "success": r.success,
301                        "error": r.error,
302                        "total_requests": r.results.total_requests,
303                        "failed_requests": r.results.failed_requests,
304                        "avg_duration_ms": r.results.avg_duration_ms,
305                        "p95_duration_ms": r.results.p95_duration_ms,
306                        "p99_duration_ms": r.results.p99_duration_ms,
307                        "output_dir": r.output_dir.to_string_lossy(),
308                    })
309                }).collect::<Vec<_>>(),
310            });
311
312            std::fs::write(&summary_path, serde_json::to_string_pretty(&summary_json)?)?;
313            TerminalReporter::print_success(&format!(
314                "Aggregated summary saved to: {}",
315                summary_path.display()
316            ));
317        }
318
319        println!("\nResults saved to: {}", self.output.display());
320        println!("  - Per-target results: {}", self.output.join("target_*").display());
321        if self.results_format == "aggregated" || self.results_format == "both" {
322            println!(
323                "  - Aggregated summary: {}",
324                self.output.join("aggregated_summary.json").display()
325            );
326        }
327
328        Ok(())
329    }
330
331    /// Parse duration string (e.g., "30s", "5m", "1h") to seconds
332    pub fn parse_duration(duration: &str) -> Result<u64> {
333        let duration = duration.trim();
334
335        if let Some(secs) = duration.strip_suffix('s') {
336            secs.parse::<u64>()
337                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
338        } else if let Some(mins) = duration.strip_suffix('m') {
339            mins.parse::<u64>()
340                .map(|m| m * 60)
341                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
342        } else if let Some(hours) = duration.strip_suffix('h') {
343            hours
344                .parse::<u64>()
345                .map(|h| h * 3600)
346                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
347        } else {
348            // Try parsing as seconds without suffix
349            duration
350                .parse::<u64>()
351                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
352        }
353    }
354
355    /// Parse headers from command line format (Key:Value,Key2:Value2)
356    pub fn parse_headers(&self) -> Result<HashMap<String, String>> {
357        let mut headers = HashMap::new();
358
359        if let Some(header_str) = &self.headers {
360            for pair in header_str.split(',') {
361                let parts: Vec<&str> = pair.splitn(2, ':').collect();
362                if parts.len() != 2 {
363                    return Err(BenchError::Other(format!(
364                        "Invalid header format: '{}'. Expected 'Key:Value'",
365                        pair
366                    )));
367                }
368                headers.insert(parts[0].trim().to_string(), parts[1].trim().to_string());
369            }
370        }
371
372        Ok(headers)
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn test_parse_duration() {
382        assert_eq!(BenchCommand::parse_duration("30s").unwrap(), 30);
383        assert_eq!(BenchCommand::parse_duration("5m").unwrap(), 300);
384        assert_eq!(BenchCommand::parse_duration("1h").unwrap(), 3600);
385        assert_eq!(BenchCommand::parse_duration("60").unwrap(), 60);
386    }
387
388    #[test]
389    fn test_parse_duration_invalid() {
390        assert!(BenchCommand::parse_duration("invalid").is_err());
391        assert!(BenchCommand::parse_duration("30x").is_err());
392    }
393
394    #[test]
395    fn test_parse_headers() {
396        let cmd = BenchCommand {
397            spec: PathBuf::from("test.yaml"),
398            target: "http://localhost".to_string(),
399            duration: "1m".to_string(),
400            vus: 10,
401            scenario: "ramp-up".to_string(),
402            operations: None,
403            exclude_operations: None,
404            auth: None,
405            headers: Some("X-API-Key:test123,X-Client-ID:client456".to_string()),
406            output: PathBuf::from("output"),
407            generate_only: false,
408            script_output: None,
409            threshold_percentile: "p(95)".to_string(),
410            threshold_ms: 500,
411            max_error_rate: 0.05,
412            verbose: false,
413            skip_tls_verify: false,
414            targets_file: None,
415            max_concurrency: None,
416            results_format: "both".to_string(),
417            params_file: None,
418        };
419
420        let headers = cmd.parse_headers().unwrap();
421        assert_eq!(headers.get("X-API-Key"), Some(&"test123".to_string()));
422        assert_eq!(headers.get("X-Client-ID"), Some(&"client456".to_string()));
423    }
424}