Skip to main content

quantrs2_core/
simd_ops_stubs.rs

1//! SIMD-like batch operations stubs replacing scirs2_core::simd_ops
2//!
3//! Implements explicit loop-unrolled vectorized operations on f64 and Complex64
4//! arrays. Uses 4-wide manual unrolling to exploit CPU instruction-level
5//! parallelism without requiring nightly std::simd.
6
7use scirs2_core::ndarray::{Array1, ArrayView1};
8use scirs2_core::Complex64;
9
10/// Lane width used for manual loop unrolling (matches common SIMD register width)
11const UNROLL: usize = 4;
12
13/// Trait for SIMD-like batch operations on f64
14pub trait SimdF64 {
15    fn simd_add(self, other: f64) -> f64;
16    fn simd_sub(self, other: f64) -> f64;
17    fn simd_mul(self, other: f64) -> f64;
18    fn simd_scalar_mul(view: &ArrayView1<f64>, scalar: f64) -> Array1<f64>;
19    fn simd_add_arrays(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array1<f64>;
20    fn simd_sub_arrays(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array1<f64>;
21    fn simd_mul_arrays(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array1<f64>;
22    fn simd_dot(a: &[f64], b: &[f64]) -> f64;
23    fn simd_sum(slice: &[f64]) -> f64;
24    fn simd_sum_array(a: &ArrayView1<f64>) -> f64;
25    fn simd_max(a: &[f64]) -> f64;
26    fn simd_min(a: &[f64]) -> f64;
27    fn simd_fmadd(a: &[f64], b: &[f64], c: &[f64]) -> Vec<f64>;
28}
29
30impl SimdF64 for f64 {
31    #[inline(always)]
32    fn simd_add(self, other: f64) -> f64 {
33        self + other
34    }
35
36    #[inline(always)]
37    fn simd_sub(self, other: f64) -> f64 {
38        self - other
39    }
40
41    #[inline(always)]
42    fn simd_mul(self, other: f64) -> f64 {
43        self * other
44    }
45
46    /// Scalar-multiply every element, 4-wide unrolled
47    #[inline]
48    fn simd_scalar_mul(view: &ArrayView1<f64>, scalar: f64) -> Array1<f64> {
49        let n = view.len();
50        let slice = view.as_slice().unwrap_or(&[]);
51
52        // Fast path: contiguous memory — unrolled loop
53        if !slice.is_empty() {
54            let mut out = vec![0.0f64; n];
55            let chunks = n / UNROLL;
56            let rem = n % UNROLL;
57            let base = chunks * UNROLL;
58
59            for i in 0..chunks {
60                let j = i * UNROLL;
61                out[j] = slice[j] * scalar;
62                out[j + 1] = slice[j + 1] * scalar;
63                out[j + 2] = slice[j + 2] * scalar;
64                out[j + 3] = slice[j + 3] * scalar;
65            }
66            for k in 0..rem {
67                out[base + k] = slice[base + k] * scalar;
68            }
69            return Array1::from(out);
70        }
71
72        // Non-contiguous fallback
73        view.mapv(|x| x * scalar)
74    }
75
76    /// Element-wise addition, 4-wide unrolled
77    #[inline]
78    fn simd_add_arrays(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array1<f64> {
79        assert_eq!(a.len(), b.len(), "simd_add_arrays: length mismatch");
80        let n = a.len();
81
82        match (a.as_slice(), b.as_slice()) {
83            (Some(sa), Some(sb)) => {
84                let mut out = vec![0.0f64; n];
85                let chunks = n / UNROLL;
86                let rem = n % UNROLL;
87                let base = chunks * UNROLL;
88
89                for i in 0..chunks {
90                    let j = i * UNROLL;
91                    out[j] = sa[j] + sb[j];
92                    out[j + 1] = sa[j + 1] + sb[j + 1];
93                    out[j + 2] = sa[j + 2] + sb[j + 2];
94                    out[j + 3] = sa[j + 3] + sb[j + 3];
95                }
96                for k in 0..rem {
97                    out[base + k] = sa[base + k] + sb[base + k];
98                }
99                Array1::from(out)
100            }
101            _ => a + b,
102        }
103    }
104
105    /// Element-wise subtraction, 4-wide unrolled
106    #[inline]
107    fn simd_sub_arrays(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array1<f64> {
108        assert_eq!(a.len(), b.len(), "simd_sub_arrays: length mismatch");
109        let n = a.len();
110
111        match (a.as_slice(), b.as_slice()) {
112            (Some(sa), Some(sb)) => {
113                let mut out = vec![0.0f64; n];
114                let chunks = n / UNROLL;
115                let rem = n % UNROLL;
116                let base = chunks * UNROLL;
117
118                for i in 0..chunks {
119                    let j = i * UNROLL;
120                    out[j] = sa[j] - sb[j];
121                    out[j + 1] = sa[j + 1] - sb[j + 1];
122                    out[j + 2] = sa[j + 2] - sb[j + 2];
123                    out[j + 3] = sa[j + 3] - sb[j + 3];
124                }
125                for k in 0..rem {
126                    out[base + k] = sa[base + k] - sb[base + k];
127                }
128                Array1::from(out)
129            }
130            _ => a - b,
131        }
132    }
133
134    /// Element-wise multiplication, 4-wide unrolled
135    #[inline]
136    fn simd_mul_arrays(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array1<f64> {
137        assert_eq!(a.len(), b.len(), "simd_mul_arrays: length mismatch");
138        let n = a.len();
139
140        match (a.as_slice(), b.as_slice()) {
141            (Some(sa), Some(sb)) => {
142                let mut out = vec![0.0f64; n];
143                let chunks = n / UNROLL;
144                let rem = n % UNROLL;
145                let base = chunks * UNROLL;
146
147                for i in 0..chunks {
148                    let j = i * UNROLL;
149                    out[j] = sa[j] * sb[j];
150                    out[j + 1] = sa[j + 1] * sb[j + 1];
151                    out[j + 2] = sa[j + 2] * sb[j + 2];
152                    out[j + 3] = sa[j + 3] * sb[j + 3];
153                }
154                for k in 0..rem {
155                    out[base + k] = sa[base + k] * sb[base + k];
156                }
157                Array1::from(out)
158            }
159            _ => a * b,
160        }
161    }
162
163    /// Dot product with 4-wide accumulator unrolling to reduce dependency chains
164    #[inline]
165    fn simd_dot(a: &[f64], b: &[f64]) -> f64 {
166        assert_eq!(a.len(), b.len(), "simd_dot: length mismatch");
167        let n = a.len();
168        let chunks = n / UNROLL;
169        let rem = n % UNROLL;
170        let base = chunks * UNROLL;
171
172        // Four independent accumulators — breaks scalar dependency chain
173        let mut acc0 = 0.0f64;
174        let mut acc1 = 0.0f64;
175        let mut acc2 = 0.0f64;
176        let mut acc3 = 0.0f64;
177
178        for i in 0..chunks {
179            let j = i * UNROLL;
180            acc0 += a[j] * b[j];
181            acc1 += a[j + 1] * b[j + 1];
182            acc2 += a[j + 2] * b[j + 2];
183            acc3 += a[j + 3] * b[j + 3];
184        }
185
186        let mut tail = acc0 + acc1 + acc2 + acc3;
187        for k in 0..rem {
188            tail += a[base + k] * b[base + k];
189        }
190        tail
191    }
192
193    /// Horizontal sum with 4-wide accumulator unrolling
194    #[inline]
195    fn simd_sum(slice: &[f64]) -> f64 {
196        let n = slice.len();
197        let chunks = n / UNROLL;
198        let rem = n % UNROLL;
199        let base = chunks * UNROLL;
200
201        let mut acc0 = 0.0f64;
202        let mut acc1 = 0.0f64;
203        let mut acc2 = 0.0f64;
204        let mut acc3 = 0.0f64;
205
206        for i in 0..chunks {
207            let j = i * UNROLL;
208            acc0 += slice[j];
209            acc1 += slice[j + 1];
210            acc2 += slice[j + 2];
211            acc3 += slice[j + 3];
212        }
213
214        let mut total = acc0 + acc1 + acc2 + acc3;
215        for k in 0..rem {
216            total += slice[base + k];
217        }
218        total
219    }
220
221    #[inline]
222    fn simd_sum_array(a: &ArrayView1<f64>) -> f64 {
223        match a.as_slice() {
224            Some(s) => <f64 as SimdF64>::simd_sum(s),
225            None => a.sum(),
226        }
227    }
228
229    /// Maximum value with 4-wide unrolled comparison
230    #[inline]
231    fn simd_max(a: &[f64]) -> f64 {
232        if a.is_empty() {
233            return f64::NEG_INFINITY;
234        }
235        let n = a.len();
236        let chunks = n / UNROLL;
237        let rem = n % UNROLL;
238        let base = chunks * UNROLL;
239
240        let mut m0 = f64::NEG_INFINITY;
241        let mut m1 = f64::NEG_INFINITY;
242        let mut m2 = f64::NEG_INFINITY;
243        let mut m3 = f64::NEG_INFINITY;
244
245        for i in 0..chunks {
246            let j = i * UNROLL;
247            m0 = m0.max(a[j]);
248            m1 = m1.max(a[j + 1]);
249            m2 = m2.max(a[j + 2]);
250            m3 = m3.max(a[j + 3]);
251        }
252
253        let mut max = m0.max(m1).max(m2).max(m3);
254        for k in 0..rem {
255            max = max.max(a[base + k]);
256        }
257        max
258    }
259
260    /// Minimum value with 4-wide unrolled comparison
261    #[inline]
262    fn simd_min(a: &[f64]) -> f64 {
263        if a.is_empty() {
264            return f64::INFINITY;
265        }
266        let n = a.len();
267        let chunks = n / UNROLL;
268        let rem = n % UNROLL;
269        let base = chunks * UNROLL;
270
271        let mut m0 = f64::INFINITY;
272        let mut m1 = f64::INFINITY;
273        let mut m2 = f64::INFINITY;
274        let mut m3 = f64::INFINITY;
275
276        for i in 0..chunks {
277            let j = i * UNROLL;
278            m0 = m0.min(a[j]);
279            m1 = m1.min(a[j + 1]);
280            m2 = m2.min(a[j + 2]);
281            m3 = m3.min(a[j + 3]);
282        }
283
284        let mut min = m0.min(m1).min(m2).min(m3);
285        for k in 0..rem {
286            min = min.min(a[base + k]);
287        }
288        min
289    }
290
291    /// Fused multiply-add: out\[i\] = a\[i\]\*b\[i\] + c\[i\], 4-wide unrolled
292    #[inline]
293    fn simd_fmadd(a: &[f64], b: &[f64], c: &[f64]) -> Vec<f64> {
294        let n = a.len();
295        assert_eq!(n, b.len(), "simd_fmadd: a/b length mismatch");
296        assert_eq!(n, c.len(), "simd_fmadd: a/c length mismatch");
297
298        let mut out = vec![0.0f64; n];
299        let chunks = n / UNROLL;
300        let rem = n % UNROLL;
301        let base = chunks * UNROLL;
302
303        for i in 0..chunks {
304            let j = i * UNROLL;
305            out[j] = a[j] * b[j] + c[j];
306            out[j + 1] = a[j + 1] * b[j + 1] + c[j + 1];
307            out[j + 2] = a[j + 2] * b[j + 2] + c[j + 2];
308            out[j + 3] = a[j + 3] * b[j + 3] + c[j + 3];
309        }
310        for k in 0..rem {
311            out[base + k] = a[base + k] * b[base + k] + c[base + k];
312        }
313        out
314    }
315}
316
317/// Trait for SIMD-like batch operations on Complex64
318pub trait SimdComplex64 {
319    fn simd_add(self, other: Complex64) -> Complex64;
320    fn simd_sub(self, other: Complex64) -> Complex64;
321    fn simd_mul(self, other: Complex64) -> Complex64;
322    fn simd_scalar_mul(self, scalar: Complex64) -> Complex64;
323    fn simd_dot(a: &[Complex64], b: &[Complex64]) -> Complex64;
324    fn simd_sum(slice: &[Complex64]) -> Complex64;
325    fn simd_sum_array(a: &ArrayView1<Complex64>) -> Complex64;
326}
327
328impl SimdComplex64 for Complex64 {
329    #[inline(always)]
330    fn simd_add(self, other: Complex64) -> Complex64 {
331        self + other
332    }
333
334    #[inline(always)]
335    fn simd_sub(self, other: Complex64) -> Complex64 {
336        self - other
337    }
338
339    #[inline(always)]
340    fn simd_mul(self, other: Complex64) -> Complex64 {
341        self * other
342    }
343
344    #[inline(always)]
345    fn simd_scalar_mul(self, scalar: Complex64) -> Complex64 {
346        self * scalar
347    }
348
349    /// Complex dot product with 4-wide unrolled accumulation
350    #[inline]
351    fn simd_dot(a: &[Complex64], b: &[Complex64]) -> Complex64 {
352        assert_eq!(a.len(), b.len(), "simd_dot complex: length mismatch");
353        let n = a.len();
354        let chunks = n / UNROLL;
355        let rem = n % UNROLL;
356        let base = chunks * UNROLL;
357
358        let zero = Complex64::new(0.0, 0.0);
359        let mut acc0 = zero;
360        let mut acc1 = zero;
361        let mut acc2 = zero;
362        let mut acc3 = zero;
363
364        for i in 0..chunks {
365            let j = i * UNROLL;
366            acc0 += a[j] * b[j];
367            acc1 += a[j + 1] * b[j + 1];
368            acc2 += a[j + 2] * b[j + 2];
369            acc3 += a[j + 3] * b[j + 3];
370        }
371
372        let mut total = acc0 + acc1 + acc2 + acc3;
373        for k in 0..rem {
374            total += a[base + k] * b[base + k];
375        }
376        total
377    }
378
379    /// Horizontal sum with 4-wide accumulator unrolling
380    #[inline]
381    fn simd_sum(slice: &[Complex64]) -> Complex64 {
382        let n = slice.len();
383        let chunks = n / UNROLL;
384        let rem = n % UNROLL;
385        let base = chunks * UNROLL;
386
387        let zero = Complex64::new(0.0, 0.0);
388        let mut acc0 = zero;
389        let mut acc1 = zero;
390        let mut acc2 = zero;
391        let mut acc3 = zero;
392
393        for i in 0..chunks {
394            let j = i * UNROLL;
395            acc0 += slice[j];
396            acc1 += slice[j + 1];
397            acc2 += slice[j + 2];
398            acc3 += slice[j + 3];
399        }
400
401        let mut total = acc0 + acc1 + acc2 + acc3;
402        for k in 0..rem {
403            total += slice[base + k];
404        }
405        total
406    }
407
408    #[inline]
409    fn simd_sum_array(a: &ArrayView1<Complex64>) -> Complex64 {
410        match a.as_slice() {
411            Some(s) => <Complex64 as SimdComplex64>::simd_sum(s),
412            None => a.sum(),
413        }
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420    use scirs2_core::ndarray::array;
421
422    #[test]
423    fn test_simd_dot_basic() {
424        let a = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
425        let b = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
426        let result = <f64 as SimdF64>::simd_dot(&a, &b);
427        let expected: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
428        assert!(
429            (result - expected).abs() < 1e-12,
430            "simd_dot mismatch: {result} vs {expected}"
431        );
432    }
433
434    #[test]
435    fn test_simd_sum_unrolled() {
436        let data: Vec<f64> = (0..17).map(|i| i as f64).collect();
437        let result = <f64 as SimdF64>::simd_sum(&data);
438        let expected: f64 = data.iter().sum();
439        assert!((result - expected).abs() < 1e-12);
440    }
441
442    #[test]
443    fn test_simd_fmadd() {
444        let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
445        let b = vec![2.0f64, 2.0, 2.0, 2.0, 2.0];
446        let c = vec![0.5f64, 0.5, 0.5, 0.5, 0.5];
447        let result = <f64 as SimdF64>::simd_fmadd(&a, &b, &c);
448        let expected: Vec<f64> = a
449            .iter()
450            .zip(b.iter())
451            .zip(c.iter())
452            .map(|((ai, bi), ci)| ai * bi + ci)
453            .collect();
454        for (r, e) in result.iter().zip(expected.iter()) {
455            assert!((r - e).abs() < 1e-12);
456        }
457    }
458
459    #[test]
460    fn test_simd_add_arrays_unrolled() {
461        let a = array![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
462        let b = array![9.0f64, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
463        let result = <f64 as SimdF64>::simd_add_arrays(&a.view(), &b.view());
464        for v in result.iter() {
465            assert!((v - 10.0).abs() < 1e-12);
466        }
467    }
468
469    #[test]
470    fn test_simd_max_min() {
471        let data = vec![3.0f64, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0, 5.0];
472        assert!(((<f64 as SimdF64>::simd_max(&data)) - 9.0).abs() < 1e-12);
473        assert!(((<f64 as SimdF64>::simd_min(&data)) - 1.0).abs() < 1e-12);
474    }
475
476    #[test]
477    fn test_complex_simd_dot() {
478        let a = vec![
479            Complex64::new(1.0, 0.0),
480            Complex64::new(0.0, 1.0),
481            Complex64::new(1.0, 1.0),
482            Complex64::new(2.0, -1.0),
483            Complex64::new(0.5, 0.5),
484        ];
485        let b = a.clone();
486        let result = <Complex64 as SimdComplex64>::simd_dot(&a, &b);
487        let expected: Complex64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
488        assert!((result.re - expected.re).abs() < 1e-12);
489        assert!((result.im - expected.im).abs() < 1e-12);
490    }
491}