mockforge_bench/
command.rs

1//! Bench command implementation
2
3use crate::error::{BenchError, Result};
4use crate::executor::K6Executor;
5use crate::k6_gen::{K6Config, K6ScriptGenerator};
6use crate::reporter::TerminalReporter;
7use crate::request_gen::RequestGenerator;
8use crate::scenarios::LoadScenario;
9use crate::spec_parser::SpecParser;
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::str::FromStr;
13
14/// Bench command configuration
15pub struct BenchCommand {
16    pub spec: PathBuf,
17    pub target: String,
18    pub duration: String,
19    pub vus: u32,
20    pub scenario: String,
21    pub operations: Option<String>,
22    pub auth: Option<String>,
23    pub headers: Option<String>,
24    pub output: PathBuf,
25    pub generate_only: bool,
26    pub script_output: Option<PathBuf>,
27    pub threshold_percentile: String,
28    pub threshold_ms: u64,
29    pub max_error_rate: f64,
30    pub verbose: bool,
31}
32
33impl BenchCommand {
34    /// Execute the bench command
35    pub async fn execute(&self) -> Result<()> {
36        // Print header
37        TerminalReporter::print_header(
38            self.spec.to_str().unwrap(),
39            &self.target,
40            0, // Will be updated later
41            &self.scenario,
42            Self::parse_duration(&self.duration)?,
43        );
44
45        // Validate k6 installation
46        if !K6Executor::is_k6_installed() {
47            TerminalReporter::print_error("k6 is not installed");
48            TerminalReporter::print_warning(
49                "Install k6 from: https://k6.io/docs/get-started/installation/",
50            );
51            return Err(BenchError::K6NotFound);
52        }
53
54        // Load and parse spec
55        TerminalReporter::print_progress("Loading OpenAPI specification...");
56        let parser = SpecParser::from_file(&self.spec).await?;
57        TerminalReporter::print_success("Specification loaded");
58
59        // Get operations
60        TerminalReporter::print_progress("Extracting API operations...");
61        let operations = if let Some(filter) = &self.operations {
62            parser.filter_operations(filter)?
63        } else {
64            parser.get_operations()
65        };
66
67        if operations.is_empty() {
68            return Err(BenchError::Other("No operations found in spec".to_string()));
69        }
70
71        TerminalReporter::print_success(&format!("Found {} operations", operations.len()));
72
73        // Generate request templates
74        TerminalReporter::print_progress("Generating request templates...");
75        let templates: Vec<_> = operations
76            .iter()
77            .map(RequestGenerator::generate_template)
78            .collect::<Result<Vec<_>>>()?;
79        TerminalReporter::print_success("Request templates generated");
80
81        // Parse headers
82        let custom_headers = self.parse_headers()?;
83
84        // Generate k6 script
85        TerminalReporter::print_progress("Generating k6 load test script...");
86        let scenario =
87            LoadScenario::from_str(&self.scenario).map_err(BenchError::InvalidScenario)?;
88
89        let k6_config = K6Config {
90            target_url: self.target.clone(),
91            scenario,
92            duration_secs: Self::parse_duration(&self.duration)?,
93            max_vus: self.vus,
94            threshold_percentile: self.threshold_percentile.clone(),
95            threshold_ms: self.threshold_ms,
96            max_error_rate: self.max_error_rate,
97            auth_header: self.auth.clone(),
98            custom_headers,
99        };
100
101        let generator = K6ScriptGenerator::new(k6_config, templates);
102        let script = generator.generate()?;
103        TerminalReporter::print_success("k6 script generated");
104
105        // Validate the generated script
106        TerminalReporter::print_progress("Validating k6 script...");
107        let validation_errors = K6ScriptGenerator::validate_script(&script);
108        if !validation_errors.is_empty() {
109            TerminalReporter::print_error("Script validation failed");
110            for error in &validation_errors {
111                eprintln!("  {}", error);
112            }
113            return Err(BenchError::Other(format!(
114                "Generated k6 script has {} validation error(s). Please check the output above.",
115                validation_errors.len()
116            )));
117        }
118        TerminalReporter::print_success("Script validation passed");
119
120        // Write script to file
121        let script_path = if let Some(output) = &self.script_output {
122            output.clone()
123        } else {
124            self.output.join("k6-script.js")
125        };
126
127        std::fs::create_dir_all(script_path.parent().unwrap())?;
128        std::fs::write(&script_path, script)?;
129        TerminalReporter::print_success(&format!("Script written to: {}", script_path.display()));
130
131        // If generate-only mode, exit here
132        if self.generate_only {
133            println!("\nScript generated successfully. Run it with:");
134            println!("  k6 run {}", script_path.display());
135            return Ok(());
136        }
137
138        // Execute k6
139        TerminalReporter::print_progress("Executing load test...");
140        let executor = K6Executor::new()?;
141
142        std::fs::create_dir_all(&self.output)?;
143
144        let results = executor.execute(&script_path, Some(&self.output), self.verbose).await?;
145
146        // Print results
147        let duration_secs = Self::parse_duration(&self.duration)?;
148        TerminalReporter::print_summary(&results, duration_secs);
149
150        println!("\nResults saved to: {}", self.output.display());
151
152        Ok(())
153    }
154
155    /// Parse duration string (e.g., "30s", "5m", "1h") to seconds
156    fn parse_duration(duration: &str) -> Result<u64> {
157        let duration = duration.trim();
158
159        if let Some(secs) = duration.strip_suffix('s') {
160            secs.parse::<u64>()
161                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
162        } else if let Some(mins) = duration.strip_suffix('m') {
163            mins.parse::<u64>()
164                .map(|m| m * 60)
165                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
166        } else if let Some(hours) = duration.strip_suffix('h') {
167            hours
168                .parse::<u64>()
169                .map(|h| h * 3600)
170                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
171        } else {
172            // Try parsing as seconds without suffix
173            duration
174                .parse::<u64>()
175                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
176        }
177    }
178
179    /// Parse headers from command line format (Key:Value,Key2:Value2)
180    fn parse_headers(&self) -> Result<HashMap<String, String>> {
181        let mut headers = HashMap::new();
182
183        if let Some(header_str) = &self.headers {
184            for pair in header_str.split(',') {
185                let parts: Vec<&str> = pair.splitn(2, ':').collect();
186                if parts.len() != 2 {
187                    return Err(BenchError::Other(format!(
188                        "Invalid header format: '{}'. Expected 'Key:Value'",
189                        pair
190                    )));
191                }
192                headers.insert(parts[0].trim().to_string(), parts[1].trim().to_string());
193            }
194        }
195
196        Ok(headers)
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn test_parse_duration() {
206        assert_eq!(BenchCommand::parse_duration("30s").unwrap(), 30);
207        assert_eq!(BenchCommand::parse_duration("5m").unwrap(), 300);
208        assert_eq!(BenchCommand::parse_duration("1h").unwrap(), 3600);
209        assert_eq!(BenchCommand::parse_duration("60").unwrap(), 60);
210    }
211
212    #[test]
213    fn test_parse_duration_invalid() {
214        assert!(BenchCommand::parse_duration("invalid").is_err());
215        assert!(BenchCommand::parse_duration("30x").is_err());
216    }
217
218    #[test]
219    fn test_parse_headers() {
220        let cmd = BenchCommand {
221            spec: PathBuf::from("test.yaml"),
222            target: "http://localhost".to_string(),
223            duration: "1m".to_string(),
224            vus: 10,
225            scenario: "ramp-up".to_string(),
226            operations: None,
227            auth: None,
228            headers: Some("X-API-Key:test123,X-Client-ID:client456".to_string()),
229            output: PathBuf::from("output"),
230            generate_only: false,
231            script_output: None,
232            threshold_percentile: "p(95)".to_string(),
233            threshold_ms: 500,
234            max_error_rate: 0.05,
235            verbose: false,
236        };
237
238        let headers = cmd.parse_headers().unwrap();
239        assert_eq!(headers.get("X-API-Key"), Some(&"test123".to_string()));
240        assert_eq!(headers.get("X-Client-ID"), Some(&"client456".to_string()));
241    }
242}