pub const QUANT_BLOCK_SIZE: usize = 32;
pub const BITS_PER_BYTE: usize = 8;
pub const DEFAULT_ROTATION_SEED: u64 = 42;
pub const DEFAULT_QJL_SEED: u64 = 12345;
#[derive(Clone, Copy, Debug, PartialEq, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum QuantNormMode {
L2Norm,
#[default]
MaxNorm,
}
impl std::str::FromStr for QuantNormMode {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"maxnorm" => Ok(Self::MaxNorm),
"l2norm" => Ok(Self::L2Norm),
other => Err(format!(
"Unknown norm mode `{other}`. Options: maxnorm, l2norm"
)),
}
}
}
#[derive(Clone, Debug)]
pub struct CacheConfig {
pub bits: u8,
pub head_dim: usize,
pub num_kv_heads: usize,
pub num_layers: usize,
pub norm_mode: QuantNormMode,
pub outlier_blocks: usize,
}
impl CacheConfig {
pub fn qjl_enabled(&self) -> bool {
self.outlier_blocks == 0
}
pub fn num_blocks(&self) -> usize {
self.head_dim / QUANT_BLOCK_SIZE
}
pub fn packed_dim(&self) -> usize {
self.head_dim * self.bits as usize / BITS_PER_BYTE
}
}
impl std::fmt::Display for QuantNormMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MaxNorm => write!(f, "maxnorm"),
Self::L2Norm => write!(f, "l2norm"),
}
}
}