1pub fn elementwise_simd_supported() -> bool {
5 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
8 {
9 if std::is_x86_feature_detected!("avx2") {
12 return true;
13 }
14 if std::is_x86_feature_detected!("sse2") {
15 return true;
16 }
17 false
18 }
19 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
20 {
21 false
23 }
24}
25
26use crate::array::Array;
27use crate::llo::reduction::ReductionKind;
28use crate::llo::ElementwiseKind;
29use anyhow::{anyhow, Result};
30use rayon::prelude::*;
31
32pub fn elementwise_simd(a: &Array, b: &Array, kind: ElementwiseKind) -> Result<Array> {
41 if a.shape != b.shape {
47 return Err(anyhow!("shape mismatch in simd elementwise"));
48 }
49
50 let mut out = Array::<f32>::zeros(a.shape.clone());
51 let n = a.len();
52
53 let mut i = 0usize;
54
55 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
57 unsafe {
58 if std::is_x86_feature_detected!("avx2") {
59 #[cfg(target_arch = "x86")]
60 use std::arch::x86::*;
61 #[cfg(target_arch = "x86_64")]
62 use std::arch::x86_64::*;
63
64 while i + 8 <= n {
65 let pa = _mm256_loadu_ps(a.data.as_ptr().add(i));
66 let pb = _mm256_loadu_ps(b.data.as_ptr().add(i));
67 let pr = match kind {
68 ElementwiseKind::Add => _mm256_add_ps(pa, pb),
69 ElementwiseKind::Mul => _mm256_mul_ps(pa, pb),
70 ElementwiseKind::Sub => _mm256_sub_ps(pa, pb),
71 ElementwiseKind::Div => _mm256_div_ps(pa, pb),
72 ElementwiseKind::Sqrt => _mm256_sqrt_ps(pa),
73 ElementwiseKind::Sin
75 | ElementwiseKind::Cos
76 | ElementwiseKind::Tan
77 | ElementwiseKind::Abs
78 | ElementwiseKind::Neg
79 | ElementwiseKind::Exp
80 | ElementwiseKind::Log
81 | ElementwiseKind::Pow
82 | ElementwiseKind::Asin
83 | ElementwiseKind::Acos
84 | ElementwiseKind::Atan
85 | ElementwiseKind::Relu
86 | ElementwiseKind::LeakyRelu
87 | ElementwiseKind::Sigmoid
88 | ElementwiseKind::Tanh
89 | ElementwiseKind::Softplus => {
90 let mut tmp_a = [0.0f32; 8];
92 let mut tmp_b = [0.0f32; 8];
93 _mm256_storeu_ps(tmp_a.as_mut_ptr(), pa);
94 _mm256_storeu_ps(tmp_b.as_mut_ptr(), pb);
95 for j in 0..8 {
96 tmp_a[j] = match kind {
97 ElementwiseKind::Sin => tmp_a[j].sin(),
98 ElementwiseKind::Cos => tmp_a[j].cos(),
99 ElementwiseKind::Tan => tmp_a[j].tan(),
100 ElementwiseKind::Abs => tmp_a[j].abs(),
101 ElementwiseKind::Neg => -tmp_a[j],
102 ElementwiseKind::Exp => tmp_a[j].exp(),
103 ElementwiseKind::Log => tmp_a[j].ln(),
104 ElementwiseKind::Pow => tmp_a[j].powf(tmp_b[j]),
105 ElementwiseKind::Asin => tmp_a[j].asin(),
106 ElementwiseKind::Acos => tmp_a[j].acos(),
107 ElementwiseKind::Atan => tmp_a[j].atan(),
108 ElementwiseKind::Relu => tmp_a[j].max(0.0),
109 ElementwiseKind::LeakyRelu => {
110 if tmp_a[j] > 0.0 {
111 tmp_a[j]
112 } else {
113 0.01 * tmp_a[j]
114 }
115 }
116 ElementwiseKind::Sigmoid => 1.0 / (1.0 + (-tmp_a[j]).exp()),
117 ElementwiseKind::Tanh => tmp_a[j].tanh(),
118 ElementwiseKind::Softplus => (1.0 + tmp_a[j].exp()).ln(),
119 _ => tmp_a[j],
120 };
121 }
122 let pr = _mm256_loadu_ps(tmp_a.as_ptr());
123 pr
124 }
125 };
126 _mm256_storeu_ps(out.data.as_mut_ptr().add(i), pr);
127 i += 8;
128 }
129 } else if std::is_x86_feature_detected!("sse2") {
130 #[cfg(target_arch = "x86")]
131 use std::arch::x86::*;
132 #[cfg(target_arch = "x86_64")]
133 use std::arch::x86_64::*;
134
135 while i + 4 <= n {
136 let pa = _mm_loadu_ps(a.data.as_ptr().add(i));
137 let pb = _mm_loadu_ps(b.data.as_ptr().add(i));
138 let pr = match kind {
139 ElementwiseKind::Add => _mm_add_ps(pa, pb),
140 ElementwiseKind::Mul => _mm_mul_ps(pa, pb),
141 ElementwiseKind::Sub => _mm_sub_ps(pa, pb),
142 ElementwiseKind::Div => _mm_div_ps(pa, pb),
143 ElementwiseKind::Sqrt => _mm_sqrt_ps(pa),
144 ElementwiseKind::Sin
145 | ElementwiseKind::Cos
146 | ElementwiseKind::Tan
147 | ElementwiseKind::Abs
148 | ElementwiseKind::Neg
149 | ElementwiseKind::Exp
150 | ElementwiseKind::Log
151 | ElementwiseKind::Pow
152 | ElementwiseKind::Asin
153 | ElementwiseKind::Acos
154 | ElementwiseKind::Atan
155 | ElementwiseKind::Relu
156 | ElementwiseKind::LeakyRelu
157 | ElementwiseKind::Sigmoid
158 | ElementwiseKind::Tanh
159 | ElementwiseKind::Softplus => {
160 let mut tmp_a = [0.0f32; 4];
161 let mut tmp_b = [0.0f32; 4];
162 _mm_storeu_ps(tmp_a.as_mut_ptr(), pa);
163 _mm_storeu_ps(tmp_b.as_mut_ptr(), pb);
164 for j in 0..4 {
165 tmp_a[j] = match kind {
166 ElementwiseKind::Sin => tmp_a[j].sin(),
167 ElementwiseKind::Cos => tmp_a[j].cos(),
168 ElementwiseKind::Tan => tmp_a[j].tan(),
169 ElementwiseKind::Abs => tmp_a[j].abs(),
170 ElementwiseKind::Neg => -tmp_a[j],
171 ElementwiseKind::Exp => tmp_a[j].exp(),
172 ElementwiseKind::Log => tmp_a[j].ln(),
173 ElementwiseKind::Pow => tmp_a[j].powf(tmp_b[j]),
174 ElementwiseKind::Asin => tmp_a[j].asin(),
175 ElementwiseKind::Acos => tmp_a[j].acos(),
176 ElementwiseKind::Atan => tmp_a[j].atan(),
177 ElementwiseKind::Relu => tmp_a[j].max(0.0),
178 ElementwiseKind::LeakyRelu => {
179 if tmp_a[j] > 0.0 {
180 tmp_a[j]
181 } else {
182 0.01 * tmp_a[j]
183 }
184 }
185 ElementwiseKind::Sigmoid => 1.0 / (1.0 + (-tmp_a[j]).exp()),
186 ElementwiseKind::Tanh => tmp_a[j].tanh(),
187 ElementwiseKind::Softplus => (1.0 + tmp_a[j].exp()).ln(),
188 _ => tmp_a[j],
189 };
190 }
191 let pr = _mm_loadu_ps(tmp_a.as_ptr());
192 pr
193 }
194 };
195 _mm_storeu_ps(out.data.as_mut_ptr().add(i), pr);
196 i += 4;
197 }
198 }
199 }
200
201 for j in i..n {
203 out.data[j] = match kind {
204 ElementwiseKind::Add => a.data[j] + b.data[j],
205 ElementwiseKind::Mul => a.data[j] * b.data[j],
206 ElementwiseKind::Sub => a.data[j] - b.data[j],
207 ElementwiseKind::Div => a.data[j] / b.data[j],
208 ElementwiseKind::Sqrt => a.data[j].sqrt(),
209 ElementwiseKind::Abs => a.data[j].abs(),
210 ElementwiseKind::Neg => -a.data[j],
211 ElementwiseKind::Exp => a.data[j].exp(),
212 ElementwiseKind::Log => a.data[j].ln(),
213 ElementwiseKind::Tan => a.data[j].tan(),
214 ElementwiseKind::Pow => a.data[j].powf(b.data[j]),
215 ElementwiseKind::Sin => a.data[j].sin(),
216 ElementwiseKind::Cos => a.data[j].cos(),
217 ElementwiseKind::Asin => a.data[j].asin(),
218 ElementwiseKind::Acos => a.data[j].acos(),
219 ElementwiseKind::Atan => a.data[j].atan(),
220 ElementwiseKind::Relu => a.data[j].max(0.0),
221 ElementwiseKind::LeakyRelu => {
222 if a.data[j] > 0.0 {
223 a.data[j]
224 } else {
225 0.01 * a.data[j]
226 }
227 }
228 ElementwiseKind::Sigmoid => 1.0 / (1.0 + (-a.data[j]).exp()),
229 ElementwiseKind::Tanh => a.data[j].tanh(),
230 ElementwiseKind::Softplus => (1.0 + a.data[j].exp()).ln(),
231 };
232 }
233
234 Ok(out)
235}
236
237pub fn reduce_simd(a: &Array, axis: Option<usize>, kind: ReductionKind) -> Result<Array> {
242 if axis.is_none() {
243 let n = a.len();
244
245 match kind {
246 ReductionKind::Sum => {
247 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
248 unsafe {
249 if std::is_x86_feature_detected!("avx2") {
250 #[cfg(target_arch = "x86")]
251 use std::arch::x86::*;
252 #[cfg(target_arch = "x86_64")]
253 use std::arch::x86_64::*;
254
255 let mut i = 0usize;
256 let mut acc = _mm256_setzero_ps();
257
258 while i + 8 <= n {
259 let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
260 acc = _mm256_add_ps(acc, p);
261 i += 8;
262 }
263
264 let mut s = [0.0f32; 8];
266 _mm256_storeu_ps(s.as_mut_ptr(), acc);
267 let mut sum = s.iter().copied().sum::<f32>();
268
269 while i < n {
271 sum += a.data[i];
272 i += 1;
273 }
274
275 return Ok(Array::new(vec![1], vec![sum]));
276 }
277 }
278
279 let sum: f32 = a.data.iter().copied().sum();
281 Ok(Array::new(vec![1], vec![sum]))
282 }
283 ReductionKind::Max => {
284 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
285 unsafe {
286 if std::is_x86_feature_detected!("avx2") {
287 #[cfg(target_arch = "x86")]
288 use std::arch::x86::*;
289 #[cfg(target_arch = "x86_64")]
290 use std::arch::x86_64::*;
291
292 let mut i = 0usize;
293 let mut acc = _mm256_set1_ps(f32::NEG_INFINITY);
294
295 while i + 8 <= n {
296 let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
297 acc = _mm256_max_ps(acc, p);
298 i += 8;
299 }
300
301 let mut s = [0.0f32; 8];
303 _mm256_storeu_ps(s.as_mut_ptr(), acc);
304 let mut max_val = s[0];
305 for &v in &s[1..] {
306 if v > max_val {
307 max_val = v;
308 }
309 }
310
311 while i < n {
313 if a.data[i] > max_val {
314 max_val = a.data[i];
315 }
316 i += 1;
317 }
318
319 return Ok(Array::new(vec![1], vec![max_val]));
320 }
321 }
322
323 let max_val = a
325 .data
326 .iter()
327 .copied()
328 .fold(f32::NEG_INFINITY, |acc, x| acc.max(x));
329 Ok(Array::new(vec![1], vec![max_val]))
330 }
331 ReductionKind::Min => {
332 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
333 unsafe {
334 if std::is_x86_feature_detected!("avx2") {
335 #[cfg(target_arch = "x86")]
336 use std::arch::x86::*;
337 #[cfg(target_arch = "x86_64")]
338 use std::arch::x86_64::*;
339
340 let mut i = 0usize;
341 let mut acc = _mm256_set1_ps(f32::INFINITY);
342
343 while i + 8 <= n {
344 let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
345 acc = _mm256_min_ps(acc, p);
346 i += 8;
347 }
348
349 let mut s = [0.0f32; 8];
351 _mm256_storeu_ps(s.as_mut_ptr(), acc);
352 let mut min_val = s[0];
353 for &v in &s[1..] {
354 if v < min_val {
355 min_val = v;
356 }
357 }
358
359 while i < n {
361 if a.data[i] < min_val {
362 min_val = a.data[i];
363 }
364 i += 1;
365 }
366
367 return Ok(Array::new(vec![1], vec![min_val]));
368 }
369 }
370
371 let min_val = a
373 .data
374 .iter()
375 .copied()
376 .fold(f32::INFINITY, |acc, x| acc.min(x));
377 Ok(Array::new(vec![1], vec![min_val]))
378 }
379 ReductionKind::Mean => {
380 let sum_result = reduce_simd(a, axis, ReductionKind::Sum)?;
382 let mean = sum_result.data[0] / n as f32;
383 Ok(Array::new(vec![1], vec![mean]))
384 }
385 ReductionKind::ArgMax | ReductionKind::Variance => {
386 crate::backend::cpu::scalar::reduce_scalar(a, None, kind)
388 }
389 }
390 } else {
391 let axis = axis.unwrap();
393
394 if axis == a.shape.len() - 1 {
396 return reduce_last_axis_simd(a, axis, kind);
397 }
398
399 crate::backend::cpu::scalar::reduce_scalar(a, Some(axis), kind)
402 }
403}
404
405fn reduce_last_axis_simd(a: &Array, axis: usize, kind: ReductionKind) -> Result<Array> {
408 let mut out_shape: Vec<usize> = a
410 .shape
411 .iter()
412 .enumerate()
413 .filter(|(i, _)| *i != axis)
414 .map(|(_, &d)| d)
415 .collect();
416
417 if out_shape.is_empty() {
418 out_shape.push(1);
419 }
420
421 let out_size: usize = out_shape.iter().product();
422 let axis_size = a.shape[axis];
423 let mut out_data = vec![0.0; out_size];
424
425 match kind {
426 ReductionKind::Sum | ReductionKind::Mean => {
427 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
428 {
429 if std::is_x86_feature_detected!("avx2") {
430 out_data
431 .par_iter_mut()
432 .enumerate()
433 .for_each(|(row_idx, out_val)| {
434 let start = row_idx * axis_size;
435 let end = start + axis_size;
436
437 unsafe {
438 #[cfg(target_arch = "x86")]
439 use std::arch::x86::*;
440 #[cfg(target_arch = "x86_64")]
441 use std::arch::x86_64::*;
442
443 let mut acc = _mm256_setzero_ps();
444 let mut i = start;
445
446 while i + 8 <= end {
448 let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
449 acc = _mm256_add_ps(acc, p);
450 i += 8;
451 }
452
453 let mut s = [0.0f32; 8];
455 _mm256_storeu_ps(s.as_mut_ptr(), acc);
456 let mut sum: f32 = s.iter().sum();
457
458 while i < end {
460 sum += a.data[i];
461 i += 1;
462 }
463
464 *out_val = sum;
465 }
466 });
467
468 if kind == ReductionKind::Mean {
469 out_data.par_iter_mut().for_each(|x| *x /= axis_size as f32);
470 }
471
472 return Ok(Array::new(out_shape, out_data));
473 }
474 }
475
476 return crate::backend::cpu::scalar::reduce_last_axis_optimized(
478 a, axis_size, out_size, out_shape, kind,
479 );
480 }
481 ReductionKind::Max => {
482 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
483 {
484 if std::is_x86_feature_detected!("avx2") {
485 out_data
486 .par_iter_mut()
487 .enumerate()
488 .for_each(|(row_idx, out_val)| {
489 let start = row_idx * axis_size;
490 let end = start + axis_size;
491
492 unsafe {
493 #[cfg(target_arch = "x86")]
494 use std::arch::x86::*;
495 #[cfg(target_arch = "x86_64")]
496 use std::arch::x86_64::*;
497
498 let mut acc = _mm256_set1_ps(f32::NEG_INFINITY);
499 let mut i = start;
500
501 while i + 8 <= end {
502 let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
503 acc = _mm256_max_ps(acc, p);
504 i += 8;
505 }
506
507 let mut s = [0.0f32; 8];
509 _mm256_storeu_ps(s.as_mut_ptr(), acc);
510 let mut max_val = s[0];
511 for &v in &s[1..] {
512 if v > max_val {
513 max_val = v;
514 }
515 }
516
517 while i < end {
519 if a.data[i] > max_val {
520 max_val = a.data[i];
521 }
522 i += 1;
523 }
524
525 *out_val = max_val;
526 }
527 });
528
529 return Ok(Array::new(out_shape, out_data));
530 }
531 }
532
533 return crate::backend::cpu::scalar::reduce_last_axis_optimized(
535 a, axis_size, out_size, out_shape, kind,
536 );
537 }
538 _ => {
539 return crate::backend::cpu::scalar::reduce_last_axis_optimized(
541 a, axis_size, out_size, out_shape, kind,
542 );
543 }
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550 use crate::backend::cpu::scalar;
551
552 fn make_arrays(len: usize) -> (Array, Array) {
553 let a = (0..len).map(|i| i as f32 * 0.5 + 0.1).collect::<Vec<_>>();
554 let b = (0..len).map(|i| (i as f32).sin()).collect::<Vec<_>>();
555 (Array::new(vec![len], a), Array::new(vec![len], b))
556 }
557
558 #[test]
559 fn simd_add_matches_scalar() {
560 for len in &[1usize, 3, 7, 8, 15, 16, 33, 64] {
561 let (a, b) = make_arrays(*len);
562 let out_simd = elementwise_simd(&a, &b, ElementwiseKind::Add).unwrap();
563 let out_scalar = scalar::elementwise_scalar(&a, &b, ElementwiseKind::Add).unwrap();
564 assert_eq!(out_simd.data, out_scalar.data);
565 }
566 }
567
568 #[test]
569 fn simd_mul_matches_scalar() {
570 for len in &[1usize, 3, 7, 8, 15, 16, 33, 64] {
571 let (a, b) = make_arrays(*len);
572 let out_simd = elementwise_simd(&a, &b, ElementwiseKind::Mul).unwrap();
573 let out_scalar = scalar::elementwise_scalar(&a, &b, ElementwiseKind::Mul).unwrap();
574 assert_eq!(out_simd.data, out_scalar.data);
575 }
576 }
577}
578
579pub fn dot_simd(a: &Array, b: &Array) -> Result<f32> {
581 if a.shape.len() != 1 || b.shape.len() != 1 {
582 return Err(anyhow!("dot_simd: both inputs must be 1-D arrays"));
583 }
584 if a.shape[0] != b.shape[0] {
585 return Err(anyhow!("dot_simd: arrays must have same length"));
586 }
587
588 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
589 {
590 if std::is_x86_feature_detected!("fma") && std::is_x86_feature_detected!("avx2") {
591 unsafe {
593 return dot_simd_avx2_fma(a, b);
594 }
595 }
596 }
597
598 crate::backend::cpu::scalar::dot_scalar(a, b)
600}
601
602#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
603#[target_feature(enable = "avx2,fma")]
604unsafe fn dot_simd_avx2_fma(a: &Array, b: &Array) -> Result<f32> {
605 #[cfg(target_arch = "x86")]
606 use std::arch::x86::*;
607 #[cfg(target_arch = "x86_64")]
608 use std::arch::x86_64::*;
609
610 let n = a.data.len();
611 let mut sum = _mm256_setzero_ps();
612
613 let chunks = n / 8;
615 for i in 0..chunks {
616 let offset = i * 8;
617 let va = _mm256_loadu_ps(a.data.as_ptr().add(offset));
618 let vb = _mm256_loadu_ps(b.data.as_ptr().add(offset));
619 sum = _mm256_fmadd_ps(va, vb, sum);
621 }
622
623 let mut result = [0.0f32; 8];
625 _mm256_storeu_ps(result.as_mut_ptr(), sum);
626 let mut total = result.iter().sum::<f32>();
627
628 for i in (chunks * 8)..n {
630 total += a.data[i] * b.data[i];
631 }
632
633 Ok(total)
634}
635
636pub fn matmul_simd(a: &Array, b: &Array) -> Array {
639 if a.shape.len() != 2 || b.shape.len() != 2 {
640 panic!("matmul_simd: both inputs must be 2-D arrays");
641 }
642
643 let m = a.shape[0];
644 let k = a.shape[1];
645 let n = b.shape[1];
646
647 if k != b.shape[0] {
648 panic!(
649 "matmul_simd: inner dimension mismatch: {} != {}",
650 k, b.shape[0]
651 );
652 }
653
654 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
655 {
656 if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
657 return matmul_simd_parallel(a, b, m, k, n);
660 }
661 }
662
663 super::matmul_scalar_direct(a, b)
665}
666
667#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
670fn matmul_simd_parallel(a: &Array, b: &Array, m: usize, k: usize, n: usize) -> Array {
671 use rayon::prelude::*;
672
673 let block_size = if m >= 2048 { 256 } else { 128 };
676 let mut result = vec![0.0f32; m * n];
677
678 result
679 .par_chunks_mut(block_size * n)
680 .enumerate()
681 .for_each(|(block_idx, out_block)| {
682 let start = block_idx * block_size;
683 let end = (start + block_size).min(m);
684 let block_rows = end - start;
685
686 let a_block_start = start * k;
688 let a_block_end = end * k;
689 let a_block_slice = &a.data[a_block_start..a_block_end];
690
691 let a_block = Array::new(vec![block_rows, k], a_block_slice.to_vec());
693
694 unsafe {
697 let block_result = matmul_simd_avx2_fma_blocked(
698 &a_block,
699 b,
700 block_rows,
701 k,
702 n,
703 vec![0.0f32; block_rows * n],
704 );
705 out_block.copy_from_slice(&block_result.data);
706 }
707 });
708
709 Array::new(vec![m, n], result)
710}
711
712#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
713#[target_feature(enable = "avx2,fma")]
714unsafe fn matmul_simd_avx2_fma_blocked(
715 a: &Array,
716 b: &Array,
717 m: usize,
718 k: usize,
719 n: usize,
720 mut result: Vec<f32>,
721) -> Array {
722 #[cfg(target_arch = "x86")]
723 use std::arch::x86::*;
724 #[cfg(target_arch = "x86_64")]
725 use std::arch::x86_64::*;
726
727 const BLOCK_M: usize = 96; const BLOCK_N: usize = 256; const BLOCK_K: usize = 512; for i0 in (0..m).step_by(BLOCK_M) {
734 let i_end = (i0 + BLOCK_M).min(m);
735
736 for j0 in (0..n).step_by(BLOCK_N) {
737 let j_end = (j0 + BLOCK_N).min(n);
738
739 for k0 in (0..k).step_by(BLOCK_K) {
740 let k_end = (k0 + BLOCK_K).min(k);
741
742 let mut i = i0;
744 while i + 2 <= i_end {
745 let a_row0_offset = i * k;
746 let a_row1_offset = (i + 1) * k;
747 let result_row0_offset = i * n;
748 let result_row1_offset = (i + 1) * n;
749
750 let mut j = j0;
752 while j + 16 <= j_end {
753 let mut sum0_0 =
754 _mm256_loadu_ps(result.as_ptr().add(result_row0_offset + j));
755 let mut sum0_1 =
756 _mm256_loadu_ps(result.as_ptr().add(result_row0_offset + j + 8));
757 let mut sum1_0 =
758 _mm256_loadu_ps(result.as_ptr().add(result_row1_offset + j));
759 let mut sum1_1 =
760 _mm256_loadu_ps(result.as_ptr().add(result_row1_offset + j + 8));
761
762 for kk in k0..k_end {
764 let a_val0 = _mm256_set1_ps(a.data[a_row0_offset + kk]);
765 let a_val1 = _mm256_set1_ps(a.data[a_row1_offset + kk]);
766 let b_row_offset = kk * n;
767 let b_vals0 = _mm256_loadu_ps(b.data.as_ptr().add(b_row_offset + j));
768 let b_vals1 =
769 _mm256_loadu_ps(b.data.as_ptr().add(b_row_offset + j + 8));
770
771 sum0_0 = _mm256_fmadd_ps(a_val0, b_vals0, sum0_0);
774 sum0_1 = _mm256_fmadd_ps(a_val0, b_vals1, sum0_1);
775 sum1_0 = _mm256_fmadd_ps(a_val1, b_vals0, sum1_0);
776 sum1_1 = _mm256_fmadd_ps(a_val1, b_vals1, sum1_1);
777 }
778
779 _mm256_storeu_ps(result.as_mut_ptr().add(result_row0_offset + j), sum0_0);
780 _mm256_storeu_ps(
781 result.as_mut_ptr().add(result_row0_offset + j + 8),
782 sum0_1,
783 );
784 _mm256_storeu_ps(result.as_mut_ptr().add(result_row1_offset + j), sum1_0);
785 _mm256_storeu_ps(
786 result.as_mut_ptr().add(result_row1_offset + j + 8),
787 sum1_1,
788 );
789 j += 16;
790 }
791
792 while j + 8 <= j_end {
794 let mut sum0 = _mm256_loadu_ps(result.as_ptr().add(result_row0_offset + j));
795 let mut sum1 = _mm256_loadu_ps(result.as_ptr().add(result_row1_offset + j));
796
797 for kk in k0..k_end {
798 let a_val0 = _mm256_set1_ps(a.data[a_row0_offset + kk]);
799 let a_val1 = _mm256_set1_ps(a.data[a_row1_offset + kk]);
800 let b_vals = _mm256_loadu_ps(b.data.as_ptr().add(kk * n + j));
801
802 sum0 = _mm256_fmadd_ps(a_val0, b_vals, sum0);
803 sum1 = _mm256_fmadd_ps(a_val1, b_vals, sum1);
804 }
805
806 _mm256_storeu_ps(result.as_mut_ptr().add(result_row0_offset + j), sum0);
807 _mm256_storeu_ps(result.as_mut_ptr().add(result_row1_offset + j), sum1);
808 j += 8;
809 }
810
811 for j in j..j_end {
813 let mut sum0 = result[result_row0_offset + j];
814 let mut sum1 = result[result_row1_offset + j];
815 for kk in k0..k_end {
816 let b_val = b.data[kk * n + j];
817 sum0 += a.data[a_row0_offset + kk] * b_val;
818 sum1 += a.data[a_row1_offset + kk] * b_val;
819 }
820 result[result_row0_offset + j] = sum0;
821 result[result_row1_offset + j] = sum1;
822 }
823
824 i += 2;
825 }
826
827 if i < i_end {
829 let a_row_offset = i * k;
830 let result_row_offset = i * n;
831
832 let mut j = j0;
833 while j + 8 <= j_end {
834 let mut sum = _mm256_loadu_ps(result.as_ptr().add(result_row_offset + j));
835
836 for kk in k0..k_end {
837 let a_val = _mm256_set1_ps(a.data[a_row_offset + kk]);
838 let b_vals = _mm256_loadu_ps(b.data.as_ptr().add(kk * n + j));
839 sum = _mm256_fmadd_ps(a_val, b_vals, sum);
840 }
841
842 _mm256_storeu_ps(result.as_mut_ptr().add(result_row_offset + j), sum);
843 j += 8;
844 }
845
846 for j in j..j_end {
847 let mut sum = result[result_row_offset + j];
848 for kk in k0..k_end {
849 sum += a.data[a_row_offset + kk] * b.data[kk * n + j];
850 }
851 result[result_row_offset + j] = sum;
852 }
853 }
854 }
855 }
856 }
857
858 Array::new(vec![m, n], result)
859}
860
861pub use super::simd_conv::conv1d_simd;