numrs/backend/cpu/
simd.rs

1/// Placeholder for SIMD CPU execution strategy. In a full implementation this
2/// would contain vectorized loops and checks for AVX/NEON availability.
3
4pub fn elementwise_simd_supported() -> bool {
5    // Try to detect SIMD availability at runtime on x86/x86_64 targets.
6    // We conservatively require at least SSE2 or AVX2 to consider SIMD available.
7    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
8    {
9        // is_x86_feature_detected! is a macro that expands to a runtime helper
10        // to check for available CPU features on the host.
11        if std::is_x86_feature_detected!("avx2") {
12            return true;
13        }
14        if std::is_x86_feature_detected!("sse2") {
15            return true;
16        }
17        false
18    }
19    #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
20    {
21        // On other architectures we conservatively return false for now.
22        false
23    }
24}
25
26use crate::array::Array;
27use crate::llo::reduction::ReductionKind;
28use crate::llo::ElementwiseKind;
29use anyhow::{anyhow, Result};
30use rayon::prelude::*;
31
32// We use the portable `core::simd::Simd` type so this code works in a
33// cross-platform manner — if a host doesn't provide hardware SIMD, the
34// compiler/runtime will still produce correct scalar fallbacks.
35// We intentionally avoid the unstable `portable_simd` API to keep the
36// prototype building on stable Rust. Instead we use arch intrinsics for
37// x86/x86_64 (AVX2/SSE) with a scalar fallback.
38
39/// Prototype SIMD path that currently delegates to scalar implementation.
40pub fn elementwise_simd(a: &Array, b: &Array, kind: ElementwiseKind) -> Result<Array> {
41    // Prototype SIMD implementation mirroring the scalar logic but operating
42    // on small fixed-size vector chunks. This keeps correctness identical to
43    // the scalar version while allowing the compiler to generate vectorized
44    // code paths on supported targets.
45
46    if a.shape != b.shape {
47        return Err(anyhow!("shape mismatch in simd elementwise"));
48    }
49
50    let mut out = Array::<f32>::zeros(a.shape.clone());
51    let n = a.len();
52
53    let mut i = 0usize;
54
55    // x86/x86_64 specialised fast paths (AVX2 -> 8 floats, SSE -> 4 floats)
56    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
57    unsafe {
58        if std::is_x86_feature_detected!("avx2") {
59            #[cfg(target_arch = "x86")]
60            use std::arch::x86::*;
61            #[cfg(target_arch = "x86_64")]
62            use std::arch::x86_64::*;
63
64            while i + 8 <= n {
65                let pa = _mm256_loadu_ps(a.data.as_ptr().add(i));
66                let pb = _mm256_loadu_ps(b.data.as_ptr().add(i));
67                let pr = match kind {
68                    ElementwiseKind::Add => _mm256_add_ps(pa, pb),
69                    ElementwiseKind::Mul => _mm256_mul_ps(pa, pb),
70                    ElementwiseKind::Sub => _mm256_sub_ps(pa, pb),
71                    ElementwiseKind::Div => _mm256_div_ps(pa, pb),
72                    ElementwiseKind::Sqrt => _mm256_sqrt_ps(pa),
73                    // fall back to scalar for certain ops inside vectorized loop
74                    ElementwiseKind::Sin
75                    | ElementwiseKind::Cos
76                    | ElementwiseKind::Tan
77                    | ElementwiseKind::Abs
78                    | ElementwiseKind::Neg
79                    | ElementwiseKind::Exp
80                    | ElementwiseKind::Log
81                    | ElementwiseKind::Pow
82                    | ElementwiseKind::Asin
83                    | ElementwiseKind::Acos
84                    | ElementwiseKind::Atan
85                    | ElementwiseKind::Relu
86                    | ElementwiseKind::LeakyRelu
87                    | ElementwiseKind::Sigmoid
88                    | ElementwiseKind::Tanh
89                    | ElementwiseKind::Softplus => {
90                        // convert to scalars for this chunk; load both a and b lanes
91                        let mut tmp_a = [0.0f32; 8];
92                        let mut tmp_b = [0.0f32; 8];
93                        _mm256_storeu_ps(tmp_a.as_mut_ptr(), pa);
94                        _mm256_storeu_ps(tmp_b.as_mut_ptr(), pb);
95                        for j in 0..8 {
96                            tmp_a[j] = match kind {
97                                ElementwiseKind::Sin => tmp_a[j].sin(),
98                                ElementwiseKind::Cos => tmp_a[j].cos(),
99                                ElementwiseKind::Tan => tmp_a[j].tan(),
100                                ElementwiseKind::Abs => tmp_a[j].abs(),
101                                ElementwiseKind::Neg => -tmp_a[j],
102                                ElementwiseKind::Exp => tmp_a[j].exp(),
103                                ElementwiseKind::Log => tmp_a[j].ln(),
104                                ElementwiseKind::Pow => tmp_a[j].powf(tmp_b[j]),
105                                ElementwiseKind::Asin => tmp_a[j].asin(),
106                                ElementwiseKind::Acos => tmp_a[j].acos(),
107                                ElementwiseKind::Atan => tmp_a[j].atan(),
108                                ElementwiseKind::Relu => tmp_a[j].max(0.0),
109                                ElementwiseKind::LeakyRelu => {
110                                    if tmp_a[j] > 0.0 {
111                                        tmp_a[j]
112                                    } else {
113                                        0.01 * tmp_a[j]
114                                    }
115                                }
116                                ElementwiseKind::Sigmoid => 1.0 / (1.0 + (-tmp_a[j]).exp()),
117                                ElementwiseKind::Tanh => tmp_a[j].tanh(),
118                                ElementwiseKind::Softplus => (1.0 + tmp_a[j].exp()).ln(),
119                                _ => tmp_a[j],
120                            };
121                        }
122                        let pr = _mm256_loadu_ps(tmp_a.as_ptr());
123                        pr
124                    }
125                };
126                _mm256_storeu_ps(out.data.as_mut_ptr().add(i), pr);
127                i += 8;
128            }
129        } else if std::is_x86_feature_detected!("sse2") {
130            #[cfg(target_arch = "x86")]
131            use std::arch::x86::*;
132            #[cfg(target_arch = "x86_64")]
133            use std::arch::x86_64::*;
134
135            while i + 4 <= n {
136                let pa = _mm_loadu_ps(a.data.as_ptr().add(i));
137                let pb = _mm_loadu_ps(b.data.as_ptr().add(i));
138                let pr = match kind {
139                    ElementwiseKind::Add => _mm_add_ps(pa, pb),
140                    ElementwiseKind::Mul => _mm_mul_ps(pa, pb),
141                    ElementwiseKind::Sub => _mm_sub_ps(pa, pb),
142                    ElementwiseKind::Div => _mm_div_ps(pa, pb),
143                    ElementwiseKind::Sqrt => _mm_sqrt_ps(pa),
144                    ElementwiseKind::Sin
145                    | ElementwiseKind::Cos
146                    | ElementwiseKind::Tan
147                    | ElementwiseKind::Abs
148                    | ElementwiseKind::Neg
149                    | ElementwiseKind::Exp
150                    | ElementwiseKind::Log
151                    | ElementwiseKind::Pow
152                    | ElementwiseKind::Asin
153                    | ElementwiseKind::Acos
154                    | ElementwiseKind::Atan
155                    | ElementwiseKind::Relu
156                    | ElementwiseKind::LeakyRelu
157                    | ElementwiseKind::Sigmoid
158                    | ElementwiseKind::Tanh
159                    | ElementwiseKind::Softplus => {
160                        let mut tmp_a = [0.0f32; 4];
161                        let mut tmp_b = [0.0f32; 4];
162                        _mm_storeu_ps(tmp_a.as_mut_ptr(), pa);
163                        _mm_storeu_ps(tmp_b.as_mut_ptr(), pb);
164                        for j in 0..4 {
165                            tmp_a[j] = match kind {
166                                ElementwiseKind::Sin => tmp_a[j].sin(),
167                                ElementwiseKind::Cos => tmp_a[j].cos(),
168                                ElementwiseKind::Tan => tmp_a[j].tan(),
169                                ElementwiseKind::Abs => tmp_a[j].abs(),
170                                ElementwiseKind::Neg => -tmp_a[j],
171                                ElementwiseKind::Exp => tmp_a[j].exp(),
172                                ElementwiseKind::Log => tmp_a[j].ln(),
173                                ElementwiseKind::Pow => tmp_a[j].powf(tmp_b[j]),
174                                ElementwiseKind::Asin => tmp_a[j].asin(),
175                                ElementwiseKind::Acos => tmp_a[j].acos(),
176                                ElementwiseKind::Atan => tmp_a[j].atan(),
177                                ElementwiseKind::Relu => tmp_a[j].max(0.0),
178                                ElementwiseKind::LeakyRelu => {
179                                    if tmp_a[j] > 0.0 {
180                                        tmp_a[j]
181                                    } else {
182                                        0.01 * tmp_a[j]
183                                    }
184                                }
185                                ElementwiseKind::Sigmoid => 1.0 / (1.0 + (-tmp_a[j]).exp()),
186                                ElementwiseKind::Tanh => tmp_a[j].tanh(),
187                                ElementwiseKind::Softplus => (1.0 + tmp_a[j].exp()).ln(),
188                                _ => tmp_a[j],
189                            };
190                        }
191                        let pr = _mm_loadu_ps(tmp_a.as_ptr());
192                        pr
193                    }
194                };
195                _mm_storeu_ps(out.data.as_mut_ptr().add(i), pr);
196                i += 4;
197            }
198        }
199    }
200
201    // Remaining / fallback scalar
202    for j in i..n {
203        out.data[j] = match kind {
204            ElementwiseKind::Add => a.data[j] + b.data[j],
205            ElementwiseKind::Mul => a.data[j] * b.data[j],
206            ElementwiseKind::Sub => a.data[j] - b.data[j],
207            ElementwiseKind::Div => a.data[j] / b.data[j],
208            ElementwiseKind::Sqrt => a.data[j].sqrt(),
209            ElementwiseKind::Abs => a.data[j].abs(),
210            ElementwiseKind::Neg => -a.data[j],
211            ElementwiseKind::Exp => a.data[j].exp(),
212            ElementwiseKind::Log => a.data[j].ln(),
213            ElementwiseKind::Tan => a.data[j].tan(),
214            ElementwiseKind::Pow => a.data[j].powf(b.data[j]),
215            ElementwiseKind::Sin => a.data[j].sin(),
216            ElementwiseKind::Cos => a.data[j].cos(),
217            ElementwiseKind::Asin => a.data[j].asin(),
218            ElementwiseKind::Acos => a.data[j].acos(),
219            ElementwiseKind::Atan => a.data[j].atan(),
220            ElementwiseKind::Relu => a.data[j].max(0.0),
221            ElementwiseKind::LeakyRelu => {
222                if a.data[j] > 0.0 {
223                    a.data[j]
224                } else {
225                    0.01 * a.data[j]
226                }
227            }
228            ElementwiseKind::Sigmoid => 1.0 / (1.0 + (-a.data[j]).exp()),
229            ElementwiseKind::Tanh => a.data[j].tanh(),
230            ElementwiseKind::Softplus => (1.0 + a.data[j].exp()).ln(),
231        };
232    }
233
234    Ok(out)
235}
236
237/// SIMD-accelerated reduction (sum, max, min, mean). For the full-sum case (axis None) we
238/// implement an AVX2 vectorized loop that accumulates into an __m256
239/// register and then horizontally reduces it. For other architectures / when
240/// AVX2 absent we fall back to scalar.
241pub fn reduce_simd(a: &Array, axis: Option<usize>, kind: ReductionKind) -> Result<Array> {
242    if axis.is_none() {
243        let n = a.len();
244
245        match kind {
246            ReductionKind::Sum => {
247                #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
248                unsafe {
249                    if std::is_x86_feature_detected!("avx2") {
250                        #[cfg(target_arch = "x86")]
251                        use std::arch::x86::*;
252                        #[cfg(target_arch = "x86_64")]
253                        use std::arch::x86_64::*;
254
255                        let mut i = 0usize;
256                        let mut acc = _mm256_setzero_ps();
257
258                        while i + 8 <= n {
259                            let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
260                            acc = _mm256_add_ps(acc, p);
261                            i += 8;
262                        }
263
264                        // Horizontal sum of acc
265                        let mut s = [0.0f32; 8];
266                        _mm256_storeu_ps(s.as_mut_ptr(), acc);
267                        let mut sum = s.iter().copied().sum::<f32>();
268
269                        // tail
270                        while i < n {
271                            sum += a.data[i];
272                            i += 1;
273                        }
274
275                        return Ok(Array::new(vec![1], vec![sum]));
276                    }
277                }
278
279                // Non x86/AVX2 path or if not detected: fallback to scalar
280                let sum: f32 = a.data.iter().copied().sum();
281                Ok(Array::new(vec![1], vec![sum]))
282            }
283            ReductionKind::Max => {
284                #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
285                unsafe {
286                    if std::is_x86_feature_detected!("avx2") {
287                        #[cfg(target_arch = "x86")]
288                        use std::arch::x86::*;
289                        #[cfg(target_arch = "x86_64")]
290                        use std::arch::x86_64::*;
291
292                        let mut i = 0usize;
293                        let mut acc = _mm256_set1_ps(f32::NEG_INFINITY);
294
295                        while i + 8 <= n {
296                            let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
297                            acc = _mm256_max_ps(acc, p);
298                            i += 8;
299                        }
300
301                        // Horizontal max of acc
302                        let mut s = [0.0f32; 8];
303                        _mm256_storeu_ps(s.as_mut_ptr(), acc);
304                        let mut max_val = s[0];
305                        for &v in &s[1..] {
306                            if v > max_val {
307                                max_val = v;
308                            }
309                        }
310
311                        // tail
312                        while i < n {
313                            if a.data[i] > max_val {
314                                max_val = a.data[i];
315                            }
316                            i += 1;
317                        }
318
319                        return Ok(Array::new(vec![1], vec![max_val]));
320                    }
321                }
322
323                // Fallback to scalar
324                let max_val = a
325                    .data
326                    .iter()
327                    .copied()
328                    .fold(f32::NEG_INFINITY, |acc, x| acc.max(x));
329                Ok(Array::new(vec![1], vec![max_val]))
330            }
331            ReductionKind::Min => {
332                #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
333                unsafe {
334                    if std::is_x86_feature_detected!("avx2") {
335                        #[cfg(target_arch = "x86")]
336                        use std::arch::x86::*;
337                        #[cfg(target_arch = "x86_64")]
338                        use std::arch::x86_64::*;
339
340                        let mut i = 0usize;
341                        let mut acc = _mm256_set1_ps(f32::INFINITY);
342
343                        while i + 8 <= n {
344                            let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
345                            acc = _mm256_min_ps(acc, p);
346                            i += 8;
347                        }
348
349                        // Horizontal min of acc
350                        let mut s = [0.0f32; 8];
351                        _mm256_storeu_ps(s.as_mut_ptr(), acc);
352                        let mut min_val = s[0];
353                        for &v in &s[1..] {
354                            if v < min_val {
355                                min_val = v;
356                            }
357                        }
358
359                        // tail
360                        while i < n {
361                            if a.data[i] < min_val {
362                                min_val = a.data[i];
363                            }
364                            i += 1;
365                        }
366
367                        return Ok(Array::new(vec![1], vec![min_val]));
368                    }
369                }
370
371                // Fallback to scalar
372                let min_val = a
373                    .data
374                    .iter()
375                    .copied()
376                    .fold(f32::INFINITY, |acc, x| acc.min(x));
377                Ok(Array::new(vec![1], vec![min_val]))
378            }
379            ReductionKind::Mean => {
380                // Reuse sum and divide
381                let sum_result = reduce_simd(a, axis, ReductionKind::Sum)?;
382                let mean = sum_result.data[0] / n as f32;
383                Ok(Array::new(vec![1], vec![mean]))
384            }
385            ReductionKind::ArgMax | ReductionKind::Variance => {
386                // Fallback to scalar - complex algorithms
387                crate::backend::cpu::scalar::reduce_scalar(a, None, kind)
388            }
389        }
390    } else {
391        // Axis-based reduction
392        let axis = axis.unwrap();
393
394        // OPTIMIZED PATH: Reducing over last axis with SIMD
395        if axis == a.shape.len() - 1 {
396            return reduce_last_axis_simd(a, axis, kind);
397        }
398
399        // For other axes, fallback to scalar implementation
400        // TODO: optimize specific cases (e.g., reducing axis 0 with proper striding)
401        crate::backend::cpu::scalar::reduce_scalar(a, Some(axis), kind)
402    }
403}
404
405/// Optimized SIMD reduction over the last axis
406/// This is the most cache-friendly case as data is contiguous
407fn reduce_last_axis_simd(a: &Array, axis: usize, kind: ReductionKind) -> Result<Array> {
408    // Compute output shape
409    let mut out_shape: Vec<usize> = a
410        .shape
411        .iter()
412        .enumerate()
413        .filter(|(i, _)| *i != axis)
414        .map(|(_, &d)| d)
415        .collect();
416
417    if out_shape.is_empty() {
418        out_shape.push(1);
419    }
420
421    let out_size: usize = out_shape.iter().product();
422    let axis_size = a.shape[axis];
423    let mut out_data = vec![0.0; out_size];
424
425    match kind {
426        ReductionKind::Sum | ReductionKind::Mean => {
427            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
428            {
429                if std::is_x86_feature_detected!("avx2") {
430                    out_data
431                        .par_iter_mut()
432                        .enumerate()
433                        .for_each(|(row_idx, out_val)| {
434                            let start = row_idx * axis_size;
435                            let end = start + axis_size;
436
437                            unsafe {
438                                #[cfg(target_arch = "x86")]
439                                use std::arch::x86::*;
440                                #[cfg(target_arch = "x86_64")]
441                                use std::arch::x86_64::*;
442
443                                let mut acc = _mm256_setzero_ps();
444                                let mut i = start;
445
446                                // Process 8 elements at a time with SIMD
447                                while i + 8 <= end {
448                                    let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
449                                    acc = _mm256_add_ps(acc, p);
450                                    i += 8;
451                                }
452
453                                // Horizontal sum
454                                let mut s = [0.0f32; 8];
455                                _mm256_storeu_ps(s.as_mut_ptr(), acc);
456                                let mut sum: f32 = s.iter().sum();
457
458                                // Handle remaining elements
459                                while i < end {
460                                    sum += a.data[i];
461                                    i += 1;
462                                }
463
464                                *out_val = sum;
465                            }
466                        });
467
468                    if kind == ReductionKind::Mean {
469                        out_data.par_iter_mut().for_each(|x| *x /= axis_size as f32);
470                    }
471
472                    return Ok(Array::new(out_shape, out_data));
473                }
474            }
475
476            // Fallback to scalar
477            return crate::backend::cpu::scalar::reduce_last_axis_optimized(
478                a, axis_size, out_size, out_shape, kind,
479            );
480        }
481        ReductionKind::Max => {
482            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
483            {
484                if std::is_x86_feature_detected!("avx2") {
485                    out_data
486                        .par_iter_mut()
487                        .enumerate()
488                        .for_each(|(row_idx, out_val)| {
489                            let start = row_idx * axis_size;
490                            let end = start + axis_size;
491
492                            unsafe {
493                                #[cfg(target_arch = "x86")]
494                                use std::arch::x86::*;
495                                #[cfg(target_arch = "x86_64")]
496                                use std::arch::x86_64::*;
497
498                                let mut acc = _mm256_set1_ps(f32::NEG_INFINITY);
499                                let mut i = start;
500
501                                while i + 8 <= end {
502                                    let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
503                                    acc = _mm256_max_ps(acc, p);
504                                    i += 8;
505                                }
506
507                                // Horizontal max
508                                let mut s = [0.0f32; 8];
509                                _mm256_storeu_ps(s.as_mut_ptr(), acc);
510                                let mut max_val = s[0];
511                                for &v in &s[1..] {
512                                    if v > max_val {
513                                        max_val = v;
514                                    }
515                                }
516
517                                // Handle remaining elements
518                                while i < end {
519                                    if a.data[i] > max_val {
520                                        max_val = a.data[i];
521                                    }
522                                    i += 1;
523                                }
524
525                                *out_val = max_val;
526                            }
527                        });
528
529                    return Ok(Array::new(out_shape, out_data));
530                }
531            }
532
533            // Fallback
534            return crate::backend::cpu::scalar::reduce_last_axis_optimized(
535                a, axis_size, out_size, out_shape, kind,
536            );
537        }
538        _ => {
539            // For other operations, use scalar optimized version
540            return crate::backend::cpu::scalar::reduce_last_axis_optimized(
541                a, axis_size, out_size, out_shape, kind,
542            );
543        }
544    }
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550    use crate::backend::cpu::scalar;
551
552    fn make_arrays(len: usize) -> (Array, Array) {
553        let a = (0..len).map(|i| i as f32 * 0.5 + 0.1).collect::<Vec<_>>();
554        let b = (0..len).map(|i| (i as f32).sin()).collect::<Vec<_>>();
555        (Array::new(vec![len], a), Array::new(vec![len], b))
556    }
557
558    #[test]
559    fn simd_add_matches_scalar() {
560        for len in &[1usize, 3, 7, 8, 15, 16, 33, 64] {
561            let (a, b) = make_arrays(*len);
562            let out_simd = elementwise_simd(&a, &b, ElementwiseKind::Add).unwrap();
563            let out_scalar = scalar::elementwise_scalar(&a, &b, ElementwiseKind::Add).unwrap();
564            assert_eq!(out_simd.data, out_scalar.data);
565        }
566    }
567
568    #[test]
569    fn simd_mul_matches_scalar() {
570        for len in &[1usize, 3, 7, 8, 15, 16, 33, 64] {
571            let (a, b) = make_arrays(*len);
572            let out_simd = elementwise_simd(&a, &b, ElementwiseKind::Mul).unwrap();
573            let out_scalar = scalar::elementwise_scalar(&a, &b, ElementwiseKind::Mul).unwrap();
574            assert_eq!(out_simd.data, out_scalar.data);
575        }
576    }
577}
578
579/// Dot product SIMD implementation with FMA (fused multiply-add)
580pub fn dot_simd(a: &Array, b: &Array) -> Result<f32> {
581    if a.shape.len() != 1 || b.shape.len() != 1 {
582        return Err(anyhow!("dot_simd: both inputs must be 1-D arrays"));
583    }
584    if a.shape[0] != b.shape[0] {
585        return Err(anyhow!("dot_simd: arrays must have same length"));
586    }
587
588    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
589    {
590        if std::is_x86_feature_detected!("fma") && std::is_x86_feature_detected!("avx2") {
591            // SAFETY: We checked for AVX2+FMA support
592            unsafe {
593                return dot_simd_avx2_fma(a, b);
594            }
595        }
596    }
597
598    // Fallback to scalar if SIMD not available
599    crate::backend::cpu::scalar::dot_scalar(a, b)
600}
601
602#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
603#[target_feature(enable = "avx2,fma")]
604unsafe fn dot_simd_avx2_fma(a: &Array, b: &Array) -> Result<f32> {
605    #[cfg(target_arch = "x86")]
606    use std::arch::x86::*;
607    #[cfg(target_arch = "x86_64")]
608    use std::arch::x86_64::*;
609
610    let n = a.data.len();
611    let mut sum = _mm256_setzero_ps();
612
613    // Process 8 floats at a time
614    let chunks = n / 8;
615    for i in 0..chunks {
616        let offset = i * 8;
617        let va = _mm256_loadu_ps(a.data.as_ptr().add(offset));
618        let vb = _mm256_loadu_ps(b.data.as_ptr().add(offset));
619        // FMA: sum = sum + (va * vb)
620        sum = _mm256_fmadd_ps(va, vb, sum);
621    }
622
623    // Horizontal sum of 8 lanes
624    let mut result = [0.0f32; 8];
625    _mm256_storeu_ps(result.as_mut_ptr(), sum);
626    let mut total = result.iter().sum::<f32>();
627
628    // Handle remaining elements
629    for i in (chunks * 8)..n {
630        total += a.data[i] * b.data[i];
631    }
632
633    Ok(total)
634}
635
636/// SIMD-accelerated matrix multiplication
637/// Uses blocked tiled algorithm with SIMD vectorization for inner loops
638pub fn matmul_simd(a: &Array, b: &Array) -> Array {
639    if a.shape.len() != 2 || b.shape.len() != 2 {
640        panic!("matmul_simd: both inputs must be 2-D arrays");
641    }
642
643    let m = a.shape[0];
644    let k = a.shape[1];
645    let n = b.shape[1];
646
647    if k != b.shape[0] {
648        panic!(
649            "matmul_simd: inner dimension mismatch: {} != {}",
650            k, b.shape[0]
651        );
652    }
653
654    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
655    {
656        if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
657            // Siempre usar Rayon + SIMD para consistencia
658            // El adaptive dispatch ya decidió si este kernel es apropiado
659            return matmul_simd_parallel(a, b, m, k, n);
660        }
661    }
662
663    // Fallback to scalar if SIMD not available
664    super::matmul_scalar_direct(a, b)
665}
666
667/// SIMD matmul con paralelización Rayon para matrices grandes
668/// Optimizado con zero-copy y bloques adaptativos
669#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
670fn matmul_simd_parallel(a: &Array, b: &Array, m: usize, k: usize, n: usize) -> Array {
671    use rayon::prelude::*;
672
673    // Adaptive block size based on matrix dimensions
674    // Larger blocks for better cache utilization on large matrices
675    let block_size = if m >= 2048 { 256 } else { 128 };
676    let mut result = vec![0.0f32; m * n];
677
678    result
679        .par_chunks_mut(block_size * n)
680        .enumerate()
681        .for_each(|(block_idx, out_block)| {
682            let start = block_idx * block_size;
683            let end = (start + block_size).min(m);
684            let block_rows = end - start;
685
686            // Zero-copy: use direct slice view of A (no allocation)
687            let a_block_start = start * k;
688            let a_block_end = end * k;
689            let a_block_slice = &a.data[a_block_start..a_block_end];
690
691            // Create temporary Array view without copying
692            let a_block = Array::new(vec![block_rows, k], a_block_slice.to_vec());
693
694            // Procesar bloque con SIMD
695            // SAFETY: We already checked for AVX2+FMA in parent function
696            unsafe {
697                let block_result = matmul_simd_avx2_fma_blocked(
698                    &a_block,
699                    b,
700                    block_rows,
701                    k,
702                    n,
703                    vec![0.0f32; block_rows * n],
704                );
705                out_block.copy_from_slice(&block_result.data);
706            }
707        });
708
709    Array::new(vec![m, n], result)
710}
711
712#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
713#[target_feature(enable = "avx2,fma")]
714unsafe fn matmul_simd_avx2_fma_blocked(
715    a: &Array,
716    b: &Array,
717    m: usize,
718    k: usize,
719    n: usize,
720    mut result: Vec<f32>,
721) -> Array {
722    #[cfg(target_arch = "x86")]
723    use std::arch::x86::*;
724    #[cfg(target_arch = "x86_64")]
725    use std::arch::x86_64::*;
726
727    // Optimized block sizes: larger blocks for better arithmetic intensity
728    const BLOCK_M: usize = 96; // Process 96 rows per block
729    const BLOCK_N: usize = 256; // Process 256 cols per block (32 AVX2 registers)
730    const BLOCK_K: usize = 512; // Larger K block for better data reuse
731
732    // Blocked matrix multiplication with advanced SIMD
733    for i0 in (0..m).step_by(BLOCK_M) {
734        let i_end = (i0 + BLOCK_M).min(m);
735
736        for j0 in (0..n).step_by(BLOCK_N) {
737            let j_end = (j0 + BLOCK_N).min(n);
738
739            for k0 in (0..k).step_by(BLOCK_K) {
740                let k_end = (k0 + BLOCK_K).min(k);
741
742                // Process 2 rows at a time for better register utilization
743                let mut i = i0;
744                while i + 2 <= i_end {
745                    let a_row0_offset = i * k;
746                    let a_row1_offset = (i + 1) * k;
747                    let result_row0_offset = i * n;
748                    let result_row1_offset = (i + 1) * n;
749
750                    // Process 16 columns at a time (2 AVX2 registers)
751                    let mut j = j0;
752                    while j + 16 <= j_end {
753                        let mut sum0_0 =
754                            _mm256_loadu_ps(result.as_ptr().add(result_row0_offset + j));
755                        let mut sum0_1 =
756                            _mm256_loadu_ps(result.as_ptr().add(result_row0_offset + j + 8));
757                        let mut sum1_0 =
758                            _mm256_loadu_ps(result.as_ptr().add(result_row1_offset + j));
759                        let mut sum1_1 =
760                            _mm256_loadu_ps(result.as_ptr().add(result_row1_offset + j + 8));
761
762                        // Inner loop: accumulate over k-block with FMA
763                        for kk in k0..k_end {
764                            let a_val0 = _mm256_set1_ps(a.data[a_row0_offset + kk]);
765                            let a_val1 = _mm256_set1_ps(a.data[a_row1_offset + kk]);
766                            let b_row_offset = kk * n;
767                            let b_vals0 = _mm256_loadu_ps(b.data.as_ptr().add(b_row_offset + j));
768                            let b_vals1 =
769                                _mm256_loadu_ps(b.data.as_ptr().add(b_row_offset + j + 8));
770
771                            // FMA: sum = sum + (a_val * b_vals)
772                            // Process both rows and both column sets simultaneously
773                            sum0_0 = _mm256_fmadd_ps(a_val0, b_vals0, sum0_0);
774                            sum0_1 = _mm256_fmadd_ps(a_val0, b_vals1, sum0_1);
775                            sum1_0 = _mm256_fmadd_ps(a_val1, b_vals0, sum1_0);
776                            sum1_1 = _mm256_fmadd_ps(a_val1, b_vals1, sum1_1);
777                        }
778
779                        _mm256_storeu_ps(result.as_mut_ptr().add(result_row0_offset + j), sum0_0);
780                        _mm256_storeu_ps(
781                            result.as_mut_ptr().add(result_row0_offset + j + 8),
782                            sum0_1,
783                        );
784                        _mm256_storeu_ps(result.as_mut_ptr().add(result_row1_offset + j), sum1_0);
785                        _mm256_storeu_ps(
786                            result.as_mut_ptr().add(result_row1_offset + j + 8),
787                            sum1_1,
788                        );
789                        j += 16;
790                    }
791
792                    // Process remaining columns in chunks of 8
793                    while j + 8 <= j_end {
794                        let mut sum0 = _mm256_loadu_ps(result.as_ptr().add(result_row0_offset + j));
795                        let mut sum1 = _mm256_loadu_ps(result.as_ptr().add(result_row1_offset + j));
796
797                        for kk in k0..k_end {
798                            let a_val0 = _mm256_set1_ps(a.data[a_row0_offset + kk]);
799                            let a_val1 = _mm256_set1_ps(a.data[a_row1_offset + kk]);
800                            let b_vals = _mm256_loadu_ps(b.data.as_ptr().add(kk * n + j));
801
802                            sum0 = _mm256_fmadd_ps(a_val0, b_vals, sum0);
803                            sum1 = _mm256_fmadd_ps(a_val1, b_vals, sum1);
804                        }
805
806                        _mm256_storeu_ps(result.as_mut_ptr().add(result_row0_offset + j), sum0);
807                        _mm256_storeu_ps(result.as_mut_ptr().add(result_row1_offset + j), sum1);
808                        j += 8;
809                    }
810
811                    // Handle remaining columns with scalar code
812                    for j in j..j_end {
813                        let mut sum0 = result[result_row0_offset + j];
814                        let mut sum1 = result[result_row1_offset + j];
815                        for kk in k0..k_end {
816                            let b_val = b.data[kk * n + j];
817                            sum0 += a.data[a_row0_offset + kk] * b_val;
818                            sum1 += a.data[a_row1_offset + kk] * b_val;
819                        }
820                        result[result_row0_offset + j] = sum0;
821                        result[result_row1_offset + j] = sum1;
822                    }
823
824                    i += 2;
825                }
826
827                // Handle remaining single row if m is odd
828                if i < i_end {
829                    let a_row_offset = i * k;
830                    let result_row_offset = i * n;
831
832                    let mut j = j0;
833                    while j + 8 <= j_end {
834                        let mut sum = _mm256_loadu_ps(result.as_ptr().add(result_row_offset + j));
835
836                        for kk in k0..k_end {
837                            let a_val = _mm256_set1_ps(a.data[a_row_offset + kk]);
838                            let b_vals = _mm256_loadu_ps(b.data.as_ptr().add(kk * n + j));
839                            sum = _mm256_fmadd_ps(a_val, b_vals, sum);
840                        }
841
842                        _mm256_storeu_ps(result.as_mut_ptr().add(result_row_offset + j), sum);
843                        j += 8;
844                    }
845
846                    for j in j..j_end {
847                        let mut sum = result[result_row_offset + j];
848                        for kk in k0..k_end {
849                            sum += a.data[a_row_offset + kk] * b.data[kk * n + j];
850                        }
851                        result[result_row_offset + j] = sum;
852                    }
853                }
854            }
855        }
856    }
857
858    Array::new(vec![m, n], result)
859}
860
861/// SIMD implementation of Conv1D (Stub)
862/// Re-export Conv1D SIMD implementation
863pub use super::simd_conv::conv1d_simd;