numrs/backend/cpu/
mod.rs

1pub mod batchnorm;
2pub mod conv;
3pub mod dropout;
4pub mod parallel;
5pub mod random;
6pub mod scalar;
7pub mod simd;
8pub mod simd_conv;
9
10// Pub use elementwise removed (doesn't exist)
11#[cfg(not(target_arch = "wasm32"))]
12use rayon::prelude::*;
13
14use crate::array::Array;
15
16// Local wrappers used by method selection. Each wrapper accepts a tuple
17// `(&Array, &Array)` and returns an `Array`, matching the `fn(&T)->R` signature
18// used by KernelSelectionContext.
19#[cfg(not(target_arch = "wasm32"))]
20fn matmul_scalar_impl(inputs: &(&Array, &Array)) -> Array {
21    let a_arr = inputs.0;
22    let b_arr = inputs.1;
23    let m = a_arr.shape[0];
24    let k1 = a_arr.shape[1];
25    let n = b_arr.shape[1];
26
27    // Optimized block sizes for better cache utilization
28    // BM/BN sized for L2 cache (~256KB), BK for L1 cache
29    let bm = 128usize; // Larger blocks for better parallelization
30    let bn = 128usize;
31    let bk = 256usize; // Larger K blocks to amortize memory access
32
33    let mut out = vec![0.0f32; m * n];
34
35    // Partition output into row-blocks and parallelize
36    let rows_per_block = bm;
37    out.par_chunks_mut(n * rows_per_block)
38        .enumerate()
39        .for_each(|(block_idx, out_block)| {
40            let i0 = block_idx * rows_per_block;
41            let i_max = (i0 + rows_per_block).min(m);
42            let rows_in_block = i_max - i0;
43
44            // Reorder loops: k -> i -> j for better cache locality
45            for k0 in (0..k1).step_by(bk) {
46                let k_max = (k0 + bk).min(k1);
47
48                for j0 in (0..n).step_by(bn) {
49                    let j_max = (j0 + bn).min(n);
50
51                    for ii in 0..rows_in_block {
52                        let i = i0 + ii;
53                        let a_row_off = i * k1;
54                        let out_row_off = ii * n;
55
56                        // Process 4 columns at a time (manual loop unrolling)
57                        let mut j = j0;
58                        while j + 4 <= j_max {
59                            let mut sum0 = out_block[out_row_off + j];
60                            let mut sum1 = out_block[out_row_off + j + 1];
61                            let mut sum2 = out_block[out_row_off + j + 2];
62                            let mut sum3 = out_block[out_row_off + j + 3];
63
64                            for kk in k0..k_max {
65                                let a_val = a_arr.data[a_row_off + kk];
66                                let b_row_off = kk * n;
67                                sum0 += a_val * b_arr.data[b_row_off + j];
68                                sum1 += a_val * b_arr.data[b_row_off + j + 1];
69                                sum2 += a_val * b_arr.data[b_row_off + j + 2];
70                                sum3 += a_val * b_arr.data[b_row_off + j + 3];
71                            }
72
73                            out_block[out_row_off + j] = sum0;
74                            out_block[out_row_off + j + 1] = sum1;
75                            out_block[out_row_off + j + 2] = sum2;
76                            out_block[out_row_off + j + 3] = sum3;
77                            j += 4;
78                        }
79
80                        // Handle remaining columns
81                        while j < j_max {
82                            let mut sum = out_block[out_row_off + j];
83                            for kk in k0..k_max {
84                                sum += a_arr.data[a_row_off + kk] * b_arr.data[kk * n + j];
85                            }
86                            out_block[out_row_off + j] = sum;
87                            j += 1;
88                        }
89                    }
90                }
91            }
92        });
93
94    crate::array::Array::new(vec![m, n], out)
95}
96
97#[cfg(target_arch = "wasm32")]
98fn matmul_scalar_impl(inputs: &(&Array, &Array)) -> Array {
99    let a_arr = inputs.0;
100    let b_arr = inputs.1;
101    let m = a_arr.shape[0];
102    let k1 = a_arr.shape[1];
103    let n = b_arr.shape[1];
104
105    let bm = 128usize;
106    let bn = 128usize;
107    let bk = 256usize;
108
109    let mut out = vec![0.0f32; m * n];
110
111    // Serial execution for WASM
112    let rows_per_block = bm;
113
114    // Using standard iter_mut().enumerate() for serial processing
115    out.chunks_mut(n * rows_per_block)
116        .enumerate()
117        .for_each(|(block_idx, out_block)| {
118            let i0 = block_idx * rows_per_block;
119            let i_max = (i0 + rows_per_block).min(m);
120            let rows_in_block = i_max - i0;
121
122            for k0 in (0..k1).step_by(bk) {
123                let k_max = (k0 + bk).min(k1);
124
125                for j0 in (0..n).step_by(bn) {
126                    let j_max = (j0 + bn).min(n);
127
128                    for ii in 0..rows_in_block {
129                        let i = i0 + ii;
130                        let a_row_off = i * k1;
131                        let out_row_off = ii * n;
132
133                        let mut j = j0;
134                        while j + 4 <= j_max {
135                            let mut sum0 = out_block[out_row_off + j];
136                            let mut sum1 = out_block[out_row_off + j + 1];
137                            let mut sum2 = out_block[out_row_off + j + 2];
138                            let mut sum3 = out_block[out_row_off + j + 3];
139
140                            for kk in k0..k_max {
141                                let a_val = a_arr.data[a_row_off + kk];
142                                let b_row_off = kk * n;
143                                sum0 += a_val * b_arr.data[b_row_off + j];
144                                sum1 += a_val * b_arr.data[b_row_off + j + 1];
145                                sum2 += a_val * b_arr.data[b_row_off + j + 2];
146                                sum3 += a_val * b_arr.data[b_row_off + j + 3];
147                            }
148
149                            out_block[out_row_off + j] = sum0;
150                            out_block[out_row_off + j + 1] = sum1;
151                            out_block[out_row_off + j + 2] = sum2;
152                            out_block[out_row_off + j + 3] = sum3;
153                            j += 4;
154                        }
155
156                        while j < j_max {
157                            let mut sum = out_block[out_row_off + j];
158                            for kk in k0..k_max {
159                                sum += a_arr.data[a_row_off + kk] * b_arr.data[kk * n + j];
160                            }
161                            out_block[out_row_off + j] = sum;
162                            j += 1;
163                        }
164                    }
165                }
166            }
167        });
168
169    crate::array::Array::new(vec![m, n], out)
170}
171
172// ============================================================================
173// Public kernel wrappers for dispatch system
174// ============================================================================
175
176/// Matmul scalar con paralelización Rayon (para benchmarking)
177/// Usa bloques optimizados pero sin instrucciones SIMD
178pub fn matmul_scalar_parallel(a: &Array, b: &Array) -> Array {
179    eprintln!(
180        "[SCALAR_IMPL] matmul_scalar_parallel called for {}x{}",
181        a.shape[0], a.shape[1]
182    );
183    matmul_scalar_impl(&(a, b))
184}
185
186/// SIMD-accelerated matmul (uses AVX2+FMA when available, falls back to scalar)
187pub fn matmul_simd_direct(a: &Array, b: &Array) -> Array {
188    simd::matmul_simd(a, b)
189}
190
191/// Scalar fallback matmul (always available)
192pub fn matmul_scalar_direct(a: &Array, b: &Array) -> Array {
193    matmul_scalar_impl(&(a, b))
194}
195
196/// CPU backend orchestrates scalar and SIMD strategies.
197#[derive(Debug, Clone)]
198pub struct CpuBackend {
199    // future: CPU threads, affinity, simd levels
200}
201
202impl CpuBackend {
203    pub fn new() -> Self {
204        Self {}
205    }
206
207    /// Expose a fallback matmul entry point for microbench/testing that
208    /// directly invokes the scalar/parallel implementation without using BLAS.
209    pub fn matmul_fallback(
210        a: &crate::array::Array,
211        b: &crate::array::Array,
212    ) -> crate::array::Array {
213        matmul_scalar_impl(&(a, b))
214    }
215
216    // execute() method removed - use ops::fast::* functions with dispatch system instead
217}