#![allow(missing_docs)]
use crate::brick::BrickBottleneck;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum QuantType {
Q4_0,
Q4_1,
#[default]
Q4K,
Q5K,
Q6K,
Q8_0,
F16,
F32,
}
impl QuantType {
pub fn to_index(self) -> usize {
match self {
QuantType::Q4_0 => 0,
QuantType::Q4_1 => 1,
QuantType::Q4K => 2,
QuantType::Q5K => 3,
QuantType::Q6K => 4,
QuantType::Q8_0 => 5,
QuantType::F16 => 6,
QuantType::F32 => 7,
}
}
pub fn bytes_per_param(self) -> f32 {
match self {
QuantType::Q4_0 | QuantType::Q4_1 | QuantType::Q4K => 0.5625, QuantType::Q5K => 0.6875, QuantType::Q6K => 0.8125, QuantType::Q8_0 => 1.0,
QuantType::F16 => 2.0,
QuantType::F32 => 4.0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum KernelType {
#[default]
TiledQ4K,
CoalescedQ4K,
VectorizedQ4K,
BatchedQ4K,
Dp4aQ4K,
FusedRmsNormQ4K,
CoalescedQ6K,
IncrementalAttention,
MultiWarpAttention,
BatchedAttention,
RmsNorm,
VectorizedRmsNorm,
BatchedRmsNorm,
Generic,
Unknown,
}
impl KernelType {
pub fn to_index(self) -> usize {
match self {
KernelType::TiledQ4K => 0,
KernelType::CoalescedQ4K => 1,
KernelType::VectorizedQ4K => 2,
KernelType::BatchedQ4K => 3,
KernelType::Dp4aQ4K => 4,
KernelType::FusedRmsNormQ4K => 5,
KernelType::CoalescedQ6K => 6,
KernelType::IncrementalAttention => 7,
KernelType::MultiWarpAttention => 8,
KernelType::BatchedAttention => 9,
KernelType::RmsNorm => 10,
KernelType::VectorizedRmsNorm => 11,
KernelType::BatchedRmsNorm => 12,
KernelType::Generic => 13,
KernelType::Unknown => 14,
}
}
pub fn from_index(idx: usize) -> Self {
match idx {
0 => KernelType::TiledQ4K,
1 => KernelType::CoalescedQ4K,
2 => KernelType::VectorizedQ4K,
3 => KernelType::BatchedQ4K,
4 => KernelType::Dp4aQ4K,
5 => KernelType::FusedRmsNormQ4K,
6 => KernelType::CoalescedQ6K,
7 => KernelType::IncrementalAttention,
8 => KernelType::MultiWarpAttention,
9 => KernelType::BatchedAttention,
10 => KernelType::RmsNorm,
11 => KernelType::VectorizedRmsNorm,
12 => KernelType::BatchedRmsNorm,
13 => KernelType::Generic,
14.. => KernelType::Unknown,
}
}
pub const COUNT: usize = 16;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum BottleneckClass {
#[default]
Unknown,
MemoryBound,
ComputeBound,
LaunchBound,
AttentionBound,
}
impl BottleneckClass {
pub fn from_brick_bottleneck(b: BrickBottleneck) -> Self {
match b {
BrickBottleneck::Memory => BottleneckClass::MemoryBound,
BrickBottleneck::Compute => BottleneckClass::ComputeBound,
BrickBottleneck::Unknown => BottleneckClass::Unknown,
}
}
pub fn recommended_action(self) -> &'static str {
match self {
BottleneckClass::MemoryBound => {
"Increase batch size (M) to amortize weight reads across sequences"
}
BottleneckClass::ComputeBound => {
"Rare for inference; check for redundant computation or use tensor cores"
}
BottleneckClass::LaunchBound => {
"Enable CUDA graphs or fuse kernels to reduce launch overhead"
}
BottleneckClass::AttentionBound => {
"Use Flash Decoding, reduce sequence length, or use batched attention"
}
BottleneckClass::Unknown => "Run profiling to identify bottleneck",
}
}
pub fn to_index(self) -> usize {
match self {
BottleneckClass::Unknown => 0,
BottleneckClass::MemoryBound => 1,
BottleneckClass::ComputeBound => 2,
BottleneckClass::LaunchBound => 3,
BottleneckClass::AttentionBound => 4,
}
}
}
impl std::fmt::Display for BottleneckClass {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BottleneckClass::Unknown => write!(f, "Unknown"),
BottleneckClass::MemoryBound => write!(f, "MemoryBound"),
BottleneckClass::ComputeBound => write!(f, "ComputeBound"),
BottleneckClass::LaunchBound => write!(f, "LaunchBound"),
BottleneckClass::AttentionBound => write!(f, "AttentionBound"),
}
}
}