Skip to main content

trueno/backends/avx512/ops/
reductions.rs

1//! AVX-512 reduction operations (dot, sum, max, min, argmax, argmin).
2
3#[cfg(target_arch = "x86_64")]
4use std::arch::x86_64::*;
5
6/// AVX-512 dot product.
7#[inline]
8#[target_feature(enable = "avx512f")]
9// SAFETY: caller ensures preconditions are met for this unsafe function
10pub(crate) unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
11    unsafe {
12        let len = a.len();
13        let mut i = 0;
14        let mut acc = _mm512_setzero_ps();
15
16        while i + 16 <= len {
17            let va = _mm512_loadu_ps(a.as_ptr().add(i));
18            let vb = _mm512_loadu_ps(b.as_ptr().add(i));
19            acc = _mm512_fmadd_ps(va, vb, acc);
20            i += 16;
21        }
22
23        let mut result = _mm512_reduce_add_ps(acc);
24        result += a[i..].iter().zip(&b[i..]).map(|(x, y)| x * y).sum::<f32>();
25        result
26    }
27}
28
29/// AVX-512 vector sum.
30#[inline]
31#[target_feature(enable = "avx512f")]
32// SAFETY: caller ensures preconditions are met for this unsafe function
33pub(crate) unsafe fn sum(a: &[f32]) -> f32 {
34    unsafe {
35        let len = a.len();
36        let mut i = 0;
37        let mut acc = _mm512_setzero_ps();
38
39        while i + 16 <= len {
40            acc = _mm512_add_ps(acc, _mm512_loadu_ps(a.as_ptr().add(i)));
41            i += 16;
42        }
43
44        let mut result = _mm512_reduce_add_ps(acc);
45        result += a[i..].iter().sum::<f32>();
46        result
47    }
48}
49
50/// AVX-512 vector max.
51#[inline]
52#[target_feature(enable = "avx512f")]
53// SAFETY: caller ensures preconditions are met for this unsafe function
54pub(crate) unsafe fn max(a: &[f32]) -> f32 {
55    unsafe {
56        let len = a.len();
57        let mut i = 0;
58        let mut vmax = _mm512_set1_ps(a[0]);
59
60        while i + 16 <= len {
61            vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(a.as_ptr().add(i)));
62            i += 16;
63        }
64
65        let mut result = _mm512_reduce_max_ps(vmax);
66        for &val in &a[i..] {
67            if val > result {
68                result = val;
69            }
70        }
71        result
72    }
73}
74
75/// AVX-512 vector min.
76#[inline]
77#[target_feature(enable = "avx512f")]
78// SAFETY: caller ensures preconditions are met for this unsafe function
79pub(crate) unsafe fn min(a: &[f32]) -> f32 {
80    unsafe {
81        let len = a.len();
82        let mut i = 0;
83        let mut vmin = _mm512_set1_ps(a[0]);
84
85        while i + 16 <= len {
86            vmin = _mm512_min_ps(vmin, _mm512_loadu_ps(a.as_ptr().add(i)));
87            i += 16;
88        }
89
90        let mut result = _mm512_reduce_min_ps(vmin);
91        for &val in &a[i..] {
92            if val < result {
93                result = val;
94            }
95        }
96        result
97    }
98}
99
100/// AVX-512 argmax.
101#[inline]
102#[target_feature(enable = "avx512f")]
103// SAFETY: caller ensures preconditions are met for this unsafe function
104pub(crate) unsafe fn argmax(a: &[f32]) -> usize {
105    let mut max_idx: usize = 0;
106    let mut max_val = a[0];
107    for (i, &val) in a.iter().enumerate() {
108        if val > max_val {
109            max_val = val;
110            max_idx = i;
111        }
112    }
113    max_idx
114}
115
116/// AVX-512 argmin.
117#[inline]
118#[target_feature(enable = "avx512f")]
119// SAFETY: caller ensures preconditions are met for this unsafe function
120pub(crate) unsafe fn argmin(a: &[f32]) -> usize {
121    let mut min_idx: usize = 0;
122    let mut min_val = a[0];
123    for (i, &val) in a.iter().enumerate() {
124        if val < min_val {
125            min_val = val;
126            min_idx = i;
127        }
128    }
129    min_idx
130}
131
132/// Kahan sum (scalar implementation).
133#[inline]
134// SAFETY: caller ensures preconditions are met for this unsafe function
135pub(crate) unsafe fn sum_kahan(a: &[f32]) -> f32 {
136    let mut sum = 0.0;
137    let mut c = 0.0;
138    for &x in a {
139        let y = x - c;
140        let t = sum + y;
141        c = (t - sum) - y;
142        sum = t;
143    }
144    sum
145}