trueno/tiling/geometry.rs
1//! TCB geometry types and level definitions.
2
3use serde::{Deserialize, Serialize};
4use std::fmt;
5
6// ============================================================================
7// TILE-001: TcbGeometry Struct
8// ============================================================================
9
10/// Dimensions for a Tiling Compute Block
11///
12/// Represents the (M, N, K) dimensions of a tile in matrix operations:
13/// - M: Output rows
14/// - N: Output columns
15/// - K: Reduction dimension (inner product)
16///
17/// # Alignment Constraints
18///
19/// Per the TCB-03 pattern (Tile Quantization Alignment), K must align with
20/// the quantization superblock size:
21/// - Q4_0: K % 32 == 0
22/// - Q4_K: K % 256 == 0
23/// - Q8_0: K % 32 == 0
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
25pub struct TcbGeometry {
26 /// Items processed in M dimension (rows)
27 pub m: u32,
28 /// Items processed in N dimension (columns)
29 pub n: u32,
30 /// Reduction dimension (inner product)
31 pub k: u32,
32 /// Alignment requirement in bytes (typically 16 for SIMD, 32 for AVX2, 64 for AVX-512)
33 pub alignment: u32,
34}
35
36impl TcbGeometry {
37 /// Create a new TCB geometry
38 ///
39 /// # Panics
40 /// Panics if any dimension is zero.
41 #[must_use]
42 pub fn new(m: u32, n: u32, k: u32) -> Self {
43 assert!(m > 0 && n > 0 && k > 0, "TCB dimensions must be non-zero");
44 Self {
45 m,
46 n,
47 k,
48 alignment: 16, // Default to SSE/NEON alignment
49 }
50 }
51
52 /// Create geometry with explicit alignment
53 #[must_use]
54 pub fn with_alignment(m: u32, n: u32, k: u32, alignment: u32) -> Self {
55 assert!(m > 0 && n > 0 && k > 0, "TCB dimensions must be non-zero");
56 assert!(alignment.is_power_of_two(), "Alignment must be power of 2");
57 Self { m, n, k, alignment }
58 }
59
60 /// Calculate arithmetic intensity (FLOPS per byte loaded)
61 ///
62 /// For GEMM: AI = (2 * M * N * K) / (M*K + K*N) * sizeof(f32)
63 ///
64 /// Higher AI means compute-bound; lower means memory-bound.
65 #[must_use]
66 pub fn arithmetic_intensity(&self) -> f32 {
67 let flops = 2.0 * self.m as f64 * self.n as f64 * self.k as f64;
68 let bytes = (self.m as f64 * self.k as f64 + self.k as f64 * self.n as f64) * 4.0;
69 (flops / bytes) as f32
70 }
71
72 /// Calculate total elements in the tile
73 #[must_use]
74 pub fn total_elements(&self) -> u64 {
75 self.m as u64 * self.n as u64
76 }
77
78 /// Calculate total FLOPs for this tile
79 #[must_use]
80 pub fn total_flops(&self) -> u64 {
81 2 * self.m as u64 * self.n as u64 * self.k as u64
82 }
83
84 /// Check if K dimension aligns with Q4_K superblock (256)
85 #[must_use]
86 pub fn is_q4k_aligned(&self) -> bool {
87 self.k % 256 == 0
88 }
89
90 /// Check if K dimension aligns with Q4_0/Q8_0 block (32)
91 #[must_use]
92 pub fn is_q4_0_aligned(&self) -> bool {
93 self.k % 32 == 0
94 }
95
96 /// Calculate bytes needed for A tile (M × K × sizeof(f32))
97 #[must_use]
98 pub fn a_tile_bytes(&self) -> usize {
99 self.m as usize * self.k as usize * 4
100 }
101
102 /// Calculate bytes needed for B tile (K × N × sizeof(f32))
103 #[must_use]
104 pub fn b_tile_bytes(&self) -> usize {
105 self.k as usize * self.n as usize * 4
106 }
107
108 /// Calculate bytes needed for C tile (M × N × sizeof(f32))
109 #[must_use]
110 pub fn c_tile_bytes(&self) -> usize {
111 self.m as usize * self.n as usize * 4
112 }
113
114 /// Check if tile fits in given cache size (bytes)
115 #[must_use]
116 pub fn fits_in_cache(&self, cache_bytes: usize) -> bool {
117 self.a_tile_bytes() + self.b_tile_bytes() <= cache_bytes
118 }
119}
120
121impl Default for TcbGeometry {
122 fn default() -> Self {
123 // Sensible default: 4×4 micro-tile for SIMD
124 Self { m: 4, n: 4, k: 4, alignment: 16 }
125 }
126}
127
128impl fmt::Display for TcbGeometry {
129 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130 write!(
131 f,
132 "TCB({}×{}×{}, align={}, AI={:.2})",
133 self.m,
134 self.n,
135 self.k,
136 self.alignment,
137 self.arithmetic_intensity()
138 )
139 }
140}
141// ============================================================================
142// TILE-001: Tiling Levels
143// ============================================================================
144
145/// Tiling hierarchy level
146#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
147pub enum TcbLevel {
148 /// Macro-tile: L3 cache / GPU global memory partitioning
149 Macro,
150 /// Midi-tile: L2 cache / GPU shared memory
151 Midi,
152 /// Micro-tile: Registers / SIMD lanes
153 Micro,
154}
155
156impl TcbLevel {
157 /// Get typical cache size for this level (x86_64)
158 #[must_use]
159 pub fn typical_cache_bytes(&self) -> usize {
160 match self {
161 TcbLevel::Macro => 32 * 1024 * 1024, // 32 MB L3
162 TcbLevel::Midi => 256 * 1024, // 256 KB L2
163 TcbLevel::Micro => 32 * 1024, // 32 KB L1
164 }
165 }
166}