Skip to main content

trueno/tiling/
calculator.rs

1//! TCB index calculation utilities.
2
3use super::config::TilingConfig;
4
5/// Index calculator for hierarchical tiling
6///
7/// Converts between linear indices and (row, col) coordinates at each tiling level.
8#[derive(Debug, Clone)]
9pub struct TcbIndexCalculator {
10    /// Tiling configuration
11    pub config: TilingConfig,
12    /// Problem dimensions
13    problem_m: u32,
14    problem_n: u32,
15    problem_k: u32,
16}
17
18impl TcbIndexCalculator {
19    /// Create a new index calculator for the given problem size
20    #[must_use]
21    pub fn new(config: TilingConfig, m: u32, n: u32, k: u32) -> Self {
22        Self { config, problem_m: m, problem_n: n, problem_k: k }
23    }
24
25    /// Get macro-tile offset for a given block index
26    ///
27    /// Returns (row_offset, col_offset) in the output matrix.
28    #[must_use]
29    pub fn macro_tile_offset(&self, block_idx: u32) -> (u32, u32) {
30        let tiles_per_row =
31            (self.problem_n + self.config.macro_tile.n - 1) / self.config.macro_tile.n;
32        let row = (block_idx / tiles_per_row) * self.config.macro_tile.m;
33        let col = (block_idx % tiles_per_row) * self.config.macro_tile.n;
34        (row, col)
35    }
36
37    /// Get midi-tile offset within a macro-tile
38    #[must_use]
39    pub fn midi_tile_offset(&self, midi_idx: u32) -> (u32, u32) {
40        let tiles_per_row = self.config.macro_tile.n / self.config.midi_tile.n;
41        let row = (midi_idx / tiles_per_row) * self.config.midi_tile.m;
42        let col = (midi_idx % tiles_per_row) * self.config.midi_tile.n;
43        (row, col)
44    }
45
46    /// Get micro-tile offset within a midi-tile
47    #[must_use]
48    pub fn micro_tile_offset(&self, micro_idx: u32) -> (u32, u32) {
49        let tiles_per_row = self.config.midi_tile.n / self.config.micro_tile.n;
50        let row = (micro_idx / tiles_per_row) * self.config.micro_tile.m;
51        let col = (micro_idx % tiles_per_row) * self.config.micro_tile.n;
52        (row, col)
53    }
54
55    /// Convert block index to linear memory offset
56    ///
57    /// For row-major C matrix with given stride.
58    #[must_use]
59    #[inline]
60    pub fn block_to_linear_offset(&self, block_idx: u32, stride: u32) -> usize {
61        let (row, col) = self.macro_tile_offset(block_idx);
62        (row * stride + col) as usize
63    }
64
65    /// Calculate A matrix offset for K-dimension blocking
66    #[must_use]
67    #[inline]
68    pub fn a_offset(&self, macro_row: u32, k_block: u32) -> usize {
69        let row = macro_row * self.config.macro_tile.m;
70        let col = k_block * self.config.macro_tile.k;
71        (row * self.problem_k + col) as usize
72    }
73
74    /// Calculate B matrix offset for K-dimension blocking
75    #[must_use]
76    #[inline]
77    pub fn b_offset(&self, k_block: u32, macro_col: u32) -> usize {
78        let row = k_block * self.config.macro_tile.k;
79        let col = macro_col * self.config.macro_tile.n;
80        (row * self.problem_n + col) as usize
81    }
82
83    /// Get number of K blocks needed
84    #[must_use]
85    pub fn num_k_blocks(&self) -> u32 {
86        (self.problem_k + self.config.macro_tile.k - 1) / self.config.macro_tile.k
87    }
88
89    /// Check if this is a boundary tile (may need masking)
90    #[must_use]
91    pub fn is_boundary_tile(&self, block_idx: u32) -> bool {
92        let (row, col) = self.macro_tile_offset(block_idx);
93        row + self.config.macro_tile.m > self.problem_m
94            || col + self.config.macro_tile.n > self.problem_n
95    }
96
97    /// Get actual tile dimensions (may be smaller at boundaries)
98    #[must_use]
99    pub fn actual_tile_dims(&self, block_idx: u32) -> (u32, u32) {
100        let (row, col) = self.macro_tile_offset(block_idx);
101        let actual_m = (self.problem_m - row).min(self.config.macro_tile.m);
102        let actual_n = (self.problem_n - col).min(self.config.macro_tile.n);
103        (actual_m, actual_n)
104    }
105}