1#[cfg(target_arch = "x86_64")]
6use std::arch::x86_64::*;
7
8#[inline]
10pub fn simd_add_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
11 assert_eq!(a.len(), b.len());
12 assert_eq!(a.len(), out.len());
13
14 #[cfg(target_arch = "x86_64")]
15 {
16 if is_x86_feature_detected!("avx2") {
17 unsafe { simd_add_f32_avx2(a, b, out) }
18 } else if is_x86_feature_detected!("sse4.1") {
19 unsafe { simd_add_f32_sse(a, b, out) }
20 } else {
21 scalar_add_f32(a, b, out)
22 }
23 }
24
25 #[cfg(not(target_arch = "x86_64"))]
26 {
27 scalar_add_f32(a, b, out)
28 }
29}
30
31#[cfg(target_arch = "x86_64")]
33#[target_feature(enable = "avx2")]
34unsafe fn simd_add_f32_avx2(a: &[f32], b: &[f32], out: &mut [f32]) {
35 let len = a.len();
36 let mut i = 0;
37
38 while i + 8 <= len {
40 let va = _mm256_loadu_ps(a.as_ptr().add(i));
41 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
42 let vout = _mm256_add_ps(va, vb);
43 _mm256_storeu_ps(out.as_mut_ptr().add(i), vout);
44 i += 8;
45 }
46
47 while i < len {
49 out[i] = a[i] + b[i];
50 i += 1;
51 }
52}
53
54#[cfg(target_arch = "x86_64")]
56#[target_feature(enable = "sse4.1")]
57unsafe fn simd_add_f32_sse(a: &[f32], b: &[f32], out: &mut [f32]) {
58 let len = a.len();
59 let mut i = 0;
60
61 while i + 4 <= len {
63 let va = _mm_loadu_ps(a.as_ptr().add(i));
64 let vb = _mm_loadu_ps(b.as_ptr().add(i));
65 let vout = _mm_add_ps(va, vb);
66 _mm_storeu_ps(out.as_mut_ptr().add(i), vout);
67 i += 4;
68 }
69
70 while i < len {
72 out[i] = a[i] + b[i];
73 i += 1;
74 }
75}
76
77#[inline]
79fn scalar_add_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
80 for i in 0..a.len() {
81 out[i] = a[i] + b[i];
82 }
83}
84
85#[inline]
87pub fn simd_mul_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
88 assert_eq!(a.len(), b.len());
89 assert_eq!(a.len(), out.len());
90
91 #[cfg(target_arch = "x86_64")]
92 {
93 if is_x86_feature_detected!("avx2") {
94 unsafe { simd_mul_f32_avx2(a, b, out) }
95 } else if is_x86_feature_detected!("sse4.1") {
96 unsafe { simd_mul_f32_sse(a, b, out) }
97 } else {
98 scalar_mul_f32(a, b, out)
99 }
100 }
101
102 #[cfg(not(target_arch = "x86_64"))]
103 {
104 scalar_mul_f32(a, b, out)
105 }
106}
107
108#[cfg(target_arch = "x86_64")]
110#[target_feature(enable = "avx2")]
111unsafe fn simd_mul_f32_avx2(a: &[f32], b: &[f32], out: &mut [f32]) {
112 let len = a.len();
113 let mut i = 0;
114
115 while i + 8 <= len {
116 let va = _mm256_loadu_ps(a.as_ptr().add(i));
117 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
118 let vout = _mm256_mul_ps(va, vb);
119 _mm256_storeu_ps(out.as_mut_ptr().add(i), vout);
120 i += 8;
121 }
122
123 while i < len {
124 out[i] = a[i] * b[i];
125 i += 1;
126 }
127}
128
129#[cfg(target_arch = "x86_64")]
131#[target_feature(enable = "sse4.1")]
132unsafe fn simd_mul_f32_sse(a: &[f32], b: &[f32], out: &mut [f32]) {
133 let len = a.len();
134 let mut i = 0;
135
136 while i + 4 <= len {
137 let va = _mm_loadu_ps(a.as_ptr().add(i));
138 let vb = _mm_loadu_ps(b.as_ptr().add(i));
139 let vout = _mm_mul_ps(va, vb);
140 _mm_storeu_ps(out.as_mut_ptr().add(i), vout);
141 i += 4;
142 }
143
144 while i < len {
145 out[i] = a[i] * b[i];
146 i += 1;
147 }
148}
149
150#[inline]
152fn scalar_mul_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
153 for i in 0..a.len() {
154 out[i] = a[i] * b[i];
155 }
156}
157
158#[inline]
160pub fn simd_dot_f32(a: &[f32], b: &[f32]) -> f32 {
161 assert_eq!(a.len(), b.len());
162
163 #[cfg(target_arch = "x86_64")]
164 {
165 if is_x86_feature_detected!("avx2") {
166 unsafe { simd_dot_f32_avx2(a, b) }
167 } else if is_x86_feature_detected!("sse4.1") {
168 unsafe { simd_dot_f32_sse(a, b) }
169 } else {
170 scalar_dot_f32(a, b)
171 }
172 }
173
174 #[cfg(not(target_arch = "x86_64"))]
175 {
176 scalar_dot_f32(a, b)
177 }
178}
179
180#[cfg(target_arch = "x86_64")]
182#[target_feature(enable = "avx2")]
183unsafe fn simd_dot_f32_avx2(a: &[f32], b: &[f32]) -> f32 {
184 let len = a.len();
185 let mut i = 0;
186 let mut sum = _mm256_setzero_ps();
187
188 while i + 8 <= len {
189 let va = _mm256_loadu_ps(a.as_ptr().add(i));
190 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
191 let vprod = _mm256_mul_ps(va, vb);
192 sum = _mm256_add_ps(sum, vprod);
193 i += 8;
194 }
195
196 let mut result = 0.0f32;
198 let sum_array: [f32; 8] = std::mem::transmute(sum);
199 for &val in &sum_array {
200 result += val;
201 }
202
203 while i < len {
205 result += a[i] * b[i];
206 i += 1;
207 }
208
209 result
210}
211
212#[cfg(target_arch = "x86_64")]
214#[target_feature(enable = "sse4.1")]
215unsafe fn simd_dot_f32_sse(a: &[f32], b: &[f32]) -> f32 {
216 let len = a.len();
217 let mut i = 0;
218 let mut sum = _mm_setzero_ps();
219
220 while i + 4 <= len {
221 let va = _mm_loadu_ps(a.as_ptr().add(i));
222 let vb = _mm_loadu_ps(b.as_ptr().add(i));
223 let vprod = _mm_mul_ps(va, vb);
224 sum = _mm_add_ps(sum, vprod);
225 i += 4;
226 }
227
228 let mut result = 0.0f32;
230 let sum_array: [f32; 4] = std::mem::transmute(sum);
231 for &val in &sum_array {
232 result += val;
233 }
234
235 while i < len {
237 result += a[i] * b[i];
238 i += 1;
239 }
240
241 result
242}
243
244#[inline]
246fn scalar_dot_f32(a: &[f32], b: &[f32]) -> f32 {
247 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
248}
249
250#[inline]
252pub fn simd_relu_f32(input: &[f32], output: &mut [f32]) {
253 assert_eq!(input.len(), output.len());
254
255 #[cfg(target_arch = "x86_64")]
256 {
257 if is_x86_feature_detected!("avx2") {
258 unsafe { simd_relu_f32_avx2(input, output) }
259 } else if is_x86_feature_detected!("sse4.1") {
260 unsafe { simd_relu_f32_sse(input, output) }
261 } else {
262 scalar_relu_f32(input, output)
263 }
264 }
265
266 #[cfg(not(target_arch = "x86_64"))]
267 {
268 scalar_relu_f32(input, output)
269 }
270}
271
272#[cfg(target_arch = "x86_64")]
274#[target_feature(enable = "avx2")]
275unsafe fn simd_relu_f32_avx2(input: &[f32], output: &mut [f32]) {
276 let len = input.len();
277 let mut i = 0;
278 let zero = _mm256_setzero_ps();
279
280 while i + 8 <= len {
281 let v = _mm256_loadu_ps(input.as_ptr().add(i));
282 let vout = _mm256_max_ps(v, zero);
283 _mm256_storeu_ps(output.as_mut_ptr().add(i), vout);
284 i += 8;
285 }
286
287 while i < len {
288 output[i] = input[i].max(0.0);
289 i += 1;
290 }
291}
292
293#[cfg(target_arch = "x86_64")]
295#[target_feature(enable = "sse4.1")]
296unsafe fn simd_relu_f32_sse(input: &[f32], output: &mut [f32]) {
297 let len = input.len();
298 let mut i = 0;
299 let zero = _mm_setzero_ps();
300
301 while i + 4 <= len {
302 let v = _mm_loadu_ps(input.as_ptr().add(i));
303 let vout = _mm_max_ps(v, zero);
304 _mm_storeu_ps(output.as_mut_ptr().add(i), vout);
305 i += 4;
306 }
307
308 while i < len {
309 output[i] = input[i].max(0.0);
310 i += 1;
311 }
312}
313
314#[inline]
316fn scalar_relu_f32(input: &[f32], output: &mut [f32]) {
317 for i in 0..input.len() {
318 output[i] = input[i].max(0.0);
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn test_simd_add() {
328 let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
329 let b = vec![8.0f32, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
330 let mut out = vec![0.0f32; 8];
331
332 simd_add_f32(&a, &b, &mut out);
333
334 for i in 0..8 {
335 assert_eq!(out[i], 9.0);
336 }
337 }
338
339 #[test]
340 fn test_simd_mul() {
341 let a = vec![1.0f32, 2.0, 3.0, 4.0];
342 let b = vec![2.0f32, 3.0, 4.0, 5.0];
343 let mut out = vec![0.0f32; 4];
344
345 simd_mul_f32(&a, &b, &mut out);
346
347 assert_eq!(out, vec![2.0, 6.0, 12.0, 20.0]);
348 }
349
350 #[test]
351 fn test_simd_dot() {
352 let a = vec![1.0f32, 2.0, 3.0, 4.0];
353 let b = vec![5.0f32, 6.0, 7.0, 8.0];
354
355 let result = simd_dot_f32(&a, &b);
356
357 assert_eq!(result, 70.0); }
359
360 #[test]
361 fn test_simd_relu() {
362 let input = vec![-2.0f32, -1.0, 0.0, 1.0, 2.0];
363 let mut output = vec![0.0f32; 5];
364
365 simd_relu_f32(&input, &mut output);
366
367 assert_eq!(output, vec![0.0, 0.0, 0.0, 1.0, 2.0]);
368 }
369}