Skip to main content

trueno/tuner/
types.rs

1#![allow(missing_docs)]
2//! Tuner Type Definitions
3//!
4//! Core enums for quantization, kernel selection, and bottleneck classification.
5
6use crate::brick::BrickBottleneck;
7use serde::{Deserialize, Serialize};
8
9// ============================================================================
10// QuantType
11// ============================================================================
12
13/// Quantization type for feature encoding.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
15pub enum QuantType {
16    Q4_0,
17    Q4_1,
18    #[default]
19    Q4K,
20    Q5K,
21    Q6K,
22    Q8_0,
23    F16,
24    F32,
25}
26
27impl QuantType {
28    /// One-hot encoding index (0-7)
29    pub fn to_index(self) -> usize {
30        match self {
31            QuantType::Q4_0 => 0,
32            QuantType::Q4_1 => 1,
33            QuantType::Q4K => 2,
34            QuantType::Q5K => 3,
35            QuantType::Q6K => 4,
36            QuantType::Q8_0 => 5,
37            QuantType::F16 => 6,
38            QuantType::F32 => 7,
39        }
40    }
41
42    /// Bytes per parameter (approximate)
43    pub fn bytes_per_param(self) -> f32 {
44        contract_pre_bytes_per_param!();
45        match self {
46            QuantType::Q4_0 | QuantType::Q4_1 | QuantType::Q4K => 0.5625, // 4.5 bits
47            QuantType::Q5K => 0.6875,                                     // 5.5 bits
48            QuantType::Q6K => 0.8125,                                     // 6.5 bits
49            QuantType::Q8_0 => 1.0,
50            QuantType::F16 => 2.0,
51            QuantType::F32 => 4.0,
52        }
53    }
54}
55
56// ============================================================================
57// KernelType
58// ============================================================================
59
60/// Kernel type for feature encoding.
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
62pub enum KernelType {
63    // Q4K variants
64    #[default]
65    TiledQ4K,
66    CoalescedQ4K,
67    VectorizedQ4K,
68    BatchedQ4K,
69    Dp4aQ4K,
70    FusedRmsNormQ4K,
71    // Q6K variants
72    CoalescedQ6K,
73    // Attention variants
74    IncrementalAttention,
75    MultiWarpAttention,
76    BatchedAttention,
77    // Normalization
78    RmsNorm,
79    VectorizedRmsNorm,
80    BatchedRmsNorm,
81    // Fused attention projection
82    FusedQKVHwDp4aQ4KGemv,
83    // Other
84    Generic,
85    Unknown,
86}
87
88impl KernelType {
89    /// One-hot encoding index (0-16)
90    pub fn to_index(self) -> usize {
91        match self {
92            KernelType::TiledQ4K => 0,
93            KernelType::CoalescedQ4K => 1,
94            KernelType::VectorizedQ4K => 2,
95            KernelType::BatchedQ4K => 3,
96            KernelType::Dp4aQ4K => 4,
97            KernelType::FusedRmsNormQ4K => 5,
98            KernelType::CoalescedQ6K => 6,
99            KernelType::IncrementalAttention => 7,
100            KernelType::MultiWarpAttention => 8,
101            KernelType::BatchedAttention => 9,
102            KernelType::RmsNorm => 10,
103            KernelType::VectorizedRmsNorm => 11,
104            KernelType::BatchedRmsNorm => 12,
105            KernelType::FusedQKVHwDp4aQ4KGemv => 13,
106            KernelType::Generic => 14,
107            KernelType::Unknown => 15,
108        }
109    }
110
111    /// Convert kernel index to type (inverse of to_index())
112    pub fn from_index(idx: usize) -> Self {
113        match idx {
114            0 => KernelType::TiledQ4K,
115            1 => KernelType::CoalescedQ4K,
116            2 => KernelType::VectorizedQ4K,
117            3 => KernelType::BatchedQ4K,
118            4 => KernelType::Dp4aQ4K,
119            5 => KernelType::FusedRmsNormQ4K,
120            6 => KernelType::CoalescedQ6K,
121            7 => KernelType::IncrementalAttention,
122            8 => KernelType::MultiWarpAttention,
123            9 => KernelType::BatchedAttention,
124            10 => KernelType::RmsNorm,
125            11 => KernelType::VectorizedRmsNorm,
126            12 => KernelType::BatchedRmsNorm,
127            13 => KernelType::FusedQKVHwDp4aQ4KGemv,
128            14 => KernelType::Generic,
129            15.. => KernelType::Unknown,
130        }
131    }
132
133    /// Number of kernel types
134    pub const COUNT: usize = 17;
135}
136
137// ============================================================================
138// BottleneckClass
139// ============================================================================
140
141/// Bottleneck classification for ML model.
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
143pub enum BottleneckClass {
144    #[default]
145    Unknown,
146    MemoryBound,
147    ComputeBound,
148    LaunchBound,
149    AttentionBound,
150}
151
152impl BottleneckClass {
153    /// Convert from BrickBottleneck
154    pub fn from_brick_bottleneck(b: BrickBottleneck) -> Self {
155        match b {
156            BrickBottleneck::Memory => BottleneckClass::MemoryBound,
157            BrickBottleneck::Compute => BottleneckClass::ComputeBound,
158            BrickBottleneck::Unknown => BottleneckClass::Unknown,
159        }
160    }
161
162    /// Recommended action for this bottleneck
163    pub fn recommended_action(self) -> &'static str {
164        match self {
165            BottleneckClass::MemoryBound => {
166                "Increase batch size (M) to amortize weight reads across sequences"
167            }
168            BottleneckClass::ComputeBound => {
169                "Rare for inference; check for redundant computation or use tensor cores"
170            }
171            BottleneckClass::LaunchBound => {
172                "Enable CUDA graphs or fuse kernels to reduce launch overhead"
173            }
174            BottleneckClass::AttentionBound => {
175                "Use Flash Decoding, reduce sequence length, or use batched attention"
176            }
177            BottleneckClass::Unknown => "Run profiling to identify bottleneck",
178        }
179    }
180
181    /// One-hot encoding index (0-4)
182    pub fn to_index(self) -> usize {
183        match self {
184            BottleneckClass::Unknown => 0,
185            BottleneckClass::MemoryBound => 1,
186            BottleneckClass::ComputeBound => 2,
187            BottleneckClass::LaunchBound => 3,
188            BottleneckClass::AttentionBound => 4,
189        }
190    }
191}
192
193impl std::fmt::Display for BottleneckClass {
194    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195        match self {
196            BottleneckClass::Unknown => write!(f, "Unknown"),
197            BottleneckClass::MemoryBound => write!(f, "MemoryBound"),
198            BottleneckClass::ComputeBound => write!(f, "ComputeBound"),
199            BottleneckClass::LaunchBound => write!(f, "LaunchBound"),
200            BottleneckClass::AttentionBound => write!(f, "AttentionBound"),
201        }
202    }
203}