1#![allow(missing_docs)]
2use crate::brick::BrickBottleneck;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
15pub enum QuantType {
16 Q4_0,
17 Q4_1,
18 #[default]
19 Q4K,
20 Q5K,
21 Q6K,
22 Q8_0,
23 F16,
24 F32,
25}
26
27impl QuantType {
28 pub fn to_index(self) -> usize {
30 match self {
31 QuantType::Q4_0 => 0,
32 QuantType::Q4_1 => 1,
33 QuantType::Q4K => 2,
34 QuantType::Q5K => 3,
35 QuantType::Q6K => 4,
36 QuantType::Q8_0 => 5,
37 QuantType::F16 => 6,
38 QuantType::F32 => 7,
39 }
40 }
41
42 pub fn bytes_per_param(self) -> f32 {
44 contract_pre_bytes_per_param!();
45 match self {
46 QuantType::Q4_0 | QuantType::Q4_1 | QuantType::Q4K => 0.5625, QuantType::Q5K => 0.6875, QuantType::Q6K => 0.8125, QuantType::Q8_0 => 1.0,
50 QuantType::F16 => 2.0,
51 QuantType::F32 => 4.0,
52 }
53 }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
62pub enum KernelType {
63 #[default]
65 TiledQ4K,
66 CoalescedQ4K,
67 VectorizedQ4K,
68 BatchedQ4K,
69 Dp4aQ4K,
70 FusedRmsNormQ4K,
71 CoalescedQ6K,
73 IncrementalAttention,
75 MultiWarpAttention,
76 BatchedAttention,
77 RmsNorm,
79 VectorizedRmsNorm,
80 BatchedRmsNorm,
81 FusedQKVHwDp4aQ4KGemv,
83 Generic,
85 Unknown,
86}
87
88impl KernelType {
89 pub fn to_index(self) -> usize {
91 match self {
92 KernelType::TiledQ4K => 0,
93 KernelType::CoalescedQ4K => 1,
94 KernelType::VectorizedQ4K => 2,
95 KernelType::BatchedQ4K => 3,
96 KernelType::Dp4aQ4K => 4,
97 KernelType::FusedRmsNormQ4K => 5,
98 KernelType::CoalescedQ6K => 6,
99 KernelType::IncrementalAttention => 7,
100 KernelType::MultiWarpAttention => 8,
101 KernelType::BatchedAttention => 9,
102 KernelType::RmsNorm => 10,
103 KernelType::VectorizedRmsNorm => 11,
104 KernelType::BatchedRmsNorm => 12,
105 KernelType::FusedQKVHwDp4aQ4KGemv => 13,
106 KernelType::Generic => 14,
107 KernelType::Unknown => 15,
108 }
109 }
110
111 pub fn from_index(idx: usize) -> Self {
113 match idx {
114 0 => KernelType::TiledQ4K,
115 1 => KernelType::CoalescedQ4K,
116 2 => KernelType::VectorizedQ4K,
117 3 => KernelType::BatchedQ4K,
118 4 => KernelType::Dp4aQ4K,
119 5 => KernelType::FusedRmsNormQ4K,
120 6 => KernelType::CoalescedQ6K,
121 7 => KernelType::IncrementalAttention,
122 8 => KernelType::MultiWarpAttention,
123 9 => KernelType::BatchedAttention,
124 10 => KernelType::RmsNorm,
125 11 => KernelType::VectorizedRmsNorm,
126 12 => KernelType::BatchedRmsNorm,
127 13 => KernelType::FusedQKVHwDp4aQ4KGemv,
128 14 => KernelType::Generic,
129 15.. => KernelType::Unknown,
130 }
131 }
132
133 pub const COUNT: usize = 17;
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
143pub enum BottleneckClass {
144 #[default]
145 Unknown,
146 MemoryBound,
147 ComputeBound,
148 LaunchBound,
149 AttentionBound,
150}
151
152impl BottleneckClass {
153 pub fn from_brick_bottleneck(b: BrickBottleneck) -> Self {
155 match b {
156 BrickBottleneck::Memory => BottleneckClass::MemoryBound,
157 BrickBottleneck::Compute => BottleneckClass::ComputeBound,
158 BrickBottleneck::Unknown => BottleneckClass::Unknown,
159 }
160 }
161
162 pub fn recommended_action(self) -> &'static str {
164 match self {
165 BottleneckClass::MemoryBound => {
166 "Increase batch size (M) to amortize weight reads across sequences"
167 }
168 BottleneckClass::ComputeBound => {
169 "Rare for inference; check for redundant computation or use tensor cores"
170 }
171 BottleneckClass::LaunchBound => {
172 "Enable CUDA graphs or fuse kernels to reduce launch overhead"
173 }
174 BottleneckClass::AttentionBound => {
175 "Use Flash Decoding, reduce sequence length, or use batched attention"
176 }
177 BottleneckClass::Unknown => "Run profiling to identify bottleneck",
178 }
179 }
180
181 pub fn to_index(self) -> usize {
183 match self {
184 BottleneckClass::Unknown => 0,
185 BottleneckClass::MemoryBound => 1,
186 BottleneckClass::ComputeBound => 2,
187 BottleneckClass::LaunchBound => 3,
188 BottleneckClass::AttentionBound => 4,
189 }
190 }
191}
192
193impl std::fmt::Display for BottleneckClass {
194 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 match self {
196 BottleneckClass::Unknown => write!(f, "Unknown"),
197 BottleneckClass::MemoryBound => write!(f, "MemoryBound"),
198 BottleneckClass::ComputeBound => write!(f, "ComputeBound"),
199 BottleneckClass::LaunchBound => write!(f, "LaunchBound"),
200 BottleneckClass::AttentionBound => write!(f, "AttentionBound"),
201 }
202 }
203}