trueno/backends/avx512/ops/reductions.rs
1//! AVX-512 reduction operations (dot, sum, max, min, argmax, argmin).
2
3#[cfg(target_arch = "x86_64")]
4use std::arch::x86_64::*;
5
6/// AVX-512 dot product.
7#[inline]
8#[target_feature(enable = "avx512f")]
9// SAFETY: caller ensures preconditions are met for this unsafe function
10pub(crate) unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
11 unsafe {
12 let len = a.len();
13 let mut i = 0;
14 let mut acc = _mm512_setzero_ps();
15
16 while i + 16 <= len {
17 let va = _mm512_loadu_ps(a.as_ptr().add(i));
18 let vb = _mm512_loadu_ps(b.as_ptr().add(i));
19 acc = _mm512_fmadd_ps(va, vb, acc);
20 i += 16;
21 }
22
23 let mut result = _mm512_reduce_add_ps(acc);
24 result += a[i..].iter().zip(&b[i..]).map(|(x, y)| x * y).sum::<f32>();
25 result
26 }
27}
28
29/// AVX-512 vector sum.
30#[inline]
31#[target_feature(enable = "avx512f")]
32// SAFETY: caller ensures preconditions are met for this unsafe function
33pub(crate) unsafe fn sum(a: &[f32]) -> f32 {
34 unsafe {
35 let len = a.len();
36 let mut i = 0;
37 let mut acc = _mm512_setzero_ps();
38
39 while i + 16 <= len {
40 acc = _mm512_add_ps(acc, _mm512_loadu_ps(a.as_ptr().add(i)));
41 i += 16;
42 }
43
44 let mut result = _mm512_reduce_add_ps(acc);
45 result += a[i..].iter().sum::<f32>();
46 result
47 }
48}
49
50/// AVX-512 vector max.
51#[inline]
52#[target_feature(enable = "avx512f")]
53// SAFETY: caller ensures preconditions are met for this unsafe function
54pub(crate) unsafe fn max(a: &[f32]) -> f32 {
55 unsafe {
56 let len = a.len();
57 let mut i = 0;
58 let mut vmax = _mm512_set1_ps(a[0]);
59
60 while i + 16 <= len {
61 vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(a.as_ptr().add(i)));
62 i += 16;
63 }
64
65 let mut result = _mm512_reduce_max_ps(vmax);
66 for &val in &a[i..] {
67 if val > result {
68 result = val;
69 }
70 }
71 result
72 }
73}
74
75/// AVX-512 vector min.
76#[inline]
77#[target_feature(enable = "avx512f")]
78// SAFETY: caller ensures preconditions are met for this unsafe function
79pub(crate) unsafe fn min(a: &[f32]) -> f32 {
80 unsafe {
81 let len = a.len();
82 let mut i = 0;
83 let mut vmin = _mm512_set1_ps(a[0]);
84
85 while i + 16 <= len {
86 vmin = _mm512_min_ps(vmin, _mm512_loadu_ps(a.as_ptr().add(i)));
87 i += 16;
88 }
89
90 let mut result = _mm512_reduce_min_ps(vmin);
91 for &val in &a[i..] {
92 if val < result {
93 result = val;
94 }
95 }
96 result
97 }
98}
99
100/// AVX-512 argmax.
101#[inline]
102#[target_feature(enable = "avx512f")]
103// SAFETY: caller ensures preconditions are met for this unsafe function
104pub(crate) unsafe fn argmax(a: &[f32]) -> usize {
105 let mut max_idx: usize = 0;
106 let mut max_val = a[0];
107 for (i, &val) in a.iter().enumerate() {
108 if val > max_val {
109 max_val = val;
110 max_idx = i;
111 }
112 }
113 max_idx
114}
115
116/// AVX-512 argmin.
117#[inline]
118#[target_feature(enable = "avx512f")]
119// SAFETY: caller ensures preconditions are met for this unsafe function
120pub(crate) unsafe fn argmin(a: &[f32]) -> usize {
121 let mut min_idx: usize = 0;
122 let mut min_val = a[0];
123 for (i, &val) in a.iter().enumerate() {
124 if val < min_val {
125 min_val = val;
126 min_idx = i;
127 }
128 }
129 min_idx
130}
131
132/// Kahan sum (scalar implementation).
133#[inline]
134// SAFETY: caller ensures preconditions are met for this unsafe function
135pub(crate) unsafe fn sum_kahan(a: &[f32]) -> f32 {
136 let mut sum = 0.0;
137 let mut c = 0.0;
138 for &x in a {
139 let y = x - c;
140 let t = sum + y;
141 c = (t - sum) - y;
142 sum = t;
143 }
144 sum
145}