1use crate::complex_ext::QuantumComplexExt;
7use crate::error::QuantRS2Error;
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::Complex64;
10use std::fmt;
11
12pub const DEFAULT_TOLERANCE: f64 = 1e-10;
14
15#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum TestResult {
18 Pass,
20 Fail(String),
22 Skip(String),
24}
25
26impl TestResult {
27 pub const fn is_pass(&self) -> bool {
28 matches!(self, Self::Pass)
29 }
30}
31
32pub struct QuantumAssert {
34 tolerance: f64,
35}
36
37impl Default for QuantumAssert {
38 fn default() -> Self {
39 Self {
40 tolerance: DEFAULT_TOLERANCE,
41 }
42 }
43}
44
45impl QuantumAssert {
46 pub const fn with_tolerance(tolerance: f64) -> Self {
48 Self { tolerance }
49 }
50
51 pub fn states_equal(
53 &self,
54 state1: &Array1<Complex64>,
55 state2: &Array1<Complex64>,
56 ) -> TestResult {
57 if state1.len() != state2.len() {
58 return TestResult::Fail(format!(
59 "State dimensions mismatch: {} vs {}",
60 state1.len(),
61 state2.len()
62 ));
63 }
64
65 let mut phase_factor = None;
67 for i in 0..state1.len() {
68 if state1[i].norm() > self.tolerance && state2[i].norm() > self.tolerance {
69 phase_factor = Some(state2[i] / state1[i]);
70 break;
71 }
72 }
73
74 let phase = phase_factor.unwrap_or(Complex64::new(1.0, 0.0));
75
76 for i in 0..state1.len() {
78 let expected = state1[i] * phase;
79 if (expected - state2[i]).norm() > self.tolerance {
80 return TestResult::Fail(format!(
81 "States differ at index {}: expected {}, got {}",
82 i, expected, state2[i]
83 ));
84 }
85 }
86
87 TestResult::Pass
88 }
89
90 pub fn state_normalized(&self, state: &Array1<Complex64>) -> TestResult {
92 let norm_squared: f64 = state.iter().map(|c| c.norm_sqr()).sum();
93
94 if (norm_squared - 1.0).abs() > self.tolerance {
95 TestResult::Fail(format!(
96 "State not normalized: norm^2 = {norm_squared} (expected 1.0)"
97 ))
98 } else {
99 TestResult::Pass
100 }
101 }
102
103 pub fn states_orthogonal(
105 &self,
106 state1: &Array1<Complex64>,
107 state2: &Array1<Complex64>,
108 ) -> TestResult {
109 if state1.len() != state2.len() {
110 return TestResult::Fail("State dimensions mismatch".to_string());
111 }
112
113 let inner_product: Complex64 = state1
114 .iter()
115 .zip(state2.iter())
116 .map(|(a, b)| a.conj() * b)
117 .sum();
118
119 if inner_product.norm() > self.tolerance {
120 TestResult::Fail(format!(
121 "States not orthogonal: inner product = {inner_product}"
122 ))
123 } else {
124 TestResult::Pass
125 }
126 }
127
128 pub fn matrix_unitary(&self, matrix: &Array2<Complex64>) -> TestResult {
130 let (rows, cols) = matrix.dim();
131 if rows != cols {
132 return TestResult::Fail(format!("Matrix not square: {rows}x{cols}"));
133 }
134
135 let conjugate_transpose = matrix.t().mapv(|c| c.conj());
137 let product = conjugate_transpose.dot(matrix);
138
139 for i in 0..rows {
141 for j in 0..cols {
142 let expected = if i == j {
143 Complex64::new(1.0, 0.0)
144 } else {
145 Complex64::new(0.0, 0.0)
146 };
147
148 if (product[[i, j]] - expected).norm() > self.tolerance {
149 return TestResult::Fail(format!(
150 "U†U not identity at ({},{}): got {}",
151 i,
152 j,
153 product[[i, j]]
154 ));
155 }
156 }
157 }
158
159 TestResult::Pass
160 }
161
162 pub fn measurement_probabilities(
164 &self,
165 state: &Array1<Complex64>,
166 expected_probs: &[(usize, f64)],
167 ) -> TestResult {
168 for &(index, expected_prob) in expected_probs {
169 if index >= state.len() {
170 return TestResult::Fail(format!(
171 "Index {} out of bounds for state of length {}",
172 index,
173 state.len()
174 ));
175 }
176
177 let actual_prob = state[index].probability();
178 if (actual_prob - expected_prob).abs() > self.tolerance {
179 return TestResult::Fail(format!(
180 "Probability mismatch at index {index}: expected {expected_prob}, got {actual_prob}"
181 ));
182 }
183 }
184
185 TestResult::Pass
186 }
187
188 pub fn is_entangled(&self, state: &Array1<Complex64>, qubit_indices: &[usize]) -> TestResult {
190 if qubit_indices.len() != 2 {
192 return TestResult::Skip(
193 "Entanglement check only implemented for 2-qubit subsystems".to_string(),
194 );
195 }
196
197 let n_qubits = (state.len() as f64).log2() as usize;
198 if qubit_indices.iter().any(|&i| i >= n_qubits) {
199 return TestResult::Fail("Qubit index out of bounds".to_string());
200 }
201
202 TestResult::Skip("Full entanglement check not yet implemented".to_string())
207 }
208}
209
210pub struct QuantumTest {
212 name: String,
213 setup: Option<Box<dyn Fn() -> Result<(), QuantRS2Error>>>,
214 test: Box<dyn Fn() -> TestResult>,
215 teardown: Option<Box<dyn Fn()>>,
216}
217
218impl QuantumTest {
219 pub fn new(name: impl Into<String>, test: impl Fn() -> TestResult + 'static) -> Self {
221 Self {
222 name: name.into(),
223 setup: None,
224 test: Box::new(test),
225 teardown: None,
226 }
227 }
228
229 #[must_use]
231 pub fn with_setup(mut self, setup: impl Fn() -> Result<(), QuantRS2Error> + 'static) -> Self {
232 self.setup = Some(Box::new(setup));
233 self
234 }
235
236 #[must_use]
238 pub fn with_teardown(mut self, teardown: impl Fn() + 'static) -> Self {
239 self.teardown = Some(Box::new(teardown));
240 self
241 }
242
243 pub fn run(&self) -> TestResult {
245 if let Some(setup) = &self.setup {
247 if let Err(e) = setup() {
248 return TestResult::Fail(format!("Setup failed: {e}"));
249 }
250 }
251
252 let result = (self.test)();
254
255 if let Some(teardown) = &self.teardown {
257 teardown();
258 }
259
260 result
261 }
262}
263
264pub struct QuantumTestSuite {
266 name: String,
267 tests: Vec<QuantumTest>,
268}
269
270impl QuantumTestSuite {
271 pub fn new(name: impl Into<String>) -> Self {
273 Self {
274 name: name.into(),
275 tests: Vec::new(),
276 }
277 }
278
279 pub fn add_test(&mut self, test: QuantumTest) {
281 self.tests.push(test);
282 }
283
284 pub fn run(&self) -> TestSuiteResult {
286 let mut results = Vec::new();
287
288 for test in &self.tests {
289 let result = test.run();
290 results.push((test.name.clone(), result));
291 }
292
293 TestSuiteResult {
294 suite_name: self.name.clone(),
295 results,
296 }
297 }
298}
299
300pub struct TestSuiteResult {
302 suite_name: String,
303 results: Vec<(String, TestResult)>,
304}
305
306impl TestSuiteResult {
307 pub fn passed(&self) -> usize {
309 self.results.iter().filter(|(_, r)| r.is_pass()).count()
310 }
311
312 pub fn failed(&self) -> usize {
314 self.results
315 .iter()
316 .filter(|(_, r)| matches!(r, TestResult::Fail(_)))
317 .count()
318 }
319
320 pub fn skipped(&self) -> usize {
322 self.results
323 .iter()
324 .filter(|(_, r)| matches!(r, TestResult::Skip(_)))
325 .count()
326 }
327
328 pub fn total(&self) -> usize {
330 self.results.len()
331 }
332
333 pub fn all_passed(&self) -> bool {
335 self.failed() == 0
336 }
337}
338
339impl fmt::Display for TestSuiteResult {
340 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
341 writeln!(f, "\n{} Test Results:", self.suite_name)?;
342 writeln!(f, "{}", "=".repeat(50))?;
343
344 for (name, result) in &self.results {
345 let status = match result {
346 TestResult::Pass => "✓ PASS",
347 TestResult::Fail(_) => "✗ FAIL",
348 TestResult::Skip(_) => "⊙ SKIP",
349 };
350
351 writeln!(f, "{status:<6} {name}")?;
352
353 if let TestResult::Fail(reason) = result {
354 writeln!(f, " Reason: {reason}")?;
355 } else if let TestResult::Skip(reason) = result {
356 writeln!(f, " Reason: {reason}")?;
357 }
358 }
359
360 writeln!(f, "{}", "=".repeat(50))?;
361 writeln!(
362 f,
363 "Total: {} | Passed: {} | Failed: {} | Skipped: {}",
364 self.total(),
365 self.passed(),
366 self.failed(),
367 self.skipped()
368 )?;
369
370 Ok(())
371 }
372}
373
374#[macro_export]
376macro_rules! quantum_test {
377 ($name:expr, $test:expr) => {
378 QuantumTest::new($name, $test)
379 };
380}
381
382#[macro_export]
383macro_rules! assert_states_equal {
384 ($state1:expr, $state2:expr) => {{
385 let assert = QuantumAssert::default();
386 assert.states_equal($state1, $state2)
387 }};
388 ($state1:expr, $state2:expr, $tolerance:expr) => {{
389 let assert = QuantumAssert::with_tolerance($tolerance);
390 assert.states_equal($state1, $state2)
391 }};
392}
393
394#[macro_export]
395macro_rules! assert_unitary {
396 ($matrix:expr) => {{
397 let assert = QuantumAssert::default();
398 assert.matrix_unitary($matrix)
399 }};
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use scirs2_core::ndarray::array;
406
407 #[test]
408 fn test_quantum_assert_states_equal() {
409 let assert = QuantumAssert::default();
410
411 let state1 = array![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
413 let state2 = state1.clone();
414 assert!(assert.states_equal(&state1, &state2).is_pass());
415
416 let state3 = array![Complex64::new(0.0, 1.0), Complex64::new(0.0, 0.0)]; assert!(assert.states_equal(&state1, &state3).is_pass());
419
420 let state4 = array![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)];
422 assert!(!assert.states_equal(&state1, &state4).is_pass());
423 }
424
425 #[test]
426 fn test_quantum_assert_normalized() {
427 let assert = QuantumAssert::default();
428
429 let state1 = array![
431 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
432 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0)
433 ];
434 assert!(assert.state_normalized(&state1).is_pass());
435
436 let state2 = array![Complex64::new(1.0, 0.0), Complex64::new(1.0, 0.0)];
438 assert!(!assert.state_normalized(&state2).is_pass());
439 }
440
441 #[test]
442 fn test_quantum_test_suite() {
443 let mut suite = QuantumTestSuite::new("Example Suite");
444
445 suite.add_test(QuantumTest::new("Test 1", || TestResult::Pass));
446 suite.add_test(QuantumTest::new("Test 2", || {
447 TestResult::Fail("Expected failure".to_string())
448 }));
449 suite.add_test(QuantumTest::new("Test 3", || {
450 TestResult::Skip("Not implemented".to_string())
451 }));
452
453 let results = suite.run();
454 assert_eq!(results.total(), 3);
455 assert_eq!(results.passed(), 1);
456 assert_eq!(results.failed(), 1);
457 assert_eq!(results.skipped(), 1);
458 }
459}