1use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use std::thread;
9use std::time::{Duration, Instant};
10
11use crate::sampler::Sampler;
12
13use super::config::{ReportFormat, TestConfig};
14use super::generators::{
15 default_generators, FinanceTestGenerator, LogisticsTestGenerator, ManufacturingTestGenerator,
16};
17use super::reports;
18use super::results::{ConvergenceData, MemoryStats, PerformanceData, RuntimeStats, TestResults};
19use super::types::{
20 CIReport, CIStatus, Difficulty, FailureType, GeneratorConfig, ProblemType, RegressionIssue,
21 RegressionReport, SamplerComparison, TestCase, TestCategory, TestComparison, TestEnvironment,
22 TestFailure, TestGenerator, TestResult, TestSuite, ValidationResult, Validator,
23};
24use super::validators::default_validators;
25
26pub struct TestingFramework {
28 pub config: TestConfig,
30 pub suite: TestSuite,
32 pub results: TestResults,
34 validators: Vec<Box<dyn Validator>>,
36 generators: Vec<Box<dyn TestGenerator>>,
38}
39
40impl TestingFramework {
41 pub fn new(config: TestConfig) -> Self {
43 Self {
44 config,
45 suite: TestSuite {
46 categories: Vec::new(),
47 test_cases: Vec::new(),
48 benchmarks: Vec::new(),
49 },
50 results: TestResults::default(),
51 validators: default_validators(),
52 generators: default_generators(),
53 }
54 }
55
56 pub fn add_category(&mut self, category: TestCategory) {
58 self.suite.categories.push(category);
59 }
60
61 pub fn add_generator(&mut self, generator: Box<dyn TestGenerator>) {
63 self.generators.push(generator);
64 }
65
66 pub fn add_validator(&mut self, validator: Box<dyn Validator>) {
68 self.validators.push(validator);
69 }
70
71 pub fn generate_suite(&mut self) -> Result<(), String> {
73 let start_time = Instant::now();
74
75 for category in &self.suite.categories {
77 for problem_type in &category.problem_types {
78 for difficulty in &category.difficulties {
79 for size in &self.config.problem_sizes {
80 let config = GeneratorConfig {
81 problem_type: problem_type.clone(),
82 size: *size,
83 difficulty: difficulty.clone(),
84 seed: self.config.seed,
85 parameters: HashMap::new(),
86 };
87
88 for generator in &self.generators {
90 if generator.supported_types().contains(problem_type) {
91 let test_cases = generator.generate(&config)?;
92 self.suite.test_cases.extend(test_cases);
93 break;
94 }
95 }
96 }
97 }
98 }
99 }
100
101 self.results.performance.runtime_stats.qubo_generation_time = start_time.elapsed();
102
103 Ok(())
104 }
105
106 pub fn run_suite<S: Sampler>(&mut self, sampler: &S) -> Result<(), String> {
108 let total_start = Instant::now();
109
110 let test_cases = self.suite.test_cases.clone();
111 for test_case in &test_cases {
112 let test_start = Instant::now();
113
114 match self.run_single_test(test_case, sampler) {
116 Ok(result) => {
117 self.results.test_results.push(result);
118 self.results.summary.passed += 1;
119 }
120 Err(e) => {
121 self.results.failures.push(TestFailure {
122 test_id: test_case.id.clone(),
123 failure_type: FailureType::SamplerError,
124 message: e,
125 stack_trace: None,
126 context: HashMap::new(),
127 });
128 self.results.summary.failed += 1;
129 }
130 }
131
132 let test_time = test_start.elapsed();
133 self.results
134 .performance
135 .runtime_stats
136 .time_per_test
137 .push((test_case.id.clone(), test_time));
138
139 self.results.summary.total_tests += 1;
140 }
141
142 self.results.performance.runtime_stats.total_time = total_start.elapsed();
143 self.calculate_summary();
144
145 Ok(())
146 }
147
148 fn run_single_test<S: Sampler>(
150 &mut self,
151 test_case: &TestCase,
152 sampler: &S,
153 ) -> Result<TestResult, String> {
154 let solve_start = Instant::now();
155
156 let sample_result = sampler
158 .run_qubo(
159 &(test_case.qubo.clone(), test_case.var_map.clone()),
160 self.config.samplers[0].num_samples,
161 )
162 .map_err(|e| format!("Sampler error: {e:?}"))?;
163
164 let solve_time = solve_start.elapsed();
165
166 let best_sample = sample_result
168 .iter()
169 .min_by(|a, b| {
170 a.energy
171 .partial_cmp(&b.energy)
172 .unwrap_or(std::cmp::Ordering::Equal)
173 })
174 .ok_or("No samples returned")?;
175
176 let solution = best_sample.assignments.clone();
178
179 let validation_start = Instant::now();
181 let mut validation = ValidationResult {
182 is_valid: true,
183 checks: Vec::new(),
184 warnings: Vec::new(),
185 };
186
187 for validator in &self.validators {
188 let result = validator.validate(
189 test_case,
190 &TestResult {
191 test_id: test_case.id.clone(),
192 sampler: "test".to_string(),
193 solution: solution.clone(),
194 objective_value: best_sample.energy,
195 constraints_satisfied: true,
196 validation: validation.clone(),
197 runtime: solve_time,
198 metrics: HashMap::new(),
199 },
200 );
201
202 validation.checks.extend(result.checks);
203 validation.warnings.extend(result.warnings);
204 validation.is_valid &= result.is_valid;
205 }
206
207 let validation_time = validation_start.elapsed();
208 self.results.performance.runtime_stats.solving_time += solve_time;
209 self.results.performance.runtime_stats.validation_time += validation_time;
210
211 Ok(TestResult {
212 test_id: test_case.id.clone(),
213 sampler: self.config.samplers[0].name.clone(),
214 solution,
215 objective_value: best_sample.energy,
216 constraints_satisfied: validation.is_valid,
217 validation,
218 runtime: solve_time + validation_time,
219 metrics: HashMap::new(),
220 })
221 }
222
223 fn calculate_summary(&mut self) {
225 if self.results.test_results.is_empty() {
226 return;
227 }
228
229 self.results.summary.success_rate =
231 self.results.summary.passed as f64 / self.results.summary.total_tests as f64;
232
233 let total_runtime: Duration = self.results.test_results.iter().map(|r| r.runtime).sum();
235 self.results.summary.avg_runtime = total_runtime / self.results.test_results.len() as u32;
236
237 let qualities: Vec<f64> = self
239 .results
240 .test_results
241 .iter()
242 .map(|r| r.objective_value)
243 .collect();
244
245 self.results.summary.quality_metrics.avg_quality =
246 qualities.iter().sum::<f64>() / qualities.len() as f64;
247
248 self.results.summary.quality_metrics.best_quality = *qualities
249 .iter()
250 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
251 .unwrap_or(&0.0);
252
253 self.results.summary.quality_metrics.worst_quality = *qualities
254 .iter()
255 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
256 .unwrap_or(&0.0);
257
258 let mean = self.results.summary.quality_metrics.avg_quality;
260 let variance =
261 qualities.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / qualities.len() as f64;
262 self.results.summary.quality_metrics.std_dev = variance.sqrt();
263
264 let satisfied = self
266 .results
267 .test_results
268 .iter()
269 .filter(|r| r.constraints_satisfied)
270 .count();
271 self.results
272 .summary
273 .quality_metrics
274 .constraint_satisfaction_rate =
275 satisfied as f64 / self.results.test_results.len() as f64;
276 }
277
278 pub fn generate_report(&self) -> Result<String, String> {
280 reports::generate_report(&self.config.output.format, &self.results, &self.suite)
281 }
282
283 pub fn save_report(&self, filename: &str) -> Result<(), String> {
285 let report = self.generate_report()?;
286 reports::save_report(&report, filename)
287 }
288
289 pub fn run_regression_tests<S: Sampler>(
291 &mut self,
292 sampler: &S,
293 baseline_file: &str,
294 ) -> Result<RegressionReport, String> {
295 let baseline = self.load_baseline(baseline_file)?;
297
298 self.run_suite(sampler)?;
300
301 let mut regressions = Vec::new();
303 let mut improvements = Vec::new();
304
305 for current_result in &self.results.test_results {
306 if let Some(baseline_result) = baseline
307 .iter()
308 .find(|b| b.test_id == current_result.test_id)
309 {
310 let quality_change = (current_result.objective_value
311 - baseline_result.objective_value)
312 / baseline_result.objective_value.abs();
313 let runtime_change = (current_result.runtime.as_secs_f64()
314 - baseline_result.runtime.as_secs_f64())
315 / baseline_result.runtime.as_secs_f64();
316
317 if quality_change > 0.05 || runtime_change > 0.2 {
318 regressions.push(RegressionIssue {
319 test_id: current_result.test_id.clone(),
320 metric: if quality_change > 0.05 {
321 "quality".to_string()
322 } else {
323 "runtime".to_string()
324 },
325 baseline_value: if quality_change > 0.05 {
326 baseline_result.objective_value
327 } else {
328 baseline_result.runtime.as_secs_f64()
329 },
330 current_value: if quality_change > 0.05 {
331 current_result.objective_value
332 } else {
333 current_result.runtime.as_secs_f64()
334 },
335 change_percent: if quality_change > 0.05 {
336 quality_change * 100.0
337 } else {
338 runtime_change * 100.0
339 },
340 });
341 } else if quality_change < -0.05 || runtime_change < -0.2 {
342 improvements.push(RegressionIssue {
343 test_id: current_result.test_id.clone(),
344 metric: if quality_change < -0.05 {
345 "quality".to_string()
346 } else {
347 "runtime".to_string()
348 },
349 baseline_value: if quality_change < -0.05 {
350 baseline_result.objective_value
351 } else {
352 baseline_result.runtime.as_secs_f64()
353 },
354 current_value: if quality_change < -0.05 {
355 current_result.objective_value
356 } else {
357 current_result.runtime.as_secs_f64()
358 },
359 change_percent: if quality_change < -0.05 {
360 quality_change * 100.0
361 } else {
362 runtime_change * 100.0
363 },
364 });
365 }
366 }
367 }
368
369 Ok(RegressionReport {
370 regressions,
371 improvements,
372 baseline_tests: baseline.len(),
373 current_tests: self.results.test_results.len(),
374 })
375 }
376
377 const fn load_baseline(&self, _filename: &str) -> Result<Vec<TestResult>, String> {
379 Ok(Vec::new())
381 }
382
383 pub fn run_suite_parallel<S: Sampler + Clone + Send + Sync + 'static>(
385 &mut self,
386 sampler: &S,
387 num_threads: usize,
388 ) -> Result<(), String> {
389 let test_cases = Arc::new(self.suite.test_cases.clone());
390 let results = Arc::new(Mutex::new(Vec::new()));
391 let failures = Arc::new(Mutex::new(Vec::new()));
392
393 let total_start = Instant::now();
394 let chunk_size = test_cases.len().div_ceil(num_threads);
395
396 let mut handles = Vec::new();
397
398 for thread_id in 0..num_threads {
399 let start_idx = thread_id * chunk_size;
400 let end_idx = ((thread_id + 1) * chunk_size).min(test_cases.len());
401
402 if start_idx >= test_cases.len() {
403 break;
404 }
405
406 let test_cases_clone = Arc::clone(&test_cases);
407 let results_clone = Arc::clone(&results);
408 let failures_clone = Arc::clone(&failures);
409 let sampler_clone = sampler.clone();
410
411 let handle = thread::spawn(move || {
412 for idx in start_idx..end_idx {
413 let test_case = &test_cases_clone[idx];
414
415 match Self::run_single_test_static(test_case, &sampler_clone) {
416 Ok(result) => {
417 if let Ok(mut guard) = results_clone.lock() {
418 guard.push(result);
419 }
420 }
421 Err(e) => {
422 if let Ok(mut guard) = failures_clone.lock() {
423 guard.push(TestFailure {
424 test_id: test_case.id.clone(),
425 failure_type: FailureType::SamplerError,
426 message: e,
427 stack_trace: None,
428 context: HashMap::new(),
429 });
430 }
431 }
432 }
433 }
434 });
435
436 handles.push(handle);
437 }
438
439 for handle in handles {
441 handle.join().map_err(|_| "Thread panic")?;
442 }
443
444 self.results.test_results = results
446 .lock()
447 .map(|guard| guard.clone())
448 .unwrap_or_default();
449 self.results.failures = failures
450 .lock()
451 .map(|guard| guard.clone())
452 .unwrap_or_default();
453
454 self.results.performance.runtime_stats.total_time = total_start.elapsed();
455 self.results.summary.passed = self.results.test_results.len();
456 self.results.summary.failed = self.results.failures.len();
457 self.results.summary.total_tests =
458 self.results.summary.passed + self.results.summary.failed;
459
460 self.calculate_summary();
461
462 Ok(())
463 }
464
465 fn run_single_test_static<S: Sampler>(
467 test_case: &TestCase,
468 sampler: &S,
469 ) -> Result<TestResult, String> {
470 let solve_start = Instant::now();
471
472 let sample_result = sampler
474 .run_qubo(&(test_case.qubo.clone(), test_case.var_map.clone()), 100)
475 .map_err(|e| format!("Sampler error: {e:?}"))?;
476
477 let solve_time = solve_start.elapsed();
478
479 let best_sample = sample_result
481 .iter()
482 .min_by(|a, b| {
483 a.energy
484 .partial_cmp(&b.energy)
485 .unwrap_or(std::cmp::Ordering::Equal)
486 })
487 .ok_or("No samples returned")?;
488
489 let solution = best_sample.assignments.clone();
490
491 Ok(TestResult {
492 test_id: test_case.id.clone(),
493 sampler: "parallel".to_string(),
494 solution,
495 objective_value: best_sample.energy,
496 constraints_satisfied: true,
497 validation: ValidationResult {
498 is_valid: true,
499 checks: Vec::new(),
500 warnings: Vec::new(),
501 },
502 runtime: solve_time,
503 metrics: HashMap::new(),
504 })
505 }
506
507 pub fn generate_ci_report(&self) -> Result<CIReport, String> {
509 let passed_rate = if self.results.summary.total_tests > 0 {
510 self.results.summary.passed as f64 / self.results.summary.total_tests as f64
511 } else {
512 0.0
513 };
514
515 let status = if passed_rate >= 0.95 {
516 CIStatus::Pass
517 } else if passed_rate >= 0.8 {
518 CIStatus::Warning
519 } else {
520 CIStatus::Fail
521 };
522
523 Ok(CIReport {
524 status,
525 passed_rate,
526 total_tests: self.results.summary.total_tests,
527 failed_tests: self.results.summary.failed,
528 critical_failures: self
529 .results
530 .failures
531 .iter()
532 .filter(|f| {
533 matches!(
534 f.failure_type,
535 FailureType::Timeout | FailureType::SamplerError
536 )
537 })
538 .count(),
539 avg_runtime: self.results.summary.avg_runtime,
540 quality_score: self.calculate_quality_score(),
541 })
542 }
543
544 fn calculate_quality_score(&self) -> f64 {
546 if self.results.test_results.is_empty() {
547 return 0.0;
548 }
549
550 let constraint_score = self
551 .results
552 .summary
553 .quality_metrics
554 .constraint_satisfaction_rate;
555 let success_score = self.results.summary.success_rate;
556 let quality_score = if self
557 .results
558 .summary
559 .quality_metrics
560 .best_quality
561 .is_finite()
562 {
563 0.8 } else {
565 0.0
566 };
567
568 (constraint_score.mul_add(0.4, success_score * 0.4) + quality_score * 0.2) * 100.0
569 }
570
571 pub fn add_stress_tests(&mut self) {
573 let stress_categories = vec![
574 TestCategory {
575 name: "Large Scale Tests".to_string(),
576 description: "Tests with large problem sizes".to_string(),
577 problem_types: vec![ProblemType::MaxCut, ProblemType::TSP],
578 difficulties: vec![Difficulty::Extreme],
579 tags: vec!["stress".to_string(), "large".to_string()],
580 },
581 TestCategory {
582 name: "Memory Stress Tests".to_string(),
583 description: "Tests designed to stress memory usage".to_string(),
584 problem_types: vec![ProblemType::Knapsack],
585 difficulties: vec![Difficulty::VeryHard, Difficulty::Extreme],
586 tags: vec!["stress".to_string(), "memory".to_string()],
587 },
588 TestCategory {
589 name: "Runtime Stress Tests".to_string(),
590 description: "Tests with challenging runtime requirements".to_string(),
591 problem_types: vec![ProblemType::GraphColoring],
592 difficulties: vec![Difficulty::Extreme],
593 tags: vec!["stress".to_string(), "runtime".to_string()],
594 },
595 ];
596
597 for category in stress_categories {
598 self.suite.categories.push(category);
599 }
600 }
601
602 pub fn detect_environment(&self) -> TestEnvironment {
604 TestEnvironment {
605 os: std::env::consts::OS.to_string(),
606 cpu_model: "Unknown".to_string(), memory_gb: 8.0, gpu_info: None,
609 rust_version: std::env::var("RUSTC_VERSION").unwrap_or_else(|_| "unknown".to_string()),
610 compile_flags: vec!["--release".to_string()],
611 }
612 }
613
614 pub fn export_results(&self, format: &str) -> Result<String, String> {
616 match format {
617 "csv" => reports::export_csv(&self.results, &self.suite),
618 "json" => reports::generate_json_report(&self.results),
619 "xml" => reports::export_xml(&self.results),
620 _ => Err(format!("Unsupported export format: {format}")),
621 }
622 }
623
624 pub fn add_industry_generators(&mut self) {
626 self.generators.push(Box::new(FinanceTestGenerator));
628
629 self.generators.push(Box::new(LogisticsTestGenerator));
631
632 self.generators.push(Box::new(ManufacturingTestGenerator));
634 }
635
636 pub fn compare_samplers<S1: Sampler, S2: Sampler>(
638 &mut self,
639 sampler1: &S1,
640 sampler2: &S2,
641 sampler1_name: &str,
642 sampler2_name: &str,
643 ) -> Result<SamplerComparison, String> {
644 self.run_suite(sampler1)?;
646 let results1 = self.results.test_results.clone();
647
648 self.results.test_results.clear();
650 self.run_suite(sampler2)?;
651 let results2 = self.results.test_results.clone();
652
653 let mut comparisons = Vec::new();
655
656 for r1 in &results1 {
657 if let Some(r2) = results2.iter().find(|r| r.test_id == r1.test_id) {
658 let quality_diff = r2.objective_value - r1.objective_value;
659 let runtime_ratio = r2.runtime.as_secs_f64() / r1.runtime.as_secs_f64();
660
661 comparisons.push(TestComparison {
662 test_id: r1.test_id.clone(),
663 sampler1_quality: r1.objective_value,
664 sampler2_quality: r2.objective_value,
665 quality_improvement: -quality_diff, sampler1_runtime: r1.runtime,
667 sampler2_runtime: r2.runtime,
668 runtime_ratio,
669 });
670 }
671 }
672
673 let avg_quality_improvement = comparisons
674 .iter()
675 .map(|c| c.quality_improvement)
676 .sum::<f64>()
677 / comparisons.len() as f64;
678 let avg_runtime_ratio =
679 comparisons.iter().map(|c| c.runtime_ratio).sum::<f64>() / comparisons.len() as f64;
680
681 Ok(SamplerComparison {
682 sampler1_name: sampler1_name.to_string(),
683 sampler2_name: sampler2_name.to_string(),
684 test_comparisons: comparisons,
685 avg_quality_improvement,
686 avg_runtime_ratio,
687 winner: if avg_quality_improvement > 0.0 {
688 sampler2_name.to_string()
689 } else {
690 sampler1_name.to_string()
691 },
692 })
693 }
694}