trueno/backends/sse2/ops/
elementwise.rs1#[cfg(target_arch = "x86_64")]
4use std::arch::x86_64::*;
5
6#[inline]
8#[target_feature(enable = "sse2")]
9pub(crate) unsafe fn norm_l1(a: &[f32]) -> f32 {
11 unsafe {
12 if a.is_empty() {
13 return 0.0;
14 }
15 let len = a.len();
16 let mut i = 0;
17 let mut acc = _mm_setzero_ps();
18 let sign_mask = _mm_set1_ps(f32::from_bits(0x7FFF_FFFF));
19 while i + 4 <= len {
20 acc = _mm_add_ps(acc, _mm_and_ps(_mm_loadu_ps(a.as_ptr().add(i)), sign_mask));
21 i += 4;
22 }
23 let mut result = {
24 let temp = _mm_add_ps(acc, _mm_movehl_ps(acc, acc));
25 let temp = _mm_add_ss(temp, _mm_shuffle_ps(temp, temp, 1));
26 _mm_cvtss_f32(temp)
27 };
28 for &val in &a[i..] {
29 result += val.abs();
30 }
31 result
32 }
33}
34
35#[inline]
37#[target_feature(enable = "sse2")]
38pub(crate) unsafe fn norm_linf(a: &[f32]) -> f32 {
40 unsafe {
41 if a.is_empty() {
42 return 0.0;
43 }
44 let len = a.len();
45 let mut i = 0;
46 let mut max_vec = _mm_setzero_ps();
47 let sign_mask = _mm_set1_ps(f32::from_bits(0x7FFF_FFFF));
48 while i + 4 <= len {
49 let va = _mm_loadu_ps(a.as_ptr().add(i));
50 max_vec = _mm_max_ps(max_vec, _mm_and_ps(va, sign_mask));
51 i += 4;
52 }
53 let mut result = {
54 let temp = _mm_max_ps(max_vec, _mm_movehl_ps(max_vec, max_vec));
55 let temp = _mm_max_ss(temp, _mm_shuffle_ps(temp, temp, 1));
56 _mm_cvtss_f32(temp)
57 };
58 for &val in &a[i..] {
59 let abs_val = val.abs();
60 if abs_val > result {
61 result = abs_val;
62 }
63 }
64 result
65 }
66}
67
68#[inline]
70#[target_feature(enable = "sse2")]
71pub(crate) unsafe fn scale(a: &[f32], scalar: f32, result: &mut [f32]) {
73 unsafe {
74 let len = a.len();
75 let mut i = 0;
76 let scalar_vec = _mm_set1_ps(scalar);
77 while i + 4 <= len {
78 _mm_storeu_ps(
79 result.as_mut_ptr().add(i),
80 _mm_mul_ps(_mm_loadu_ps(a.as_ptr().add(i)), scalar_vec),
81 );
82 i += 4;
83 }
84 for j in i..len {
85 result[j] = a[j] * scalar;
86 }
87 }
88}
89
90#[inline]
92#[target_feature(enable = "sse2")]
93pub(crate) unsafe fn abs(a: &[f32], result: &mut [f32]) {
95 unsafe {
96 let len = a.len();
97 let mut i = 0;
98 let sign_mask = _mm_set1_ps(f32::from_bits(0x7FFF_FFFF));
99 while i + 4 <= len {
100 _mm_storeu_ps(
101 result.as_mut_ptr().add(i),
102 _mm_and_ps(_mm_loadu_ps(a.as_ptr().add(i)), sign_mask),
103 );
104 i += 4;
105 }
106 for j in i..len {
107 result[j] = a[j].abs();
108 }
109 }
110}
111
112#[inline]
114#[target_feature(enable = "sse2")]
115pub(crate) unsafe fn clamp(a: &[f32], min_val: f32, max_val: f32, result: &mut [f32]) {
117 unsafe {
118 let len = a.len();
119 let mut i = 0;
120 let min_vec = _mm_set1_ps(min_val);
121 let max_vec = _mm_set1_ps(max_val);
122 while i + 4 <= len {
123 let va = _mm_loadu_ps(a.as_ptr().add(i));
124 _mm_storeu_ps(result.as_mut_ptr().add(i), _mm_min_ps(_mm_max_ps(va, min_vec), max_vec));
125 i += 4;
126 }
127 for j in i..len {
128 result[j] = a[j].max(min_val).min(max_val);
129 }
130 }
131}
132
133#[inline]
135#[target_feature(enable = "sse2")]
136pub(crate) unsafe fn lerp(a: &[f32], b: &[f32], t: f32, result: &mut [f32]) {
138 unsafe {
139 let len = a.len();
140 let mut i = 0;
141 let t_vec = _mm_set1_ps(t);
142 while i + 4 <= len {
143 let va = _mm_loadu_ps(a.as_ptr().add(i));
144 let vb = _mm_loadu_ps(b.as_ptr().add(i));
145 _mm_storeu_ps(
146 result.as_mut_ptr().add(i),
147 _mm_add_ps(va, _mm_mul_ps(t_vec, _mm_sub_ps(vb, va))),
148 );
149 i += 4;
150 }
151 for j in i..len {
152 result[j] = a[j] + t * (b[j] - a[j]);
153 }
154 }
155}
156
157#[inline]
159#[target_feature(enable = "sse2")]
160pub(crate) unsafe fn fma(a: &[f32], b: &[f32], c: &[f32], result: &mut [f32]) {
162 unsafe {
163 let len = a.len();
164 let mut i = 0;
165 while i + 4 <= len {
166 let va = _mm_loadu_ps(a.as_ptr().add(i));
167 let vb = _mm_loadu_ps(b.as_ptr().add(i));
168 let vc = _mm_loadu_ps(c.as_ptr().add(i));
169 _mm_storeu_ps(result.as_mut_ptr().add(i), _mm_add_ps(_mm_mul_ps(va, vb), vc));
170 i += 4;
171 }
172 for j in i..len {
173 result[j] = a[j] * b[j] + c[j];
174 }
175 }
176}
177
178#[inline]
180#[target_feature(enable = "sse2")]
181pub(crate) unsafe fn relu(a: &[f32], result: &mut [f32]) {
183 unsafe {
184 let len = a.len();
185 let mut i = 0;
186 let zero = _mm_setzero_ps();
187 while i + 4 <= len {
188 _mm_storeu_ps(
189 result.as_mut_ptr().add(i),
190 _mm_max_ps(_mm_loadu_ps(a.as_ptr().add(i)), zero),
191 );
192 i += 4;
193 }
194 for j in i..len {
195 result[j] = a[j].max(0.0);
196 }
197 }
198}
199
200#[inline]
202#[target_feature(enable = "sse2")]
203pub(crate) unsafe fn sqrt(a: &[f32], result: &mut [f32]) {
205 unsafe {
206 let len = a.len();
207 let mut i = 0;
208 while i + 4 <= len {
209 _mm_storeu_ps(result.as_mut_ptr().add(i), _mm_sqrt_ps(_mm_loadu_ps(a.as_ptr().add(i))));
210 i += 4;
211 }
212 for j in i..len {
213 result[j] = a[j].sqrt();
214 }
215 }
216}
217
218#[inline]
220#[target_feature(enable = "sse2")]
221pub(crate) unsafe fn recip(a: &[f32], result: &mut [f32]) {
223 unsafe {
224 let len = a.len();
225 let mut i = 0;
226 let one = _mm_set1_ps(1.0);
227 while i + 4 <= len {
228 _mm_storeu_ps(
229 result.as_mut_ptr().add(i),
230 _mm_div_ps(one, _mm_loadu_ps(a.as_ptr().add(i))),
231 );
232 i += 4;
233 }
234 for j in i..len {
235 result[j] = a[j].recip();
236 }
237 }
238}