1use crate::tensor_simd;
18
19const DEFAULT_TILE_SIZE: usize = 64;
21
22pub struct TiledMatmul {
24 pub tile_size: usize,
26}
27
28impl TiledMatmul {
29 pub fn new() -> Self {
31 TiledMatmul {
32 tile_size: DEFAULT_TILE_SIZE,
33 }
34 }
35
36 pub fn with_tile_size(tile_size: usize) -> Self {
38 let ts = if tile_size == 0 { DEFAULT_TILE_SIZE } else { tile_size };
39 TiledMatmul { tile_size: ts }
40 }
41
42 pub fn matmul(
50 &self,
51 a: &[f64],
52 m: usize,
53 k: usize,
54 b: &[f64],
55 n: usize,
56 ) -> Vec<f64> {
57 assert_eq!(a.len(), m * k, "a dimensions mismatch");
58 assert_eq!(b.len(), k * n, "b dimensions mismatch");
59
60 let mut c = vec![0.0f64; m * n];
61 let ts = self.tile_size;
62
63 let mut ii = 0;
65 while ii < m {
66 let i_end = (ii + ts).min(m);
67 let mut jj = 0;
68 while jj < n {
69 let j_end = (jj + ts).min(n);
70 let mut pp = 0;
71 while pp < k {
72 let p_end = (pp + ts).min(k);
73
74 let j_len = j_end - jj;
78 for i in ii..i_end {
79 for p in pp..p_end {
80 let a_ip = a[i * k + p];
81 let c_slice = &mut c[i * n + jj .. i * n + j_end];
82 let b_slice = &b[p * n + jj .. p * n + j_end];
83 tensor_simd::simd_axpy(c_slice, b_slice, a_ip, j_len);
84 }
85 }
86
87 pp += ts;
88 }
89 jj += ts;
90 }
91 ii += ts;
92 }
93
94 c
95 }
96
97 pub fn matmul_transposed_b(
104 &self,
105 a: &[f64],
106 m: usize,
107 k: usize,
108 b: &[f64],
109 n: usize,
110 ) -> Vec<f64> {
111 assert_eq!(a.len(), m * k, "a dimensions mismatch");
112 assert_eq!(b.len(), n * k, "b dimensions mismatch (n × k expected)");
113
114 let mut c = vec![0.0f64; m * n];
115 let ts = self.tile_size;
116
117 let mut ii = 0;
118 while ii < m {
119 let i_end = (ii + ts).min(m);
120 let mut jj = 0;
121 while jj < n {
122 let j_end = (jj + ts).min(n);
123
124 for i in ii..i_end {
125 for j in jj..j_end {
126 let mut sum = 0.0f64;
127 for p in 0..k {
128 sum += a[i * k + p] * b[j * k + p];
129 }
130 c[i * n + j] = sum;
131 }
132 }
133
134 jj += ts;
135 }
136 ii += ts;
137 }
138
139 c
140 }
141}
142
143impl Default for TiledMatmul {
144 fn default() -> Self {
145 Self::new()
146 }
147}
148
149#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn test_tiled_matmul_2x2() {
159 let engine = TiledMatmul::new();
160 let a = vec![1.0, 2.0, 3.0, 4.0];
163 let b = vec![5.0, 6.0, 7.0, 8.0];
164 let c = engine.matmul(&a, 2, 2, &b, 2);
165 assert_eq!(c, vec![19.0, 22.0, 43.0, 50.0]);
166 }
167
168 #[test]
169 fn test_tiled_matmul_nonsquare() {
170 let engine = TiledMatmul::new();
171 let a = vec![2.0, 3.0];
174 let b = vec![1.0, 0.0, 4.0, 5.0];
175 let c = engine.matmul(&a, 1, 2, &b, 2);
176 assert_eq!(c, vec![14.0, 15.0]);
177 }
178
179 #[test]
180 fn test_tiled_matmul_identity() {
181 let engine = TiledMatmul::new();
182 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
183 let eye = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
184 let c = engine.matmul(&a, 3, 3, &eye, 3);
185 assert_eq!(c, a);
186 }
187
188 #[test]
189 fn test_tiled_with_small_tile() {
190 let engine = TiledMatmul::with_tile_size(2);
192 let a = vec![
193 1.0, 2.0, 3.0, 4.0,
194 5.0, 6.0, 7.0, 8.0,
195 9.0, 10.0, 11.0, 12.0,
196 13.0, 14.0, 15.0, 16.0,
197 ];
198 let b = vec![
199 1.0, 0.0, 0.0, 0.0,
200 0.0, 1.0, 0.0, 0.0,
201 0.0, 0.0, 1.0, 0.0,
202 0.0, 0.0, 0.0, 1.0,
203 ];
204 let c = engine.matmul(&a, 4, 4, &b, 4);
205 assert_eq!(c, a, "A × I = A with tiling");
206 }
207
208 #[test]
209 fn test_tiled_deterministic() {
210 let e1 = TiledMatmul::with_tile_size(3);
211 let e2 = TiledMatmul::with_tile_size(3);
212
213 let a: Vec<f64> = (0..25).map(|i| i as f64 * 0.1).collect();
214 let b: Vec<f64> = (0..25).map(|i| (25 - i) as f64 * 0.1).collect();
215
216 let c1 = e1.matmul(&a, 5, 5, &b, 5);
217 let c2 = e2.matmul(&a, 5, 5, &b, 5);
218
219 assert_eq!(c1, c2, "deterministic tiled matmul");
220 }
221
222 #[test]
223 fn test_tiled_matches_naive() {
224 let engine = TiledMatmul::with_tile_size(2);
225 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
226 let b = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
227
228 let tiled = engine.matmul(&a, 2, 3, &b, 2);
229
230 let expected = naive_matmul(&a, 2, 3, &b, 2);
232
233 for (i, (t, e)) in tiled.iter().zip(expected.iter()).enumerate() {
234 assert!(
235 (t - e).abs() < 1e-12,
236 "mismatch at index {i}: tiled={t}, naive={e}"
237 );
238 }
239 }
240
241 #[test]
242 fn test_transposed_b_matmul() {
243 let engine = TiledMatmul::new();
244 let a = vec![1.0, 2.0, 3.0, 4.0];
251 let bt = vec![5.0, 7.0, 6.0, 8.0]; let c = engine.matmul_transposed_b(&a, 2, 2, &bt, 2);
253 assert_eq!(c, vec![19.0, 22.0, 43.0, 50.0]);
254 }
255
256 #[test]
257 fn test_large_tiled_correctness() {
258 let engine = TiledMatmul::with_tile_size(8);
260 let n = 32;
261 let a: Vec<f64> = (0..n * n).map(|i| (i as f64) * 0.01).collect();
262 let b: Vec<f64> = (0..n * n).map(|i| ((n * n - i) as f64) * 0.01).collect();
263
264 let tiled = engine.matmul(&a, n, n, &b, n);
265 let naive = naive_matmul(&a, n, n, &b, n);
266
267 for (i, (t, e)) in tiled.iter().zip(naive.iter()).enumerate() {
268 assert!(
269 (t - e).abs() < 1e-8,
270 "mismatch at [{}, {}]: tiled={t}, naive={e}",
271 i / n,
272 i % n
273 );
274 }
275 }
276
277 fn naive_matmul(a: &[f64], m: usize, k: usize, b: &[f64], n: usize) -> Vec<f64> {
279 let mut c = vec![0.0f64; m * n];
280 for i in 0..m {
281 for j in 0..n {
282 let mut sum = 0.0;
283 for p in 0..k {
284 sum += a[i * k + p] * b[p * n + j];
285 }
286 c[i * n + j] = sum;
287 }
288 }
289 c
290 }
291}