quantrs2_sim/
optimized_simd.rs

1//! SIMD-accelerated operations for quantum state vector simulation
2//!
3//! This module provides SIMD-optimized implementations of quantum gate operations
4//! for improved performance on modern CPUs using `SciRS2` SIMD operations.
5
6use crate::scirs2_complex_simd::{
7    apply_cnot_complex_simd, apply_hadamard_gate_complex_simd,
8    apply_single_qubit_gate_complex_simd, ComplexSimdOps, ComplexSimdVector,
9};
10use scirs2_core::ndarray::{Array1, ArrayView1, ArrayViewMut1};
11use scirs2_core::parallel_ops::{
12    IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
13};
14use scirs2_core::simd_ops::SimdUnifiedOps;
15use scirs2_core::Complex64;
16
17/// Simplified SIMD-like structure for complex operations
18/// NOTE: This is being deprecated in favor of `SciRS2` SIMD operations.
19/// New code should use `scirs2_core::simd_ops::SimdUnifiedOps` directly.
20#[derive(Clone, Copy, Debug)]
21#[deprecated(note = "Use scirs2_core::simd_ops::SimdUnifiedOps instead")]
22pub struct ComplexVec4 {
23    re: [f64; 4],
24    im: [f64; 4],
25}
26
27impl ComplexVec4 {
28    /// Create a new `ComplexVec4` from four Complex64 values
29    #[must_use]
30    pub fn new(values: [Complex64; 4]) -> Self {
31        let mut re = [0.0; 4];
32        let mut im = [0.0; 4];
33
34        for i in 0..4 {
35            re[i] = values[i].re;
36            im[i] = values[i].im;
37        }
38
39        Self { re, im }
40    }
41
42    /// Create a new `ComplexVec4` where all elements have the same value
43    #[must_use]
44    pub const fn splat(value: Complex64) -> Self {
45        Self {
46            re: [value.re, value.re, value.re, value.re],
47            im: [value.im, value.im, value.im, value.im],
48        }
49    }
50
51    /// Get the element at the specified index
52    #[must_use]
53    pub fn get(&self, idx: usize) -> Complex64 {
54        assert!(idx < 4, "Index out of bounds");
55        Complex64::new(self.re[idx], self.im[idx])
56    }
57
58    /// Multiply by another `ComplexVec4`
59    #[must_use]
60    pub fn mul(&self, other: &Self) -> Self {
61        let mut result = Self {
62            re: [0.0; 4],
63            im: [0.0; 4],
64        };
65
66        for i in 0..4 {
67            result.re[i] = self.re[i].mul_add(other.re[i], -(self.im[i] * other.im[i]));
68            result.im[i] = self.re[i].mul_add(other.im[i], self.im[i] * other.re[i]);
69        }
70
71        result
72    }
73
74    /// Add another `ComplexVec4`
75    #[must_use]
76    pub fn add(&self, other: &Self) -> Self {
77        let mut result = Self {
78            re: [0.0; 4],
79            im: [0.0; 4],
80        };
81
82        for i in 0..4 {
83            result.re[i] = self.re[i] + other.re[i];
84            result.im[i] = self.im[i] + other.im[i];
85        }
86
87        result
88    }
89
90    /// Subtract another `ComplexVec4`
91    #[must_use]
92    pub fn sub(&self, other: &Self) -> Self {
93        let mut result = Self {
94            re: [0.0; 4],
95            im: [0.0; 4],
96        };
97
98        for i in 0..4 {
99            result.re[i] = self.re[i] - other.re[i];
100            result.im[i] = self.im[i] - other.im[i];
101        }
102
103        result
104    }
105
106    /// Negate all elements
107    #[must_use]
108    pub fn neg(&self) -> Self {
109        let mut result = Self {
110            re: [0.0; 4],
111            im: [0.0; 4],
112        };
113
114        for i in 0..4 {
115            result.re[i] = -self.re[i];
116            result.im[i] = -self.im[i];
117        }
118
119        result
120    }
121}
122
123// ============================================================================
124// NEW SCIRS2-BASED SIMD IMPLEMENTATIONS
125// ============================================================================
126
127/// Apply a single-qubit gate using `SciRS2` SIMD operations
128///
129/// This function uses the `SciRS2` `SimdUnifiedOps` trait for better performance
130/// and compliance with the `SciRS2` integration policy.
131pub fn apply_single_qubit_gate_simd_v2(
132    matrix: &[Complex64; 4],
133    in_amps0: &[Complex64],
134    in_amps1: &[Complex64],
135    out_amps0: &mut [Complex64],
136    out_amps1: &mut [Complex64],
137) {
138    let len = in_amps0.len();
139
140    // Extract matrix elements
141    let m00 = matrix[0];
142    let m01 = matrix[1];
143    let m10 = matrix[2];
144    let m11 = matrix[3];
145
146    // Extract real and imaginary parts for SIMD operations
147    let mut a0_real: Vec<f64> = in_amps0.iter().map(|c| c.re).collect();
148    let mut a0_imag: Vec<f64> = in_amps0.iter().map(|c| c.im).collect();
149    let mut a1_real: Vec<f64> = in_amps1.iter().map(|c| c.re).collect();
150    let mut a1_imag: Vec<f64> = in_amps1.iter().map(|c| c.im).collect();
151
152    let a0_real_view = ArrayView1::from(&a0_real);
153    let a0_imag_view = ArrayView1::from(&a0_imag);
154    let a1_real_view = ArrayView1::from(&a1_real);
155    let a1_imag_view = ArrayView1::from(&a1_imag);
156
157    // Compute new_a0 = m00 * a0 + m01 * a1
158    // Real part: m00.re * a0.re - m00.im * a0.im + m01.re * a1.re - m01.im * a1.im
159    let term1 = f64::simd_scalar_mul(&a0_real_view, m00.re);
160    let term2 = f64::simd_scalar_mul(&a0_imag_view, m00.im);
161    let term3 = f64::simd_scalar_mul(&a1_real_view, m01.re);
162    let term4 = f64::simd_scalar_mul(&a1_imag_view, m01.im);
163
164    let temp1 = f64::simd_sub(&term1.view(), &term2.view());
165    let temp2 = f64::simd_sub(&term3.view(), &term4.view());
166    let new_a0_real = f64::simd_add(&temp1.view(), &temp2.view());
167
168    // Imaginary part: m00.re * a0.im + m00.im * a0.re + m01.re * a1.im + m01.im * a1.re
169    let term5 = f64::simd_scalar_mul(&a0_imag_view, m00.re);
170    let term6 = f64::simd_scalar_mul(&a0_real_view, m00.im);
171    let term7 = f64::simd_scalar_mul(&a1_imag_view, m01.re);
172    let term8 = f64::simd_scalar_mul(&a1_real_view, m01.im);
173
174    let temp3 = f64::simd_add(&term5.view(), &term6.view());
175    let temp4 = f64::simd_add(&term7.view(), &term8.view());
176    let new_a0_imag = f64::simd_add(&temp3.view(), &temp4.view());
177
178    // Compute new_a1 = m10 * a0 + m11 * a1
179    let term9 = f64::simd_scalar_mul(&a0_real_view, m10.re);
180    let term10 = f64::simd_scalar_mul(&a0_imag_view, m10.im);
181    let term11 = f64::simd_scalar_mul(&a1_real_view, m11.re);
182    let term12 = f64::simd_scalar_mul(&a1_imag_view, m11.im);
183
184    let temp5 = f64::simd_sub(&term9.view(), &term10.view());
185    let temp6 = f64::simd_sub(&term11.view(), &term12.view());
186    let new_a1_real = f64::simd_add(&temp5.view(), &temp6.view());
187
188    let term13 = f64::simd_scalar_mul(&a0_imag_view, m10.re);
189    let term14 = f64::simd_scalar_mul(&a0_real_view, m10.im);
190    let term15 = f64::simd_scalar_mul(&a1_imag_view, m11.re);
191    let term16 = f64::simd_scalar_mul(&a1_real_view, m11.im);
192
193    let temp7 = f64::simd_add(&term13.view(), &term14.view());
194    let temp8 = f64::simd_add(&term15.view(), &term16.view());
195    let new_a1_imag = f64::simd_add(&temp7.view(), &temp8.view());
196
197    // Write back results
198    for i in 0..len {
199        out_amps0[i] = Complex64::new(new_a0_real[i], new_a0_imag[i]);
200        out_amps1[i] = Complex64::new(new_a1_real[i], new_a1_imag[i]);
201    }
202}
203
204/// Apply Hadamard gate using `SciRS2` SIMD operations
205pub fn apply_h_gate_simd_v2(
206    in_amps0: &[Complex64],
207    in_amps1: &[Complex64],
208    out_amps0: &mut [Complex64],
209    out_amps1: &mut [Complex64],
210) {
211    let sqrt2_inv = std::f64::consts::FRAC_1_SQRT_2;
212    let len = in_amps0.len();
213
214    // Extract real and imaginary parts
215    let a0_real: Vec<f64> = in_amps0.iter().map(|c| c.re).collect();
216    let a0_imag: Vec<f64> = in_amps0.iter().map(|c| c.im).collect();
217    let a1_real: Vec<f64> = in_amps1.iter().map(|c| c.re).collect();
218    let a1_imag: Vec<f64> = in_amps1.iter().map(|c| c.im).collect();
219
220    let a0_real_view = ArrayView1::from(&a0_real);
221    let a0_imag_view = ArrayView1::from(&a0_imag);
222    let a1_real_view = ArrayView1::from(&a1_real);
223    let a1_imag_view = ArrayView1::from(&a1_imag);
224
225    // Hadamard: new_a0 = (a0 + a1) / sqrt(2), new_a1 = (a0 - a1) / sqrt(2)
226    let sum_real = f64::simd_add(&a0_real_view, &a1_real_view);
227    let sum_imag = f64::simd_add(&a0_imag_view, &a1_imag_view);
228    let diff_real = f64::simd_sub(&a0_real_view, &a1_real_view);
229    let diff_imag = f64::simd_sub(&a0_imag_view, &a1_imag_view);
230
231    let new_a0_real = f64::simd_scalar_mul(&sum_real.view(), sqrt2_inv);
232    let new_a0_imag = f64::simd_scalar_mul(&sum_imag.view(), sqrt2_inv);
233    let new_a1_real = f64::simd_scalar_mul(&diff_real.view(), sqrt2_inv);
234    let new_a1_imag = f64::simd_scalar_mul(&diff_imag.view(), sqrt2_inv);
235
236    // Write back results
237    for i in 0..len {
238        out_amps0[i] = Complex64::new(new_a0_real[i], new_a0_imag[i]);
239        out_amps1[i] = Complex64::new(new_a1_real[i], new_a1_imag[i]);
240    }
241}
242
243// ============================================================================
244// LEGACY IMPLEMENTATIONS (to be removed after full migration)
245// ============================================================================
246
247/// Apply a single-qubit gate to multiple amplitudes using SIMD-like operations
248///
249/// This function processes 4 pairs of amplitudes at once using SIMD-like operations
250///
251/// # Arguments
252///
253/// * `matrix` - The 2x2 matrix representation of the gate
254/// * `in_amps0` - The first set of input amplitudes (corresponding to bit=0)
255/// * `in_amps1` - The second set of input amplitudes (corresponding to bit=1)
256/// * `out_amps0` - Output buffer for the first set of amplitudes
257/// * `out_amps1` - Output buffer for the second set of amplitudes
258pub fn apply_single_qubit_gate_simd(
259    matrix: &[Complex64; 4],
260    in_amps0: &[Complex64],
261    in_amps1: &[Complex64],
262    out_amps0: &mut [Complex64],
263    out_amps1: &mut [Complex64],
264) {
265    // Process elements in chunks of 4
266    let chunks = in_amps0.len() / 4;
267
268    // Extract matrix elements for SIMD-like operations
269    let m00 = ComplexVec4::splat(matrix[0]);
270    let m01 = ComplexVec4::splat(matrix[1]);
271    let m10 = ComplexVec4::splat(matrix[2]);
272    let m11 = ComplexVec4::splat(matrix[3]);
273
274    for chunk in 0..chunks {
275        let offset = chunk * 4;
276
277        // Load 4 complex numbers from in_amps0 and in_amps1
278        let a0 = ComplexVec4::new([
279            in_amps0[offset],
280            in_amps0[offset + 1],
281            in_amps0[offset + 2],
282            in_amps0[offset + 3],
283        ]);
284
285        let a1 = ComplexVec4::new([
286            in_amps1[offset],
287            in_amps1[offset + 1],
288            in_amps1[offset + 2],
289            in_amps1[offset + 3],
290        ]);
291
292        // Compute complex multiplications
293        let m00a0 = m00.mul(&a0);
294        let m01a1 = m01.mul(&a1);
295        let m10a0 = m10.mul(&a0);
296        let m11a1 = m11.mul(&a1);
297
298        // Compute new amplitudes
299        let new_a0 = m00a0.add(&m01a1);
300        let new_a1 = m10a0.add(&m11a1);
301
302        // Store the results
303        for i in 0..4 {
304            out_amps0[offset + i] = new_a0.get(i);
305            out_amps1[offset + i] = new_a1.get(i);
306        }
307    }
308
309    // Handle remaining elements (less than 4)
310    let remainder_start = chunks * 4;
311    for i in remainder_start..in_amps0.len() {
312        let a0 = in_amps0[i];
313        let a1 = in_amps1[i];
314
315        out_amps0[i] = matrix[0] * a0 + matrix[1] * a1;
316        out_amps1[i] = matrix[2] * a0 + matrix[3] * a1;
317    }
318}
319
320/// Apply X gate to multiple amplitudes using SIMD-like operations
321///
322/// This is a specialized implementation for the Pauli X gate, which simply swaps
323/// amplitudes, making it very efficient to implement.
324///
325/// # Arguments
326///
327/// * `in_amps0` - The first set of input amplitudes (corresponding to bit=0)
328/// * `in_amps1` - The second set of input amplitudes (corresponding to bit=1)
329/// * `out_amps0` - Output buffer for the first set of amplitudes
330/// * `out_amps1` - Output buffer for the second set of amplitudes
331pub fn apply_x_gate_simd(
332    in_amps0: &[Complex64],
333    in_amps1: &[Complex64],
334    out_amps0: &mut [Complex64],
335    out_amps1: &mut [Complex64],
336) {
337    // Simply swap the amplitudes using copy_from_slice
338    out_amps0[..in_amps0.len()].copy_from_slice(&in_amps1[..in_amps0.len()]);
339    out_amps1[..in_amps0.len()].copy_from_slice(in_amps0);
340}
341
342/// Apply Z gate to multiple amplitudes using SIMD-like operations
343///
344/// This is a specialized implementation for the Pauli Z gate, which only flips the
345/// sign of amplitudes where the target bit is 1.
346///
347/// # Arguments
348///
349/// * `in_amps0` - The first set of input amplitudes (corresponding to bit=0)
350/// * `in_amps1` - The second set of input amplitudes (corresponding to bit=1)
351/// * `out_amps0` - Output buffer for the first set of amplitudes
352/// * `out_amps1` - Output buffer for the second set of amplitudes
353pub fn apply_z_gate_simd(
354    in_amps0: &[Complex64],
355    in_amps1: &[Complex64],
356    out_amps0: &mut [Complex64],
357    out_amps1: &mut [Complex64],
358) {
359    // For Z gate, a0 stays the same, a1 gets negated
360    for i in 0..in_amps0.len() {
361        out_amps0[i] = in_amps0[i];
362        out_amps1[i] = -in_amps1[i];
363    }
364}
365
366/// Apply Hadamard gate using SIMD-like operations
367///
368/// This is a specialized implementation for the Hadamard gate using the matrix:
369/// H = 1/√2 * [[1, 1], [1, -1]]
370pub fn apply_h_gate_simd(
371    in_amps0: &[Complex64],
372    in_amps1: &[Complex64],
373    out_amps0: &mut [Complex64],
374    out_amps1: &mut [Complex64],
375) {
376    use std::f64::consts::FRAC_1_SQRT_2;
377    let h_coeff = Complex64::new(FRAC_1_SQRT_2, 0.0);
378
379    // Process elements in chunks of 4
380    let chunks = in_amps0.len() / 4;
381    let h_vec = ComplexVec4::splat(h_coeff);
382
383    for chunk in 0..chunks {
384        let offset = chunk * 4;
385
386        let a0 = ComplexVec4::new([
387            in_amps0[offset],
388            in_amps0[offset + 1],
389            in_amps0[offset + 2],
390            in_amps0[offset + 3],
391        ]);
392
393        let a1 = ComplexVec4::new([
394            in_amps1[offset],
395            in_amps1[offset + 1],
396            in_amps1[offset + 2],
397            in_amps1[offset + 3],
398        ]);
399
400        // H|0⟩ = 1/√2(|0⟩ + |1⟩), H|1⟩ = 1/√2(|0⟩ - |1⟩)
401        let sum = a0.add(&a1);
402        let diff = a0.sub(&a1);
403
404        let new_a0 = h_vec.mul(&sum);
405        let new_a1 = h_vec.mul(&diff);
406
407        for i in 0..4 {
408            out_amps0[offset + i] = new_a0.get(i);
409            out_amps1[offset + i] = new_a1.get(i);
410        }
411    }
412
413    // Handle remaining elements
414    let remainder_start = chunks * 4;
415    for i in remainder_start..in_amps0.len() {
416        let a0 = in_amps0[i];
417        let a1 = in_amps1[i];
418
419        out_amps0[i] = h_coeff * (a0 + a1);
420        out_amps1[i] = h_coeff * (a0 - a1);
421    }
422}
423
424/// Apply Y gate using SIMD-like operations
425///
426/// Y gate: [[0, -i], [i, 0]]
427pub fn apply_y_gate_simd(
428    in_amps0: &[Complex64],
429    in_amps1: &[Complex64],
430    out_amps0: &mut [Complex64],
431    out_amps1: &mut [Complex64],
432) {
433    let i_pos = Complex64::new(0.0, 1.0);
434    let i_neg = Complex64::new(0.0, -1.0);
435
436    // Process elements in chunks of 4
437    let chunks = in_amps0.len() / 4;
438    let i_pos_vec = ComplexVec4::splat(i_pos);
439    let i_neg_vec = ComplexVec4::splat(i_neg);
440
441    for chunk in 0..chunks {
442        let offset = chunk * 4;
443
444        let a0 = ComplexVec4::new([
445            in_amps0[offset],
446            in_amps0[offset + 1],
447            in_amps0[offset + 2],
448            in_amps0[offset + 3],
449        ]);
450
451        let a1 = ComplexVec4::new([
452            in_amps1[offset],
453            in_amps1[offset + 1],
454            in_amps1[offset + 2],
455            in_amps1[offset + 3],
456        ]);
457
458        // Y|0⟩ = i|1⟩, Y|1⟩ = -i|0⟩
459        let new_a0 = i_neg_vec.mul(&a1);
460        let new_a1 = i_pos_vec.mul(&a0);
461
462        for i in 0..4 {
463            out_amps0[offset + i] = new_a0.get(i);
464            out_amps1[offset + i] = new_a1.get(i);
465        }
466    }
467
468    // Handle remaining elements
469    let remainder_start = chunks * 4;
470    for i in remainder_start..in_amps0.len() {
471        let a0 = in_amps0[i];
472        let a1 = in_amps1[i];
473
474        out_amps0[i] = i_neg * a1;
475        out_amps1[i] = i_pos * a0;
476    }
477}
478
479/// Apply phase gate (S gate) using SIMD-like operations
480///
481/// S gate: [[1, 0], [0, i]]
482pub fn apply_s_gate_simd(
483    in_amps0: &[Complex64],
484    in_amps1: &[Complex64],
485    out_amps0: &mut [Complex64],
486    out_amps1: &mut [Complex64],
487) {
488    let i_phase = Complex64::new(0.0, 1.0);
489
490    // Process elements in chunks of 4
491    let chunks = in_amps0.len() / 4;
492    let i_vec = ComplexVec4::splat(i_phase);
493
494    for chunk in 0..chunks {
495        let offset = chunk * 4;
496
497        let a1 = ComplexVec4::new([
498            in_amps1[offset],
499            in_amps1[offset + 1],
500            in_amps1[offset + 2],
501            in_amps1[offset + 3],
502        ]);
503
504        let new_a1 = i_vec.mul(&a1);
505
506        // Copy a0 unchanged, multiply a1 by i
507        for i in 0..4 {
508            out_amps0[offset + i] = in_amps0[offset + i];
509            out_amps1[offset + i] = new_a1.get(i);
510        }
511    }
512
513    // Handle remaining elements
514    let remainder_start = chunks * 4;
515    for i in remainder_start..in_amps0.len() {
516        out_amps0[i] = in_amps0[i];
517        out_amps1[i] = i_phase * in_amps1[i];
518    }
519}
520
521/// Apply rotation-X gate using SIMD-like operations
522///
523/// RX(θ) = [[cos(θ/2), -i*sin(θ/2)], [-i*sin(θ/2), cos(θ/2)]]
524pub fn apply_rx_gate_simd(
525    angle: f64,
526    in_amps0: &[Complex64],
527    in_amps1: &[Complex64],
528    out_amps0: &mut [Complex64],
529    out_amps1: &mut [Complex64],
530) {
531    let half_angle = angle / 2.0;
532    let cos_val = Complex64::new(half_angle.cos(), 0.0);
533    let neg_i_sin_val = Complex64::new(0.0, -half_angle.sin());
534
535    // Process elements in chunks of 4
536    let chunks = in_amps0.len() / 4;
537    let cos_vec = ComplexVec4::splat(cos_val);
538    let neg_i_sin_vec = ComplexVec4::splat(neg_i_sin_val);
539
540    for chunk in 0..chunks {
541        let offset = chunk * 4;
542
543        let a0 = ComplexVec4::new([
544            in_amps0[offset],
545            in_amps0[offset + 1],
546            in_amps0[offset + 2],
547            in_amps0[offset + 3],
548        ]);
549
550        let a1 = ComplexVec4::new([
551            in_amps1[offset],
552            in_amps1[offset + 1],
553            in_amps1[offset + 2],
554            in_amps1[offset + 3],
555        ]);
556
557        let cos_a0 = cos_vec.mul(&a0);
558        let neg_i_sin_a1 = neg_i_sin_vec.mul(&a1);
559        let neg_i_sin_a0 = neg_i_sin_vec.mul(&a0);
560        let cos_a1 = cos_vec.mul(&a1);
561
562        let new_a0 = cos_a0.add(&neg_i_sin_a1);
563        let new_a1 = neg_i_sin_a0.add(&cos_a1);
564
565        for i in 0..4 {
566            out_amps0[offset + i] = new_a0.get(i);
567            out_amps1[offset + i] = new_a1.get(i);
568        }
569    }
570
571    // Handle remaining elements
572    let remainder_start = chunks * 4;
573    for i in remainder_start..in_amps0.len() {
574        let a0 = in_amps0[i];
575        let a1 = in_amps1[i];
576
577        out_amps0[i] = cos_val * a0 + neg_i_sin_val * a1;
578        out_amps1[i] = neg_i_sin_val * a0 + cos_val * a1;
579    }
580}
581
582/// SIMD-optimized wrapper function for applying gates
583///
584/// This function uses enhanced `SciRS2` complex SIMD implementations for optimal performance.
585pub fn apply_single_qubit_gate_optimized(
586    matrix: &[Complex64; 4],
587    in_amps0: &[Complex64],
588    in_amps1: &[Complex64],
589    out_amps0: &mut [Complex64],
590    out_amps1: &mut [Complex64],
591) {
592    use std::f64::consts::FRAC_1_SQRT_2;
593
594    // Determine optimal implementation based on vector size and hardware capabilities
595    let vector_size = in_amps0.len();
596    let simd_threshold = 64; // Minimum size to benefit from complex SIMD
597
598    if vector_size >= simd_threshold && ComplexSimdVector::detect_simd_width() > 1 {
599        // Use enhanced complex SIMD implementation for large vectors
600        if is_hadamard_gate(matrix) {
601            apply_hadamard_gate_complex_simd(in_amps0, in_amps1, out_amps0, out_amps1);
602        } else {
603            apply_single_qubit_gate_complex_simd(matrix, in_amps0, in_amps1, out_amps0, out_amps1);
604        }
605    } else {
606        // Fall back to component-wise SIMD for smaller vectors or limited hardware
607        if is_hadamard_gate(matrix) {
608            apply_h_gate_simd_v2(in_amps0, in_amps1, out_amps0, out_amps1);
609        } else {
610            apply_single_qubit_gate_simd_v2(matrix, in_amps0, in_amps1, out_amps0, out_amps1);
611        }
612    }
613}
614
615/// Check if matrix represents a Hadamard gate
616fn is_hadamard_gate(matrix: &[Complex64; 4]) -> bool {
617    use std::f64::consts::FRAC_1_SQRT_2;
618
619    let h_matrix = [
620        Complex64::new(FRAC_1_SQRT_2, 0.0),
621        Complex64::new(FRAC_1_SQRT_2, 0.0),
622        Complex64::new(FRAC_1_SQRT_2, 0.0),
623        Complex64::new(-FRAC_1_SQRT_2, 0.0),
624    ];
625
626    matrix
627        .iter()
628        .zip(h_matrix.iter())
629        .all(|(a, b)| (a - b).norm() < 1e-10)
630}
631
632/// Apply rotation-Y gate using SIMD-like operations
633///
634/// RY(θ) = [[cos(θ/2), -sin(θ/2)], [sin(θ/2), cos(θ/2)]]
635pub fn apply_ry_gate_simd(
636    angle: f64,
637    in_amps0: &[Complex64],
638    in_amps1: &[Complex64],
639    out_amps0: &mut [Complex64],
640    out_amps1: &mut [Complex64],
641) {
642    let half_angle = angle / 2.0;
643    let cos_val = Complex64::new(half_angle.cos(), 0.0);
644    let sin_val = Complex64::new(half_angle.sin(), 0.0);
645    let neg_sin_val = Complex64::new(-half_angle.sin(), 0.0);
646
647    // Process elements in chunks of 4
648    let chunks = in_amps0.len() / 4;
649    let cos_vec = ComplexVec4::splat(cos_val);
650    let sin_vec = ComplexVec4::splat(sin_val);
651    let neg_sin_vec = ComplexVec4::splat(neg_sin_val);
652
653    for chunk in 0..chunks {
654        let offset = chunk * 4;
655
656        let a0 = ComplexVec4::new([
657            in_amps0[offset],
658            in_amps0[offset + 1],
659            in_amps0[offset + 2],
660            in_amps0[offset + 3],
661        ]);
662
663        let a1 = ComplexVec4::new([
664            in_amps1[offset],
665            in_amps1[offset + 1],
666            in_amps1[offset + 2],
667            in_amps1[offset + 3],
668        ]);
669
670        let cos_a0 = cos_vec.mul(&a0);
671        let neg_sin_a1 = neg_sin_vec.mul(&a1);
672        let sin_a0 = sin_vec.mul(&a0);
673        let cos_a1 = cos_vec.mul(&a1);
674
675        let new_a0 = cos_a0.add(&neg_sin_a1);
676        let new_a1 = sin_a0.add(&cos_a1);
677
678        for i in 0..4 {
679            out_amps0[offset + i] = new_a0.get(i);
680            out_amps1[offset + i] = new_a1.get(i);
681        }
682    }
683
684    // Handle remaining elements
685    let remainder_start = chunks * 4;
686    for i in remainder_start..in_amps0.len() {
687        let a0 = in_amps0[i];
688        let a1 = in_amps1[i];
689
690        out_amps0[i] = cos_val * a0 + neg_sin_val * a1;
691        out_amps1[i] = sin_val * a0 + cos_val * a1;
692    }
693}
694
695/// Apply rotation-Z gate using SIMD-like operations
696///
697/// RZ(θ) = [[e^(-iθ/2), 0], [0, e^(iθ/2)]]
698pub fn apply_rz_gate_simd(
699    angle: f64,
700    in_amps0: &[Complex64],
701    in_amps1: &[Complex64],
702    out_amps0: &mut [Complex64],
703    out_amps1: &mut [Complex64],
704) {
705    let half_angle = angle / 2.0;
706    let exp_neg_i = Complex64::new(half_angle.cos(), -half_angle.sin());
707    let exp_pos_i = Complex64::new(half_angle.cos(), half_angle.sin());
708
709    // Process elements in chunks of 4
710    let chunks = in_amps0.len() / 4;
711    let exp_neg_vec = ComplexVec4::splat(exp_neg_i);
712    let exp_pos_vec = ComplexVec4::splat(exp_pos_i);
713
714    for chunk in 0..chunks {
715        let offset = chunk * 4;
716
717        let a0 = ComplexVec4::new([
718            in_amps0[offset],
719            in_amps0[offset + 1],
720            in_amps0[offset + 2],
721            in_amps0[offset + 3],
722        ]);
723
724        let a1 = ComplexVec4::new([
725            in_amps1[offset],
726            in_amps1[offset + 1],
727            in_amps1[offset + 2],
728            in_amps1[offset + 3],
729        ]);
730
731        let new_a0 = exp_neg_vec.mul(&a0);
732        let new_a1 = exp_pos_vec.mul(&a1);
733
734        for i in 0..4 {
735            out_amps0[offset + i] = new_a0.get(i);
736            out_amps1[offset + i] = new_a1.get(i);
737        }
738    }
739
740    // Handle remaining elements
741    let remainder_start = chunks * 4;
742    for i in remainder_start..in_amps0.len() {
743        out_amps0[i] = exp_neg_i * in_amps0[i];
744        out_amps1[i] = exp_pos_i * in_amps1[i];
745    }
746}
747
748/// Apply T gate using SIMD-like operations
749///
750/// T gate: [[1, 0], [0, e^(iπ/4)]]
751pub fn apply_t_gate_simd(
752    in_amps0: &[Complex64],
753    in_amps1: &[Complex64],
754    out_amps0: &mut [Complex64],
755    out_amps1: &mut [Complex64],
756) {
757    use std::f64::consts::FRAC_PI_4;
758    let t_phase = Complex64::new(FRAC_PI_4.cos(), FRAC_PI_4.sin());
759
760    // Process elements in chunks of 4
761    let chunks = in_amps0.len() / 4;
762    let t_vec = ComplexVec4::splat(t_phase);
763
764    for chunk in 0..chunks {
765        let offset = chunk * 4;
766
767        let a1 = ComplexVec4::new([
768            in_amps1[offset],
769            in_amps1[offset + 1],
770            in_amps1[offset + 2],
771            in_amps1[offset + 3],
772        ]);
773
774        let new_a1 = t_vec.mul(&a1);
775
776        // Copy a0 unchanged, multiply a1 by t_phase
777        for i in 0..4 {
778            out_amps0[offset + i] = in_amps0[offset + i];
779            out_amps1[offset + i] = new_a1.get(i);
780        }
781    }
782
783    // Handle remaining elements
784    let remainder_start = chunks * 4;
785    for i in remainder_start..in_amps0.len() {
786        out_amps0[i] = in_amps0[i];
787        out_amps1[i] = t_phase * in_amps1[i];
788    }
789}
790
791/// Gate fusion structure for combining adjacent single-qubit gates
792#[derive(Debug, Clone)]
793pub struct GateFusion {
794    /// Fused matrix representation
795    pub fused_matrix: [Complex64; 4],
796    /// Target qubit
797    pub target: usize,
798    /// Number of gates fused
799    pub gate_count: usize,
800}
801
802impl GateFusion {
803    /// Create a new gate fusion starting with an identity gate
804    #[must_use]
805    pub const fn new(target: usize) -> Self {
806        Self {
807            fused_matrix: [
808                Complex64::new(1.0, 0.0), // I[0,0]
809                Complex64::new(0.0, 0.0), // I[0,1]
810                Complex64::new(0.0, 0.0), // I[1,0]
811                Complex64::new(1.0, 0.0), // I[1,1]
812            ],
813            target,
814            gate_count: 0,
815        }
816    }
817
818    /// Fuse another gate into this fusion
819    pub fn fuse_gate(&mut self, gate_matrix: &[Complex64; 4]) {
820        // Matrix multiplication: new_matrix = gate_matrix * fused_matrix
821        let m = &self.fused_matrix;
822        let g = gate_matrix;
823
824        self.fused_matrix = [
825            g[0] * m[0] + g[1] * m[2], // (0,0)
826            g[0] * m[1] + g[1] * m[3], // (0,1)
827            g[2] * m[0] + g[3] * m[2], // (1,0)
828            g[2] * m[1] + g[3] * m[3], // (1,1)
829        ];
830
831        self.gate_count += 1;
832    }
833
834    /// Check if this fusion can be applied using a specialized SIMD kernel
835    #[must_use]
836    pub fn can_use_specialized_kernel(&self) -> bool {
837        use std::f64::consts::FRAC_1_SQRT_2;
838
839        // Check for common gate patterns after fusion
840        let m = &self.fused_matrix;
841
842        // Identity gate (no-op)
843        if (m[0] - Complex64::new(1.0, 0.0)).norm() < 1e-10
844            && m[1].norm() < 1e-10
845            && m[2].norm() < 1e-10
846            && (m[3] - Complex64::new(1.0, 0.0)).norm() < 1e-10
847        {
848            return true;
849        }
850
851        // X gate
852        if m[0].norm() < 1e-10
853            && (m[1] - Complex64::new(1.0, 0.0)).norm() < 1e-10
854            && (m[2] - Complex64::new(1.0, 0.0)).norm() < 1e-10
855            && m[3].norm() < 1e-10
856        {
857            return true;
858        }
859
860        // Y gate
861        if m[0].norm() < 1e-10
862            && (m[1] - Complex64::new(0.0, -1.0)).norm() < 1e-10
863            && (m[2] - Complex64::new(0.0, 1.0)).norm() < 1e-10
864            && m[3].norm() < 1e-10
865        {
866            return true;
867        }
868
869        // Z gate
870        if (m[0] - Complex64::new(1.0, 0.0)).norm() < 1e-10
871            && m[1].norm() < 1e-10
872            && m[2].norm() < 1e-10
873            && (m[3] - Complex64::new(-1.0, 0.0)).norm() < 1e-10
874        {
875            return true;
876        }
877
878        // Hadamard gate
879        if (m[0] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10
880            && (m[1] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10
881            && (m[2] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10
882            && (m[3] - Complex64::new(-FRAC_1_SQRT_2, 0.0)).norm() < 1e-10
883        {
884            return true;
885        }
886
887        false
888    }
889
890    /// Apply the fused gate using SIMD optimization
891    pub fn apply_simd(
892        &self,
893        in_amps0: &[Complex64],
894        in_amps1: &[Complex64],
895        out_amps0: &mut [Complex64],
896        out_amps1: &mut [Complex64],
897    ) {
898        apply_single_qubit_gate_optimized(
899            &self.fused_matrix,
900            in_amps0,
901            in_amps1,
902            out_amps0,
903            out_amps1,
904        );
905    }
906}
907
908/// Vectorized CNOT gate application using SIMD for processing multiple pairs
909///
910/// This processes control/target pairs in parallel where possible
911pub fn apply_cnot_vectorized(
912    state: &mut [Complex64],
913    control_indices: &[usize],
914    target_indices: &[usize],
915    num_qubits: usize,
916) {
917    let dim = 1 << num_qubits;
918    let mut new_state = vec![Complex64::new(0.0, 0.0); dim];
919
920    // Process all CNOT gates in parallel
921    new_state
922        .par_iter_mut()
923        .enumerate()
924        .for_each(|(i, new_amp)| {
925            let mut final_idx = i;
926
927            // Apply all CNOT gates in sequence
928            for (&control_idx, &target_idx) in control_indices.iter().zip(target_indices.iter()) {
929                if (final_idx >> control_idx) & 1 == 1 {
930                    final_idx ^= 1 << target_idx;
931                }
932            }
933
934            *new_amp = state[final_idx];
935        });
936
937    state.copy_from_slice(&new_state);
938}
939
940/// Scalar implementation of `apply_single_qubit_gate` for fallback
941///
942/// # Arguments
943///
944/// * `matrix` - The 2x2 matrix representation of the gate
945/// * `in_amps0` - The first set of input amplitudes (corresponding to bit=0)
946/// * `in_amps1` - The second set of input amplitudes (corresponding to bit=1)
947/// * `out_amps0` - Output buffer for the first set of amplitudes
948/// * `out_amps1` - Output buffer for the second set of amplitudes
949pub fn apply_single_qubit_gate_scalar(
950    matrix: &[Complex64; 4],
951    in_amps0: &[Complex64],
952    in_amps1: &[Complex64],
953    out_amps0: &mut [Complex64],
954    out_amps1: &mut [Complex64],
955) {
956    for i in 0..in_amps0.len() {
957        let a0 = in_amps0[i];
958        let a1 = in_amps1[i];
959
960        out_amps0[i] = matrix[0] * a0 + matrix[1] * a1;
961        out_amps1[i] = matrix[2] * a0 + matrix[3] * a1;
962    }
963}
964
965#[cfg(test)]
966mod tests {
967    use super::*;
968    use std::f64::consts::FRAC_1_SQRT_2;
969
970    #[test]
971    fn test_x_gate_scalar() {
972        // X gate matrix
973        let x_matrix = [
974            Complex64::new(0.0, 0.0),
975            Complex64::new(1.0, 0.0),
976            Complex64::new(1.0, 0.0),
977            Complex64::new(0.0, 0.0),
978        ];
979
980        // Test data
981        let in_amps0 = vec![Complex64::new(1.0, 0.0), Complex64::new(0.5, 0.0)];
982        let in_amps1 = vec![Complex64::new(0.0, 0.0), Complex64::new(0.5, 0.0)];
983        let mut out_amps0 = [Complex64::new(0.0, 0.0); 2];
984        let mut out_amps1 = [Complex64::new(0.0, 0.0); 2];
985
986        // Apply gate
987        apply_single_qubit_gate_scalar(
988            &x_matrix,
989            &in_amps0,
990            &in_amps1,
991            &mut out_amps0,
992            &mut out_amps1,
993        );
994
995        // Check results
996        assert_eq!(out_amps0[0], Complex64::new(0.0, 0.0));
997        assert_eq!(out_amps1[0], Complex64::new(1.0, 0.0));
998        assert_eq!(out_amps0[1], Complex64::new(0.5, 0.0));
999        assert_eq!(out_amps1[1], Complex64::new(0.5, 0.0));
1000    }
1001
1002    #[test]
1003    fn test_hadamard_gate_scalar() {
1004        // Hadamard gate matrix
1005        let h_matrix = [
1006            Complex64::new(FRAC_1_SQRT_2, 0.0),
1007            Complex64::new(FRAC_1_SQRT_2, 0.0),
1008            Complex64::new(FRAC_1_SQRT_2, 0.0),
1009            Complex64::new(-FRAC_1_SQRT_2, 0.0),
1010        ];
1011
1012        // Test data
1013        let in_amps0 = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
1014        let in_amps1 = vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)];
1015        let mut out_amps0 = [Complex64::new(0.0, 0.0); 2];
1016        let mut out_amps1 = [Complex64::new(0.0, 0.0); 2];
1017
1018        // Apply gate
1019        apply_single_qubit_gate_scalar(
1020            &h_matrix,
1021            &in_amps0,
1022            &in_amps1,
1023            &mut out_amps0,
1024            &mut out_amps1,
1025        );
1026
1027        // Check results - applying H to |0> should give (|0> + |1>)/sqrt(2)
1028        assert!((out_amps0[0] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
1029        assert!((out_amps1[0] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
1030
1031        // Applying H to |1> should give (|0> - |1>)/sqrt(2)
1032        assert!((out_amps0[1] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
1033        assert!((out_amps1[1] - Complex64::new(-FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
1034    }
1035
1036    #[test]
1037    fn test_optimized_gate_wrapper() {
1038        // Hadamard gate matrix
1039        let h_matrix = [
1040            Complex64::new(FRAC_1_SQRT_2, 0.0),
1041            Complex64::new(FRAC_1_SQRT_2, 0.0),
1042            Complex64::new(FRAC_1_SQRT_2, 0.0),
1043            Complex64::new(-FRAC_1_SQRT_2, 0.0),
1044        ];
1045
1046        // Test data
1047        let in_amps0 = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
1048        let in_amps1 = vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)];
1049        let mut out_amps0 = [Complex64::new(0.0, 0.0); 2];
1050        let mut out_amps1 = [Complex64::new(0.0, 0.0); 2];
1051
1052        // Apply gate using the optimized wrapper
1053        apply_single_qubit_gate_optimized(
1054            &h_matrix,
1055            &in_amps0,
1056            &in_amps1,
1057            &mut out_amps0,
1058            &mut out_amps1,
1059        );
1060
1061        // Check results - applying H to |0> should give (|0> + |1>)/sqrt(2)
1062        assert!((out_amps0[0] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
1063        assert!((out_amps1[0] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
1064
1065        // Applying H to |1> should give (|0> - |1>)/sqrt(2)
1066        assert!((out_amps0[1] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
1067        assert!((out_amps1[1] - Complex64::new(-FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
1068    }
1069
1070    #[test]
1071    fn test_complex_vec4() {
1072        // Test splat creation
1073        let a = ComplexVec4::splat(Complex64::new(1.0, 2.0));
1074        for i in 0..4 {
1075            assert_eq!(a.get(i), Complex64::new(1.0, 2.0));
1076        }
1077
1078        // Test new creation
1079        let b = ComplexVec4::new([
1080            Complex64::new(1.0, 2.0),
1081            Complex64::new(3.0, 4.0),
1082            Complex64::new(5.0, 6.0),
1083            Complex64::new(7.0, 8.0),
1084        ]);
1085
1086        assert_eq!(b.get(0), Complex64::new(1.0, 2.0));
1087        assert_eq!(b.get(1), Complex64::new(3.0, 4.0));
1088        assert_eq!(b.get(2), Complex64::new(5.0, 6.0));
1089        assert_eq!(b.get(3), Complex64::new(7.0, 8.0));
1090
1091        // Test multiplication
1092        let c = a.mul(&b);
1093        assert!((c.get(0) - Complex64::new(1.0, 2.0) * Complex64::new(1.0, 2.0)).norm() < 1e-10);
1094        assert!((c.get(1) - Complex64::new(1.0, 2.0) * Complex64::new(3.0, 4.0)).norm() < 1e-10);
1095        assert!((c.get(2) - Complex64::new(1.0, 2.0) * Complex64::new(5.0, 6.0)).norm() < 1e-10);
1096        assert!((c.get(3) - Complex64::new(1.0, 2.0) * Complex64::new(7.0, 8.0)).norm() < 1e-10);
1097
1098        // Test addition
1099        let d = a.add(&b);
1100        assert!((d.get(0) - (Complex64::new(1.0, 2.0) + Complex64::new(1.0, 2.0))).norm() < 1e-10);
1101        assert!((d.get(1) - (Complex64::new(1.0, 2.0) + Complex64::new(3.0, 4.0))).norm() < 1e-10);
1102        assert!((d.get(2) - (Complex64::new(1.0, 2.0) + Complex64::new(5.0, 6.0))).norm() < 1e-10);
1103        assert!((d.get(3) - (Complex64::new(1.0, 2.0) + Complex64::new(7.0, 8.0))).norm() < 1e-10);
1104
1105        // Test negation
1106        let e = b.neg();
1107        assert!((e.get(0) - (-Complex64::new(1.0, 2.0))).norm() < 1e-10);
1108        assert!((e.get(1) - (-Complex64::new(3.0, 4.0))).norm() < 1e-10);
1109        assert!((e.get(2) - (-Complex64::new(5.0, 6.0))).norm() < 1e-10);
1110        assert!((e.get(3) - (-Complex64::new(7.0, 8.0))).norm() < 1e-10);
1111    }
1112}