1use scirs2_core::ndarray::{Array1, ArrayView1};
8use scirs2_core::Complex64;
9
10const UNROLL: usize = 4;
12
13pub trait SimdF64 {
15 fn simd_add(self, other: f64) -> f64;
16 fn simd_sub(self, other: f64) -> f64;
17 fn simd_mul(self, other: f64) -> f64;
18 fn simd_scalar_mul(view: &ArrayView1<f64>, scalar: f64) -> Array1<f64>;
19 fn simd_add_arrays(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array1<f64>;
20 fn simd_sub_arrays(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array1<f64>;
21 fn simd_mul_arrays(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array1<f64>;
22 fn simd_dot(a: &[f64], b: &[f64]) -> f64;
23 fn simd_sum(slice: &[f64]) -> f64;
24 fn simd_sum_array(a: &ArrayView1<f64>) -> f64;
25 fn simd_max(a: &[f64]) -> f64;
26 fn simd_min(a: &[f64]) -> f64;
27 fn simd_fmadd(a: &[f64], b: &[f64], c: &[f64]) -> Vec<f64>;
28}
29
30impl SimdF64 for f64 {
31 #[inline(always)]
32 fn simd_add(self, other: f64) -> f64 {
33 self + other
34 }
35
36 #[inline(always)]
37 fn simd_sub(self, other: f64) -> f64 {
38 self - other
39 }
40
41 #[inline(always)]
42 fn simd_mul(self, other: f64) -> f64 {
43 self * other
44 }
45
46 #[inline]
48 fn simd_scalar_mul(view: &ArrayView1<f64>, scalar: f64) -> Array1<f64> {
49 let n = view.len();
50 let slice = view.as_slice().unwrap_or(&[]);
51
52 if !slice.is_empty() {
54 let mut out = vec![0.0f64; n];
55 let chunks = n / UNROLL;
56 let rem = n % UNROLL;
57 let base = chunks * UNROLL;
58
59 for i in 0..chunks {
60 let j = i * UNROLL;
61 out[j] = slice[j] * scalar;
62 out[j + 1] = slice[j + 1] * scalar;
63 out[j + 2] = slice[j + 2] * scalar;
64 out[j + 3] = slice[j + 3] * scalar;
65 }
66 for k in 0..rem {
67 out[base + k] = slice[base + k] * scalar;
68 }
69 return Array1::from(out);
70 }
71
72 view.mapv(|x| x * scalar)
74 }
75
76 #[inline]
78 fn simd_add_arrays(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array1<f64> {
79 assert_eq!(a.len(), b.len(), "simd_add_arrays: length mismatch");
80 let n = a.len();
81
82 match (a.as_slice(), b.as_slice()) {
83 (Some(sa), Some(sb)) => {
84 let mut out = vec![0.0f64; n];
85 let chunks = n / UNROLL;
86 let rem = n % UNROLL;
87 let base = chunks * UNROLL;
88
89 for i in 0..chunks {
90 let j = i * UNROLL;
91 out[j] = sa[j] + sb[j];
92 out[j + 1] = sa[j + 1] + sb[j + 1];
93 out[j + 2] = sa[j + 2] + sb[j + 2];
94 out[j + 3] = sa[j + 3] + sb[j + 3];
95 }
96 for k in 0..rem {
97 out[base + k] = sa[base + k] + sb[base + k];
98 }
99 Array1::from(out)
100 }
101 _ => a + b,
102 }
103 }
104
105 #[inline]
107 fn simd_sub_arrays(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array1<f64> {
108 assert_eq!(a.len(), b.len(), "simd_sub_arrays: length mismatch");
109 let n = a.len();
110
111 match (a.as_slice(), b.as_slice()) {
112 (Some(sa), Some(sb)) => {
113 let mut out = vec![0.0f64; n];
114 let chunks = n / UNROLL;
115 let rem = n % UNROLL;
116 let base = chunks * UNROLL;
117
118 for i in 0..chunks {
119 let j = i * UNROLL;
120 out[j] = sa[j] - sb[j];
121 out[j + 1] = sa[j + 1] - sb[j + 1];
122 out[j + 2] = sa[j + 2] - sb[j + 2];
123 out[j + 3] = sa[j + 3] - sb[j + 3];
124 }
125 for k in 0..rem {
126 out[base + k] = sa[base + k] - sb[base + k];
127 }
128 Array1::from(out)
129 }
130 _ => a - b,
131 }
132 }
133
134 #[inline]
136 fn simd_mul_arrays(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array1<f64> {
137 assert_eq!(a.len(), b.len(), "simd_mul_arrays: length mismatch");
138 let n = a.len();
139
140 match (a.as_slice(), b.as_slice()) {
141 (Some(sa), Some(sb)) => {
142 let mut out = vec![0.0f64; n];
143 let chunks = n / UNROLL;
144 let rem = n % UNROLL;
145 let base = chunks * UNROLL;
146
147 for i in 0..chunks {
148 let j = i * UNROLL;
149 out[j] = sa[j] * sb[j];
150 out[j + 1] = sa[j + 1] * sb[j + 1];
151 out[j + 2] = sa[j + 2] * sb[j + 2];
152 out[j + 3] = sa[j + 3] * sb[j + 3];
153 }
154 for k in 0..rem {
155 out[base + k] = sa[base + k] * sb[base + k];
156 }
157 Array1::from(out)
158 }
159 _ => a * b,
160 }
161 }
162
163 #[inline]
165 fn simd_dot(a: &[f64], b: &[f64]) -> f64 {
166 assert_eq!(a.len(), b.len(), "simd_dot: length mismatch");
167 let n = a.len();
168 let chunks = n / UNROLL;
169 let rem = n % UNROLL;
170 let base = chunks * UNROLL;
171
172 let mut acc0 = 0.0f64;
174 let mut acc1 = 0.0f64;
175 let mut acc2 = 0.0f64;
176 let mut acc3 = 0.0f64;
177
178 for i in 0..chunks {
179 let j = i * UNROLL;
180 acc0 += a[j] * b[j];
181 acc1 += a[j + 1] * b[j + 1];
182 acc2 += a[j + 2] * b[j + 2];
183 acc3 += a[j + 3] * b[j + 3];
184 }
185
186 let mut tail = acc0 + acc1 + acc2 + acc3;
187 for k in 0..rem {
188 tail += a[base + k] * b[base + k];
189 }
190 tail
191 }
192
193 #[inline]
195 fn simd_sum(slice: &[f64]) -> f64 {
196 let n = slice.len();
197 let chunks = n / UNROLL;
198 let rem = n % UNROLL;
199 let base = chunks * UNROLL;
200
201 let mut acc0 = 0.0f64;
202 let mut acc1 = 0.0f64;
203 let mut acc2 = 0.0f64;
204 let mut acc3 = 0.0f64;
205
206 for i in 0..chunks {
207 let j = i * UNROLL;
208 acc0 += slice[j];
209 acc1 += slice[j + 1];
210 acc2 += slice[j + 2];
211 acc3 += slice[j + 3];
212 }
213
214 let mut total = acc0 + acc1 + acc2 + acc3;
215 for k in 0..rem {
216 total += slice[base + k];
217 }
218 total
219 }
220
221 #[inline]
222 fn simd_sum_array(a: &ArrayView1<f64>) -> f64 {
223 match a.as_slice() {
224 Some(s) => <f64 as SimdF64>::simd_sum(s),
225 None => a.sum(),
226 }
227 }
228
229 #[inline]
231 fn simd_max(a: &[f64]) -> f64 {
232 if a.is_empty() {
233 return f64::NEG_INFINITY;
234 }
235 let n = a.len();
236 let chunks = n / UNROLL;
237 let rem = n % UNROLL;
238 let base = chunks * UNROLL;
239
240 let mut m0 = f64::NEG_INFINITY;
241 let mut m1 = f64::NEG_INFINITY;
242 let mut m2 = f64::NEG_INFINITY;
243 let mut m3 = f64::NEG_INFINITY;
244
245 for i in 0..chunks {
246 let j = i * UNROLL;
247 m0 = m0.max(a[j]);
248 m1 = m1.max(a[j + 1]);
249 m2 = m2.max(a[j + 2]);
250 m3 = m3.max(a[j + 3]);
251 }
252
253 let mut max = m0.max(m1).max(m2).max(m3);
254 for k in 0..rem {
255 max = max.max(a[base + k]);
256 }
257 max
258 }
259
260 #[inline]
262 fn simd_min(a: &[f64]) -> f64 {
263 if a.is_empty() {
264 return f64::INFINITY;
265 }
266 let n = a.len();
267 let chunks = n / UNROLL;
268 let rem = n % UNROLL;
269 let base = chunks * UNROLL;
270
271 let mut m0 = f64::INFINITY;
272 let mut m1 = f64::INFINITY;
273 let mut m2 = f64::INFINITY;
274 let mut m3 = f64::INFINITY;
275
276 for i in 0..chunks {
277 let j = i * UNROLL;
278 m0 = m0.min(a[j]);
279 m1 = m1.min(a[j + 1]);
280 m2 = m2.min(a[j + 2]);
281 m3 = m3.min(a[j + 3]);
282 }
283
284 let mut min = m0.min(m1).min(m2).min(m3);
285 for k in 0..rem {
286 min = min.min(a[base + k]);
287 }
288 min
289 }
290
291 #[inline]
293 fn simd_fmadd(a: &[f64], b: &[f64], c: &[f64]) -> Vec<f64> {
294 let n = a.len();
295 assert_eq!(n, b.len(), "simd_fmadd: a/b length mismatch");
296 assert_eq!(n, c.len(), "simd_fmadd: a/c length mismatch");
297
298 let mut out = vec![0.0f64; n];
299 let chunks = n / UNROLL;
300 let rem = n % UNROLL;
301 let base = chunks * UNROLL;
302
303 for i in 0..chunks {
304 let j = i * UNROLL;
305 out[j] = a[j] * b[j] + c[j];
306 out[j + 1] = a[j + 1] * b[j + 1] + c[j + 1];
307 out[j + 2] = a[j + 2] * b[j + 2] + c[j + 2];
308 out[j + 3] = a[j + 3] * b[j + 3] + c[j + 3];
309 }
310 for k in 0..rem {
311 out[base + k] = a[base + k] * b[base + k] + c[base + k];
312 }
313 out
314 }
315}
316
317pub trait SimdComplex64 {
319 fn simd_add(self, other: Complex64) -> Complex64;
320 fn simd_sub(self, other: Complex64) -> Complex64;
321 fn simd_mul(self, other: Complex64) -> Complex64;
322 fn simd_scalar_mul(self, scalar: Complex64) -> Complex64;
323 fn simd_dot(a: &[Complex64], b: &[Complex64]) -> Complex64;
324 fn simd_sum(slice: &[Complex64]) -> Complex64;
325 fn simd_sum_array(a: &ArrayView1<Complex64>) -> Complex64;
326}
327
328impl SimdComplex64 for Complex64 {
329 #[inline(always)]
330 fn simd_add(self, other: Complex64) -> Complex64 {
331 self + other
332 }
333
334 #[inline(always)]
335 fn simd_sub(self, other: Complex64) -> Complex64 {
336 self - other
337 }
338
339 #[inline(always)]
340 fn simd_mul(self, other: Complex64) -> Complex64 {
341 self * other
342 }
343
344 #[inline(always)]
345 fn simd_scalar_mul(self, scalar: Complex64) -> Complex64 {
346 self * scalar
347 }
348
349 #[inline]
351 fn simd_dot(a: &[Complex64], b: &[Complex64]) -> Complex64 {
352 assert_eq!(a.len(), b.len(), "simd_dot complex: length mismatch");
353 let n = a.len();
354 let chunks = n / UNROLL;
355 let rem = n % UNROLL;
356 let base = chunks * UNROLL;
357
358 let zero = Complex64::new(0.0, 0.0);
359 let mut acc0 = zero;
360 let mut acc1 = zero;
361 let mut acc2 = zero;
362 let mut acc3 = zero;
363
364 for i in 0..chunks {
365 let j = i * UNROLL;
366 acc0 += a[j] * b[j];
367 acc1 += a[j + 1] * b[j + 1];
368 acc2 += a[j + 2] * b[j + 2];
369 acc3 += a[j + 3] * b[j + 3];
370 }
371
372 let mut total = acc0 + acc1 + acc2 + acc3;
373 for k in 0..rem {
374 total += a[base + k] * b[base + k];
375 }
376 total
377 }
378
379 #[inline]
381 fn simd_sum(slice: &[Complex64]) -> Complex64 {
382 let n = slice.len();
383 let chunks = n / UNROLL;
384 let rem = n % UNROLL;
385 let base = chunks * UNROLL;
386
387 let zero = Complex64::new(0.0, 0.0);
388 let mut acc0 = zero;
389 let mut acc1 = zero;
390 let mut acc2 = zero;
391 let mut acc3 = zero;
392
393 for i in 0..chunks {
394 let j = i * UNROLL;
395 acc0 += slice[j];
396 acc1 += slice[j + 1];
397 acc2 += slice[j + 2];
398 acc3 += slice[j + 3];
399 }
400
401 let mut total = acc0 + acc1 + acc2 + acc3;
402 for k in 0..rem {
403 total += slice[base + k];
404 }
405 total
406 }
407
408 #[inline]
409 fn simd_sum_array(a: &ArrayView1<Complex64>) -> Complex64 {
410 match a.as_slice() {
411 Some(s) => <Complex64 as SimdComplex64>::simd_sum(s),
412 None => a.sum(),
413 }
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use scirs2_core::ndarray::array;
421
422 #[test]
423 fn test_simd_dot_basic() {
424 let a = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
425 let b = [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
426 let result = <f64 as SimdF64>::simd_dot(&a, &b);
427 let expected: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
428 assert!(
429 (result - expected).abs() < 1e-12,
430 "simd_dot mismatch: {result} vs {expected}"
431 );
432 }
433
434 #[test]
435 fn test_simd_sum_unrolled() {
436 let data: Vec<f64> = (0..17).map(|i| i as f64).collect();
437 let result = <f64 as SimdF64>::simd_sum(&data);
438 let expected: f64 = data.iter().sum();
439 assert!((result - expected).abs() < 1e-12);
440 }
441
442 #[test]
443 fn test_simd_fmadd() {
444 let a = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
445 let b = vec![2.0f64, 2.0, 2.0, 2.0, 2.0];
446 let c = vec![0.5f64, 0.5, 0.5, 0.5, 0.5];
447 let result = <f64 as SimdF64>::simd_fmadd(&a, &b, &c);
448 let expected: Vec<f64> = a
449 .iter()
450 .zip(b.iter())
451 .zip(c.iter())
452 .map(|((ai, bi), ci)| ai * bi + ci)
453 .collect();
454 for (r, e) in result.iter().zip(expected.iter()) {
455 assert!((r - e).abs() < 1e-12);
456 }
457 }
458
459 #[test]
460 fn test_simd_add_arrays_unrolled() {
461 let a = array![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
462 let b = array![9.0f64, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
463 let result = <f64 as SimdF64>::simd_add_arrays(&a.view(), &b.view());
464 for v in result.iter() {
465 assert!((v - 10.0).abs() < 1e-12);
466 }
467 }
468
469 #[test]
470 fn test_simd_max_min() {
471 let data = vec![3.0f64, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0, 5.0];
472 assert!(((<f64 as SimdF64>::simd_max(&data)) - 9.0).abs() < 1e-12);
473 assert!(((<f64 as SimdF64>::simd_min(&data)) - 1.0).abs() < 1e-12);
474 }
475
476 #[test]
477 fn test_complex_simd_dot() {
478 let a = vec![
479 Complex64::new(1.0, 0.0),
480 Complex64::new(0.0, 1.0),
481 Complex64::new(1.0, 1.0),
482 Complex64::new(2.0, -1.0),
483 Complex64::new(0.5, 0.5),
484 ];
485 let b = a.clone();
486 let result = <Complex64 as SimdComplex64>::simd_dot(&a, &b);
487 let expected: Complex64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
488 assert!((result.re - expected.re).abs() < 1e-12);
489 assert!((result.im - expected.im).abs() < 1e-12);
490 }
491}