mockforge_bench/
command.rs

1//! Bench command implementation
2
3use crate::error::{BenchError, Result};
4use crate::executor::K6Executor;
5use crate::k6_gen::{K6Config, K6ScriptGenerator};
6use crate::reporter::TerminalReporter;
7use crate::request_gen::RequestGenerator;
8use crate::scenarios::LoadScenario;
9use crate::spec_parser::SpecParser;
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::str::FromStr;
13
14/// Bench command configuration
15pub struct BenchCommand {
16    pub spec: PathBuf,
17    pub target: String,
18    pub duration: String,
19    pub vus: u32,
20    pub scenario: String,
21    pub operations: Option<String>,
22    pub auth: Option<String>,
23    pub headers: Option<String>,
24    pub output: PathBuf,
25    pub generate_only: bool,
26    pub script_output: Option<PathBuf>,
27    pub threshold_percentile: String,
28    pub threshold_ms: u64,
29    pub max_error_rate: f64,
30    pub verbose: bool,
31}
32
33impl BenchCommand {
34    /// Execute the bench command
35    pub async fn execute(&self) -> Result<()> {
36        // Print header
37        TerminalReporter::print_header(
38            self.spec.to_str().unwrap(),
39            &self.target,
40            0, // Will be updated later
41            &self.scenario,
42            Self::parse_duration(&self.duration)?,
43        );
44
45        // Validate k6 installation
46        if !K6Executor::is_k6_installed() {
47            TerminalReporter::print_error("k6 is not installed");
48            TerminalReporter::print_warning(
49                "Install k6 from: https://k6.io/docs/get-started/installation/",
50            );
51            return Err(BenchError::K6NotFound);
52        }
53
54        // Load and parse spec
55        TerminalReporter::print_progress("Loading OpenAPI specification...");
56        let parser = SpecParser::from_file(&self.spec).await?;
57        TerminalReporter::print_success("Specification loaded");
58
59        // Get operations
60        TerminalReporter::print_progress("Extracting API operations...");
61        let operations = if let Some(filter) = &self.operations {
62            parser.filter_operations(filter)?
63        } else {
64            parser.get_operations()
65        };
66
67        if operations.is_empty() {
68            return Err(BenchError::Other("No operations found in spec".to_string()));
69        }
70
71        TerminalReporter::print_success(&format!("Found {} operations", operations.len()));
72
73        // Generate request templates
74        TerminalReporter::print_progress("Generating request templates...");
75        let templates: Vec<_> = operations
76            .iter()
77            .map(RequestGenerator::generate_template)
78            .collect::<Result<Vec<_>>>()?;
79        TerminalReporter::print_success("Request templates generated");
80
81        // Parse headers
82        let custom_headers = self.parse_headers()?;
83
84        // Generate k6 script
85        TerminalReporter::print_progress("Generating k6 load test script...");
86        let scenario =
87            LoadScenario::from_str(&self.scenario).map_err(BenchError::InvalidScenario)?;
88
89        let k6_config = K6Config {
90            target_url: self.target.clone(),
91            scenario,
92            duration_secs: Self::parse_duration(&self.duration)?,
93            max_vus: self.vus,
94            threshold_percentile: self.threshold_percentile.clone(),
95            threshold_ms: self.threshold_ms,
96            max_error_rate: self.max_error_rate,
97            auth_header: self.auth.clone(),
98            custom_headers,
99        };
100
101        let generator = K6ScriptGenerator::new(k6_config, templates);
102        let script = generator.generate()?;
103        TerminalReporter::print_success("k6 script generated");
104
105        // Write script to file
106        let script_path = if let Some(output) = &self.script_output {
107            output.clone()
108        } else {
109            self.output.join("k6-script.js")
110        };
111
112        std::fs::create_dir_all(script_path.parent().unwrap())?;
113        std::fs::write(&script_path, script)?;
114        TerminalReporter::print_success(&format!("Script written to: {}", script_path.display()));
115
116        // If generate-only mode, exit here
117        if self.generate_only {
118            println!("\nScript generated successfully. Run it with:");
119            println!("  k6 run {}", script_path.display());
120            return Ok(());
121        }
122
123        // Execute k6
124        TerminalReporter::print_progress("Executing load test...");
125        let executor = K6Executor::new()?;
126
127        std::fs::create_dir_all(&self.output)?;
128
129        let results = executor.execute(&script_path, Some(&self.output), self.verbose).await?;
130
131        // Print results
132        let duration_secs = Self::parse_duration(&self.duration)?;
133        TerminalReporter::print_summary(&results, duration_secs);
134
135        println!("\nResults saved to: {}", self.output.display());
136
137        Ok(())
138    }
139
140    /// Parse duration string (e.g., "30s", "5m", "1h") to seconds
141    fn parse_duration(duration: &str) -> Result<u64> {
142        let duration = duration.trim();
143
144        if let Some(secs) = duration.strip_suffix('s') {
145            secs.parse::<u64>()
146                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
147        } else if let Some(mins) = duration.strip_suffix('m') {
148            mins.parse::<u64>()
149                .map(|m| m * 60)
150                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
151        } else if let Some(hours) = duration.strip_suffix('h') {
152            hours
153                .parse::<u64>()
154                .map(|h| h * 3600)
155                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
156        } else {
157            // Try parsing as seconds without suffix
158            duration
159                .parse::<u64>()
160                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
161        }
162    }
163
164    /// Parse headers from command line format (Key:Value,Key2:Value2)
165    fn parse_headers(&self) -> Result<HashMap<String, String>> {
166        let mut headers = HashMap::new();
167
168        if let Some(header_str) = &self.headers {
169            for pair in header_str.split(',') {
170                let parts: Vec<&str> = pair.splitn(2, ':').collect();
171                if parts.len() != 2 {
172                    return Err(BenchError::Other(format!(
173                        "Invalid header format: '{}'. Expected 'Key:Value'",
174                        pair
175                    )));
176                }
177                headers.insert(parts[0].trim().to_string(), parts[1].trim().to_string());
178            }
179        }
180
181        Ok(headers)
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    #[test]
190    fn test_parse_duration() {
191        assert_eq!(BenchCommand::parse_duration("30s").unwrap(), 30);
192        assert_eq!(BenchCommand::parse_duration("5m").unwrap(), 300);
193        assert_eq!(BenchCommand::parse_duration("1h").unwrap(), 3600);
194        assert_eq!(BenchCommand::parse_duration("60").unwrap(), 60);
195    }
196
197    #[test]
198    fn test_parse_duration_invalid() {
199        assert!(BenchCommand::parse_duration("invalid").is_err());
200        assert!(BenchCommand::parse_duration("30x").is_err());
201    }
202
203    #[test]
204    fn test_parse_headers() {
205        let cmd = BenchCommand {
206            spec: PathBuf::from("test.yaml"),
207            target: "http://localhost".to_string(),
208            duration: "1m".to_string(),
209            vus: 10,
210            scenario: "ramp-up".to_string(),
211            operations: None,
212            auth: None,
213            headers: Some("X-API-Key:test123,X-Client-ID:client456".to_string()),
214            output: PathBuf::from("output"),
215            generate_only: false,
216            script_output: None,
217            threshold_percentile: "p95".to_string(),
218            threshold_ms: 500,
219            max_error_rate: 0.05,
220            verbose: false,
221        };
222
223        let headers = cmd.parse_headers().unwrap();
224        assert_eq!(headers.get("X-API-Key"), Some(&"test123".to_string()));
225        assert_eq!(headers.get("X-Client-ID"), Some(&"client456".to_string()));
226    }
227}