1pub fn elementwise_simd_supported() -> bool {
5 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
8 {
9 if std::is_x86_feature_detected!("avx2") {
12 return true;
13 }
14 if std::is_x86_feature_detected!("sse2") {
15 return true;
16 }
17 false
18 }
19 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
20 {
21 false
23 }
24}
25
26use crate::array::Array;
27use crate::llo::reduction::ReductionKind;
28use crate::llo::ElementwiseKind;
29use anyhow::{anyhow, Result};
30#[cfg(feature = "parallel")]
31use rayon::prelude::*;
32
33pub fn elementwise_simd(a: &Array, b: &Array, kind: ElementwiseKind) -> Result<Array> {
42 if a.shape != b.shape {
48 return Err(anyhow!("shape mismatch in simd elementwise"));
49 }
50
51 let mut out = Array::<f32>::zeros(a.shape.clone());
52 let n = a.len();
53
54 let mut i = 0usize;
55
56 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
58 unsafe {
59 if std::is_x86_feature_detected!("avx2") {
60 #[cfg(target_arch = "x86")]
61 use std::arch::x86::*;
62 #[cfg(target_arch = "x86_64")]
63 use std::arch::x86_64::*;
64
65 while i + 8 <= n {
66 let pa = _mm256_loadu_ps(a.data.as_ptr().add(i));
67 let pb = _mm256_loadu_ps(b.data.as_ptr().add(i));
68 let pr = match kind {
69 ElementwiseKind::Add => _mm256_add_ps(pa, pb),
70 ElementwiseKind::Mul => _mm256_mul_ps(pa, pb),
71 ElementwiseKind::Sub => _mm256_sub_ps(pa, pb),
72 ElementwiseKind::Div => _mm256_div_ps(pa, pb),
73 ElementwiseKind::Sqrt => _mm256_sqrt_ps(pa),
74 ElementwiseKind::Sin
76 | ElementwiseKind::Cos
77 | ElementwiseKind::Tan
78 | ElementwiseKind::Abs
79 | ElementwiseKind::Neg
80 | ElementwiseKind::Exp
81 | ElementwiseKind::Log
82 | ElementwiseKind::Pow
83 | ElementwiseKind::Asin
84 | ElementwiseKind::Acos
85 | ElementwiseKind::Atan
86 | ElementwiseKind::Relu
87 | ElementwiseKind::LeakyRelu
88 | ElementwiseKind::Sigmoid
89 | ElementwiseKind::Tanh
90 | ElementwiseKind::Softplus => {
91 let mut tmp_a = [0.0f32; 8];
93 let mut tmp_b = [0.0f32; 8];
94 _mm256_storeu_ps(tmp_a.as_mut_ptr(), pa);
95 _mm256_storeu_ps(tmp_b.as_mut_ptr(), pb);
96 for j in 0..8 {
97 tmp_a[j] = match kind {
98 ElementwiseKind::Sin => tmp_a[j].sin(),
99 ElementwiseKind::Cos => tmp_a[j].cos(),
100 ElementwiseKind::Tan => tmp_a[j].tan(),
101 ElementwiseKind::Abs => tmp_a[j].abs(),
102 ElementwiseKind::Neg => -tmp_a[j],
103 ElementwiseKind::Exp => tmp_a[j].exp(),
104 ElementwiseKind::Log => tmp_a[j].ln(),
105 ElementwiseKind::Pow => tmp_a[j].powf(tmp_b[j]),
106 ElementwiseKind::Asin => tmp_a[j].asin(),
107 ElementwiseKind::Acos => tmp_a[j].acos(),
108 ElementwiseKind::Atan => tmp_a[j].atan(),
109 ElementwiseKind::Relu => tmp_a[j].max(0.0),
110 ElementwiseKind::LeakyRelu => {
111 if tmp_a[j] > 0.0 {
112 tmp_a[j]
113 } else {
114 0.01 * tmp_a[j]
115 }
116 }
117 ElementwiseKind::Sigmoid => 1.0 / (1.0 + (-tmp_a[j]).exp()),
118 ElementwiseKind::Tanh => tmp_a[j].tanh(),
119 ElementwiseKind::Softplus => (1.0 + tmp_a[j].exp()).ln(),
120 _ => tmp_a[j],
121 };
122 }
123 let pr = _mm256_loadu_ps(tmp_a.as_ptr());
124 pr
125 }
126 };
127 _mm256_storeu_ps(out.data.as_mut_ptr().add(i), pr);
128 i += 8;
129 }
130 } else if std::is_x86_feature_detected!("sse2") {
131 #[cfg(target_arch = "x86")]
132 use std::arch::x86::*;
133 #[cfg(target_arch = "x86_64")]
134 use std::arch::x86_64::*;
135
136 while i + 4 <= n {
137 let pa = _mm_loadu_ps(a.data.as_ptr().add(i));
138 let pb = _mm_loadu_ps(b.data.as_ptr().add(i));
139 let pr = match kind {
140 ElementwiseKind::Add => _mm_add_ps(pa, pb),
141 ElementwiseKind::Mul => _mm_mul_ps(pa, pb),
142 ElementwiseKind::Sub => _mm_sub_ps(pa, pb),
143 ElementwiseKind::Div => _mm_div_ps(pa, pb),
144 ElementwiseKind::Sqrt => _mm_sqrt_ps(pa),
145 ElementwiseKind::Sin
146 | ElementwiseKind::Cos
147 | ElementwiseKind::Tan
148 | ElementwiseKind::Abs
149 | ElementwiseKind::Neg
150 | ElementwiseKind::Exp
151 | ElementwiseKind::Log
152 | ElementwiseKind::Pow
153 | ElementwiseKind::Asin
154 | ElementwiseKind::Acos
155 | ElementwiseKind::Atan
156 | ElementwiseKind::Relu
157 | ElementwiseKind::LeakyRelu
158 | ElementwiseKind::Sigmoid
159 | ElementwiseKind::Tanh
160 | ElementwiseKind::Softplus => {
161 let mut tmp_a = [0.0f32; 4];
162 let mut tmp_b = [0.0f32; 4];
163 _mm_storeu_ps(tmp_a.as_mut_ptr(), pa);
164 _mm_storeu_ps(tmp_b.as_mut_ptr(), pb);
165 for j in 0..4 {
166 tmp_a[j] = match kind {
167 ElementwiseKind::Sin => tmp_a[j].sin(),
168 ElementwiseKind::Cos => tmp_a[j].cos(),
169 ElementwiseKind::Tan => tmp_a[j].tan(),
170 ElementwiseKind::Abs => tmp_a[j].abs(),
171 ElementwiseKind::Neg => -tmp_a[j],
172 ElementwiseKind::Exp => tmp_a[j].exp(),
173 ElementwiseKind::Log => tmp_a[j].ln(),
174 ElementwiseKind::Pow => tmp_a[j].powf(tmp_b[j]),
175 ElementwiseKind::Asin => tmp_a[j].asin(),
176 ElementwiseKind::Acos => tmp_a[j].acos(),
177 ElementwiseKind::Atan => tmp_a[j].atan(),
178 ElementwiseKind::Relu => tmp_a[j].max(0.0),
179 ElementwiseKind::LeakyRelu => {
180 if tmp_a[j] > 0.0 {
181 tmp_a[j]
182 } else {
183 0.01 * tmp_a[j]
184 }
185 }
186 ElementwiseKind::Sigmoid => 1.0 / (1.0 + (-tmp_a[j]).exp()),
187 ElementwiseKind::Tanh => tmp_a[j].tanh(),
188 ElementwiseKind::Softplus => (1.0 + tmp_a[j].exp()).ln(),
189 _ => tmp_a[j],
190 };
191 }
192 let pr = _mm_loadu_ps(tmp_a.as_ptr());
193 pr
194 }
195 };
196 _mm_storeu_ps(out.data.as_mut_ptr().add(i), pr);
197 i += 4;
198 }
199 }
200 }
201
202 for j in i..n {
204 out.data[j] = match kind {
205 ElementwiseKind::Add => a.data[j] + b.data[j],
206 ElementwiseKind::Mul => a.data[j] * b.data[j],
207 ElementwiseKind::Sub => a.data[j] - b.data[j],
208 ElementwiseKind::Div => a.data[j] / b.data[j],
209 ElementwiseKind::Sqrt => a.data[j].sqrt(),
210 ElementwiseKind::Abs => a.data[j].abs(),
211 ElementwiseKind::Neg => -a.data[j],
212 ElementwiseKind::Exp => a.data[j].exp(),
213 ElementwiseKind::Log => a.data[j].ln(),
214 ElementwiseKind::Tan => a.data[j].tan(),
215 ElementwiseKind::Pow => a.data[j].powf(b.data[j]),
216 ElementwiseKind::Sin => a.data[j].sin(),
217 ElementwiseKind::Cos => a.data[j].cos(),
218 ElementwiseKind::Asin => a.data[j].asin(),
219 ElementwiseKind::Acos => a.data[j].acos(),
220 ElementwiseKind::Atan => a.data[j].atan(),
221 ElementwiseKind::Relu => a.data[j].max(0.0),
222 ElementwiseKind::LeakyRelu => {
223 if a.data[j] > 0.0 {
224 a.data[j]
225 } else {
226 0.01 * a.data[j]
227 }
228 }
229 ElementwiseKind::Sigmoid => 1.0 / (1.0 + (-a.data[j]).exp()),
230 ElementwiseKind::Tanh => a.data[j].tanh(),
231 ElementwiseKind::Softplus => (1.0 + a.data[j].exp()).ln(),
232 };
233 }
234
235 Ok(out)
236}
237
238pub fn reduce_simd(a: &Array, axis: Option<usize>, kind: ReductionKind) -> Result<Array> {
243 if axis.is_none() {
244 let n = a.len();
245
246 match kind {
247 ReductionKind::Sum => {
248 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
249 unsafe {
250 if std::is_x86_feature_detected!("avx2") {
251 #[cfg(target_arch = "x86")]
252 use std::arch::x86::*;
253 #[cfg(target_arch = "x86_64")]
254 use std::arch::x86_64::*;
255
256 let mut i = 0usize;
257 let mut acc = _mm256_setzero_ps();
258
259 while i + 8 <= n {
260 let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
261 acc = _mm256_add_ps(acc, p);
262 i += 8;
263 }
264
265 let mut s = [0.0f32; 8];
267 _mm256_storeu_ps(s.as_mut_ptr(), acc);
268 let mut sum = s.iter().copied().sum::<f32>();
269
270 while i < n {
272 sum += a.data[i];
273 i += 1;
274 }
275
276 return Ok(Array::new(vec![1], vec![sum]));
277 }
278 }
279
280 let sum: f32 = a.data.iter().copied().sum();
282 Ok(Array::new(vec![1], vec![sum]))
283 }
284 ReductionKind::Max => {
285 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
286 unsafe {
287 if std::is_x86_feature_detected!("avx2") {
288 #[cfg(target_arch = "x86")]
289 use std::arch::x86::*;
290 #[cfg(target_arch = "x86_64")]
291 use std::arch::x86_64::*;
292
293 let mut i = 0usize;
294 let mut acc = _mm256_set1_ps(f32::NEG_INFINITY);
295
296 while i + 8 <= n {
297 let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
298 acc = _mm256_max_ps(acc, p);
299 i += 8;
300 }
301
302 let mut s = [0.0f32; 8];
304 _mm256_storeu_ps(s.as_mut_ptr(), acc);
305 let mut max_val = s[0];
306 for &v in &s[1..] {
307 if v > max_val {
308 max_val = v;
309 }
310 }
311
312 while i < n {
314 if a.data[i] > max_val {
315 max_val = a.data[i];
316 }
317 i += 1;
318 }
319
320 return Ok(Array::new(vec![1], vec![max_val]));
321 }
322 }
323
324 let max_val = a
326 .data
327 .iter()
328 .copied()
329 .fold(f32::NEG_INFINITY, |acc, x| acc.max(x));
330 Ok(Array::new(vec![1], vec![max_val]))
331 }
332 ReductionKind::Min => {
333 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
334 unsafe {
335 if std::is_x86_feature_detected!("avx2") {
336 #[cfg(target_arch = "x86")]
337 use std::arch::x86::*;
338 #[cfg(target_arch = "x86_64")]
339 use std::arch::x86_64::*;
340
341 let mut i = 0usize;
342 let mut acc = _mm256_set1_ps(f32::INFINITY);
343
344 while i + 8 <= n {
345 let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
346 acc = _mm256_min_ps(acc, p);
347 i += 8;
348 }
349
350 let mut s = [0.0f32; 8];
352 _mm256_storeu_ps(s.as_mut_ptr(), acc);
353 let mut min_val = s[0];
354 for &v in &s[1..] {
355 if v < min_val {
356 min_val = v;
357 }
358 }
359
360 while i < n {
362 if a.data[i] < min_val {
363 min_val = a.data[i];
364 }
365 i += 1;
366 }
367
368 return Ok(Array::new(vec![1], vec![min_val]));
369 }
370 }
371
372 let min_val = a
374 .data
375 .iter()
376 .copied()
377 .fold(f32::INFINITY, |acc, x| acc.min(x));
378 Ok(Array::new(vec![1], vec![min_val]))
379 }
380 ReductionKind::Mean => {
381 let sum_result = reduce_simd(a, axis, ReductionKind::Sum)?;
383 let mean = sum_result.data[0] / n as f32;
384 Ok(Array::new(vec![1], vec![mean]))
385 }
386 ReductionKind::ArgMax | ReductionKind::Variance => {
387 crate::backend::cpu::scalar::reduce_scalar(a, None, kind)
389 }
390 }
391 } else {
392 let axis = axis.unwrap();
394
395 if axis == a.shape.len() - 1 {
397 return reduce_last_axis_simd(a, axis, kind);
398 }
399
400 crate::backend::cpu::scalar::reduce_scalar(a, Some(axis), kind)
403 }
404}
405
406fn reduce_last_axis_simd(a: &Array, axis: usize, kind: ReductionKind) -> Result<Array> {
409 let mut out_shape: Vec<usize> = a
411 .shape
412 .iter()
413 .enumerate()
414 .filter(|(i, _)| *i != axis)
415 .map(|(_, &d)| d)
416 .collect();
417
418 if out_shape.is_empty() {
419 out_shape.push(1);
420 }
421
422 let out_size: usize = out_shape.iter().product();
423 let axis_size = a.shape[axis];
424 let mut out_data = vec![0.0; out_size];
425
426 match kind {
427 ReductionKind::Sum | ReductionKind::Mean => {
428 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
429 {
430 if std::is_x86_feature_detected!("avx2") {
431 out_data
432 .par_iter_mut()
433 .enumerate()
434 .for_each(|(row_idx, out_val)| {
435 let start = row_idx * axis_size;
436 let end = start + axis_size;
437
438 unsafe {
439 #[cfg(target_arch = "x86")]
440 use std::arch::x86::*;
441 #[cfg(target_arch = "x86_64")]
442 use std::arch::x86_64::*;
443
444 let mut acc = _mm256_setzero_ps();
445 let mut i = start;
446
447 while i + 8 <= end {
449 let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
450 acc = _mm256_add_ps(acc, p);
451 i += 8;
452 }
453
454 let mut s = [0.0f32; 8];
456 _mm256_storeu_ps(s.as_mut_ptr(), acc);
457 let mut sum: f32 = s.iter().sum();
458
459 while i < end {
461 sum += a.data[i];
462 i += 1;
463 }
464
465 *out_val = sum;
466 }
467 });
468
469 if kind == ReductionKind::Mean {
470 out_data.par_iter_mut().for_each(|x| *x /= axis_size as f32);
471 }
472
473 return Ok(Array::new(out_shape, out_data));
474 }
475 }
476
477 return crate::backend::cpu::scalar::reduce_last_axis_optimized(
479 a, axis_size, out_size, out_shape, kind,
480 );
481 }
482 ReductionKind::Max => {
483 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
484 {
485 if std::is_x86_feature_detected!("avx2") {
486 out_data
487 .par_iter_mut()
488 .enumerate()
489 .for_each(|(row_idx, out_val)| {
490 let start = row_idx * axis_size;
491 let end = start + axis_size;
492
493 unsafe {
494 #[cfg(target_arch = "x86")]
495 use std::arch::x86::*;
496 #[cfg(target_arch = "x86_64")]
497 use std::arch::x86_64::*;
498
499 let mut acc = _mm256_set1_ps(f32::NEG_INFINITY);
500 let mut i = start;
501
502 while i + 8 <= end {
503 let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
504 acc = _mm256_max_ps(acc, p);
505 i += 8;
506 }
507
508 let mut s = [0.0f32; 8];
510 _mm256_storeu_ps(s.as_mut_ptr(), acc);
511 let mut max_val = s[0];
512 for &v in &s[1..] {
513 if v > max_val {
514 max_val = v;
515 }
516 }
517
518 while i < end {
520 if a.data[i] > max_val {
521 max_val = a.data[i];
522 }
523 i += 1;
524 }
525
526 *out_val = max_val;
527 }
528 });
529
530 return Ok(Array::new(out_shape, out_data));
531 }
532 }
533
534 return crate::backend::cpu::scalar::reduce_last_axis_optimized(
536 a, axis_size, out_size, out_shape, kind,
537 );
538 }
539 _ => {
540 return crate::backend::cpu::scalar::reduce_last_axis_optimized(
542 a, axis_size, out_size, out_shape, kind,
543 );
544 }
545 }
546}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551 use crate::backend::cpu::scalar;
552
553 fn make_arrays(len: usize) -> (Array, Array) {
554 let a = (0..len).map(|i| i as f32 * 0.5 + 0.1).collect::<Vec<_>>();
555 let b = (0..len).map(|i| (i as f32).sin()).collect::<Vec<_>>();
556 (Array::new(vec![len], a), Array::new(vec![len], b))
557 }
558
559 #[test]
560 fn simd_add_matches_scalar() {
561 for len in &[1usize, 3, 7, 8, 15, 16, 33, 64] {
562 let (a, b) = make_arrays(*len);
563 let out_simd = elementwise_simd(&a, &b, ElementwiseKind::Add).unwrap();
564 let out_scalar = scalar::elementwise_scalar(&a, &b, ElementwiseKind::Add).unwrap();
565 assert_eq!(out_simd.data, out_scalar.data);
566 }
567 }
568
569 #[test]
570 fn simd_mul_matches_scalar() {
571 for len in &[1usize, 3, 7, 8, 15, 16, 33, 64] {
572 let (a, b) = make_arrays(*len);
573 let out_simd = elementwise_simd(&a, &b, ElementwiseKind::Mul).unwrap();
574 let out_scalar = scalar::elementwise_scalar(&a, &b, ElementwiseKind::Mul).unwrap();
575 assert_eq!(out_simd.data, out_scalar.data);
576 }
577 }
578}
579
580pub fn dot_simd(a: &Array, b: &Array) -> Result<f32> {
582 if a.shape.len() != 1 || b.shape.len() != 1 {
583 return Err(anyhow!("dot_simd: both inputs must be 1-D arrays"));
584 }
585 if a.shape[0] != b.shape[0] {
586 return Err(anyhow!("dot_simd: arrays must have same length"));
587 }
588
589 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
590 {
591 if std::is_x86_feature_detected!("fma") && std::is_x86_feature_detected!("avx2") {
592 unsafe {
594 return dot_simd_avx2_fma(a, b);
595 }
596 }
597 }
598
599 crate::backend::cpu::scalar::dot_scalar(a, b)
601}
602
603#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
604#[target_feature(enable = "avx2,fma")]
605unsafe fn dot_simd_avx2_fma(a: &Array, b: &Array) -> Result<f32> {
606 #[cfg(target_arch = "x86")]
607 use std::arch::x86::*;
608 #[cfg(target_arch = "x86_64")]
609 use std::arch::x86_64::*;
610
611 let n = a.data.len();
612 let mut sum = _mm256_setzero_ps();
613
614 let chunks = n / 8;
616 for i in 0..chunks {
617 let offset = i * 8;
618 let va = _mm256_loadu_ps(a.data.as_ptr().add(offset));
619 let vb = _mm256_loadu_ps(b.data.as_ptr().add(offset));
620 sum = _mm256_fmadd_ps(va, vb, sum);
622 }
623
624 let mut result = [0.0f32; 8];
626 _mm256_storeu_ps(result.as_mut_ptr(), sum);
627 let mut total = result.iter().sum::<f32>();
628
629 for i in (chunks * 8)..n {
631 total += a.data[i] * b.data[i];
632 }
633
634 Ok(total)
635}
636
637pub fn matmul_simd(a: &Array, b: &Array) -> Array {
640 if a.shape.len() != 2 || b.shape.len() != 2 {
641 panic!("matmul_simd: both inputs must be 2-D arrays");
642 }
643
644 let m = a.shape[0];
645 let k = a.shape[1];
646 let n = b.shape[1];
647
648 if k != b.shape[0] {
649 panic!(
650 "matmul_simd: inner dimension mismatch: {} != {}",
651 k, b.shape[0]
652 );
653 }
654
655 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
656 {
657 if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
658 return matmul_simd_parallel(a, b, m, k, n);
661 }
662 }
663
664 super::matmul_scalar_direct(a, b)
666}
667
668#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
671fn matmul_simd_parallel(a: &Array, b: &Array, m: usize, k: usize, n: usize) -> Array {
672 use rayon::prelude::*;
673
674 let block_size = if m >= 2048 { 256 } else { 128 };
677 let mut result = vec![0.0f32; m * n];
678
679 result
680 .par_chunks_mut(block_size * n)
681 .enumerate()
682 .for_each(|(block_idx, out_block)| {
683 let start = block_idx * block_size;
684 let end = (start + block_size).min(m);
685 let block_rows = end - start;
686
687 let a_block_start = start * k;
689 let a_block_end = end * k;
690 let a_block_slice = &a.data[a_block_start..a_block_end];
691
692 let a_block = Array::new(vec![block_rows, k], a_block_slice.to_vec());
694
695 unsafe {
698 let block_result = matmul_simd_avx2_fma_blocked(
699 &a_block,
700 b,
701 block_rows,
702 k,
703 n,
704 vec![0.0f32; block_rows * n],
705 );
706 out_block.copy_from_slice(&block_result.data);
707 }
708 });
709
710 Array::new(vec![m, n], result)
711}
712
713#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
714#[target_feature(enable = "avx2,fma")]
715unsafe fn matmul_simd_avx2_fma_blocked(
716 a: &Array,
717 b: &Array,
718 m: usize,
719 k: usize,
720 n: usize,
721 mut result: Vec<f32>,
722) -> Array {
723 #[cfg(target_arch = "x86")]
724 use std::arch::x86::*;
725 #[cfg(target_arch = "x86_64")]
726 use std::arch::x86_64::*;
727
728 const BLOCK_M: usize = 96; const BLOCK_N: usize = 256; const BLOCK_K: usize = 512; for i0 in (0..m).step_by(BLOCK_M) {
735 let i_end = (i0 + BLOCK_M).min(m);
736
737 for j0 in (0..n).step_by(BLOCK_N) {
738 let j_end = (j0 + BLOCK_N).min(n);
739
740 for k0 in (0..k).step_by(BLOCK_K) {
741 let k_end = (k0 + BLOCK_K).min(k);
742
743 let mut i = i0;
745 while i + 2 <= i_end {
746 let a_row0_offset = i * k;
747 let a_row1_offset = (i + 1) * k;
748 let result_row0_offset = i * n;
749 let result_row1_offset = (i + 1) * n;
750
751 let mut j = j0;
753 while j + 16 <= j_end {
754 let mut sum0_0 =
755 _mm256_loadu_ps(result.as_ptr().add(result_row0_offset + j));
756 let mut sum0_1 =
757 _mm256_loadu_ps(result.as_ptr().add(result_row0_offset + j + 8));
758 let mut sum1_0 =
759 _mm256_loadu_ps(result.as_ptr().add(result_row1_offset + j));
760 let mut sum1_1 =
761 _mm256_loadu_ps(result.as_ptr().add(result_row1_offset + j + 8));
762
763 for kk in k0..k_end {
765 let a_val0 = _mm256_set1_ps(a.data[a_row0_offset + kk]);
766 let a_val1 = _mm256_set1_ps(a.data[a_row1_offset + kk]);
767 let b_row_offset = kk * n;
768 let b_vals0 = _mm256_loadu_ps(b.data.as_ptr().add(b_row_offset + j));
769 let b_vals1 =
770 _mm256_loadu_ps(b.data.as_ptr().add(b_row_offset + j + 8));
771
772 sum0_0 = _mm256_fmadd_ps(a_val0, b_vals0, sum0_0);
775 sum0_1 = _mm256_fmadd_ps(a_val0, b_vals1, sum0_1);
776 sum1_0 = _mm256_fmadd_ps(a_val1, b_vals0, sum1_0);
777 sum1_1 = _mm256_fmadd_ps(a_val1, b_vals1, sum1_1);
778 }
779
780 _mm256_storeu_ps(result.as_mut_ptr().add(result_row0_offset + j), sum0_0);
781 _mm256_storeu_ps(
782 result.as_mut_ptr().add(result_row0_offset + j + 8),
783 sum0_1,
784 );
785 _mm256_storeu_ps(result.as_mut_ptr().add(result_row1_offset + j), sum1_0);
786 _mm256_storeu_ps(
787 result.as_mut_ptr().add(result_row1_offset + j + 8),
788 sum1_1,
789 );
790 j += 16;
791 }
792
793 while j + 8 <= j_end {
795 let mut sum0 = _mm256_loadu_ps(result.as_ptr().add(result_row0_offset + j));
796 let mut sum1 = _mm256_loadu_ps(result.as_ptr().add(result_row1_offset + j));
797
798 for kk in k0..k_end {
799 let a_val0 = _mm256_set1_ps(a.data[a_row0_offset + kk]);
800 let a_val1 = _mm256_set1_ps(a.data[a_row1_offset + kk]);
801 let b_vals = _mm256_loadu_ps(b.data.as_ptr().add(kk * n + j));
802
803 sum0 = _mm256_fmadd_ps(a_val0, b_vals, sum0);
804 sum1 = _mm256_fmadd_ps(a_val1, b_vals, sum1);
805 }
806
807 _mm256_storeu_ps(result.as_mut_ptr().add(result_row0_offset + j), sum0);
808 _mm256_storeu_ps(result.as_mut_ptr().add(result_row1_offset + j), sum1);
809 j += 8;
810 }
811
812 for j in j..j_end {
814 let mut sum0 = result[result_row0_offset + j];
815 let mut sum1 = result[result_row1_offset + j];
816 for kk in k0..k_end {
817 let b_val = b.data[kk * n + j];
818 sum0 += a.data[a_row0_offset + kk] * b_val;
819 sum1 += a.data[a_row1_offset + kk] * b_val;
820 }
821 result[result_row0_offset + j] = sum0;
822 result[result_row1_offset + j] = sum1;
823 }
824
825 i += 2;
826 }
827
828 if i < i_end {
830 let a_row_offset = i * k;
831 let result_row_offset = i * n;
832
833 let mut j = j0;
834 while j + 8 <= j_end {
835 let mut sum = _mm256_loadu_ps(result.as_ptr().add(result_row_offset + j));
836
837 for kk in k0..k_end {
838 let a_val = _mm256_set1_ps(a.data[a_row_offset + kk]);
839 let b_vals = _mm256_loadu_ps(b.data.as_ptr().add(kk * n + j));
840 sum = _mm256_fmadd_ps(a_val, b_vals, sum);
841 }
842
843 _mm256_storeu_ps(result.as_mut_ptr().add(result_row_offset + j), sum);
844 j += 8;
845 }
846
847 for j in j..j_end {
848 let mut sum = result[result_row_offset + j];
849 for kk in k0..k_end {
850 sum += a.data[a_row_offset + kk] * b.data[kk * n + j];
851 }
852 result[result_row_offset + j] = sum;
853 }
854 }
855 }
856 }
857 }
858
859 Array::new(vec![m, n], result)
860}
861
862pub use super::simd_conv::conv1d_simd;