Skip to main content

quantrs2_sim/
state_vector_simd.rs

1//! SIMD-accelerated single-qubit gate kernels for state vector simulation.
2//!
3//! This module provides a high-level public API for applying individual quantum gates
4//! directly to `Vec<Complex64>` state vectors. It owns the gather/scatter logic
5//! internally so callers do not need to split amplitudes into separate `in_amps0` /
6//! `in_amps1` buffers themselves.
7//!
8//! Internally the implementations delegate to the existing
9//! `crate::optimized_simd` SIMD primitives (backed by
10//! `scirs2_core::simd_ops::SimdUnifiedOps`) for all vector arithmetic, with
11//! a scalar fallback for small state vectors (< 256 amplitudes, i.e. < 8 qubits).
12//!
13//! ## Usage example
14//!
15//! ```no_run
16//! use quantrs2_sim::state_vector_simd::{apply_h_simd, apply_x_simd};
17//! use scirs2_core::Complex64;
18//!
19//! let n_qubits = 4usize;
20//! let mut state = vec![Complex64::new(0.0, 0.0); 1 << n_qubits];
21//! state[0] = Complex64::new(1.0, 0.0); // |0000⟩
22//!
23//! apply_h_simd(&mut state, 0, n_qubits);   // qubit 0 → |+⟩
24//! apply_x_simd(&mut state, 1, n_qubits);   // qubit 1 → |1⟩ (CNOT-like)
25//! ```
26//!
27//! These SIMD kernels are provided as a standalone module and can be called
28//! directly. Integration into the main `StateVectorSimulator` dispatch already
29//! exists in `statevector.rs` via `apply_single_qubit_gate_simd`; this module
30//! adds the named, standalone API surface (see also TODO.md).
31
32use scirs2_core::Complex64;
33
34use crate::optimized_simd::{
35    apply_h_gate_simd, apply_rx_gate_simd, apply_ry_gate_simd, apply_rz_gate_simd,
36    apply_s_gate_simd, apply_single_qubit_gate_optimized, apply_t_gate_simd, apply_x_gate_simd,
37    apply_y_gate_simd, apply_z_gate_simd,
38};
39
40// ============================================================================
41// Internal gather / scatter helpers
42// ============================================================================
43
44/// Gather amplitudes for a given target qubit into `out0` (bit=0) and `out1` (bit=1)
45/// buffers.  Returns the number of pairs gathered (= `n_states / 2`).
46fn gather_pairs(
47    state: &[Complex64],
48    target: usize,
49    n_qubits: usize,
50    out0: &mut Vec<Complex64>,
51    out1: &mut Vec<Complex64>,
52) -> usize {
53    let n_states = 1usize << n_qubits;
54    let stride = 1usize << target;
55    let total_pairs = n_states / 2;
56
57    out0.clear();
58    out1.clear();
59    out0.reserve(total_pairs);
60    out1.reserve(total_pairs);
61
62    let mut i = 0usize;
63    while i < n_states {
64        for j in i..(i + stride) {
65            out0.push(state[j]);
66            out1.push(state[j + stride]);
67        }
68        i += 2 * stride;
69    }
70
71    total_pairs
72}
73
74/// Scatter computed amplitude pairs back into `state`.
75fn scatter_pairs(
76    state: &mut [Complex64],
77    target: usize,
78    n_qubits: usize,
79    src0: &[Complex64],
80    src1: &[Complex64],
81) {
82    let n_states = 1usize << n_qubits;
83    let stride = 1usize << target;
84
85    let mut pair_idx = 0usize;
86    let mut i = 0usize;
87    while i < n_states {
88        for j in i..(i + stride) {
89            state[j] = src0[pair_idx];
90            state[j + stride] = src1[pair_idx];
91            pair_idx += 1;
92        }
93        i += 2 * stride;
94    }
95}
96
97// ============================================================================
98// Threshold
99// ============================================================================
100
101/// Minimum number of amplitudes for which the SIMD path is used.
102/// Below this threshold the scalar fallback is cheaper.
103const SIMD_THRESHOLD: usize = 256; // 8 qubits
104
105// ============================================================================
106// Scalar fallback for 2×2 unitary gate
107// ============================================================================
108
109/// Pure-scalar application of a 2×2 unitary `matrix` to qubit `target`.
110///
111/// Uses the stride-based pair traversal; always correct regardless of
112/// `n_qubits` or `target`.
113pub fn apply_gate_2x2_scalar(
114    state: &mut [Complex64],
115    matrix: [[Complex64; 2]; 2],
116    target: usize,
117    n_qubits: usize,
118) {
119    let stride = 1usize << target;
120    let n_states = 1usize << n_qubits;
121    let [[a, b], [c, d]] = matrix;
122
123    let mut i = 0usize;
124    while i < n_states {
125        for j in i..(i + stride) {
126            let zero = state[j];
127            let one = state[j + stride];
128            state[j] = a * zero + b * one;
129            state[j + stride] = c * zero + d * one;
130        }
131        i += 2 * stride;
132    }
133}
134
135// ============================================================================
136// Public API — generic dispatcher
137// ============================================================================
138
139/// Apply a generic 2×2 unitary gate to qubit `target` using SIMD acceleration.
140///
141/// `matrix` is given as `[[a, b], [c, d]]` where the new amplitudes are:
142///   - new_|0⟩ = a·old_|0⟩ + b·old_|1⟩
143///   - new_|1⟩ = c·old_|0⟩ + d·old_|1⟩
144///
145/// Falls back to scalar arithmetic for `state.len() < 256`.
146pub fn apply_gate_2x2_simd(
147    state: &mut Vec<Complex64>,
148    matrix: [[Complex64; 2]; 2],
149    target: usize,
150    n_qubits: usize,
151) {
152    if state.len() < SIMD_THRESHOLD {
153        apply_gate_2x2_scalar(state, matrix, target, n_qubits);
154        return;
155    }
156
157    let flat = [matrix[0][0], matrix[0][1], matrix[1][0], matrix[1][1]];
158
159    let mut amps0 = Vec::with_capacity(state.len() / 2);
160    let mut amps1 = Vec::with_capacity(state.len() / 2);
161    let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
162
163    let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
164    let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
165
166    apply_single_qubit_gate_optimized(&flat, &amps0, &amps1, &mut out0, &mut out1);
167    scatter_pairs(state, target, n_qubits, &out0, &out1);
168}
169
170// ============================================================================
171// Named single-qubit gates
172// ============================================================================
173
174/// Apply the Hadamard gate H to qubit `target` using SIMD.
175///
176/// H = (1/√2) [[1, 1], [1, -1]]
177pub fn apply_h_simd(state: &mut Vec<Complex64>, target: usize, n_qubits: usize) {
178    if state.len() < SIMD_THRESHOLD {
179        use std::f64::consts::FRAC_1_SQRT_2;
180        let h = [
181            [
182                Complex64::new(FRAC_1_SQRT_2, 0.0),
183                Complex64::new(FRAC_1_SQRT_2, 0.0),
184            ],
185            [
186                Complex64::new(FRAC_1_SQRT_2, 0.0),
187                Complex64::new(-FRAC_1_SQRT_2, 0.0),
188            ],
189        ];
190        apply_gate_2x2_scalar(state, h, target, n_qubits);
191        return;
192    }
193
194    let mut amps0 = Vec::with_capacity(state.len() / 2);
195    let mut amps1 = Vec::with_capacity(state.len() / 2);
196    let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
197
198    let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
199    let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
200
201    apply_h_gate_simd(&amps0, &amps1, &mut out0, &mut out1);
202    scatter_pairs(state, target, n_qubits, &out0, &out1);
203}
204
205/// Apply the Pauli-X (NOT) gate to qubit `target` using SIMD.
206///
207/// X = [[0, 1], [1, 0]]
208pub fn apply_x_simd(state: &mut Vec<Complex64>, target: usize, n_qubits: usize) {
209    if state.len() < SIMD_THRESHOLD {
210        apply_gate_2x2_scalar(
211            state,
212            [
213                [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
214                [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
215            ],
216            target,
217            n_qubits,
218        );
219        return;
220    }
221
222    let mut amps0 = Vec::with_capacity(state.len() / 2);
223    let mut amps1 = Vec::with_capacity(state.len() / 2);
224    let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
225
226    let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
227    let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
228
229    apply_x_gate_simd(&amps0, &amps1, &mut out0, &mut out1);
230    scatter_pairs(state, target, n_qubits, &out0, &out1);
231}
232
233/// Apply the Pauli-Y gate to qubit `target` using SIMD.
234///
235/// Y = [[0, -i], [i, 0]]
236pub fn apply_y_simd(state: &mut Vec<Complex64>, target: usize, n_qubits: usize) {
237    if state.len() < SIMD_THRESHOLD {
238        apply_gate_2x2_scalar(
239            state,
240            [
241                [Complex64::new(0.0, 0.0), Complex64::new(0.0, -1.0)],
242                [Complex64::new(0.0, 1.0), Complex64::new(0.0, 0.0)],
243            ],
244            target,
245            n_qubits,
246        );
247        return;
248    }
249
250    let mut amps0 = Vec::with_capacity(state.len() / 2);
251    let mut amps1 = Vec::with_capacity(state.len() / 2);
252    let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
253
254    let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
255    let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
256
257    apply_y_gate_simd(&amps0, &amps1, &mut out0, &mut out1);
258    scatter_pairs(state, target, n_qubits, &out0, &out1);
259}
260
261/// Apply the Pauli-Z gate to qubit `target` using SIMD.
262///
263/// Z = [[1, 0], [0, -1]]
264pub fn apply_z_simd(state: &mut Vec<Complex64>, target: usize, n_qubits: usize) {
265    if state.len() < SIMD_THRESHOLD {
266        apply_gate_2x2_scalar(
267            state,
268            [
269                [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
270                [Complex64::new(0.0, 0.0), Complex64::new(-1.0, 0.0)],
271            ],
272            target,
273            n_qubits,
274        );
275        return;
276    }
277
278    let mut amps0 = Vec::with_capacity(state.len() / 2);
279    let mut amps1 = Vec::with_capacity(state.len() / 2);
280    let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
281
282    let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
283    let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
284
285    apply_z_gate_simd(&amps0, &amps1, &mut out0, &mut out1);
286    scatter_pairs(state, target, n_qubits, &out0, &out1);
287}
288
289/// Apply the S (phase) gate to qubit `target` using SIMD.
290///
291/// S = [[1, 0], [0, i]]
292pub fn apply_s_simd(state: &mut Vec<Complex64>, target: usize, n_qubits: usize) {
293    if state.len() < SIMD_THRESHOLD {
294        apply_gate_2x2_scalar(
295            state,
296            [
297                [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
298                [Complex64::new(0.0, 0.0), Complex64::new(0.0, 1.0)],
299            ],
300            target,
301            n_qubits,
302        );
303        return;
304    }
305
306    let mut amps0 = Vec::with_capacity(state.len() / 2);
307    let mut amps1 = Vec::with_capacity(state.len() / 2);
308    let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
309
310    let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
311    let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
312
313    apply_s_gate_simd(&amps0, &amps1, &mut out0, &mut out1);
314    scatter_pairs(state, target, n_qubits, &out0, &out1);
315}
316
317/// Apply the T gate to qubit `target` using SIMD.
318///
319/// T = [[1, 0], [0, exp(iπ/4)]]
320pub fn apply_t_simd(state: &mut Vec<Complex64>, target: usize, n_qubits: usize) {
321    if state.len() < SIMD_THRESHOLD {
322        use std::f64::consts::FRAC_PI_4;
323        apply_gate_2x2_scalar(
324            state,
325            [
326                [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
327                [
328                    Complex64::new(0.0, 0.0),
329                    Complex64::new(FRAC_PI_4.cos(), FRAC_PI_4.sin()),
330                ],
331            ],
332            target,
333            n_qubits,
334        );
335        return;
336    }
337
338    let mut amps0 = Vec::with_capacity(state.len() / 2);
339    let mut amps1 = Vec::with_capacity(state.len() / 2);
340    let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
341
342    let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
343    let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
344
345    apply_t_gate_simd(&amps0, &amps1, &mut out0, &mut out1);
346    scatter_pairs(state, target, n_qubits, &out0, &out1);
347}
348
349/// Apply the RX(theta) rotation gate to qubit `target` using SIMD.
350///
351/// RX(θ) = [[cos(θ/2), −i·sin(θ/2)], [−i·sin(θ/2), cos(θ/2)]]
352pub fn apply_rx_simd(state: &mut Vec<Complex64>, theta: f64, target: usize, n_qubits: usize) {
353    if state.len() < SIMD_THRESHOLD {
354        let h = theta / 2.0;
355        apply_gate_2x2_scalar(
356            state,
357            [
358                [Complex64::new(h.cos(), 0.0), Complex64::new(0.0, -h.sin())],
359                [Complex64::new(0.0, -h.sin()), Complex64::new(h.cos(), 0.0)],
360            ],
361            target,
362            n_qubits,
363        );
364        return;
365    }
366
367    let mut amps0 = Vec::with_capacity(state.len() / 2);
368    let mut amps1 = Vec::with_capacity(state.len() / 2);
369    let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
370
371    let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
372    let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
373
374    apply_rx_gate_simd(theta, &amps0, &amps1, &mut out0, &mut out1);
375    scatter_pairs(state, target, n_qubits, &out0, &out1);
376}
377
378/// Apply the RY(theta) rotation gate to qubit `target` using SIMD.
379///
380/// RY(θ) = [[cos(θ/2), −sin(θ/2)], [sin(θ/2), cos(θ/2)]]
381pub fn apply_ry_simd(state: &mut Vec<Complex64>, theta: f64, target: usize, n_qubits: usize) {
382    if state.len() < SIMD_THRESHOLD {
383        let h = theta / 2.0;
384        apply_gate_2x2_scalar(
385            state,
386            [
387                [Complex64::new(h.cos(), 0.0), Complex64::new(-h.sin(), 0.0)],
388                [Complex64::new(h.sin(), 0.0), Complex64::new(h.cos(), 0.0)],
389            ],
390            target,
391            n_qubits,
392        );
393        return;
394    }
395
396    let mut amps0 = Vec::with_capacity(state.len() / 2);
397    let mut amps1 = Vec::with_capacity(state.len() / 2);
398    let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
399
400    let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
401    let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
402
403    apply_ry_gate_simd(theta, &amps0, &amps1, &mut out0, &mut out1);
404    scatter_pairs(state, target, n_qubits, &out0, &out1);
405}
406
407/// Apply the RZ(theta) rotation gate to qubit `target` using SIMD.
408///
409/// RZ(θ) = [[exp(−iθ/2), 0], [0, exp(iθ/2)]]
410pub fn apply_rz_simd(state: &mut Vec<Complex64>, theta: f64, target: usize, n_qubits: usize) {
411    if state.len() < SIMD_THRESHOLD {
412        let h = theta / 2.0;
413        apply_gate_2x2_scalar(
414            state,
415            [
416                [Complex64::new(h.cos(), -h.sin()), Complex64::new(0.0, 0.0)],
417                [Complex64::new(0.0, 0.0), Complex64::new(h.cos(), h.sin())],
418            ],
419            target,
420            n_qubits,
421        );
422        return;
423    }
424
425    let mut amps0 = Vec::with_capacity(state.len() / 2);
426    let mut amps1 = Vec::with_capacity(state.len() / 2);
427    let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
428
429    let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
430    let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
431
432    apply_rz_gate_simd(theta, &amps0, &amps1, &mut out0, &mut out1);
433    scatter_pairs(state, target, n_qubits, &out0, &out1);
434}
435
436// ============================================================================
437// Runtime SIMD capability detection
438// ============================================================================
439
440/// Returns `true` when SIMD acceleration is available at runtime on this CPU.
441///
442/// On x86_64: requires AVX2.  On aarch64: NEON is always available.
443/// On other architectures: returns `false`.
444pub fn simd_available() -> bool {
445    #[cfg(target_arch = "x86_64")]
446    {
447        std::arch::is_x86_feature_detected!("avx2")
448    }
449    #[cfg(target_arch = "aarch64")]
450    {
451        true
452    }
453    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
454    {
455        false
456    }
457}
458
459// ============================================================================
460// In-file unit tests
461// ============================================================================
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466    use std::f64::consts::{FRAC_1_SQRT_2, PI};
467
468    /// Build an n-qubit |0...0⟩ state.
469    fn zero_state(n: usize) -> Vec<Complex64> {
470        let mut s = vec![Complex64::new(0.0, 0.0); 1 << n];
471        s[0] = Complex64::new(1.0, 0.0);
472        s
473    }
474
475    /// Maximum L2 distance between two state vectors.
476    fn max_diff(a: &[Complex64], b: &[Complex64]) -> f64 {
477        a.iter()
478            .zip(b.iter())
479            .map(|(x, y)| (x - y).norm())
480            .fold(0.0_f64, f64::max)
481    }
482
483    // -----------------------------------------------------------------------
484    // Scalar fallback tests (small state, n < 8)
485    // -----------------------------------------------------------------------
486
487    #[test]
488    fn test_h_gate_zero_state() {
489        // H|0⟩ = (|0⟩ + |1⟩)/√2
490        let mut state = zero_state(1);
491        apply_h_simd(&mut state, 0, 1);
492
493        assert!(
494            (state[0] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-12,
495            "H|0> amplitude of |0> wrong: {:?}",
496            state[0]
497        );
498        assert!(
499            (state[1] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-12,
500            "H|0> amplitude of |1> wrong: {:?}",
501            state[1]
502        );
503    }
504
505    #[test]
506    fn test_x_gate() {
507        // X|0⟩ = |1⟩
508        let mut state = zero_state(2);
509        apply_x_simd(&mut state, 0, 2);
510
511        // Qubit 0 flipped → index 1 is |1⟩
512        assert!(
513            (state[0] - Complex64::new(0.0, 0.0)).norm() < 1e-12,
514            "X|0>: state[0] should be 0"
515        );
516        assert!(
517            (state[1] - Complex64::new(1.0, 0.0)).norm() < 1e-12,
518            "X|0>: state[1] should be 1"
519        );
520    }
521
522    #[test]
523    fn test_z_gate_on_plus_state() {
524        // Prepare |+⟩ then apply Z → |−⟩
525        // |+⟩ = (|0⟩ + |1⟩)/√2  after H on |0⟩
526        let mut state = zero_state(1);
527        apply_h_simd(&mut state, 0, 1);
528        apply_z_simd(&mut state, 0, 1);
529
530        // Z|+⟩ = |−⟩ = (|0⟩ - |1⟩)/√2
531        assert!(
532            (state[0] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-12,
533            "Z|+>: state[0] wrong"
534        );
535        assert!(
536            (state[1] - Complex64::new(-FRAC_1_SQRT_2, 0.0)).norm() < 1e-12,
537            "Z|+>: state[1] wrong"
538        );
539    }
540
541    #[test]
542    fn test_rx_half_pi() {
543        // RX(π/2)|0⟩ = cos(π/4)|0⟩ − i·sin(π/4)|1⟩
544        let theta = PI / 2.0;
545        let mut state = zero_state(1);
546        apply_rx_simd(&mut state, theta, 0, 1);
547
548        let expected0 = Complex64::new((theta / 2.0).cos(), 0.0);
549        let expected1 = Complex64::new(0.0, -(theta / 2.0).sin());
550
551        assert!(
552            (state[0] - expected0).norm() < 1e-12,
553            "RX(π/2)|0>: state[0] wrong: {:?}",
554            state[0]
555        );
556        assert!(
557            (state[1] - expected1).norm() < 1e-12,
558            "RX(π/2)|0>: state[1] wrong: {:?}",
559            state[1]
560        );
561    }
562
563    #[test]
564    fn test_ry_pi() {
565        // RY(π)|0⟩ ≈ |1⟩  (up to global phase: sin(π/2)=1, cos(π/2)=0)
566        let mut state = zero_state(1);
567        apply_ry_simd(&mut state, PI, 0, 1);
568
569        assert!(
570            state[0].norm() < 1e-12,
571            "RY(π)|0>: state[0] should be ~0, got {:?}",
572            state[0]
573        );
574        assert!(
575            (state[1] - Complex64::new(1.0, 0.0)).norm() < 1e-12,
576            "RY(π)|0>: state[1] should be ~1, got {:?}",
577            state[1]
578        );
579    }
580
581    #[test]
582    fn test_s_gate() {
583        // S|1⟩ = i|1⟩
584        let mut state = zero_state(1);
585        apply_x_simd(&mut state, 0, 1); // |1⟩
586        apply_s_simd(&mut state, 0, 1);
587
588        assert!(state[0].norm() < 1e-12, "S|1>: state[0] should be 0");
589        assert!(
590            (state[1] - Complex64::new(0.0, 1.0)).norm() < 1e-12,
591            "S|1>: state[1] should be i"
592        );
593    }
594
595    #[test]
596    fn test_t_gate() {
597        // T|1⟩ = exp(iπ/4)|1⟩
598        use std::f64::consts::FRAC_PI_4;
599        let mut state = zero_state(1);
600        apply_x_simd(&mut state, 0, 1); // |1⟩
601        apply_t_simd(&mut state, 0, 1);
602
603        let expected = Complex64::new(FRAC_PI_4.cos(), FRAC_PI_4.sin());
604        assert!(state[0].norm() < 1e-12, "T|1>: state[0] should be 0");
605        assert!((state[1] - expected).norm() < 1e-12, "T|1>: state[1] wrong");
606    }
607
608    // -----------------------------------------------------------------------
609    // SIMD vs scalar consistency — 6-qubit random state (uses SIMD path)
610    // -----------------------------------------------------------------------
611
612    /// Simple LCG for reproducible random states without external rand dep.
613    fn lcg_random_state(n_qubits: usize, seed: u64) -> Vec<Complex64> {
614        let mut rng = seed;
615        let mut state: Vec<Complex64> = (0..(1usize << n_qubits))
616            .map(|_| {
617                rng = rng
618                    .wrapping_mul(6_364_136_223_846_793_005)
619                    .wrapping_add(1_442_695_040_888_963_407);
620                let re = (rng as f64) / (u64::MAX as f64) * 2.0 - 1.0;
621                rng = rng
622                    .wrapping_mul(6_364_136_223_846_793_005)
623                    .wrapping_add(1_442_695_040_888_963_407);
624                let im = (rng as f64) / (u64::MAX as f64) * 2.0 - 1.0;
625                Complex64::new(re, im)
626            })
627            .collect();
628
629        // Normalize
630        let norm: f64 = state.iter().map(|c| c.norm_sqr()).sum::<f64>().sqrt();
631        state.iter_mut().for_each(|c| *c /= norm);
632        state
633    }
634
635    #[test]
636    fn test_simd_vs_scalar_h() {
637        let n = 6usize;
638        let base = lcg_random_state(n, 42);
639
640        for target in 0..n {
641            let mut simd_state = base.clone();
642            let mut scalar_state = base.clone();
643
644            apply_h_simd(&mut simd_state, target, n);
645
646            let h = [
647                [
648                    Complex64::new(FRAC_1_SQRT_2, 0.0),
649                    Complex64::new(FRAC_1_SQRT_2, 0.0),
650                ],
651                [
652                    Complex64::new(FRAC_1_SQRT_2, 0.0),
653                    Complex64::new(-FRAC_1_SQRT_2, 0.0),
654                ],
655            ];
656            apply_gate_2x2_scalar(&mut scalar_state, h, target, n);
657
658            let diff = max_diff(&simd_state, &scalar_state);
659            assert!(
660                diff < 1e-12,
661                "SIMD vs scalar H mismatch at target={}: max_diff={}",
662                target,
663                diff
664            );
665        }
666    }
667
668    #[test]
669    fn test_simd_vs_scalar_x() {
670        let n = 6usize;
671        let base = lcg_random_state(n, 123);
672        let x_mat = [
673            [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
674            [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
675        ];
676
677        for target in 0..n {
678            let mut simd_state = base.clone();
679            let mut scalar_state = base.clone();
680
681            apply_x_simd(&mut simd_state, target, n);
682            apply_gate_2x2_scalar(&mut scalar_state, x_mat, target, n);
683
684            let diff = max_diff(&simd_state, &scalar_state);
685            assert!(
686                diff < 1e-12,
687                "SIMD vs scalar X mismatch at target={}: max_diff={}",
688                target,
689                diff
690            );
691        }
692    }
693
694    #[test]
695    fn test_simd_vs_scalar_rz() {
696        let n = 6usize;
697        let base = lcg_random_state(n, 999);
698        let theta = 1.23456_f64;
699        let h = theta / 2.0;
700        let rz_mat = [
701            [Complex64::new(h.cos(), -h.sin()), Complex64::new(0.0, 0.0)],
702            [Complex64::new(0.0, 0.0), Complex64::new(h.cos(), h.sin())],
703        ];
704
705        for target in 0..n {
706            let mut simd_state = base.clone();
707            let mut scalar_state = base.clone();
708
709            apply_rz_simd(&mut simd_state, theta, target, n);
710            apply_gate_2x2_scalar(&mut scalar_state, rz_mat, target, n);
711
712            let diff = max_diff(&simd_state, &scalar_state);
713            assert!(
714                diff < 1e-12,
715                "SIMD vs scalar RZ mismatch at target={}: max_diff={}",
716                target,
717                diff
718            );
719        }
720    }
721
722    #[test]
723    fn test_gate_2x2_simd_identity() {
724        // Applying identity should leave state unchanged.
725        let n = 4usize;
726        let mut state = lcg_random_state(n, 7);
727        let original = state.clone();
728        let id = [
729            [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
730            [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
731        ];
732        apply_gate_2x2_scalar(&mut state, id, 0, n);
733        let diff = max_diff(&state, &original);
734        assert!(
735            diff < 1e-15,
736            "Identity gate altered state: max_diff={}",
737            diff
738        );
739    }
740
741    #[test]
742    fn test_y_gate_eigenvalue() {
743        // Y|+y⟩ = |+y⟩ where |+y⟩ = (|0⟩ + i|1⟩)/√2
744        let mut state = vec![
745            Complex64::new(FRAC_1_SQRT_2, 0.0),
746            Complex64::new(0.0, FRAC_1_SQRT_2),
747        ];
748        let original = state.clone();
749        apply_y_simd(&mut state, 0, 1);
750
751        // Y|+y⟩ = |+y⟩  (eigenvalue +1)
752        let diff = max_diff(&state, &original);
753        assert!(
754            diff < 1e-12,
755            "Y eigenstate property failed: max_diff={}",
756            diff
757        );
758    }
759}