quantrs2_tytan/testing_framework/
validators.rs1use scirs2_core::ndarray::Array2;
7use scirs2_core::random::prelude::*;
8use std::collections::HashMap;
9
10use super::types::{
11 Constraint, ConstraintType, TestCase, TestResult, ValidationCheck, ValidationResult, Validator,
12};
13
14pub struct ConstraintValidator;
16
17impl Validator for ConstraintValidator {
18 fn validate(&self, test_case: &TestCase, result: &TestResult) -> ValidationResult {
19 let mut checks = Vec::new();
20 let mut is_valid = true;
21
22 for constraint in &test_case.constraints {
23 let satisfied = self.check_constraint(constraint, &result.solution);
24
25 checks.push(ValidationCheck {
26 name: format!("Constraint {:?}", constraint.constraint_type),
27 passed: satisfied,
28 message: if satisfied {
29 "Constraint satisfied".to_string()
30 } else {
31 "Constraint violated".to_string()
32 },
33 details: None,
34 });
35
36 is_valid &= satisfied;
37 }
38
39 ValidationResult {
40 is_valid,
41 checks,
42 warnings: Vec::new(),
43 }
44 }
45
46 fn name(&self) -> &'static str {
47 "ConstraintValidator"
48 }
49}
50
51impl ConstraintValidator {
52 fn check_constraint(&self, constraint: &Constraint, solution: &HashMap<String, bool>) -> bool {
53 match &constraint.constraint_type {
54 ConstraintType::OneHot => {
55 let active = constraint
56 .variables
57 .iter()
58 .filter(|v| *solution.get(*v).unwrap_or(&false))
59 .count();
60 active == 1
61 }
62 ConstraintType::AtMostK { k } => {
63 let active = constraint
64 .variables
65 .iter()
66 .filter(|v| *solution.get(*v).unwrap_or(&false))
67 .count();
68 active <= *k
69 }
70 ConstraintType::AtLeastK { k } => {
71 let active = constraint
72 .variables
73 .iter()
74 .filter(|v| *solution.get(*v).unwrap_or(&false))
75 .count();
76 active >= *k
77 }
78 ConstraintType::ExactlyK { k } => {
79 let active = constraint
80 .variables
81 .iter()
82 .filter(|v| *solution.get(*v).unwrap_or(&false))
83 .count();
84 active == *k
85 }
86 _ => true, }
88 }
89}
90
91pub struct ObjectiveValidator;
93
94impl Validator for ObjectiveValidator {
95 fn validate(&self, test_case: &TestCase, result: &TestResult) -> ValidationResult {
96 let mut checks = Vec::new();
97
98 let random_value = self.estimate_random_objective(&test_case.qubo);
100 let improvement = (random_value - result.objective_value) / random_value.abs();
101
102 checks.push(ValidationCheck {
103 name: "Objective improvement".to_string(),
104 passed: improvement > 0.0,
105 message: format!("Improvement over random: {:.2}%", improvement * 100.0),
106 details: Some(format!(
107 "Random: {:.4}, Found: {:.4}",
108 random_value, result.objective_value
109 )),
110 });
111
112 if let Some(optimal_value) = test_case.optimal_value {
114 let gap = (result.objective_value - optimal_value).abs() / optimal_value.abs();
115 let acceptable_gap = 0.05; checks.push(ValidationCheck {
118 name: "Optimality gap".to_string(),
119 passed: gap <= acceptable_gap,
120 message: format!("Gap to optimal: {:.2}%", gap * 100.0),
121 details: Some(format!(
122 "Optimal: {:.4}, Found: {:.4}",
123 optimal_value, result.objective_value
124 )),
125 });
126 }
127
128 ValidationResult {
129 is_valid: checks.iter().all(|c| c.passed),
130 checks,
131 warnings: Vec::new(),
132 }
133 }
134
135 fn name(&self) -> &'static str {
136 "ObjectiveValidator"
137 }
138}
139
140impl ObjectiveValidator {
141 fn estimate_random_objective(&self, qubo: &Array2<f64>) -> f64 {
142 let n = qubo.shape()[0];
143 let mut rng = thread_rng();
144 let mut total = 0.0;
145 let samples = 100;
146
147 for _ in 0..samples {
148 let mut x = vec![0.0; n];
149 for x_item in x.iter_mut().take(n) {
150 *x_item = if rng.random::<bool>() { 1.0 } else { 0.0 };
151 }
152
153 let mut value = 0.0;
154 for i in 0..n {
155 for j in 0..n {
156 value += qubo[[i, j]] * x[i] * x[j];
157 }
158 }
159
160 total += value;
161 }
162
163 total / samples as f64
164 }
165}
166
167pub struct BoundsValidator;
169
170impl Validator for BoundsValidator {
171 fn validate(&self, test_case: &TestCase, result: &TestResult) -> ValidationResult {
172 let mut checks = Vec::new();
173
174 let all_binary = true;
176
177 checks.push(ValidationCheck {
178 name: "Binary variables".to_string(),
179 passed: all_binary,
180 message: if all_binary {
181 "All variables are binary".to_string()
182 } else {
183 "Non-binary values found".to_string()
184 },
185 details: None,
186 });
187
188 let expected_vars = test_case.var_map.len();
190 let actual_vars = result.solution.len();
191
192 checks.push(ValidationCheck {
193 name: "Variable count".to_string(),
194 passed: expected_vars == actual_vars,
195 message: format!("Expected {expected_vars} variables, found {actual_vars}"),
196 details: None,
197 });
198
199 ValidationResult {
200 is_valid: checks.iter().all(|c| c.passed),
201 checks,
202 warnings: Vec::new(),
203 }
204 }
205
206 fn name(&self) -> &'static str {
207 "BoundsValidator"
208 }
209}
210
211pub struct SymmetryValidator;
213
214impl Validator for SymmetryValidator {
215 fn validate(&self, test_case: &TestCase, _result: &TestResult) -> ValidationResult {
216 let mut warnings = Vec::new();
217
218 if self.is_symmetric(&test_case.qubo) {
220 warnings.push("QUBO matrix has symmetries that might not be broken".to_string());
221 }
222
223 ValidationResult {
224 is_valid: true,
225 checks: Vec::new(),
226 warnings,
227 }
228 }
229
230 fn name(&self) -> &'static str {
231 "SymmetryValidator"
232 }
233}
234
235impl SymmetryValidator {
236 fn is_symmetric(&self, qubo: &Array2<f64>) -> bool {
237 let n = qubo.shape()[0];
238
239 for i in 0..n {
240 for j in i + 1..n {
241 if (qubo[[i, j]] - qubo[[j, i]]).abs() > 1e-10 {
242 return false;
243 }
244 }
245 }
246
247 true
248 }
249}
250
251pub fn default_validators() -> Vec<Box<dyn Validator>> {
253 vec![
254 Box::new(ConstraintValidator),
255 Box::new(ObjectiveValidator),
256 Box::new(BoundsValidator),
257 Box::new(SymmetryValidator),
258 ]
259}