Skip to main content

trueno/backends/sse2/ops/
reductions.rs

1//! SSE2 reduction operations (dot, sum, max, min, argmax, argmin).
2
3#[cfg(target_arch = "x86_64")]
4use std::arch::x86_64::*;
5
6use crate::backends::VectorBackend;
7
8/// SSE2 dot product.
9#[inline]
10#[target_feature(enable = "sse2")]
11// SAFETY: caller ensures preconditions are met for this unsafe function
12pub(crate) unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
13    unsafe {
14        let len = a.len();
15        let mut i = 0;
16        let mut acc = _mm_setzero_ps();
17
18        while i + 4 <= len {
19            let va = _mm_loadu_ps(a.as_ptr().add(i));
20            let vb = _mm_loadu_ps(b.as_ptr().add(i));
21            acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
22            i += 4;
23        }
24
25        let mut result = horizontal_sum_ps(acc);
26        result += a[i..].iter().zip(&b[i..]).map(|(x, y)| x * y).sum::<f32>();
27        result
28    }
29}
30
31/// SSE2 vector sum.
32#[inline]
33#[target_feature(enable = "sse2")]
34// SAFETY: caller ensures preconditions are met for this unsafe function
35pub(crate) unsafe fn sum(a: &[f32]) -> f32 {
36    unsafe {
37        let len = a.len();
38        let mut i = 0;
39        let mut acc = _mm_setzero_ps();
40
41        while i + 4 <= len {
42            acc = _mm_add_ps(acc, _mm_loadu_ps(a.as_ptr().add(i)));
43            i += 4;
44        }
45
46        let mut result = horizontal_sum_ps(acc);
47        result += a[i..].iter().sum::<f32>();
48        result
49    }
50}
51
52/// SSE2 vector max.
53#[inline]
54#[target_feature(enable = "sse2")]
55// SAFETY: caller ensures preconditions are met for this unsafe function
56pub(crate) unsafe fn max(a: &[f32]) -> f32 {
57    unsafe {
58        let len = a.len();
59        let mut i = 0;
60        let mut vmax = _mm_set1_ps(a[0]);
61
62        while i + 4 <= len {
63            vmax = _mm_max_ps(vmax, _mm_loadu_ps(a.as_ptr().add(i)));
64            i += 4;
65        }
66
67        let mut result = horizontal_max_ps(vmax);
68        for &val in &a[i..] {
69            if val > result {
70                result = val;
71            }
72        }
73        result
74    }
75}
76
77/// SSE2 vector min.
78#[inline]
79#[target_feature(enable = "sse2")]
80// SAFETY: caller ensures preconditions are met for this unsafe function
81pub(crate) unsafe fn min(a: &[f32]) -> f32 {
82    unsafe {
83        let len = a.len();
84        let mut i = 0;
85        let mut vmin = _mm_set1_ps(a[0]);
86
87        while i + 4 <= len {
88            vmin = _mm_min_ps(vmin, _mm_loadu_ps(a.as_ptr().add(i)));
89            i += 4;
90        }
91
92        let mut result = horizontal_min_ps(vmin);
93        for &val in &a[i..] {
94            if val < result {
95                result = val;
96            }
97        }
98        result
99    }
100}
101
102/// SSE2 argmax.
103#[inline]
104#[target_feature(enable = "sse2")]
105// SAFETY: caller ensures preconditions are met for this unsafe function
106pub(crate) unsafe fn argmax(a: &[f32]) -> usize {
107    let _len = a.len();
108    let mut max_idx: usize = 0;
109    let mut max_val = a[0];
110
111    for (i, &val) in a.iter().enumerate() {
112        if val > max_val {
113            max_val = val;
114            max_idx = i;
115        }
116    }
117    max_idx
118}
119
120/// SSE2 argmin.
121#[inline]
122#[target_feature(enable = "sse2")]
123// SAFETY: caller ensures preconditions are met for this unsafe function
124pub(crate) unsafe fn argmin(a: &[f32]) -> usize {
125    let _len = a.len();
126    let mut min_idx: usize = 0;
127    let mut min_val = a[0];
128
129    for (i, &val) in a.iter().enumerate() {
130        if val < min_val {
131            min_val = val;
132            min_idx = i;
133        }
134    }
135    min_idx
136}
137
138/// Kahan sum (delegates to scalar).
139#[inline]
140// SAFETY: caller ensures preconditions are met for this unsafe function
141pub(crate) unsafe fn sum_kahan(a: &[f32]) -> f32 {
142    unsafe { crate::backends::scalar::ScalarBackend::sum_kahan(a) }
143}
144
145// Helper: horizontal sum of 4 floats
146#[inline]
147#[target_feature(enable = "sse2")]
148// SAFETY: caller ensures preconditions are met for this unsafe function
149unsafe fn horizontal_sum_ps(v: __m128) -> f32 {
150    let temp = _mm_add_ps(v, _mm_movehl_ps(v, v));
151    let temp = _mm_add_ss(temp, _mm_shuffle_ps(temp, temp, 1));
152    _mm_cvtss_f32(temp)
153}
154
155// Helper: horizontal max of 4 floats
156#[inline]
157#[target_feature(enable = "sse2")]
158// SAFETY: caller ensures preconditions are met for this unsafe function
159unsafe fn horizontal_max_ps(v: __m128) -> f32 {
160    let temp = _mm_max_ps(v, _mm_movehl_ps(v, v));
161    let temp = _mm_max_ss(temp, _mm_shuffle_ps(temp, temp, 1));
162    _mm_cvtss_f32(temp)
163}
164
165// Helper: horizontal min of 4 floats
166#[inline]
167#[target_feature(enable = "sse2")]
168// SAFETY: caller ensures preconditions are met for this unsafe function
169unsafe fn horizontal_min_ps(v: __m128) -> f32 {
170    let temp = _mm_min_ps(v, _mm_movehl_ps(v, v));
171    let temp = _mm_min_ss(temp, _mm_shuffle_ps(temp, temp, 1));
172    _mm_cvtss_f32(temp)
173}