use crate::NodeId;
use std::collections::HashMap;
#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum QuantScheme {
Int8Block { block_size: u32 },
Int8BlockAsym { block_size: u32 },
Int4Block { block_size: u32 },
Fp8E4m3,
Fp8E5m2,
GgufQ4K,
GgufQ5K,
GgufQ6K,
GgufQ8K,
GgufQ2K,
GgufQ3K,
Nvfp4Block,
}
impl QuantScheme {
pub const fn bits_per_element_x10(self) -> u32 {
match self {
Self::Int8Block { .. } | Self::Int8BlockAsym { .. } => 80,
Self::Int4Block { .. } => 40,
Self::Fp8E4m3 | Self::Fp8E5m2 => 80,
Self::GgufQ4K => 45, Self::GgufQ5K => 55, Self::GgufQ6K => 66, Self::GgufQ8K => 91, Self::GgufQ2K => 26, Self::GgufQ3K => 34, Self::Nvfp4Block => 40,
}
}
pub const fn bits_per_element(self) -> u32 {
self.bits_per_element_x10() / 10
}
pub const fn has_scale(self) -> bool {
matches!(
self,
Self::Int8Block { .. }
| Self::Int8BlockAsym { .. }
| Self::Int4Block { .. }
| Self::Nvfp4Block
)
}
pub const fn scale_is_fp8(self) -> bool {
matches!(self, Self::Nvfp4Block)
}
pub const fn nvfp4_group_size(self) -> u32 {
match self {
Self::Nvfp4Block => crate::nvfp4::NVFP4_GROUP_SIZE as u32,
_ => 0,
}
}
pub const fn has_zero_point(self) -> bool {
matches!(self, Self::Int8BlockAsym { .. })
}
pub const fn gguf_block_size(self) -> u32 {
match self {
Self::GgufQ4K
| Self::GgufQ5K
| Self::GgufQ6K
| Self::GgufQ8K
| Self::GgufQ2K
| Self::GgufQ3K => 256,
_ => 0,
}
}
pub const fn gguf_block_bytes(self) -> u32 {
match self {
Self::GgufQ4K => 144, Self::GgufQ5K => 176, Self::GgufQ6K => 210, Self::GgufQ8K => 292, Self::GgufQ2K => 84, Self::GgufQ3K => 110, _ => 0,
}
}
pub const fn is_gguf(self) -> bool {
matches!(
self,
Self::GgufQ4K
| Self::GgufQ5K
| Self::GgufQ6K
| Self::GgufQ8K
| Self::GgufQ2K
| Self::GgufQ3K
)
}
}
impl std::fmt::Display for QuantScheme {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Int8Block { block_size } => write!(f, "int8/{block_size}"),
Self::Int8BlockAsym { block_size } => write!(f, "int8a/{block_size}"),
Self::Int4Block { block_size } => write!(f, "int4/{block_size}"),
Self::Fp8E4m3 => write!(f, "fp8e4m3"),
Self::Fp8E5m2 => write!(f, "fp8e5m2"),
Self::GgufQ4K => write!(f, "gguf_q4k"),
Self::GgufQ5K => write!(f, "gguf_q5k"),
Self::GgufQ6K => write!(f, "gguf_q6k"),
Self::GgufQ8K => write!(f, "gguf_q8k"),
Self::GgufQ2K => write!(f, "gguf_q2k"),
Self::GgufQ3K => write!(f, "gguf_q3k"),
Self::Nvfp4Block => write!(f, "nvfp4/16"),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct QuantMap {
map: HashMap<NodeId, QuantScheme>,
}
impl QuantMap {
pub fn new() -> Self {
Self::default()
}
pub fn get(&self, id: NodeId) -> Option<QuantScheme> {
self.map.get(&id).copied()
}
pub fn insert(&mut self, id: NodeId, scheme: QuantScheme) -> Option<QuantScheme> {
self.map.insert(id, scheme)
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn iter(&self) -> impl Iterator<Item = (&NodeId, &QuantScheme)> {
self.map.iter()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn scheme_traits() {
assert_eq!(
QuantScheme::Int4Block { block_size: 32 }.bits_per_element(),
4
);
assert!(QuantScheme::Int8BlockAsym { block_size: 64 }.has_zero_point());
assert!(!QuantScheme::Fp8E4m3.has_scale());
}
#[test]
fn quant_map_lookup() {
let mut q = QuantMap::new();
let id = NodeId(7);
q.insert(id, QuantScheme::Int8Block { block_size: 32 });
assert_eq!(q.get(id), Some(QuantScheme::Int8Block { block_size: 32 }));
assert_eq!(q.get(NodeId(99)), None);
}
}