Skip to main content

sklears_simd/
reduction.rs

1//! SIMD-optimized reduction and scan operations
2//!
3//! This module provides vectorized implementations of reduction operations
4//! including parallel reductions, prefix sums (scan), and segment-based operations.
5
6#[cfg(feature = "no-std")]
7use alloc::{vec, vec::Vec};
8
9/// SIMD-optimized parallel reduction sum
10/// Computes the sum of all elements in the array using parallel reduction
11pub fn parallel_sum_f32_simd(arr: &[f32]) -> f32 {
12    if arr.is_empty() {
13        return 0.0;
14    }
15
16    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
17    {
18        if crate::simd_feature_detected!("avx2") && arr.len() >= 8 {
19            return unsafe { parallel_sum_avx2(arr) };
20        } else if crate::simd_feature_detected!("sse2") && arr.len() >= 4 {
21            return unsafe { parallel_sum_sse2(arr) };
22        }
23    }
24
25    parallel_sum_scalar(arr)
26}
27
28fn parallel_sum_scalar(arr: &[f32]) -> f32 {
29    arr.iter().sum()
30}
31
32#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
33#[target_feature(enable = "sse2")]
34unsafe fn parallel_sum_sse2(arr: &[f32]) -> f32 {
35    use core::arch::x86_64::*;
36
37    let mut sum = _mm_setzero_ps();
38    let mut i = 0;
39
40    // Process 4 elements at a time
41    while i + 4 <= arr.len() {
42        let vec = _mm_loadu_ps(&arr[i]);
43        sum = _mm_add_ps(sum, vec);
44        i += 4;
45    }
46
47    // Horizontal sum of the SIMD register
48    let mut result = [0.0f32; 4];
49    _mm_storeu_ps(result.as_mut_ptr(), sum);
50    let mut total = result[0] + result[1] + result[2] + result[3];
51
52    // Handle remaining elements
53    while i < arr.len() {
54        total += arr[i];
55        i += 1;
56    }
57
58    total
59}
60
61#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
62#[target_feature(enable = "avx2")]
63unsafe fn parallel_sum_avx2(arr: &[f32]) -> f32 {
64    use core::arch::x86_64::*;
65
66    let mut sum = _mm256_setzero_ps();
67    let mut i = 0;
68
69    // Process 8 elements at a time
70    while i + 8 <= arr.len() {
71        let vec = _mm256_loadu_ps(&arr[i]);
72        sum = _mm256_add_ps(sum, vec);
73        i += 8;
74    }
75
76    // Horizontal sum of the SIMD register
77    let mut result = [0.0f32; 8];
78    _mm256_storeu_ps(result.as_mut_ptr(), sum);
79    let mut total = result.iter().sum::<f32>();
80
81    // Handle remaining elements
82    while i < arr.len() {
83        total += arr[i];
84        i += 1;
85    }
86
87    total
88}
89
90/// SIMD-optimized parallel reduction product
91/// Computes the product of all elements in the array using parallel reduction
92pub fn parallel_product_f32_simd(arr: &[f32]) -> f32 {
93    if arr.is_empty() {
94        return 1.0;
95    }
96
97    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
98    {
99        if crate::simd_feature_detected!("avx2") && arr.len() >= 8 {
100            return unsafe { parallel_product_avx2(arr) };
101        } else if crate::simd_feature_detected!("sse2") && arr.len() >= 4 {
102            return unsafe { parallel_product_sse2(arr) };
103        }
104    }
105
106    parallel_product_scalar(arr)
107}
108
109fn parallel_product_scalar(arr: &[f32]) -> f32 {
110    arr.iter().product()
111}
112
113#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
114#[target_feature(enable = "sse2")]
115unsafe fn parallel_product_sse2(arr: &[f32]) -> f32 {
116    use core::arch::x86_64::*;
117
118    let mut product = _mm_set1_ps(1.0);
119    let mut i = 0;
120
121    // Process 4 elements at a time
122    while i + 4 <= arr.len() {
123        let vec = _mm_loadu_ps(&arr[i]);
124        product = _mm_mul_ps(product, vec);
125        i += 4;
126    }
127
128    // Horizontal product of the SIMD register
129    let mut result = [0.0f32; 4];
130    _mm_storeu_ps(result.as_mut_ptr(), product);
131    let mut total = result[0] * result[1] * result[2] * result[3];
132
133    // Handle remaining elements
134    while i < arr.len() {
135        total *= arr[i];
136        i += 1;
137    }
138
139    total
140}
141
142#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
143#[target_feature(enable = "avx2")]
144unsafe fn parallel_product_avx2(arr: &[f32]) -> f32 {
145    use core::arch::x86_64::*;
146
147    let mut product = _mm256_set1_ps(1.0);
148    let mut i = 0;
149
150    // Process 8 elements at a time
151    while i + 8 <= arr.len() {
152        let vec = _mm256_loadu_ps(&arr[i]);
153        product = _mm256_mul_ps(product, vec);
154        i += 8;
155    }
156
157    // Horizontal product of the SIMD register
158    let mut result = [0.0f32; 8];
159    _mm256_storeu_ps(result.as_mut_ptr(), product);
160    let mut total = result.iter().product::<f32>();
161
162    // Handle remaining elements
163    while i < arr.len() {
164        total *= arr[i];
165        i += 1;
166    }
167
168    total
169}
170
171/// SIMD-optimized parallel reduction maximum
172/// Finds the maximum element in the array using parallel reduction
173pub fn parallel_max_f32_simd(arr: &[f32]) -> Option<f32> {
174    if arr.is_empty() {
175        return None;
176    }
177
178    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
179    {
180        if crate::simd_feature_detected!("avx2") && arr.len() >= 8 {
181            return Some(unsafe { parallel_max_avx2(arr) });
182        } else if crate::simd_feature_detected!("sse2") && arr.len() >= 4 {
183            return Some(unsafe { parallel_max_sse2(arr) });
184        }
185    }
186
187    parallel_max_scalar(arr)
188}
189
190fn parallel_max_scalar(arr: &[f32]) -> Option<f32> {
191    arr.iter()
192        .fold(None, |acc, &x| Some(acc.map_or(x, |max| x.max(max))))
193}
194
195#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
196#[target_feature(enable = "sse2")]
197unsafe fn parallel_max_sse2(arr: &[f32]) -> f32 {
198    use core::arch::x86_64::*;
199
200    let mut max_vec = _mm_set1_ps(f32::NEG_INFINITY);
201    let mut i = 0;
202
203    // Process 4 elements at a time
204    while i + 4 <= arr.len() {
205        let vec = _mm_loadu_ps(&arr[i]);
206        max_vec = _mm_max_ps(max_vec, vec);
207        i += 4;
208    }
209
210    // Horizontal max of the SIMD register
211    let mut result = [0.0f32; 4];
212    _mm_storeu_ps(result.as_mut_ptr(), max_vec);
213    let mut max_val = result[0].max(result[1]).max(result[2]).max(result[3]);
214
215    // Handle remaining elements
216    while i < arr.len() {
217        max_val = max_val.max(arr[i]);
218        i += 1;
219    }
220
221    max_val
222}
223
224#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
225#[target_feature(enable = "avx2")]
226unsafe fn parallel_max_avx2(arr: &[f32]) -> f32 {
227    use core::arch::x86_64::*;
228
229    let mut max_vec = _mm256_set1_ps(f32::NEG_INFINITY);
230    let mut i = 0;
231
232    // Process 8 elements at a time
233    while i + 8 <= arr.len() {
234        let vec = _mm256_loadu_ps(&arr[i]);
235        max_vec = _mm256_max_ps(max_vec, vec);
236        i += 8;
237    }
238
239    // Horizontal max of the SIMD register
240    let mut result = [0.0f32; 8];
241    _mm256_storeu_ps(result.as_mut_ptr(), max_vec);
242    let mut max_val = result[0];
243    for val in result.iter().skip(1) {
244        max_val = max_val.max(*val);
245    }
246
247    // Handle remaining elements
248    while i < arr.len() {
249        max_val = max_val.max(arr[i]);
250        i += 1;
251    }
252
253    max_val
254}
255
256/// SIMD-optimized parallel reduction minimum
257/// Finds the minimum element in the array using parallel reduction
258pub fn parallel_min_f32_simd(arr: &[f32]) -> Option<f32> {
259    if arr.is_empty() {
260        return None;
261    }
262
263    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
264    {
265        if crate::simd_feature_detected!("avx2") && arr.len() >= 8 {
266            return Some(unsafe { parallel_min_avx2(arr) });
267        } else if crate::simd_feature_detected!("sse2") && arr.len() >= 4 {
268            return Some(unsafe { parallel_min_sse2(arr) });
269        }
270    }
271
272    parallel_min_scalar(arr)
273}
274
275fn parallel_min_scalar(arr: &[f32]) -> Option<f32> {
276    arr.iter()
277        .fold(None, |acc, &x| Some(acc.map_or(x, |min| x.min(min))))
278}
279
280#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
281#[target_feature(enable = "sse2")]
282unsafe fn parallel_min_sse2(arr: &[f32]) -> f32 {
283    use core::arch::x86_64::*;
284
285    let mut min_vec = _mm_set1_ps(f32::INFINITY);
286    let mut i = 0;
287
288    // Process 4 elements at a time
289    while i + 4 <= arr.len() {
290        let vec = _mm_loadu_ps(&arr[i]);
291        min_vec = _mm_min_ps(min_vec, vec);
292        i += 4;
293    }
294
295    // Horizontal min of the SIMD register
296    let mut result = [0.0f32; 4];
297    _mm_storeu_ps(result.as_mut_ptr(), min_vec);
298    let mut min_val = result[0].min(result[1]).min(result[2]).min(result[3]);
299
300    // Handle remaining elements
301    while i < arr.len() {
302        min_val = min_val.min(arr[i]);
303        i += 1;
304    }
305
306    min_val
307}
308
309#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
310#[target_feature(enable = "avx2")]
311unsafe fn parallel_min_avx2(arr: &[f32]) -> f32 {
312    use core::arch::x86_64::*;
313
314    let mut min_vec = _mm256_set1_ps(f32::INFINITY);
315    let mut i = 0;
316
317    // Process 8 elements at a time
318    while i + 8 <= arr.len() {
319        let vec = _mm256_loadu_ps(&arr[i]);
320        min_vec = _mm256_min_ps(min_vec, vec);
321        i += 8;
322    }
323
324    // Horizontal min of the SIMD register
325    let mut result = [0.0f32; 8];
326    _mm256_storeu_ps(result.as_mut_ptr(), min_vec);
327    let mut min_val = result[0];
328    for val in result.iter().skip(1) {
329        min_val = min_val.min(*val);
330    }
331
332    // Handle remaining elements
333    while i < arr.len() {
334        min_val = min_val.min(arr[i]);
335        i += 1;
336    }
337
338    min_val
339}
340
341/// SIMD-optimized prefix sum (inclusive scan)
342/// Computes cumulative sum where output\[i\] = sum of input\[0\] through input\[i\]
343pub fn prefix_sum_f32_simd(input: &[f32], output: &mut [f32]) {
344    assert_eq!(
345        input.len(),
346        output.len(),
347        "Input and output must have same length"
348    );
349
350    if input.is_empty() {
351        return;
352    }
353
354    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
355    {
356        if crate::simd_feature_detected!("avx2") && input.len() >= 8 {
357            unsafe { prefix_sum_avx2(input, output) };
358            return;
359        } else if crate::simd_feature_detected!("sse2") && input.len() >= 4 {
360            unsafe { prefix_sum_sse2(input, output) };
361            return;
362        }
363    }
364
365    prefix_sum_scalar(input, output);
366}
367
368fn prefix_sum_scalar(input: &[f32], output: &mut [f32]) {
369    if input.is_empty() {
370        return;
371    }
372
373    output[0] = input[0];
374    for i in 1..input.len() {
375        output[i] = output[i - 1] + input[i];
376    }
377}
378
379#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
380#[target_feature(enable = "sse2")]
381unsafe fn prefix_sum_sse2(input: &[f32], output: &mut [f32]) {
382    use core::arch::x86_64::*;
383
384    let len = input.len();
385    let mut running_sum = 0.0;
386
387    let mut i = 0;
388    while i + 4 <= len {
389        // Load 4 elements
390        let vec = _mm_loadu_ps(&input[i]);
391
392        // Extract elements for prefix sum computation
393        let mut temp = [0.0f32; 4];
394        _mm_storeu_ps(temp.as_mut_ptr(), vec);
395
396        // Compute prefix sum for this block
397        output[i] = running_sum + temp[0];
398        output[i + 1] = running_sum + temp[0] + temp[1];
399        output[i + 2] = running_sum + temp[0] + temp[1] + temp[2];
400        output[i + 3] = running_sum + temp[0] + temp[1] + temp[2] + temp[3];
401
402        // Update running sum
403        running_sum = output[i + 3];
404
405        i += 4;
406    }
407
408    // Handle remaining elements
409    while i < len {
410        output[i] = running_sum + input[i];
411        running_sum = output[i];
412        i += 1;
413    }
414}
415
416#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
417#[target_feature(enable = "avx2")]
418unsafe fn prefix_sum_avx2(input: &[f32], output: &mut [f32]) {
419    use core::arch::x86_64::*;
420
421    let len = input.len();
422    let mut running_sum = 0.0;
423
424    let mut i = 0;
425    while i + 8 <= len {
426        // Load 8 elements
427        let vec = _mm256_loadu_ps(&input[i]);
428
429        // Extract elements for prefix sum computation
430        let mut temp = [0.0f32; 8];
431        _mm256_storeu_ps(temp.as_mut_ptr(), vec);
432
433        // Compute prefix sum for this block
434        output[i] = running_sum + temp[0];
435        output[i + 1] = running_sum + temp[0] + temp[1];
436        output[i + 2] = running_sum + temp[0] + temp[1] + temp[2];
437        output[i + 3] = running_sum + temp[0] + temp[1] + temp[2] + temp[3];
438        output[i + 4] = running_sum + temp[0] + temp[1] + temp[2] + temp[3] + temp[4];
439        output[i + 5] = running_sum + temp[0] + temp[1] + temp[2] + temp[3] + temp[4] + temp[5];
440        output[i + 6] =
441            running_sum + temp[0] + temp[1] + temp[2] + temp[3] + temp[4] + temp[5] + temp[6];
442        output[i + 7] = running_sum
443            + temp[0]
444            + temp[1]
445            + temp[2]
446            + temp[3]
447            + temp[4]
448            + temp[5]
449            + temp[6]
450            + temp[7];
451
452        // Update running sum
453        running_sum = output[i + 7];
454
455        i += 8;
456    }
457
458    // Handle remaining elements
459    while i < len {
460        output[i] = running_sum + input[i];
461        running_sum = output[i];
462        i += 1;
463    }
464}
465
466/// SIMD-optimized exclusive scan (prefix sum where output\[i\] = sum of input\[0\] through input\[i-1\])
467/// output\[0\] = 0, output\[i\] = sum of input\[0\] through input\[i-1\] for i > 0
468pub fn exclusive_scan_f32_simd(input: &[f32], output: &mut [f32]) {
469    assert_eq!(
470        input.len(),
471        output.len(),
472        "Input and output must have same length"
473    );
474
475    if input.is_empty() {
476        return;
477    }
478
479    // Compute exclusive scan by shifting the inclusive scan
480    let mut temp = vec![0.0; input.len()];
481    prefix_sum_f32_simd(input, &mut temp);
482
483    output[0] = 0.0;
484    output[1..].copy_from_slice(&temp[..input.len() - 1]);
485}
486
487/// SIMD-optimized segmented reduction
488/// Performs reduction within segments defined by segment flags
489/// When segment_flags\[i\] is true, a new segment starts at position i
490pub fn segmented_sum_f32_simd(input: &[f32], segment_flags: &[bool], output: &mut [f32]) {
491    assert_eq!(
492        input.len(),
493        segment_flags.len(),
494        "Input and flags must have same length"
495    );
496    assert_eq!(
497        input.len(),
498        output.len(),
499        "Input and output must have same length"
500    );
501
502    if input.is_empty() {
503        return;
504    }
505
506    let mut running_sum = 0.0;
507
508    for i in 0..input.len() {
509        if segment_flags[i] {
510            running_sum = input[i];
511        } else {
512            running_sum += input[i];
513        }
514        output[i] = running_sum;
515    }
516}
517
518/// SIMD-optimized conditional reduction
519/// Performs reduction only on elements where condition\[i\] is true
520pub fn conditional_sum_f32_simd(input: &[f32], condition: &[bool]) -> f32 {
521    assert_eq!(
522        input.len(),
523        condition.len(),
524        "Input and condition must have same length"
525    );
526
527    if input.is_empty() {
528        return 0.0;
529    }
530
531    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
532    {
533        if crate::simd_feature_detected!("avx2") && input.len() >= 8 {
534            return unsafe { conditional_sum_avx2(input, condition) };
535        } else if crate::simd_feature_detected!("sse2") && input.len() >= 4 {
536            return unsafe { conditional_sum_sse2(input, condition) };
537        }
538    }
539
540    conditional_sum_scalar(input, condition)
541}
542
543fn conditional_sum_scalar(input: &[f32], condition: &[bool]) -> f32 {
544    input
545        .iter()
546        .zip(condition.iter())
547        .map(|(&val, &cond)| if cond { val } else { 0.0 })
548        .sum()
549}
550
551#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
552#[target_feature(enable = "sse2")]
553unsafe fn conditional_sum_sse2(input: &[f32], condition: &[bool]) -> f32 {
554    use core::arch::x86_64::*;
555
556    let mut sum = _mm_setzero_ps();
557    let mut i = 0;
558
559    while i + 4 <= input.len() {
560        let vec = _mm_loadu_ps(&input[i]);
561
562        // Create mask from boolean conditions
563        let mask = _mm_set_ps(
564            if condition[i + 3] { 1.0 } else { 0.0 },
565            if condition[i + 2] { 1.0 } else { 0.0 },
566            if condition[i + 1] { 1.0 } else { 0.0 },
567            if condition[i] { 1.0 } else { 0.0 },
568        );
569
570        let masked_vec = _mm_mul_ps(vec, mask);
571        sum = _mm_add_ps(sum, masked_vec);
572
573        i += 4;
574    }
575
576    // Horizontal sum
577    let mut result = [0.0f32; 4];
578    _mm_storeu_ps(result.as_mut_ptr(), sum);
579    let mut total = result[0] + result[1] + result[2] + result[3];
580
581    // Handle remaining elements
582    while i < input.len() {
583        if condition[i] {
584            total += input[i];
585        }
586        i += 1;
587    }
588
589    total
590}
591
592#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
593#[target_feature(enable = "avx2")]
594unsafe fn conditional_sum_avx2(input: &[f32], condition: &[bool]) -> f32 {
595    use core::arch::x86_64::*;
596
597    let mut sum = _mm256_setzero_ps();
598    let mut i = 0;
599
600    while i + 8 <= input.len() {
601        let vec = _mm256_loadu_ps(&input[i]);
602
603        // Create mask from boolean conditions
604        let mask = _mm256_set_ps(
605            if condition[i + 7] { 1.0 } else { 0.0 },
606            if condition[i + 6] { 1.0 } else { 0.0 },
607            if condition[i + 5] { 1.0 } else { 0.0 },
608            if condition[i + 4] { 1.0 } else { 0.0 },
609            if condition[i + 3] { 1.0 } else { 0.0 },
610            if condition[i + 2] { 1.0 } else { 0.0 },
611            if condition[i + 1] { 1.0 } else { 0.0 },
612            if condition[i] { 1.0 } else { 0.0 },
613        );
614
615        let masked_vec = _mm256_mul_ps(vec, mask);
616        sum = _mm256_add_ps(sum, masked_vec);
617
618        i += 8;
619    }
620
621    // Horizontal sum
622    let mut result = [0.0f32; 8];
623    _mm256_storeu_ps(result.as_mut_ptr(), sum);
624    let mut total = result.iter().sum::<f32>();
625
626    // Handle remaining elements
627    while i < input.len() {
628        if condition[i] {
629            total += input[i];
630        }
631        i += 1;
632    }
633
634    total
635}
636
637/// SIMD-optimized reduce by key operation
638/// Groups consecutive elements with the same key and reduces each group
639#[derive(Debug, Clone, PartialEq)]
640pub struct KeyValue<K, V> {
641    pub key: K,
642    pub value: V,
643}
644
645pub fn reduce_by_key_f32_simd(
646    input: &[KeyValue<i32, f32>],
647    reduction_op: fn(f32, f32) -> f32,
648) -> Vec<KeyValue<i32, f32>> {
649    if input.is_empty() {
650        return Vec::new();
651    }
652
653    let mut result = Vec::new();
654    let mut current_key = input[0].key;
655    let mut current_value = input[0].value;
656
657    for item in input.iter().skip(1) {
658        if item.key == current_key {
659            current_value = reduction_op(current_value, item.value);
660        } else {
661            result.push(KeyValue {
662                key: current_key,
663                value: current_value,
664            });
665            current_key = item.key;
666            current_value = item.value;
667        }
668    }
669
670    // Don't forget the last group
671    result.push(KeyValue {
672        key: current_key,
673        value: current_value,
674    });
675
676    result
677}
678
679#[allow(non_snake_case)]
680#[cfg(all(test, not(feature = "no-std")))]
681mod tests {
682    use super::*;
683    use approx::assert_relative_eq;
684
685    #[test]
686    fn test_parallel_sum() {
687        let arr = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
688        let sum = parallel_sum_f32_simd(&arr);
689        let expected: f32 = arr.iter().sum();
690        assert_relative_eq!(sum, expected, epsilon = 1e-6);
691    }
692
693    #[test]
694    fn test_parallel_product() {
695        let arr = vec![1.0, 2.0, 3.0, 4.0];
696        let product = parallel_product_f32_simd(&arr);
697        let expected: f32 = arr.iter().product();
698        assert_relative_eq!(product, expected, epsilon = 1e-6);
699    }
700
701    #[test]
702    fn test_parallel_max() {
703        let arr = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0];
704        let max_val = parallel_max_f32_simd(&arr);
705        assert_eq!(max_val, Some(9.0));
706    }
707
708    #[test]
709    fn test_parallel_min() {
710        let arr = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0];
711        let min_val = parallel_min_f32_simd(&arr);
712        assert_eq!(min_val, Some(1.0));
713    }
714
715    #[test]
716    fn test_prefix_sum() {
717        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
718        let mut output = vec![0.0; input.len()];
719        prefix_sum_f32_simd(&input, &mut output);
720
721        let expected = [1.0, 3.0, 6.0, 10.0, 15.0];
722        for (i, &expected_val) in expected.iter().enumerate() {
723            assert_relative_eq!(output[i], expected_val, epsilon = 1e-6);
724        }
725    }
726
727    #[test]
728    fn test_exclusive_scan() {
729        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
730        let mut output = vec![0.0; input.len()];
731        exclusive_scan_f32_simd(&input, &mut output);
732
733        let expected = [0.0, 1.0, 3.0, 6.0, 10.0];
734        for (i, &expected_val) in expected.iter().enumerate() {
735            assert_relative_eq!(output[i], expected_val, epsilon = 1e-6);
736        }
737    }
738
739    #[test]
740    fn test_segmented_sum() {
741        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
742        let flags = vec![true, false, false, true, false, true];
743        let mut output = vec![0.0; input.len()];
744
745        segmented_sum_f32_simd(&input, &flags, &mut output);
746
747        // Segments: [1,2,3], [4,5], [6]
748        // Expected cumulative sums: [1,3,6], [4,9], [6]
749        let expected = [1.0, 3.0, 6.0, 4.0, 9.0, 6.0];
750        for (i, &expected_val) in expected.iter().enumerate() {
751            assert_relative_eq!(output[i], expected_val, epsilon = 1e-6);
752        }
753    }
754
755    #[test]
756    fn test_conditional_sum() {
757        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
758        let condition = vec![true, false, true, false, true];
759
760        let sum = conditional_sum_f32_simd(&input, &condition);
761        assert_relative_eq!(sum, 9.0, epsilon = 1e-6); // 1 + 3 + 5
762    }
763
764    #[test]
765    fn test_reduce_by_key() {
766        let input = vec![
767            KeyValue {
768                key: 1,
769                value: 10.0,
770            },
771            KeyValue {
772                key: 1,
773                value: 20.0,
774            },
775            KeyValue {
776                key: 2,
777                value: 30.0,
778            },
779            KeyValue {
780                key: 2,
781                value: 40.0,
782            },
783            KeyValue {
784                key: 3,
785                value: 50.0,
786            },
787        ];
788
789        let result = reduce_by_key_f32_simd(&input, |a, b| a + b);
790
791        assert_eq!(result.len(), 3);
792        assert_eq!(
793            result[0],
794            KeyValue {
795                key: 1,
796                value: 30.0
797            }
798        );
799        assert_eq!(
800            result[1],
801            KeyValue {
802                key: 2,
803                value: 70.0
804            }
805        );
806        assert_eq!(
807            result[2],
808            KeyValue {
809                key: 3,
810                value: 50.0
811            }
812        );
813    }
814
815    #[test]
816    fn test_empty_arrays() {
817        let empty: Vec<f32> = vec![];
818        assert_eq!(parallel_sum_f32_simd(&empty), 0.0);
819        assert_eq!(parallel_product_f32_simd(&empty), 1.0);
820        assert_eq!(parallel_max_f32_simd(&empty), None);
821        assert_eq!(parallel_min_f32_simd(&empty), None);
822    }
823
824    #[test]
825    fn test_single_element() {
826        let arr = vec![42.0];
827        assert_eq!(parallel_sum_f32_simd(&arr), 42.0);
828        assert_eq!(parallel_product_f32_simd(&arr), 42.0);
829        assert_eq!(parallel_max_f32_simd(&arr), Some(42.0));
830        assert_eq!(parallel_min_f32_simd(&arr), Some(42.0));
831    }
832}