Skip to main content

mockforge_bench/
parallel_executor.rs

1//! Parallel execution engine for multi-target bench testing
2//!
3//! Executes load tests against multiple targets in parallel with configurable
4//! concurrency limits. Uses tokio for async execution and semaphores for
5//! backpressure control.
6
7use crate::command::BenchCommand;
8use crate::error::{BenchError, Result};
9use crate::executor::{K6Executor, K6Results};
10use crate::k6_gen::{K6Config, K6ScriptGenerator};
11use crate::reporter::TerminalReporter;
12use crate::request_gen::RequestGenerator;
13use crate::scenarios::LoadScenario;
14use crate::spec_parser::SpecParser;
15use crate::target_parser::TargetConfig;
16use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
17use std::collections::HashMap;
18use std::path::{Path, PathBuf};
19use std::str::FromStr;
20use std::sync::Arc;
21use tokio::sync::Semaphore;
22use tokio::task::JoinHandle;
23
24/// Result for a single target execution
25#[derive(Debug, Clone)]
26pub struct TargetResult {
27    /// Target URL that was tested
28    pub target_url: String,
29    /// Index of the target (for ordering)
30    pub target_index: usize,
31    /// k6 test results
32    pub results: K6Results,
33    /// Output directory for this target
34    pub output_dir: PathBuf,
35    /// Whether the test succeeded
36    pub success: bool,
37    /// Error message if test failed
38    pub error: Option<String>,
39}
40
41/// Aggregated results from all target executions
42#[derive(Debug, Clone)]
43pub struct AggregatedResults {
44    /// Results for each target
45    pub target_results: Vec<TargetResult>,
46    /// Overall statistics
47    pub total_targets: usize,
48    pub successful_targets: usize,
49    pub failed_targets: usize,
50    /// Aggregated metrics across all targets
51    pub aggregated_metrics: AggregatedMetrics,
52}
53
54/// Aggregated metrics across all targets
55#[derive(Debug, Clone)]
56pub struct AggregatedMetrics {
57    /// Total requests across all targets
58    pub total_requests: u64,
59    /// Total failed requests across all targets
60    pub total_failed_requests: u64,
61    /// Average response time across all targets (ms)
62    pub avg_duration_ms: f64,
63    /// p95 response time across all targets (ms)
64    pub p95_duration_ms: f64,
65    /// p99 response time across all targets (ms)
66    pub p99_duration_ms: f64,
67    /// Overall error rate percentage
68    pub error_rate: f64,
69}
70
71impl AggregatedMetrics {
72    /// Calculate aggregated metrics from target results
73    fn from_results(results: &[TargetResult]) -> Self {
74        let mut total_requests = 0u64;
75        let mut total_failed_requests = 0u64;
76        let mut durations = Vec::new();
77        let mut p95_values = Vec::new();
78        let mut p99_values = Vec::new();
79
80        for result in results {
81            if result.success {
82                total_requests += result.results.total_requests;
83                total_failed_requests += result.results.failed_requests;
84                durations.push(result.results.avg_duration_ms);
85                p95_values.push(result.results.p95_duration_ms);
86                p99_values.push(result.results.p99_duration_ms);
87            }
88        }
89
90        let avg_duration_ms = if !durations.is_empty() {
91            durations.iter().sum::<f64>() / durations.len() as f64
92        } else {
93            0.0
94        };
95
96        let p95_duration_ms = if !p95_values.is_empty() {
97            p95_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
98            let index = (p95_values.len() as f64 * 0.95).ceil() as usize - 1;
99            p95_values[index.min(p95_values.len() - 1)]
100        } else {
101            0.0
102        };
103
104        let p99_duration_ms = if !p99_values.is_empty() {
105            p99_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
106            let index = (p99_values.len() as f64 * 0.99).ceil() as usize - 1;
107            p99_values[index.min(p99_values.len() - 1)]
108        } else {
109            0.0
110        };
111
112        let error_rate = if total_requests > 0 {
113            (total_failed_requests as f64 / total_requests as f64) * 100.0
114        } else {
115            0.0
116        };
117
118        Self {
119            total_requests,
120            total_failed_requests,
121            avg_duration_ms,
122            p95_duration_ms,
123            p99_duration_ms,
124            error_rate,
125        }
126    }
127}
128
129/// Parallel executor for multi-target bench testing
130pub struct ParallelExecutor {
131    /// Base command configuration (shared across all targets)
132    base_command: BenchCommand,
133    /// List of targets to test
134    targets: Vec<TargetConfig>,
135    /// Maximum number of concurrent executions
136    max_concurrency: usize,
137    /// Base output directory
138    base_output: PathBuf,
139}
140
141impl ParallelExecutor {
142    /// Create a new parallel executor
143    pub fn new(
144        base_command: BenchCommand,
145        targets: Vec<TargetConfig>,
146        max_concurrency: usize,
147    ) -> Self {
148        let base_output = base_command.output.clone();
149        Self {
150            base_command,
151            targets,
152            max_concurrency,
153            base_output,
154        }
155    }
156
157    /// Execute tests against all targets in parallel
158    pub async fn execute_all(&self) -> Result<AggregatedResults> {
159        let total_targets = self.targets.len();
160        TerminalReporter::print_progress(&format!(
161            "Starting parallel execution for {} targets (max concurrency: {})",
162            total_targets, self.max_concurrency
163        ));
164
165        // Validate k6 installation
166        if !K6Executor::is_k6_installed() {
167            TerminalReporter::print_error("k6 is not installed");
168            TerminalReporter::print_warning(
169                "Install k6 from: https://k6.io/docs/get-started/installation/",
170            );
171            return Err(BenchError::K6NotFound);
172        }
173
174        // Load and parse spec(s) (shared across all targets)
175        TerminalReporter::print_progress("Loading OpenAPI specification(s)...");
176        let merged_spec = self.base_command.load_and_merge_specs().await?;
177        let parser = SpecParser::from_spec(merged_spec);
178        TerminalReporter::print_success("Specification(s) loaded");
179
180        // Get operations
181        let operations = if let Some(filter) = &self.base_command.operations {
182            parser.filter_operations(filter)?
183        } else {
184            parser.get_operations()
185        };
186
187        if operations.is_empty() {
188            return Err(BenchError::Other("No operations found in spec".to_string()));
189        }
190
191        TerminalReporter::print_success(&format!("Found {} operations", operations.len()));
192
193        // Generate request templates (shared across all targets)
194        TerminalReporter::print_progress("Generating request templates...");
195        let templates: Vec<_> = operations
196            .iter()
197            .map(RequestGenerator::generate_template)
198            .collect::<Result<Vec<_>>>()?;
199        TerminalReporter::print_success("Request templates generated");
200
201        // Parse base headers
202        let base_headers = self.base_command.parse_headers()?;
203
204        // Resolve base path (CLI option takes priority over spec's servers URL)
205        let base_path = self.resolve_base_path(&parser);
206        if let Some(ref bp) = base_path {
207            TerminalReporter::print_progress(&format!("Using base path: {}", bp));
208        }
209
210        // Parse scenario
211        let scenario = LoadScenario::from_str(&self.base_command.scenario)
212            .map_err(BenchError::InvalidScenario)?;
213
214        let duration_secs = BenchCommand::parse_duration(&self.base_command.duration)?;
215
216        // Compute security testing flag
217        let security_testing_enabled =
218            self.base_command.security_test || self.base_command.wafbench_dir.is_some();
219
220        // Pre-compute enhancement code once (same for all targets)
221        let has_advanced_features = self.base_command.data_file.is_some()
222            || self.base_command.error_rate.is_some()
223            || self.base_command.security_test
224            || self.base_command.parallel_create.is_some()
225            || self.base_command.wafbench_dir.is_some();
226
227        let enhancement_code = if has_advanced_features {
228            let dummy_script = "export const options = {};";
229            let enhanced = self.base_command.generate_enhanced_script(dummy_script)?;
230            if let Some(pos) = enhanced.find("export const options") {
231                enhanced[..pos].to_string()
232            } else {
233                String::new()
234            }
235        } else {
236            String::new()
237        };
238
239        // Create semaphore for concurrency control
240        let semaphore = Arc::new(Semaphore::new(self.max_concurrency));
241        let multi_progress = MultiProgress::new();
242
243        // Create progress bars for each target
244        let progress_bars: Vec<ProgressBar> = (0..total_targets)
245            .map(|i| {
246                let pb = multi_progress.add(ProgressBar::new(1));
247                pb.set_style(
248                    ProgressStyle::default_bar()
249                        .template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} {msg}")
250                        .unwrap(),
251                );
252                pb.set_message(format!("Target {}", i + 1));
253                pb
254            })
255            .collect();
256
257        // Spawn tasks for each target
258        let mut handles: Vec<JoinHandle<Result<TargetResult>>> = Vec::new();
259
260        for (index, target) in self.targets.iter().enumerate() {
261            let target = target.clone();
262            // Clone necessary fields from base_command instead of passing reference
263            let duration = self.base_command.duration.clone();
264            let vus = self.base_command.vus;
265            let scenario_str = self.base_command.scenario.clone();
266            let operations = self.base_command.operations.clone();
267            let auth = self.base_command.auth.clone();
268            let headers = self.base_command.headers.clone();
269            let threshold_percentile = self.base_command.threshold_percentile.clone();
270            let threshold_ms = self.base_command.threshold_ms;
271            let max_error_rate = self.base_command.max_error_rate;
272            let verbose = self.base_command.verbose;
273            let skip_tls_verify = self.base_command.skip_tls_verify;
274
275            let templates = templates.clone();
276            let base_headers = base_headers.clone();
277            let scenario = scenario.clone();
278            let duration_secs = duration_secs;
279            let base_output = self.base_output.clone();
280            let semaphore = semaphore.clone();
281            let progress_bar = progress_bars[index].clone();
282            let target_index = index;
283            let base_path = base_path.clone();
284            let security_testing_enabled = security_testing_enabled;
285            let enhancement_code = enhancement_code.clone();
286
287            let handle = tokio::spawn(async move {
288                // Acquire semaphore permit
289                let _permit = semaphore.acquire().await.map_err(|e| {
290                    BenchError::Other(format!("Failed to acquire semaphore: {}", e))
291                })?;
292
293                progress_bar.set_message(format!("Testing {}", target.url));
294
295                // Execute test for this target
296                let result = Self::execute_single_target_internal(
297                    &duration,
298                    vus,
299                    &scenario_str,
300                    &operations,
301                    &auth,
302                    &headers,
303                    &threshold_percentile,
304                    threshold_ms,
305                    max_error_rate,
306                    verbose,
307                    skip_tls_verify,
308                    base_path.as_ref(),
309                    &target,
310                    target_index,
311                    &templates,
312                    &base_headers,
313                    &scenario,
314                    duration_secs,
315                    &base_output,
316                    security_testing_enabled,
317                    &enhancement_code,
318                )
319                .await;
320
321                progress_bar.inc(1);
322                progress_bar.finish_with_message(format!("Completed {}", target.url));
323
324                result
325            });
326
327            handles.push(handle);
328        }
329
330        // Wait for all tasks to complete and collect results
331        let mut target_results = Vec::new();
332        for (index, handle) in handles.into_iter().enumerate() {
333            match handle.await {
334                Ok(Ok(result)) => {
335                    target_results.push(result);
336                }
337                Ok(Err(e)) => {
338                    // Create error result
339                    let target_url = self.targets[index].url.clone();
340                    target_results.push(TargetResult {
341                        target_url: target_url.clone(),
342                        target_index: index,
343                        results: K6Results::default(),
344                        output_dir: self.base_output.join(format!("target_{}", index + 1)),
345                        success: false,
346                        error: Some(e.to_string()),
347                    });
348                }
349                Err(e) => {
350                    // Join error
351                    let target_url = self.targets[index].url.clone();
352                    target_results.push(TargetResult {
353                        target_url: target_url.clone(),
354                        target_index: index,
355                        results: K6Results::default(),
356                        output_dir: self.base_output.join(format!("target_{}", index + 1)),
357                        success: false,
358                        error: Some(format!("Task join error: {}", e)),
359                    });
360                }
361            }
362        }
363
364        // Sort results by target index
365        target_results.sort_by_key(|r| r.target_index);
366
367        // Calculate aggregated metrics
368        let aggregated_metrics = AggregatedMetrics::from_results(&target_results);
369
370        let successful_targets = target_results.iter().filter(|r| r.success).count();
371        let failed_targets = total_targets - successful_targets;
372
373        Ok(AggregatedResults {
374            target_results,
375            total_targets,
376            successful_targets,
377            failed_targets,
378            aggregated_metrics,
379        })
380    }
381
382    /// Resolve the effective base path for API endpoints
383    fn resolve_base_path(&self, parser: &SpecParser) -> Option<String> {
384        // CLI option takes priority (including empty string to disable)
385        if let Some(cli_base_path) = &self.base_command.base_path {
386            if cli_base_path.is_empty() {
387                return None;
388            }
389            return Some(cli_base_path.clone());
390        }
391        // Fall back to spec's base path
392        parser.get_base_path()
393    }
394
395    /// Execute a single target test (internal method that doesn't require BenchCommand)
396    #[allow(clippy::too_many_arguments)]
397    async fn execute_single_target_internal(
398        duration: &str,
399        vus: u32,
400        scenario_str: &str,
401        operations: &Option<String>,
402        auth: &Option<String>,
403        headers: &Option<String>,
404        threshold_percentile: &str,
405        threshold_ms: u64,
406        max_error_rate: f64,
407        verbose: bool,
408        skip_tls_verify: bool,
409        base_path: Option<&String>,
410        target: &TargetConfig,
411        target_index: usize,
412        templates: &[crate::request_gen::RequestTemplate],
413        base_headers: &HashMap<String, String>,
414        scenario: &LoadScenario,
415        duration_secs: u64,
416        base_output: &Path,
417        security_testing_enabled: bool,
418        enhancement_code: &str,
419    ) -> Result<TargetResult> {
420        // Merge target-specific headers with base headers
421        let mut custom_headers = base_headers.clone();
422        if let Some(target_headers) = &target.headers {
423            custom_headers.extend(target_headers.clone());
424        }
425
426        // Use target-specific auth if provided, otherwise use base auth
427        let auth_header = target.auth.as_ref().or(auth.as_ref()).cloned();
428
429        // Create k6 config for this target
430        let k6_config = K6Config {
431            target_url: target.url.clone(),
432            base_path: base_path.cloned(),
433            scenario: scenario.clone(),
434            duration_secs,
435            max_vus: vus,
436            threshold_percentile: threshold_percentile.to_string(),
437            threshold_ms,
438            max_error_rate,
439            auth_header,
440            custom_headers,
441            skip_tls_verify,
442            security_testing_enabled,
443        };
444
445        // Generate k6 script
446        let generator = K6ScriptGenerator::new(k6_config, templates.to_vec());
447        let mut script = generator.generate()?;
448
449        // Apply pre-computed enhancement code (security definitions, etc.)
450        if !enhancement_code.is_empty() {
451            if let Some(pos) = script.find("export const options") {
452                script.insert_str(pos, enhancement_code);
453            }
454        }
455
456        // Validate script
457        let validation_errors = K6ScriptGenerator::validate_script(&script);
458        if !validation_errors.is_empty() {
459            return Err(BenchError::Other(format!(
460                "Script validation failed for target {}: {}",
461                target.url,
462                validation_errors.join(", ")
463            )));
464        }
465
466        // Create output directory for this target
467        let output_dir = base_output.join(format!("target_{}", target_index + 1));
468        std::fs::create_dir_all(&output_dir)?;
469
470        // Write script to file
471        let script_path = output_dir.join("k6-script.js");
472        std::fs::write(&script_path, script)?;
473
474        // Execute k6
475        let executor = K6Executor::new()?;
476        let results = executor.execute(&script_path, Some(&output_dir), verbose).await;
477
478        match results {
479            Ok(k6_results) => Ok(TargetResult {
480                target_url: target.url.clone(),
481                target_index,
482                results: k6_results,
483                output_dir,
484                success: true,
485                error: None,
486            }),
487            Err(e) => Ok(TargetResult {
488                target_url: target.url.clone(),
489                target_index,
490                results: K6Results::default(),
491                output_dir,
492                success: false,
493                error: Some(e.to_string()),
494            }),
495        }
496    }
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502
503    #[test]
504    fn test_aggregated_metrics_from_results() {
505        let results = vec![
506            TargetResult {
507                target_url: "http://api1.com".to_string(),
508                target_index: 0,
509                results: K6Results {
510                    total_requests: 100,
511                    failed_requests: 5,
512                    avg_duration_ms: 100.0,
513                    p95_duration_ms: 200.0,
514                    p99_duration_ms: 300.0,
515                },
516                output_dir: PathBuf::from("output1"),
517                success: true,
518                error: None,
519            },
520            TargetResult {
521                target_url: "http://api2.com".to_string(),
522                target_index: 1,
523                results: K6Results {
524                    total_requests: 200,
525                    failed_requests: 10,
526                    avg_duration_ms: 150.0,
527                    p95_duration_ms: 250.0,
528                    p99_duration_ms: 350.0,
529                },
530                output_dir: PathBuf::from("output2"),
531                success: true,
532                error: None,
533            },
534        ];
535
536        let metrics = AggregatedMetrics::from_results(&results);
537        assert_eq!(metrics.total_requests, 300);
538        assert_eq!(metrics.total_failed_requests, 15);
539        assert_eq!(metrics.avg_duration_ms, 125.0); // (100 + 150) / 2
540    }
541
542    #[test]
543    fn test_aggregated_metrics_with_failed_targets() {
544        let results = vec![
545            TargetResult {
546                target_url: "http://api1.com".to_string(),
547                target_index: 0,
548                results: K6Results {
549                    total_requests: 100,
550                    failed_requests: 5,
551                    avg_duration_ms: 100.0,
552                    p95_duration_ms: 200.0,
553                    p99_duration_ms: 300.0,
554                },
555                output_dir: PathBuf::from("output1"),
556                success: true,
557                error: None,
558            },
559            TargetResult {
560                target_url: "http://api2.com".to_string(),
561                target_index: 1,
562                results: K6Results::default(),
563                output_dir: PathBuf::from("output2"),
564                success: false,
565                error: Some("Network error".to_string()),
566            },
567        ];
568
569        let metrics = AggregatedMetrics::from_results(&results);
570        // Only successful target should be counted
571        assert_eq!(metrics.total_requests, 100);
572        assert_eq!(metrics.total_failed_requests, 5);
573        assert_eq!(metrics.avg_duration_ms, 100.0);
574    }
575
576    #[test]
577    fn test_aggregated_metrics_empty_results() {
578        let results = vec![];
579        let metrics = AggregatedMetrics::from_results(&results);
580        assert_eq!(metrics.total_requests, 0);
581        assert_eq!(metrics.total_failed_requests, 0);
582        assert_eq!(metrics.avg_duration_ms, 0.0);
583        assert_eq!(metrics.error_rate, 0.0);
584    }
585
586    #[test]
587    fn test_aggregated_metrics_error_rate_calculation() {
588        let results = vec![TargetResult {
589            target_url: "http://api1.com".to_string(),
590            target_index: 0,
591            results: K6Results {
592                total_requests: 1000,
593                failed_requests: 50,
594                avg_duration_ms: 100.0,
595                p95_duration_ms: 200.0,
596                p99_duration_ms: 300.0,
597            },
598            output_dir: PathBuf::from("output1"),
599            success: true,
600            error: None,
601        }];
602
603        let metrics = AggregatedMetrics::from_results(&results);
604        assert_eq!(metrics.error_rate, 5.0); // 50/1000 * 100
605    }
606
607    #[test]
608    fn test_aggregated_metrics_p95_p99_calculation() {
609        let results = vec![
610            TargetResult {
611                target_url: "http://api1.com".to_string(),
612                target_index: 0,
613                results: K6Results {
614                    total_requests: 100,
615                    failed_requests: 0,
616                    avg_duration_ms: 100.0,
617                    p95_duration_ms: 150.0,
618                    p99_duration_ms: 200.0,
619                },
620                output_dir: PathBuf::from("output1"),
621                success: true,
622                error: None,
623            },
624            TargetResult {
625                target_url: "http://api2.com".to_string(),
626                target_index: 1,
627                results: K6Results {
628                    total_requests: 100,
629                    failed_requests: 0,
630                    avg_duration_ms: 200.0,
631                    p95_duration_ms: 250.0,
632                    p99_duration_ms: 300.0,
633                },
634                output_dir: PathBuf::from("output2"),
635                success: true,
636                error: None,
637            },
638            TargetResult {
639                target_url: "http://api3.com".to_string(),
640                target_index: 2,
641                results: K6Results {
642                    total_requests: 100,
643                    failed_requests: 0,
644                    avg_duration_ms: 300.0,
645                    p95_duration_ms: 350.0,
646                    p99_duration_ms: 400.0,
647                },
648                output_dir: PathBuf::from("output3"),
649                success: true,
650                error: None,
651            },
652        ];
653
654        let metrics = AggregatedMetrics::from_results(&results);
655        // p95 should be the 95th percentile of [150, 250, 350] = index 2 = 350
656        // p99 should be the 99th percentile of [200, 300, 400] = index 2 = 400
657        assert_eq!(metrics.p95_duration_ms, 350.0);
658        assert_eq!(metrics.p99_duration_ms, 400.0);
659    }
660}