1pub mod batchnorm;
2pub mod conv;
3pub mod dropout;
4pub mod parallel;
5pub mod random;
6pub mod scalar;
7pub mod simd;
8pub mod simd_conv;
9
10#[cfg(not(target_arch = "wasm32"))]
12use rayon::prelude::*;
13
14use crate::array::Array;
15
16#[cfg(not(target_arch = "wasm32"))]
20fn matmul_scalar_impl(inputs: &(&Array, &Array)) -> Array {
21 let a_arr = inputs.0;
22 let b_arr = inputs.1;
23 let m = a_arr.shape[0];
24 let k1 = a_arr.shape[1];
25 let n = b_arr.shape[1];
26
27 let bm = 128usize; let bn = 128usize;
31 let bk = 256usize; let mut out = vec![0.0f32; m * n];
34
35 let rows_per_block = bm;
37 out.par_chunks_mut(n * rows_per_block)
38 .enumerate()
39 .for_each(|(block_idx, out_block)| {
40 let i0 = block_idx * rows_per_block;
41 let i_max = (i0 + rows_per_block).min(m);
42 let rows_in_block = i_max - i0;
43
44 for k0 in (0..k1).step_by(bk) {
46 let k_max = (k0 + bk).min(k1);
47
48 for j0 in (0..n).step_by(bn) {
49 let j_max = (j0 + bn).min(n);
50
51 for ii in 0..rows_in_block {
52 let i = i0 + ii;
53 let a_row_off = i * k1;
54 let out_row_off = ii * n;
55
56 let mut j = j0;
58 while j + 4 <= j_max {
59 let mut sum0 = out_block[out_row_off + j];
60 let mut sum1 = out_block[out_row_off + j + 1];
61 let mut sum2 = out_block[out_row_off + j + 2];
62 let mut sum3 = out_block[out_row_off + j + 3];
63
64 for kk in k0..k_max {
65 let a_val = a_arr.data[a_row_off + kk];
66 let b_row_off = kk * n;
67 sum0 += a_val * b_arr.data[b_row_off + j];
68 sum1 += a_val * b_arr.data[b_row_off + j + 1];
69 sum2 += a_val * b_arr.data[b_row_off + j + 2];
70 sum3 += a_val * b_arr.data[b_row_off + j + 3];
71 }
72
73 out_block[out_row_off + j] = sum0;
74 out_block[out_row_off + j + 1] = sum1;
75 out_block[out_row_off + j + 2] = sum2;
76 out_block[out_row_off + j + 3] = sum3;
77 j += 4;
78 }
79
80 while j < j_max {
82 let mut sum = out_block[out_row_off + j];
83 for kk in k0..k_max {
84 sum += a_arr.data[a_row_off + kk] * b_arr.data[kk * n + j];
85 }
86 out_block[out_row_off + j] = sum;
87 j += 1;
88 }
89 }
90 }
91 }
92 });
93
94 crate::array::Array::new(vec![m, n], out)
95}
96
97#[cfg(target_arch = "wasm32")]
98fn matmul_scalar_impl(inputs: &(&Array, &Array)) -> Array {
99 let a_arr = inputs.0;
100 let b_arr = inputs.1;
101 let m = a_arr.shape[0];
102 let k1 = a_arr.shape[1];
103 let n = b_arr.shape[1];
104
105 let bm = 128usize;
106 let bn = 128usize;
107 let bk = 256usize;
108
109 let mut out = vec![0.0f32; m * n];
110
111 let rows_per_block = bm;
113
114 out.chunks_mut(n * rows_per_block)
116 .enumerate()
117 .for_each(|(block_idx, out_block)| {
118 let i0 = block_idx * rows_per_block;
119 let i_max = (i0 + rows_per_block).min(m);
120 let rows_in_block = i_max - i0;
121
122 for k0 in (0..k1).step_by(bk) {
123 let k_max = (k0 + bk).min(k1);
124
125 for j0 in (0..n).step_by(bn) {
126 let j_max = (j0 + bn).min(n);
127
128 for ii in 0..rows_in_block {
129 let i = i0 + ii;
130 let a_row_off = i * k1;
131 let out_row_off = ii * n;
132
133 let mut j = j0;
134 while j + 4 <= j_max {
135 let mut sum0 = out_block[out_row_off + j];
136 let mut sum1 = out_block[out_row_off + j + 1];
137 let mut sum2 = out_block[out_row_off + j + 2];
138 let mut sum3 = out_block[out_row_off + j + 3];
139
140 for kk in k0..k_max {
141 let a_val = a_arr.data[a_row_off + kk];
142 let b_row_off = kk * n;
143 sum0 += a_val * b_arr.data[b_row_off + j];
144 sum1 += a_val * b_arr.data[b_row_off + j + 1];
145 sum2 += a_val * b_arr.data[b_row_off + j + 2];
146 sum3 += a_val * b_arr.data[b_row_off + j + 3];
147 }
148
149 out_block[out_row_off + j] = sum0;
150 out_block[out_row_off + j + 1] = sum1;
151 out_block[out_row_off + j + 2] = sum2;
152 out_block[out_row_off + j + 3] = sum3;
153 j += 4;
154 }
155
156 while j < j_max {
157 let mut sum = out_block[out_row_off + j];
158 for kk in k0..k_max {
159 sum += a_arr.data[a_row_off + kk] * b_arr.data[kk * n + j];
160 }
161 out_block[out_row_off + j] = sum;
162 j += 1;
163 }
164 }
165 }
166 }
167 });
168
169 crate::array::Array::new(vec![m, n], out)
170}
171
172pub fn matmul_scalar_parallel(a: &Array, b: &Array) -> Array {
179 eprintln!(
180 "[SCALAR_IMPL] matmul_scalar_parallel called for {}x{}",
181 a.shape[0], a.shape[1]
182 );
183 matmul_scalar_impl(&(a, b))
184}
185
186pub fn matmul_simd_direct(a: &Array, b: &Array) -> Array {
188 simd::matmul_simd(a, b)
189}
190
191pub fn matmul_scalar_direct(a: &Array, b: &Array) -> Array {
193 matmul_scalar_impl(&(a, b))
194}
195
196#[derive(Debug, Clone)]
198pub struct CpuBackend {
199 }
201
202impl CpuBackend {
203 pub fn new() -> Self {
204 Self {}
205 }
206
207 pub fn matmul_fallback(
210 a: &crate::array::Array,
211 b: &crate::array::Array,
212 ) -> crate::array::Array {
213 matmul_scalar_impl(&(a, b))
214 }
215
216 }