mockforge_bench/
command.rs

1//! Bench command implementation
2
3use crate::error::{BenchError, Result};
4use crate::executor::K6Executor;
5use crate::k6_gen::{K6Config, K6ScriptGenerator};
6use crate::parallel_executor::{AggregatedResults, ParallelExecutor};
7use crate::reporter::TerminalReporter;
8use crate::request_gen::RequestGenerator;
9use crate::scenarios::LoadScenario;
10use crate::spec_parser::SpecParser;
11use crate::target_parser::parse_targets_file;
12use std::collections::HashMap;
13use std::path::PathBuf;
14use std::str::FromStr;
15
16/// Bench command configuration
17pub struct BenchCommand {
18    pub spec: PathBuf,
19    pub target: String,
20    pub duration: String,
21    pub vus: u32,
22    pub scenario: String,
23    pub operations: Option<String>,
24    pub auth: Option<String>,
25    pub headers: Option<String>,
26    pub output: PathBuf,
27    pub generate_only: bool,
28    pub script_output: Option<PathBuf>,
29    pub threshold_percentile: String,
30    pub threshold_ms: u64,
31    pub max_error_rate: f64,
32    pub verbose: bool,
33    pub skip_tls_verify: bool,
34    /// Optional file containing multiple targets
35    pub targets_file: Option<PathBuf>,
36    /// Maximum number of parallel executions (for multi-target mode)
37    pub max_concurrency: Option<u32>,
38    /// Results format: "per-target", "aggregated", or "both"
39    pub results_format: String,
40}
41
42impl BenchCommand {
43    /// Execute the bench command
44    pub async fn execute(&self) -> Result<()> {
45        // Check if we're in multi-target mode
46        if let Some(targets_file) = &self.targets_file {
47            return self.execute_multi_target(targets_file).await;
48        }
49
50        // Single target mode (existing behavior)
51        // Print header
52        TerminalReporter::print_header(
53            self.spec.to_str().unwrap(),
54            &self.target,
55            0, // Will be updated later
56            &self.scenario,
57            Self::parse_duration(&self.duration)?,
58        );
59
60        // Validate k6 installation
61        if !K6Executor::is_k6_installed() {
62            TerminalReporter::print_error("k6 is not installed");
63            TerminalReporter::print_warning(
64                "Install k6 from: https://k6.io/docs/get-started/installation/",
65            );
66            return Err(BenchError::K6NotFound);
67        }
68
69        // Load and parse spec
70        TerminalReporter::print_progress("Loading OpenAPI specification...");
71        let parser = SpecParser::from_file(&self.spec).await?;
72        TerminalReporter::print_success("Specification loaded");
73
74        // Get operations
75        TerminalReporter::print_progress("Extracting API operations...");
76        let operations = if let Some(filter) = &self.operations {
77            parser.filter_operations(filter)?
78        } else {
79            parser.get_operations()
80        };
81
82        if operations.is_empty() {
83            return Err(BenchError::Other("No operations found in spec".to_string()));
84        }
85
86        TerminalReporter::print_success(&format!("Found {} operations", operations.len()));
87
88        // Generate request templates
89        TerminalReporter::print_progress("Generating request templates...");
90        let templates: Vec<_> = operations
91            .iter()
92            .map(RequestGenerator::generate_template)
93            .collect::<Result<Vec<_>>>()?;
94        TerminalReporter::print_success("Request templates generated");
95
96        // Parse headers
97        let custom_headers = self.parse_headers()?;
98
99        // Generate k6 script
100        TerminalReporter::print_progress("Generating k6 load test script...");
101        let scenario =
102            LoadScenario::from_str(&self.scenario).map_err(BenchError::InvalidScenario)?;
103
104        let k6_config = K6Config {
105            target_url: self.target.clone(),
106            scenario,
107            duration_secs: Self::parse_duration(&self.duration)?,
108            max_vus: self.vus,
109            threshold_percentile: self.threshold_percentile.clone(),
110            threshold_ms: self.threshold_ms,
111            max_error_rate: self.max_error_rate,
112            auth_header: self.auth.clone(),
113            custom_headers,
114            skip_tls_verify: self.skip_tls_verify,
115        };
116
117        let generator = K6ScriptGenerator::new(k6_config, templates);
118        let script = generator.generate()?;
119        TerminalReporter::print_success("k6 script generated");
120
121        // Validate the generated script
122        TerminalReporter::print_progress("Validating k6 script...");
123        let validation_errors = K6ScriptGenerator::validate_script(&script);
124        if !validation_errors.is_empty() {
125            TerminalReporter::print_error("Script validation failed");
126            for error in &validation_errors {
127                eprintln!("  {}", error);
128            }
129            return Err(BenchError::Other(format!(
130                "Generated k6 script has {} validation error(s). Please check the output above.",
131                validation_errors.len()
132            )));
133        }
134        TerminalReporter::print_success("Script validation passed");
135
136        // Write script to file
137        let script_path = if let Some(output) = &self.script_output {
138            output.clone()
139        } else {
140            self.output.join("k6-script.js")
141        };
142
143        std::fs::create_dir_all(script_path.parent().unwrap())?;
144        std::fs::write(&script_path, script)?;
145        TerminalReporter::print_success(&format!("Script written to: {}", script_path.display()));
146
147        // If generate-only mode, exit here
148        if self.generate_only {
149            println!("\nScript generated successfully. Run it with:");
150            println!("  k6 run {}", script_path.display());
151            return Ok(());
152        }
153
154        // Execute k6
155        TerminalReporter::print_progress("Executing load test...");
156        let executor = K6Executor::new()?;
157
158        std::fs::create_dir_all(&self.output)?;
159
160        let results = executor.execute(&script_path, Some(&self.output), self.verbose).await?;
161
162        // Print results
163        let duration_secs = Self::parse_duration(&self.duration)?;
164        TerminalReporter::print_summary(&results, duration_secs);
165
166        println!("\nResults saved to: {}", self.output.display());
167
168        Ok(())
169    }
170
171    /// Execute multi-target bench testing
172    async fn execute_multi_target(&self, targets_file: &PathBuf) -> Result<()> {
173        TerminalReporter::print_progress("Parsing targets file...");
174        let targets = parse_targets_file(targets_file)?;
175        let num_targets = targets.len();
176        TerminalReporter::print_success(&format!("Loaded {} targets", num_targets));
177
178        if targets.is_empty() {
179            return Err(BenchError::Other("No targets found in file".to_string()));
180        }
181
182        // Determine max concurrency
183        let max_concurrency = self.max_concurrency.unwrap_or(10) as usize;
184        let max_concurrency = max_concurrency.min(num_targets); // Don't exceed number of targets
185
186        // Print header for multi-target mode
187        TerminalReporter::print_header(
188            self.spec.to_str().unwrap(),
189            &format!("{} targets", num_targets),
190            0,
191            &self.scenario,
192            Self::parse_duration(&self.duration)?,
193        );
194
195        // Create parallel executor
196        let executor = ParallelExecutor::new(
197            BenchCommand {
198                // Clone all fields except targets_file (we don't need it in the executor)
199                spec: self.spec.clone(),
200                target: self.target.clone(), // Not used in multi-target mode, but kept for compatibility
201                duration: self.duration.clone(),
202                vus: self.vus,
203                scenario: self.scenario.clone(),
204                operations: self.operations.clone(),
205                auth: self.auth.clone(),
206                headers: self.headers.clone(),
207                output: self.output.clone(),
208                generate_only: self.generate_only,
209                script_output: self.script_output.clone(),
210                threshold_percentile: self.threshold_percentile.clone(),
211                threshold_ms: self.threshold_ms,
212                max_error_rate: self.max_error_rate,
213                verbose: self.verbose,
214                skip_tls_verify: self.skip_tls_verify,
215                targets_file: None,
216                max_concurrency: None,
217                results_format: self.results_format.clone(),
218            },
219            targets,
220            max_concurrency,
221        );
222
223        // Execute all targets
224        let aggregated_results = executor.execute_all().await?;
225
226        // Organize and report results
227        self.report_multi_target_results(&aggregated_results)?;
228
229        Ok(())
230    }
231
232    /// Report results for multi-target execution
233    fn report_multi_target_results(&self, results: &AggregatedResults) -> Result<()> {
234        // Print summary
235        TerminalReporter::print_multi_target_summary(results);
236
237        // Save aggregated summary if requested
238        if self.results_format == "aggregated" || self.results_format == "both" {
239            let summary_path = self.output.join("aggregated_summary.json");
240            let summary_json = serde_json::json!({
241                "total_targets": results.total_targets,
242                "successful_targets": results.successful_targets,
243                "failed_targets": results.failed_targets,
244                "aggregated_metrics": {
245                    "total_requests": results.aggregated_metrics.total_requests,
246                    "total_failed_requests": results.aggregated_metrics.total_failed_requests,
247                    "avg_duration_ms": results.aggregated_metrics.avg_duration_ms,
248                    "p95_duration_ms": results.aggregated_metrics.p95_duration_ms,
249                    "p99_duration_ms": results.aggregated_metrics.p99_duration_ms,
250                    "error_rate": results.aggregated_metrics.error_rate,
251                },
252                "target_results": results.target_results.iter().map(|r| {
253                    serde_json::json!({
254                        "target_url": r.target_url,
255                        "target_index": r.target_index,
256                        "success": r.success,
257                        "error": r.error,
258                        "total_requests": r.results.total_requests,
259                        "failed_requests": r.results.failed_requests,
260                        "avg_duration_ms": r.results.avg_duration_ms,
261                        "p95_duration_ms": r.results.p95_duration_ms,
262                        "p99_duration_ms": r.results.p99_duration_ms,
263                        "output_dir": r.output_dir.to_string_lossy(),
264                    })
265                }).collect::<Vec<_>>(),
266            });
267
268            std::fs::write(&summary_path, serde_json::to_string_pretty(&summary_json)?)?;
269            TerminalReporter::print_success(&format!(
270                "Aggregated summary saved to: {}",
271                summary_path.display()
272            ));
273        }
274
275        println!("\nResults saved to: {}", self.output.display());
276        println!("  - Per-target results: {}", self.output.join("target_*").display());
277        if self.results_format == "aggregated" || self.results_format == "both" {
278            println!(
279                "  - Aggregated summary: {}",
280                self.output.join("aggregated_summary.json").display()
281            );
282        }
283
284        Ok(())
285    }
286
287    /// Parse duration string (e.g., "30s", "5m", "1h") to seconds
288    pub fn parse_duration(duration: &str) -> Result<u64> {
289        let duration = duration.trim();
290
291        if let Some(secs) = duration.strip_suffix('s') {
292            secs.parse::<u64>()
293                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
294        } else if let Some(mins) = duration.strip_suffix('m') {
295            mins.parse::<u64>()
296                .map(|m| m * 60)
297                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
298        } else if let Some(hours) = duration.strip_suffix('h') {
299            hours
300                .parse::<u64>()
301                .map(|h| h * 3600)
302                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
303        } else {
304            // Try parsing as seconds without suffix
305            duration
306                .parse::<u64>()
307                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
308        }
309    }
310
311    /// Parse headers from command line format (Key:Value,Key2:Value2)
312    pub fn parse_headers(&self) -> Result<HashMap<String, String>> {
313        let mut headers = HashMap::new();
314
315        if let Some(header_str) = &self.headers {
316            for pair in header_str.split(',') {
317                let parts: Vec<&str> = pair.splitn(2, ':').collect();
318                if parts.len() != 2 {
319                    return Err(BenchError::Other(format!(
320                        "Invalid header format: '{}'. Expected 'Key:Value'",
321                        pair
322                    )));
323                }
324                headers.insert(parts[0].trim().to_string(), parts[1].trim().to_string());
325            }
326        }
327
328        Ok(headers)
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn test_parse_duration() {
338        assert_eq!(BenchCommand::parse_duration("30s").unwrap(), 30);
339        assert_eq!(BenchCommand::parse_duration("5m").unwrap(), 300);
340        assert_eq!(BenchCommand::parse_duration("1h").unwrap(), 3600);
341        assert_eq!(BenchCommand::parse_duration("60").unwrap(), 60);
342    }
343
344    #[test]
345    fn test_parse_duration_invalid() {
346        assert!(BenchCommand::parse_duration("invalid").is_err());
347        assert!(BenchCommand::parse_duration("30x").is_err());
348    }
349
350    #[test]
351    fn test_parse_headers() {
352        let cmd = BenchCommand {
353            spec: PathBuf::from("test.yaml"),
354            target: "http://localhost".to_string(),
355            duration: "1m".to_string(),
356            vus: 10,
357            scenario: "ramp-up".to_string(),
358            operations: None,
359            auth: None,
360            headers: Some("X-API-Key:test123,X-Client-ID:client456".to_string()),
361            output: PathBuf::from("output"),
362            generate_only: false,
363            script_output: None,
364            threshold_percentile: "p(95)".to_string(),
365            threshold_ms: 500,
366            max_error_rate: 0.05,
367            verbose: false,
368            skip_tls_verify: false,
369            targets_file: None,
370            max_concurrency: None,
371            results_format: "both".to_string(),
372        };
373
374        let headers = cmd.parse_headers().unwrap();
375        assert_eq!(headers.get("X-API-Key"), Some(&"test123".to_string()));
376        assert_eq!(headers.get("X-Client-ID"), Some(&"client456".to_string()));
377    }
378}