1use crate::error::{BenchError, Result};
4use crate::executor::K6Executor;
5use crate::k6_gen::{K6Config, K6ScriptGenerator};
6use crate::parallel_executor::{AggregatedResults, ParallelExecutor};
7use crate::reporter::TerminalReporter;
8use crate::request_gen::RequestGenerator;
9use crate::scenarios::LoadScenario;
10use crate::spec_parser::SpecParser;
11use crate::target_parser::parse_targets_file;
12use std::collections::HashMap;
13use std::path::PathBuf;
14use std::str::FromStr;
15
16pub struct BenchCommand {
18 pub spec: PathBuf,
19 pub target: String,
20 pub duration: String,
21 pub vus: u32,
22 pub scenario: String,
23 pub operations: Option<String>,
24 pub auth: Option<String>,
25 pub headers: Option<String>,
26 pub output: PathBuf,
27 pub generate_only: bool,
28 pub script_output: Option<PathBuf>,
29 pub threshold_percentile: String,
30 pub threshold_ms: u64,
31 pub max_error_rate: f64,
32 pub verbose: bool,
33 pub skip_tls_verify: bool,
34 pub targets_file: Option<PathBuf>,
36 pub max_concurrency: Option<u32>,
38 pub results_format: String,
40}
41
42impl BenchCommand {
43 pub async fn execute(&self) -> Result<()> {
45 if let Some(targets_file) = &self.targets_file {
47 return self.execute_multi_target(targets_file).await;
48 }
49
50 TerminalReporter::print_header(
53 self.spec.to_str().unwrap(),
54 &self.target,
55 0, &self.scenario,
57 Self::parse_duration(&self.duration)?,
58 );
59
60 if !K6Executor::is_k6_installed() {
62 TerminalReporter::print_error("k6 is not installed");
63 TerminalReporter::print_warning(
64 "Install k6 from: https://k6.io/docs/get-started/installation/",
65 );
66 return Err(BenchError::K6NotFound);
67 }
68
69 TerminalReporter::print_progress("Loading OpenAPI specification...");
71 let parser = SpecParser::from_file(&self.spec).await?;
72 TerminalReporter::print_success("Specification loaded");
73
74 TerminalReporter::print_progress("Extracting API operations...");
76 let operations = if let Some(filter) = &self.operations {
77 parser.filter_operations(filter)?
78 } else {
79 parser.get_operations()
80 };
81
82 if operations.is_empty() {
83 return Err(BenchError::Other("No operations found in spec".to_string()));
84 }
85
86 TerminalReporter::print_success(&format!("Found {} operations", operations.len()));
87
88 TerminalReporter::print_progress("Generating request templates...");
90 let templates: Vec<_> = operations
91 .iter()
92 .map(RequestGenerator::generate_template)
93 .collect::<Result<Vec<_>>>()?;
94 TerminalReporter::print_success("Request templates generated");
95
96 let custom_headers = self.parse_headers()?;
98
99 TerminalReporter::print_progress("Generating k6 load test script...");
101 let scenario =
102 LoadScenario::from_str(&self.scenario).map_err(BenchError::InvalidScenario)?;
103
104 let k6_config = K6Config {
105 target_url: self.target.clone(),
106 scenario,
107 duration_secs: Self::parse_duration(&self.duration)?,
108 max_vus: self.vus,
109 threshold_percentile: self.threshold_percentile.clone(),
110 threshold_ms: self.threshold_ms,
111 max_error_rate: self.max_error_rate,
112 auth_header: self.auth.clone(),
113 custom_headers,
114 skip_tls_verify: self.skip_tls_verify,
115 };
116
117 let generator = K6ScriptGenerator::new(k6_config, templates);
118 let script = generator.generate()?;
119 TerminalReporter::print_success("k6 script generated");
120
121 TerminalReporter::print_progress("Validating k6 script...");
123 let validation_errors = K6ScriptGenerator::validate_script(&script);
124 if !validation_errors.is_empty() {
125 TerminalReporter::print_error("Script validation failed");
126 for error in &validation_errors {
127 eprintln!(" {}", error);
128 }
129 return Err(BenchError::Other(format!(
130 "Generated k6 script has {} validation error(s). Please check the output above.",
131 validation_errors.len()
132 )));
133 }
134 TerminalReporter::print_success("Script validation passed");
135
136 let script_path = if let Some(output) = &self.script_output {
138 output.clone()
139 } else {
140 self.output.join("k6-script.js")
141 };
142
143 std::fs::create_dir_all(script_path.parent().unwrap())?;
144 std::fs::write(&script_path, script)?;
145 TerminalReporter::print_success(&format!("Script written to: {}", script_path.display()));
146
147 if self.generate_only {
149 println!("\nScript generated successfully. Run it with:");
150 println!(" k6 run {}", script_path.display());
151 return Ok(());
152 }
153
154 TerminalReporter::print_progress("Executing load test...");
156 let executor = K6Executor::new()?;
157
158 std::fs::create_dir_all(&self.output)?;
159
160 let results = executor.execute(&script_path, Some(&self.output), self.verbose).await?;
161
162 let duration_secs = Self::parse_duration(&self.duration)?;
164 TerminalReporter::print_summary(&results, duration_secs);
165
166 println!("\nResults saved to: {}", self.output.display());
167
168 Ok(())
169 }
170
171 async fn execute_multi_target(&self, targets_file: &PathBuf) -> Result<()> {
173 TerminalReporter::print_progress("Parsing targets file...");
174 let targets = parse_targets_file(targets_file)?;
175 let num_targets = targets.len();
176 TerminalReporter::print_success(&format!("Loaded {} targets", num_targets));
177
178 if targets.is_empty() {
179 return Err(BenchError::Other("No targets found in file".to_string()));
180 }
181
182 let max_concurrency = self.max_concurrency.unwrap_or(10) as usize;
184 let max_concurrency = max_concurrency.min(num_targets); TerminalReporter::print_header(
188 self.spec.to_str().unwrap(),
189 &format!("{} targets", num_targets),
190 0,
191 &self.scenario,
192 Self::parse_duration(&self.duration)?,
193 );
194
195 let executor = ParallelExecutor::new(
197 BenchCommand {
198 spec: self.spec.clone(),
200 target: self.target.clone(), duration: self.duration.clone(),
202 vus: self.vus,
203 scenario: self.scenario.clone(),
204 operations: self.operations.clone(),
205 auth: self.auth.clone(),
206 headers: self.headers.clone(),
207 output: self.output.clone(),
208 generate_only: self.generate_only,
209 script_output: self.script_output.clone(),
210 threshold_percentile: self.threshold_percentile.clone(),
211 threshold_ms: self.threshold_ms,
212 max_error_rate: self.max_error_rate,
213 verbose: self.verbose,
214 skip_tls_verify: self.skip_tls_verify,
215 targets_file: None,
216 max_concurrency: None,
217 results_format: self.results_format.clone(),
218 },
219 targets,
220 max_concurrency,
221 );
222
223 let aggregated_results = executor.execute_all().await?;
225
226 self.report_multi_target_results(&aggregated_results)?;
228
229 Ok(())
230 }
231
232 fn report_multi_target_results(&self, results: &AggregatedResults) -> Result<()> {
234 TerminalReporter::print_multi_target_summary(results);
236
237 if self.results_format == "aggregated" || self.results_format == "both" {
239 let summary_path = self.output.join("aggregated_summary.json");
240 let summary_json = serde_json::json!({
241 "total_targets": results.total_targets,
242 "successful_targets": results.successful_targets,
243 "failed_targets": results.failed_targets,
244 "aggregated_metrics": {
245 "total_requests": results.aggregated_metrics.total_requests,
246 "total_failed_requests": results.aggregated_metrics.total_failed_requests,
247 "avg_duration_ms": results.aggregated_metrics.avg_duration_ms,
248 "p95_duration_ms": results.aggregated_metrics.p95_duration_ms,
249 "p99_duration_ms": results.aggregated_metrics.p99_duration_ms,
250 "error_rate": results.aggregated_metrics.error_rate,
251 },
252 "target_results": results.target_results.iter().map(|r| {
253 serde_json::json!({
254 "target_url": r.target_url,
255 "target_index": r.target_index,
256 "success": r.success,
257 "error": r.error,
258 "total_requests": r.results.total_requests,
259 "failed_requests": r.results.failed_requests,
260 "avg_duration_ms": r.results.avg_duration_ms,
261 "p95_duration_ms": r.results.p95_duration_ms,
262 "p99_duration_ms": r.results.p99_duration_ms,
263 "output_dir": r.output_dir.to_string_lossy(),
264 })
265 }).collect::<Vec<_>>(),
266 });
267
268 std::fs::write(&summary_path, serde_json::to_string_pretty(&summary_json)?)?;
269 TerminalReporter::print_success(&format!(
270 "Aggregated summary saved to: {}",
271 summary_path.display()
272 ));
273 }
274
275 println!("\nResults saved to: {}", self.output.display());
276 println!(" - Per-target results: {}", self.output.join("target_*").display());
277 if self.results_format == "aggregated" || self.results_format == "both" {
278 println!(
279 " - Aggregated summary: {}",
280 self.output.join("aggregated_summary.json").display()
281 );
282 }
283
284 Ok(())
285 }
286
287 pub fn parse_duration(duration: &str) -> Result<u64> {
289 let duration = duration.trim();
290
291 if let Some(secs) = duration.strip_suffix('s') {
292 secs.parse::<u64>()
293 .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
294 } else if let Some(mins) = duration.strip_suffix('m') {
295 mins.parse::<u64>()
296 .map(|m| m * 60)
297 .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
298 } else if let Some(hours) = duration.strip_suffix('h') {
299 hours
300 .parse::<u64>()
301 .map(|h| h * 3600)
302 .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
303 } else {
304 duration
306 .parse::<u64>()
307 .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
308 }
309 }
310
311 pub fn parse_headers(&self) -> Result<HashMap<String, String>> {
313 let mut headers = HashMap::new();
314
315 if let Some(header_str) = &self.headers {
316 for pair in header_str.split(',') {
317 let parts: Vec<&str> = pair.splitn(2, ':').collect();
318 if parts.len() != 2 {
319 return Err(BenchError::Other(format!(
320 "Invalid header format: '{}'. Expected 'Key:Value'",
321 pair
322 )));
323 }
324 headers.insert(parts[0].trim().to_string(), parts[1].trim().to_string());
325 }
326 }
327
328 Ok(headers)
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn test_parse_duration() {
338 assert_eq!(BenchCommand::parse_duration("30s").unwrap(), 30);
339 assert_eq!(BenchCommand::parse_duration("5m").unwrap(), 300);
340 assert_eq!(BenchCommand::parse_duration("1h").unwrap(), 3600);
341 assert_eq!(BenchCommand::parse_duration("60").unwrap(), 60);
342 }
343
344 #[test]
345 fn test_parse_duration_invalid() {
346 assert!(BenchCommand::parse_duration("invalid").is_err());
347 assert!(BenchCommand::parse_duration("30x").is_err());
348 }
349
350 #[test]
351 fn test_parse_headers() {
352 let cmd = BenchCommand {
353 spec: PathBuf::from("test.yaml"),
354 target: "http://localhost".to_string(),
355 duration: "1m".to_string(),
356 vus: 10,
357 scenario: "ramp-up".to_string(),
358 operations: None,
359 auth: None,
360 headers: Some("X-API-Key:test123,X-Client-ID:client456".to_string()),
361 output: PathBuf::from("output"),
362 generate_only: false,
363 script_output: None,
364 threshold_percentile: "p(95)".to_string(),
365 threshold_ms: 500,
366 max_error_rate: 0.05,
367 verbose: false,
368 skip_tls_verify: false,
369 targets_file: None,
370 max_concurrency: None,
371 results_format: "both".to_string(),
372 };
373
374 let headers = cmd.parse_headers().unwrap();
375 assert_eq!(headers.get("X-API-Key"), Some(&"test123".to_string()));
376 assert_eq!(headers.get("X-Client-ID"), Some(&"client456".to_string()));
377 }
378}