Skip to main content

trueno/tiling/
config.rs

1//! Tiling configuration and backend selection.
2
3use super::error::TilingError;
4use super::geometry::TcbGeometry;
5use serde::{Deserialize, Serialize};
6
7/// Complete tiling configuration for a kernel
8///
9/// Contains geometry for all three tiling levels, enabling hierarchical
10/// cache-aware execution.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TilingConfig {
13    /// Kernel name for identification
14    pub name: String,
15    /// Macro-tile geometry (L3/Global)
16    pub macro_tile: TcbGeometry,
17    /// Midi-tile geometry (L2/Shared)
18    pub midi_tile: TcbGeometry,
19    /// Micro-tile geometry (Registers)
20    pub micro_tile: TcbGeometry,
21    /// Target backend
22    pub backend: TilingBackend,
23}
24
25/// Backend target for tiling configuration
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
27pub enum TilingBackend {
28    /// CPU with AVX2 (256-bit SIMD)
29    CpuAvx2,
30    /// CPU with AVX-512 (512-bit SIMD)
31    CpuAvx512,
32    /// CPU with NEON (128-bit SIMD)
33    CpuNeon,
34    /// GPU (CUDA/wgpu)
35    Gpu,
36    /// Scalar fallback
37    Scalar,
38}
39
40impl TilingConfig {
41    /// Create configuration for GPU Q4_K MatVec
42    ///
43    /// Optimized for single-token generation where M=1.
44    #[must_use]
45    pub fn gpu_q4k_matvec() -> Self {
46        Self {
47            name: "Q4K_MatVec_GPU".into(),
48            macro_tile: TcbGeometry::with_alignment(1, 4096, 256, 64),
49            midi_tile: TcbGeometry::with_alignment(1, 256, 256, 64),
50            micro_tile: TcbGeometry::with_alignment(1, 32, 256, 64),
51            backend: TilingBackend::Gpu,
52        }
53    }
54
55    /// Create configuration for GPU Q4_K MatMul (batched)
56    ///
57    /// Optimized for prefill where M > 1.
58    #[must_use]
59    pub fn gpu_q4k_matmul() -> Self {
60        Self {
61            name: "Q4K_MatMul_GPU".into(),
62            macro_tile: TcbGeometry::with_alignment(128, 128, 256, 64),
63            midi_tile: TcbGeometry::with_alignment(32, 32, 256, 64),
64            micro_tile: TcbGeometry::with_alignment(8, 8, 256, 64),
65            backend: TilingBackend::Gpu,
66        }
67    }
68
69    /// Create configuration for GPU Softmax
70    #[must_use]
71    pub fn gpu_softmax() -> Self {
72        Self {
73            name: "Softmax_GPU".into(),
74            macro_tile: TcbGeometry::with_alignment(1, 32000, 1, 64),
75            midi_tile: TcbGeometry::with_alignment(1, 1024, 1, 64),
76            micro_tile: TcbGeometry::with_alignment(1, 32, 1, 64),
77            backend: TilingBackend::Gpu,
78        }
79    }
80
81    /// Create configuration for CPU AVX-512 MatMul
82    ///
83    /// Optimized for 512-bit wide SIMD:
84    /// - 16 floats per ZMM register
85    /// - 32 ZMM registers available
86    /// - 4×16 micro-kernel uses 8 registers (4 accumulators + 4 scratch)
87    #[must_use]
88    pub fn cpu_avx512_matmul() -> Self {
89        Self {
90            name: "MatMul_AVX512".into(),
91            macro_tile: TcbGeometry::with_alignment(512, 512, 512, 64),
92            midi_tile: TcbGeometry::with_alignment(128, 128, 128, 64),
93            // 16 floats wide × 4 rows = 64 elements in registers
94            micro_tile: TcbGeometry::with_alignment(4, 16, 128, 64),
95            backend: TilingBackend::CpuAvx512,
96        }
97    }
98
99    /// Create configuration for CPU AVX-512 Q4K MatVec
100    ///
101    /// Optimized for Q4_K quantized inference with 512-bit SIMD.
102    /// Key differences from AVX2:
103    /// - 64-byte aligned for cache line optimization
104    /// - 4×1 micro-kernel processes 4 rows simultaneously
105    /// - K=256 aligned to Q4_K superblock
106    #[must_use]
107    pub fn cpu_avx512_q4k_matvec() -> Self {
108        Self {
109            name: "Q4K_MatVec_AVX512".into(),
110            // Large macro-tile to amortize L3 access
111            macro_tile: TcbGeometry::with_alignment(4096, 1, 4096, 64),
112            // Midi-tile fits in L2 (256KB)
113            // 64 rows × 256 K × 0.5625 bytes/element ≈ 9KB weights
114            midi_tile: TcbGeometry::with_alignment(64, 1, 256, 64),
115            // 4 rows × 1 output, K=256 (Q4_K superblock)
116            micro_tile: TcbGeometry::with_alignment(4, 1, 256, 64),
117            backend: TilingBackend::CpuAvx512,
118        }
119    }
120
121    /// Create configuration for AVX-512 VNNI Q4K×Q8K integer dot product
122    ///
123    /// AVX-512 VNNI (Vector Neural Network Instructions) provides:
124    /// - VPDPBUSD: 8-bit unsigned × 8-bit signed multiply-add to i32
125    /// - VPDPWSSD: 16-bit signed × 16-bit signed multiply-add to i32
126    ///
127    /// This enables pure integer Q4K×Q8K without intermediate f32 conversion.
128    #[must_use]
129    pub fn cpu_avx512_vnni_q4k_q8k() -> Self {
130        Self {
131            name: "Q4K_Q8K_VNNI".into(),
132            macro_tile: TcbGeometry::with_alignment(4096, 1, 4096, 64),
133            midi_tile: TcbGeometry::with_alignment(64, 1, 256, 64),
134            // VNNI processes 64 i8 values per ZMM register
135            micro_tile: TcbGeometry::with_alignment(4, 1, 256, 64),
136            backend: TilingBackend::CpuAvx512,
137        }
138    }
139
140    /// Create configuration for CPU AVX2 MatMul
141    #[must_use]
142    pub fn cpu_avx2_matmul() -> Self {
143        Self {
144            name: "MatMul_AVX2".into(),
145            macro_tile: TcbGeometry::with_alignment(256, 256, 256, 32),
146            midi_tile: TcbGeometry::with_alignment(64, 64, 64, 32),
147            // 8 floats wide × 4 rows = 32 elements in registers
148            micro_tile: TcbGeometry::with_alignment(4, 8, 64, 32),
149            backend: TilingBackend::CpuAvx2,
150        }
151    }
152
153    /// Create configuration for CPU Q4_K MatVec (AVX2)
154    #[must_use]
155    pub fn cpu_avx2_q4k_matvec() -> Self {
156        Self {
157            name: "Q4K_MatVec_AVX2".into(),
158            // Process 4 rows at a time (4×1 micro-kernel)
159            macro_tile: TcbGeometry::with_alignment(4096, 1, 4096, 32),
160            midi_tile: TcbGeometry::with_alignment(64, 1, 256, 32),
161            // 4 rows × 1 output, K=256 (Q4_K superblock)
162            micro_tile: TcbGeometry::with_alignment(4, 1, 256, 32),
163            backend: TilingBackend::CpuAvx2,
164        }
165    }
166
167    /// Create configuration for RMSNorm (CPU)
168    #[must_use]
169    pub fn cpu_rmsnorm() -> Self {
170        Self {
171            name: "RMSNorm_CPU".into(),
172            macro_tile: TcbGeometry::with_alignment(1, 4096, 1, 32),
173            midi_tile: TcbGeometry::with_alignment(1, 256, 1, 32),
174            micro_tile: TcbGeometry::with_alignment(1, 16, 1, 32),
175            backend: TilingBackend::CpuAvx512,
176        }
177    }
178
179    /// Validate that tiling configuration is internally consistent
180    pub fn validate(&self) -> Result<(), TilingError> {
181        // Macro must be >= Midi >= Micro
182        if self.midi_tile.m > self.macro_tile.m
183            || self.midi_tile.n > self.macro_tile.n
184            || self.midi_tile.k > self.macro_tile.k
185        {
186            return Err(TilingError::InvalidHierarchy {
187                reason: "Midi-tile larger than macro-tile".into(),
188            });
189        }
190
191        if self.micro_tile.m > self.midi_tile.m
192            || self.micro_tile.n > self.midi_tile.n
193            || self.micro_tile.k > self.midi_tile.k
194        {
195            return Err(TilingError::InvalidHierarchy {
196                reason: "Micro-tile larger than midi-tile".into(),
197            });
198        }
199
200        // Check divisibility
201        if self.macro_tile.m % self.midi_tile.m != 0 {
202            return Err(TilingError::DivisibilityError {
203                level: "macro/midi",
204                dimension: "M",
205                larger: self.macro_tile.m,
206                smaller: self.midi_tile.m,
207            });
208        }
209
210        if self.midi_tile.m % self.micro_tile.m != 0 {
211            return Err(TilingError::DivisibilityError {
212                level: "midi/micro",
213                dimension: "M",
214                larger: self.midi_tile.m,
215                smaller: self.micro_tile.m,
216            });
217        }
218
219        Ok(())
220    }
221
222    /// Calculate total number of macro-tiles for given problem size
223    #[must_use]
224    pub fn num_macro_tiles(&self, m: u32, n: u32) -> u32 {
225        let m_tiles = (m + self.macro_tile.m - 1) / self.macro_tile.m;
226        let n_tiles = (n + self.macro_tile.n - 1) / self.macro_tile.n;
227        m_tiles * n_tiles
228    }
229
230    /// Calculate total number of midi-tiles within a macro-tile
231    #[must_use]
232    pub fn midi_tiles_per_macro(&self) -> u32 {
233        let m_tiles = self.macro_tile.m / self.midi_tile.m;
234        let n_tiles = self.macro_tile.n / self.midi_tile.n;
235        m_tiles * n_tiles
236    }
237
238    /// Calculate total number of micro-tiles within a midi-tile
239    #[must_use]
240    pub fn micro_tiles_per_midi(&self) -> u32 {
241        let m_tiles = self.midi_tile.m / self.micro_tile.m;
242        let n_tiles = self.midi_tile.n / self.micro_tile.n;
243        m_tiles * n_tiles
244    }
245}