use half::f16;
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum QuantType {
Q8_0,
Q4_0,
Q4_1,
Q5_0,
Q5_1,
F16,
F32,
}
impl QuantType {
pub fn block_size(&self) -> usize {
match self {
QuantType::Q8_0
| QuantType::Q4_0
| QuantType::Q4_1
| QuantType::Q5_0
| QuantType::Q5_1 => 32,
QuantType::F16 | QuantType::F32 => 1,
}
}
pub fn bytes_per_block(&self) -> usize {
match self {
QuantType::Q8_0 => 2 + 32, QuantType::Q4_0 => 2 + 16, QuantType::Q4_1 => 4 + 16, QuantType::Q5_0 => 2 + 20, QuantType::Q5_1 => 4 + 20, QuantType::F16 => 2,
QuantType::F32 => 4,
}
}
pub fn bits_per_value(&self) -> usize {
match self {
QuantType::Q8_0 => 8,
QuantType::Q4_0 | QuantType::Q4_1 => 4,
QuantType::Q5_0 | QuantType::Q5_1 => 5,
QuantType::F16 => 16,
QuantType::F32 => 32,
}
}
pub fn compression_ratio(&self) -> f32 {
32.0 / self.bits_per_value() as f32
}
pub fn is_block_quantized(&self) -> bool {
matches!(
self,
QuantType::Q8_0 | QuantType::Q4_0 | QuantType::Q4_1 | QuantType::Q5_0 | QuantType::Q5_1
)
}
pub fn parse_type(s: &str) -> Option<Self> {
match s.to_uppercase().as_str() {
"Q8_0" | "Q8" | "INT8" => Some(QuantType::Q8_0),
"Q4_0" | "Q4" | "INT4" => Some(QuantType::Q4_0),
"Q4_1" => Some(QuantType::Q4_1),
"Q5_0" | "Q5" => Some(QuantType::Q5_0),
"Q5_1" => Some(QuantType::Q5_1),
"F16" | "FLOAT16" | "HALF" => Some(QuantType::F16),
"F32" | "FLOAT32" | "FLOAT" => Some(QuantType::F32),
_ => None,
}
}
}
impl std::str::FromStr for QuantType {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
Self::parse_type(s).ok_or_else(|| format!("Unknown quant type: '{s}'"))
}
}
impl fmt::Display for QuantType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
QuantType::Q8_0 => write!(f, "Q8_0"),
QuantType::Q4_0 => write!(f, "Q4_0"),
QuantType::Q4_1 => write!(f, "Q4_1"),
QuantType::Q5_0 => write!(f, "Q5_0"),
QuantType::Q5_1 => write!(f, "Q5_1"),
QuantType::F16 => write!(f, "F16"),
QuantType::F32 => write!(f, "F32"),
}
}
}
#[derive(Debug, Clone)]
pub struct Q8Block {
pub scale: f16,
pub data: [i8; 32],
}
impl Q8Block {
pub fn new(scale: f16, data: [i8; 32]) -> Self {
Self { scale, data }
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(34);
bytes.extend_from_slice(&self.scale.to_le_bytes());
bytes.extend(self.data.iter().map(|&x| x as u8));
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 34 {
return None;
}
let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
let mut data = [0i8; 32];
for (i, &b) in bytes[2..34].iter().enumerate() {
data[i] = b as i8;
}
Some(Self { scale, data })
}
}
#[derive(Debug, Clone)]
pub struct Q4Block {
pub scale: f16,
pub data: [u8; 16],
}
impl Q4Block {
pub fn new(scale: f16, data: [u8; 16]) -> Self {
Self { scale, data }
}
pub fn unpack(&self) -> [i8; 32] {
let mut result = [0i8; 32];
for i in 0..16 {
let byte = self.data[i];
result[i * 2] = ((byte & 0x0F) as i8) - 8;
result[i * 2 + 1] = ((byte >> 4) as i8) - 8;
}
result
}
pub fn pack(values: &[i8; 32]) -> [u8; 16] {
let mut data = [0u8; 16];
for i in 0..16 {
let low = ((values[i * 2] + 8) as u8) & 0x0F;
let high = ((values[i * 2 + 1] + 8) as u8) & 0x0F;
data[i] = low | (high << 4);
}
data
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(18);
bytes.extend_from_slice(&self.scale.to_le_bytes());
bytes.extend_from_slice(&self.data);
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() < 18 {
return None;
}
let scale = f16::from_le_bytes([bytes[0], bytes[1]]);
let mut data = [0u8; 16];
data.copy_from_slice(&bytes[2..18]);
Some(Self { scale, data })
}
}
#[derive(Debug, Clone)]
pub struct Q4_1Block {
pub scale: f16,
pub min: f16,
pub data: [u8; 16],
}
impl Q4_1Block {
pub fn new(scale: f16, min: f16, data: [u8; 16]) -> Self {
Self { scale, min, data }
}
pub fn unpack(&self) -> [u8; 32] {
let mut result = [0u8; 32];
for i in 0..16 {
let byte = self.data[i];
result[i * 2] = byte & 0x0F;
result[i * 2 + 1] = byte >> 4;
}
result
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(20);
bytes.extend_from_slice(&self.scale.to_le_bytes());
bytes.extend_from_slice(&self.min.to_le_bytes());
bytes.extend_from_slice(&self.data);
bytes
}
}
#[derive(Debug, Clone)]
pub struct Q5Block {
pub scale: f16,
pub data: [u8; 20],
}
impl Q5Block {
pub fn new(scale: f16, data: [u8; 20]) -> Self {
Self { scale, data }
}
pub fn pack(values: &[i8; 32]) -> [u8; 20] {
let mut packed = [0u8; 20];
#[allow(clippy::needless_range_loop)]
for i in 0..32 {
let v = (values[i] as u8) & 0x1F; let bit_offset = i * 5;
let byte_offset = bit_offset / 8;
let bit_shift = bit_offset % 8;
packed[byte_offset] |= v << bit_shift;
if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
packed[byte_offset + 1] |= v >> (8 - bit_shift);
}
}
packed
}
pub fn unpack(&self) -> [i8; 32] {
let mut result = [0i8; 32];
#[allow(clippy::needless_range_loop)]
for i in 0..32 {
let bit_offset = i * 5;
let byte_offset = bit_offset / 8;
let bit_shift = bit_offset % 8;
let mut v = (self.data[byte_offset] >> bit_shift) & 0x1F;
if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
v |= (self.data[byte_offset + 1] << (8 - bit_shift)) & 0x1F;
}
if v & 0x10 != 0 {
result[i] = (v | 0xE0) as i8; } else {
result[i] = v as i8;
}
}
result
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(22);
bytes.extend_from_slice(&self.scale.to_le_bytes());
bytes.extend_from_slice(&self.data);
bytes
}
}
#[derive(Debug, Clone)]
pub struct Q5_1Block {
pub scale: f16,
pub min: f16,
pub data: [u8; 20],
}
impl Q5_1Block {
pub fn new(scale: f16, min: f16, data: [u8; 20]) -> Self {
Self { scale, min, data }
}
pub fn pack(values: &[u8; 32]) -> [u8; 20] {
let mut packed = [0u8; 20];
#[allow(clippy::needless_range_loop)]
for i in 0..32 {
let v = values[i] & 0x1F;
let bit_offset = i * 5;
let byte_offset = bit_offset / 8;
let bit_shift = bit_offset % 8;
packed[byte_offset] |= v << bit_shift;
if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
packed[byte_offset + 1] |= v >> (8 - bit_shift);
}
}
packed
}
pub fn unpack(&self) -> [u8; 32] {
let mut result = [0u8; 32];
#[allow(clippy::needless_range_loop)]
for i in 0..32 {
let bit_offset = i * 5;
let byte_offset = bit_offset / 8;
let bit_shift = bit_offset % 8;
let mut v = (self.data[byte_offset] >> bit_shift) & 0x1F;
if bit_shift + 5 > 8 && byte_offset + 1 < 20 {
v |= (self.data[byte_offset + 1] << (8 - bit_shift)) & 0x1F;
}
result[i] = v;
}
result
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(24);
bytes.extend_from_slice(&self.scale.to_le_bytes());
bytes.extend_from_slice(&self.min.to_le_bytes());
bytes.extend_from_slice(&self.data);
bytes
}
}
#[derive(Debug, Clone)]
pub enum QuantizedBlock {
Q8(Q8Block),
Q4(Q4Block),
Q4_1(Q4_1Block),
Q5(Q5Block),
Q5_1(Q5_1Block),
F16(Vec<f16>),
F32(Vec<f32>),
}
impl QuantizedBlock {
pub fn quant_type(&self) -> QuantType {
match self {
QuantizedBlock::Q8(_) => QuantType::Q8_0,
QuantizedBlock::Q4(_) => QuantType::Q4_0,
QuantizedBlock::Q4_1(_) => QuantType::Q4_1,
QuantizedBlock::Q5(_) => QuantType::Q5_0,
QuantizedBlock::Q5_1(_) => QuantType::Q5_1,
QuantizedBlock::F16(_) => QuantType::F16,
QuantizedBlock::F32(_) => QuantType::F32,
}
}
}
#[derive(Debug, Clone)]
pub struct QuantizedTensor {
pub shape: Vec<usize>,
pub quant_type: QuantType,
pub blocks: Vec<QuantizedBlock>,
pub numel: usize,
}
impl QuantizedTensor {
pub fn new(shape: Vec<usize>, quant_type: QuantType, blocks: Vec<QuantizedBlock>) -> Self {
let numel = shape.iter().product();
Self {
shape,
quant_type,
blocks,
numel,
}
}
pub fn size_bytes(&self) -> usize {
self.blocks.len() * self.quant_type.bytes_per_block()
}
pub fn compression_ratio(&self) -> f32 {
let original_bytes = self.numel * 4;
original_bytes as f32 / self.size_bytes() as f32
}
pub fn num_blocks(&self) -> usize {
self.blocks.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quant_type_properties() {
assert_eq!(QuantType::Q8_0.block_size(), 32);
assert_eq!(QuantType::Q4_0.block_size(), 32);
assert_eq!(QuantType::F16.block_size(), 1);
assert_eq!(QuantType::Q8_0.bits_per_value(), 8);
assert_eq!(QuantType::Q4_0.bits_per_value(), 4);
assert!(QuantType::Q8_0.is_block_quantized());
assert!(!QuantType::F16.is_block_quantized());
}
#[test]
fn test_quant_type_from_str() {
assert_eq!(QuantType::parse_type("Q8_0"), Some(QuantType::Q8_0));
assert_eq!(QuantType::parse_type("INT8"), Some(QuantType::Q8_0));
assert_eq!(QuantType::parse_type("Q4"), Some(QuantType::Q4_0));
assert_eq!(QuantType::parse_type("F16"), Some(QuantType::F16));
assert_eq!(QuantType::parse_type("invalid"), None);
}
#[test]
fn test_q4_pack_unpack() {
let values: [i8; 32] = [
-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1,
0, 1, 2, 3, 4, 5, 6, 7,
];
let packed = Q4Block::pack(&values);
let block = Q4Block::new(f16::from_f32(1.0), packed);
let unpacked = block.unpack();
assert_eq!(values, unpacked);
}
#[test]
fn test_q8_block() {
let data = [0i8; 32];
let block = Q8Block::new(f16::from_f32(0.5), data);
let bytes = block.to_bytes();
let restored = Q8Block::from_bytes(&bytes).unwrap();
assert_eq!(block.scale, restored.scale);
assert_eq!(block.data, restored.data);
}
}