1use crate::error::{BenchError, Result};
4use crate::executor::K6Executor;
5use crate::k6_gen::{K6Config, K6ScriptGenerator};
6use crate::parallel_executor::{AggregatedResults, ParallelExecutor};
7use crate::param_overrides::ParameterOverrides;
8use crate::reporter::TerminalReporter;
9use crate::request_gen::RequestGenerator;
10use crate::scenarios::LoadScenario;
11use crate::spec_parser::SpecParser;
12use crate::target_parser::parse_targets_file;
13use std::collections::HashMap;
14use std::path::PathBuf;
15use std::str::FromStr;
16
17pub struct BenchCommand {
19 pub spec: PathBuf,
20 pub target: String,
21 pub duration: String,
22 pub vus: u32,
23 pub scenario: String,
24 pub operations: Option<String>,
25 pub exclude_operations: Option<String>,
29 pub auth: Option<String>,
30 pub headers: Option<String>,
31 pub output: PathBuf,
32 pub generate_only: bool,
33 pub script_output: Option<PathBuf>,
34 pub threshold_percentile: String,
35 pub threshold_ms: u64,
36 pub max_error_rate: f64,
37 pub verbose: bool,
38 pub skip_tls_verify: bool,
39 pub targets_file: Option<PathBuf>,
41 pub max_concurrency: Option<u32>,
43 pub results_format: String,
45 pub params_file: Option<PathBuf>,
50}
51
52impl BenchCommand {
53 pub async fn execute(&self) -> Result<()> {
55 if let Some(targets_file) = &self.targets_file {
57 return self.execute_multi_target(targets_file).await;
58 }
59
60 TerminalReporter::print_header(
63 self.spec.to_str().unwrap(),
64 &self.target,
65 0, &self.scenario,
67 Self::parse_duration(&self.duration)?,
68 );
69
70 if !K6Executor::is_k6_installed() {
72 TerminalReporter::print_error("k6 is not installed");
73 TerminalReporter::print_warning(
74 "Install k6 from: https://k6.io/docs/get-started/installation/",
75 );
76 return Err(BenchError::K6NotFound);
77 }
78
79 TerminalReporter::print_progress("Loading OpenAPI specification...");
81 let parser = SpecParser::from_file(&self.spec).await?;
82 TerminalReporter::print_success("Specification loaded");
83
84 TerminalReporter::print_progress("Extracting API operations...");
86 let mut operations = if let Some(filter) = &self.operations {
87 parser.filter_operations(filter)?
88 } else {
89 parser.get_operations()
90 };
91
92 if let Some(exclude) = &self.exclude_operations {
94 let before_count = operations.len();
95 operations = parser.exclude_operations(operations, exclude)?;
96 let excluded_count = before_count - operations.len();
97 if excluded_count > 0 {
98 TerminalReporter::print_progress(&format!(
99 "Excluded {} operations matching '{}'",
100 excluded_count, exclude
101 ));
102 }
103 }
104
105 if operations.is_empty() {
106 return Err(BenchError::Other("No operations found in spec".to_string()));
107 }
108
109 TerminalReporter::print_success(&format!("Found {} operations", operations.len()));
110
111 let param_overrides = if let Some(params_file) = &self.params_file {
113 TerminalReporter::print_progress("Loading parameter overrides...");
114 let overrides = ParameterOverrides::from_file(params_file)?;
115 TerminalReporter::print_success(&format!(
116 "Loaded parameter overrides ({} operation-specific, {} defaults)",
117 overrides.operations.len(),
118 if overrides.defaults.is_empty() { 0 } else { 1 }
119 ));
120 Some(overrides)
121 } else {
122 None
123 };
124
125 TerminalReporter::print_progress("Generating request templates...");
127 let templates: Vec<_> = operations
128 .iter()
129 .map(|op| {
130 let op_overrides = param_overrides.as_ref().map(|po| {
131 po.get_for_operation(op.operation_id.as_deref(), &op.method, &op.path)
132 });
133 RequestGenerator::generate_template_with_overrides(op, op_overrides.as_ref())
134 })
135 .collect::<Result<Vec<_>>>()?;
136 TerminalReporter::print_success("Request templates generated");
137
138 let custom_headers = self.parse_headers()?;
140
141 TerminalReporter::print_progress("Generating k6 load test script...");
143 let scenario =
144 LoadScenario::from_str(&self.scenario).map_err(BenchError::InvalidScenario)?;
145
146 let k6_config = K6Config {
147 target_url: self.target.clone(),
148 scenario,
149 duration_secs: Self::parse_duration(&self.duration)?,
150 max_vus: self.vus,
151 threshold_percentile: self.threshold_percentile.clone(),
152 threshold_ms: self.threshold_ms,
153 max_error_rate: self.max_error_rate,
154 auth_header: self.auth.clone(),
155 custom_headers,
156 skip_tls_verify: self.skip_tls_verify,
157 };
158
159 let generator = K6ScriptGenerator::new(k6_config, templates);
160 let script = generator.generate()?;
161 TerminalReporter::print_success("k6 script generated");
162
163 TerminalReporter::print_progress("Validating k6 script...");
165 let validation_errors = K6ScriptGenerator::validate_script(&script);
166 if !validation_errors.is_empty() {
167 TerminalReporter::print_error("Script validation failed");
168 for error in &validation_errors {
169 eprintln!(" {}", error);
170 }
171 return Err(BenchError::Other(format!(
172 "Generated k6 script has {} validation error(s). Please check the output above.",
173 validation_errors.len()
174 )));
175 }
176 TerminalReporter::print_success("Script validation passed");
177
178 let script_path = if let Some(output) = &self.script_output {
180 output.clone()
181 } else {
182 self.output.join("k6-script.js")
183 };
184
185 std::fs::create_dir_all(script_path.parent().unwrap())?;
186 std::fs::write(&script_path, script)?;
187 TerminalReporter::print_success(&format!("Script written to: {}", script_path.display()));
188
189 if self.generate_only {
191 println!("\nScript generated successfully. Run it with:");
192 println!(" k6 run {}", script_path.display());
193 return Ok(());
194 }
195
196 TerminalReporter::print_progress("Executing load test...");
198 let executor = K6Executor::new()?;
199
200 std::fs::create_dir_all(&self.output)?;
201
202 let results = executor.execute(&script_path, Some(&self.output), self.verbose).await?;
203
204 let duration_secs = Self::parse_duration(&self.duration)?;
206 TerminalReporter::print_summary(&results, duration_secs);
207
208 println!("\nResults saved to: {}", self.output.display());
209
210 Ok(())
211 }
212
213 async fn execute_multi_target(&self, targets_file: &PathBuf) -> Result<()> {
215 TerminalReporter::print_progress("Parsing targets file...");
216 let targets = parse_targets_file(targets_file)?;
217 let num_targets = targets.len();
218 TerminalReporter::print_success(&format!("Loaded {} targets", num_targets));
219
220 if targets.is_empty() {
221 return Err(BenchError::Other("No targets found in file".to_string()));
222 }
223
224 let max_concurrency = self.max_concurrency.unwrap_or(10) as usize;
226 let max_concurrency = max_concurrency.min(num_targets); TerminalReporter::print_header(
230 self.spec.to_str().unwrap(),
231 &format!("{} targets", num_targets),
232 0,
233 &self.scenario,
234 Self::parse_duration(&self.duration)?,
235 );
236
237 let executor = ParallelExecutor::new(
239 BenchCommand {
240 spec: self.spec.clone(),
242 target: self.target.clone(), duration: self.duration.clone(),
244 vus: self.vus,
245 scenario: self.scenario.clone(),
246 operations: self.operations.clone(),
247 exclude_operations: self.exclude_operations.clone(),
248 auth: self.auth.clone(),
249 headers: self.headers.clone(),
250 output: self.output.clone(),
251 generate_only: self.generate_only,
252 script_output: self.script_output.clone(),
253 threshold_percentile: self.threshold_percentile.clone(),
254 threshold_ms: self.threshold_ms,
255 max_error_rate: self.max_error_rate,
256 verbose: self.verbose,
257 skip_tls_verify: self.skip_tls_verify,
258 targets_file: None,
259 max_concurrency: None,
260 results_format: self.results_format.clone(),
261 params_file: self.params_file.clone(),
262 },
263 targets,
264 max_concurrency,
265 );
266
267 let aggregated_results = executor.execute_all().await?;
269
270 self.report_multi_target_results(&aggregated_results)?;
272
273 Ok(())
274 }
275
276 fn report_multi_target_results(&self, results: &AggregatedResults) -> Result<()> {
278 TerminalReporter::print_multi_target_summary(results);
280
281 if self.results_format == "aggregated" || self.results_format == "both" {
283 let summary_path = self.output.join("aggregated_summary.json");
284 let summary_json = serde_json::json!({
285 "total_targets": results.total_targets,
286 "successful_targets": results.successful_targets,
287 "failed_targets": results.failed_targets,
288 "aggregated_metrics": {
289 "total_requests": results.aggregated_metrics.total_requests,
290 "total_failed_requests": results.aggregated_metrics.total_failed_requests,
291 "avg_duration_ms": results.aggregated_metrics.avg_duration_ms,
292 "p95_duration_ms": results.aggregated_metrics.p95_duration_ms,
293 "p99_duration_ms": results.aggregated_metrics.p99_duration_ms,
294 "error_rate": results.aggregated_metrics.error_rate,
295 },
296 "target_results": results.target_results.iter().map(|r| {
297 serde_json::json!({
298 "target_url": r.target_url,
299 "target_index": r.target_index,
300 "success": r.success,
301 "error": r.error,
302 "total_requests": r.results.total_requests,
303 "failed_requests": r.results.failed_requests,
304 "avg_duration_ms": r.results.avg_duration_ms,
305 "p95_duration_ms": r.results.p95_duration_ms,
306 "p99_duration_ms": r.results.p99_duration_ms,
307 "output_dir": r.output_dir.to_string_lossy(),
308 })
309 }).collect::<Vec<_>>(),
310 });
311
312 std::fs::write(&summary_path, serde_json::to_string_pretty(&summary_json)?)?;
313 TerminalReporter::print_success(&format!(
314 "Aggregated summary saved to: {}",
315 summary_path.display()
316 ));
317 }
318
319 println!("\nResults saved to: {}", self.output.display());
320 println!(" - Per-target results: {}", self.output.join("target_*").display());
321 if self.results_format == "aggregated" || self.results_format == "both" {
322 println!(
323 " - Aggregated summary: {}",
324 self.output.join("aggregated_summary.json").display()
325 );
326 }
327
328 Ok(())
329 }
330
331 pub fn parse_duration(duration: &str) -> Result<u64> {
333 let duration = duration.trim();
334
335 if let Some(secs) = duration.strip_suffix('s') {
336 secs.parse::<u64>()
337 .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
338 } else if let Some(mins) = duration.strip_suffix('m') {
339 mins.parse::<u64>()
340 .map(|m| m * 60)
341 .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
342 } else if let Some(hours) = duration.strip_suffix('h') {
343 hours
344 .parse::<u64>()
345 .map(|h| h * 3600)
346 .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
347 } else {
348 duration
350 .parse::<u64>()
351 .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
352 }
353 }
354
355 pub fn parse_headers(&self) -> Result<HashMap<String, String>> {
357 let mut headers = HashMap::new();
358
359 if let Some(header_str) = &self.headers {
360 for pair in header_str.split(',') {
361 let parts: Vec<&str> = pair.splitn(2, ':').collect();
362 if parts.len() != 2 {
363 return Err(BenchError::Other(format!(
364 "Invalid header format: '{}'. Expected 'Key:Value'",
365 pair
366 )));
367 }
368 headers.insert(parts[0].trim().to_string(), parts[1].trim().to_string());
369 }
370 }
371
372 Ok(headers)
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn test_parse_duration() {
382 assert_eq!(BenchCommand::parse_duration("30s").unwrap(), 30);
383 assert_eq!(BenchCommand::parse_duration("5m").unwrap(), 300);
384 assert_eq!(BenchCommand::parse_duration("1h").unwrap(), 3600);
385 assert_eq!(BenchCommand::parse_duration("60").unwrap(), 60);
386 }
387
388 #[test]
389 fn test_parse_duration_invalid() {
390 assert!(BenchCommand::parse_duration("invalid").is_err());
391 assert!(BenchCommand::parse_duration("30x").is_err());
392 }
393
394 #[test]
395 fn test_parse_headers() {
396 let cmd = BenchCommand {
397 spec: PathBuf::from("test.yaml"),
398 target: "http://localhost".to_string(),
399 duration: "1m".to_string(),
400 vus: 10,
401 scenario: "ramp-up".to_string(),
402 operations: None,
403 exclude_operations: None,
404 auth: None,
405 headers: Some("X-API-Key:test123,X-Client-ID:client456".to_string()),
406 output: PathBuf::from("output"),
407 generate_only: false,
408 script_output: None,
409 threshold_percentile: "p(95)".to_string(),
410 threshold_ms: 500,
411 max_error_rate: 0.05,
412 verbose: false,
413 skip_tls_verify: false,
414 targets_file: None,
415 max_concurrency: None,
416 results_format: "both".to_string(),
417 params_file: None,
418 };
419
420 let headers = cmd.parse_headers().unwrap();
421 assert_eq!(headers.get("X-API-Key"), Some(&"test123".to_string()));
422 assert_eq!(headers.get("X-Client-ID"), Some(&"client456".to_string()));
423 }
424}