quantrs2_sim/
scirs2_complex_simd.rs

1//! Enhanced Complex Number SIMD Operations using `SciRS2`
2//!
3//! This module provides advanced SIMD implementations specifically optimized
4//! for complex number arithmetic in quantum state vector operations.
5//! It leverages `SciRS2`'s `SimdUnifiedOps` for maximum performance.
6
7use quantrs2_core::platform::PlatformCapabilities;
8use scirs2_core::ndarray::{Array1, ArrayView1, ArrayViewMut1};
9use scirs2_core::simd_ops::SimdUnifiedOps;
10use scirs2_core::Complex64;
11
12#[cfg(target_arch = "x86_64")]
13use std::arch::x86_64::*;
14
15/// Complex SIMD vector operations using native `SciRS2` primitives
16#[derive(Debug, Clone)]
17pub struct ComplexSimdVector {
18    /// Real components (SIMD-aligned)
19    real: Vec<f64>,
20    /// Imaginary components (SIMD-aligned)
21    imag: Vec<f64>,
22    /// Number of complex elements
23    length: usize,
24    /// SIMD lane width for this platform
25    simd_width: usize,
26}
27
28impl ComplexSimdVector {
29    /// Create a new SIMD-aligned complex vector
30    #[must_use]
31    pub fn new(length: usize) -> Self {
32        let simd_width = Self::detect_simd_width();
33        let aligned_length = Self::align_length(length, simd_width);
34
35        Self {
36            real: vec![0.0; aligned_length],
37            imag: vec![0.0; aligned_length],
38            length,
39            simd_width,
40        }
41    }
42
43    /// Create from slice of Complex64
44    #[must_use]
45    pub fn from_slice(data: &[Complex64]) -> Self {
46        let mut vec = Self::new(data.len());
47        for (i, &complex) in data.iter().enumerate() {
48            vec.real[i] = complex.re;
49            vec.imag[i] = complex.im;
50        }
51        vec
52    }
53
54    /// Convert back to Complex64 slice
55    #[must_use]
56    pub fn to_complex_vec(&self) -> Vec<Complex64> {
57        (0..self.length)
58            .map(|i| Complex64::new(self.real[i], self.imag[i]))
59            .collect()
60    }
61
62    /// Detect optimal SIMD width for current hardware using `PlatformCapabilities`
63    #[must_use]
64    pub fn detect_simd_width() -> usize {
65        PlatformCapabilities::detect().optimal_simd_width_f64()
66    }
67
68    /// Align length to SIMD boundary
69    const fn align_length(length: usize, simd_width: usize) -> usize {
70        length.div_ceil(simd_width) * simd_width
71    }
72
73    /// Get real part as array view
74    #[must_use]
75    pub fn real_view(&self) -> ArrayView1<'_, f64> {
76        ArrayView1::from(&self.real[..self.length])
77    }
78
79    /// Get imaginary part as array view
80    #[must_use]
81    pub fn imag_view(&self) -> ArrayView1<'_, f64> {
82        ArrayView1::from(&self.imag[..self.length])
83    }
84
85    /// Get length
86    #[must_use]
87    pub const fn len(&self) -> usize {
88        self.length
89    }
90
91    /// Check if empty
92    #[must_use]
93    pub const fn is_empty(&self) -> bool {
94        self.length == 0
95    }
96}
97
98/// High-performance complex arithmetic operations using `SciRS2` SIMD
99pub struct ComplexSimdOps;
100
101impl ComplexSimdOps {
102    /// Complex multiplication: c = a * b, vectorized
103    pub fn complex_mul_simd(
104        a: &ComplexSimdVector,
105        b: &ComplexSimdVector,
106        c: &mut ComplexSimdVector,
107    ) {
108        assert_eq!(a.len(), b.len());
109        assert_eq!(a.len(), c.len());
110
111        let a_real = a.real_view();
112        let a_imag = a.imag_view();
113        let b_real = b.real_view();
114        let b_imag = b.imag_view();
115
116        // Complex multiplication: (a_r + i*a_i) * (b_r + i*b_i)
117        // = (a_r*b_r - a_i*b_i) + i*(a_r*b_i + a_i*b_r)
118
119        // Real part: a_r*b_r - a_i*b_i using SciRS2 SIMD operations
120        let ar_br = f64::simd_mul(&a_real, &b_real);
121        let ai_bi = f64::simd_mul(&a_imag, &b_imag);
122        let real_result = f64::simd_sub(&ar_br.view(), &ai_bi.view());
123
124        // Imaginary part: a_r*b_i + a_i*b_r using SciRS2 SIMD operations
125        let ar_bi = f64::simd_mul(&a_real, &b_imag);
126        let ai_br = f64::simd_mul(&a_imag, &b_real);
127        let imag_result = f64::simd_add(&ar_bi.view(), &ai_br.view());
128
129        // Store results
130        for i in 0..c.length {
131            c.real[i] = real_result[i];
132            c.imag[i] = imag_result[i];
133        }
134    }
135
136    /// Complex addition: c = a + b, vectorized
137    pub fn complex_add_simd(
138        a: &ComplexSimdVector,
139        b: &ComplexSimdVector,
140        c: &mut ComplexSimdVector,
141    ) {
142        assert_eq!(a.len(), b.len());
143        assert_eq!(a.len(), c.len());
144
145        let a_real = a.real_view();
146        let a_imag = a.imag_view();
147        let b_real = b.real_view();
148        let b_imag = b.imag_view();
149
150        let real_result = f64::simd_add(&a_real, &b_real);
151        let imag_result = f64::simd_add(&a_imag, &b_imag);
152
153        for i in 0..c.length {
154            c.real[i] = real_result[i];
155            c.imag[i] = imag_result[i];
156        }
157    }
158
159    /// Complex subtraction: c = a - b, vectorized
160    pub fn complex_sub_simd(
161        a: &ComplexSimdVector,
162        b: &ComplexSimdVector,
163        c: &mut ComplexSimdVector,
164    ) {
165        assert_eq!(a.len(), b.len());
166        assert_eq!(a.len(), c.len());
167
168        let a_real = a.real_view();
169        let a_imag = a.imag_view();
170        let b_real = b.real_view();
171        let b_imag = b.imag_view();
172
173        let real_result = f64::simd_sub(&a_real, &b_real);
174        let imag_result = f64::simd_sub(&a_imag, &b_imag);
175
176        for i in 0..c.length {
177            c.real[i] = real_result[i];
178            c.imag[i] = imag_result[i];
179        }
180    }
181
182    /// Scalar complex multiplication: c = a * scalar, vectorized
183    pub fn complex_scalar_mul_simd(
184        a: &ComplexSimdVector,
185        scalar: Complex64,
186        c: &mut ComplexSimdVector,
187    ) {
188        assert_eq!(a.len(), c.len());
189
190        let a_real = a.real_view();
191        let a_imag = a.imag_view();
192
193        // (a_r + i*a_i) * (s_r + i*s_i) = (a_r*s_r - a_i*s_i) + i*(a_r*s_i + a_i*s_r)
194        let ar_sr = f64::simd_scalar_mul(&a_real, scalar.re);
195        let ai_si = f64::simd_scalar_mul(&a_imag, scalar.im);
196        let real_result = f64::simd_sub(&ar_sr.view(), &ai_si.view());
197
198        let ar_si = f64::simd_scalar_mul(&a_real, scalar.im);
199        let ai_sr = f64::simd_scalar_mul(&a_imag, scalar.re);
200        let imag_result = f64::simd_add(&ar_si.view(), &ai_sr.view());
201
202        for i in 0..c.length {
203            c.real[i] = real_result[i];
204            c.imag[i] = imag_result[i];
205        }
206    }
207
208    /// Complex conjugate: c = conj(a), vectorized
209    pub fn complex_conj_simd(a: &ComplexSimdVector, c: &mut ComplexSimdVector) {
210        assert_eq!(a.len(), c.len());
211
212        let a_real = a.real_view();
213        let a_imag = a.imag_view();
214
215        // Copy real part unchanged
216        for i in 0..c.length {
217            c.real[i] = a.real[i];
218        }
219
220        // Negate imaginary part
221        let zero_array = Array1::zeros(a.length);
222        let negated_imag = f64::simd_sub(&zero_array.view(), &a_imag);
223
224        for i in 0..c.length {
225            c.imag[i] = negated_imag[i];
226        }
227    }
228
229    /// Complex magnitude squared: |a|^2, vectorized
230    #[must_use]
231    pub fn complex_norm_squared_simd(a: &ComplexSimdVector) -> Vec<f64> {
232        let a_real = a.real_view();
233        let a_imag = a.imag_view();
234
235        let real_sq = f64::simd_mul(&a_real, &a_real);
236        let imag_sq = f64::simd_mul(&a_imag, &a_imag);
237        let norm_sq = f64::simd_add(&real_sq.view(), &imag_sq.view());
238
239        norm_sq.to_vec()
240    }
241}
242
243/// Enhanced single-qubit gate application with native complex SIMD
244pub fn apply_single_qubit_gate_complex_simd(
245    matrix: &[Complex64; 4],
246    in_amps0: &[Complex64],
247    in_amps1: &[Complex64],
248    out_amps0: &mut [Complex64],
249    out_amps1: &mut [Complex64],
250) {
251    let len = in_amps0.len();
252
253    // Convert to SIMD vectors
254    let a0_simd = ComplexSimdVector::from_slice(in_amps0);
255    let a1_simd = ComplexSimdVector::from_slice(in_amps1);
256
257    let mut result0_simd = ComplexSimdVector::new(len);
258    let mut result1_simd = ComplexSimdVector::new(len);
259
260    // Temporary vectors for intermediate results
261    let mut temp0_simd = ComplexSimdVector::new(len);
262    let mut temp1_simd = ComplexSimdVector::new(len);
263
264    // Compute: out_a0 = matrix[0] * a0 + matrix[1] * a1
265    ComplexSimdOps::complex_scalar_mul_simd(&a0_simd, matrix[0], &mut temp0_simd);
266    ComplexSimdOps::complex_scalar_mul_simd(&a1_simd, matrix[1], &mut temp1_simd);
267    ComplexSimdOps::complex_add_simd(&temp0_simd, &temp1_simd, &mut result0_simd);
268
269    // Compute: out_a1 = matrix[2] * a0 + matrix[3] * a1
270    ComplexSimdOps::complex_scalar_mul_simd(&a0_simd, matrix[2], &mut temp0_simd);
271    ComplexSimdOps::complex_scalar_mul_simd(&a1_simd, matrix[3], &mut temp1_simd);
272    ComplexSimdOps::complex_add_simd(&temp0_simd, &temp1_simd, &mut result1_simd);
273
274    // Convert back to Complex64 arrays
275    let result0 = result0_simd.to_complex_vec();
276    let result1 = result1_simd.to_complex_vec();
277
278    out_amps0.copy_from_slice(&result0);
279    out_amps1.copy_from_slice(&result1);
280}
281
282/// Enhanced Hadamard gate with complex SIMD optimizations
283pub fn apply_hadamard_gate_complex_simd(
284    in_amps0: &[Complex64],
285    in_amps1: &[Complex64],
286    out_amps0: &mut [Complex64],
287    out_amps1: &mut [Complex64],
288) {
289    let len = in_amps0.len();
290    let sqrt2_inv = Complex64::new(std::f64::consts::FRAC_1_SQRT_2, 0.0);
291
292    let a0_simd = ComplexSimdVector::from_slice(in_amps0);
293    let a1_simd = ComplexSimdVector::from_slice(in_amps1);
294
295    let mut sum_simd = ComplexSimdVector::new(len);
296    let mut diff_simd = ComplexSimdVector::new(len);
297    let mut result0_simd = ComplexSimdVector::new(len);
298    let mut result1_simd = ComplexSimdVector::new(len);
299
300    // Hadamard: out_a0 = (a0 + a1) / sqrt(2), out_a1 = (a0 - a1) / sqrt(2)
301    ComplexSimdOps::complex_add_simd(&a0_simd, &a1_simd, &mut sum_simd);
302    ComplexSimdOps::complex_sub_simd(&a0_simd, &a1_simd, &mut diff_simd);
303
304    ComplexSimdOps::complex_scalar_mul_simd(&sum_simd, sqrt2_inv, &mut result0_simd);
305    ComplexSimdOps::complex_scalar_mul_simd(&diff_simd, sqrt2_inv, &mut result1_simd);
306
307    let result0 = result0_simd.to_complex_vec();
308    let result1 = result1_simd.to_complex_vec();
309
310    out_amps0.copy_from_slice(&result0);
311    out_amps1.copy_from_slice(&result1);
312}
313
314/// Optimized CNOT gate for multiple qubit pairs using complex SIMD
315pub fn apply_cnot_complex_simd(
316    state: &mut [Complex64],
317    control_qubit: usize,
318    target_qubit: usize,
319    num_qubits: usize,
320) {
321    let dim = 1 << num_qubits;
322    let control_mask = 1 << control_qubit;
323    let target_mask = 1 << target_qubit;
324
325    // Process state in SIMD chunks where possible
326    let chunk_size = ComplexSimdVector::detect_simd_width();
327    let num_chunks = dim / (chunk_size * 2); // Process pairs of indices
328
329    for chunk in 0..num_chunks {
330        let base_idx = chunk * chunk_size * 2;
331        let mut chunk_data = vec![Complex64::new(0.0, 0.0); chunk_size * 2];
332
333        // Collect indices that need swapping
334        let mut swap_indices = Vec::new();
335        for i in 0..chunk_size {
336            let idx = base_idx + i;
337            if idx < dim && (idx & control_mask) != 0 {
338                let swapped_idx = idx ^ target_mask;
339                if swapped_idx < dim {
340                    swap_indices.push((idx, swapped_idx));
341                    chunk_data[i * 2] = state[idx];
342                    chunk_data[i * 2 + 1] = state[swapped_idx];
343                }
344            }
345        }
346
347        // Apply swaps using SIMD operations
348        if !swap_indices.is_empty() {
349            let chunk_simd = ComplexSimdVector::from_slice(&chunk_data);
350            for (i, (idx, swapped_idx)) in swap_indices.iter().enumerate() {
351                state[*idx] = chunk_simd.to_complex_vec()[i * 2 + 1];
352                state[*swapped_idx] = chunk_simd.to_complex_vec()[i * 2];
353            }
354        }
355    }
356
357    // Handle remaining elements with scalar operations
358    let remaining_start = num_chunks * chunk_size * 2;
359    for i in remaining_start..dim {
360        if (i & control_mask) != 0 {
361            let swapped_i = i ^ target_mask;
362            if swapped_i < dim {
363                state.swap(i, swapped_i);
364            }
365        }
366    }
367}
368
369/// Performance benchmarking for complex SIMD operations
370#[must_use]
371pub fn benchmark_complex_simd_operations() -> std::collections::HashMap<String, f64> {
372    use std::time::Instant;
373    let mut results = std::collections::HashMap::new();
374
375    // Test vector sizes
376    let sizes = vec![1024, 4096, 16_384, 65_536];
377
378    for &size in &sizes {
379        let a = ComplexSimdVector::from_slice(&vec![Complex64::new(1.0, 0.5); size]);
380        let b = ComplexSimdVector::from_slice(&vec![Complex64::new(0.5, 1.0); size]);
381        let mut c = ComplexSimdVector::new(size);
382
383        // Benchmark complex multiplication
384        let start = Instant::now();
385        for _ in 0..1000 {
386            ComplexSimdOps::complex_mul_simd(&a, &b, &mut c);
387        }
388        let mul_time = start.elapsed().as_nanos() as f64 / 1000.0;
389        results.insert(format!("complex_mul_simd_{size}"), mul_time);
390
391        // Benchmark complex addition
392        let start = Instant::now();
393        for _ in 0..1000 {
394            ComplexSimdOps::complex_add_simd(&a, &b, &mut c);
395        }
396        let add_time = start.elapsed().as_nanos() as f64 / 1000.0;
397        results.insert(format!("complex_add_simd_{size}"), add_time);
398
399        // Benchmark scalar multiplication
400        let scalar = Complex64::new(2.0, 1.0);
401        let start = Instant::now();
402        for _ in 0..1000 {
403            ComplexSimdOps::complex_scalar_mul_simd(&a, scalar, &mut c);
404        }
405        let scalar_mul_time = start.elapsed().as_nanos() as f64 / 1000.0;
406        results.insert(format!("complex_scalar_mul_simd_{size}"), scalar_mul_time);
407    }
408
409    results
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use std::f64::consts::FRAC_1_SQRT_2;
416
417    #[test]
418    fn test_complex_simd_vector_creation() {
419        let data = vec![
420            Complex64::new(1.0, 2.0),
421            Complex64::new(3.0, 4.0),
422            Complex64::new(5.0, 6.0),
423        ];
424
425        let simd_vec = ComplexSimdVector::from_slice(&data);
426        assert_eq!(simd_vec.len(), 3);
427
428        let result = simd_vec.to_complex_vec();
429        for (i, &expected) in data.iter().enumerate() {
430            assert!((result[i] - expected).norm() < 1e-10);
431        }
432    }
433
434    #[test]
435    fn test_complex_multiplication_simd() {
436        let a =
437            ComplexSimdVector::from_slice(&[Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)]);
438        let b =
439            ComplexSimdVector::from_slice(&[Complex64::new(5.0, 6.0), Complex64::new(7.0, 8.0)]);
440        let mut c = ComplexSimdVector::new(2);
441
442        ComplexSimdOps::complex_mul_simd(&a, &b, &mut c);
443        let result = c.to_complex_vec();
444
445        // Verify: (1+2i)*(5+6i) = 5+6i+10i-12 = -7+16i
446        let expected0 = Complex64::new(-7.0, 16.0);
447        assert!((result[0] - expected0).norm() < 1e-10);
448
449        // Verify: (3+4i)*(7+8i) = 21+24i+28i-32 = -11+52i
450        let expected1 = Complex64::new(-11.0, 52.0);
451        assert!((result[1] - expected1).norm() < 1e-10);
452    }
453
454    #[test]
455    fn test_hadamard_gate_complex_simd() {
456        let in_amps0 = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
457        let in_amps1 = vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)];
458        let mut out_amps0 = vec![Complex64::new(0.0, 0.0); 2];
459        let mut out_amps1 = vec![Complex64::new(0.0, 0.0); 2];
460
461        apply_hadamard_gate_complex_simd(&in_amps0, &in_amps1, &mut out_amps0, &mut out_amps1);
462
463        // H|0⟩ = (|0⟩ + |1⟩)/√2
464        let expected = Complex64::new(FRAC_1_SQRT_2, 0.0);
465        assert!((out_amps0[0] - expected).norm() < 1e-10);
466        assert!((out_amps1[0] - expected).norm() < 1e-10);
467
468        // H|1⟩ = (|0⟩ - |1⟩)/√2
469        assert!((out_amps0[1] - expected).norm() < 1e-10);
470        assert!((out_amps1[1] - (-expected)).norm() < 1e-10);
471    }
472
473    #[test]
474    fn test_single_qubit_gate_complex_simd() {
475        // X gate matrix
476        let x_matrix = [
477            Complex64::new(0.0, 0.0),
478            Complex64::new(1.0, 0.0),
479            Complex64::new(1.0, 0.0),
480            Complex64::new(0.0, 0.0),
481        ];
482
483        let in_amps0 = vec![Complex64::new(1.0, 0.0), Complex64::new(0.5, 0.0)];
484        let in_amps1 = vec![Complex64::new(0.0, 0.0), Complex64::new(0.5, 0.0)];
485        let mut out_amps0 = vec![Complex64::new(0.0, 0.0); 2];
486        let mut out_amps1 = vec![Complex64::new(0.0, 0.0); 2];
487
488        apply_single_qubit_gate_complex_simd(
489            &x_matrix,
490            &in_amps0,
491            &in_amps1,
492            &mut out_amps0,
493            &mut out_amps1,
494        );
495
496        // X gate should swap amplitudes
497        assert!((out_amps0[0] - Complex64::new(0.0, 0.0)).norm() < 1e-10);
498        assert!((out_amps1[0] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
499        assert!((out_amps0[1] - Complex64::new(0.5, 0.0)).norm() < 1e-10);
500        assert!((out_amps1[1] - Complex64::new(0.5, 0.0)).norm() < 1e-10);
501    }
502
503    #[test]
504    fn test_complex_norm_squared_simd() {
505        let data = vec![
506            Complex64::new(3.0, 4.0), // |3+4i|² = 9+16 = 25
507            Complex64::new(1.0, 1.0), // |1+i|² = 1+1 = 2
508        ];
509
510        let simd_vec = ComplexSimdVector::from_slice(&data);
511        let norms = ComplexSimdOps::complex_norm_squared_simd(&simd_vec);
512
513        assert!((norms[0] - 25.0).abs() < 1e-10);
514        assert!((norms[1] - 2.0).abs() < 1e-10);
515    }
516}