1use crate::complex_ext::QuantumComplexExt;
7use crate::error::QuantRS2Error;
8use ndarray::{Array1, Array2};
9use num_complex::Complex64;
10use std::fmt;
11
12pub const DEFAULT_TOLERANCE: f64 = 1e-10;
14
15#[derive(Debug, Clone, PartialEq)]
17pub enum TestResult {
18 Pass,
20 Fail(String),
22 Skip(String),
24}
25
26impl TestResult {
27 pub fn is_pass(&self) -> bool {
28 matches!(self, TestResult::Pass)
29 }
30}
31
32pub struct QuantumAssert {
34 tolerance: f64,
35}
36
37impl Default for QuantumAssert {
38 fn default() -> Self {
39 Self {
40 tolerance: DEFAULT_TOLERANCE,
41 }
42 }
43}
44
45impl QuantumAssert {
46 pub fn with_tolerance(tolerance: f64) -> Self {
48 Self { tolerance }
49 }
50
51 pub fn states_equal(
53 &self,
54 state1: &Array1<Complex64>,
55 state2: &Array1<Complex64>,
56 ) -> TestResult {
57 if state1.len() != state2.len() {
58 return TestResult::Fail(format!(
59 "State dimensions mismatch: {} vs {}",
60 state1.len(),
61 state2.len()
62 ));
63 }
64
65 let mut phase_factor = None;
67 for i in 0..state1.len() {
68 if state1[i].norm() > self.tolerance && state2[i].norm() > self.tolerance {
69 phase_factor = Some(state2[i] / state1[i]);
70 break;
71 }
72 }
73
74 let phase = phase_factor.unwrap_or(Complex64::new(1.0, 0.0));
75
76 for i in 0..state1.len() {
78 let expected = state1[i] * phase;
79 if (expected - state2[i]).norm() > self.tolerance {
80 return TestResult::Fail(format!(
81 "States differ at index {}: expected {}, got {}",
82 i, expected, state2[i]
83 ));
84 }
85 }
86
87 TestResult::Pass
88 }
89
90 pub fn state_normalized(&self, state: &Array1<Complex64>) -> TestResult {
92 let norm_squared: f64 = state.iter().map(|c| c.norm_sqr()).sum();
93
94 if (norm_squared - 1.0).abs() > self.tolerance {
95 TestResult::Fail(format!(
96 "State not normalized: norm^2 = {} (expected 1.0)",
97 norm_squared
98 ))
99 } else {
100 TestResult::Pass
101 }
102 }
103
104 pub fn states_orthogonal(
106 &self,
107 state1: &Array1<Complex64>,
108 state2: &Array1<Complex64>,
109 ) -> TestResult {
110 if state1.len() != state2.len() {
111 return TestResult::Fail("State dimensions mismatch".to_string());
112 }
113
114 let inner_product: Complex64 = state1
115 .iter()
116 .zip(state2.iter())
117 .map(|(a, b)| a.conj() * b)
118 .sum();
119
120 if inner_product.norm() > self.tolerance {
121 TestResult::Fail(format!(
122 "States not orthogonal: inner product = {}",
123 inner_product
124 ))
125 } else {
126 TestResult::Pass
127 }
128 }
129
130 pub fn matrix_unitary(&self, matrix: &Array2<Complex64>) -> TestResult {
132 let (rows, cols) = matrix.dim();
133 if rows != cols {
134 return TestResult::Fail(format!("Matrix not square: {}x{}", rows, cols));
135 }
136
137 let conjugate_transpose = matrix.t().mapv(|c| c.conj());
139 let product = conjugate_transpose.dot(matrix);
140
141 for i in 0..rows {
143 for j in 0..cols {
144 let expected = if i == j {
145 Complex64::new(1.0, 0.0)
146 } else {
147 Complex64::new(0.0, 0.0)
148 };
149
150 if (product[[i, j]] - expected).norm() > self.tolerance {
151 return TestResult::Fail(format!(
152 "U†U not identity at ({},{}): got {}",
153 i,
154 j,
155 product[[i, j]]
156 ));
157 }
158 }
159 }
160
161 TestResult::Pass
162 }
163
164 pub fn measurement_probabilities(
166 &self,
167 state: &Array1<Complex64>,
168 expected_probs: &[(usize, f64)],
169 ) -> TestResult {
170 for &(index, expected_prob) in expected_probs {
171 if index >= state.len() {
172 return TestResult::Fail(format!(
173 "Index {} out of bounds for state of length {}",
174 index,
175 state.len()
176 ));
177 }
178
179 let actual_prob = state[index].probability();
180 if (actual_prob - expected_prob).abs() > self.tolerance {
181 return TestResult::Fail(format!(
182 "Probability mismatch at index {}: expected {}, got {}",
183 index, expected_prob, actual_prob
184 ));
185 }
186 }
187
188 TestResult::Pass
189 }
190
191 pub fn is_entangled(&self, state: &Array1<Complex64>, qubit_indices: &[usize]) -> TestResult {
193 if qubit_indices.len() != 2 {
195 return TestResult::Skip(
196 "Entanglement check only implemented for 2-qubit subsystems".to_string(),
197 );
198 }
199
200 let n_qubits = (state.len() as f64).log2() as usize;
201 if qubit_indices.iter().any(|&i| i >= n_qubits) {
202 return TestResult::Fail("Qubit index out of bounds".to_string());
203 }
204
205 TestResult::Skip("Full entanglement check not yet implemented".to_string())
210 }
211}
212
213pub struct QuantumTest {
215 name: String,
216 setup: Option<Box<dyn Fn() -> Result<(), QuantRS2Error>>>,
217 test: Box<dyn Fn() -> TestResult>,
218 teardown: Option<Box<dyn Fn()>>,
219}
220
221impl QuantumTest {
222 pub fn new(name: impl Into<String>, test: impl Fn() -> TestResult + 'static) -> Self {
224 Self {
225 name: name.into(),
226 setup: None,
227 test: Box::new(test),
228 teardown: None,
229 }
230 }
231
232 pub fn with_setup(mut self, setup: impl Fn() -> Result<(), QuantRS2Error> + 'static) -> Self {
234 self.setup = Some(Box::new(setup));
235 self
236 }
237
238 pub fn with_teardown(mut self, teardown: impl Fn() + 'static) -> Self {
240 self.teardown = Some(Box::new(teardown));
241 self
242 }
243
244 pub fn run(&self) -> TestResult {
246 if let Some(setup) = &self.setup {
248 if let Err(e) = setup() {
249 return TestResult::Fail(format!("Setup failed: {}", e));
250 }
251 }
252
253 let result = (self.test)();
255
256 if let Some(teardown) = &self.teardown {
258 teardown();
259 }
260
261 result
262 }
263}
264
265pub struct QuantumTestSuite {
267 name: String,
268 tests: Vec<QuantumTest>,
269}
270
271impl QuantumTestSuite {
272 pub fn new(name: impl Into<String>) -> Self {
274 Self {
275 name: name.into(),
276 tests: Vec::new(),
277 }
278 }
279
280 pub fn add_test(&mut self, test: QuantumTest) {
282 self.tests.push(test);
283 }
284
285 pub fn run(&self) -> TestSuiteResult {
287 let mut results = Vec::new();
288
289 for test in &self.tests {
290 let result = test.run();
291 results.push((test.name.clone(), result));
292 }
293
294 TestSuiteResult {
295 suite_name: self.name.clone(),
296 results,
297 }
298 }
299}
300
301pub struct TestSuiteResult {
303 suite_name: String,
304 results: Vec<(String, TestResult)>,
305}
306
307impl TestSuiteResult {
308 pub fn passed(&self) -> usize {
310 self.results.iter().filter(|(_, r)| r.is_pass()).count()
311 }
312
313 pub fn failed(&self) -> usize {
315 self.results
316 .iter()
317 .filter(|(_, r)| matches!(r, TestResult::Fail(_)))
318 .count()
319 }
320
321 pub fn skipped(&self) -> usize {
323 self.results
324 .iter()
325 .filter(|(_, r)| matches!(r, TestResult::Skip(_)))
326 .count()
327 }
328
329 pub fn total(&self) -> usize {
331 self.results.len()
332 }
333
334 pub fn all_passed(&self) -> bool {
336 self.failed() == 0
337 }
338}
339
340impl fmt::Display for TestSuiteResult {
341 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342 writeln!(f, "\n{} Test Results:", self.suite_name)?;
343 writeln!(f, "{}", "=".repeat(50))?;
344
345 for (name, result) in &self.results {
346 let status = match result {
347 TestResult::Pass => "✓ PASS",
348 TestResult::Fail(_) => "✗ FAIL",
349 TestResult::Skip(_) => "⊙ SKIP",
350 };
351
352 writeln!(f, "{:<6} {}", status, name)?;
353
354 if let TestResult::Fail(reason) = result {
355 writeln!(f, " Reason: {}", reason)?;
356 } else if let TestResult::Skip(reason) = result {
357 writeln!(f, " Reason: {}", reason)?;
358 }
359 }
360
361 writeln!(f, "{}", "=".repeat(50))?;
362 writeln!(
363 f,
364 "Total: {} | Passed: {} | Failed: {} | Skipped: {}",
365 self.total(),
366 self.passed(),
367 self.failed(),
368 self.skipped()
369 )?;
370
371 Ok(())
372 }
373}
374
375#[macro_export]
377macro_rules! quantum_test {
378 ($name:expr, $test:expr) => {
379 QuantumTest::new($name, $test)
380 };
381}
382
383#[macro_export]
384macro_rules! assert_states_equal {
385 ($state1:expr, $state2:expr) => {{
386 let assert = QuantumAssert::default();
387 assert.states_equal($state1, $state2)
388 }};
389 ($state1:expr, $state2:expr, $tolerance:expr) => {{
390 let assert = QuantumAssert::with_tolerance($tolerance);
391 assert.states_equal($state1, $state2)
392 }};
393}
394
395#[macro_export]
396macro_rules! assert_unitary {
397 ($matrix:expr) => {{
398 let assert = QuantumAssert::default();
399 assert.matrix_unitary($matrix)
400 }};
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406 use ndarray::array;
407
408 #[test]
409 fn test_quantum_assert_states_equal() {
410 let assert = QuantumAssert::default();
411
412 let state1 = array![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
414 let state2 = state1.clone();
415 assert!(assert.states_equal(&state1, &state2).is_pass());
416
417 let state3 = array![Complex64::new(0.0, 1.0), Complex64::new(0.0, 0.0)]; assert!(assert.states_equal(&state1, &state3).is_pass());
420
421 let state4 = array![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)];
423 assert!(!assert.states_equal(&state1, &state4).is_pass());
424 }
425
426 #[test]
427 fn test_quantum_assert_normalized() {
428 let assert = QuantumAssert::default();
429
430 let state1 = array![
432 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
433 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0)
434 ];
435 assert!(assert.state_normalized(&state1).is_pass());
436
437 let state2 = array![Complex64::new(1.0, 0.0), Complex64::new(1.0, 0.0)];
439 assert!(!assert.state_normalized(&state2).is_pass());
440 }
441
442 #[test]
443 fn test_quantum_test_suite() {
444 let mut suite = QuantumTestSuite::new("Example Suite");
445
446 suite.add_test(QuantumTest::new("Test 1", || TestResult::Pass));
447 suite.add_test(QuantumTest::new("Test 2", || {
448 TestResult::Fail("Expected failure".to_string())
449 }));
450 suite.add_test(QuantumTest::new("Test 3", || {
451 TestResult::Skip("Not implemented".to_string())
452 }));
453
454 let results = suite.run();
455 assert_eq!(results.total(), 3);
456 assert_eq!(results.passed(), 1);
457 assert_eq!(results.failed(), 1);
458 assert_eq!(results.skipped(), 1);
459 }
460}