1#[cfg(target_arch = "x86_64")]
25#[inline]
26pub fn has_avx2() -> bool {
27 is_x86_feature_detected!("avx2")
29}
30
31#[cfg(not(target_arch = "x86_64"))]
32#[inline]
33pub fn has_avx2() -> bool {
34 false
35}
36
37#[derive(Clone, Copy)]
41pub enum BinOp {
42 Add,
43 Sub,
44 Mul,
45 Div,
46}
47
48const PARALLEL_THRESHOLD: usize = 100_000;
51
52pub fn simd_binop(a: &[f64], b: &[f64], op: BinOp) -> Vec<f64> {
61 let n = a.len();
62 debug_assert_eq!(n, b.len());
63
64 #[cfg(feature = "parallel")]
66 {
67 if n >= PARALLEL_THRESHOLD {
68 return simd_binop_parallel(a, b, op);
69 }
70 }
71
72 simd_binop_sequential(a, b, op)
73}
74
75fn simd_binop_sequential(a: &[f64], b: &[f64], op: BinOp) -> Vec<f64> {
77 let n = a.len();
78 let mut out = vec![0.0f64; n];
79
80 #[cfg(target_arch = "x86_64")]
81 {
82 if has_avx2() {
83 unsafe {
84 match op {
85 BinOp::Add => avx2_binop::<ADD_TAG>(a, b, &mut out),
86 BinOp::Sub => avx2_binop::<SUB_TAG>(a, b, &mut out),
87 BinOp::Mul => avx2_binop::<MUL_TAG>(a, b, &mut out),
88 BinOp::Div => avx2_binop::<DIV_TAG>(a, b, &mut out),
89 }
90 }
91 return out;
92 }
93 }
94
95 match op {
97 BinOp::Add => { for i in 0..n { out[i] = a[i] + b[i]; } }
98 BinOp::Sub => { for i in 0..n { out[i] = a[i] - b[i]; } }
99 BinOp::Mul => { for i in 0..n { out[i] = a[i] * b[i]; } }
100 BinOp::Div => { for i in 0..n { out[i] = a[i] / b[i]; } }
101 }
102 out
103}
104
105#[cfg(feature = "parallel")]
110fn simd_binop_parallel(a: &[f64], b: &[f64], op: BinOp) -> Vec<f64> {
111 use rayon::prelude::*;
112
113 let n = a.len();
114 let mut out = vec![0.0f64; n];
115 let chunk_size = 4096; crate::runtime_policy::run_parallel(|| {
120 out.par_chunks_mut(chunk_size)
121 .enumerate()
122 .for_each(|(chunk_idx, out_chunk)| {
123 let start = chunk_idx * chunk_size;
124 let len = out_chunk.len();
125 let a_chunk = &a[start..start + len];
126 let b_chunk = &b[start..start + len];
127
128 #[cfg(target_arch = "x86_64")]
129 {
130 if has_avx2() {
131 unsafe {
132 match op {
133 BinOp::Add => avx2_binop::<ADD_TAG>(a_chunk, b_chunk, out_chunk),
134 BinOp::Sub => avx2_binop::<SUB_TAG>(a_chunk, b_chunk, out_chunk),
135 BinOp::Mul => avx2_binop::<MUL_TAG>(a_chunk, b_chunk, out_chunk),
136 BinOp::Div => avx2_binop::<DIV_TAG>(a_chunk, b_chunk, out_chunk),
137 }
138 }
139 return;
140 }
141 }
142
143 match op {
144 BinOp::Add => { for i in 0..len { out_chunk[i] = a_chunk[i] + b_chunk[i]; } }
145 BinOp::Sub => { for i in 0..len { out_chunk[i] = a_chunk[i] - b_chunk[i]; } }
146 BinOp::Mul => { for i in 0..len { out_chunk[i] = a_chunk[i] * b_chunk[i]; } }
147 BinOp::Div => { for i in 0..len { out_chunk[i] = a_chunk[i] / b_chunk[i]; } }
148 }
149 });
150 });
151
152 out
153}
154
155const ADD_TAG: u8 = 0;
157const SUB_TAG: u8 = 1;
158const MUL_TAG: u8 = 2;
159const DIV_TAG: u8 = 3;
160
161#[cfg(target_arch = "x86_64")]
162#[target_feature(enable = "avx2")]
163unsafe fn avx2_binop<const OP: u8>(a: &[f64], b: &[f64], out: &mut [f64]) {
164 use std::arch::x86_64::*;
165 let n = a.len();
166 let mut i = 0;
167
168 while i + 4 <= n {
169 let va = _mm256_loadu_pd(a.as_ptr().add(i));
170 let vb = _mm256_loadu_pd(b.as_ptr().add(i));
171 let vr = match OP {
172 ADD_TAG => _mm256_add_pd(va, vb),
173 SUB_TAG => _mm256_sub_pd(va, vb),
174 MUL_TAG => _mm256_mul_pd(va, vb),
175 _ => _mm256_div_pd(va, vb), };
177 _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
178 i += 4;
179 }
180
181 while i < n {
183 out[i] = match OP {
184 ADD_TAG => a[i] + b[i],
185 SUB_TAG => a[i] - b[i],
186 MUL_TAG => a[i] * b[i],
187 _ => a[i] / b[i],
188 };
189 i += 1;
190 }
191}
192
193#[derive(Clone, Copy)]
197pub enum UnaryOp {
198 Sqrt,
199 Abs,
200 Neg,
201 Relu,
202}
203
204pub fn simd_unary(a: &[f64], op: UnaryOp) -> Vec<f64> {
213 let n = a.len();
214 let mut out = vec![0.0f64; n];
215
216 #[cfg(target_arch = "x86_64")]
217 {
218 if has_avx2() {
219 unsafe {
220 match op {
221 UnaryOp::Sqrt => avx2_sqrt(a, &mut out),
222 UnaryOp::Abs => avx2_abs(a, &mut out),
223 UnaryOp::Neg => avx2_neg(a, &mut out),
224 UnaryOp::Relu => avx2_relu(a, &mut out),
225 }
226 }
227 return out;
228 }
229 }
230
231 match op {
233 UnaryOp::Sqrt => { for i in 0..n { out[i] = a[i].sqrt(); } }
234 UnaryOp::Abs => { for i in 0..n { out[i] = a[i].abs(); } }
235 UnaryOp::Neg => { for i in 0..n { out[i] = -a[i]; } }
236 UnaryOp::Relu => { for i in 0..n { out[i] = if a[i] > 0.0 { a[i] } else { 0.0 }; } }
237 }
238 out
239}
240
241#[cfg(target_arch = "x86_64")]
242#[target_feature(enable = "avx2")]
243unsafe fn avx2_sqrt(a: &[f64], out: &mut [f64]) {
244 use std::arch::x86_64::*;
245 let n = a.len();
246 let mut i = 0;
247 while i + 4 <= n {
248 let va = _mm256_loadu_pd(a.as_ptr().add(i));
249 let vr = _mm256_sqrt_pd(va);
250 _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
251 i += 4;
252 }
253 while i < n { out[i] = a[i].sqrt(); i += 1; }
254}
255
256#[cfg(target_arch = "x86_64")]
257#[target_feature(enable = "avx2")]
258unsafe fn avx2_abs(a: &[f64], out: &mut [f64]) {
259 use std::arch::x86_64::*;
260 let n = a.len();
261 let mask = _mm256_set1_pd(f64::from_bits(0x7FFF_FFFF_FFFF_FFFFu64));
263 let mut i = 0;
264 while i + 4 <= n {
265 let va = _mm256_loadu_pd(a.as_ptr().add(i));
266 let vr = _mm256_and_pd(va, mask);
267 _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
268 i += 4;
269 }
270 while i < n { out[i] = a[i].abs(); i += 1; }
271}
272
273#[cfg(target_arch = "x86_64")]
274#[target_feature(enable = "avx2")]
275unsafe fn avx2_neg(a: &[f64], out: &mut [f64]) {
276 use std::arch::x86_64::*;
277 let n = a.len();
278 let sign_bit = _mm256_set1_pd(f64::from_bits(0x8000_0000_0000_0000u64));
280 let mut i = 0;
281 while i + 4 <= n {
282 let va = _mm256_loadu_pd(a.as_ptr().add(i));
283 let vr = _mm256_xor_pd(va, sign_bit);
284 _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
285 i += 4;
286 }
287 while i < n { out[i] = -a[i]; i += 1; }
288}
289
290#[cfg(target_arch = "x86_64")]
291#[target_feature(enable = "avx2")]
292unsafe fn avx2_relu(a: &[f64], out: &mut [f64]) {
293 use std::arch::x86_64::*;
294 let n = a.len();
295 let zero = _mm256_setzero_pd();
296 let mut i = 0;
297 while i + 4 <= n {
298 let va = _mm256_loadu_pd(a.as_ptr().add(i));
299 let vr = _mm256_max_pd(va, zero);
300 _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
301 i += 4;
302 }
303 while i < n { out[i] = if a[i] > 0.0 { a[i] } else { 0.0 }; i += 1; }
304}
305
306pub fn simd_axpy(c: &mut [f64], b: &[f64], scalar: f64, len: usize) {
316 debug_assert!(c.len() >= len);
317 debug_assert!(b.len() >= len);
318
319 #[cfg(target_arch = "x86_64")]
320 {
321 if has_avx2() {
322 unsafe { avx2_axpy(c, b, scalar, len); }
323 return;
324 }
325 }
326
327 for j in 0..len {
329 c[j] += scalar * b[j];
330 }
331}
332
333#[cfg(target_arch = "x86_64")]
334#[target_feature(enable = "avx2")]
335unsafe fn avx2_axpy(c: &mut [f64], b: &[f64], scalar: f64, len: usize) {
336 use std::arch::x86_64::*;
337 let a_vec = _mm256_set1_pd(scalar);
338 let mut j = 0;
339
340 while j + 4 <= len {
341 let c_ptr = c.as_mut_ptr().add(j);
342 let b_ptr = b.as_ptr().add(j);
343 let c_val = _mm256_loadu_pd(c_ptr);
344 let b_val = _mm256_loadu_pd(b_ptr);
345 let prod = _mm256_mul_pd(a_vec, b_val);
347 let result = _mm256_add_pd(c_val, prod);
348 _mm256_storeu_pd(c_ptr, result);
349 j += 4;
350 }
351
352 while j < len {
354 *c.get_unchecked_mut(j) += scalar * *b.get_unchecked(j);
355 j += 1;
356 }
357}
358
359#[cfg(test)]
362mod tests {
363 use super::*;
364
365 #[test]
366 fn test_simd_add_matches_scalar() {
367 let a: Vec<f64> = (0..17).map(|i| i as f64 * 0.3).collect();
368 let b: Vec<f64> = (0..17).map(|i| (17 - i) as f64 * 0.7).collect();
369 let result = simd_binop(&a, &b, BinOp::Add);
370 let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect();
371 assert_eq!(result, expected, "SIMD add must be bit-identical to scalar");
372 }
373
374 #[test]
375 fn test_simd_sub_matches_scalar() {
376 let a: Vec<f64> = (0..17).map(|i| i as f64 * 1.1).collect();
377 let b: Vec<f64> = (0..17).map(|i| (17 - i) as f64 * 0.9).collect();
378 let result = simd_binop(&a, &b, BinOp::Sub);
379 let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x - y).collect();
380 assert_eq!(result, expected, "SIMD sub must be bit-identical to scalar");
381 }
382
383 #[test]
384 fn test_simd_mul_matches_scalar() {
385 let a: Vec<f64> = (0..17).map(|i| i as f64 * 0.1 + 0.01).collect();
386 let b: Vec<f64> = (0..17).map(|i| (17 - i) as f64 * 0.2 + 0.03).collect();
387 let result = simd_binop(&a, &b, BinOp::Mul);
388 let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect();
389 assert_eq!(result, expected, "SIMD mul must be bit-identical to scalar");
390 }
391
392 #[test]
393 fn test_simd_div_matches_scalar() {
394 let a: Vec<f64> = (0..17).map(|i| i as f64 * 0.5 + 1.0).collect();
395 let b: Vec<f64> = (0..17).map(|i| (i + 1) as f64 * 0.3).collect();
396 let result = simd_binop(&a, &b, BinOp::Div);
397 let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x / y).collect();
398 assert_eq!(result, expected, "SIMD div must be bit-identical to scalar");
399 }
400
401 #[test]
402 fn test_simd_sqrt_matches_scalar() {
403 let a: Vec<f64> = (0..17).map(|i| i as f64 * 2.5 + 0.1).collect();
404 let result = simd_unary(&a, UnaryOp::Sqrt);
405 let expected: Vec<f64> = a.iter().map(|&x| x.sqrt()).collect();
406 assert_eq!(result, expected, "SIMD sqrt must be bit-identical to scalar");
407 }
408
409 #[test]
410 fn test_simd_abs_matches_scalar() {
411 let a: Vec<f64> = (-8..9).map(|i| i as f64 * 1.5).collect();
412 let result = simd_unary(&a, UnaryOp::Abs);
413 let expected: Vec<f64> = a.iter().map(|&x| x.abs()).collect();
414 assert_eq!(result, expected, "SIMD abs must be bit-identical to scalar");
415 }
416
417 #[test]
418 fn test_simd_neg_matches_scalar() {
419 let a: Vec<f64> = (-8..9).map(|i| i as f64 * 1.5).collect();
420 let result = simd_unary(&a, UnaryOp::Neg);
421 let expected: Vec<f64> = a.iter().map(|&x| -x).collect();
422 assert_eq!(result, expected, "SIMD neg must be bit-identical to scalar");
423 }
424
425 #[test]
426 fn test_simd_relu_matches_scalar() {
427 let a: Vec<f64> = (-8..9).map(|i| i as f64 * 1.5).collect();
428 let result = simd_unary(&a, UnaryOp::Relu);
429 let expected: Vec<f64> = a.iter().map(|&x| if x > 0.0 { x } else { 0.0 }).collect();
430 assert_eq!(result, expected, "SIMD relu must be bit-identical to scalar");
431 }
432
433 #[test]
434 fn test_simd_axpy_matches_scalar() {
435 let b: Vec<f64> = (0..17).map(|i| i as f64 * 0.3).collect();
436 let scalar = 2.5;
437 let mut c_simd: Vec<f64> = (0..17).map(|i| i as f64 * 0.1).collect();
438 let mut c_scalar = c_simd.clone();
439
440 simd_axpy(&mut c_simd, &b, scalar, 17);
441 for j in 0..17 {
442 c_scalar[j] += scalar * b[j];
443 }
444 assert_eq!(c_simd, c_scalar, "SIMD axpy must be bit-identical to scalar");
445 }
446
447 #[test]
448 fn test_simd_empty_input() {
449 let empty: Vec<f64> = vec![];
450 assert_eq!(simd_binop(&empty, &empty, BinOp::Add), Vec::<f64>::new());
451 assert_eq!(simd_unary(&empty, UnaryOp::Sqrt), Vec::<f64>::new());
452 }
453
454 #[test]
455 fn test_simd_single_element() {
456 let a = vec![3.0];
457 let b = vec![4.0];
458 assert_eq!(simd_binop(&a, &b, BinOp::Add), vec![7.0]);
459 assert_eq!(simd_unary(&a, UnaryOp::Sqrt), vec![3.0f64.sqrt()]);
460 }
461
462 #[test]
463 fn test_simd_exactly_four_elements() {
464 let a = vec![1.0, 2.0, 3.0, 4.0];
465 let b = vec![5.0, 6.0, 7.0, 8.0];
466 assert_eq!(simd_binop(&a, &b, BinOp::Add), vec![6.0, 8.0, 10.0, 12.0]);
467 assert_eq!(simd_binop(&a, &b, BinOp::Mul), vec![5.0, 12.0, 21.0, 32.0]);
468 }
469
470 #[test]
471 fn test_avx2_detection() {
472 let _has = has_avx2();
474 }
475}