Skip to main content

cjc_runtime/
tensor_tiled.rs

1//! Tensor Tiling — L2-friendly tiled matrix multiplication.
2//!
3//! Provides a tiled matmul implementation that operates on tiles that fit
4//! within the L2 cache, improving locality for large matrices.
5//!
6//! # Determinism
7//!
8//! - Tile iteration order is deterministic (row-major over tiles).
9//! - The summation within each tile uses the same accumulation order.
10//! - Same inputs → bit-identical outputs on the same platform.
11//!
12//! # Tile Size
13//!
14//! Default tile size is 64×64 (32 KB per tile at f64, fits in most L2 caches).
15//! Configurable via `TiledMatmul::with_tile_size()`.
16
17use crate::tensor_simd;
18
19/// Default tile dimension. 64×64 doubles = 32 KB per tile.
20const DEFAULT_TILE_SIZE: usize = 64;
21
22/// Tiled matrix multiplication engine.
23pub struct TiledMatmul {
24    /// Tile dimension (square tiles).
25    pub tile_size: usize,
26}
27
28impl TiledMatmul {
29    /// Create with default tile size (64).
30    pub fn new() -> Self {
31        TiledMatmul {
32            tile_size: DEFAULT_TILE_SIZE,
33        }
34    }
35
36    /// Create with a custom tile size.
37    pub fn with_tile_size(tile_size: usize) -> Self {
38        let ts = if tile_size == 0 { DEFAULT_TILE_SIZE } else { tile_size };
39        TiledMatmul { tile_size: ts }
40    }
41
42    /// Compute C = A × B using tiled iteration.
43    ///
44    /// - `a`: row-major matrix [m × k]
45    /// - `b`: row-major matrix [k × n]
46    /// - Returns: row-major matrix [m × n]
47    ///
48    /// Panics if inner dimensions don't match.
49    pub fn matmul(
50        &self,
51        a: &[f64],
52        m: usize,
53        k: usize,
54        b: &[f64],
55        n: usize,
56    ) -> Vec<f64> {
57        assert_eq!(a.len(), m * k, "a dimensions mismatch");
58        assert_eq!(b.len(), k * n, "b dimensions mismatch");
59
60        let mut c = vec![0.0f64; m * n];
61        let ts = self.tile_size;
62
63        // Tile over all three dimensions: i, j, p (deterministic order).
64        let mut ii = 0;
65        while ii < m {
66            let i_end = (ii + ts).min(m);
67            let mut jj = 0;
68            while jj < n {
69                let j_end = (jj + ts).min(n);
70                let mut pp = 0;
71                while pp < k {
72                    let p_end = (pp + ts).min(k);
73
74                    // Micro-kernel: accumulate tile contribution.
75                    // Uses SIMD-accelerated AXPY for the inner j-loop
76                    // (4-wide AVX2 when available, scalar fallback otherwise).
77                    let j_len = j_end - jj;
78                    for i in ii..i_end {
79                        for p in pp..p_end {
80                            let a_ip = a[i * k + p];
81                            let c_slice = &mut c[i * n + jj .. i * n + j_end];
82                            let b_slice = &b[p * n + jj .. p * n + j_end];
83                            tensor_simd::simd_axpy(c_slice, b_slice, a_ip, j_len);
84                        }
85                    }
86
87                    pp += ts;
88                }
89                jj += ts;
90            }
91            ii += ts;
92        }
93
94        c
95    }
96
97    /// Compute C = A × B^T using tiled iteration (useful when B is stored
98    /// in row-major but you need A × B^T).
99    ///
100    /// - `a`: row-major matrix [m × k]
101    /// - `b`: row-major matrix [n × k] (transposed: each row of b is a column of B)
102    /// - Returns: row-major matrix [m × n]
103    pub fn matmul_transposed_b(
104        &self,
105        a: &[f64],
106        m: usize,
107        k: usize,
108        b: &[f64],
109        n: usize,
110    ) -> Vec<f64> {
111        assert_eq!(a.len(), m * k, "a dimensions mismatch");
112        assert_eq!(b.len(), n * k, "b dimensions mismatch (n × k expected)");
113
114        let mut c = vec![0.0f64; m * n];
115        let ts = self.tile_size;
116
117        let mut ii = 0;
118        while ii < m {
119            let i_end = (ii + ts).min(m);
120            let mut jj = 0;
121            while jj < n {
122                let j_end = (jj + ts).min(n);
123
124                for i in ii..i_end {
125                    for j in jj..j_end {
126                        let mut sum = 0.0f64;
127                        for p in 0..k {
128                            sum += a[i * k + p] * b[j * k + p];
129                        }
130                        c[i * n + j] = sum;
131                    }
132                }
133
134                jj += ts;
135            }
136            ii += ts;
137        }
138
139        c
140    }
141}
142
143impl Default for TiledMatmul {
144    fn default() -> Self {
145        Self::new()
146    }
147}
148
149// ---------------------------------------------------------------------------
150// Tests
151// ---------------------------------------------------------------------------
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn test_tiled_matmul_2x2() {
159        let engine = TiledMatmul::new();
160        // [1 2] × [5 6] = [19 22]
161        // [3 4]   [7 8]   [43 50]
162        let a = vec![1.0, 2.0, 3.0, 4.0];
163        let b = vec![5.0, 6.0, 7.0, 8.0];
164        let c = engine.matmul(&a, 2, 2, &b, 2);
165        assert_eq!(c, vec![19.0, 22.0, 43.0, 50.0]);
166    }
167
168    #[test]
169    fn test_tiled_matmul_nonsquare() {
170        let engine = TiledMatmul::new();
171        // [2 3] × [1 0] = [2+12 0+15]   = [14 15]
172        //         [4 5]
173        let a = vec![2.0, 3.0];
174        let b = vec![1.0, 0.0, 4.0, 5.0];
175        let c = engine.matmul(&a, 1, 2, &b, 2);
176        assert_eq!(c, vec![14.0, 15.0]);
177    }
178
179    #[test]
180    fn test_tiled_matmul_identity() {
181        let engine = TiledMatmul::new();
182        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
183        let eye = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
184        let c = engine.matmul(&a, 3, 3, &eye, 3);
185        assert_eq!(c, a);
186    }
187
188    #[test]
189    fn test_tiled_with_small_tile() {
190        // Use tile_size=2 to force tiling on a 4×4 matrix.
191        let engine = TiledMatmul::with_tile_size(2);
192        let a = vec![
193            1.0, 2.0, 3.0, 4.0,
194            5.0, 6.0, 7.0, 8.0,
195            9.0, 10.0, 11.0, 12.0,
196            13.0, 14.0, 15.0, 16.0,
197        ];
198        let b = vec![
199            1.0, 0.0, 0.0, 0.0,
200            0.0, 1.0, 0.0, 0.0,
201            0.0, 0.0, 1.0, 0.0,
202            0.0, 0.0, 0.0, 1.0,
203        ];
204        let c = engine.matmul(&a, 4, 4, &b, 4);
205        assert_eq!(c, a, "A × I = A with tiling");
206    }
207
208    #[test]
209    fn test_tiled_deterministic() {
210        let e1 = TiledMatmul::with_tile_size(3);
211        let e2 = TiledMatmul::with_tile_size(3);
212
213        let a: Vec<f64> = (0..25).map(|i| i as f64 * 0.1).collect();
214        let b: Vec<f64> = (0..25).map(|i| (25 - i) as f64 * 0.1).collect();
215
216        let c1 = e1.matmul(&a, 5, 5, &b, 5);
217        let c2 = e2.matmul(&a, 5, 5, &b, 5);
218
219        assert_eq!(c1, c2, "deterministic tiled matmul");
220    }
221
222    #[test]
223    fn test_tiled_matches_naive() {
224        let engine = TiledMatmul::with_tile_size(2);
225        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
226        let b = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
227
228        let tiled = engine.matmul(&a, 2, 3, &b, 2);
229
230        // Naive computation.
231        let expected = naive_matmul(&a, 2, 3, &b, 2);
232
233        for (i, (t, e)) in tiled.iter().zip(expected.iter()).enumerate() {
234            assert!(
235                (t - e).abs() < 1e-12,
236                "mismatch at index {i}: tiled={t}, naive={e}"
237            );
238        }
239    }
240
241    #[test]
242    fn test_transposed_b_matmul() {
243        let engine = TiledMatmul::new();
244        // A = [1 2]   B^T stored as [5 7] (row 0 of B^T = col 0 of B)
245        //     [3 4]                  [6 8] (row 1 of B^T = col 1 of B)
246        // A × B = A × (B^T)^T
247        // where B^T is [5 7; 6 8], so B = [5 6; 7 8]
248        // A × B = [1*5+2*7  1*6+2*8] = [19 22]
249        //         [3*5+4*7  3*6+4*8]   [43 50]
250        let a = vec![1.0, 2.0, 3.0, 4.0];
251        let bt = vec![5.0, 7.0, 6.0, 8.0]; // B transposed, stored [n × k]
252        let c = engine.matmul_transposed_b(&a, 2, 2, &bt, 2);
253        assert_eq!(c, vec![19.0, 22.0, 43.0, 50.0]);
254    }
255
256    #[test]
257    fn test_large_tiled_correctness() {
258        // 32×32 matrix multiplication with tile_size=8.
259        let engine = TiledMatmul::with_tile_size(8);
260        let n = 32;
261        let a: Vec<f64> = (0..n * n).map(|i| (i as f64) * 0.01).collect();
262        let b: Vec<f64> = (0..n * n).map(|i| ((n * n - i) as f64) * 0.01).collect();
263
264        let tiled = engine.matmul(&a, n, n, &b, n);
265        let naive = naive_matmul(&a, n, n, &b, n);
266
267        for (i, (t, e)) in tiled.iter().zip(naive.iter()).enumerate() {
268            assert!(
269                (t - e).abs() < 1e-8,
270                "mismatch at [{}, {}]: tiled={t}, naive={e}",
271                i / n,
272                i % n
273            );
274        }
275    }
276
277    /// Naive O(n³) matmul for verification.
278    fn naive_matmul(a: &[f64], m: usize, k: usize, b: &[f64], n: usize) -> Vec<f64> {
279        let mut c = vec![0.0f64; m * n];
280        for i in 0..m {
281            for j in 0..n {
282                let mut sum = 0.0;
283                for p in 0..k {
284                    sum += a[i * k + p] * b[p * n + j];
285                }
286                c[i * n + j] = sum;
287            }
288        }
289        c
290    }
291}