use crate::error::{AprenderError, Result};
use half::f16;
use serde::{Deserialize, Serialize};
pub const BLOCK_SIZE: usize = 32;
pub const Q8_0_BLOCK_BYTES: usize = 34;
pub const Q4_0_BLOCK_BYTES: usize = 18;
pub const Q4_1_BLOCK_BYTES: usize = 20;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[repr(u8)]
pub enum QuantType {
Q8_0 = 0x01,
Q4_0 = 0x02,
Q4_1 = 0x03,
Q8Tensor = 0x10,
Custom = 0xFF,
}
impl QuantType {
pub fn from_u8(value: u8) -> Option<Self> {
match value {
0x01 => Some(Self::Q8_0),
0x02 => Some(Self::Q4_0),
0x03 => Some(Self::Q4_1),
0x10 => Some(Self::Q8Tensor),
0xFF => Some(Self::Custom),
_ => None,
}
}
#[must_use]
pub fn bits_per_weight(&self) -> f32 {
match self {
Self::Q8_0 => 8.5, Self::Q4_0 => 4.5, Self::Q4_1 => 5.0, Self::Q8Tensor => 8.0,
Self::Custom => 0.0, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizedBlock {
pub quant_type: QuantType,
pub shape: Vec<usize>,
pub blocks: Vec<u8>,
pub block_size: usize,
}
impl QuantizedBlock {
#[must_use]
pub fn num_blocks(&self) -> usize {
let total_elements: usize = self.shape.iter().product();
(total_elements + self.block_size - 1) / self.block_size
}
#[must_use]
pub fn num_elements(&self) -> usize {
self.shape.iter().product()
}
#[must_use]
pub fn size_bytes(&self) -> usize {
self.blocks.len()
}
#[must_use]
pub fn original_size_bytes(&self) -> usize {
self.num_elements() * 4
}
#[must_use]
pub fn compression_ratio(&self) -> f32 {
if self.blocks.is_empty() {
return 1.0;
}
self.original_size_bytes() as f32 / self.size_bytes() as f32
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizedTensor {
pub quant_type: QuantType,
pub shape: Vec<usize>,
pub data: Vec<i8>,
pub scale: f32,
pub zero_point: i8,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizationInfo {
pub quant_type: QuantType,
pub calibration_method: String,
pub calibration_samples: u32,
pub original_dtype: String,
pub quantization_error: Option<f32>,
}
impl Default for QuantizationInfo {
fn default() -> Self {
Self {
quant_type: QuantType::Q8_0,
calibration_method: "minmax".to_string(),
calibration_samples: 0,
original_dtype: "f32".to_string(),
quantization_error: None,
}
}
}
pub trait Quantizer: Send + Sync {
fn name(&self) -> &'static str;
fn quantize(&self, data: &[f32], shape: &[usize]) -> Result<QuantizedBlock>;
fn dequantize(&self, block: &QuantizedBlock) -> Result<Vec<f32>>;
fn bits_per_weight(&self) -> f32;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Q8_0Quantizer;
impl Quantizer for Q8_0Quantizer {
fn name(&self) -> &'static str {
"Q8_0"
}
fn quantize(&self, data: &[f32], shape: &[usize]) -> Result<QuantizedBlock> {
let expected_len: usize = shape.iter().product();
if data.len() != expected_len {
return Err(AprenderError::DimensionMismatch {
expected: expected_len.to_string(),
actual: data.len().to_string(),
});
}
let num_blocks = (data.len() + BLOCK_SIZE - 1) / BLOCK_SIZE;
let mut blocks = Vec::with_capacity(num_blocks * Q8_0_BLOCK_BYTES);
for block_idx in 0..num_blocks {
let start = block_idx * BLOCK_SIZE;
let end = (start + BLOCK_SIZE).min(data.len());
let block_data = &data[start..end];
let max_abs = block_data.iter().map(|x| x.abs()).fold(0.0_f32, f32::max);
let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };
let inv_scale = if scale > 0.0 { 1.0 / scale } else { 0.0 };
let scale_f16 = f16::from_f32(scale);
blocks.extend_from_slice(&scale_f16.to_le_bytes());
for &val in block_data {
let q = (val * inv_scale).round().clamp(-127.0, 127.0) as i8;
blocks.push(q as u8);
}
let padding_count = BLOCK_SIZE - block_data.len();
if padding_count > 0 {
blocks.resize(blocks.len() + padding_count, 0);
}
}
Ok(QuantizedBlock {
quant_type: QuantType::Q8_0,
shape: shape.to_vec(),
blocks,
block_size: BLOCK_SIZE,
})
}
fn dequantize(&self, block: &QuantizedBlock) -> Result<Vec<f32>> {
if block.quant_type != QuantType::Q8_0 {
return Err(AprenderError::FormatError {
message: format!("Expected Q8_0 block, got {:?}", block.quant_type),
});
}
let total_elements: usize = block.shape.iter().product();
let num_blocks = block.num_blocks();
if block.blocks.len() != num_blocks * Q8_0_BLOCK_BYTES {
return Err(AprenderError::FormatError {
message: format!(
"Invalid Q8_0 block data size: expected {}, got {}",
num_blocks * Q8_0_BLOCK_BYTES,
block.blocks.len()
),
});
}
let mut result = Vec::with_capacity(total_elements);
for block_idx in 0..num_blocks {
let block_start = block_idx * Q8_0_BLOCK_BYTES;
let scale_bytes = [block.blocks[block_start], block.blocks[block_start + 1]];
let scale = f16::from_le_bytes(scale_bytes).to_f32();
let quants_start = block_start + 2;
let elements_in_block = if block_idx == num_blocks - 1 {
let remaining = total_elements % BLOCK_SIZE;
if remaining == 0 {
BLOCK_SIZE
} else {
remaining
}
} else {
BLOCK_SIZE
};
for i in 0..elements_in_block {
let q = block.blocks[quants_start + i] as i8;
let val = f32::from(q) * scale;
result.push(val);
}
}
Ok(result)
}
fn bits_per_weight(&self) -> f32 {
8.5
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Q4_0Quantizer;
impl Quantizer for Q4_0Quantizer {
fn name(&self) -> &'static str {
"Q4_0"
}
fn quantize(&self, data: &[f32], shape: &[usize]) -> Result<QuantizedBlock> {
let expected_len: usize = shape.iter().product();
if data.len() != expected_len {
return Err(AprenderError::DimensionMismatch {
expected: expected_len.to_string(),
actual: data.len().to_string(),
});
}
let num_blocks = (data.len() + BLOCK_SIZE - 1) / BLOCK_SIZE;
let mut blocks = Vec::with_capacity(num_blocks * Q4_0_BLOCK_BYTES);
for block_idx in 0..num_blocks {
let start = block_idx * BLOCK_SIZE;
let end = (start + BLOCK_SIZE).min(data.len());
let block_data = &data[start..end];
let max_abs = block_data.iter().map(|x| x.abs()).fold(0.0_f32, f32::max);
let scale = if max_abs > 0.0 { max_abs / 7.0 } else { 1.0 };
let inv_scale = if scale > 0.0 { 1.0 / scale } else { 0.0 };
let scale_f16 = f16::from_f32(scale);
blocks.extend_from_slice(&scale_f16.to_le_bytes());
let mut padded_data = block_data.to_vec();
padded_data.resize(BLOCK_SIZE, 0.0);
for i in (0..BLOCK_SIZE).step_by(2) {
let q0 = ((padded_data[i] * inv_scale).round().clamp(-8.0, 7.0) as i8 + 8) as u8;
let q1 =
((padded_data[i + 1] * inv_scale).round().clamp(-8.0, 7.0) as i8 + 8) as u8;
let packed = (q0 & 0x0F) | ((q1 & 0x0F) << 4);
blocks.push(packed);
}
}
Ok(QuantizedBlock {
quant_type: QuantType::Q4_0,
shape: shape.to_vec(),
blocks,
block_size: BLOCK_SIZE,
})
}
fn dequantize(&self, block: &QuantizedBlock) -> Result<Vec<f32>> {
if block.quant_type != QuantType::Q4_0 {
return Err(AprenderError::FormatError {
message: format!("Expected Q4_0 block, got {:?}", block.quant_type),
});
}
let total_elements: usize = block.shape.iter().product();
let num_blocks = block.num_blocks();
if block.blocks.len() != num_blocks * Q4_0_BLOCK_BYTES {
return Err(AprenderError::FormatError {
message: format!(
"Invalid Q4_0 block data size: expected {}, got {}",
num_blocks * Q4_0_BLOCK_BYTES,
block.blocks.len()
),
});
}
let mut result = Vec::with_capacity(total_elements);
for block_idx in 0..num_blocks {
let block_start = block_idx * Q4_0_BLOCK_BYTES;
let scale_bytes = [block.blocks[block_start], block.blocks[block_start + 1]];
let scale = f16::from_le_bytes(scale_bytes).to_f32();
let quants_start = block_start + 2;
let elements_in_block = if block_idx == num_blocks - 1 {
let remaining = total_elements % BLOCK_SIZE;
if remaining == 0 {
BLOCK_SIZE
} else {
remaining
}
} else {
BLOCK_SIZE
};
for i in 0..(elements_in_block + 1) / 2 {
let packed = block.blocks[quants_start + i];
let q0 = (packed & 0x0F) as i8 - 8;
let q1 = ((packed >> 4) & 0x0F) as i8 - 8;
result.push(f32::from(q0) * scale);
if result.len() < total_elements && (i * 2 + 1) < elements_in_block {
result.push(f32::from(q1) * scale);
}
}
}
result.truncate(total_elements);
Ok(result)
}
fn bits_per_weight(&self) -> f32 {
4.5
}
}
pub fn quantize(data: &[f32], shape: &[usize], quant_type: QuantType) -> Result<QuantizedBlock> {
match quant_type {
QuantType::Q8_0 => Q8_0Quantizer.quantize(data, shape),
QuantType::Q4_0 => Q4_0Quantizer.quantize(data, shape),
QuantType::Q4_1 => Err(AprenderError::FormatError {
message: "Q4_1 quantization not yet implemented".to_string(),
}),
QuantType::Q8Tensor => Err(AprenderError::FormatError {
message: "Q8Tensor quantization not yet implemented".to_string(),
}),
QuantType::Custom => Err(AprenderError::FormatError {
message: "Custom quantization requires a custom Quantizer implementation".to_string(),
}),
}
}
include!("quantize_dequant.rs");
include!("generate.rs");
include!("quantize_tests_falsification.rs");