Skip to main content

trueno/backends/sse2/
mod.rs

1//! SSE2 backend implementation (x86_64 baseline SIMD)
2//!
3//! This backend uses SSE2 intrinsics for 128-bit SIMD operations.
4//! SSE2 is available on all x86_64 CPUs as a baseline requirement.
5//!
6//! # Performance
7//!
8//! Expected speedup: 4x for operations on aligned f32 vectors (4 elements per register)
9//!
10//! # Safety
11//!
12//! All SSE2 intrinsics are marked `unsafe` by Rust. This module carefully isolates
13//! all unsafe code and verifies correctness through comprehensive testing.
14
15mod ops;
16
17use super::VectorBackend;
18
19/// SSE2 backend (128-bit SIMD for x86_64)
20pub struct Sse2Backend;
21
22impl VectorBackend for Sse2Backend {
23    #[inline]
24    #[target_feature(enable = "sse2")]
25    // SAFETY: caller ensures preconditions are met for this unsafe function
26    unsafe fn add(a: &[f32], b: &[f32], result: &mut [f32]) {
27        unsafe {
28            ops::arithmetic::add(a, b, result);
29        }
30    }
31
32    #[inline]
33    #[target_feature(enable = "sse2")]
34    // SAFETY: caller ensures preconditions are met for this unsafe function
35    unsafe fn sub(a: &[f32], b: &[f32], result: &mut [f32]) {
36        unsafe {
37            ops::arithmetic::sub(a, b, result);
38        }
39    }
40
41    #[inline]
42    #[target_feature(enable = "sse2")]
43    // SAFETY: caller ensures preconditions are met for this unsafe function
44    unsafe fn mul(a: &[f32], b: &[f32], result: &mut [f32]) {
45        unsafe {
46            ops::arithmetic::mul(a, b, result);
47        }
48    }
49
50    #[inline]
51    #[target_feature(enable = "sse2")]
52    // SAFETY: caller ensures preconditions are met for this unsafe function
53    unsafe fn div(a: &[f32], b: &[f32], result: &mut [f32]) {
54        unsafe {
55            ops::arithmetic::div(a, b, result);
56        }
57    }
58
59    #[inline]
60    #[target_feature(enable = "sse2")]
61    // SAFETY: caller ensures preconditions are met for this unsafe function
62    unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
63        unsafe { ops::reductions::dot(a, b) }
64    }
65
66    #[inline]
67    #[target_feature(enable = "sse2")]
68    // SAFETY: caller ensures preconditions are met for this unsafe function
69    unsafe fn sum(a: &[f32]) -> f32 {
70        unsafe { ops::reductions::sum(a) }
71    }
72
73    #[inline]
74    #[target_feature(enable = "sse2")]
75    // SAFETY: caller ensures preconditions are met for this unsafe function
76    unsafe fn max(a: &[f32]) -> f32 {
77        unsafe { ops::reductions::max(a) }
78    }
79
80    #[inline]
81    #[target_feature(enable = "sse2")]
82    // SAFETY: caller ensures preconditions are met for this unsafe function
83    unsafe fn min(a: &[f32]) -> f32 {
84        unsafe { ops::reductions::min(a) }
85    }
86
87    #[inline]
88    #[target_feature(enable = "sse2")]
89    // SAFETY: caller ensures preconditions are met for this unsafe function
90    unsafe fn argmax(a: &[f32]) -> usize {
91        unsafe { ops::reductions::argmax(a) }
92    }
93
94    #[inline]
95    #[target_feature(enable = "sse2")]
96    // SAFETY: caller ensures preconditions are met for this unsafe function
97    unsafe fn argmin(a: &[f32]) -> usize {
98        unsafe { ops::reductions::argmin(a) }
99    }
100
101    #[inline]
102    #[target_feature(enable = "sse2")]
103    // SAFETY: caller ensures preconditions are met for this unsafe function
104    unsafe fn sum_kahan(a: &[f32]) -> f32 {
105        unsafe { ops::reductions::sum_kahan(a) }
106    }
107
108    #[inline]
109    #[target_feature(enable = "sse2")]
110    // SAFETY: caller ensures preconditions are met for this unsafe function
111    unsafe fn norm_l2(a: &[f32]) -> f32 {
112        unsafe {
113            if a.is_empty() {
114                return 0.0;
115            }
116            Self::dot(a, a).sqrt()
117        }
118    }
119
120    #[inline]
121    #[target_feature(enable = "sse2")]
122    // SAFETY: caller ensures preconditions are met for this unsafe function
123    unsafe fn norm_l1(a: &[f32]) -> f32 {
124        unsafe { ops::elementwise::norm_l1(a) }
125    }
126
127    #[inline]
128    #[target_feature(enable = "sse2")]
129    // SAFETY: caller ensures preconditions are met for this unsafe function
130    unsafe fn norm_linf(a: &[f32]) -> f32 {
131        unsafe { ops::elementwise::norm_linf(a) }
132    }
133
134    #[inline]
135    #[target_feature(enable = "sse2")]
136    // SAFETY: caller ensures preconditions are met for this unsafe function
137    unsafe fn scale(a: &[f32], scalar: f32, result: &mut [f32]) {
138        unsafe {
139            ops::elementwise::scale(a, scalar, result);
140        }
141    }
142
143    #[inline]
144    #[target_feature(enable = "sse2")]
145    // SAFETY: caller ensures preconditions are met for this unsafe function
146    unsafe fn abs(a: &[f32], result: &mut [f32]) {
147        unsafe {
148            ops::elementwise::abs(a, result);
149        }
150    }
151
152    #[inline]
153    #[target_feature(enable = "sse2")]
154    // SAFETY: caller ensures preconditions are met for this unsafe function
155    unsafe fn clamp(a: &[f32], min_val: f32, max_val: f32, result: &mut [f32]) {
156        unsafe {
157            ops::elementwise::clamp(a, min_val, max_val, result);
158        }
159    }
160
161    #[inline]
162    #[target_feature(enable = "sse2")]
163    // SAFETY: caller ensures preconditions are met for this unsafe function
164    unsafe fn lerp(a: &[f32], b: &[f32], t: f32, result: &mut [f32]) {
165        unsafe {
166            ops::elementwise::lerp(a, b, t, result);
167        }
168    }
169
170    #[inline]
171    #[target_feature(enable = "sse2")]
172    // SAFETY: caller ensures preconditions are met for this unsafe function
173    unsafe fn fma(a: &[f32], b: &[f32], c: &[f32], result: &mut [f32]) {
174        unsafe {
175            ops::elementwise::fma(a, b, c, result);
176        }
177    }
178
179    #[inline]
180    #[target_feature(enable = "sse2")]
181    // SAFETY: caller ensures preconditions are met for this unsafe function
182    unsafe fn relu(a: &[f32], result: &mut [f32]) {
183        unsafe {
184            ops::elementwise::relu(a, result);
185        }
186    }
187
188    #[inline]
189    #[target_feature(enable = "sse2")]
190    // SAFETY: caller ensures preconditions are met for this unsafe function
191    unsafe fn exp(a: &[f32], result: &mut [f32]) {
192        unsafe {
193            ops::activations::exp(a, result);
194        }
195    }
196
197    #[inline]
198    #[target_feature(enable = "sse2")]
199    // SAFETY: caller ensures preconditions are met for this unsafe function
200    unsafe fn sigmoid(a: &[f32], result: &mut [f32]) {
201        unsafe {
202            ops::activations::sigmoid(a, result);
203        }
204    }
205
206    #[inline]
207    #[target_feature(enable = "sse2")]
208    // SAFETY: caller ensures preconditions are met for this unsafe function
209    unsafe fn gelu(a: &[f32], result: &mut [f32]) {
210        unsafe {
211            ops::activations::gelu(a, result);
212        }
213    }
214
215    #[inline]
216    #[target_feature(enable = "sse2")]
217    // SAFETY: caller ensures preconditions are met for this unsafe function
218    unsafe fn swish(a: &[f32], result: &mut [f32]) {
219        unsafe {
220            ops::activations::swish(a, result);
221        }
222    }
223
224    #[inline]
225    #[target_feature(enable = "sse2")]
226    // SAFETY: caller ensures preconditions are met for this unsafe function
227    unsafe fn tanh(a: &[f32], result: &mut [f32]) {
228        unsafe {
229            ops::activations::tanh(a, result);
230        }
231    }
232
233    #[inline]
234    #[target_feature(enable = "sse2")]
235    // SAFETY: caller ensures preconditions are met for this unsafe function
236    unsafe fn sqrt(a: &[f32], result: &mut [f32]) {
237        unsafe {
238            ops::elementwise::sqrt(a, result);
239        }
240    }
241
242    #[inline]
243    #[target_feature(enable = "sse2")]
244    // SAFETY: caller ensures preconditions are met for this unsafe function
245    unsafe fn recip(a: &[f32], result: &mut [f32]) {
246        unsafe {
247            ops::elementwise::recip(a, result);
248        }
249    }
250
251    // SAFETY: caller ensures preconditions are met for this unsafe function
252    unsafe fn ln(a: &[f32], result: &mut [f32]) {
253        unsafe {
254            super::scalar::ScalarBackend::ln(a, result);
255        }
256    }
257    // SAFETY: caller ensures preconditions are met for this unsafe function
258    unsafe fn log2(a: &[f32], result: &mut [f32]) {
259        unsafe {
260            super::scalar::ScalarBackend::log2(a, result);
261        }
262    }
263    // SAFETY: caller ensures preconditions are met for this unsafe function
264    unsafe fn log10(a: &[f32], result: &mut [f32]) {
265        unsafe {
266            super::scalar::ScalarBackend::log10(a, result);
267        }
268    }
269    // SAFETY: caller ensures preconditions are met for this unsafe function
270    unsafe fn sin(a: &[f32], result: &mut [f32]) {
271        unsafe {
272            super::scalar::ScalarBackend::sin(a, result);
273        }
274    }
275    // SAFETY: caller ensures preconditions are met for this unsafe function
276    unsafe fn cos(a: &[f32], result: &mut [f32]) {
277        unsafe {
278            super::scalar::ScalarBackend::cos(a, result);
279        }
280    }
281    // SAFETY: caller ensures preconditions are met for this unsafe function
282    unsafe fn tan(a: &[f32], result: &mut [f32]) {
283        unsafe {
284            super::scalar::ScalarBackend::tan(a, result);
285        }
286    }
287    // SAFETY: caller ensures preconditions are met for this unsafe function
288    unsafe fn floor(a: &[f32], result: &mut [f32]) {
289        unsafe {
290            super::scalar::ScalarBackend::floor(a, result);
291        }
292    }
293    // SAFETY: caller ensures preconditions are met for this unsafe function
294    unsafe fn ceil(a: &[f32], result: &mut [f32]) {
295        unsafe {
296            super::scalar::ScalarBackend::ceil(a, result);
297        }
298    }
299    // SAFETY: caller ensures preconditions are met for this unsafe function
300    unsafe fn round(a: &[f32], result: &mut [f32]) {
301        unsafe {
302            super::scalar::ScalarBackend::round(a, result);
303        }
304    }
305}
306
307#[cfg(test)]
308mod tests;