use crate::error::{Error, Result};
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum QuantFormat {
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q8_0,
Q8_1,
Q2K,
Q3K,
Q4K,
Q5K,
Q6K,
Q8K,
IQ1S,
IQ1M,
IQ2XXS,
IQ2XS,
IQ2S,
IQ3XXS,
IQ3S,
IQ4NL,
IQ4XS,
TQ1_0,
TQ2_0,
}
impl QuantFormat {
pub const fn block_size(self) -> usize {
match self {
Self::Q4_0
| Self::Q4_1
| Self::Q5_0
| Self::Q5_1
| Self::Q8_0
| Self::Q8_1
| Self::IQ4NL => 32,
Self::Q2K
| Self::Q3K
| Self::Q4K
| Self::Q5K
| Self::Q6K
| Self::Q8K
| Self::IQ1S
| Self::IQ1M
| Self::IQ2XXS
| Self::IQ2XS
| Self::IQ2S
| Self::IQ3XXS
| Self::IQ3S
| Self::IQ4XS
| Self::TQ1_0
| Self::TQ2_0 => 256,
}
}
pub const fn block_bytes(self) -> usize {
match self {
Self::Q4_0 => 18,
Self::Q4_1 => 20,
Self::Q5_0 => 22,
Self::Q5_1 => 24,
Self::Q8_0 => 34,
Self::Q8_1 => 36,
Self::Q2K => 84,
Self::Q3K => 110,
Self::Q4K => 144,
Self::Q5K => 176,
Self::Q6K => 210,
Self::Q8K => 292,
Self::IQ1S => 50,
Self::IQ1M => 56,
Self::IQ2XXS => 66,
Self::IQ2XS => 74,
Self::IQ2S => 82,
Self::IQ3XXS => 98,
Self::IQ3S => 110,
Self::IQ4NL => 18,
Self::IQ4XS => 136,
Self::TQ1_0 => 54,
Self::TQ2_0 => 66,
}
}
pub fn storage_bytes(self, numel: usize) -> Result<usize> {
let bs = self.block_size();
if numel % bs != 0 {
return Err(Error::QuantError {
reason: format!(
"{}: element count {} is not a multiple of block_size {}",
self.name(),
numel,
bs,
),
});
}
Ok((numel / bs) * self.block_bytes())
}
pub fn num_blocks(self, numel: usize) -> Result<usize> {
let bs = self.block_size();
if numel % bs != 0 {
return Err(Error::QuantError {
reason: format!(
"{}: element count {} is not a multiple of block_size {}",
self.name(),
numel,
bs,
),
});
}
Ok(numel / bs)
}
pub const fn ggml_type_id(self) -> u32 {
match self {
Self::Q4_0 => 2,
Self::Q4_1 => 3,
Self::Q5_0 => 6,
Self::Q5_1 => 7,
Self::Q8_0 => 8,
Self::Q8_1 => 9,
Self::Q2K => 10,
Self::Q3K => 11,
Self::Q4K => 12,
Self::Q5K => 13,
Self::Q6K => 14,
Self::Q8K => 15,
Self::IQ2XXS => 16,
Self::IQ2XS => 17,
Self::IQ3XXS => 18,
Self::IQ1S => 19,
Self::IQ4NL => 20,
Self::IQ3S => 21,
Self::IQ2S => 22,
Self::IQ4XS => 23,
Self::IQ1M => 24,
Self::TQ1_0 => 34,
Self::TQ2_0 => 35,
}
}
pub const fn format_id(self) -> u32 {
match self {
Self::Q4_0 => 0,
Self::Q4_1 => 1,
Self::Q5_0 => 2,
Self::Q5_1 => 3,
Self::Q8_0 => 4,
Self::Q8_1 => 5,
Self::Q2K => 6,
Self::Q3K => 7,
Self::Q4K => 8,
Self::Q5K => 9,
Self::Q6K => 10,
Self::Q8K => 11,
Self::IQ1S => 12,
Self::IQ1M => 13,
Self::IQ2XXS => 14,
Self::IQ2XS => 15,
Self::IQ2S => 16,
Self::IQ3XXS => 17,
Self::IQ3S => 18,
Self::IQ4NL => 19,
Self::IQ4XS => 20,
Self::TQ1_0 => 21,
Self::TQ2_0 => 22,
}
}
pub fn from_ggml_type_id(id: u32) -> Result<Self> {
match id {
2 => Ok(Self::Q4_0),
3 => Ok(Self::Q4_1),
6 => Ok(Self::Q5_0),
7 => Ok(Self::Q5_1),
8 => Ok(Self::Q8_0),
9 => Ok(Self::Q8_1),
10 => Ok(Self::Q2K),
11 => Ok(Self::Q3K),
12 => Ok(Self::Q4K),
13 => Ok(Self::Q5K),
14 => Ok(Self::Q6K),
15 => Ok(Self::Q8K),
16 => Ok(Self::IQ2XXS),
17 => Ok(Self::IQ2XS),
18 => Ok(Self::IQ3XXS),
19 => Ok(Self::IQ1S),
20 => Ok(Self::IQ4NL),
21 => Ok(Self::IQ3S),
22 => Ok(Self::IQ2S),
23 => Ok(Self::IQ4XS),
24 => Ok(Self::IQ1M),
34 => Ok(Self::TQ1_0),
35 => Ok(Self::TQ2_0),
_ => Err(Error::UnsupportedQuantFormat {
format: format!("GGML type ID {}", id),
}),
}
}
pub const fn name(self) -> &'static str {
match self {
Self::Q4_0 => "Q4_0",
Self::Q4_1 => "Q4_1",
Self::Q5_0 => "Q5_0",
Self::Q5_1 => "Q5_1",
Self::Q8_0 => "Q8_0",
Self::Q8_1 => "Q8_1",
Self::Q2K => "Q2_K",
Self::Q3K => "Q3_K",
Self::Q4K => "Q4_K",
Self::Q5K => "Q5_K",
Self::Q6K => "Q6_K",
Self::Q8K => "Q8_K",
Self::IQ1S => "IQ1_S",
Self::IQ1M => "IQ1_M",
Self::IQ2XXS => "IQ2_XXS",
Self::IQ2XS => "IQ2_XS",
Self::IQ2S => "IQ2_S",
Self::IQ3XXS => "IQ3_XXS",
Self::IQ3S => "IQ3_S",
Self::IQ4NL => "IQ4_NL",
Self::IQ4XS => "IQ4_XS",
Self::TQ1_0 => "TQ1_0",
Self::TQ2_0 => "TQ2_0",
}
}
}
impl fmt::Display for QuantFormat {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.name())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_block_sizes() {
assert_eq!(QuantFormat::Q4_0.block_size(), 32);
assert_eq!(QuantFormat::Q4K.block_size(), 256);
assert_eq!(QuantFormat::IQ4NL.block_size(), 32);
assert_eq!(QuantFormat::TQ1_0.block_size(), 256);
}
#[test]
fn test_block_bytes() {
assert_eq!(QuantFormat::Q4_0.block_bytes(), 18);
assert_eq!(QuantFormat::Q4K.block_bytes(), 144);
assert_eq!(QuantFormat::Q6K.block_bytes(), 210);
assert_eq!(QuantFormat::Q8K.block_bytes(), 292);
assert_eq!(QuantFormat::IQ4NL.block_bytes(), 18);
}
#[test]
fn test_storage_bytes() {
assert_eq!(QuantFormat::Q4_0.storage_bytes(32).unwrap(), 18);
assert_eq!(QuantFormat::Q4_0.storage_bytes(64).unwrap(), 36);
assert_eq!(QuantFormat::Q4_0.storage_bytes(1024).unwrap(), 576);
assert_eq!(QuantFormat::Q4K.storage_bytes(256).unwrap(), 144);
assert_eq!(QuantFormat::Q4K.storage_bytes(4096).unwrap(), 2304);
}
#[test]
fn test_storage_bytes_alignment_error() {
assert!(QuantFormat::Q4_0.storage_bytes(33).is_err());
assert!(QuantFormat::Q4K.storage_bytes(100).is_err());
}
#[test]
fn test_ggml_roundtrip() {
let formats = [
QuantFormat::Q4_0,
QuantFormat::Q4_1,
QuantFormat::Q5_0,
QuantFormat::Q5_1,
QuantFormat::Q8_0,
QuantFormat::Q8_1,
QuantFormat::Q2K,
QuantFormat::Q3K,
QuantFormat::Q4K,
QuantFormat::Q5K,
QuantFormat::Q6K,
QuantFormat::Q8K,
QuantFormat::IQ1S,
QuantFormat::IQ1M,
QuantFormat::IQ2XXS,
QuantFormat::IQ2XS,
QuantFormat::IQ2S,
QuantFormat::IQ3XXS,
QuantFormat::IQ3S,
QuantFormat::IQ4NL,
QuantFormat::IQ4XS,
QuantFormat::TQ1_0,
QuantFormat::TQ2_0,
];
for fmt in &formats {
let id = fmt.ggml_type_id();
let recovered = QuantFormat::from_ggml_type_id(id).unwrap();
assert_eq!(
*fmt, recovered,
"roundtrip failed for {:?} (id={})",
fmt, id
);
}
}
#[test]
fn test_from_ggml_unknown() {
assert!(QuantFormat::from_ggml_type_id(999).is_err());
}
#[test]
fn test_num_blocks() {
assert_eq!(QuantFormat::Q4_0.num_blocks(32).unwrap(), 1);
assert_eq!(QuantFormat::Q4_0.num_blocks(1024).unwrap(), 32);
assert_eq!(QuantFormat::Q4K.num_blocks(4096).unwrap(), 16);
}
#[test]
fn test_display() {
assert_eq!(format!("{}", QuantFormat::Q4K), "Q4_K");
assert_eq!(format!("{}", QuantFormat::IQ2XXS), "IQ2_XXS");
}
}