quantrs2_core/
testing.rs

1//! Quantum unit testing framework
2//!
3//! This module provides tools for testing quantum circuits, states, and operations
4//! with proper handling of quantum-specific properties like phase and entanglement.
5
6use crate::complex_ext::QuantumComplexExt;
7use crate::error::QuantRS2Error;
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::Complex64;
10use std::fmt;
11
12/// Tolerance for quantum state comparisons
13pub const DEFAULT_TOLERANCE: f64 = 1e-10;
14
15/// Result of a quantum test
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum TestResult {
18    /// Test passed
19    Pass,
20    /// Test failed with reason
21    Fail(String),
22    /// Test skipped
23    Skip(String),
24}
25
26impl TestResult {
27    pub const fn is_pass(&self) -> bool {
28        matches!(self, Self::Pass)
29    }
30}
31
32/// Quantum state assertion helper
33pub struct QuantumAssert {
34    tolerance: f64,
35}
36
37impl Default for QuantumAssert {
38    fn default() -> Self {
39        Self {
40            tolerance: DEFAULT_TOLERANCE,
41        }
42    }
43}
44
45impl QuantumAssert {
46    /// Create with custom tolerance
47    pub const fn with_tolerance(tolerance: f64) -> Self {
48        Self { tolerance }
49    }
50
51    /// Assert two quantum states are equal (up to global phase)
52    pub fn states_equal(
53        &self,
54        state1: &Array1<Complex64>,
55        state2: &Array1<Complex64>,
56    ) -> TestResult {
57        if state1.len() != state2.len() {
58            return TestResult::Fail(format!(
59                "State dimensions mismatch: {} vs {}",
60                state1.len(),
61                state2.len()
62            ));
63        }
64
65        // Find first non-zero amplitude to determine global phase
66        let mut phase_factor = None;
67        for i in 0..state1.len() {
68            if state1[i].norm() > self.tolerance && state2[i].norm() > self.tolerance {
69                phase_factor = Some(state2[i] / state1[i]);
70                break;
71            }
72        }
73
74        let phase = phase_factor.unwrap_or(Complex64::new(1.0, 0.0));
75
76        // Check all amplitudes match after phase correction
77        for i in 0..state1.len() {
78            let expected = state1[i] * phase;
79            if (expected - state2[i]).norm() > self.tolerance {
80                return TestResult::Fail(format!(
81                    "States differ at index {}: expected {}, got {}",
82                    i, expected, state2[i]
83                ));
84            }
85        }
86
87        TestResult::Pass
88    }
89
90    /// Assert a state is normalized
91    pub fn state_normalized(&self, state: &Array1<Complex64>) -> TestResult {
92        let norm_squared: f64 = state.iter().map(|c| c.norm_sqr()).sum();
93
94        if (norm_squared - 1.0).abs() > self.tolerance {
95            TestResult::Fail(format!(
96                "State not normalized: norm^2 = {norm_squared} (expected 1.0)"
97            ))
98        } else {
99            TestResult::Pass
100        }
101    }
102
103    /// Assert two states are orthogonal
104    pub fn states_orthogonal(
105        &self,
106        state1: &Array1<Complex64>,
107        state2: &Array1<Complex64>,
108    ) -> TestResult {
109        if state1.len() != state2.len() {
110            return TestResult::Fail("State dimensions mismatch".to_string());
111        }
112
113        let inner_product: Complex64 = state1
114            .iter()
115            .zip(state2.iter())
116            .map(|(a, b)| a.conj() * b)
117            .sum();
118
119        if inner_product.norm() > self.tolerance {
120            TestResult::Fail(format!(
121                "States not orthogonal: inner product = {inner_product}"
122            ))
123        } else {
124            TestResult::Pass
125        }
126    }
127
128    /// Assert a matrix is unitary
129    pub fn matrix_unitary(&self, matrix: &Array2<Complex64>) -> TestResult {
130        let (rows, cols) = matrix.dim();
131        if rows != cols {
132            return TestResult::Fail(format!("Matrix not square: {rows}x{cols}"));
133        }
134
135        // Compute U† U
136        let conjugate_transpose = matrix.t().mapv(|c| c.conj());
137        let product = conjugate_transpose.dot(matrix);
138
139        // Check if it's identity
140        for i in 0..rows {
141            for j in 0..cols {
142                let expected = if i == j {
143                    Complex64::new(1.0, 0.0)
144                } else {
145                    Complex64::new(0.0, 0.0)
146                };
147
148                if (product[[i, j]] - expected).norm() > self.tolerance {
149                    return TestResult::Fail(format!(
150                        "U†U not identity at ({},{}): got {}",
151                        i,
152                        j,
153                        product[[i, j]]
154                    ));
155                }
156            }
157        }
158
159        TestResult::Pass
160    }
161
162    /// Assert a state has specific measurement probabilities
163    pub fn measurement_probabilities(
164        &self,
165        state: &Array1<Complex64>,
166        expected_probs: &[(usize, f64)],
167    ) -> TestResult {
168        for &(index, expected_prob) in expected_probs {
169            if index >= state.len() {
170                return TestResult::Fail(format!(
171                    "Index {} out of bounds for state of length {}",
172                    index,
173                    state.len()
174                ));
175            }
176
177            let actual_prob = state[index].probability();
178            if (actual_prob - expected_prob).abs() > self.tolerance {
179                return TestResult::Fail(format!(
180                    "Probability mismatch at index {index}: expected {expected_prob}, got {actual_prob}"
181                ));
182            }
183        }
184
185        TestResult::Pass
186    }
187
188    /// Assert entanglement properties
189    pub fn is_entangled(&self, state: &Array1<Complex64>, qubit_indices: &[usize]) -> TestResult {
190        // For a 2-qubit system, check if state can be written as |ψ⟩ = |a⟩ ⊗ |b⟩
191        if qubit_indices.len() != 2 {
192            return TestResult::Skip(
193                "Entanglement check only implemented for 2-qubit subsystems".to_string(),
194            );
195        }
196
197        let n_qubits = (state.len() as f64).log2() as usize;
198        if qubit_indices.iter().any(|&i| i >= n_qubits) {
199            return TestResult::Fail("Qubit index out of bounds".to_string());
200        }
201
202        // Simplified check: for computational basis states
203        // A separable state has rank-1 reduced density matrix
204        // This is a placeholder - full implementation would compute partial trace
205
206        TestResult::Skip("Full entanglement check not yet implemented".to_string())
207    }
208}
209
210/// Quantum circuit test builder
211pub struct QuantumTest {
212    name: String,
213    setup: Option<Box<dyn Fn() -> Result<(), QuantRS2Error>>>,
214    test: Box<dyn Fn() -> TestResult>,
215    teardown: Option<Box<dyn Fn()>>,
216}
217
218impl QuantumTest {
219    /// Create a new quantum test
220    pub fn new(name: impl Into<String>, test: impl Fn() -> TestResult + 'static) -> Self {
221        Self {
222            name: name.into(),
223            setup: None,
224            test: Box::new(test),
225            teardown: None,
226        }
227    }
228
229    /// Add setup function
230    #[must_use]
231    pub fn with_setup(mut self, setup: impl Fn() -> Result<(), QuantRS2Error> + 'static) -> Self {
232        self.setup = Some(Box::new(setup));
233        self
234    }
235
236    /// Add teardown function
237    #[must_use]
238    pub fn with_teardown(mut self, teardown: impl Fn() + 'static) -> Self {
239        self.teardown = Some(Box::new(teardown));
240        self
241    }
242
243    /// Run the test
244    pub fn run(&self) -> TestResult {
245        // Run setup
246        if let Some(setup) = &self.setup {
247            if let Err(e) = setup() {
248                return TestResult::Fail(format!("Setup failed: {e}"));
249            }
250        }
251
252        // Run test
253        let result = (self.test)();
254
255        // Run teardown
256        if let Some(teardown) = &self.teardown {
257            teardown();
258        }
259
260        result
261    }
262}
263
264/// Test suite for organizing multiple quantum tests
265pub struct QuantumTestSuite {
266    name: String,
267    tests: Vec<QuantumTest>,
268}
269
270impl QuantumTestSuite {
271    /// Create a new test suite
272    pub fn new(name: impl Into<String>) -> Self {
273        Self {
274            name: name.into(),
275            tests: Vec::new(),
276        }
277    }
278
279    /// Add a test to the suite
280    pub fn add_test(&mut self, test: QuantumTest) {
281        self.tests.push(test);
282    }
283
284    /// Run all tests in the suite
285    pub fn run(&self) -> TestSuiteResult {
286        let mut results = Vec::new();
287
288        for test in &self.tests {
289            let result = test.run();
290            results.push((test.name.clone(), result));
291        }
292
293        TestSuiteResult {
294            suite_name: self.name.clone(),
295            results,
296        }
297    }
298}
299
300/// Results from running a test suite
301pub struct TestSuiteResult {
302    suite_name: String,
303    results: Vec<(String, TestResult)>,
304}
305
306impl TestSuiteResult {
307    /// Get number of passed tests
308    pub fn passed(&self) -> usize {
309        self.results.iter().filter(|(_, r)| r.is_pass()).count()
310    }
311
312    /// Get number of failed tests
313    pub fn failed(&self) -> usize {
314        self.results
315            .iter()
316            .filter(|(_, r)| matches!(r, TestResult::Fail(_)))
317            .count()
318    }
319
320    /// Get number of skipped tests
321    pub fn skipped(&self) -> usize {
322        self.results
323            .iter()
324            .filter(|(_, r)| matches!(r, TestResult::Skip(_)))
325            .count()
326    }
327
328    /// Get total number of tests
329    pub fn total(&self) -> usize {
330        self.results.len()
331    }
332
333    /// Check if all tests passed
334    pub fn all_passed(&self) -> bool {
335        self.failed() == 0
336    }
337}
338
339impl fmt::Display for TestSuiteResult {
340    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
341        writeln!(f, "\n{} Test Results:", self.suite_name)?;
342        writeln!(f, "{}", "=".repeat(50))?;
343
344        for (name, result) in &self.results {
345            let status = match result {
346                TestResult::Pass => "✓ PASS",
347                TestResult::Fail(_) => "✗ FAIL",
348                TestResult::Skip(_) => "⊙ SKIP",
349            };
350
351            writeln!(f, "{status:<6} {name}")?;
352
353            if let TestResult::Fail(reason) = result {
354                writeln!(f, "       Reason: {reason}")?;
355            } else if let TestResult::Skip(reason) = result {
356                writeln!(f, "       Reason: {reason}")?;
357            }
358        }
359
360        writeln!(f, "{}", "=".repeat(50))?;
361        writeln!(
362            f,
363            "Total: {} | Passed: {} | Failed: {} | Skipped: {}",
364            self.total(),
365            self.passed(),
366            self.failed(),
367            self.skipped()
368        )?;
369
370        Ok(())
371    }
372}
373
374/// Macros for quantum testing
375#[macro_export]
376macro_rules! quantum_test {
377    ($name:expr, $test:expr) => {
378        QuantumTest::new($name, $test)
379    };
380}
381
382#[macro_export]
383macro_rules! assert_states_equal {
384    ($state1:expr, $state2:expr) => {{
385        let assert = QuantumAssert::default();
386        assert.states_equal($state1, $state2)
387    }};
388    ($state1:expr, $state2:expr, $tolerance:expr) => {{
389        let assert = QuantumAssert::with_tolerance($tolerance);
390        assert.states_equal($state1, $state2)
391    }};
392}
393
394#[macro_export]
395macro_rules! assert_unitary {
396    ($matrix:expr) => {{
397        let assert = QuantumAssert::default();
398        assert.matrix_unitary($matrix)
399    }};
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use scirs2_core::ndarray::array;
406
407    #[test]
408    fn test_quantum_assert_states_equal() {
409        let assert = QuantumAssert::default();
410
411        // Test equal states
412        let state1 = array![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
413        let state2 = state1.clone();
414        assert!(assert.states_equal(&state1, &state2).is_pass());
415
416        // Test states with global phase
417        let state3 = array![Complex64::new(0.0, 1.0), Complex64::new(0.0, 0.0)]; // i|0⟩
418        assert!(assert.states_equal(&state1, &state3).is_pass());
419
420        // Test different states
421        let state4 = array![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)];
422        assert!(!assert.states_equal(&state1, &state4).is_pass());
423    }
424
425    #[test]
426    fn test_quantum_assert_normalized() {
427        let assert = QuantumAssert::default();
428
429        // Normalized state
430        let state1 = array![
431            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
432            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0)
433        ];
434        assert!(assert.state_normalized(&state1).is_pass());
435
436        // Not normalized
437        let state2 = array![Complex64::new(1.0, 0.0), Complex64::new(1.0, 0.0)];
438        assert!(!assert.state_normalized(&state2).is_pass());
439    }
440
441    #[test]
442    fn test_quantum_test_suite() {
443        let mut suite = QuantumTestSuite::new("Example Suite");
444
445        suite.add_test(QuantumTest::new("Test 1", || TestResult::Pass));
446        suite.add_test(QuantumTest::new("Test 2", || {
447            TestResult::Fail("Expected failure".to_string())
448        }));
449        suite.add_test(QuantumTest::new("Test 3", || {
450            TestResult::Skip("Not implemented".to_string())
451        }));
452
453        let results = suite.run();
454        assert_eq!(results.total(), 3);
455        assert_eq!(results.passed(), 1);
456        assert_eq!(results.failed(), 1);
457        assert_eq!(results.skipped(), 1);
458    }
459}