trueno/backends/sse2/
mod.rs1mod ops;
16
17use super::VectorBackend;
18
19pub struct Sse2Backend;
21
22impl VectorBackend for Sse2Backend {
23 #[inline]
24 #[target_feature(enable = "sse2")]
25 unsafe fn add(a: &[f32], b: &[f32], result: &mut [f32]) {
27 unsafe {
28 ops::arithmetic::add(a, b, result);
29 }
30 }
31
32 #[inline]
33 #[target_feature(enable = "sse2")]
34 unsafe fn sub(a: &[f32], b: &[f32], result: &mut [f32]) {
36 unsafe {
37 ops::arithmetic::sub(a, b, result);
38 }
39 }
40
41 #[inline]
42 #[target_feature(enable = "sse2")]
43 unsafe fn mul(a: &[f32], b: &[f32], result: &mut [f32]) {
45 unsafe {
46 ops::arithmetic::mul(a, b, result);
47 }
48 }
49
50 #[inline]
51 #[target_feature(enable = "sse2")]
52 unsafe fn div(a: &[f32], b: &[f32], result: &mut [f32]) {
54 unsafe {
55 ops::arithmetic::div(a, b, result);
56 }
57 }
58
59 #[inline]
60 #[target_feature(enable = "sse2")]
61 unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
63 unsafe { ops::reductions::dot(a, b) }
64 }
65
66 #[inline]
67 #[target_feature(enable = "sse2")]
68 unsafe fn sum(a: &[f32]) -> f32 {
70 unsafe { ops::reductions::sum(a) }
71 }
72
73 #[inline]
74 #[target_feature(enable = "sse2")]
75 unsafe fn max(a: &[f32]) -> f32 {
77 unsafe { ops::reductions::max(a) }
78 }
79
80 #[inline]
81 #[target_feature(enable = "sse2")]
82 unsafe fn min(a: &[f32]) -> f32 {
84 unsafe { ops::reductions::min(a) }
85 }
86
87 #[inline]
88 #[target_feature(enable = "sse2")]
89 unsafe fn argmax(a: &[f32]) -> usize {
91 unsafe { ops::reductions::argmax(a) }
92 }
93
94 #[inline]
95 #[target_feature(enable = "sse2")]
96 unsafe fn argmin(a: &[f32]) -> usize {
98 unsafe { ops::reductions::argmin(a) }
99 }
100
101 #[inline]
102 #[target_feature(enable = "sse2")]
103 unsafe fn sum_kahan(a: &[f32]) -> f32 {
105 unsafe { ops::reductions::sum_kahan(a) }
106 }
107
108 #[inline]
109 #[target_feature(enable = "sse2")]
110 unsafe fn norm_l2(a: &[f32]) -> f32 {
112 unsafe {
113 if a.is_empty() {
114 return 0.0;
115 }
116 Self::dot(a, a).sqrt()
117 }
118 }
119
120 #[inline]
121 #[target_feature(enable = "sse2")]
122 unsafe fn norm_l1(a: &[f32]) -> f32 {
124 unsafe { ops::elementwise::norm_l1(a) }
125 }
126
127 #[inline]
128 #[target_feature(enable = "sse2")]
129 unsafe fn norm_linf(a: &[f32]) -> f32 {
131 unsafe { ops::elementwise::norm_linf(a) }
132 }
133
134 #[inline]
135 #[target_feature(enable = "sse2")]
136 unsafe fn scale(a: &[f32], scalar: f32, result: &mut [f32]) {
138 unsafe {
139 ops::elementwise::scale(a, scalar, result);
140 }
141 }
142
143 #[inline]
144 #[target_feature(enable = "sse2")]
145 unsafe fn abs(a: &[f32], result: &mut [f32]) {
147 unsafe {
148 ops::elementwise::abs(a, result);
149 }
150 }
151
152 #[inline]
153 #[target_feature(enable = "sse2")]
154 unsafe fn clamp(a: &[f32], min_val: f32, max_val: f32, result: &mut [f32]) {
156 unsafe {
157 ops::elementwise::clamp(a, min_val, max_val, result);
158 }
159 }
160
161 #[inline]
162 #[target_feature(enable = "sse2")]
163 unsafe fn lerp(a: &[f32], b: &[f32], t: f32, result: &mut [f32]) {
165 unsafe {
166 ops::elementwise::lerp(a, b, t, result);
167 }
168 }
169
170 #[inline]
171 #[target_feature(enable = "sse2")]
172 unsafe fn fma(a: &[f32], b: &[f32], c: &[f32], result: &mut [f32]) {
174 unsafe {
175 ops::elementwise::fma(a, b, c, result);
176 }
177 }
178
179 #[inline]
180 #[target_feature(enable = "sse2")]
181 unsafe fn relu(a: &[f32], result: &mut [f32]) {
183 unsafe {
184 ops::elementwise::relu(a, result);
185 }
186 }
187
188 #[inline]
189 #[target_feature(enable = "sse2")]
190 unsafe fn exp(a: &[f32], result: &mut [f32]) {
192 unsafe {
193 ops::activations::exp(a, result);
194 }
195 }
196
197 #[inline]
198 #[target_feature(enable = "sse2")]
199 unsafe fn sigmoid(a: &[f32], result: &mut [f32]) {
201 unsafe {
202 ops::activations::sigmoid(a, result);
203 }
204 }
205
206 #[inline]
207 #[target_feature(enable = "sse2")]
208 unsafe fn gelu(a: &[f32], result: &mut [f32]) {
210 unsafe {
211 ops::activations::gelu(a, result);
212 }
213 }
214
215 #[inline]
216 #[target_feature(enable = "sse2")]
217 unsafe fn swish(a: &[f32], result: &mut [f32]) {
219 unsafe {
220 ops::activations::swish(a, result);
221 }
222 }
223
224 #[inline]
225 #[target_feature(enable = "sse2")]
226 unsafe fn tanh(a: &[f32], result: &mut [f32]) {
228 unsafe {
229 ops::activations::tanh(a, result);
230 }
231 }
232
233 #[inline]
234 #[target_feature(enable = "sse2")]
235 unsafe fn sqrt(a: &[f32], result: &mut [f32]) {
237 unsafe {
238 ops::elementwise::sqrt(a, result);
239 }
240 }
241
242 #[inline]
243 #[target_feature(enable = "sse2")]
244 unsafe fn recip(a: &[f32], result: &mut [f32]) {
246 unsafe {
247 ops::elementwise::recip(a, result);
248 }
249 }
250
251 unsafe fn ln(a: &[f32], result: &mut [f32]) {
253 unsafe {
254 super::scalar::ScalarBackend::ln(a, result);
255 }
256 }
257 unsafe fn log2(a: &[f32], result: &mut [f32]) {
259 unsafe {
260 super::scalar::ScalarBackend::log2(a, result);
261 }
262 }
263 unsafe fn log10(a: &[f32], result: &mut [f32]) {
265 unsafe {
266 super::scalar::ScalarBackend::log10(a, result);
267 }
268 }
269 unsafe fn sin(a: &[f32], result: &mut [f32]) {
271 unsafe {
272 super::scalar::ScalarBackend::sin(a, result);
273 }
274 }
275 unsafe fn cos(a: &[f32], result: &mut [f32]) {
277 unsafe {
278 super::scalar::ScalarBackend::cos(a, result);
279 }
280 }
281 unsafe fn tan(a: &[f32], result: &mut [f32]) {
283 unsafe {
284 super::scalar::ScalarBackend::tan(a, result);
285 }
286 }
287 unsafe fn floor(a: &[f32], result: &mut [f32]) {
289 unsafe {
290 super::scalar::ScalarBackend::floor(a, result);
291 }
292 }
293 unsafe fn ceil(a: &[f32], result: &mut [f32]) {
295 unsafe {
296 super::scalar::ScalarBackend::ceil(a, result);
297 }
298 }
299 unsafe fn round(a: &[f32], result: &mut [f32]) {
301 unsafe {
302 super::scalar::ScalarBackend::round(a, result);
303 }
304 }
305}
306
307#[cfg(test)]
308mod tests;