quantrs2_tytan/testing_framework/
framework.rs

1//! Core testing framework implementation.
2//!
3//! This module provides the main TestingFramework struct and its implementation
4//! for running tests, managing test suites, and generating reports.
5
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use std::thread;
9use std::time::{Duration, Instant};
10
11use crate::sampler::Sampler;
12
13use super::config::{ReportFormat, TestConfig};
14use super::generators::{
15    default_generators, FinanceTestGenerator, LogisticsTestGenerator, ManufacturingTestGenerator,
16};
17use super::reports;
18use super::results::{ConvergenceData, MemoryStats, PerformanceData, RuntimeStats, TestResults};
19use super::types::{
20    CIReport, CIStatus, Difficulty, FailureType, GeneratorConfig, ProblemType, RegressionIssue,
21    RegressionReport, SamplerComparison, TestCase, TestCategory, TestComparison, TestEnvironment,
22    TestFailure, TestGenerator, TestResult, TestSuite, ValidationResult, Validator,
23};
24use super::validators::default_validators;
25
26/// Automated testing framework
27pub struct TestingFramework {
28    /// Test configuration
29    pub config: TestConfig,
30    /// Test suite
31    pub suite: TestSuite,
32    /// Test results
33    pub results: TestResults,
34    /// Validators
35    validators: Vec<Box<dyn Validator>>,
36    /// Generators
37    generators: Vec<Box<dyn TestGenerator>>,
38}
39
40impl TestingFramework {
41    /// Create new testing framework
42    pub fn new(config: TestConfig) -> Self {
43        Self {
44            config,
45            suite: TestSuite {
46                categories: Vec::new(),
47                test_cases: Vec::new(),
48                benchmarks: Vec::new(),
49            },
50            results: TestResults::default(),
51            validators: default_validators(),
52            generators: default_generators(),
53        }
54    }
55
56    /// Add test category
57    pub fn add_category(&mut self, category: TestCategory) {
58        self.suite.categories.push(category);
59    }
60
61    /// Add custom generator
62    pub fn add_generator(&mut self, generator: Box<dyn TestGenerator>) {
63        self.generators.push(generator);
64    }
65
66    /// Add custom validator
67    pub fn add_validator(&mut self, validator: Box<dyn Validator>) {
68        self.validators.push(validator);
69    }
70
71    /// Generate test suite
72    pub fn generate_suite(&mut self) -> Result<(), String> {
73        let start_time = Instant::now();
74
75        // Generate tests for each category
76        for category in &self.suite.categories {
77            for problem_type in &category.problem_types {
78                for difficulty in &category.difficulties {
79                    for size in &self.config.problem_sizes {
80                        let config = GeneratorConfig {
81                            problem_type: problem_type.clone(),
82                            size: *size,
83                            difficulty: difficulty.clone(),
84                            seed: self.config.seed,
85                            parameters: HashMap::new(),
86                        };
87
88                        // Find suitable generator
89                        for generator in &self.generators {
90                            if generator.supported_types().contains(problem_type) {
91                                let test_cases = generator.generate(&config)?;
92                                self.suite.test_cases.extend(test_cases);
93                                break;
94                            }
95                        }
96                    }
97                }
98            }
99        }
100
101        self.results.performance.runtime_stats.qubo_generation_time = start_time.elapsed();
102
103        Ok(())
104    }
105
106    /// Run test suite
107    pub fn run_suite<S: Sampler>(&mut self, sampler: &S) -> Result<(), String> {
108        let total_start = Instant::now();
109
110        let test_cases = self.suite.test_cases.clone();
111        for test_case in &test_cases {
112            let test_start = Instant::now();
113
114            // Run test with timeout
115            match self.run_single_test(test_case, sampler) {
116                Ok(result) => {
117                    self.results.test_results.push(result);
118                    self.results.summary.passed += 1;
119                }
120                Err(e) => {
121                    self.results.failures.push(TestFailure {
122                        test_id: test_case.id.clone(),
123                        failure_type: FailureType::SamplerError,
124                        message: e,
125                        stack_trace: None,
126                        context: HashMap::new(),
127                    });
128                    self.results.summary.failed += 1;
129                }
130            }
131
132            let test_time = test_start.elapsed();
133            self.results
134                .performance
135                .runtime_stats
136                .time_per_test
137                .push((test_case.id.clone(), test_time));
138
139            self.results.summary.total_tests += 1;
140        }
141
142        self.results.performance.runtime_stats.total_time = total_start.elapsed();
143        self.calculate_summary();
144
145        Ok(())
146    }
147
148    /// Run single test
149    fn run_single_test<S: Sampler>(
150        &mut self,
151        test_case: &TestCase,
152        sampler: &S,
153    ) -> Result<TestResult, String> {
154        let solve_start = Instant::now();
155
156        // Run sampler
157        let sample_result = sampler
158            .run_qubo(
159                &(test_case.qubo.clone(), test_case.var_map.clone()),
160                self.config.samplers[0].num_samples,
161            )
162            .map_err(|e| format!("Sampler error: {e:?}"))?;
163
164        let solve_time = solve_start.elapsed();
165
166        // Get best solution
167        let best_sample = sample_result
168            .iter()
169            .min_by(|a, b| {
170                a.energy
171                    .partial_cmp(&b.energy)
172                    .unwrap_or(std::cmp::Ordering::Equal)
173            })
174            .ok_or("No samples returned")?;
175
176        // Use the assignments directly (already decoded)
177        let solution = best_sample.assignments.clone();
178
179        // Validate
180        let validation_start = Instant::now();
181        let mut validation = ValidationResult {
182            is_valid: true,
183            checks: Vec::new(),
184            warnings: Vec::new(),
185        };
186
187        for validator in &self.validators {
188            let result = validator.validate(
189                test_case,
190                &TestResult {
191                    test_id: test_case.id.clone(),
192                    sampler: "test".to_string(),
193                    solution: solution.clone(),
194                    objective_value: best_sample.energy,
195                    constraints_satisfied: true,
196                    validation: validation.clone(),
197                    runtime: solve_time,
198                    metrics: HashMap::new(),
199                },
200            );
201
202            validation.checks.extend(result.checks);
203            validation.warnings.extend(result.warnings);
204            validation.is_valid &= result.is_valid;
205        }
206
207        let validation_time = validation_start.elapsed();
208        self.results.performance.runtime_stats.solving_time += solve_time;
209        self.results.performance.runtime_stats.validation_time += validation_time;
210
211        Ok(TestResult {
212            test_id: test_case.id.clone(),
213            sampler: self.config.samplers[0].name.clone(),
214            solution,
215            objective_value: best_sample.energy,
216            constraints_satisfied: validation.is_valid,
217            validation,
218            runtime: solve_time + validation_time,
219            metrics: HashMap::new(),
220        })
221    }
222
223    /// Calculate summary statistics
224    fn calculate_summary(&mut self) {
225        if self.results.test_results.is_empty() {
226            return;
227        }
228
229        // Success rate
230        self.results.summary.success_rate =
231            self.results.summary.passed as f64 / self.results.summary.total_tests as f64;
232
233        // Average runtime
234        let total_runtime: Duration = self.results.test_results.iter().map(|r| r.runtime).sum();
235        self.results.summary.avg_runtime = total_runtime / self.results.test_results.len() as u32;
236
237        // Quality metrics
238        let qualities: Vec<f64> = self
239            .results
240            .test_results
241            .iter()
242            .map(|r| r.objective_value)
243            .collect();
244
245        self.results.summary.quality_metrics.avg_quality =
246            qualities.iter().sum::<f64>() / qualities.len() as f64;
247
248        self.results.summary.quality_metrics.best_quality = *qualities
249            .iter()
250            .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
251            .unwrap_or(&0.0);
252
253        self.results.summary.quality_metrics.worst_quality = *qualities
254            .iter()
255            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
256            .unwrap_or(&0.0);
257
258        // Standard deviation
259        let mean = self.results.summary.quality_metrics.avg_quality;
260        let variance =
261            qualities.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / qualities.len() as f64;
262        self.results.summary.quality_metrics.std_dev = variance.sqrt();
263
264        // Constraint satisfaction rate
265        let satisfied = self
266            .results
267            .test_results
268            .iter()
269            .filter(|r| r.constraints_satisfied)
270            .count();
271        self.results
272            .summary
273            .quality_metrics
274            .constraint_satisfaction_rate =
275            satisfied as f64 / self.results.test_results.len() as f64;
276    }
277
278    /// Generate report
279    pub fn generate_report(&self) -> Result<String, String> {
280        reports::generate_report(&self.config.output.format, &self.results, &self.suite)
281    }
282
283    /// Save report to file
284    pub fn save_report(&self, filename: &str) -> Result<(), String> {
285        let report = self.generate_report()?;
286        reports::save_report(&report, filename)
287    }
288
289    /// Run regression tests against baseline
290    pub fn run_regression_tests<S: Sampler>(
291        &mut self,
292        sampler: &S,
293        baseline_file: &str,
294    ) -> Result<RegressionReport, String> {
295        // Load baseline results
296        let baseline = self.load_baseline(baseline_file)?;
297
298        // Run current tests
299        self.run_suite(sampler)?;
300
301        // Compare with baseline
302        let mut regressions = Vec::new();
303        let mut improvements = Vec::new();
304
305        for current_result in &self.results.test_results {
306            if let Some(baseline_result) = baseline
307                .iter()
308                .find(|b| b.test_id == current_result.test_id)
309            {
310                let quality_change = (current_result.objective_value
311                    - baseline_result.objective_value)
312                    / baseline_result.objective_value.abs();
313                let runtime_change = (current_result.runtime.as_secs_f64()
314                    - baseline_result.runtime.as_secs_f64())
315                    / baseline_result.runtime.as_secs_f64();
316
317                if quality_change > 0.05 || runtime_change > 0.2 {
318                    regressions.push(RegressionIssue {
319                        test_id: current_result.test_id.clone(),
320                        metric: if quality_change > 0.05 {
321                            "quality".to_string()
322                        } else {
323                            "runtime".to_string()
324                        },
325                        baseline_value: if quality_change > 0.05 {
326                            baseline_result.objective_value
327                        } else {
328                            baseline_result.runtime.as_secs_f64()
329                        },
330                        current_value: if quality_change > 0.05 {
331                            current_result.objective_value
332                        } else {
333                            current_result.runtime.as_secs_f64()
334                        },
335                        change_percent: if quality_change > 0.05 {
336                            quality_change * 100.0
337                        } else {
338                            runtime_change * 100.0
339                        },
340                    });
341                } else if quality_change < -0.05 || runtime_change < -0.2 {
342                    improvements.push(RegressionIssue {
343                        test_id: current_result.test_id.clone(),
344                        metric: if quality_change < -0.05 {
345                            "quality".to_string()
346                        } else {
347                            "runtime".to_string()
348                        },
349                        baseline_value: if quality_change < -0.05 {
350                            baseline_result.objective_value
351                        } else {
352                            baseline_result.runtime.as_secs_f64()
353                        },
354                        current_value: if quality_change < -0.05 {
355                            current_result.objective_value
356                        } else {
357                            current_result.runtime.as_secs_f64()
358                        },
359                        change_percent: if quality_change < -0.05 {
360                            quality_change * 100.0
361                        } else {
362                            runtime_change * 100.0
363                        },
364                    });
365                }
366            }
367        }
368
369        Ok(RegressionReport {
370            regressions,
371            improvements,
372            baseline_tests: baseline.len(),
373            current_tests: self.results.test_results.len(),
374        })
375    }
376
377    /// Load baseline results from file
378    const fn load_baseline(&self, _filename: &str) -> Result<Vec<TestResult>, String> {
379        // Simplified implementation - in practice would load from JSON/CSV
380        Ok(Vec::new())
381    }
382
383    /// Run test suite in parallel
384    pub fn run_suite_parallel<S: Sampler + Clone + Send + Sync + 'static>(
385        &mut self,
386        sampler: &S,
387        num_threads: usize,
388    ) -> Result<(), String> {
389        let test_cases = Arc::new(self.suite.test_cases.clone());
390        let results = Arc::new(Mutex::new(Vec::new()));
391        let failures = Arc::new(Mutex::new(Vec::new()));
392
393        let total_start = Instant::now();
394        let chunk_size = test_cases.len().div_ceil(num_threads);
395
396        let mut handles = Vec::new();
397
398        for thread_id in 0..num_threads {
399            let start_idx = thread_id * chunk_size;
400            let end_idx = ((thread_id + 1) * chunk_size).min(test_cases.len());
401
402            if start_idx >= test_cases.len() {
403                break;
404            }
405
406            let test_cases_clone = Arc::clone(&test_cases);
407            let results_clone = Arc::clone(&results);
408            let failures_clone = Arc::clone(&failures);
409            let sampler_clone = sampler.clone();
410
411            let handle = thread::spawn(move || {
412                for idx in start_idx..end_idx {
413                    let test_case = &test_cases_clone[idx];
414
415                    match Self::run_single_test_static(test_case, &sampler_clone) {
416                        Ok(result) => {
417                            if let Ok(mut guard) = results_clone.lock() {
418                                guard.push(result);
419                            }
420                        }
421                        Err(e) => {
422                            if let Ok(mut guard) = failures_clone.lock() {
423                                guard.push(TestFailure {
424                                    test_id: test_case.id.clone(),
425                                    failure_type: FailureType::SamplerError,
426                                    message: e,
427                                    stack_trace: None,
428                                    context: HashMap::new(),
429                                });
430                            }
431                        }
432                    }
433                }
434            });
435
436            handles.push(handle);
437        }
438
439        // Wait for all threads to complete
440        for handle in handles {
441            handle.join().map_err(|_| "Thread panic")?;
442        }
443
444        // Collect results
445        self.results.test_results = results
446            .lock()
447            .map(|guard| guard.clone())
448            .unwrap_or_default();
449        self.results.failures = failures
450            .lock()
451            .map(|guard| guard.clone())
452            .unwrap_or_default();
453
454        self.results.performance.runtime_stats.total_time = total_start.elapsed();
455        self.results.summary.passed = self.results.test_results.len();
456        self.results.summary.failed = self.results.failures.len();
457        self.results.summary.total_tests =
458            self.results.summary.passed + self.results.summary.failed;
459
460        self.calculate_summary();
461
462        Ok(())
463    }
464
465    /// Static version of run_single_test for parallel execution
466    fn run_single_test_static<S: Sampler>(
467        test_case: &TestCase,
468        sampler: &S,
469    ) -> Result<TestResult, String> {
470        let solve_start = Instant::now();
471
472        // Run sampler
473        let sample_result = sampler
474            .run_qubo(&(test_case.qubo.clone(), test_case.var_map.clone()), 100)
475            .map_err(|e| format!("Sampler error: {e:?}"))?;
476
477        let solve_time = solve_start.elapsed();
478
479        // Get best solution
480        let best_sample = sample_result
481            .iter()
482            .min_by(|a, b| {
483                a.energy
484                    .partial_cmp(&b.energy)
485                    .unwrap_or(std::cmp::Ordering::Equal)
486            })
487            .ok_or("No samples returned")?;
488
489        let solution = best_sample.assignments.clone();
490
491        Ok(TestResult {
492            test_id: test_case.id.clone(),
493            sampler: "parallel".to_string(),
494            solution,
495            objective_value: best_sample.energy,
496            constraints_satisfied: true,
497            validation: ValidationResult {
498                is_valid: true,
499                checks: Vec::new(),
500                warnings: Vec::new(),
501            },
502            runtime: solve_time,
503            metrics: HashMap::new(),
504        })
505    }
506
507    /// Generate CI/CD report
508    pub fn generate_ci_report(&self) -> Result<CIReport, String> {
509        let passed_rate = if self.results.summary.total_tests > 0 {
510            self.results.summary.passed as f64 / self.results.summary.total_tests as f64
511        } else {
512            0.0
513        };
514
515        let status = if passed_rate >= 0.95 {
516            CIStatus::Pass
517        } else if passed_rate >= 0.8 {
518            CIStatus::Warning
519        } else {
520            CIStatus::Fail
521        };
522
523        Ok(CIReport {
524            status,
525            passed_rate,
526            total_tests: self.results.summary.total_tests,
527            failed_tests: self.results.summary.failed,
528            critical_failures: self
529                .results
530                .failures
531                .iter()
532                .filter(|f| {
533                    matches!(
534                        f.failure_type,
535                        FailureType::Timeout | FailureType::SamplerError
536                    )
537                })
538                .count(),
539            avg_runtime: self.results.summary.avg_runtime,
540            quality_score: self.calculate_quality_score(),
541        })
542    }
543
544    /// Calculate overall quality score
545    fn calculate_quality_score(&self) -> f64 {
546        if self.results.test_results.is_empty() {
547            return 0.0;
548        }
549
550        let constraint_score = self
551            .results
552            .summary
553            .quality_metrics
554            .constraint_satisfaction_rate;
555        let success_score = self.results.summary.success_rate;
556        let quality_score = if self
557            .results
558            .summary
559            .quality_metrics
560            .best_quality
561            .is_finite()
562        {
563            0.8 // Base score for having finite solutions
564        } else {
565            0.0
566        };
567
568        (constraint_score.mul_add(0.4, success_score * 0.4) + quality_score * 0.2) * 100.0
569    }
570
571    /// Add stress test cases
572    pub fn add_stress_tests(&mut self) {
573        let stress_categories = vec![
574            TestCategory {
575                name: "Large Scale Tests".to_string(),
576                description: "Tests with large problem sizes".to_string(),
577                problem_types: vec![ProblemType::MaxCut, ProblemType::TSP],
578                difficulties: vec![Difficulty::Extreme],
579                tags: vec!["stress".to_string(), "large".to_string()],
580            },
581            TestCategory {
582                name: "Memory Stress Tests".to_string(),
583                description: "Tests designed to stress memory usage".to_string(),
584                problem_types: vec![ProblemType::Knapsack],
585                difficulties: vec![Difficulty::VeryHard, Difficulty::Extreme],
586                tags: vec!["stress".to_string(), "memory".to_string()],
587            },
588            TestCategory {
589                name: "Runtime Stress Tests".to_string(),
590                description: "Tests with challenging runtime requirements".to_string(),
591                problem_types: vec![ProblemType::GraphColoring],
592                difficulties: vec![Difficulty::Extreme],
593                tags: vec!["stress".to_string(), "runtime".to_string()],
594            },
595        ];
596
597        for category in stress_categories {
598            self.suite.categories.push(category);
599        }
600    }
601
602    /// Detect test environment
603    pub fn detect_environment(&self) -> TestEnvironment {
604        TestEnvironment {
605            os: std::env::consts::OS.to_string(),
606            cpu_model: "Unknown".to_string(), // Would need OS-specific detection
607            memory_gb: 8.0,                   // Simplified - would need system detection
608            gpu_info: None,
609            rust_version: std::env::var("RUSTC_VERSION").unwrap_or_else(|_| "unknown".to_string()),
610            compile_flags: vec!["--release".to_string()],
611        }
612    }
613
614    /// Export test results for external analysis
615    pub fn export_results(&self, format: &str) -> Result<String, String> {
616        match format {
617            "csv" => reports::export_csv(&self.results, &self.suite),
618            "json" => reports::generate_json_report(&self.results),
619            "xml" => reports::export_xml(&self.results),
620            _ => Err(format!("Unsupported export format: {format}")),
621        }
622    }
623
624    /// Add industry-specific test generators
625    pub fn add_industry_generators(&mut self) {
626        // Add finance test generator
627        self.generators.push(Box::new(FinanceTestGenerator));
628
629        // Add logistics test generator
630        self.generators.push(Box::new(LogisticsTestGenerator));
631
632        // Add manufacturing test generator
633        self.generators.push(Box::new(ManufacturingTestGenerator));
634    }
635
636    /// Generate performance comparison report
637    pub fn compare_samplers<S1: Sampler, S2: Sampler>(
638        &mut self,
639        sampler1: &S1,
640        sampler2: &S2,
641        sampler1_name: &str,
642        sampler2_name: &str,
643    ) -> Result<SamplerComparison, String> {
644        // Run tests with first sampler
645        self.run_suite(sampler1)?;
646        let results1 = self.results.test_results.clone();
647
648        // Clear results and run with second sampler
649        self.results.test_results.clear();
650        self.run_suite(sampler2)?;
651        let results2 = self.results.test_results.clone();
652
653        // Compare results
654        let mut comparisons = Vec::new();
655
656        for r1 in &results1 {
657            if let Some(r2) = results2.iter().find(|r| r.test_id == r1.test_id) {
658                let quality_diff = r2.objective_value - r1.objective_value;
659                let runtime_ratio = r2.runtime.as_secs_f64() / r1.runtime.as_secs_f64();
660
661                comparisons.push(TestComparison {
662                    test_id: r1.test_id.clone(),
663                    sampler1_quality: r1.objective_value,
664                    sampler2_quality: r2.objective_value,
665                    quality_improvement: -quality_diff, // Negative because lower is better
666                    sampler1_runtime: r1.runtime,
667                    sampler2_runtime: r2.runtime,
668                    runtime_ratio,
669                });
670            }
671        }
672
673        let avg_quality_improvement = comparisons
674            .iter()
675            .map(|c| c.quality_improvement)
676            .sum::<f64>()
677            / comparisons.len() as f64;
678        let avg_runtime_ratio =
679            comparisons.iter().map(|c| c.runtime_ratio).sum::<f64>() / comparisons.len() as f64;
680
681        Ok(SamplerComparison {
682            sampler1_name: sampler1_name.to_string(),
683            sampler2_name: sampler2_name.to_string(),
684            test_comparisons: comparisons,
685            avg_quality_improvement,
686            avg_runtime_ratio,
687            winner: if avg_quality_improvement > 0.0 {
688                sampler2_name.to_string()
689            } else {
690                sampler1_name.to_string()
691            },
692        })
693    }
694}