1#[cfg(feature = "no-std")]
7use alloc::{vec, vec::Vec};
8
9pub fn parallel_sum_f32_simd(arr: &[f32]) -> f32 {
12 if arr.is_empty() {
13 return 0.0;
14 }
15
16 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
17 {
18 if crate::simd_feature_detected!("avx2") && arr.len() >= 8 {
19 return unsafe { parallel_sum_avx2(arr) };
20 } else if crate::simd_feature_detected!("sse2") && arr.len() >= 4 {
21 return unsafe { parallel_sum_sse2(arr) };
22 }
23 }
24
25 parallel_sum_scalar(arr)
26}
27
28fn parallel_sum_scalar(arr: &[f32]) -> f32 {
29 arr.iter().sum()
30}
31
32#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
33#[target_feature(enable = "sse2")]
34unsafe fn parallel_sum_sse2(arr: &[f32]) -> f32 {
35 use core::arch::x86_64::*;
36
37 let mut sum = _mm_setzero_ps();
38 let mut i = 0;
39
40 while i + 4 <= arr.len() {
42 let vec = _mm_loadu_ps(&arr[i]);
43 sum = _mm_add_ps(sum, vec);
44 i += 4;
45 }
46
47 let mut result = [0.0f32; 4];
49 _mm_storeu_ps(result.as_mut_ptr(), sum);
50 let mut total = result[0] + result[1] + result[2] + result[3];
51
52 while i < arr.len() {
54 total += arr[i];
55 i += 1;
56 }
57
58 total
59}
60
61#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
62#[target_feature(enable = "avx2")]
63unsafe fn parallel_sum_avx2(arr: &[f32]) -> f32 {
64 use core::arch::x86_64::*;
65
66 let mut sum = _mm256_setzero_ps();
67 let mut i = 0;
68
69 while i + 8 <= arr.len() {
71 let vec = _mm256_loadu_ps(&arr[i]);
72 sum = _mm256_add_ps(sum, vec);
73 i += 8;
74 }
75
76 let mut result = [0.0f32; 8];
78 _mm256_storeu_ps(result.as_mut_ptr(), sum);
79 let mut total = result.iter().sum::<f32>();
80
81 while i < arr.len() {
83 total += arr[i];
84 i += 1;
85 }
86
87 total
88}
89
90pub fn parallel_product_f32_simd(arr: &[f32]) -> f32 {
93 if arr.is_empty() {
94 return 1.0;
95 }
96
97 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
98 {
99 if crate::simd_feature_detected!("avx2") && arr.len() >= 8 {
100 return unsafe { parallel_product_avx2(arr) };
101 } else if crate::simd_feature_detected!("sse2") && arr.len() >= 4 {
102 return unsafe { parallel_product_sse2(arr) };
103 }
104 }
105
106 parallel_product_scalar(arr)
107}
108
109fn parallel_product_scalar(arr: &[f32]) -> f32 {
110 arr.iter().product()
111}
112
113#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
114#[target_feature(enable = "sse2")]
115unsafe fn parallel_product_sse2(arr: &[f32]) -> f32 {
116 use core::arch::x86_64::*;
117
118 let mut product = _mm_set1_ps(1.0);
119 let mut i = 0;
120
121 while i + 4 <= arr.len() {
123 let vec = _mm_loadu_ps(&arr[i]);
124 product = _mm_mul_ps(product, vec);
125 i += 4;
126 }
127
128 let mut result = [0.0f32; 4];
130 _mm_storeu_ps(result.as_mut_ptr(), product);
131 let mut total = result[0] * result[1] * result[2] * result[3];
132
133 while i < arr.len() {
135 total *= arr[i];
136 i += 1;
137 }
138
139 total
140}
141
142#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
143#[target_feature(enable = "avx2")]
144unsafe fn parallel_product_avx2(arr: &[f32]) -> f32 {
145 use core::arch::x86_64::*;
146
147 let mut product = _mm256_set1_ps(1.0);
148 let mut i = 0;
149
150 while i + 8 <= arr.len() {
152 let vec = _mm256_loadu_ps(&arr[i]);
153 product = _mm256_mul_ps(product, vec);
154 i += 8;
155 }
156
157 let mut result = [0.0f32; 8];
159 _mm256_storeu_ps(result.as_mut_ptr(), product);
160 let mut total = result.iter().product::<f32>();
161
162 while i < arr.len() {
164 total *= arr[i];
165 i += 1;
166 }
167
168 total
169}
170
171pub fn parallel_max_f32_simd(arr: &[f32]) -> Option<f32> {
174 if arr.is_empty() {
175 return None;
176 }
177
178 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
179 {
180 if crate::simd_feature_detected!("avx2") && arr.len() >= 8 {
181 return Some(unsafe { parallel_max_avx2(arr) });
182 } else if crate::simd_feature_detected!("sse2") && arr.len() >= 4 {
183 return Some(unsafe { parallel_max_sse2(arr) });
184 }
185 }
186
187 parallel_max_scalar(arr)
188}
189
190fn parallel_max_scalar(arr: &[f32]) -> Option<f32> {
191 arr.iter()
192 .fold(None, |acc, &x| Some(acc.map_or(x, |max| x.max(max))))
193}
194
195#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
196#[target_feature(enable = "sse2")]
197unsafe fn parallel_max_sse2(arr: &[f32]) -> f32 {
198 use core::arch::x86_64::*;
199
200 let mut max_vec = _mm_set1_ps(f32::NEG_INFINITY);
201 let mut i = 0;
202
203 while i + 4 <= arr.len() {
205 let vec = _mm_loadu_ps(&arr[i]);
206 max_vec = _mm_max_ps(max_vec, vec);
207 i += 4;
208 }
209
210 let mut result = [0.0f32; 4];
212 _mm_storeu_ps(result.as_mut_ptr(), max_vec);
213 let mut max_val = result[0].max(result[1]).max(result[2]).max(result[3]);
214
215 while i < arr.len() {
217 max_val = max_val.max(arr[i]);
218 i += 1;
219 }
220
221 max_val
222}
223
224#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
225#[target_feature(enable = "avx2")]
226unsafe fn parallel_max_avx2(arr: &[f32]) -> f32 {
227 use core::arch::x86_64::*;
228
229 let mut max_vec = _mm256_set1_ps(f32::NEG_INFINITY);
230 let mut i = 0;
231
232 while i + 8 <= arr.len() {
234 let vec = _mm256_loadu_ps(&arr[i]);
235 max_vec = _mm256_max_ps(max_vec, vec);
236 i += 8;
237 }
238
239 let mut result = [0.0f32; 8];
241 _mm256_storeu_ps(result.as_mut_ptr(), max_vec);
242 let mut max_val = result[0];
243 for val in result.iter().skip(1) {
244 max_val = max_val.max(*val);
245 }
246
247 while i < arr.len() {
249 max_val = max_val.max(arr[i]);
250 i += 1;
251 }
252
253 max_val
254}
255
256pub fn parallel_min_f32_simd(arr: &[f32]) -> Option<f32> {
259 if arr.is_empty() {
260 return None;
261 }
262
263 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
264 {
265 if crate::simd_feature_detected!("avx2") && arr.len() >= 8 {
266 return Some(unsafe { parallel_min_avx2(arr) });
267 } else if crate::simd_feature_detected!("sse2") && arr.len() >= 4 {
268 return Some(unsafe { parallel_min_sse2(arr) });
269 }
270 }
271
272 parallel_min_scalar(arr)
273}
274
275fn parallel_min_scalar(arr: &[f32]) -> Option<f32> {
276 arr.iter()
277 .fold(None, |acc, &x| Some(acc.map_or(x, |min| x.min(min))))
278}
279
280#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
281#[target_feature(enable = "sse2")]
282unsafe fn parallel_min_sse2(arr: &[f32]) -> f32 {
283 use core::arch::x86_64::*;
284
285 let mut min_vec = _mm_set1_ps(f32::INFINITY);
286 let mut i = 0;
287
288 while i + 4 <= arr.len() {
290 let vec = _mm_loadu_ps(&arr[i]);
291 min_vec = _mm_min_ps(min_vec, vec);
292 i += 4;
293 }
294
295 let mut result = [0.0f32; 4];
297 _mm_storeu_ps(result.as_mut_ptr(), min_vec);
298 let mut min_val = result[0].min(result[1]).min(result[2]).min(result[3]);
299
300 while i < arr.len() {
302 min_val = min_val.min(arr[i]);
303 i += 1;
304 }
305
306 min_val
307}
308
309#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
310#[target_feature(enable = "avx2")]
311unsafe fn parallel_min_avx2(arr: &[f32]) -> f32 {
312 use core::arch::x86_64::*;
313
314 let mut min_vec = _mm256_set1_ps(f32::INFINITY);
315 let mut i = 0;
316
317 while i + 8 <= arr.len() {
319 let vec = _mm256_loadu_ps(&arr[i]);
320 min_vec = _mm256_min_ps(min_vec, vec);
321 i += 8;
322 }
323
324 let mut result = [0.0f32; 8];
326 _mm256_storeu_ps(result.as_mut_ptr(), min_vec);
327 let mut min_val = result[0];
328 for val in result.iter().skip(1) {
329 min_val = min_val.min(*val);
330 }
331
332 while i < arr.len() {
334 min_val = min_val.min(arr[i]);
335 i += 1;
336 }
337
338 min_val
339}
340
341pub fn prefix_sum_f32_simd(input: &[f32], output: &mut [f32]) {
344 assert_eq!(
345 input.len(),
346 output.len(),
347 "Input and output must have same length"
348 );
349
350 if input.is_empty() {
351 return;
352 }
353
354 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
355 {
356 if crate::simd_feature_detected!("avx2") && input.len() >= 8 {
357 unsafe { prefix_sum_avx2(input, output) };
358 return;
359 } else if crate::simd_feature_detected!("sse2") && input.len() >= 4 {
360 unsafe { prefix_sum_sse2(input, output) };
361 return;
362 }
363 }
364
365 prefix_sum_scalar(input, output);
366}
367
368fn prefix_sum_scalar(input: &[f32], output: &mut [f32]) {
369 if input.is_empty() {
370 return;
371 }
372
373 output[0] = input[0];
374 for i in 1..input.len() {
375 output[i] = output[i - 1] + input[i];
376 }
377}
378
379#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
380#[target_feature(enable = "sse2")]
381unsafe fn prefix_sum_sse2(input: &[f32], output: &mut [f32]) {
382 use core::arch::x86_64::*;
383
384 let len = input.len();
385 let mut running_sum = 0.0;
386
387 let mut i = 0;
388 while i + 4 <= len {
389 let vec = _mm_loadu_ps(&input[i]);
391
392 let mut temp = [0.0f32; 4];
394 _mm_storeu_ps(temp.as_mut_ptr(), vec);
395
396 output[i] = running_sum + temp[0];
398 output[i + 1] = running_sum + temp[0] + temp[1];
399 output[i + 2] = running_sum + temp[0] + temp[1] + temp[2];
400 output[i + 3] = running_sum + temp[0] + temp[1] + temp[2] + temp[3];
401
402 running_sum = output[i + 3];
404
405 i += 4;
406 }
407
408 while i < len {
410 output[i] = running_sum + input[i];
411 running_sum = output[i];
412 i += 1;
413 }
414}
415
416#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
417#[target_feature(enable = "avx2")]
418unsafe fn prefix_sum_avx2(input: &[f32], output: &mut [f32]) {
419 use core::arch::x86_64::*;
420
421 let len = input.len();
422 let mut running_sum = 0.0;
423
424 let mut i = 0;
425 while i + 8 <= len {
426 let vec = _mm256_loadu_ps(&input[i]);
428
429 let mut temp = [0.0f32; 8];
431 _mm256_storeu_ps(temp.as_mut_ptr(), vec);
432
433 output[i] = running_sum + temp[0];
435 output[i + 1] = running_sum + temp[0] + temp[1];
436 output[i + 2] = running_sum + temp[0] + temp[1] + temp[2];
437 output[i + 3] = running_sum + temp[0] + temp[1] + temp[2] + temp[3];
438 output[i + 4] = running_sum + temp[0] + temp[1] + temp[2] + temp[3] + temp[4];
439 output[i + 5] = running_sum + temp[0] + temp[1] + temp[2] + temp[3] + temp[4] + temp[5];
440 output[i + 6] =
441 running_sum + temp[0] + temp[1] + temp[2] + temp[3] + temp[4] + temp[5] + temp[6];
442 output[i + 7] = running_sum
443 + temp[0]
444 + temp[1]
445 + temp[2]
446 + temp[3]
447 + temp[4]
448 + temp[5]
449 + temp[6]
450 + temp[7];
451
452 running_sum = output[i + 7];
454
455 i += 8;
456 }
457
458 while i < len {
460 output[i] = running_sum + input[i];
461 running_sum = output[i];
462 i += 1;
463 }
464}
465
466pub fn exclusive_scan_f32_simd(input: &[f32], output: &mut [f32]) {
469 assert_eq!(
470 input.len(),
471 output.len(),
472 "Input and output must have same length"
473 );
474
475 if input.is_empty() {
476 return;
477 }
478
479 let mut temp = vec![0.0; input.len()];
481 prefix_sum_f32_simd(input, &mut temp);
482
483 output[0] = 0.0;
484 output[1..].copy_from_slice(&temp[..input.len() - 1]);
485}
486
487pub fn segmented_sum_f32_simd(input: &[f32], segment_flags: &[bool], output: &mut [f32]) {
491 assert_eq!(
492 input.len(),
493 segment_flags.len(),
494 "Input and flags must have same length"
495 );
496 assert_eq!(
497 input.len(),
498 output.len(),
499 "Input and output must have same length"
500 );
501
502 if input.is_empty() {
503 return;
504 }
505
506 let mut running_sum = 0.0;
507
508 for i in 0..input.len() {
509 if segment_flags[i] {
510 running_sum = input[i];
511 } else {
512 running_sum += input[i];
513 }
514 output[i] = running_sum;
515 }
516}
517
518pub fn conditional_sum_f32_simd(input: &[f32], condition: &[bool]) -> f32 {
521 assert_eq!(
522 input.len(),
523 condition.len(),
524 "Input and condition must have same length"
525 );
526
527 if input.is_empty() {
528 return 0.0;
529 }
530
531 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
532 {
533 if crate::simd_feature_detected!("avx2") && input.len() >= 8 {
534 return unsafe { conditional_sum_avx2(input, condition) };
535 } else if crate::simd_feature_detected!("sse2") && input.len() >= 4 {
536 return unsafe { conditional_sum_sse2(input, condition) };
537 }
538 }
539
540 conditional_sum_scalar(input, condition)
541}
542
543fn conditional_sum_scalar(input: &[f32], condition: &[bool]) -> f32 {
544 input
545 .iter()
546 .zip(condition.iter())
547 .map(|(&val, &cond)| if cond { val } else { 0.0 })
548 .sum()
549}
550
551#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
552#[target_feature(enable = "sse2")]
553unsafe fn conditional_sum_sse2(input: &[f32], condition: &[bool]) -> f32 {
554 use core::arch::x86_64::*;
555
556 let mut sum = _mm_setzero_ps();
557 let mut i = 0;
558
559 while i + 4 <= input.len() {
560 let vec = _mm_loadu_ps(&input[i]);
561
562 let mask = _mm_set_ps(
564 if condition[i + 3] { 1.0 } else { 0.0 },
565 if condition[i + 2] { 1.0 } else { 0.0 },
566 if condition[i + 1] { 1.0 } else { 0.0 },
567 if condition[i] { 1.0 } else { 0.0 },
568 );
569
570 let masked_vec = _mm_mul_ps(vec, mask);
571 sum = _mm_add_ps(sum, masked_vec);
572
573 i += 4;
574 }
575
576 let mut result = [0.0f32; 4];
578 _mm_storeu_ps(result.as_mut_ptr(), sum);
579 let mut total = result[0] + result[1] + result[2] + result[3];
580
581 while i < input.len() {
583 if condition[i] {
584 total += input[i];
585 }
586 i += 1;
587 }
588
589 total
590}
591
592#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
593#[target_feature(enable = "avx2")]
594unsafe fn conditional_sum_avx2(input: &[f32], condition: &[bool]) -> f32 {
595 use core::arch::x86_64::*;
596
597 let mut sum = _mm256_setzero_ps();
598 let mut i = 0;
599
600 while i + 8 <= input.len() {
601 let vec = _mm256_loadu_ps(&input[i]);
602
603 let mask = _mm256_set_ps(
605 if condition[i + 7] { 1.0 } else { 0.0 },
606 if condition[i + 6] { 1.0 } else { 0.0 },
607 if condition[i + 5] { 1.0 } else { 0.0 },
608 if condition[i + 4] { 1.0 } else { 0.0 },
609 if condition[i + 3] { 1.0 } else { 0.0 },
610 if condition[i + 2] { 1.0 } else { 0.0 },
611 if condition[i + 1] { 1.0 } else { 0.0 },
612 if condition[i] { 1.0 } else { 0.0 },
613 );
614
615 let masked_vec = _mm256_mul_ps(vec, mask);
616 sum = _mm256_add_ps(sum, masked_vec);
617
618 i += 8;
619 }
620
621 let mut result = [0.0f32; 8];
623 _mm256_storeu_ps(result.as_mut_ptr(), sum);
624 let mut total = result.iter().sum::<f32>();
625
626 while i < input.len() {
628 if condition[i] {
629 total += input[i];
630 }
631 i += 1;
632 }
633
634 total
635}
636
637#[derive(Debug, Clone, PartialEq)]
640pub struct KeyValue<K, V> {
641 pub key: K,
642 pub value: V,
643}
644
645pub fn reduce_by_key_f32_simd(
646 input: &[KeyValue<i32, f32>],
647 reduction_op: fn(f32, f32) -> f32,
648) -> Vec<KeyValue<i32, f32>> {
649 if input.is_empty() {
650 return Vec::new();
651 }
652
653 let mut result = Vec::new();
654 let mut current_key = input[0].key;
655 let mut current_value = input[0].value;
656
657 for item in input.iter().skip(1) {
658 if item.key == current_key {
659 current_value = reduction_op(current_value, item.value);
660 } else {
661 result.push(KeyValue {
662 key: current_key,
663 value: current_value,
664 });
665 current_key = item.key;
666 current_value = item.value;
667 }
668 }
669
670 result.push(KeyValue {
672 key: current_key,
673 value: current_value,
674 });
675
676 result
677}
678
679#[allow(non_snake_case)]
680#[cfg(all(test, not(feature = "no-std")))]
681mod tests {
682 use super::*;
683 use approx::assert_relative_eq;
684
685 #[test]
686 fn test_parallel_sum() {
687 let arr = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
688 let sum = parallel_sum_f32_simd(&arr);
689 let expected: f32 = arr.iter().sum();
690 assert_relative_eq!(sum, expected, epsilon = 1e-6);
691 }
692
693 #[test]
694 fn test_parallel_product() {
695 let arr = vec![1.0, 2.0, 3.0, 4.0];
696 let product = parallel_product_f32_simd(&arr);
697 let expected: f32 = arr.iter().product();
698 assert_relative_eq!(product, expected, epsilon = 1e-6);
699 }
700
701 #[test]
702 fn test_parallel_max() {
703 let arr = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0];
704 let max_val = parallel_max_f32_simd(&arr);
705 assert_eq!(max_val, Some(9.0));
706 }
707
708 #[test]
709 fn test_parallel_min() {
710 let arr = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0];
711 let min_val = parallel_min_f32_simd(&arr);
712 assert_eq!(min_val, Some(1.0));
713 }
714
715 #[test]
716 fn test_prefix_sum() {
717 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
718 let mut output = vec![0.0; input.len()];
719 prefix_sum_f32_simd(&input, &mut output);
720
721 let expected = [1.0, 3.0, 6.0, 10.0, 15.0];
722 for (i, &expected_val) in expected.iter().enumerate() {
723 assert_relative_eq!(output[i], expected_val, epsilon = 1e-6);
724 }
725 }
726
727 #[test]
728 fn test_exclusive_scan() {
729 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
730 let mut output = vec![0.0; input.len()];
731 exclusive_scan_f32_simd(&input, &mut output);
732
733 let expected = [0.0, 1.0, 3.0, 6.0, 10.0];
734 for (i, &expected_val) in expected.iter().enumerate() {
735 assert_relative_eq!(output[i], expected_val, epsilon = 1e-6);
736 }
737 }
738
739 #[test]
740 fn test_segmented_sum() {
741 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
742 let flags = vec![true, false, false, true, false, true];
743 let mut output = vec![0.0; input.len()];
744
745 segmented_sum_f32_simd(&input, &flags, &mut output);
746
747 let expected = [1.0, 3.0, 6.0, 4.0, 9.0, 6.0];
750 for (i, &expected_val) in expected.iter().enumerate() {
751 assert_relative_eq!(output[i], expected_val, epsilon = 1e-6);
752 }
753 }
754
755 #[test]
756 fn test_conditional_sum() {
757 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
758 let condition = vec![true, false, true, false, true];
759
760 let sum = conditional_sum_f32_simd(&input, &condition);
761 assert_relative_eq!(sum, 9.0, epsilon = 1e-6); }
763
764 #[test]
765 fn test_reduce_by_key() {
766 let input = vec![
767 KeyValue {
768 key: 1,
769 value: 10.0,
770 },
771 KeyValue {
772 key: 1,
773 value: 20.0,
774 },
775 KeyValue {
776 key: 2,
777 value: 30.0,
778 },
779 KeyValue {
780 key: 2,
781 value: 40.0,
782 },
783 KeyValue {
784 key: 3,
785 value: 50.0,
786 },
787 ];
788
789 let result = reduce_by_key_f32_simd(&input, |a, b| a + b);
790
791 assert_eq!(result.len(), 3);
792 assert_eq!(
793 result[0],
794 KeyValue {
795 key: 1,
796 value: 30.0
797 }
798 );
799 assert_eq!(
800 result[1],
801 KeyValue {
802 key: 2,
803 value: 70.0
804 }
805 );
806 assert_eq!(
807 result[2],
808 KeyValue {
809 key: 3,
810 value: 50.0
811 }
812 );
813 }
814
815 #[test]
816 fn test_empty_arrays() {
817 let empty: Vec<f32> = vec![];
818 assert_eq!(parallel_sum_f32_simd(&empty), 0.0);
819 assert_eq!(parallel_product_f32_simd(&empty), 1.0);
820 assert_eq!(parallel_max_f32_simd(&empty), None);
821 assert_eq!(parallel_min_f32_simd(&empty), None);
822 }
823
824 #[test]
825 fn test_single_element() {
826 let arr = vec![42.0];
827 assert_eq!(parallel_sum_f32_simd(&arr), 42.0);
828 assert_eq!(parallel_product_f32_simd(&arr), 42.0);
829 assert_eq!(parallel_max_f32_simd(&arr), Some(42.0));
830 assert_eq!(parallel_min_f32_simd(&arr), Some(42.0));
831 }
832}