Skip to main content

trueno/tiling/
geometry.rs

1//! TCB geometry types and level definitions.
2
3use serde::{Deserialize, Serialize};
4use std::fmt;
5
6// ============================================================================
7// TILE-001: TcbGeometry Struct
8// ============================================================================
9
10/// Dimensions for a Tiling Compute Block
11///
12/// Represents the (M, N, K) dimensions of a tile in matrix operations:
13/// - M: Output rows
14/// - N: Output columns
15/// - K: Reduction dimension (inner product)
16///
17/// # Alignment Constraints
18///
19/// Per the TCB-03 pattern (Tile Quantization Alignment), K must align with
20/// the quantization superblock size:
21/// - Q4_0: K % 32 == 0
22/// - Q4_K: K % 256 == 0
23/// - Q8_0: K % 32 == 0
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
25pub struct TcbGeometry {
26    /// Items processed in M dimension (rows)
27    pub m: u32,
28    /// Items processed in N dimension (columns)
29    pub n: u32,
30    /// Reduction dimension (inner product)
31    pub k: u32,
32    /// Alignment requirement in bytes (typically 16 for SIMD, 32 for AVX2, 64 for AVX-512)
33    pub alignment: u32,
34}
35
36impl TcbGeometry {
37    /// Create a new TCB geometry
38    ///
39    /// # Panics
40    /// Panics if any dimension is zero.
41    #[must_use]
42    pub fn new(m: u32, n: u32, k: u32) -> Self {
43        assert!(m > 0 && n > 0 && k > 0, "TCB dimensions must be non-zero");
44        Self {
45            m,
46            n,
47            k,
48            alignment: 16, // Default to SSE/NEON alignment
49        }
50    }
51
52    /// Create geometry with explicit alignment
53    #[must_use]
54    pub fn with_alignment(m: u32, n: u32, k: u32, alignment: u32) -> Self {
55        assert!(m > 0 && n > 0 && k > 0, "TCB dimensions must be non-zero");
56        assert!(alignment.is_power_of_two(), "Alignment must be power of 2");
57        Self { m, n, k, alignment }
58    }
59
60    /// Calculate arithmetic intensity (FLOPS per byte loaded)
61    ///
62    /// For GEMM: AI = (2 * M * N * K) / (M*K + K*N) * sizeof(f32)
63    ///
64    /// Higher AI means compute-bound; lower means memory-bound.
65    #[must_use]
66    pub fn arithmetic_intensity(&self) -> f32 {
67        let flops = 2.0 * self.m as f64 * self.n as f64 * self.k as f64;
68        let bytes = (self.m as f64 * self.k as f64 + self.k as f64 * self.n as f64) * 4.0;
69        (flops / bytes) as f32
70    }
71
72    /// Calculate total elements in the tile
73    #[must_use]
74    pub fn total_elements(&self) -> u64 {
75        self.m as u64 * self.n as u64
76    }
77
78    /// Calculate total FLOPs for this tile
79    #[must_use]
80    pub fn total_flops(&self) -> u64 {
81        2 * self.m as u64 * self.n as u64 * self.k as u64
82    }
83
84    /// Check if K dimension aligns with Q4_K superblock (256)
85    #[must_use]
86    pub fn is_q4k_aligned(&self) -> bool {
87        self.k % 256 == 0
88    }
89
90    /// Check if K dimension aligns with Q4_0/Q8_0 block (32)
91    #[must_use]
92    pub fn is_q4_0_aligned(&self) -> bool {
93        self.k % 32 == 0
94    }
95
96    /// Calculate bytes needed for A tile (M × K × sizeof(f32))
97    #[must_use]
98    pub fn a_tile_bytes(&self) -> usize {
99        self.m as usize * self.k as usize * 4
100    }
101
102    /// Calculate bytes needed for B tile (K × N × sizeof(f32))
103    #[must_use]
104    pub fn b_tile_bytes(&self) -> usize {
105        self.k as usize * self.n as usize * 4
106    }
107
108    /// Calculate bytes needed for C tile (M × N × sizeof(f32))
109    #[must_use]
110    pub fn c_tile_bytes(&self) -> usize {
111        self.m as usize * self.n as usize * 4
112    }
113
114    /// Check if tile fits in given cache size (bytes)
115    #[must_use]
116    pub fn fits_in_cache(&self, cache_bytes: usize) -> bool {
117        self.a_tile_bytes() + self.b_tile_bytes() <= cache_bytes
118    }
119}
120
121impl Default for TcbGeometry {
122    fn default() -> Self {
123        // Sensible default: 4×4 micro-tile for SIMD
124        Self { m: 4, n: 4, k: 4, alignment: 16 }
125    }
126}
127
128impl fmt::Display for TcbGeometry {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        write!(
131            f,
132            "TCB({}×{}×{}, align={}, AI={:.2})",
133            self.m,
134            self.n,
135            self.k,
136            self.alignment,
137            self.arithmetic_intensity()
138        )
139    }
140}
141// ============================================================================
142// TILE-001: Tiling Levels
143// ============================================================================
144
145/// Tiling hierarchy level
146#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
147pub enum TcbLevel {
148    /// Macro-tile: L3 cache / GPU global memory partitioning
149    Macro,
150    /// Midi-tile: L2 cache / GPU shared memory
151    Midi,
152    /// Micro-tile: Registers / SIMD lanes
153    Micro,
154}
155
156impl TcbLevel {
157    /// Get typical cache size for this level (x86_64)
158    #[must_use]
159    pub fn typical_cache_bytes(&self) -> usize {
160        match self {
161            TcbLevel::Macro => 32 * 1024 * 1024, // 32 MB L3
162            TcbLevel::Midi => 256 * 1024,        // 256 KB L2
163            TcbLevel::Micro => 32 * 1024,        // 32 KB L1
164        }
165    }
166}