quantrs2_sim/
optimized_simd.rs

1//! SIMD-accelerated operations for quantum state vector simulation
2//!
3//! This module provides SIMD-optimized implementations of quantum gate operations
4//! for improved performance on modern CPUs using SciRS2 SIMD operations.
5
6use crate::scirs2_complex_simd::{
7    apply_cnot_complex_simd, apply_hadamard_gate_complex_simd,
8    apply_single_qubit_gate_complex_simd, ComplexSimdOps, ComplexSimdVector,
9};
10use scirs2_core::ndarray::{Array1, ArrayView1, ArrayViewMut1};
11use scirs2_core::Complex64;
12use scirs2_core::parallel_ops::*;
13use scirs2_core::simd_ops::SimdUnifiedOps;
14
15/// Simplified SIMD-like structure for complex operations
16/// NOTE: This is being deprecated in favor of SciRS2 SIMD operations.
17/// New code should use scirs2_core::simd_ops::SimdUnifiedOps directly.
18#[derive(Clone, Copy, Debug)]
19#[deprecated(note = "Use scirs2_core::simd_ops::SimdUnifiedOps instead")]
20pub struct ComplexVec4 {
21    re: [f64; 4],
22    im: [f64; 4],
23}
24
25impl ComplexVec4 {
26    /// Create a new ComplexVec4 from four Complex64 values
27    pub fn new(values: [Complex64; 4]) -> Self {
28        let mut re = [0.0; 4];
29        let mut im = [0.0; 4];
30
31        for i in 0..4 {
32            re[i] = values[i].re;
33            im[i] = values[i].im;
34        }
35
36        Self { re, im }
37    }
38
39    /// Create a new ComplexVec4 where all elements have the same value
40    pub fn splat(value: Complex64) -> Self {
41        Self {
42            re: [value.re, value.re, value.re, value.re],
43            im: [value.im, value.im, value.im, value.im],
44        }
45    }
46
47    /// Get the element at the specified index
48    pub fn get(&self, idx: usize) -> Complex64 {
49        assert!(idx < 4, "Index out of bounds");
50        Complex64::new(self.re[idx], self.im[idx])
51    }
52
53    /// Multiply by another ComplexVec4
54    pub fn mul(&self, other: &ComplexVec4) -> ComplexVec4 {
55        let mut result = ComplexVec4 {
56            re: [0.0; 4],
57            im: [0.0; 4],
58        };
59
60        for i in 0..4 {
61            result.re[i] = self.re[i] * other.re[i] - self.im[i] * other.im[i];
62            result.im[i] = self.re[i] * other.im[i] + self.im[i] * other.re[i];
63        }
64
65        result
66    }
67
68    /// Add another ComplexVec4
69    pub fn add(&self, other: &ComplexVec4) -> ComplexVec4 {
70        let mut result = ComplexVec4 {
71            re: [0.0; 4],
72            im: [0.0; 4],
73        };
74
75        for i in 0..4 {
76            result.re[i] = self.re[i] + other.re[i];
77            result.im[i] = self.im[i] + other.im[i];
78        }
79
80        result
81    }
82
83    /// Subtract another ComplexVec4
84    pub fn sub(&self, other: &ComplexVec4) -> ComplexVec4 {
85        let mut result = ComplexVec4 {
86            re: [0.0; 4],
87            im: [0.0; 4],
88        };
89
90        for i in 0..4 {
91            result.re[i] = self.re[i] - other.re[i];
92            result.im[i] = self.im[i] - other.im[i];
93        }
94
95        result
96    }
97
98    /// Negate all elements
99    pub fn neg(&self) -> ComplexVec4 {
100        let mut result = ComplexVec4 {
101            re: [0.0; 4],
102            im: [0.0; 4],
103        };
104
105        for i in 0..4 {
106            result.re[i] = -self.re[i];
107            result.im[i] = -self.im[i];
108        }
109
110        result
111    }
112}
113
114// ============================================================================
115// NEW SCIRS2-BASED SIMD IMPLEMENTATIONS
116// ============================================================================
117
118/// Apply a single-qubit gate using SciRS2 SIMD operations
119///
120/// This function uses the SciRS2 SimdUnifiedOps trait for better performance
121/// and compliance with the SciRS2 integration policy.
122pub fn apply_single_qubit_gate_simd_v2(
123    matrix: &[Complex64; 4],
124    in_amps0: &[Complex64],
125    in_amps1: &[Complex64],
126    out_amps0: &mut [Complex64],
127    out_amps1: &mut [Complex64],
128) {
129    let len = in_amps0.len();
130
131    // Extract matrix elements
132    let m00 = matrix[0];
133    let m01 = matrix[1];
134    let m10 = matrix[2];
135    let m11 = matrix[3];
136
137    // Extract real and imaginary parts for SIMD operations
138    let mut a0_real: Vec<f64> = in_amps0.iter().map(|c| c.re).collect();
139    let mut a0_imag: Vec<f64> = in_amps0.iter().map(|c| c.im).collect();
140    let mut a1_real: Vec<f64> = in_amps1.iter().map(|c| c.re).collect();
141    let mut a1_imag: Vec<f64> = in_amps1.iter().map(|c| c.im).collect();
142
143    let a0_real_view = ArrayView1::from(&a0_real);
144    let a0_imag_view = ArrayView1::from(&a0_imag);
145    let a1_real_view = ArrayView1::from(&a1_real);
146    let a1_imag_view = ArrayView1::from(&a1_imag);
147
148    // Compute new_a0 = m00 * a0 + m01 * a1
149    // Real part: m00.re * a0.re - m00.im * a0.im + m01.re * a1.re - m01.im * a1.im
150    let term1 = f64::simd_scalar_mul(&a0_real_view, m00.re);
151    let term2 = f64::simd_scalar_mul(&a0_imag_view, m00.im);
152    let term3 = f64::simd_scalar_mul(&a1_real_view, m01.re);
153    let term4 = f64::simd_scalar_mul(&a1_imag_view, m01.im);
154
155    let temp1 = f64::simd_sub(&term1.view(), &term2.view());
156    let temp2 = f64::simd_sub(&term3.view(), &term4.view());
157    let new_a0_real = f64::simd_add(&temp1.view(), &temp2.view());
158
159    // Imaginary part: m00.re * a0.im + m00.im * a0.re + m01.re * a1.im + m01.im * a1.re
160    let term5 = f64::simd_scalar_mul(&a0_imag_view, m00.re);
161    let term6 = f64::simd_scalar_mul(&a0_real_view, m00.im);
162    let term7 = f64::simd_scalar_mul(&a1_imag_view, m01.re);
163    let term8 = f64::simd_scalar_mul(&a1_real_view, m01.im);
164
165    let temp3 = f64::simd_add(&term5.view(), &term6.view());
166    let temp4 = f64::simd_add(&term7.view(), &term8.view());
167    let new_a0_imag = f64::simd_add(&temp3.view(), &temp4.view());
168
169    // Compute new_a1 = m10 * a0 + m11 * a1
170    let term9 = f64::simd_scalar_mul(&a0_real_view, m10.re);
171    let term10 = f64::simd_scalar_mul(&a0_imag_view, m10.im);
172    let term11 = f64::simd_scalar_mul(&a1_real_view, m11.re);
173    let term12 = f64::simd_scalar_mul(&a1_imag_view, m11.im);
174
175    let temp5 = f64::simd_sub(&term9.view(), &term10.view());
176    let temp6 = f64::simd_sub(&term11.view(), &term12.view());
177    let new_a1_real = f64::simd_add(&temp5.view(), &temp6.view());
178
179    let term13 = f64::simd_scalar_mul(&a0_imag_view, m10.re);
180    let term14 = f64::simd_scalar_mul(&a0_real_view, m10.im);
181    let term15 = f64::simd_scalar_mul(&a1_imag_view, m11.re);
182    let term16 = f64::simd_scalar_mul(&a1_real_view, m11.im);
183
184    let temp7 = f64::simd_add(&term13.view(), &term14.view());
185    let temp8 = f64::simd_add(&term15.view(), &term16.view());
186    let new_a1_imag = f64::simd_add(&temp7.view(), &temp8.view());
187
188    // Write back results
189    for i in 0..len {
190        out_amps0[i] = Complex64::new(new_a0_real[i], new_a0_imag[i]);
191        out_amps1[i] = Complex64::new(new_a1_real[i], new_a1_imag[i]);
192    }
193}
194
195/// Apply Hadamard gate using SciRS2 SIMD operations
196pub fn apply_h_gate_simd_v2(
197    in_amps0: &[Complex64],
198    in_amps1: &[Complex64],
199    out_amps0: &mut [Complex64],
200    out_amps1: &mut [Complex64],
201) {
202    let sqrt2_inv = std::f64::consts::FRAC_1_SQRT_2;
203    let len = in_amps0.len();
204
205    // Extract real and imaginary parts
206    let a0_real: Vec<f64> = in_amps0.iter().map(|c| c.re).collect();
207    let a0_imag: Vec<f64> = in_amps0.iter().map(|c| c.im).collect();
208    let a1_real: Vec<f64> = in_amps1.iter().map(|c| c.re).collect();
209    let a1_imag: Vec<f64> = in_amps1.iter().map(|c| c.im).collect();
210
211    let a0_real_view = ArrayView1::from(&a0_real);
212    let a0_imag_view = ArrayView1::from(&a0_imag);
213    let a1_real_view = ArrayView1::from(&a1_real);
214    let a1_imag_view = ArrayView1::from(&a1_imag);
215
216    // Hadamard: new_a0 = (a0 + a1) / sqrt(2), new_a1 = (a0 - a1) / sqrt(2)
217    let sum_real = f64::simd_add(&a0_real_view, &a1_real_view);
218    let sum_imag = f64::simd_add(&a0_imag_view, &a1_imag_view);
219    let diff_real = f64::simd_sub(&a0_real_view, &a1_real_view);
220    let diff_imag = f64::simd_sub(&a0_imag_view, &a1_imag_view);
221
222    let new_a0_real = f64::simd_scalar_mul(&sum_real.view(), sqrt2_inv);
223    let new_a0_imag = f64::simd_scalar_mul(&sum_imag.view(), sqrt2_inv);
224    let new_a1_real = f64::simd_scalar_mul(&diff_real.view(), sqrt2_inv);
225    let new_a1_imag = f64::simd_scalar_mul(&diff_imag.view(), sqrt2_inv);
226
227    // Write back results
228    for i in 0..len {
229        out_amps0[i] = Complex64::new(new_a0_real[i], new_a0_imag[i]);
230        out_amps1[i] = Complex64::new(new_a1_real[i], new_a1_imag[i]);
231    }
232}
233
234// ============================================================================
235// LEGACY IMPLEMENTATIONS (to be removed after full migration)
236// ============================================================================
237
238/// Apply a single-qubit gate to multiple amplitudes using SIMD-like operations
239///
240/// This function processes 4 pairs of amplitudes at once using SIMD-like operations
241///
242/// # Arguments
243///
244/// * `matrix` - The 2x2 matrix representation of the gate
245/// * `in_amps0` - The first set of input amplitudes (corresponding to bit=0)
246/// * `in_amps1` - The second set of input amplitudes (corresponding to bit=1)
247/// * `out_amps0` - Output buffer for the first set of amplitudes
248/// * `out_amps1` - Output buffer for the second set of amplitudes
249pub fn apply_single_qubit_gate_simd(
250    matrix: &[Complex64; 4],
251    in_amps0: &[Complex64],
252    in_amps1: &[Complex64],
253    out_amps0: &mut [Complex64],
254    out_amps1: &mut [Complex64],
255) {
256    // Process elements in chunks of 4
257    let chunks = in_amps0.len() / 4;
258
259    // Extract matrix elements for SIMD-like operations
260    let m00 = ComplexVec4::splat(matrix[0]);
261    let m01 = ComplexVec4::splat(matrix[1]);
262    let m10 = ComplexVec4::splat(matrix[2]);
263    let m11 = ComplexVec4::splat(matrix[3]);
264
265    for chunk in 0..chunks {
266        let offset = chunk * 4;
267
268        // Load 4 complex numbers from in_amps0 and in_amps1
269        let a0 = ComplexVec4::new([
270            in_amps0[offset],
271            in_amps0[offset + 1],
272            in_amps0[offset + 2],
273            in_amps0[offset + 3],
274        ]);
275
276        let a1 = ComplexVec4::new([
277            in_amps1[offset],
278            in_amps1[offset + 1],
279            in_amps1[offset + 2],
280            in_amps1[offset + 3],
281        ]);
282
283        // Compute complex multiplications
284        let m00a0 = m00.mul(&a0);
285        let m01a1 = m01.mul(&a1);
286        let m10a0 = m10.mul(&a0);
287        let m11a1 = m11.mul(&a1);
288
289        // Compute new amplitudes
290        let new_a0 = m00a0.add(&m01a1);
291        let new_a1 = m10a0.add(&m11a1);
292
293        // Store the results
294        for i in 0..4 {
295            out_amps0[offset + i] = new_a0.get(i);
296            out_amps1[offset + i] = new_a1.get(i);
297        }
298    }
299
300    // Handle remaining elements (less than 4)
301    let remainder_start = chunks * 4;
302    for i in remainder_start..in_amps0.len() {
303        let a0 = in_amps0[i];
304        let a1 = in_amps1[i];
305
306        out_amps0[i] = matrix[0] * a0 + matrix[1] * a1;
307        out_amps1[i] = matrix[2] * a0 + matrix[3] * a1;
308    }
309}
310
311/// Apply X gate to multiple amplitudes using SIMD-like operations
312///
313/// This is a specialized implementation for the Pauli X gate, which simply swaps
314/// amplitudes, making it very efficient to implement.
315///
316/// # Arguments
317///
318/// * `in_amps0` - The first set of input amplitudes (corresponding to bit=0)
319/// * `in_amps1` - The second set of input amplitudes (corresponding to bit=1)
320/// * `out_amps0` - Output buffer for the first set of amplitudes
321/// * `out_amps1` - Output buffer for the second set of amplitudes
322pub fn apply_x_gate_simd(
323    in_amps0: &[Complex64],
324    in_amps1: &[Complex64],
325    out_amps0: &mut [Complex64],
326    out_amps1: &mut [Complex64],
327) {
328    // Simply swap the amplitudes using copy_from_slice
329    out_amps0[..in_amps0.len()].copy_from_slice(&in_amps1[..in_amps0.len()]);
330    out_amps1[..in_amps0.len()].copy_from_slice(in_amps0);
331}
332
333/// Apply Z gate to multiple amplitudes using SIMD-like operations
334///
335/// This is a specialized implementation for the Pauli Z gate, which only flips the
336/// sign of amplitudes where the target bit is 1.
337///
338/// # Arguments
339///
340/// * `in_amps0` - The first set of input amplitudes (corresponding to bit=0)
341/// * `in_amps1` - The second set of input amplitudes (corresponding to bit=1)
342/// * `out_amps0` - Output buffer for the first set of amplitudes
343/// * `out_amps1` - Output buffer for the second set of amplitudes
344pub fn apply_z_gate_simd(
345    in_amps0: &[Complex64],
346    in_amps1: &[Complex64],
347    out_amps0: &mut [Complex64],
348    out_amps1: &mut [Complex64],
349) {
350    // For Z gate, a0 stays the same, a1 gets negated
351    for i in 0..in_amps0.len() {
352        out_amps0[i] = in_amps0[i];
353        out_amps1[i] = -in_amps1[i];
354    }
355}
356
357/// Apply Hadamard gate using SIMD-like operations
358///
359/// This is a specialized implementation for the Hadamard gate using the matrix:
360/// H = 1/√2 * [[1, 1], [1, -1]]
361pub fn apply_h_gate_simd(
362    in_amps0: &[Complex64],
363    in_amps1: &[Complex64],
364    out_amps0: &mut [Complex64],
365    out_amps1: &mut [Complex64],
366) {
367    use std::f64::consts::FRAC_1_SQRT_2;
368    let h_coeff = Complex64::new(FRAC_1_SQRT_2, 0.0);
369
370    // Process elements in chunks of 4
371    let chunks = in_amps0.len() / 4;
372    let h_vec = ComplexVec4::splat(h_coeff);
373
374    for chunk in 0..chunks {
375        let offset = chunk * 4;
376
377        let a0 = ComplexVec4::new([
378            in_amps0[offset],
379            in_amps0[offset + 1],
380            in_amps0[offset + 2],
381            in_amps0[offset + 3],
382        ]);
383
384        let a1 = ComplexVec4::new([
385            in_amps1[offset],
386            in_amps1[offset + 1],
387            in_amps1[offset + 2],
388            in_amps1[offset + 3],
389        ]);
390
391        // H|0⟩ = 1/√2(|0⟩ + |1⟩), H|1⟩ = 1/√2(|0⟩ - |1⟩)
392        let sum = a0.add(&a1);
393        let diff = a0.sub(&a1);
394
395        let new_a0 = h_vec.mul(&sum);
396        let new_a1 = h_vec.mul(&diff);
397
398        for i in 0..4 {
399            out_amps0[offset + i] = new_a0.get(i);
400            out_amps1[offset + i] = new_a1.get(i);
401        }
402    }
403
404    // Handle remaining elements
405    let remainder_start = chunks * 4;
406    for i in remainder_start..in_amps0.len() {
407        let a0 = in_amps0[i];
408        let a1 = in_amps1[i];
409
410        out_amps0[i] = h_coeff * (a0 + a1);
411        out_amps1[i] = h_coeff * (a0 - a1);
412    }
413}
414
415/// Apply Y gate using SIMD-like operations
416///
417/// Y gate: [[0, -i], [i, 0]]
418pub fn apply_y_gate_simd(
419    in_amps0: &[Complex64],
420    in_amps1: &[Complex64],
421    out_amps0: &mut [Complex64],
422    out_amps1: &mut [Complex64],
423) {
424    let i_pos = Complex64::new(0.0, 1.0);
425    let i_neg = Complex64::new(0.0, -1.0);
426
427    // Process elements in chunks of 4
428    let chunks = in_amps0.len() / 4;
429    let i_pos_vec = ComplexVec4::splat(i_pos);
430    let i_neg_vec = ComplexVec4::splat(i_neg);
431
432    for chunk in 0..chunks {
433        let offset = chunk * 4;
434
435        let a0 = ComplexVec4::new([
436            in_amps0[offset],
437            in_amps0[offset + 1],
438            in_amps0[offset + 2],
439            in_amps0[offset + 3],
440        ]);
441
442        let a1 = ComplexVec4::new([
443            in_amps1[offset],
444            in_amps1[offset + 1],
445            in_amps1[offset + 2],
446            in_amps1[offset + 3],
447        ]);
448
449        // Y|0⟩ = i|1⟩, Y|1⟩ = -i|0⟩
450        let new_a0 = i_neg_vec.mul(&a1);
451        let new_a1 = i_pos_vec.mul(&a0);
452
453        for i in 0..4 {
454            out_amps0[offset + i] = new_a0.get(i);
455            out_amps1[offset + i] = new_a1.get(i);
456        }
457    }
458
459    // Handle remaining elements
460    let remainder_start = chunks * 4;
461    for i in remainder_start..in_amps0.len() {
462        let a0 = in_amps0[i];
463        let a1 = in_amps1[i];
464
465        out_amps0[i] = i_neg * a1;
466        out_amps1[i] = i_pos * a0;
467    }
468}
469
470/// Apply phase gate (S gate) using SIMD-like operations
471///
472/// S gate: [[1, 0], [0, i]]
473pub fn apply_s_gate_simd(
474    in_amps0: &[Complex64],
475    in_amps1: &[Complex64],
476    out_amps0: &mut [Complex64],
477    out_amps1: &mut [Complex64],
478) {
479    let i_phase = Complex64::new(0.0, 1.0);
480
481    // Process elements in chunks of 4
482    let chunks = in_amps0.len() / 4;
483    let i_vec = ComplexVec4::splat(i_phase);
484
485    for chunk in 0..chunks {
486        let offset = chunk * 4;
487
488        let a1 = ComplexVec4::new([
489            in_amps1[offset],
490            in_amps1[offset + 1],
491            in_amps1[offset + 2],
492            in_amps1[offset + 3],
493        ]);
494
495        let new_a1 = i_vec.mul(&a1);
496
497        // Copy a0 unchanged, multiply a1 by i
498        for i in 0..4 {
499            out_amps0[offset + i] = in_amps0[offset + i];
500            out_amps1[offset + i] = new_a1.get(i);
501        }
502    }
503
504    // Handle remaining elements
505    let remainder_start = chunks * 4;
506    for i in remainder_start..in_amps0.len() {
507        out_amps0[i] = in_amps0[i];
508        out_amps1[i] = i_phase * in_amps1[i];
509    }
510}
511
512/// Apply rotation-X gate using SIMD-like operations
513///
514/// RX(θ) = [[cos(θ/2), -i*sin(θ/2)], [-i*sin(θ/2), cos(θ/2)]]
515pub fn apply_rx_gate_simd(
516    angle: f64,
517    in_amps0: &[Complex64],
518    in_amps1: &[Complex64],
519    out_amps0: &mut [Complex64],
520    out_amps1: &mut [Complex64],
521) {
522    let half_angle = angle / 2.0;
523    let cos_val = Complex64::new(half_angle.cos(), 0.0);
524    let neg_i_sin_val = Complex64::new(0.0, -half_angle.sin());
525
526    // Process elements in chunks of 4
527    let chunks = in_amps0.len() / 4;
528    let cos_vec = ComplexVec4::splat(cos_val);
529    let neg_i_sin_vec = ComplexVec4::splat(neg_i_sin_val);
530
531    for chunk in 0..chunks {
532        let offset = chunk * 4;
533
534        let a0 = ComplexVec4::new([
535            in_amps0[offset],
536            in_amps0[offset + 1],
537            in_amps0[offset + 2],
538            in_amps0[offset + 3],
539        ]);
540
541        let a1 = ComplexVec4::new([
542            in_amps1[offset],
543            in_amps1[offset + 1],
544            in_amps1[offset + 2],
545            in_amps1[offset + 3],
546        ]);
547
548        let cos_a0 = cos_vec.mul(&a0);
549        let neg_i_sin_a1 = neg_i_sin_vec.mul(&a1);
550        let neg_i_sin_a0 = neg_i_sin_vec.mul(&a0);
551        let cos_a1 = cos_vec.mul(&a1);
552
553        let new_a0 = cos_a0.add(&neg_i_sin_a1);
554        let new_a1 = neg_i_sin_a0.add(&cos_a1);
555
556        for i in 0..4 {
557            out_amps0[offset + i] = new_a0.get(i);
558            out_amps1[offset + i] = new_a1.get(i);
559        }
560    }
561
562    // Handle remaining elements
563    let remainder_start = chunks * 4;
564    for i in remainder_start..in_amps0.len() {
565        let a0 = in_amps0[i];
566        let a1 = in_amps1[i];
567
568        out_amps0[i] = cos_val * a0 + neg_i_sin_val * a1;
569        out_amps1[i] = neg_i_sin_val * a0 + cos_val * a1;
570    }
571}
572
573/// SIMD-optimized wrapper function for applying gates
574///
575/// This function uses enhanced SciRS2 complex SIMD implementations for optimal performance.
576pub fn apply_single_qubit_gate_optimized(
577    matrix: &[Complex64; 4],
578    in_amps0: &[Complex64],
579    in_amps1: &[Complex64],
580    out_amps0: &mut [Complex64],
581    out_amps1: &mut [Complex64],
582) {
583    use std::f64::consts::FRAC_1_SQRT_2;
584
585    // Determine optimal implementation based on vector size and hardware capabilities
586    let vector_size = in_amps0.len();
587    let simd_threshold = 64; // Minimum size to benefit from complex SIMD
588
589    if vector_size >= simd_threshold && ComplexSimdVector::detect_simd_width() > 1 {
590        // Use enhanced complex SIMD implementation for large vectors
591        if is_hadamard_gate(matrix) {
592            apply_hadamard_gate_complex_simd(in_amps0, in_amps1, out_amps0, out_amps1);
593        } else {
594            apply_single_qubit_gate_complex_simd(matrix, in_amps0, in_amps1, out_amps0, out_amps1);
595        }
596    } else {
597        // Fall back to component-wise SIMD for smaller vectors or limited hardware
598        if is_hadamard_gate(matrix) {
599            apply_h_gate_simd_v2(in_amps0, in_amps1, out_amps0, out_amps1);
600        } else {
601            apply_single_qubit_gate_simd_v2(matrix, in_amps0, in_amps1, out_amps0, out_amps1);
602        }
603    }
604}
605
606/// Check if matrix represents a Hadamard gate
607fn is_hadamard_gate(matrix: &[Complex64; 4]) -> bool {
608    use std::f64::consts::FRAC_1_SQRT_2;
609
610    let h_matrix = [
611        Complex64::new(FRAC_1_SQRT_2, 0.0),
612        Complex64::new(FRAC_1_SQRT_2, 0.0),
613        Complex64::new(FRAC_1_SQRT_2, 0.0),
614        Complex64::new(-FRAC_1_SQRT_2, 0.0),
615    ];
616
617    matrix
618        .iter()
619        .zip(h_matrix.iter())
620        .all(|(a, b)| (a - b).norm() < 1e-10)
621}
622
623/// Apply rotation-Y gate using SIMD-like operations
624///
625/// RY(θ) = [[cos(θ/2), -sin(θ/2)], [sin(θ/2), cos(θ/2)]]
626pub fn apply_ry_gate_simd(
627    angle: f64,
628    in_amps0: &[Complex64],
629    in_amps1: &[Complex64],
630    out_amps0: &mut [Complex64],
631    out_amps1: &mut [Complex64],
632) {
633    let half_angle = angle / 2.0;
634    let cos_val = Complex64::new(half_angle.cos(), 0.0);
635    let sin_val = Complex64::new(half_angle.sin(), 0.0);
636    let neg_sin_val = Complex64::new(-half_angle.sin(), 0.0);
637
638    // Process elements in chunks of 4
639    let chunks = in_amps0.len() / 4;
640    let cos_vec = ComplexVec4::splat(cos_val);
641    let sin_vec = ComplexVec4::splat(sin_val);
642    let neg_sin_vec = ComplexVec4::splat(neg_sin_val);
643
644    for chunk in 0..chunks {
645        let offset = chunk * 4;
646
647        let a0 = ComplexVec4::new([
648            in_amps0[offset],
649            in_amps0[offset + 1],
650            in_amps0[offset + 2],
651            in_amps0[offset + 3],
652        ]);
653
654        let a1 = ComplexVec4::new([
655            in_amps1[offset],
656            in_amps1[offset + 1],
657            in_amps1[offset + 2],
658            in_amps1[offset + 3],
659        ]);
660
661        let cos_a0 = cos_vec.mul(&a0);
662        let neg_sin_a1 = neg_sin_vec.mul(&a1);
663        let sin_a0 = sin_vec.mul(&a0);
664        let cos_a1 = cos_vec.mul(&a1);
665
666        let new_a0 = cos_a0.add(&neg_sin_a1);
667        let new_a1 = sin_a0.add(&cos_a1);
668
669        for i in 0..4 {
670            out_amps0[offset + i] = new_a0.get(i);
671            out_amps1[offset + i] = new_a1.get(i);
672        }
673    }
674
675    // Handle remaining elements
676    let remainder_start = chunks * 4;
677    for i in remainder_start..in_amps0.len() {
678        let a0 = in_amps0[i];
679        let a1 = in_amps1[i];
680
681        out_amps0[i] = cos_val * a0 + neg_sin_val * a1;
682        out_amps1[i] = sin_val * a0 + cos_val * a1;
683    }
684}
685
686/// Apply rotation-Z gate using SIMD-like operations
687///
688/// RZ(θ) = [[e^(-iθ/2), 0], [0, e^(iθ/2)]]
689pub fn apply_rz_gate_simd(
690    angle: f64,
691    in_amps0: &[Complex64],
692    in_amps1: &[Complex64],
693    out_amps0: &mut [Complex64],
694    out_amps1: &mut [Complex64],
695) {
696    let half_angle = angle / 2.0;
697    let exp_neg_i = Complex64::new(half_angle.cos(), -half_angle.sin());
698    let exp_pos_i = Complex64::new(half_angle.cos(), half_angle.sin());
699
700    // Process elements in chunks of 4
701    let chunks = in_amps0.len() / 4;
702    let exp_neg_vec = ComplexVec4::splat(exp_neg_i);
703    let exp_pos_vec = ComplexVec4::splat(exp_pos_i);
704
705    for chunk in 0..chunks {
706        let offset = chunk * 4;
707
708        let a0 = ComplexVec4::new([
709            in_amps0[offset],
710            in_amps0[offset + 1],
711            in_amps0[offset + 2],
712            in_amps0[offset + 3],
713        ]);
714
715        let a1 = ComplexVec4::new([
716            in_amps1[offset],
717            in_amps1[offset + 1],
718            in_amps1[offset + 2],
719            in_amps1[offset + 3],
720        ]);
721
722        let new_a0 = exp_neg_vec.mul(&a0);
723        let new_a1 = exp_pos_vec.mul(&a1);
724
725        for i in 0..4 {
726            out_amps0[offset + i] = new_a0.get(i);
727            out_amps1[offset + i] = new_a1.get(i);
728        }
729    }
730
731    // Handle remaining elements
732    let remainder_start = chunks * 4;
733    for i in remainder_start..in_amps0.len() {
734        out_amps0[i] = exp_neg_i * in_amps0[i];
735        out_amps1[i] = exp_pos_i * in_amps1[i];
736    }
737}
738
739/// Apply T gate using SIMD-like operations
740///
741/// T gate: [[1, 0], [0, e^(iπ/4)]]
742pub fn apply_t_gate_simd(
743    in_amps0: &[Complex64],
744    in_amps1: &[Complex64],
745    out_amps0: &mut [Complex64],
746    out_amps1: &mut [Complex64],
747) {
748    use std::f64::consts::FRAC_PI_4;
749    let t_phase = Complex64::new(FRAC_PI_4.cos(), FRAC_PI_4.sin());
750
751    // Process elements in chunks of 4
752    let chunks = in_amps0.len() / 4;
753    let t_vec = ComplexVec4::splat(t_phase);
754
755    for chunk in 0..chunks {
756        let offset = chunk * 4;
757
758        let a1 = ComplexVec4::new([
759            in_amps1[offset],
760            in_amps1[offset + 1],
761            in_amps1[offset + 2],
762            in_amps1[offset + 3],
763        ]);
764
765        let new_a1 = t_vec.mul(&a1);
766
767        // Copy a0 unchanged, multiply a1 by t_phase
768        for i in 0..4 {
769            out_amps0[offset + i] = in_amps0[offset + i];
770            out_amps1[offset + i] = new_a1.get(i);
771        }
772    }
773
774    // Handle remaining elements
775    let remainder_start = chunks * 4;
776    for i in remainder_start..in_amps0.len() {
777        out_amps0[i] = in_amps0[i];
778        out_amps1[i] = t_phase * in_amps1[i];
779    }
780}
781
782/// Gate fusion structure for combining adjacent single-qubit gates
783#[derive(Debug, Clone)]
784pub struct GateFusion {
785    /// Fused matrix representation
786    pub fused_matrix: [Complex64; 4],
787    /// Target qubit
788    pub target: usize,
789    /// Number of gates fused
790    pub gate_count: usize,
791}
792
793impl GateFusion {
794    /// Create a new gate fusion starting with an identity gate
795    pub fn new(target: usize) -> Self {
796        Self {
797            fused_matrix: [
798                Complex64::new(1.0, 0.0), // I[0,0]
799                Complex64::new(0.0, 0.0), // I[0,1]
800                Complex64::new(0.0, 0.0), // I[1,0]
801                Complex64::new(1.0, 0.0), // I[1,1]
802            ],
803            target,
804            gate_count: 0,
805        }
806    }
807
808    /// Fuse another gate into this fusion
809    pub fn fuse_gate(&mut self, gate_matrix: &[Complex64; 4]) {
810        // Matrix multiplication: new_matrix = gate_matrix * fused_matrix
811        let m = &self.fused_matrix;
812        let g = gate_matrix;
813
814        self.fused_matrix = [
815            g[0] * m[0] + g[1] * m[2], // (0,0)
816            g[0] * m[1] + g[1] * m[3], // (0,1)
817            g[2] * m[0] + g[3] * m[2], // (1,0)
818            g[2] * m[1] + g[3] * m[3], // (1,1)
819        ];
820
821        self.gate_count += 1;
822    }
823
824    /// Check if this fusion can be applied using a specialized SIMD kernel
825    pub fn can_use_specialized_kernel(&self) -> bool {
826        use std::f64::consts::FRAC_1_SQRT_2;
827
828        // Check for common gate patterns after fusion
829        let m = &self.fused_matrix;
830
831        // Identity gate (no-op)
832        if (m[0] - Complex64::new(1.0, 0.0)).norm() < 1e-10
833            && m[1].norm() < 1e-10
834            && m[2].norm() < 1e-10
835            && (m[3] - Complex64::new(1.0, 0.0)).norm() < 1e-10
836        {
837            return true;
838        }
839
840        // X gate
841        if m[0].norm() < 1e-10
842            && (m[1] - Complex64::new(1.0, 0.0)).norm() < 1e-10
843            && (m[2] - Complex64::new(1.0, 0.0)).norm() < 1e-10
844            && m[3].norm() < 1e-10
845        {
846            return true;
847        }
848
849        // Y gate
850        if m[0].norm() < 1e-10
851            && (m[1] - Complex64::new(0.0, -1.0)).norm() < 1e-10
852            && (m[2] - Complex64::new(0.0, 1.0)).norm() < 1e-10
853            && m[3].norm() < 1e-10
854        {
855            return true;
856        }
857
858        // Z gate
859        if (m[0] - Complex64::new(1.0, 0.0)).norm() < 1e-10
860            && m[1].norm() < 1e-10
861            && m[2].norm() < 1e-10
862            && (m[3] - Complex64::new(-1.0, 0.0)).norm() < 1e-10
863        {
864            return true;
865        }
866
867        // Hadamard gate
868        if (m[0] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10
869            && (m[1] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10
870            && (m[2] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10
871            && (m[3] - Complex64::new(-FRAC_1_SQRT_2, 0.0)).norm() < 1e-10
872        {
873            return true;
874        }
875
876        false
877    }
878
879    /// Apply the fused gate using SIMD optimization
880    pub fn apply_simd(
881        &self,
882        in_amps0: &[Complex64],
883        in_amps1: &[Complex64],
884        out_amps0: &mut [Complex64],
885        out_amps1: &mut [Complex64],
886    ) {
887        apply_single_qubit_gate_optimized(
888            &self.fused_matrix,
889            in_amps0,
890            in_amps1,
891            out_amps0,
892            out_amps1,
893        );
894    }
895}
896
897/// Vectorized CNOT gate application using SIMD for processing multiple pairs
898///
899/// This processes control/target pairs in parallel where possible
900pub fn apply_cnot_vectorized(
901    state: &mut [Complex64],
902    control_indices: &[usize],
903    target_indices: &[usize],
904    num_qubits: usize,
905) {
906    let dim = 1 << num_qubits;
907    let mut new_state = vec![Complex64::new(0.0, 0.0); dim];
908
909    // Process all CNOT gates in parallel
910    new_state
911        .par_iter_mut()
912        .enumerate()
913        .for_each(|(i, new_amp)| {
914            let mut final_idx = i;
915
916            // Apply all CNOT gates in sequence
917            for (&control_idx, &target_idx) in control_indices.iter().zip(target_indices.iter()) {
918                if (final_idx >> control_idx) & 1 == 1 {
919                    final_idx ^= 1 << target_idx;
920                }
921            }
922
923            *new_amp = state[final_idx];
924        });
925
926    state.copy_from_slice(&new_state);
927}
928
929/// Scalar implementation of apply_single_qubit_gate for fallback
930///
931/// # Arguments
932///
933/// * `matrix` - The 2x2 matrix representation of the gate
934/// * `in_amps0` - The first set of input amplitudes (corresponding to bit=0)
935/// * `in_amps1` - The second set of input amplitudes (corresponding to bit=1)
936/// * `out_amps0` - Output buffer for the first set of amplitudes
937/// * `out_amps1` - Output buffer for the second set of amplitudes
938pub fn apply_single_qubit_gate_scalar(
939    matrix: &[Complex64; 4],
940    in_amps0: &[Complex64],
941    in_amps1: &[Complex64],
942    out_amps0: &mut [Complex64],
943    out_amps1: &mut [Complex64],
944) {
945    for i in 0..in_amps0.len() {
946        let a0 = in_amps0[i];
947        let a1 = in_amps1[i];
948
949        out_amps0[i] = matrix[0] * a0 + matrix[1] * a1;
950        out_amps1[i] = matrix[2] * a0 + matrix[3] * a1;
951    }
952}
953
954#[cfg(test)]
955mod tests {
956    use super::*;
957    use std::f64::consts::FRAC_1_SQRT_2;
958
959    #[test]
960    fn test_x_gate_scalar() {
961        // X gate matrix
962        let x_matrix = [
963            Complex64::new(0.0, 0.0),
964            Complex64::new(1.0, 0.0),
965            Complex64::new(1.0, 0.0),
966            Complex64::new(0.0, 0.0),
967        ];
968
969        // Test data
970        let in_amps0 = vec![Complex64::new(1.0, 0.0), Complex64::new(0.5, 0.0)];
971        let in_amps1 = vec![Complex64::new(0.0, 0.0), Complex64::new(0.5, 0.0)];
972        let mut out_amps0 = [Complex64::new(0.0, 0.0); 2];
973        let mut out_amps1 = [Complex64::new(0.0, 0.0); 2];
974
975        // Apply gate
976        apply_single_qubit_gate_scalar(
977            &x_matrix,
978            &in_amps0,
979            &in_amps1,
980            &mut out_amps0,
981            &mut out_amps1,
982        );
983
984        // Check results
985        assert_eq!(out_amps0[0], Complex64::new(0.0, 0.0));
986        assert_eq!(out_amps1[0], Complex64::new(1.0, 0.0));
987        assert_eq!(out_amps0[1], Complex64::new(0.5, 0.0));
988        assert_eq!(out_amps1[1], Complex64::new(0.5, 0.0));
989    }
990
991    #[test]
992    fn test_hadamard_gate_scalar() {
993        // Hadamard gate matrix
994        let h_matrix = [
995            Complex64::new(FRAC_1_SQRT_2, 0.0),
996            Complex64::new(FRAC_1_SQRT_2, 0.0),
997            Complex64::new(FRAC_1_SQRT_2, 0.0),
998            Complex64::new(-FRAC_1_SQRT_2, 0.0),
999        ];
1000
1001        // Test data
1002        let in_amps0 = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
1003        let in_amps1 = vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)];
1004        let mut out_amps0 = [Complex64::new(0.0, 0.0); 2];
1005        let mut out_amps1 = [Complex64::new(0.0, 0.0); 2];
1006
1007        // Apply gate
1008        apply_single_qubit_gate_scalar(
1009            &h_matrix,
1010            &in_amps0,
1011            &in_amps1,
1012            &mut out_amps0,
1013            &mut out_amps1,
1014        );
1015
1016        // Check results - applying H to |0> should give (|0> + |1>)/sqrt(2)
1017        assert!((out_amps0[0] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
1018        assert!((out_amps1[0] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
1019
1020        // Applying H to |1> should give (|0> - |1>)/sqrt(2)
1021        assert!((out_amps0[1] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
1022        assert!((out_amps1[1] - Complex64::new(-FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
1023    }
1024
1025    #[test]
1026    fn test_optimized_gate_wrapper() {
1027        // Hadamard gate matrix
1028        let h_matrix = [
1029            Complex64::new(FRAC_1_SQRT_2, 0.0),
1030            Complex64::new(FRAC_1_SQRT_2, 0.0),
1031            Complex64::new(FRAC_1_SQRT_2, 0.0),
1032            Complex64::new(-FRAC_1_SQRT_2, 0.0),
1033        ];
1034
1035        // Test data
1036        let in_amps0 = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
1037        let in_amps1 = vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)];
1038        let mut out_amps0 = [Complex64::new(0.0, 0.0); 2];
1039        let mut out_amps1 = [Complex64::new(0.0, 0.0); 2];
1040
1041        // Apply gate using the optimized wrapper
1042        apply_single_qubit_gate_optimized(
1043            &h_matrix,
1044            &in_amps0,
1045            &in_amps1,
1046            &mut out_amps0,
1047            &mut out_amps1,
1048        );
1049
1050        // Check results - applying H to |0> should give (|0> + |1>)/sqrt(2)
1051        assert!((out_amps0[0] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
1052        assert!((out_amps1[0] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
1053
1054        // Applying H to |1> should give (|0> - |1>)/sqrt(2)
1055        assert!((out_amps0[1] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
1056        assert!((out_amps1[1] - Complex64::new(-FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
1057    }
1058
1059    #[test]
1060    fn test_complex_vec4() {
1061        // Test splat creation
1062        let a = ComplexVec4::splat(Complex64::new(1.0, 2.0));
1063        for i in 0..4 {
1064            assert_eq!(a.get(i), Complex64::new(1.0, 2.0));
1065        }
1066
1067        // Test new creation
1068        let b = ComplexVec4::new([
1069            Complex64::new(1.0, 2.0),
1070            Complex64::new(3.0, 4.0),
1071            Complex64::new(5.0, 6.0),
1072            Complex64::new(7.0, 8.0),
1073        ]);
1074
1075        assert_eq!(b.get(0), Complex64::new(1.0, 2.0));
1076        assert_eq!(b.get(1), Complex64::new(3.0, 4.0));
1077        assert_eq!(b.get(2), Complex64::new(5.0, 6.0));
1078        assert_eq!(b.get(3), Complex64::new(7.0, 8.0));
1079
1080        // Test multiplication
1081        let c = a.mul(&b);
1082        assert!((c.get(0) - Complex64::new(1.0, 2.0) * Complex64::new(1.0, 2.0)).norm() < 1e-10);
1083        assert!((c.get(1) - Complex64::new(1.0, 2.0) * Complex64::new(3.0, 4.0)).norm() < 1e-10);
1084        assert!((c.get(2) - Complex64::new(1.0, 2.0) * Complex64::new(5.0, 6.0)).norm() < 1e-10);
1085        assert!((c.get(3) - Complex64::new(1.0, 2.0) * Complex64::new(7.0, 8.0)).norm() < 1e-10);
1086
1087        // Test addition
1088        let d = a.add(&b);
1089        assert!((d.get(0) - (Complex64::new(1.0, 2.0) + Complex64::new(1.0, 2.0))).norm() < 1e-10);
1090        assert!((d.get(1) - (Complex64::new(1.0, 2.0) + Complex64::new(3.0, 4.0))).norm() < 1e-10);
1091        assert!((d.get(2) - (Complex64::new(1.0, 2.0) + Complex64::new(5.0, 6.0))).norm() < 1e-10);
1092        assert!((d.get(3) - (Complex64::new(1.0, 2.0) + Complex64::new(7.0, 8.0))).norm() < 1e-10);
1093
1094        // Test negation
1095        let e = b.neg();
1096        assert!((e.get(0) - (-Complex64::new(1.0, 2.0))).norm() < 1e-10);
1097        assert!((e.get(1) - (-Complex64::new(3.0, 4.0))).norm() < 1e-10);
1098        assert!((e.get(2) - (-Complex64::new(5.0, 6.0))).norm() < 1e-10);
1099        assert!((e.get(3) - (-Complex64::new(7.0, 8.0))).norm() < 1e-10);
1100    }
1101}