1#[cfg(target_arch = "x86_64")]
25#[inline]
26pub fn has_avx2() -> bool {
27 is_x86_feature_detected!("avx2")
29}
30
31#[cfg(not(target_arch = "x86_64"))]
32#[inline]
33pub fn has_avx2() -> bool {
34 false
35}
36
37#[derive(Clone, Copy)]
41pub enum BinOp {
42 Add,
43 Sub,
44 Mul,
45 Div,
46}
47
48const PARALLEL_THRESHOLD: usize = 100_000;
51
52pub fn simd_binop(a: &[f64], b: &[f64], op: BinOp) -> Vec<f64> {
61 let n = a.len();
62 debug_assert_eq!(n, b.len());
63
64 #[cfg(feature = "parallel")]
66 {
67 if n >= PARALLEL_THRESHOLD {
68 return simd_binop_parallel(a, b, op);
69 }
70 }
71
72 simd_binop_sequential(a, b, op)
73}
74
75fn simd_binop_sequential(a: &[f64], b: &[f64], op: BinOp) -> Vec<f64> {
77 let n = a.len();
78 let mut out = vec![0.0f64; n];
79
80 #[cfg(target_arch = "x86_64")]
81 {
82 if has_avx2() {
83 unsafe {
84 match op {
85 BinOp::Add => avx2_binop::<ADD_TAG>(a, b, &mut out),
86 BinOp::Sub => avx2_binop::<SUB_TAG>(a, b, &mut out),
87 BinOp::Mul => avx2_binop::<MUL_TAG>(a, b, &mut out),
88 BinOp::Div => avx2_binop::<DIV_TAG>(a, b, &mut out),
89 }
90 }
91 return out;
92 }
93 }
94
95 match op {
97 BinOp::Add => { for i in 0..n { out[i] = a[i] + b[i]; } }
98 BinOp::Sub => { for i in 0..n { out[i] = a[i] - b[i]; } }
99 BinOp::Mul => { for i in 0..n { out[i] = a[i] * b[i]; } }
100 BinOp::Div => { for i in 0..n { out[i] = a[i] / b[i]; } }
101 }
102 out
103}
104
105#[cfg(feature = "parallel")]
110fn simd_binop_parallel(a: &[f64], b: &[f64], op: BinOp) -> Vec<f64> {
111 use rayon::prelude::*;
112
113 let n = a.len();
114 let mut out = vec![0.0f64; n];
115 let chunk_size = 4096; out.par_chunks_mut(chunk_size)
118 .enumerate()
119 .for_each(|(chunk_idx, out_chunk)| {
120 let start = chunk_idx * chunk_size;
121 let len = out_chunk.len();
122 let a_chunk = &a[start..start + len];
123 let b_chunk = &b[start..start + len];
124
125 #[cfg(target_arch = "x86_64")]
126 {
127 if has_avx2() {
128 unsafe {
129 match op {
130 BinOp::Add => avx2_binop::<ADD_TAG>(a_chunk, b_chunk, out_chunk),
131 BinOp::Sub => avx2_binop::<SUB_TAG>(a_chunk, b_chunk, out_chunk),
132 BinOp::Mul => avx2_binop::<MUL_TAG>(a_chunk, b_chunk, out_chunk),
133 BinOp::Div => avx2_binop::<DIV_TAG>(a_chunk, b_chunk, out_chunk),
134 }
135 }
136 return;
137 }
138 }
139
140 match op {
141 BinOp::Add => { for i in 0..len { out_chunk[i] = a_chunk[i] + b_chunk[i]; } }
142 BinOp::Sub => { for i in 0..len { out_chunk[i] = a_chunk[i] - b_chunk[i]; } }
143 BinOp::Mul => { for i in 0..len { out_chunk[i] = a_chunk[i] * b_chunk[i]; } }
144 BinOp::Div => { for i in 0..len { out_chunk[i] = a_chunk[i] / b_chunk[i]; } }
145 }
146 });
147
148 out
149}
150
151const ADD_TAG: u8 = 0;
153const SUB_TAG: u8 = 1;
154const MUL_TAG: u8 = 2;
155const DIV_TAG: u8 = 3;
156
157#[cfg(target_arch = "x86_64")]
158#[target_feature(enable = "avx2")]
159unsafe fn avx2_binop<const OP: u8>(a: &[f64], b: &[f64], out: &mut [f64]) {
160 use std::arch::x86_64::*;
161 let n = a.len();
162 let mut i = 0;
163
164 while i + 4 <= n {
165 let va = _mm256_loadu_pd(a.as_ptr().add(i));
166 let vb = _mm256_loadu_pd(b.as_ptr().add(i));
167 let vr = match OP {
168 ADD_TAG => _mm256_add_pd(va, vb),
169 SUB_TAG => _mm256_sub_pd(va, vb),
170 MUL_TAG => _mm256_mul_pd(va, vb),
171 _ => _mm256_div_pd(va, vb), };
173 _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
174 i += 4;
175 }
176
177 while i < n {
179 out[i] = match OP {
180 ADD_TAG => a[i] + b[i],
181 SUB_TAG => a[i] - b[i],
182 MUL_TAG => a[i] * b[i],
183 _ => a[i] / b[i],
184 };
185 i += 1;
186 }
187}
188
189#[derive(Clone, Copy)]
193pub enum UnaryOp {
194 Sqrt,
195 Abs,
196 Neg,
197 Relu,
198}
199
200pub fn simd_unary(a: &[f64], op: UnaryOp) -> Vec<f64> {
209 let n = a.len();
210 let mut out = vec![0.0f64; n];
211
212 #[cfg(target_arch = "x86_64")]
213 {
214 if has_avx2() {
215 unsafe {
216 match op {
217 UnaryOp::Sqrt => avx2_sqrt(a, &mut out),
218 UnaryOp::Abs => avx2_abs(a, &mut out),
219 UnaryOp::Neg => avx2_neg(a, &mut out),
220 UnaryOp::Relu => avx2_relu(a, &mut out),
221 }
222 }
223 return out;
224 }
225 }
226
227 match op {
229 UnaryOp::Sqrt => { for i in 0..n { out[i] = a[i].sqrt(); } }
230 UnaryOp::Abs => { for i in 0..n { out[i] = a[i].abs(); } }
231 UnaryOp::Neg => { for i in 0..n { out[i] = -a[i]; } }
232 UnaryOp::Relu => { for i in 0..n { out[i] = if a[i] > 0.0 { a[i] } else { 0.0 }; } }
233 }
234 out
235}
236
237#[cfg(target_arch = "x86_64")]
238#[target_feature(enable = "avx2")]
239unsafe fn avx2_sqrt(a: &[f64], out: &mut [f64]) {
240 use std::arch::x86_64::*;
241 let n = a.len();
242 let mut i = 0;
243 while i + 4 <= n {
244 let va = _mm256_loadu_pd(a.as_ptr().add(i));
245 let vr = _mm256_sqrt_pd(va);
246 _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
247 i += 4;
248 }
249 while i < n { out[i] = a[i].sqrt(); i += 1; }
250}
251
252#[cfg(target_arch = "x86_64")]
253#[target_feature(enable = "avx2")]
254unsafe fn avx2_abs(a: &[f64], out: &mut [f64]) {
255 use std::arch::x86_64::*;
256 let n = a.len();
257 let mask = _mm256_set1_pd(f64::from_bits(0x7FFF_FFFF_FFFF_FFFFu64));
259 let mut i = 0;
260 while i + 4 <= n {
261 let va = _mm256_loadu_pd(a.as_ptr().add(i));
262 let vr = _mm256_and_pd(va, mask);
263 _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
264 i += 4;
265 }
266 while i < n { out[i] = a[i].abs(); i += 1; }
267}
268
269#[cfg(target_arch = "x86_64")]
270#[target_feature(enable = "avx2")]
271unsafe fn avx2_neg(a: &[f64], out: &mut [f64]) {
272 use std::arch::x86_64::*;
273 let n = a.len();
274 let sign_bit = _mm256_set1_pd(f64::from_bits(0x8000_0000_0000_0000u64));
276 let mut i = 0;
277 while i + 4 <= n {
278 let va = _mm256_loadu_pd(a.as_ptr().add(i));
279 let vr = _mm256_xor_pd(va, sign_bit);
280 _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
281 i += 4;
282 }
283 while i < n { out[i] = -a[i]; i += 1; }
284}
285
286#[cfg(target_arch = "x86_64")]
287#[target_feature(enable = "avx2")]
288unsafe fn avx2_relu(a: &[f64], out: &mut [f64]) {
289 use std::arch::x86_64::*;
290 let n = a.len();
291 let zero = _mm256_setzero_pd();
292 let mut i = 0;
293 while i + 4 <= n {
294 let va = _mm256_loadu_pd(a.as_ptr().add(i));
295 let vr = _mm256_max_pd(va, zero);
296 _mm256_storeu_pd(out.as_mut_ptr().add(i), vr);
297 i += 4;
298 }
299 while i < n { out[i] = if a[i] > 0.0 { a[i] } else { 0.0 }; i += 1; }
300}
301
302pub fn simd_axpy(c: &mut [f64], b: &[f64], scalar: f64, len: usize) {
312 debug_assert!(c.len() >= len);
313 debug_assert!(b.len() >= len);
314
315 #[cfg(target_arch = "x86_64")]
316 {
317 if has_avx2() {
318 unsafe { avx2_axpy(c, b, scalar, len); }
319 return;
320 }
321 }
322
323 for j in 0..len {
325 c[j] += scalar * b[j];
326 }
327}
328
329#[cfg(target_arch = "x86_64")]
330#[target_feature(enable = "avx2")]
331unsafe fn avx2_axpy(c: &mut [f64], b: &[f64], scalar: f64, len: usize) {
332 use std::arch::x86_64::*;
333 let a_vec = _mm256_set1_pd(scalar);
334 let mut j = 0;
335
336 while j + 4 <= len {
337 let c_ptr = c.as_mut_ptr().add(j);
338 let b_ptr = b.as_ptr().add(j);
339 let c_val = _mm256_loadu_pd(c_ptr);
340 let b_val = _mm256_loadu_pd(b_ptr);
341 let prod = _mm256_mul_pd(a_vec, b_val);
343 let result = _mm256_add_pd(c_val, prod);
344 _mm256_storeu_pd(c_ptr, result);
345 j += 4;
346 }
347
348 while j < len {
350 *c.get_unchecked_mut(j) += scalar * *b.get_unchecked(j);
351 j += 1;
352 }
353}
354
355#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_simd_add_matches_scalar() {
363 let a: Vec<f64> = (0..17).map(|i| i as f64 * 0.3).collect();
364 let b: Vec<f64> = (0..17).map(|i| (17 - i) as f64 * 0.7).collect();
365 let result = simd_binop(&a, &b, BinOp::Add);
366 let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect();
367 assert_eq!(result, expected, "SIMD add must be bit-identical to scalar");
368 }
369
370 #[test]
371 fn test_simd_sub_matches_scalar() {
372 let a: Vec<f64> = (0..17).map(|i| i as f64 * 1.1).collect();
373 let b: Vec<f64> = (0..17).map(|i| (17 - i) as f64 * 0.9).collect();
374 let result = simd_binop(&a, &b, BinOp::Sub);
375 let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x - y).collect();
376 assert_eq!(result, expected, "SIMD sub must be bit-identical to scalar");
377 }
378
379 #[test]
380 fn test_simd_mul_matches_scalar() {
381 let a: Vec<f64> = (0..17).map(|i| i as f64 * 0.1 + 0.01).collect();
382 let b: Vec<f64> = (0..17).map(|i| (17 - i) as f64 * 0.2 + 0.03).collect();
383 let result = simd_binop(&a, &b, BinOp::Mul);
384 let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect();
385 assert_eq!(result, expected, "SIMD mul must be bit-identical to scalar");
386 }
387
388 #[test]
389 fn test_simd_div_matches_scalar() {
390 let a: Vec<f64> = (0..17).map(|i| i as f64 * 0.5 + 1.0).collect();
391 let b: Vec<f64> = (0..17).map(|i| (i + 1) as f64 * 0.3).collect();
392 let result = simd_binop(&a, &b, BinOp::Div);
393 let expected: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x / y).collect();
394 assert_eq!(result, expected, "SIMD div must be bit-identical to scalar");
395 }
396
397 #[test]
398 fn test_simd_sqrt_matches_scalar() {
399 let a: Vec<f64> = (0..17).map(|i| i as f64 * 2.5 + 0.1).collect();
400 let result = simd_unary(&a, UnaryOp::Sqrt);
401 let expected: Vec<f64> = a.iter().map(|&x| x.sqrt()).collect();
402 assert_eq!(result, expected, "SIMD sqrt must be bit-identical to scalar");
403 }
404
405 #[test]
406 fn test_simd_abs_matches_scalar() {
407 let a: Vec<f64> = (-8..9).map(|i| i as f64 * 1.5).collect();
408 let result = simd_unary(&a, UnaryOp::Abs);
409 let expected: Vec<f64> = a.iter().map(|&x| x.abs()).collect();
410 assert_eq!(result, expected, "SIMD abs must be bit-identical to scalar");
411 }
412
413 #[test]
414 fn test_simd_neg_matches_scalar() {
415 let a: Vec<f64> = (-8..9).map(|i| i as f64 * 1.5).collect();
416 let result = simd_unary(&a, UnaryOp::Neg);
417 let expected: Vec<f64> = a.iter().map(|&x| -x).collect();
418 assert_eq!(result, expected, "SIMD neg must be bit-identical to scalar");
419 }
420
421 #[test]
422 fn test_simd_relu_matches_scalar() {
423 let a: Vec<f64> = (-8..9).map(|i| i as f64 * 1.5).collect();
424 let result = simd_unary(&a, UnaryOp::Relu);
425 let expected: Vec<f64> = a.iter().map(|&x| if x > 0.0 { x } else { 0.0 }).collect();
426 assert_eq!(result, expected, "SIMD relu must be bit-identical to scalar");
427 }
428
429 #[test]
430 fn test_simd_axpy_matches_scalar() {
431 let b: Vec<f64> = (0..17).map(|i| i as f64 * 0.3).collect();
432 let scalar = 2.5;
433 let mut c_simd: Vec<f64> = (0..17).map(|i| i as f64 * 0.1).collect();
434 let mut c_scalar = c_simd.clone();
435
436 simd_axpy(&mut c_simd, &b, scalar, 17);
437 for j in 0..17 {
438 c_scalar[j] += scalar * b[j];
439 }
440 assert_eq!(c_simd, c_scalar, "SIMD axpy must be bit-identical to scalar");
441 }
442
443 #[test]
444 fn test_simd_empty_input() {
445 let empty: Vec<f64> = vec![];
446 assert_eq!(simd_binop(&empty, &empty, BinOp::Add), Vec::<f64>::new());
447 assert_eq!(simd_unary(&empty, UnaryOp::Sqrt), Vec::<f64>::new());
448 }
449
450 #[test]
451 fn test_simd_single_element() {
452 let a = vec![3.0];
453 let b = vec![4.0];
454 assert_eq!(simd_binop(&a, &b, BinOp::Add), vec![7.0]);
455 assert_eq!(simd_unary(&a, UnaryOp::Sqrt), vec![3.0f64.sqrt()]);
456 }
457
458 #[test]
459 fn test_simd_exactly_four_elements() {
460 let a = vec![1.0, 2.0, 3.0, 4.0];
461 let b = vec![5.0, 6.0, 7.0, 8.0];
462 assert_eq!(simd_binop(&a, &b, BinOp::Add), vec![6.0, 8.0, 10.0, 12.0]);
463 assert_eq!(simd_binop(&a, &b, BinOp::Mul), vec![5.0, 12.0, 21.0, 32.0]);
464 }
465
466 #[test]
467 fn test_avx2_detection() {
468 let _has = has_avx2();
470 }
471}