quantrs2_sim/
scirs2_complex_simd.rs

1//! Enhanced Complex Number SIMD Operations using SciRS2
2//!
3//! This module provides advanced SIMD implementations specifically optimized
4//! for complex number arithmetic in quantum state vector operations.
5//! It leverages SciRS2's SimdUnifiedOps for maximum performance.
6
7use quantrs2_core::platform::PlatformCapabilities;
8use scirs2_core::ndarray::{Array1, ArrayView1, ArrayViewMut1};
9use scirs2_core::simd_ops::SimdUnifiedOps;
10use scirs2_core::Complex64;
11
12#[cfg(target_arch = "x86_64")]
13use std::arch::x86_64::*;
14
15/// Complex SIMD vector operations using native SciRS2 primitives
16#[derive(Debug, Clone)]
17pub struct ComplexSimdVector {
18    /// Real components (SIMD-aligned)
19    real: Vec<f64>,
20    /// Imaginary components (SIMD-aligned)
21    imag: Vec<f64>,
22    /// Number of complex elements
23    length: usize,
24    /// SIMD lane width for this platform
25    simd_width: usize,
26}
27
28impl ComplexSimdVector {
29    /// Create a new SIMD-aligned complex vector
30    pub fn new(length: usize) -> Self {
31        let simd_width = Self::detect_simd_width();
32        let aligned_length = Self::align_length(length, simd_width);
33
34        Self {
35            real: vec![0.0; aligned_length],
36            imag: vec![0.0; aligned_length],
37            length,
38            simd_width,
39        }
40    }
41
42    /// Create from slice of Complex64
43    pub fn from_slice(data: &[Complex64]) -> Self {
44        let mut vec = Self::new(data.len());
45        for (i, &complex) in data.iter().enumerate() {
46            vec.real[i] = complex.re;
47            vec.imag[i] = complex.im;
48        }
49        vec
50    }
51
52    /// Convert back to Complex64 slice
53    pub fn to_complex_vec(&self) -> Vec<Complex64> {
54        (0..self.length)
55            .map(|i| Complex64::new(self.real[i], self.imag[i]))
56            .collect()
57    }
58
59    /// Detect optimal SIMD width for current hardware using PlatformCapabilities
60    pub fn detect_simd_width() -> usize {
61        PlatformCapabilities::detect().optimal_simd_width_f64()
62    }
63
64    /// Align length to SIMD boundary
65    const fn align_length(length: usize, simd_width: usize) -> usize {
66        length.div_ceil(simd_width) * simd_width
67    }
68
69    /// Get real part as array view
70    pub fn real_view(&self) -> ArrayView1<'_, f64> {
71        ArrayView1::from(&self.real[..self.length])
72    }
73
74    /// Get imaginary part as array view
75    pub fn imag_view(&self) -> ArrayView1<'_, f64> {
76        ArrayView1::from(&self.imag[..self.length])
77    }
78
79    /// Get length
80    pub const fn len(&self) -> usize {
81        self.length
82    }
83
84    /// Check if empty
85    pub const fn is_empty(&self) -> bool {
86        self.length == 0
87    }
88}
89
90/// High-performance complex arithmetic operations using SciRS2 SIMD
91pub struct ComplexSimdOps;
92
93impl ComplexSimdOps {
94    /// Complex multiplication: c = a * b, vectorized
95    pub fn complex_mul_simd(
96        a: &ComplexSimdVector,
97        b: &ComplexSimdVector,
98        c: &mut ComplexSimdVector,
99    ) {
100        assert_eq!(a.len(), b.len());
101        assert_eq!(a.len(), c.len());
102
103        let a_real = a.real_view();
104        let a_imag = a.imag_view();
105        let b_real = b.real_view();
106        let b_imag = b.imag_view();
107
108        // Complex multiplication: (a_r + i*a_i) * (b_r + i*b_i)
109        // = (a_r*b_r - a_i*b_i) + i*(a_r*b_i + a_i*b_r)
110
111        // Real part: a_r*b_r - a_i*b_i using SciRS2 SIMD operations
112        let ar_br = f64::simd_mul(&a_real, &b_real);
113        let ai_bi = f64::simd_mul(&a_imag, &b_imag);
114        let real_result = f64::simd_sub(&ar_br.view(), &ai_bi.view());
115
116        // Imaginary part: a_r*b_i + a_i*b_r using SciRS2 SIMD operations
117        let ar_bi = f64::simd_mul(&a_real, &b_imag);
118        let ai_br = f64::simd_mul(&a_imag, &b_real);
119        let imag_result = f64::simd_add(&ar_bi.view(), &ai_br.view());
120
121        // Store results
122        for i in 0..c.length {
123            c.real[i] = real_result[i];
124            c.imag[i] = imag_result[i];
125        }
126    }
127
128    /// Complex addition: c = a + b, vectorized
129    pub fn complex_add_simd(
130        a: &ComplexSimdVector,
131        b: &ComplexSimdVector,
132        c: &mut ComplexSimdVector,
133    ) {
134        assert_eq!(a.len(), b.len());
135        assert_eq!(a.len(), c.len());
136
137        let a_real = a.real_view();
138        let a_imag = a.imag_view();
139        let b_real = b.real_view();
140        let b_imag = b.imag_view();
141
142        let real_result = f64::simd_add(&a_real, &b_real);
143        let imag_result = f64::simd_add(&a_imag, &b_imag);
144
145        for i in 0..c.length {
146            c.real[i] = real_result[i];
147            c.imag[i] = imag_result[i];
148        }
149    }
150
151    /// Complex subtraction: c = a - b, vectorized
152    pub fn complex_sub_simd(
153        a: &ComplexSimdVector,
154        b: &ComplexSimdVector,
155        c: &mut ComplexSimdVector,
156    ) {
157        assert_eq!(a.len(), b.len());
158        assert_eq!(a.len(), c.len());
159
160        let a_real = a.real_view();
161        let a_imag = a.imag_view();
162        let b_real = b.real_view();
163        let b_imag = b.imag_view();
164
165        let real_result = f64::simd_sub(&a_real, &b_real);
166        let imag_result = f64::simd_sub(&a_imag, &b_imag);
167
168        for i in 0..c.length {
169            c.real[i] = real_result[i];
170            c.imag[i] = imag_result[i];
171        }
172    }
173
174    /// Scalar complex multiplication: c = a * scalar, vectorized
175    pub fn complex_scalar_mul_simd(
176        a: &ComplexSimdVector,
177        scalar: Complex64,
178        c: &mut ComplexSimdVector,
179    ) {
180        assert_eq!(a.len(), c.len());
181
182        let a_real = a.real_view();
183        let a_imag = a.imag_view();
184
185        // (a_r + i*a_i) * (s_r + i*s_i) = (a_r*s_r - a_i*s_i) + i*(a_r*s_i + a_i*s_r)
186        let ar_sr = f64::simd_scalar_mul(&a_real, scalar.re);
187        let ai_si = f64::simd_scalar_mul(&a_imag, scalar.im);
188        let real_result = f64::simd_sub(&ar_sr.view(), &ai_si.view());
189
190        let ar_si = f64::simd_scalar_mul(&a_real, scalar.im);
191        let ai_sr = f64::simd_scalar_mul(&a_imag, scalar.re);
192        let imag_result = f64::simd_add(&ar_si.view(), &ai_sr.view());
193
194        for i in 0..c.length {
195            c.real[i] = real_result[i];
196            c.imag[i] = imag_result[i];
197        }
198    }
199
200    /// Complex conjugate: c = conj(a), vectorized
201    pub fn complex_conj_simd(a: &ComplexSimdVector, c: &mut ComplexSimdVector) {
202        assert_eq!(a.len(), c.len());
203
204        let a_real = a.real_view();
205        let a_imag = a.imag_view();
206
207        // Copy real part unchanged
208        for i in 0..c.length {
209            c.real[i] = a.real[i];
210        }
211
212        // Negate imaginary part
213        let zero_array = Array1::zeros(a.length);
214        let negated_imag = f64::simd_sub(&zero_array.view(), &a_imag);
215
216        for i in 0..c.length {
217            c.imag[i] = negated_imag[i];
218        }
219    }
220
221    /// Complex magnitude squared: |a|^2, vectorized
222    pub fn complex_norm_squared_simd(a: &ComplexSimdVector) -> Vec<f64> {
223        let a_real = a.real_view();
224        let a_imag = a.imag_view();
225
226        let real_sq = f64::simd_mul(&a_real, &a_real);
227        let imag_sq = f64::simd_mul(&a_imag, &a_imag);
228        let norm_sq = f64::simd_add(&real_sq.view(), &imag_sq.view());
229
230        norm_sq.to_vec()
231    }
232}
233
234/// Enhanced single-qubit gate application with native complex SIMD
235pub fn apply_single_qubit_gate_complex_simd(
236    matrix: &[Complex64; 4],
237    in_amps0: &[Complex64],
238    in_amps1: &[Complex64],
239    out_amps0: &mut [Complex64],
240    out_amps1: &mut [Complex64],
241) {
242    let len = in_amps0.len();
243
244    // Convert to SIMD vectors
245    let a0_simd = ComplexSimdVector::from_slice(in_amps0);
246    let a1_simd = ComplexSimdVector::from_slice(in_amps1);
247
248    let mut result0_simd = ComplexSimdVector::new(len);
249    let mut result1_simd = ComplexSimdVector::new(len);
250
251    // Temporary vectors for intermediate results
252    let mut temp0_simd = ComplexSimdVector::new(len);
253    let mut temp1_simd = ComplexSimdVector::new(len);
254
255    // Compute: out_a0 = matrix[0] * a0 + matrix[1] * a1
256    ComplexSimdOps::complex_scalar_mul_simd(&a0_simd, matrix[0], &mut temp0_simd);
257    ComplexSimdOps::complex_scalar_mul_simd(&a1_simd, matrix[1], &mut temp1_simd);
258    ComplexSimdOps::complex_add_simd(&temp0_simd, &temp1_simd, &mut result0_simd);
259
260    // Compute: out_a1 = matrix[2] * a0 + matrix[3] * a1
261    ComplexSimdOps::complex_scalar_mul_simd(&a0_simd, matrix[2], &mut temp0_simd);
262    ComplexSimdOps::complex_scalar_mul_simd(&a1_simd, matrix[3], &mut temp1_simd);
263    ComplexSimdOps::complex_add_simd(&temp0_simd, &temp1_simd, &mut result1_simd);
264
265    // Convert back to Complex64 arrays
266    let result0 = result0_simd.to_complex_vec();
267    let result1 = result1_simd.to_complex_vec();
268
269    out_amps0.copy_from_slice(&result0);
270    out_amps1.copy_from_slice(&result1);
271}
272
273/// Enhanced Hadamard gate with complex SIMD optimizations
274pub fn apply_hadamard_gate_complex_simd(
275    in_amps0: &[Complex64],
276    in_amps1: &[Complex64],
277    out_amps0: &mut [Complex64],
278    out_amps1: &mut [Complex64],
279) {
280    let len = in_amps0.len();
281    let sqrt2_inv = Complex64::new(std::f64::consts::FRAC_1_SQRT_2, 0.0);
282
283    let a0_simd = ComplexSimdVector::from_slice(in_amps0);
284    let a1_simd = ComplexSimdVector::from_slice(in_amps1);
285
286    let mut sum_simd = ComplexSimdVector::new(len);
287    let mut diff_simd = ComplexSimdVector::new(len);
288    let mut result0_simd = ComplexSimdVector::new(len);
289    let mut result1_simd = ComplexSimdVector::new(len);
290
291    // Hadamard: out_a0 = (a0 + a1) / sqrt(2), out_a1 = (a0 - a1) / sqrt(2)
292    ComplexSimdOps::complex_add_simd(&a0_simd, &a1_simd, &mut sum_simd);
293    ComplexSimdOps::complex_sub_simd(&a0_simd, &a1_simd, &mut diff_simd);
294
295    ComplexSimdOps::complex_scalar_mul_simd(&sum_simd, sqrt2_inv, &mut result0_simd);
296    ComplexSimdOps::complex_scalar_mul_simd(&diff_simd, sqrt2_inv, &mut result1_simd);
297
298    let result0 = result0_simd.to_complex_vec();
299    let result1 = result1_simd.to_complex_vec();
300
301    out_amps0.copy_from_slice(&result0);
302    out_amps1.copy_from_slice(&result1);
303}
304
305/// Optimized CNOT gate for multiple qubit pairs using complex SIMD
306pub fn apply_cnot_complex_simd(
307    state: &mut [Complex64],
308    control_qubit: usize,
309    target_qubit: usize,
310    num_qubits: usize,
311) {
312    let dim = 1 << num_qubits;
313    let control_mask = 1 << control_qubit;
314    let target_mask = 1 << target_qubit;
315
316    // Process state in SIMD chunks where possible
317    let chunk_size = ComplexSimdVector::detect_simd_width();
318    let num_chunks = dim / (chunk_size * 2); // Process pairs of indices
319
320    for chunk in 0..num_chunks {
321        let base_idx = chunk * chunk_size * 2;
322        let mut chunk_data = vec![Complex64::new(0.0, 0.0); chunk_size * 2];
323
324        // Collect indices that need swapping
325        let mut swap_indices = Vec::new();
326        for i in 0..chunk_size {
327            let idx = base_idx + i;
328            if idx < dim && (idx & control_mask) != 0 {
329                let swapped_idx = idx ^ target_mask;
330                if swapped_idx < dim {
331                    swap_indices.push((idx, swapped_idx));
332                    chunk_data[i * 2] = state[idx];
333                    chunk_data[i * 2 + 1] = state[swapped_idx];
334                }
335            }
336        }
337
338        // Apply swaps using SIMD operations
339        if !swap_indices.is_empty() {
340            let chunk_simd = ComplexSimdVector::from_slice(&chunk_data);
341            for (i, (idx, swapped_idx)) in swap_indices.iter().enumerate() {
342                state[*idx] = chunk_simd.to_complex_vec()[i * 2 + 1];
343                state[*swapped_idx] = chunk_simd.to_complex_vec()[i * 2];
344            }
345        }
346    }
347
348    // Handle remaining elements with scalar operations
349    let remaining_start = num_chunks * chunk_size * 2;
350    for i in remaining_start..dim {
351        if (i & control_mask) != 0 {
352            let swapped_i = i ^ target_mask;
353            if swapped_i < dim {
354                state.swap(i, swapped_i);
355            }
356        }
357    }
358}
359
360/// Performance benchmarking for complex SIMD operations
361pub fn benchmark_complex_simd_operations() -> std::collections::HashMap<String, f64> {
362    use std::time::Instant;
363    let mut results = std::collections::HashMap::new();
364
365    // Test vector sizes
366    let sizes = vec![1024, 4096, 16384, 65536];
367
368    for &size in &sizes {
369        let a = ComplexSimdVector::from_slice(&vec![Complex64::new(1.0, 0.5); size]);
370        let b = ComplexSimdVector::from_slice(&vec![Complex64::new(0.5, 1.0); size]);
371        let mut c = ComplexSimdVector::new(size);
372
373        // Benchmark complex multiplication
374        let start = Instant::now();
375        for _ in 0..1000 {
376            ComplexSimdOps::complex_mul_simd(&a, &b, &mut c);
377        }
378        let mul_time = start.elapsed().as_nanos() as f64 / 1000.0;
379        results.insert(format!("complex_mul_simd_{size}"), mul_time);
380
381        // Benchmark complex addition
382        let start = Instant::now();
383        for _ in 0..1000 {
384            ComplexSimdOps::complex_add_simd(&a, &b, &mut c);
385        }
386        let add_time = start.elapsed().as_nanos() as f64 / 1000.0;
387        results.insert(format!("complex_add_simd_{size}"), add_time);
388
389        // Benchmark scalar multiplication
390        let scalar = Complex64::new(2.0, 1.0);
391        let start = Instant::now();
392        for _ in 0..1000 {
393            ComplexSimdOps::complex_scalar_mul_simd(&a, scalar, &mut c);
394        }
395        let scalar_mul_time = start.elapsed().as_nanos() as f64 / 1000.0;
396        results.insert(format!("complex_scalar_mul_simd_{size}"), scalar_mul_time);
397    }
398
399    results
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use std::f64::consts::FRAC_1_SQRT_2;
406
407    #[test]
408    fn test_complex_simd_vector_creation() {
409        let data = vec![
410            Complex64::new(1.0, 2.0),
411            Complex64::new(3.0, 4.0),
412            Complex64::new(5.0, 6.0),
413        ];
414
415        let simd_vec = ComplexSimdVector::from_slice(&data);
416        assert_eq!(simd_vec.len(), 3);
417
418        let result = simd_vec.to_complex_vec();
419        for (i, &expected) in data.iter().enumerate() {
420            assert!((result[i] - expected).norm() < 1e-10);
421        }
422    }
423
424    #[test]
425    fn test_complex_multiplication_simd() {
426        let a =
427            ComplexSimdVector::from_slice(&[Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)]);
428        let b =
429            ComplexSimdVector::from_slice(&[Complex64::new(5.0, 6.0), Complex64::new(7.0, 8.0)]);
430        let mut c = ComplexSimdVector::new(2);
431
432        ComplexSimdOps::complex_mul_simd(&a, &b, &mut c);
433        let result = c.to_complex_vec();
434
435        // Verify: (1+2i)*(5+6i) = 5+6i+10i-12 = -7+16i
436        let expected0 = Complex64::new(-7.0, 16.0);
437        assert!((result[0] - expected0).norm() < 1e-10);
438
439        // Verify: (3+4i)*(7+8i) = 21+24i+28i-32 = -11+52i
440        let expected1 = Complex64::new(-11.0, 52.0);
441        assert!((result[1] - expected1).norm() < 1e-10);
442    }
443
444    #[test]
445    fn test_hadamard_gate_complex_simd() {
446        let in_amps0 = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
447        let in_amps1 = vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)];
448        let mut out_amps0 = vec![Complex64::new(0.0, 0.0); 2];
449        let mut out_amps1 = vec![Complex64::new(0.0, 0.0); 2];
450
451        apply_hadamard_gate_complex_simd(&in_amps0, &in_amps1, &mut out_amps0, &mut out_amps1);
452
453        // H|0⟩ = (|0⟩ + |1⟩)/√2
454        let expected = Complex64::new(FRAC_1_SQRT_2, 0.0);
455        assert!((out_amps0[0] - expected).norm() < 1e-10);
456        assert!((out_amps1[0] - expected).norm() < 1e-10);
457
458        // H|1⟩ = (|0⟩ - |1⟩)/√2
459        assert!((out_amps0[1] - expected).norm() < 1e-10);
460        assert!((out_amps1[1] - (-expected)).norm() < 1e-10);
461    }
462
463    #[test]
464    fn test_single_qubit_gate_complex_simd() {
465        // X gate matrix
466        let x_matrix = [
467            Complex64::new(0.0, 0.0),
468            Complex64::new(1.0, 0.0),
469            Complex64::new(1.0, 0.0),
470            Complex64::new(0.0, 0.0),
471        ];
472
473        let in_amps0 = vec![Complex64::new(1.0, 0.0), Complex64::new(0.5, 0.0)];
474        let in_amps1 = vec![Complex64::new(0.0, 0.0), Complex64::new(0.5, 0.0)];
475        let mut out_amps0 = vec![Complex64::new(0.0, 0.0); 2];
476        let mut out_amps1 = vec![Complex64::new(0.0, 0.0); 2];
477
478        apply_single_qubit_gate_complex_simd(
479            &x_matrix,
480            &in_amps0,
481            &in_amps1,
482            &mut out_amps0,
483            &mut out_amps1,
484        );
485
486        // X gate should swap amplitudes
487        assert!((out_amps0[0] - Complex64::new(0.0, 0.0)).norm() < 1e-10);
488        assert!((out_amps1[0] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
489        assert!((out_amps0[1] - Complex64::new(0.5, 0.0)).norm() < 1e-10);
490        assert!((out_amps1[1] - Complex64::new(0.5, 0.0)).norm() < 1e-10);
491    }
492
493    #[test]
494    fn test_complex_norm_squared_simd() {
495        let data = vec![
496            Complex64::new(3.0, 4.0), // |3+4i|² = 9+16 = 25
497            Complex64::new(1.0, 1.0), // |1+i|² = 1+1 = 2
498        ];
499
500        let simd_vec = ComplexSimdVector::from_slice(&data);
501        let norms = ComplexSimdOps::complex_norm_squared_simd(&simd_vec);
502
503        assert!((norms[0] - 25.0).abs() < 1e-10);
504        assert!((norms[1] - 2.0).abs() < 1e-10);
505    }
506}