mockforge_bench/
command.rs

1//! Bench command implementation
2
3use crate::crud_flow::{CrudFlowConfig, CrudFlowDetector};
4use crate::data_driven::{DataDistribution, DataDrivenConfig, DataDrivenGenerator, DataMapping};
5use crate::error::{BenchError, Result};
6use crate::executor::K6Executor;
7use crate::invalid_data::{InvalidDataConfig, InvalidDataGenerator, InvalidDataType};
8use crate::k6_gen::{K6Config, K6ScriptGenerator};
9use crate::mock_integration::{
10    MockIntegrationConfig, MockIntegrationGenerator, MockServerDetector,
11};
12use crate::parallel_executor::{AggregatedResults, ParallelExecutor};
13use crate::parallel_requests::{ParallelConfig, ParallelRequestGenerator};
14use crate::param_overrides::ParameterOverrides;
15use crate::reporter::TerminalReporter;
16use crate::request_gen::RequestGenerator;
17use crate::scenarios::LoadScenario;
18use crate::security_payloads::{
19    SecurityCategory, SecurityPayloads, SecurityTestConfig, SecurityTestGenerator,
20};
21use crate::spec_parser::SpecParser;
22use crate::target_parser::parse_targets_file;
23use std::collections::HashMap;
24use std::path::PathBuf;
25use std::str::FromStr;
26
27/// Bench command configuration
28pub struct BenchCommand {
29    pub spec: PathBuf,
30    pub target: String,
31    pub duration: String,
32    pub vus: u32,
33    pub scenario: String,
34    pub operations: Option<String>,
35    /// Exclude operations from testing (comma-separated)
36    ///
37    /// Supports "METHOD /path" or just "METHOD" to exclude all operations of that type.
38    pub exclude_operations: Option<String>,
39    pub auth: Option<String>,
40    pub headers: Option<String>,
41    pub output: PathBuf,
42    pub generate_only: bool,
43    pub script_output: Option<PathBuf>,
44    pub threshold_percentile: String,
45    pub threshold_ms: u64,
46    pub max_error_rate: f64,
47    pub verbose: bool,
48    pub skip_tls_verify: bool,
49    /// Optional file containing multiple targets
50    pub targets_file: Option<PathBuf>,
51    /// Maximum number of parallel executions (for multi-target mode)
52    pub max_concurrency: Option<u32>,
53    /// Results format: "per-target", "aggregated", or "both"
54    pub results_format: String,
55    /// Optional file containing parameter value overrides (JSON or YAML)
56    ///
57    /// Allows users to provide custom values for path parameters, query parameters,
58    /// headers, and request bodies instead of auto-generated placeholder values.
59    pub params_file: Option<PathBuf>,
60
61    // === CRUD Flow Options ===
62    /// Enable CRUD flow mode
63    pub crud_flow: bool,
64    /// Custom CRUD flow configuration file
65    pub flow_config: Option<PathBuf>,
66    /// Fields to extract from responses
67    pub extract_fields: Option<String>,
68
69    // === Parallel Execution Options ===
70    /// Number of resources to create in parallel
71    pub parallel_create: Option<u32>,
72
73    // === Data-Driven Testing Options ===
74    /// Test data file (CSV or JSON)
75    pub data_file: Option<PathBuf>,
76    /// Data distribution strategy
77    pub data_distribution: String,
78    /// Data column to field mappings
79    pub data_mappings: Option<String>,
80
81    // === Invalid Data Testing Options ===
82    /// Percentage of requests with invalid data
83    pub error_rate: Option<f64>,
84    /// Types of invalid data to generate
85    pub error_types: Option<String>,
86
87    // === Security Testing Options ===
88    /// Enable security testing
89    pub security_test: bool,
90    /// Custom security payloads file
91    pub security_payloads: Option<PathBuf>,
92    /// Security test categories
93    pub security_categories: Option<String>,
94    /// Fields to target for security injection
95    pub security_target_fields: Option<String>,
96}
97
98impl BenchCommand {
99    /// Execute the bench command
100    pub async fn execute(&self) -> Result<()> {
101        // Check if we're in multi-target mode
102        if let Some(targets_file) = &self.targets_file {
103            return self.execute_multi_target(targets_file).await;
104        }
105
106        // Single target mode (existing behavior)
107        // Print header
108        TerminalReporter::print_header(
109            self.spec.to_str().unwrap(),
110            &self.target,
111            0, // Will be updated later
112            &self.scenario,
113            Self::parse_duration(&self.duration)?,
114        );
115
116        // Validate k6 installation
117        if !K6Executor::is_k6_installed() {
118            TerminalReporter::print_error("k6 is not installed");
119            TerminalReporter::print_warning(
120                "Install k6 from: https://k6.io/docs/get-started/installation/",
121            );
122            return Err(BenchError::K6NotFound);
123        }
124
125        // Load and parse spec
126        TerminalReporter::print_progress("Loading OpenAPI specification...");
127        let parser = SpecParser::from_file(&self.spec).await?;
128        TerminalReporter::print_success("Specification loaded");
129
130        // Check for mock server integration
131        let mock_config = self.build_mock_config().await;
132        if mock_config.is_mock_server {
133            TerminalReporter::print_progress("Mock server integration enabled");
134        }
135
136        // Check for CRUD flow mode
137        if self.crud_flow {
138            return self.execute_crud_flow(&parser).await;
139        }
140
141        // Get operations
142        TerminalReporter::print_progress("Extracting API operations...");
143        let mut operations = if let Some(filter) = &self.operations {
144            parser.filter_operations(filter)?
145        } else {
146            parser.get_operations()
147        };
148
149        // Apply exclusions if provided
150        if let Some(exclude) = &self.exclude_operations {
151            let before_count = operations.len();
152            operations = parser.exclude_operations(operations, exclude)?;
153            let excluded_count = before_count - operations.len();
154            if excluded_count > 0 {
155                TerminalReporter::print_progress(&format!(
156                    "Excluded {} operations matching '{}'",
157                    excluded_count, exclude
158                ));
159            }
160        }
161
162        if operations.is_empty() {
163            return Err(BenchError::Other("No operations found in spec".to_string()));
164        }
165
166        TerminalReporter::print_success(&format!("Found {} operations", operations.len()));
167
168        // Load parameter overrides if provided
169        let param_overrides = if let Some(params_file) = &self.params_file {
170            TerminalReporter::print_progress("Loading parameter overrides...");
171            let overrides = ParameterOverrides::from_file(params_file)?;
172            TerminalReporter::print_success(&format!(
173                "Loaded parameter overrides ({} operation-specific, {} defaults)",
174                overrides.operations.len(),
175                if overrides.defaults.is_empty() { 0 } else { 1 }
176            ));
177            Some(overrides)
178        } else {
179            None
180        };
181
182        // Generate request templates
183        TerminalReporter::print_progress("Generating request templates...");
184        let templates: Vec<_> = operations
185            .iter()
186            .map(|op| {
187                let op_overrides = param_overrides.as_ref().map(|po| {
188                    po.get_for_operation(op.operation_id.as_deref(), &op.method, &op.path)
189                });
190                RequestGenerator::generate_template_with_overrides(op, op_overrides.as_ref())
191            })
192            .collect::<Result<Vec<_>>>()?;
193        TerminalReporter::print_success("Request templates generated");
194
195        // Parse headers
196        let custom_headers = self.parse_headers()?;
197
198        // Generate k6 script
199        TerminalReporter::print_progress("Generating k6 load test script...");
200        let scenario =
201            LoadScenario::from_str(&self.scenario).map_err(BenchError::InvalidScenario)?;
202
203        let k6_config = K6Config {
204            target_url: self.target.clone(),
205            scenario,
206            duration_secs: Self::parse_duration(&self.duration)?,
207            max_vus: self.vus,
208            threshold_percentile: self.threshold_percentile.clone(),
209            threshold_ms: self.threshold_ms,
210            max_error_rate: self.max_error_rate,
211            auth_header: self.auth.clone(),
212            custom_headers,
213            skip_tls_verify: self.skip_tls_verify,
214        };
215
216        let generator = K6ScriptGenerator::new(k6_config, templates);
217        let mut script = generator.generate()?;
218        TerminalReporter::print_success("k6 script generated");
219
220        // Check if any advanced features are enabled
221        let has_advanced_features = self.data_file.is_some()
222            || self.error_rate.is_some()
223            || self.security_test
224            || self.parallel_create.is_some();
225
226        // Enhance script with advanced features
227        if has_advanced_features {
228            script = self.generate_enhanced_script(&script)?;
229        }
230
231        // Add mock server integration code
232        if mock_config.is_mock_server {
233            let setup_code = MockIntegrationGenerator::generate_setup(&mock_config);
234            let teardown_code = MockIntegrationGenerator::generate_teardown(&mock_config);
235            let helper_code = MockIntegrationGenerator::generate_vu_id_helper();
236
237            // Insert mock server code after imports
238            if let Some(import_end) = script.find("export const options") {
239                script.insert_str(
240                    import_end,
241                    &format!(
242                        "\n// === Mock Server Integration ===\n{}\n{}\n{}\n",
243                        helper_code, setup_code, teardown_code
244                    ),
245                );
246            }
247        }
248
249        // Validate the generated script
250        TerminalReporter::print_progress("Validating k6 script...");
251        let validation_errors = K6ScriptGenerator::validate_script(&script);
252        if !validation_errors.is_empty() {
253            TerminalReporter::print_error("Script validation failed");
254            for error in &validation_errors {
255                eprintln!("  {}", error);
256            }
257            return Err(BenchError::Other(format!(
258                "Generated k6 script has {} validation error(s). Please check the output above.",
259                validation_errors.len()
260            )));
261        }
262        TerminalReporter::print_success("Script validation passed");
263
264        // Write script to file
265        let script_path = if let Some(output) = &self.script_output {
266            output.clone()
267        } else {
268            self.output.join("k6-script.js")
269        };
270
271        std::fs::create_dir_all(script_path.parent().unwrap())?;
272        std::fs::write(&script_path, &script)?;
273        TerminalReporter::print_success(&format!("Script written to: {}", script_path.display()));
274
275        // If generate-only mode, exit here
276        if self.generate_only {
277            println!("\nScript generated successfully. Run it with:");
278            println!("  k6 run {}", script_path.display());
279            return Ok(());
280        }
281
282        // Execute k6
283        TerminalReporter::print_progress("Executing load test...");
284        let executor = K6Executor::new()?;
285
286        std::fs::create_dir_all(&self.output)?;
287
288        let results = executor.execute(&script_path, Some(&self.output), self.verbose).await?;
289
290        // Print results
291        let duration_secs = Self::parse_duration(&self.duration)?;
292        TerminalReporter::print_summary(&results, duration_secs);
293
294        println!("\nResults saved to: {}", self.output.display());
295
296        Ok(())
297    }
298
299    /// Execute multi-target bench testing
300    async fn execute_multi_target(&self, targets_file: &PathBuf) -> Result<()> {
301        TerminalReporter::print_progress("Parsing targets file...");
302        let targets = parse_targets_file(targets_file)?;
303        let num_targets = targets.len();
304        TerminalReporter::print_success(&format!("Loaded {} targets", num_targets));
305
306        if targets.is_empty() {
307            return Err(BenchError::Other("No targets found in file".to_string()));
308        }
309
310        // Determine max concurrency
311        let max_concurrency = self.max_concurrency.unwrap_or(10) as usize;
312        let max_concurrency = max_concurrency.min(num_targets); // Don't exceed number of targets
313
314        // Print header for multi-target mode
315        TerminalReporter::print_header(
316            self.spec.to_str().unwrap(),
317            &format!("{} targets", num_targets),
318            0,
319            &self.scenario,
320            Self::parse_duration(&self.duration)?,
321        );
322
323        // Create parallel executor
324        let executor = ParallelExecutor::new(
325            BenchCommand {
326                // Clone all fields except targets_file (we don't need it in the executor)
327                spec: self.spec.clone(),
328                target: self.target.clone(), // Not used in multi-target mode, but kept for compatibility
329                duration: self.duration.clone(),
330                vus: self.vus,
331                scenario: self.scenario.clone(),
332                operations: self.operations.clone(),
333                exclude_operations: self.exclude_operations.clone(),
334                auth: self.auth.clone(),
335                headers: self.headers.clone(),
336                output: self.output.clone(),
337                generate_only: self.generate_only,
338                script_output: self.script_output.clone(),
339                threshold_percentile: self.threshold_percentile.clone(),
340                threshold_ms: self.threshold_ms,
341                max_error_rate: self.max_error_rate,
342                verbose: self.verbose,
343                skip_tls_verify: self.skip_tls_verify,
344                targets_file: None,
345                max_concurrency: None,
346                results_format: self.results_format.clone(),
347                params_file: self.params_file.clone(),
348                crud_flow: self.crud_flow,
349                flow_config: self.flow_config.clone(),
350                extract_fields: self.extract_fields.clone(),
351                parallel_create: self.parallel_create,
352                data_file: self.data_file.clone(),
353                data_distribution: self.data_distribution.clone(),
354                data_mappings: self.data_mappings.clone(),
355                error_rate: self.error_rate,
356                error_types: self.error_types.clone(),
357                security_test: self.security_test,
358                security_payloads: self.security_payloads.clone(),
359                security_categories: self.security_categories.clone(),
360                security_target_fields: self.security_target_fields.clone(),
361            },
362            targets,
363            max_concurrency,
364        );
365
366        // Execute all targets
367        let aggregated_results = executor.execute_all().await?;
368
369        // Organize and report results
370        self.report_multi_target_results(&aggregated_results)?;
371
372        Ok(())
373    }
374
375    /// Report results for multi-target execution
376    fn report_multi_target_results(&self, results: &AggregatedResults) -> Result<()> {
377        // Print summary
378        TerminalReporter::print_multi_target_summary(results);
379
380        // Save aggregated summary if requested
381        if self.results_format == "aggregated" || self.results_format == "both" {
382            let summary_path = self.output.join("aggregated_summary.json");
383            let summary_json = serde_json::json!({
384                "total_targets": results.total_targets,
385                "successful_targets": results.successful_targets,
386                "failed_targets": results.failed_targets,
387                "aggregated_metrics": {
388                    "total_requests": results.aggregated_metrics.total_requests,
389                    "total_failed_requests": results.aggregated_metrics.total_failed_requests,
390                    "avg_duration_ms": results.aggregated_metrics.avg_duration_ms,
391                    "p95_duration_ms": results.aggregated_metrics.p95_duration_ms,
392                    "p99_duration_ms": results.aggregated_metrics.p99_duration_ms,
393                    "error_rate": results.aggregated_metrics.error_rate,
394                },
395                "target_results": results.target_results.iter().map(|r| {
396                    serde_json::json!({
397                        "target_url": r.target_url,
398                        "target_index": r.target_index,
399                        "success": r.success,
400                        "error": r.error,
401                        "total_requests": r.results.total_requests,
402                        "failed_requests": r.results.failed_requests,
403                        "avg_duration_ms": r.results.avg_duration_ms,
404                        "p95_duration_ms": r.results.p95_duration_ms,
405                        "p99_duration_ms": r.results.p99_duration_ms,
406                        "output_dir": r.output_dir.to_string_lossy(),
407                    })
408                }).collect::<Vec<_>>(),
409            });
410
411            std::fs::write(&summary_path, serde_json::to_string_pretty(&summary_json)?)?;
412            TerminalReporter::print_success(&format!(
413                "Aggregated summary saved to: {}",
414                summary_path.display()
415            ));
416        }
417
418        println!("\nResults saved to: {}", self.output.display());
419        println!("  - Per-target results: {}", self.output.join("target_*").display());
420        if self.results_format == "aggregated" || self.results_format == "both" {
421            println!(
422                "  - Aggregated summary: {}",
423                self.output.join("aggregated_summary.json").display()
424            );
425        }
426
427        Ok(())
428    }
429
430    /// Parse duration string (e.g., "30s", "5m", "1h") to seconds
431    pub fn parse_duration(duration: &str) -> Result<u64> {
432        let duration = duration.trim();
433
434        if let Some(secs) = duration.strip_suffix('s') {
435            secs.parse::<u64>()
436                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
437        } else if let Some(mins) = duration.strip_suffix('m') {
438            mins.parse::<u64>()
439                .map(|m| m * 60)
440                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
441        } else if let Some(hours) = duration.strip_suffix('h') {
442            hours
443                .parse::<u64>()
444                .map(|h| h * 3600)
445                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
446        } else {
447            // Try parsing as seconds without suffix
448            duration
449                .parse::<u64>()
450                .map_err(|_| BenchError::Other(format!("Invalid duration: {}", duration)))
451        }
452    }
453
454    /// Parse headers from command line format (Key:Value,Key2:Value2)
455    pub fn parse_headers(&self) -> Result<HashMap<String, String>> {
456        let mut headers = HashMap::new();
457
458        if let Some(header_str) = &self.headers {
459            for pair in header_str.split(',') {
460                let parts: Vec<&str> = pair.splitn(2, ':').collect();
461                if parts.len() != 2 {
462                    return Err(BenchError::Other(format!(
463                        "Invalid header format: '{}'. Expected 'Key:Value'",
464                        pair
465                    )));
466                }
467                headers.insert(parts[0].trim().to_string(), parts[1].trim().to_string());
468            }
469        }
470
471        Ok(headers)
472    }
473
474    /// Build mock server integration configuration
475    async fn build_mock_config(&self) -> MockIntegrationConfig {
476        // Check if target looks like a mock server
477        if MockServerDetector::looks_like_mock_server(&self.target) {
478            // Try to detect if it's actually a MockForge server
479            if let Ok(info) = MockServerDetector::detect(&self.target).await {
480                if info.is_mockforge {
481                    TerminalReporter::print_success(&format!(
482                        "Detected MockForge server (version: {})",
483                        info.version.as_deref().unwrap_or("unknown")
484                    ));
485                    return MockIntegrationConfig::mock_server();
486                }
487            }
488        }
489        MockIntegrationConfig::real_api()
490    }
491
492    /// Build CRUD flow configuration
493    fn build_crud_flow_config(&self) -> Option<CrudFlowConfig> {
494        if !self.crud_flow {
495            return None;
496        }
497
498        // If flow_config file is provided, load it
499        if let Some(config_path) = &self.flow_config {
500            match CrudFlowConfig::from_file(config_path) {
501                Ok(config) => return Some(config),
502                Err(e) => {
503                    TerminalReporter::print_warning(&format!(
504                        "Failed to load flow config: {}. Using auto-detection.",
505                        e
506                    ));
507                }
508            }
509        }
510
511        // Parse extract fields
512        let extract_fields = self
513            .extract_fields
514            .as_ref()
515            .map(|f| f.split(',').map(|s| s.trim().to_string()).collect())
516            .unwrap_or_else(|| vec!["id".to_string(), "uuid".to_string()]);
517
518        Some(CrudFlowConfig {
519            flows: Vec::new(), // Will be auto-detected
520            default_extract_fields: extract_fields,
521        })
522    }
523
524    /// Build data-driven testing configuration
525    fn build_data_driven_config(&self) -> Option<DataDrivenConfig> {
526        let data_file = self.data_file.as_ref()?;
527
528        let distribution = DataDistribution::from_str(&self.data_distribution)
529            .unwrap_or(DataDistribution::UniquePerVu);
530
531        let mappings = self
532            .data_mappings
533            .as_ref()
534            .map(|m| DataMapping::parse_mappings(m).unwrap_or_default())
535            .unwrap_or_default();
536
537        Some(DataDrivenConfig {
538            file_path: data_file.to_string_lossy().to_string(),
539            distribution,
540            mappings,
541            csv_has_header: true,
542        })
543    }
544
545    /// Build invalid data testing configuration
546    fn build_invalid_data_config(&self) -> Option<InvalidDataConfig> {
547        let error_rate = self.error_rate?;
548
549        let error_types = self
550            .error_types
551            .as_ref()
552            .map(|types| InvalidDataConfig::parse_error_types(types).unwrap_or_default())
553            .unwrap_or_default();
554
555        Some(InvalidDataConfig {
556            error_rate,
557            error_types,
558            target_fields: Vec::new(),
559        })
560    }
561
562    /// Build security testing configuration
563    fn build_security_config(&self) -> Option<SecurityTestConfig> {
564        if !self.security_test {
565            return None;
566        }
567
568        let categories = self
569            .security_categories
570            .as_ref()
571            .map(|cats| SecurityTestConfig::parse_categories(cats).unwrap_or_default())
572            .unwrap_or_else(|| {
573                let mut default = std::collections::HashSet::new();
574                default.insert(SecurityCategory::SqlInjection);
575                default.insert(SecurityCategory::Xss);
576                default
577            });
578
579        let target_fields = self
580            .security_target_fields
581            .as_ref()
582            .map(|fields| fields.split(',').map(|f| f.trim().to_string()).collect())
583            .unwrap_or_default();
584
585        let custom_payloads_file =
586            self.security_payloads.as_ref().map(|p| p.to_string_lossy().to_string());
587
588        Some(SecurityTestConfig {
589            enabled: true,
590            categories,
591            target_fields,
592            custom_payloads_file,
593            include_high_risk: false,
594        })
595    }
596
597    /// Build parallel execution configuration
598    fn build_parallel_config(&self) -> Option<ParallelConfig> {
599        let count = self.parallel_create?;
600
601        Some(ParallelConfig::new(count))
602    }
603
604    /// Generate enhanced k6 script with advanced features
605    fn generate_enhanced_script(&self, base_script: &str) -> Result<String> {
606        let mut enhanced_script = base_script.to_string();
607        let mut additional_code = String::new();
608
609        // Add data-driven testing code
610        if let Some(config) = self.build_data_driven_config() {
611            TerminalReporter::print_progress("Adding data-driven testing support...");
612            additional_code.push_str(&DataDrivenGenerator::generate_setup(&config));
613            additional_code.push('\n');
614            TerminalReporter::print_success("Data-driven testing enabled");
615        }
616
617        // Add invalid data generation code
618        if let Some(config) = self.build_invalid_data_config() {
619            TerminalReporter::print_progress("Adding invalid data testing support...");
620            additional_code.push_str(&InvalidDataGenerator::generate_invalidation_logic());
621            additional_code.push('\n');
622            additional_code
623                .push_str(&InvalidDataGenerator::generate_should_invalidate(config.error_rate));
624            additional_code.push('\n');
625            additional_code
626                .push_str(&InvalidDataGenerator::generate_type_selection(&config.error_types));
627            additional_code.push('\n');
628            TerminalReporter::print_success(&format!(
629                "Invalid data testing enabled ({}% error rate)",
630                (self.error_rate.unwrap_or(0.0) * 100.0) as u32
631            ));
632        }
633
634        // Add security testing code
635        if let Some(config) = self.build_security_config() {
636            TerminalReporter::print_progress("Adding security testing support...");
637            let payload_list = SecurityPayloads::get_payloads(&config);
638            additional_code
639                .push_str(&SecurityTestGenerator::generate_payload_selection(&payload_list));
640            additional_code.push('\n');
641            additional_code
642                .push_str(&SecurityTestGenerator::generate_apply_payload(&config.target_fields));
643            additional_code.push('\n');
644            additional_code.push_str(&SecurityTestGenerator::generate_security_checks());
645            additional_code.push('\n');
646            TerminalReporter::print_success(&format!(
647                "Security testing enabled ({} payloads)",
648                payload_list.len()
649            ));
650        }
651
652        // Add parallel execution code
653        if let Some(config) = self.build_parallel_config() {
654            TerminalReporter::print_progress("Adding parallel execution support...");
655            additional_code.push_str(&ParallelRequestGenerator::generate_batch_helper(&config));
656            additional_code.push('\n');
657            TerminalReporter::print_success(&format!(
658                "Parallel execution enabled (count: {})",
659                config.count
660            ));
661        }
662
663        // Insert additional code after the imports section
664        if !additional_code.is_empty() {
665            // Find the end of the import section
666            if let Some(import_end) = enhanced_script.find("export const options") {
667                enhanced_script.insert_str(
668                    import_end,
669                    &format!("\n// === Advanced Testing Features ===\n{}\n", additional_code),
670                );
671            }
672        }
673
674        Ok(enhanced_script)
675    }
676
677    /// Execute CRUD flow testing mode
678    async fn execute_crud_flow(&self, parser: &SpecParser) -> Result<()> {
679        TerminalReporter::print_progress("Detecting CRUD operations...");
680
681        let operations = parser.get_operations();
682        let flows = CrudFlowDetector::detect_flows(&operations);
683
684        if flows.is_empty() {
685            return Err(BenchError::Other(
686                "No CRUD flows detected in spec. Ensure spec has POST/GET/PUT/DELETE operations on related paths.".to_string(),
687            ));
688        }
689
690        TerminalReporter::print_success(&format!("Detected {} CRUD flow(s)", flows.len()));
691
692        for flow in &flows {
693            TerminalReporter::print_progress(&format!(
694                "  - {}: {} steps",
695                flow.name,
696                flow.steps.len()
697            ));
698        }
699
700        // Generate CRUD flow script
701        let handlebars = handlebars::Handlebars::new();
702        let template = include_str!("templates/k6_crud_flow.hbs");
703
704        let custom_headers = self.parse_headers()?;
705        let config = self.build_crud_flow_config().unwrap_or_default();
706
707        let data = serde_json::json!({
708            "base_url": self.target,
709            "flows": flows.iter().map(|f| {
710                // Sanitize flow name for use as JavaScript variable and k6 metric names
711                let sanitized_name = K6ScriptGenerator::sanitize_js_identifier(&f.name);
712                serde_json::json!({
713                    "name": sanitized_name.clone(),  // Use sanitized name for variable names
714                    "display_name": f.name,          // Keep original for comments/display
715                    "base_path": f.base_path,
716                    "steps": f.steps.iter().map(|s| {
717                        serde_json::json!({
718                            "operation": s.operation,
719                            "extract": s.extract,
720                            "use_values": s.use_values,
721                            "description": s.description,
722                        })
723                    }).collect::<Vec<_>>(),
724                })
725            }).collect::<Vec<_>>(),
726            "extract_fields": config.default_extract_fields,
727            "duration_secs": Self::parse_duration(&self.duration)?,
728            "max_vus": self.vus,
729            "auth_header": self.auth,
730            "custom_headers": custom_headers,
731            "skip_tls_verify": self.skip_tls_verify,
732        });
733
734        let script = handlebars
735            .render_template(template, &data)
736            .map_err(|e| BenchError::ScriptGenerationFailed(e.to_string()))?;
737
738        // Validate the generated CRUD flow script
739        TerminalReporter::print_progress("Validating CRUD flow script...");
740        let validation_errors = K6ScriptGenerator::validate_script(&script);
741        if !validation_errors.is_empty() {
742            TerminalReporter::print_error("CRUD flow script validation failed");
743            for error in &validation_errors {
744                eprintln!("  {}", error);
745            }
746            return Err(BenchError::Other(format!(
747                "CRUD flow script validation failed with {} error(s)",
748                validation_errors.len()
749            )));
750        }
751
752        TerminalReporter::print_success("CRUD flow script generated");
753
754        // Write and execute script
755        let script_path = if let Some(output) = &self.script_output {
756            output.clone()
757        } else {
758            self.output.join("k6-crud-flow-script.js")
759        };
760
761        std::fs::create_dir_all(script_path.parent().unwrap())?;
762        std::fs::write(&script_path, &script)?;
763        TerminalReporter::print_success(&format!("Script written to: {}", script_path.display()));
764
765        if self.generate_only {
766            println!("\nScript generated successfully. Run it with:");
767            println!("  k6 run {}", script_path.display());
768            return Ok(());
769        }
770
771        // Execute k6
772        TerminalReporter::print_progress("Executing CRUD flow test...");
773        let executor = K6Executor::new()?;
774        std::fs::create_dir_all(&self.output)?;
775
776        let results = executor.execute(&script_path, Some(&self.output), self.verbose).await?;
777
778        let duration_secs = Self::parse_duration(&self.duration)?;
779        TerminalReporter::print_summary(&results, duration_secs);
780
781        Ok(())
782    }
783}
784
785#[cfg(test)]
786mod tests {
787    use super::*;
788
789    #[test]
790    fn test_parse_duration() {
791        assert_eq!(BenchCommand::parse_duration("30s").unwrap(), 30);
792        assert_eq!(BenchCommand::parse_duration("5m").unwrap(), 300);
793        assert_eq!(BenchCommand::parse_duration("1h").unwrap(), 3600);
794        assert_eq!(BenchCommand::parse_duration("60").unwrap(), 60);
795    }
796
797    #[test]
798    fn test_parse_duration_invalid() {
799        assert!(BenchCommand::parse_duration("invalid").is_err());
800        assert!(BenchCommand::parse_duration("30x").is_err());
801    }
802
803    #[test]
804    fn test_parse_headers() {
805        let cmd = BenchCommand {
806            spec: PathBuf::from("test.yaml"),
807            target: "http://localhost".to_string(),
808            duration: "1m".to_string(),
809            vus: 10,
810            scenario: "ramp-up".to_string(),
811            operations: None,
812            exclude_operations: None,
813            auth: None,
814            headers: Some("X-API-Key:test123,X-Client-ID:client456".to_string()),
815            output: PathBuf::from("output"),
816            generate_only: false,
817            script_output: None,
818            threshold_percentile: "p(95)".to_string(),
819            threshold_ms: 500,
820            max_error_rate: 0.05,
821            verbose: false,
822            skip_tls_verify: false,
823            targets_file: None,
824            max_concurrency: None,
825            results_format: "both".to_string(),
826            params_file: None,
827            crud_flow: false,
828            flow_config: None,
829            extract_fields: None,
830            parallel_create: None,
831            data_file: None,
832            data_distribution: "unique-per-vu".to_string(),
833            data_mappings: None,
834            error_rate: None,
835            error_types: None,
836            security_test: false,
837            security_payloads: None,
838            security_categories: None,
839            security_target_fields: None,
840        };
841
842        let headers = cmd.parse_headers().unwrap();
843        assert_eq!(headers.get("X-API-Key"), Some(&"test123".to_string()));
844        assert_eq!(headers.get("X-Client-ID"), Some(&"client456".to_string()));
845    }
846}