Skip to main content

trueno/backends/sse2/ops/
elementwise.rs

1//! SSE2 elementwise operations (scale, abs, clamp, lerp, fma, relu, sqrt, recip, norms).
2
3#[cfg(target_arch = "x86_64")]
4use std::arch::x86_64::*;
5
6/// SSE2 L1 norm.
7#[inline]
8#[target_feature(enable = "sse2")]
9// SAFETY: caller ensures preconditions are met for this unsafe function
10pub(crate) unsafe fn norm_l1(a: &[f32]) -> f32 {
11    unsafe {
12        if a.is_empty() {
13            return 0.0;
14        }
15        let len = a.len();
16        let mut i = 0;
17        let mut acc = _mm_setzero_ps();
18        let sign_mask = _mm_set1_ps(f32::from_bits(0x7FFF_FFFF));
19        while i + 4 <= len {
20            acc = _mm_add_ps(acc, _mm_and_ps(_mm_loadu_ps(a.as_ptr().add(i)), sign_mask));
21            i += 4;
22        }
23        let mut result = {
24            let temp = _mm_add_ps(acc, _mm_movehl_ps(acc, acc));
25            let temp = _mm_add_ss(temp, _mm_shuffle_ps(temp, temp, 1));
26            _mm_cvtss_f32(temp)
27        };
28        for &val in &a[i..] {
29            result += val.abs();
30        }
31        result
32    }
33}
34
35/// SSE2 L-infinity norm.
36#[inline]
37#[target_feature(enable = "sse2")]
38// SAFETY: caller ensures preconditions are met for this unsafe function
39pub(crate) unsafe fn norm_linf(a: &[f32]) -> f32 {
40    unsafe {
41        if a.is_empty() {
42            return 0.0;
43        }
44        let len = a.len();
45        let mut i = 0;
46        let mut max_vec = _mm_setzero_ps();
47        let sign_mask = _mm_set1_ps(f32::from_bits(0x7FFF_FFFF));
48        while i + 4 <= len {
49            let va = _mm_loadu_ps(a.as_ptr().add(i));
50            max_vec = _mm_max_ps(max_vec, _mm_and_ps(va, sign_mask));
51            i += 4;
52        }
53        let mut result = {
54            let temp = _mm_max_ps(max_vec, _mm_movehl_ps(max_vec, max_vec));
55            let temp = _mm_max_ss(temp, _mm_shuffle_ps(temp, temp, 1));
56            _mm_cvtss_f32(temp)
57        };
58        for &val in &a[i..] {
59            let abs_val = val.abs();
60            if abs_val > result {
61                result = abs_val;
62            }
63        }
64        result
65    }
66}
67
68/// SSE2 scalar multiply.
69#[inline]
70#[target_feature(enable = "sse2")]
71// SAFETY: caller ensures preconditions are met for this unsafe function
72pub(crate) unsafe fn scale(a: &[f32], scalar: f32, result: &mut [f32]) {
73    unsafe {
74        let len = a.len();
75        let mut i = 0;
76        let scalar_vec = _mm_set1_ps(scalar);
77        while i + 4 <= len {
78            _mm_storeu_ps(
79                result.as_mut_ptr().add(i),
80                _mm_mul_ps(_mm_loadu_ps(a.as_ptr().add(i)), scalar_vec),
81            );
82            i += 4;
83        }
84        for j in i..len {
85            result[j] = a[j] * scalar;
86        }
87    }
88}
89
90/// SSE2 absolute value.
91#[inline]
92#[target_feature(enable = "sse2")]
93// SAFETY: caller ensures preconditions are met for this unsafe function
94pub(crate) unsafe fn abs(a: &[f32], result: &mut [f32]) {
95    unsafe {
96        let len = a.len();
97        let mut i = 0;
98        let sign_mask = _mm_set1_ps(f32::from_bits(0x7FFF_FFFF));
99        while i + 4 <= len {
100            _mm_storeu_ps(
101                result.as_mut_ptr().add(i),
102                _mm_and_ps(_mm_loadu_ps(a.as_ptr().add(i)), sign_mask),
103            );
104            i += 4;
105        }
106        for j in i..len {
107            result[j] = a[j].abs();
108        }
109    }
110}
111
112/// SSE2 clamp.
113#[inline]
114#[target_feature(enable = "sse2")]
115// SAFETY: caller ensures preconditions are met for this unsafe function
116pub(crate) unsafe fn clamp(a: &[f32], min_val: f32, max_val: f32, result: &mut [f32]) {
117    unsafe {
118        let len = a.len();
119        let mut i = 0;
120        let min_vec = _mm_set1_ps(min_val);
121        let max_vec = _mm_set1_ps(max_val);
122        while i + 4 <= len {
123            let va = _mm_loadu_ps(a.as_ptr().add(i));
124            _mm_storeu_ps(result.as_mut_ptr().add(i), _mm_min_ps(_mm_max_ps(va, min_vec), max_vec));
125            i += 4;
126        }
127        for j in i..len {
128            result[j] = a[j].max(min_val).min(max_val);
129        }
130    }
131}
132
133/// SSE2 linear interpolation.
134#[inline]
135#[target_feature(enable = "sse2")]
136// SAFETY: caller ensures preconditions are met for this unsafe function
137pub(crate) unsafe fn lerp(a: &[f32], b: &[f32], t: f32, result: &mut [f32]) {
138    unsafe {
139        let len = a.len();
140        let mut i = 0;
141        let t_vec = _mm_set1_ps(t);
142        while i + 4 <= len {
143            let va = _mm_loadu_ps(a.as_ptr().add(i));
144            let vb = _mm_loadu_ps(b.as_ptr().add(i));
145            _mm_storeu_ps(
146                result.as_mut_ptr().add(i),
147                _mm_add_ps(va, _mm_mul_ps(t_vec, _mm_sub_ps(vb, va))),
148            );
149            i += 4;
150        }
151        for j in i..len {
152            result[j] = a[j] + t * (b[j] - a[j]);
153        }
154    }
155}
156
157/// SSE2 fused multiply-add (emulated, no FMA instruction set).
158#[inline]
159#[target_feature(enable = "sse2")]
160// SAFETY: caller ensures preconditions are met for this unsafe function
161pub(crate) unsafe fn fma(a: &[f32], b: &[f32], c: &[f32], result: &mut [f32]) {
162    unsafe {
163        let len = a.len();
164        let mut i = 0;
165        while i + 4 <= len {
166            let va = _mm_loadu_ps(a.as_ptr().add(i));
167            let vb = _mm_loadu_ps(b.as_ptr().add(i));
168            let vc = _mm_loadu_ps(c.as_ptr().add(i));
169            _mm_storeu_ps(result.as_mut_ptr().add(i), _mm_add_ps(_mm_mul_ps(va, vb), vc));
170            i += 4;
171        }
172        for j in i..len {
173            result[j] = a[j] * b[j] + c[j];
174        }
175    }
176}
177
178/// SSE2 ReLU activation.
179#[inline]
180#[target_feature(enable = "sse2")]
181// SAFETY: caller ensures preconditions are met for this unsafe function
182pub(crate) unsafe fn relu(a: &[f32], result: &mut [f32]) {
183    unsafe {
184        let len = a.len();
185        let mut i = 0;
186        let zero = _mm_setzero_ps();
187        while i + 4 <= len {
188            _mm_storeu_ps(
189                result.as_mut_ptr().add(i),
190                _mm_max_ps(_mm_loadu_ps(a.as_ptr().add(i)), zero),
191            );
192            i += 4;
193        }
194        for j in i..len {
195            result[j] = a[j].max(0.0);
196        }
197    }
198}
199
200/// SSE2 square root.
201#[inline]
202#[target_feature(enable = "sse2")]
203// SAFETY: caller ensures preconditions are met for this unsafe function
204pub(crate) unsafe fn sqrt(a: &[f32], result: &mut [f32]) {
205    unsafe {
206        let len = a.len();
207        let mut i = 0;
208        while i + 4 <= len {
209            _mm_storeu_ps(result.as_mut_ptr().add(i), _mm_sqrt_ps(_mm_loadu_ps(a.as_ptr().add(i))));
210            i += 4;
211        }
212        for j in i..len {
213            result[j] = a[j].sqrt();
214        }
215    }
216}
217
218/// SSE2 reciprocal.
219#[inline]
220#[target_feature(enable = "sse2")]
221// SAFETY: caller ensures preconditions are met for this unsafe function
222pub(crate) unsafe fn recip(a: &[f32], result: &mut [f32]) {
223    unsafe {
224        let len = a.len();
225        let mut i = 0;
226        let one = _mm_set1_ps(1.0);
227        while i + 4 <= len {
228            _mm_storeu_ps(
229                result.as_mut_ptr().add(i),
230                _mm_div_ps(one, _mm_loadu_ps(a.as_ptr().add(i))),
231            );
232            i += 4;
233        }
234        for j in i..len {
235            result[j] = a[j].recip();
236        }
237    }
238}