numrs/backend/cpu/
simd.rs

1/// Placeholder for SIMD CPU execution strategy. In a full implementation this
2/// would contain vectorized loops and checks for AVX/NEON availability.
3
4pub fn elementwise_simd_supported() -> bool {
5    // Try to detect SIMD availability at runtime on x86/x86_64 targets.
6    // We conservatively require at least SSE2 or AVX2 to consider SIMD available.
7    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
8    {
9        // is_x86_feature_detected! is a macro that expands to a runtime helper
10        // to check for available CPU features on the host.
11        if std::is_x86_feature_detected!("avx2") {
12            return true;
13        }
14        if std::is_x86_feature_detected!("sse2") {
15            return true;
16        }
17        false
18    }
19    #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
20    {
21        // On other architectures we conservatively return false for now.
22        false
23    }
24}
25
26use crate::array::Array;
27use crate::llo::reduction::ReductionKind;
28use crate::llo::ElementwiseKind;
29use anyhow::{anyhow, Result};
30#[cfg(feature = "parallel")]
31use rayon::prelude::*;
32
33// We use the portable `core::simd::Simd` type so this code works in a
34// cross-platform manner — if a host doesn't provide hardware SIMD, the
35// compiler/runtime will still produce correct scalar fallbacks.
36// We intentionally avoid the unstable `portable_simd` API to keep the
37// prototype building on stable Rust. Instead we use arch intrinsics for
38// x86/x86_64 (AVX2/SSE) with a scalar fallback.
39
40/// Prototype SIMD path that currently delegates to scalar implementation.
41pub fn elementwise_simd(a: &Array, b: &Array, kind: ElementwiseKind) -> Result<Array> {
42    // Prototype SIMD implementation mirroring the scalar logic but operating
43    // on small fixed-size vector chunks. This keeps correctness identical to
44    // the scalar version while allowing the compiler to generate vectorized
45    // code paths on supported targets.
46
47    if a.shape != b.shape {
48        return Err(anyhow!("shape mismatch in simd elementwise"));
49    }
50
51    let mut out = Array::<f32>::zeros(a.shape.clone());
52    let n = a.len();
53
54    let mut i = 0usize;
55
56    // x86/x86_64 specialised fast paths (AVX2 -> 8 floats, SSE -> 4 floats)
57    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
58    unsafe {
59        if std::is_x86_feature_detected!("avx2") {
60            #[cfg(target_arch = "x86")]
61            use std::arch::x86::*;
62            #[cfg(target_arch = "x86_64")]
63            use std::arch::x86_64::*;
64
65            while i + 8 <= n {
66                let pa = _mm256_loadu_ps(a.data.as_ptr().add(i));
67                let pb = _mm256_loadu_ps(b.data.as_ptr().add(i));
68                let pr = match kind {
69                    ElementwiseKind::Add => _mm256_add_ps(pa, pb),
70                    ElementwiseKind::Mul => _mm256_mul_ps(pa, pb),
71                    ElementwiseKind::Sub => _mm256_sub_ps(pa, pb),
72                    ElementwiseKind::Div => _mm256_div_ps(pa, pb),
73                    ElementwiseKind::Sqrt => _mm256_sqrt_ps(pa),
74                    // fall back to scalar for certain ops inside vectorized loop
75                    ElementwiseKind::Sin
76                    | ElementwiseKind::Cos
77                    | ElementwiseKind::Tan
78                    | ElementwiseKind::Abs
79                    | ElementwiseKind::Neg
80                    | ElementwiseKind::Exp
81                    | ElementwiseKind::Log
82                    | ElementwiseKind::Pow
83                    | ElementwiseKind::Asin
84                    | ElementwiseKind::Acos
85                    | ElementwiseKind::Atan
86                    | ElementwiseKind::Relu
87                    | ElementwiseKind::LeakyRelu
88                    | ElementwiseKind::Sigmoid
89                    | ElementwiseKind::Tanh
90                    | ElementwiseKind::Softplus => {
91                        // convert to scalars for this chunk; load both a and b lanes
92                        let mut tmp_a = [0.0f32; 8];
93                        let mut tmp_b = [0.0f32; 8];
94                        _mm256_storeu_ps(tmp_a.as_mut_ptr(), pa);
95                        _mm256_storeu_ps(tmp_b.as_mut_ptr(), pb);
96                        for j in 0..8 {
97                            tmp_a[j] = match kind {
98                                ElementwiseKind::Sin => tmp_a[j].sin(),
99                                ElementwiseKind::Cos => tmp_a[j].cos(),
100                                ElementwiseKind::Tan => tmp_a[j].tan(),
101                                ElementwiseKind::Abs => tmp_a[j].abs(),
102                                ElementwiseKind::Neg => -tmp_a[j],
103                                ElementwiseKind::Exp => tmp_a[j].exp(),
104                                ElementwiseKind::Log => tmp_a[j].ln(),
105                                ElementwiseKind::Pow => tmp_a[j].powf(tmp_b[j]),
106                                ElementwiseKind::Asin => tmp_a[j].asin(),
107                                ElementwiseKind::Acos => tmp_a[j].acos(),
108                                ElementwiseKind::Atan => tmp_a[j].atan(),
109                                ElementwiseKind::Relu => tmp_a[j].max(0.0),
110                                ElementwiseKind::LeakyRelu => {
111                                    if tmp_a[j] > 0.0 {
112                                        tmp_a[j]
113                                    } else {
114                                        0.01 * tmp_a[j]
115                                    }
116                                }
117                                ElementwiseKind::Sigmoid => 1.0 / (1.0 + (-tmp_a[j]).exp()),
118                                ElementwiseKind::Tanh => tmp_a[j].tanh(),
119                                ElementwiseKind::Softplus => (1.0 + tmp_a[j].exp()).ln(),
120                                _ => tmp_a[j],
121                            };
122                        }
123                        let pr = _mm256_loadu_ps(tmp_a.as_ptr());
124                        pr
125                    }
126                };
127                _mm256_storeu_ps(out.data.as_mut_ptr().add(i), pr);
128                i += 8;
129            }
130        } else if std::is_x86_feature_detected!("sse2") {
131            #[cfg(target_arch = "x86")]
132            use std::arch::x86::*;
133            #[cfg(target_arch = "x86_64")]
134            use std::arch::x86_64::*;
135
136            while i + 4 <= n {
137                let pa = _mm_loadu_ps(a.data.as_ptr().add(i));
138                let pb = _mm_loadu_ps(b.data.as_ptr().add(i));
139                let pr = match kind {
140                    ElementwiseKind::Add => _mm_add_ps(pa, pb),
141                    ElementwiseKind::Mul => _mm_mul_ps(pa, pb),
142                    ElementwiseKind::Sub => _mm_sub_ps(pa, pb),
143                    ElementwiseKind::Div => _mm_div_ps(pa, pb),
144                    ElementwiseKind::Sqrt => _mm_sqrt_ps(pa),
145                    ElementwiseKind::Sin
146                    | ElementwiseKind::Cos
147                    | ElementwiseKind::Tan
148                    | ElementwiseKind::Abs
149                    | ElementwiseKind::Neg
150                    | ElementwiseKind::Exp
151                    | ElementwiseKind::Log
152                    | ElementwiseKind::Pow
153                    | ElementwiseKind::Asin
154                    | ElementwiseKind::Acos
155                    | ElementwiseKind::Atan
156                    | ElementwiseKind::Relu
157                    | ElementwiseKind::LeakyRelu
158                    | ElementwiseKind::Sigmoid
159                    | ElementwiseKind::Tanh
160                    | ElementwiseKind::Softplus => {
161                        let mut tmp_a = [0.0f32; 4];
162                        let mut tmp_b = [0.0f32; 4];
163                        _mm_storeu_ps(tmp_a.as_mut_ptr(), pa);
164                        _mm_storeu_ps(tmp_b.as_mut_ptr(), pb);
165                        for j in 0..4 {
166                            tmp_a[j] = match kind {
167                                ElementwiseKind::Sin => tmp_a[j].sin(),
168                                ElementwiseKind::Cos => tmp_a[j].cos(),
169                                ElementwiseKind::Tan => tmp_a[j].tan(),
170                                ElementwiseKind::Abs => tmp_a[j].abs(),
171                                ElementwiseKind::Neg => -tmp_a[j],
172                                ElementwiseKind::Exp => tmp_a[j].exp(),
173                                ElementwiseKind::Log => tmp_a[j].ln(),
174                                ElementwiseKind::Pow => tmp_a[j].powf(tmp_b[j]),
175                                ElementwiseKind::Asin => tmp_a[j].asin(),
176                                ElementwiseKind::Acos => tmp_a[j].acos(),
177                                ElementwiseKind::Atan => tmp_a[j].atan(),
178                                ElementwiseKind::Relu => tmp_a[j].max(0.0),
179                                ElementwiseKind::LeakyRelu => {
180                                    if tmp_a[j] > 0.0 {
181                                        tmp_a[j]
182                                    } else {
183                                        0.01 * tmp_a[j]
184                                    }
185                                }
186                                ElementwiseKind::Sigmoid => 1.0 / (1.0 + (-tmp_a[j]).exp()),
187                                ElementwiseKind::Tanh => tmp_a[j].tanh(),
188                                ElementwiseKind::Softplus => (1.0 + tmp_a[j].exp()).ln(),
189                                _ => tmp_a[j],
190                            };
191                        }
192                        let pr = _mm_loadu_ps(tmp_a.as_ptr());
193                        pr
194                    }
195                };
196                _mm_storeu_ps(out.data.as_mut_ptr().add(i), pr);
197                i += 4;
198            }
199        }
200    }
201
202    // Remaining / fallback scalar
203    for j in i..n {
204        out.data[j] = match kind {
205            ElementwiseKind::Add => a.data[j] + b.data[j],
206            ElementwiseKind::Mul => a.data[j] * b.data[j],
207            ElementwiseKind::Sub => a.data[j] - b.data[j],
208            ElementwiseKind::Div => a.data[j] / b.data[j],
209            ElementwiseKind::Sqrt => a.data[j].sqrt(),
210            ElementwiseKind::Abs => a.data[j].abs(),
211            ElementwiseKind::Neg => -a.data[j],
212            ElementwiseKind::Exp => a.data[j].exp(),
213            ElementwiseKind::Log => a.data[j].ln(),
214            ElementwiseKind::Tan => a.data[j].tan(),
215            ElementwiseKind::Pow => a.data[j].powf(b.data[j]),
216            ElementwiseKind::Sin => a.data[j].sin(),
217            ElementwiseKind::Cos => a.data[j].cos(),
218            ElementwiseKind::Asin => a.data[j].asin(),
219            ElementwiseKind::Acos => a.data[j].acos(),
220            ElementwiseKind::Atan => a.data[j].atan(),
221            ElementwiseKind::Relu => a.data[j].max(0.0),
222            ElementwiseKind::LeakyRelu => {
223                if a.data[j] > 0.0 {
224                    a.data[j]
225                } else {
226                    0.01 * a.data[j]
227                }
228            }
229            ElementwiseKind::Sigmoid => 1.0 / (1.0 + (-a.data[j]).exp()),
230            ElementwiseKind::Tanh => a.data[j].tanh(),
231            ElementwiseKind::Softplus => (1.0 + a.data[j].exp()).ln(),
232        };
233    }
234
235    Ok(out)
236}
237
238/// SIMD-accelerated reduction (sum, max, min, mean). For the full-sum case (axis None) we
239/// implement an AVX2 vectorized loop that accumulates into an __m256
240/// register and then horizontally reduces it. For other architectures / when
241/// AVX2 absent we fall back to scalar.
242pub fn reduce_simd(a: &Array, axis: Option<usize>, kind: ReductionKind) -> Result<Array> {
243    if axis.is_none() {
244        let n = a.len();
245
246        match kind {
247            ReductionKind::Sum => {
248                #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
249                unsafe {
250                    if std::is_x86_feature_detected!("avx2") {
251                        #[cfg(target_arch = "x86")]
252                        use std::arch::x86::*;
253                        #[cfg(target_arch = "x86_64")]
254                        use std::arch::x86_64::*;
255
256                        let mut i = 0usize;
257                        let mut acc = _mm256_setzero_ps();
258
259                        while i + 8 <= n {
260                            let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
261                            acc = _mm256_add_ps(acc, p);
262                            i += 8;
263                        }
264
265                        // Horizontal sum of acc
266                        let mut s = [0.0f32; 8];
267                        _mm256_storeu_ps(s.as_mut_ptr(), acc);
268                        let mut sum = s.iter().copied().sum::<f32>();
269
270                        // tail
271                        while i < n {
272                            sum += a.data[i];
273                            i += 1;
274                        }
275
276                        return Ok(Array::new(vec![1], vec![sum]));
277                    }
278                }
279
280                // Non x86/AVX2 path or if not detected: fallback to scalar
281                let sum: f32 = a.data.iter().copied().sum();
282                Ok(Array::new(vec![1], vec![sum]))
283            }
284            ReductionKind::Max => {
285                #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
286                unsafe {
287                    if std::is_x86_feature_detected!("avx2") {
288                        #[cfg(target_arch = "x86")]
289                        use std::arch::x86::*;
290                        #[cfg(target_arch = "x86_64")]
291                        use std::arch::x86_64::*;
292
293                        let mut i = 0usize;
294                        let mut acc = _mm256_set1_ps(f32::NEG_INFINITY);
295
296                        while i + 8 <= n {
297                            let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
298                            acc = _mm256_max_ps(acc, p);
299                            i += 8;
300                        }
301
302                        // Horizontal max of acc
303                        let mut s = [0.0f32; 8];
304                        _mm256_storeu_ps(s.as_mut_ptr(), acc);
305                        let mut max_val = s[0];
306                        for &v in &s[1..] {
307                            if v > max_val {
308                                max_val = v;
309                            }
310                        }
311
312                        // tail
313                        while i < n {
314                            if a.data[i] > max_val {
315                                max_val = a.data[i];
316                            }
317                            i += 1;
318                        }
319
320                        return Ok(Array::new(vec![1], vec![max_val]));
321                    }
322                }
323
324                // Fallback to scalar
325                let max_val = a
326                    .data
327                    .iter()
328                    .copied()
329                    .fold(f32::NEG_INFINITY, |acc, x| acc.max(x));
330                Ok(Array::new(vec![1], vec![max_val]))
331            }
332            ReductionKind::Min => {
333                #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
334                unsafe {
335                    if std::is_x86_feature_detected!("avx2") {
336                        #[cfg(target_arch = "x86")]
337                        use std::arch::x86::*;
338                        #[cfg(target_arch = "x86_64")]
339                        use std::arch::x86_64::*;
340
341                        let mut i = 0usize;
342                        let mut acc = _mm256_set1_ps(f32::INFINITY);
343
344                        while i + 8 <= n {
345                            let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
346                            acc = _mm256_min_ps(acc, p);
347                            i += 8;
348                        }
349
350                        // Horizontal min of acc
351                        let mut s = [0.0f32; 8];
352                        _mm256_storeu_ps(s.as_mut_ptr(), acc);
353                        let mut min_val = s[0];
354                        for &v in &s[1..] {
355                            if v < min_val {
356                                min_val = v;
357                            }
358                        }
359
360                        // tail
361                        while i < n {
362                            if a.data[i] < min_val {
363                                min_val = a.data[i];
364                            }
365                            i += 1;
366                        }
367
368                        return Ok(Array::new(vec![1], vec![min_val]));
369                    }
370                }
371
372                // Fallback to scalar
373                let min_val = a
374                    .data
375                    .iter()
376                    .copied()
377                    .fold(f32::INFINITY, |acc, x| acc.min(x));
378                Ok(Array::new(vec![1], vec![min_val]))
379            }
380            ReductionKind::Mean => {
381                // Reuse sum and divide
382                let sum_result = reduce_simd(a, axis, ReductionKind::Sum)?;
383                let mean = sum_result.data[0] / n as f32;
384                Ok(Array::new(vec![1], vec![mean]))
385            }
386            ReductionKind::ArgMax | ReductionKind::Variance => {
387                // Fallback to scalar - complex algorithms
388                crate::backend::cpu::scalar::reduce_scalar(a, None, kind)
389            }
390        }
391    } else {
392        // Axis-based reduction
393        let axis = axis.unwrap();
394
395        // OPTIMIZED PATH: Reducing over last axis with SIMD
396        if axis == a.shape.len() - 1 {
397            return reduce_last_axis_simd(a, axis, kind);
398        }
399
400        // For other axes, fallback to scalar implementation
401        // TODO: optimize specific cases (e.g., reducing axis 0 with proper striding)
402        crate::backend::cpu::scalar::reduce_scalar(a, Some(axis), kind)
403    }
404}
405
406/// Optimized SIMD reduction over the last axis
407/// This is the most cache-friendly case as data is contiguous
408fn reduce_last_axis_simd(a: &Array, axis: usize, kind: ReductionKind) -> Result<Array> {
409    // Compute output shape
410    let mut out_shape: Vec<usize> = a
411        .shape
412        .iter()
413        .enumerate()
414        .filter(|(i, _)| *i != axis)
415        .map(|(_, &d)| d)
416        .collect();
417
418    if out_shape.is_empty() {
419        out_shape.push(1);
420    }
421
422    let out_size: usize = out_shape.iter().product();
423    let axis_size = a.shape[axis];
424    let mut out_data = vec![0.0; out_size];
425
426    match kind {
427        ReductionKind::Sum | ReductionKind::Mean => {
428            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
429            {
430                if std::is_x86_feature_detected!("avx2") {
431                    out_data
432                        .par_iter_mut()
433                        .enumerate()
434                        .for_each(|(row_idx, out_val)| {
435                            let start = row_idx * axis_size;
436                            let end = start + axis_size;
437
438                            unsafe {
439                                #[cfg(target_arch = "x86")]
440                                use std::arch::x86::*;
441                                #[cfg(target_arch = "x86_64")]
442                                use std::arch::x86_64::*;
443
444                                let mut acc = _mm256_setzero_ps();
445                                let mut i = start;
446
447                                // Process 8 elements at a time with SIMD
448                                while i + 8 <= end {
449                                    let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
450                                    acc = _mm256_add_ps(acc, p);
451                                    i += 8;
452                                }
453
454                                // Horizontal sum
455                                let mut s = [0.0f32; 8];
456                                _mm256_storeu_ps(s.as_mut_ptr(), acc);
457                                let mut sum: f32 = s.iter().sum();
458
459                                // Handle remaining elements
460                                while i < end {
461                                    sum += a.data[i];
462                                    i += 1;
463                                }
464
465                                *out_val = sum;
466                            }
467                        });
468
469                    if kind == ReductionKind::Mean {
470                        out_data.par_iter_mut().for_each(|x| *x /= axis_size as f32);
471                    }
472
473                    return Ok(Array::new(out_shape, out_data));
474                }
475            }
476
477            // Fallback to scalar
478            return crate::backend::cpu::scalar::reduce_last_axis_optimized(
479                a, axis_size, out_size, out_shape, kind,
480            );
481        }
482        ReductionKind::Max => {
483            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
484            {
485                if std::is_x86_feature_detected!("avx2") {
486                    out_data
487                        .par_iter_mut()
488                        .enumerate()
489                        .for_each(|(row_idx, out_val)| {
490                            let start = row_idx * axis_size;
491                            let end = start + axis_size;
492
493                            unsafe {
494                                #[cfg(target_arch = "x86")]
495                                use std::arch::x86::*;
496                                #[cfg(target_arch = "x86_64")]
497                                use std::arch::x86_64::*;
498
499                                let mut acc = _mm256_set1_ps(f32::NEG_INFINITY);
500                                let mut i = start;
501
502                                while i + 8 <= end {
503                                    let p = _mm256_loadu_ps(a.data.as_ptr().add(i));
504                                    acc = _mm256_max_ps(acc, p);
505                                    i += 8;
506                                }
507
508                                // Horizontal max
509                                let mut s = [0.0f32; 8];
510                                _mm256_storeu_ps(s.as_mut_ptr(), acc);
511                                let mut max_val = s[0];
512                                for &v in &s[1..] {
513                                    if v > max_val {
514                                        max_val = v;
515                                    }
516                                }
517
518                                // Handle remaining elements
519                                while i < end {
520                                    if a.data[i] > max_val {
521                                        max_val = a.data[i];
522                                    }
523                                    i += 1;
524                                }
525
526                                *out_val = max_val;
527                            }
528                        });
529
530                    return Ok(Array::new(out_shape, out_data));
531                }
532            }
533
534            // Fallback
535            return crate::backend::cpu::scalar::reduce_last_axis_optimized(
536                a, axis_size, out_size, out_shape, kind,
537            );
538        }
539        _ => {
540            // For other operations, use scalar optimized version
541            return crate::backend::cpu::scalar::reduce_last_axis_optimized(
542                a, axis_size, out_size, out_shape, kind,
543            );
544        }
545    }
546}
547
548#[cfg(test)]
549mod tests {
550    use super::*;
551    use crate::backend::cpu::scalar;
552
553    fn make_arrays(len: usize) -> (Array, Array) {
554        let a = (0..len).map(|i| i as f32 * 0.5 + 0.1).collect::<Vec<_>>();
555        let b = (0..len).map(|i| (i as f32).sin()).collect::<Vec<_>>();
556        (Array::new(vec![len], a), Array::new(vec![len], b))
557    }
558
559    #[test]
560    fn simd_add_matches_scalar() {
561        for len in &[1usize, 3, 7, 8, 15, 16, 33, 64] {
562            let (a, b) = make_arrays(*len);
563            let out_simd = elementwise_simd(&a, &b, ElementwiseKind::Add).unwrap();
564            let out_scalar = scalar::elementwise_scalar(&a, &b, ElementwiseKind::Add).unwrap();
565            assert_eq!(out_simd.data, out_scalar.data);
566        }
567    }
568
569    #[test]
570    fn simd_mul_matches_scalar() {
571        for len in &[1usize, 3, 7, 8, 15, 16, 33, 64] {
572            let (a, b) = make_arrays(*len);
573            let out_simd = elementwise_simd(&a, &b, ElementwiseKind::Mul).unwrap();
574            let out_scalar = scalar::elementwise_scalar(&a, &b, ElementwiseKind::Mul).unwrap();
575            assert_eq!(out_simd.data, out_scalar.data);
576        }
577    }
578}
579
580/// Dot product SIMD implementation with FMA (fused multiply-add)
581pub fn dot_simd(a: &Array, b: &Array) -> Result<f32> {
582    if a.shape.len() != 1 || b.shape.len() != 1 {
583        return Err(anyhow!("dot_simd: both inputs must be 1-D arrays"));
584    }
585    if a.shape[0] != b.shape[0] {
586        return Err(anyhow!("dot_simd: arrays must have same length"));
587    }
588
589    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
590    {
591        if std::is_x86_feature_detected!("fma") && std::is_x86_feature_detected!("avx2") {
592            // SAFETY: We checked for AVX2+FMA support
593            unsafe {
594                return dot_simd_avx2_fma(a, b);
595            }
596        }
597    }
598
599    // Fallback to scalar if SIMD not available
600    crate::backend::cpu::scalar::dot_scalar(a, b)
601}
602
603#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
604#[target_feature(enable = "avx2,fma")]
605unsafe fn dot_simd_avx2_fma(a: &Array, b: &Array) -> Result<f32> {
606    #[cfg(target_arch = "x86")]
607    use std::arch::x86::*;
608    #[cfg(target_arch = "x86_64")]
609    use std::arch::x86_64::*;
610
611    let n = a.data.len();
612    let mut sum = _mm256_setzero_ps();
613
614    // Process 8 floats at a time
615    let chunks = n / 8;
616    for i in 0..chunks {
617        let offset = i * 8;
618        let va = _mm256_loadu_ps(a.data.as_ptr().add(offset));
619        let vb = _mm256_loadu_ps(b.data.as_ptr().add(offset));
620        // FMA: sum = sum + (va * vb)
621        sum = _mm256_fmadd_ps(va, vb, sum);
622    }
623
624    // Horizontal sum of 8 lanes
625    let mut result = [0.0f32; 8];
626    _mm256_storeu_ps(result.as_mut_ptr(), sum);
627    let mut total = result.iter().sum::<f32>();
628
629    // Handle remaining elements
630    for i in (chunks * 8)..n {
631        total += a.data[i] * b.data[i];
632    }
633
634    Ok(total)
635}
636
637/// SIMD-accelerated matrix multiplication
638/// Uses blocked tiled algorithm with SIMD vectorization for inner loops
639pub fn matmul_simd(a: &Array, b: &Array) -> Array {
640    if a.shape.len() != 2 || b.shape.len() != 2 {
641        panic!("matmul_simd: both inputs must be 2-D arrays");
642    }
643
644    let m = a.shape[0];
645    let k = a.shape[1];
646    let n = b.shape[1];
647
648    if k != b.shape[0] {
649        panic!(
650            "matmul_simd: inner dimension mismatch: {} != {}",
651            k, b.shape[0]
652        );
653    }
654
655    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
656    {
657        if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
658            // Siempre usar Rayon + SIMD para consistencia
659            // El adaptive dispatch ya decidió si este kernel es apropiado
660            return matmul_simd_parallel(a, b, m, k, n);
661        }
662    }
663
664    // Fallback to scalar if SIMD not available
665    super::matmul_scalar_direct(a, b)
666}
667
668/// SIMD matmul con paralelización Rayon para matrices grandes
669/// Optimizado con zero-copy y bloques adaptativos
670#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
671fn matmul_simd_parallel(a: &Array, b: &Array, m: usize, k: usize, n: usize) -> Array {
672    use rayon::prelude::*;
673
674    // Adaptive block size based on matrix dimensions
675    // Larger blocks for better cache utilization on large matrices
676    let block_size = if m >= 2048 { 256 } else { 128 };
677    let mut result = vec![0.0f32; m * n];
678
679    result
680        .par_chunks_mut(block_size * n)
681        .enumerate()
682        .for_each(|(block_idx, out_block)| {
683            let start = block_idx * block_size;
684            let end = (start + block_size).min(m);
685            let block_rows = end - start;
686
687            // Zero-copy: use direct slice view of A (no allocation)
688            let a_block_start = start * k;
689            let a_block_end = end * k;
690            let a_block_slice = &a.data[a_block_start..a_block_end];
691
692            // Create temporary Array view without copying
693            let a_block = Array::new(vec![block_rows, k], a_block_slice.to_vec());
694
695            // Procesar bloque con SIMD
696            // SAFETY: We already checked for AVX2+FMA in parent function
697            unsafe {
698                let block_result = matmul_simd_avx2_fma_blocked(
699                    &a_block,
700                    b,
701                    block_rows,
702                    k,
703                    n,
704                    vec![0.0f32; block_rows * n],
705                );
706                out_block.copy_from_slice(&block_result.data);
707            }
708        });
709
710    Array::new(vec![m, n], result)
711}
712
713#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
714#[target_feature(enable = "avx2,fma")]
715unsafe fn matmul_simd_avx2_fma_blocked(
716    a: &Array,
717    b: &Array,
718    m: usize,
719    k: usize,
720    n: usize,
721    mut result: Vec<f32>,
722) -> Array {
723    #[cfg(target_arch = "x86")]
724    use std::arch::x86::*;
725    #[cfg(target_arch = "x86_64")]
726    use std::arch::x86_64::*;
727
728    // Optimized block sizes: larger blocks for better arithmetic intensity
729    const BLOCK_M: usize = 96; // Process 96 rows per block
730    const BLOCK_N: usize = 256; // Process 256 cols per block (32 AVX2 registers)
731    const BLOCK_K: usize = 512; // Larger K block for better data reuse
732
733    // Blocked matrix multiplication with advanced SIMD
734    for i0 in (0..m).step_by(BLOCK_M) {
735        let i_end = (i0 + BLOCK_M).min(m);
736
737        for j0 in (0..n).step_by(BLOCK_N) {
738            let j_end = (j0 + BLOCK_N).min(n);
739
740            for k0 in (0..k).step_by(BLOCK_K) {
741                let k_end = (k0 + BLOCK_K).min(k);
742
743                // Process 2 rows at a time for better register utilization
744                let mut i = i0;
745                while i + 2 <= i_end {
746                    let a_row0_offset = i * k;
747                    let a_row1_offset = (i + 1) * k;
748                    let result_row0_offset = i * n;
749                    let result_row1_offset = (i + 1) * n;
750
751                    // Process 16 columns at a time (2 AVX2 registers)
752                    let mut j = j0;
753                    while j + 16 <= j_end {
754                        let mut sum0_0 =
755                            _mm256_loadu_ps(result.as_ptr().add(result_row0_offset + j));
756                        let mut sum0_1 =
757                            _mm256_loadu_ps(result.as_ptr().add(result_row0_offset + j + 8));
758                        let mut sum1_0 =
759                            _mm256_loadu_ps(result.as_ptr().add(result_row1_offset + j));
760                        let mut sum1_1 =
761                            _mm256_loadu_ps(result.as_ptr().add(result_row1_offset + j + 8));
762
763                        // Inner loop: accumulate over k-block with FMA
764                        for kk in k0..k_end {
765                            let a_val0 = _mm256_set1_ps(a.data[a_row0_offset + kk]);
766                            let a_val1 = _mm256_set1_ps(a.data[a_row1_offset + kk]);
767                            let b_row_offset = kk * n;
768                            let b_vals0 = _mm256_loadu_ps(b.data.as_ptr().add(b_row_offset + j));
769                            let b_vals1 =
770                                _mm256_loadu_ps(b.data.as_ptr().add(b_row_offset + j + 8));
771
772                            // FMA: sum = sum + (a_val * b_vals)
773                            // Process both rows and both column sets simultaneously
774                            sum0_0 = _mm256_fmadd_ps(a_val0, b_vals0, sum0_0);
775                            sum0_1 = _mm256_fmadd_ps(a_val0, b_vals1, sum0_1);
776                            sum1_0 = _mm256_fmadd_ps(a_val1, b_vals0, sum1_0);
777                            sum1_1 = _mm256_fmadd_ps(a_val1, b_vals1, sum1_1);
778                        }
779
780                        _mm256_storeu_ps(result.as_mut_ptr().add(result_row0_offset + j), sum0_0);
781                        _mm256_storeu_ps(
782                            result.as_mut_ptr().add(result_row0_offset + j + 8),
783                            sum0_1,
784                        );
785                        _mm256_storeu_ps(result.as_mut_ptr().add(result_row1_offset + j), sum1_0);
786                        _mm256_storeu_ps(
787                            result.as_mut_ptr().add(result_row1_offset + j + 8),
788                            sum1_1,
789                        );
790                        j += 16;
791                    }
792
793                    // Process remaining columns in chunks of 8
794                    while j + 8 <= j_end {
795                        let mut sum0 = _mm256_loadu_ps(result.as_ptr().add(result_row0_offset + j));
796                        let mut sum1 = _mm256_loadu_ps(result.as_ptr().add(result_row1_offset + j));
797
798                        for kk in k0..k_end {
799                            let a_val0 = _mm256_set1_ps(a.data[a_row0_offset + kk]);
800                            let a_val1 = _mm256_set1_ps(a.data[a_row1_offset + kk]);
801                            let b_vals = _mm256_loadu_ps(b.data.as_ptr().add(kk * n + j));
802
803                            sum0 = _mm256_fmadd_ps(a_val0, b_vals, sum0);
804                            sum1 = _mm256_fmadd_ps(a_val1, b_vals, sum1);
805                        }
806
807                        _mm256_storeu_ps(result.as_mut_ptr().add(result_row0_offset + j), sum0);
808                        _mm256_storeu_ps(result.as_mut_ptr().add(result_row1_offset + j), sum1);
809                        j += 8;
810                    }
811
812                    // Handle remaining columns with scalar code
813                    for j in j..j_end {
814                        let mut sum0 = result[result_row0_offset + j];
815                        let mut sum1 = result[result_row1_offset + j];
816                        for kk in k0..k_end {
817                            let b_val = b.data[kk * n + j];
818                            sum0 += a.data[a_row0_offset + kk] * b_val;
819                            sum1 += a.data[a_row1_offset + kk] * b_val;
820                        }
821                        result[result_row0_offset + j] = sum0;
822                        result[result_row1_offset + j] = sum1;
823                    }
824
825                    i += 2;
826                }
827
828                // Handle remaining single row if m is odd
829                if i < i_end {
830                    let a_row_offset = i * k;
831                    let result_row_offset = i * n;
832
833                    let mut j = j0;
834                    while j + 8 <= j_end {
835                        let mut sum = _mm256_loadu_ps(result.as_ptr().add(result_row_offset + j));
836
837                        for kk in k0..k_end {
838                            let a_val = _mm256_set1_ps(a.data[a_row_offset + kk]);
839                            let b_vals = _mm256_loadu_ps(b.data.as_ptr().add(kk * n + j));
840                            sum = _mm256_fmadd_ps(a_val, b_vals, sum);
841                        }
842
843                        _mm256_storeu_ps(result.as_mut_ptr().add(result_row_offset + j), sum);
844                        j += 8;
845                    }
846
847                    for j in j..j_end {
848                        let mut sum = result[result_row_offset + j];
849                        for kk in k0..k_end {
850                            sum += a.data[a_row_offset + kk] * b.data[kk * n + j];
851                        }
852                        result[result_row_offset + j] = sum;
853                    }
854                }
855            }
856        }
857    }
858
859    Array::new(vec![m, n], result)
860}
861
862/// SIMD implementation of Conv1D (Stub)
863/// Re-export Conv1D SIMD implementation
864pub use super::simd_conv::conv1d_simd;