quantrs2_core/
testing.rs

1//! Quantum unit testing framework
2//!
3//! This module provides tools for testing quantum circuits, states, and operations
4//! with proper handling of quantum-specific properties like phase and entanglement.
5
6use crate::complex_ext::QuantumComplexExt;
7use crate::error::QuantRS2Error;
8use ndarray::{Array1, Array2};
9use num_complex::Complex64;
10use std::fmt;
11
12/// Tolerance for quantum state comparisons
13pub const DEFAULT_TOLERANCE: f64 = 1e-10;
14
15/// Result of a quantum test
16#[derive(Debug, Clone, PartialEq)]
17pub enum TestResult {
18    /// Test passed
19    Pass,
20    /// Test failed with reason
21    Fail(String),
22    /// Test skipped
23    Skip(String),
24}
25
26impl TestResult {
27    pub fn is_pass(&self) -> bool {
28        matches!(self, TestResult::Pass)
29    }
30}
31
32/// Quantum state assertion helper
33pub struct QuantumAssert {
34    tolerance: f64,
35}
36
37impl Default for QuantumAssert {
38    fn default() -> Self {
39        Self {
40            tolerance: DEFAULT_TOLERANCE,
41        }
42    }
43}
44
45impl QuantumAssert {
46    /// Create with custom tolerance
47    pub fn with_tolerance(tolerance: f64) -> Self {
48        Self { tolerance }
49    }
50
51    /// Assert two quantum states are equal (up to global phase)
52    pub fn states_equal(
53        &self,
54        state1: &Array1<Complex64>,
55        state2: &Array1<Complex64>,
56    ) -> TestResult {
57        if state1.len() != state2.len() {
58            return TestResult::Fail(format!(
59                "State dimensions mismatch: {} vs {}",
60                state1.len(),
61                state2.len()
62            ));
63        }
64
65        // Find first non-zero amplitude to determine global phase
66        let mut phase_factor = None;
67        for i in 0..state1.len() {
68            if state1[i].norm() > self.tolerance && state2[i].norm() > self.tolerance {
69                phase_factor = Some(state2[i] / state1[i]);
70                break;
71            }
72        }
73
74        let phase = phase_factor.unwrap_or(Complex64::new(1.0, 0.0));
75
76        // Check all amplitudes match after phase correction
77        for i in 0..state1.len() {
78            let expected = state1[i] * phase;
79            if (expected - state2[i]).norm() > self.tolerance {
80                return TestResult::Fail(format!(
81                    "States differ at index {}: expected {}, got {}",
82                    i, expected, state2[i]
83                ));
84            }
85        }
86
87        TestResult::Pass
88    }
89
90    /// Assert a state is normalized
91    pub fn state_normalized(&self, state: &Array1<Complex64>) -> TestResult {
92        let norm_squared: f64 = state.iter().map(|c| c.norm_sqr()).sum();
93
94        if (norm_squared - 1.0).abs() > self.tolerance {
95            TestResult::Fail(format!(
96                "State not normalized: norm^2 = {} (expected 1.0)",
97                norm_squared
98            ))
99        } else {
100            TestResult::Pass
101        }
102    }
103
104    /// Assert two states are orthogonal
105    pub fn states_orthogonal(
106        &self,
107        state1: &Array1<Complex64>,
108        state2: &Array1<Complex64>,
109    ) -> TestResult {
110        if state1.len() != state2.len() {
111            return TestResult::Fail("State dimensions mismatch".to_string());
112        }
113
114        let inner_product: Complex64 = state1
115            .iter()
116            .zip(state2.iter())
117            .map(|(a, b)| a.conj() * b)
118            .sum();
119
120        if inner_product.norm() > self.tolerance {
121            TestResult::Fail(format!(
122                "States not orthogonal: inner product = {}",
123                inner_product
124            ))
125        } else {
126            TestResult::Pass
127        }
128    }
129
130    /// Assert a matrix is unitary
131    pub fn matrix_unitary(&self, matrix: &Array2<Complex64>) -> TestResult {
132        let (rows, cols) = matrix.dim();
133        if rows != cols {
134            return TestResult::Fail(format!("Matrix not square: {}x{}", rows, cols));
135        }
136
137        // Compute U† U
138        let conjugate_transpose = matrix.t().mapv(|c| c.conj());
139        let product = conjugate_transpose.dot(matrix);
140
141        // Check if it's identity
142        for i in 0..rows {
143            for j in 0..cols {
144                let expected = if i == j {
145                    Complex64::new(1.0, 0.0)
146                } else {
147                    Complex64::new(0.0, 0.0)
148                };
149
150                if (product[[i, j]] - expected).norm() > self.tolerance {
151                    return TestResult::Fail(format!(
152                        "U†U not identity at ({},{}): got {}",
153                        i,
154                        j,
155                        product[[i, j]]
156                    ));
157                }
158            }
159        }
160
161        TestResult::Pass
162    }
163
164    /// Assert a state has specific measurement probabilities
165    pub fn measurement_probabilities(
166        &self,
167        state: &Array1<Complex64>,
168        expected_probs: &[(usize, f64)],
169    ) -> TestResult {
170        for &(index, expected_prob) in expected_probs {
171            if index >= state.len() {
172                return TestResult::Fail(format!(
173                    "Index {} out of bounds for state of length {}",
174                    index,
175                    state.len()
176                ));
177            }
178
179            let actual_prob = state[index].probability();
180            if (actual_prob - expected_prob).abs() > self.tolerance {
181                return TestResult::Fail(format!(
182                    "Probability mismatch at index {}: expected {}, got {}",
183                    index, expected_prob, actual_prob
184                ));
185            }
186        }
187
188        TestResult::Pass
189    }
190
191    /// Assert entanglement properties
192    pub fn is_entangled(&self, state: &Array1<Complex64>, qubit_indices: &[usize]) -> TestResult {
193        // For a 2-qubit system, check if state can be written as |ψ⟩ = |a⟩ ⊗ |b⟩
194        if qubit_indices.len() != 2 {
195            return TestResult::Skip(
196                "Entanglement check only implemented for 2-qubit subsystems".to_string(),
197            );
198        }
199
200        let n_qubits = (state.len() as f64).log2() as usize;
201        if qubit_indices.iter().any(|&i| i >= n_qubits) {
202            return TestResult::Fail("Qubit index out of bounds".to_string());
203        }
204
205        // Simplified check: for computational basis states
206        // A separable state has rank-1 reduced density matrix
207        // This is a placeholder - full implementation would compute partial trace
208
209        TestResult::Skip("Full entanglement check not yet implemented".to_string())
210    }
211}
212
213/// Quantum circuit test builder
214pub struct QuantumTest {
215    name: String,
216    setup: Option<Box<dyn Fn() -> Result<(), QuantRS2Error>>>,
217    test: Box<dyn Fn() -> TestResult>,
218    teardown: Option<Box<dyn Fn()>>,
219}
220
221impl QuantumTest {
222    /// Create a new quantum test
223    pub fn new(name: impl Into<String>, test: impl Fn() -> TestResult + 'static) -> Self {
224        Self {
225            name: name.into(),
226            setup: None,
227            test: Box::new(test),
228            teardown: None,
229        }
230    }
231
232    /// Add setup function
233    pub fn with_setup(mut self, setup: impl Fn() -> Result<(), QuantRS2Error> + 'static) -> Self {
234        self.setup = Some(Box::new(setup));
235        self
236    }
237
238    /// Add teardown function
239    pub fn with_teardown(mut self, teardown: impl Fn() + 'static) -> Self {
240        self.teardown = Some(Box::new(teardown));
241        self
242    }
243
244    /// Run the test
245    pub fn run(&self) -> TestResult {
246        // Run setup
247        if let Some(setup) = &self.setup {
248            if let Err(e) = setup() {
249                return TestResult::Fail(format!("Setup failed: {}", e));
250            }
251        }
252
253        // Run test
254        let result = (self.test)();
255
256        // Run teardown
257        if let Some(teardown) = &self.teardown {
258            teardown();
259        }
260
261        result
262    }
263}
264
265/// Test suite for organizing multiple quantum tests
266pub struct QuantumTestSuite {
267    name: String,
268    tests: Vec<QuantumTest>,
269}
270
271impl QuantumTestSuite {
272    /// Create a new test suite
273    pub fn new(name: impl Into<String>) -> Self {
274        Self {
275            name: name.into(),
276            tests: Vec::new(),
277        }
278    }
279
280    /// Add a test to the suite
281    pub fn add_test(&mut self, test: QuantumTest) {
282        self.tests.push(test);
283    }
284
285    /// Run all tests in the suite
286    pub fn run(&self) -> TestSuiteResult {
287        let mut results = Vec::new();
288
289        for test in &self.tests {
290            let result = test.run();
291            results.push((test.name.clone(), result));
292        }
293
294        TestSuiteResult {
295            suite_name: self.name.clone(),
296            results,
297        }
298    }
299}
300
301/// Results from running a test suite
302pub struct TestSuiteResult {
303    suite_name: String,
304    results: Vec<(String, TestResult)>,
305}
306
307impl TestSuiteResult {
308    /// Get number of passed tests
309    pub fn passed(&self) -> usize {
310        self.results.iter().filter(|(_, r)| r.is_pass()).count()
311    }
312
313    /// Get number of failed tests
314    pub fn failed(&self) -> usize {
315        self.results
316            .iter()
317            .filter(|(_, r)| matches!(r, TestResult::Fail(_)))
318            .count()
319    }
320
321    /// Get number of skipped tests
322    pub fn skipped(&self) -> usize {
323        self.results
324            .iter()
325            .filter(|(_, r)| matches!(r, TestResult::Skip(_)))
326            .count()
327    }
328
329    /// Get total number of tests
330    pub fn total(&self) -> usize {
331        self.results.len()
332    }
333
334    /// Check if all tests passed
335    pub fn all_passed(&self) -> bool {
336        self.failed() == 0
337    }
338}
339
340impl fmt::Display for TestSuiteResult {
341    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342        writeln!(f, "\n{} Test Results:", self.suite_name)?;
343        writeln!(f, "{}", "=".repeat(50))?;
344
345        for (name, result) in &self.results {
346            let status = match result {
347                TestResult::Pass => "✓ PASS",
348                TestResult::Fail(_) => "✗ FAIL",
349                TestResult::Skip(_) => "⊙ SKIP",
350            };
351
352            writeln!(f, "{:<6} {}", status, name)?;
353
354            if let TestResult::Fail(reason) = result {
355                writeln!(f, "       Reason: {}", reason)?;
356            } else if let TestResult::Skip(reason) = result {
357                writeln!(f, "       Reason: {}", reason)?;
358            }
359        }
360
361        writeln!(f, "{}", "=".repeat(50))?;
362        writeln!(
363            f,
364            "Total: {} | Passed: {} | Failed: {} | Skipped: {}",
365            self.total(),
366            self.passed(),
367            self.failed(),
368            self.skipped()
369        )?;
370
371        Ok(())
372    }
373}
374
375/// Macros for quantum testing
376#[macro_export]
377macro_rules! quantum_test {
378    ($name:expr, $test:expr) => {
379        QuantumTest::new($name, $test)
380    };
381}
382
383#[macro_export]
384macro_rules! assert_states_equal {
385    ($state1:expr, $state2:expr) => {{
386        let assert = QuantumAssert::default();
387        assert.states_equal($state1, $state2)
388    }};
389    ($state1:expr, $state2:expr, $tolerance:expr) => {{
390        let assert = QuantumAssert::with_tolerance($tolerance);
391        assert.states_equal($state1, $state2)
392    }};
393}
394
395#[macro_export]
396macro_rules! assert_unitary {
397    ($matrix:expr) => {{
398        let assert = QuantumAssert::default();
399        assert.matrix_unitary($matrix)
400    }};
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    use ndarray::array;
407
408    #[test]
409    fn test_quantum_assert_states_equal() {
410        let assert = QuantumAssert::default();
411
412        // Test equal states
413        let state1 = array![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
414        let state2 = state1.clone();
415        assert!(assert.states_equal(&state1, &state2).is_pass());
416
417        // Test states with global phase
418        let state3 = array![Complex64::new(0.0, 1.0), Complex64::new(0.0, 0.0)]; // i|0⟩
419        assert!(assert.states_equal(&state1, &state3).is_pass());
420
421        // Test different states
422        let state4 = array![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)];
423        assert!(!assert.states_equal(&state1, &state4).is_pass());
424    }
425
426    #[test]
427    fn test_quantum_assert_normalized() {
428        let assert = QuantumAssert::default();
429
430        // Normalized state
431        let state1 = array![
432            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
433            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0)
434        ];
435        assert!(assert.state_normalized(&state1).is_pass());
436
437        // Not normalized
438        let state2 = array![Complex64::new(1.0, 0.0), Complex64::new(1.0, 0.0)];
439        assert!(!assert.state_normalized(&state2).is_pass());
440    }
441
442    #[test]
443    fn test_quantum_test_suite() {
444        let mut suite = QuantumTestSuite::new("Example Suite");
445
446        suite.add_test(QuantumTest::new("Test 1", || TestResult::Pass));
447        suite.add_test(QuantumTest::new("Test 2", || {
448            TestResult::Fail("Expected failure".to_string())
449        }));
450        suite.add_test(QuantumTest::new("Test 3", || {
451            TestResult::Skip("Not implemented".to_string())
452        }));
453
454        let results = suite.run();
455        assert_eq!(results.total(), 3);
456        assert_eq!(results.passed(), 1);
457        assert_eq!(results.failed(), 1);
458        assert_eq!(results.skipped(), 1);
459    }
460}