use axonml_tensor::Tensor;
use half::f16;
use rayon::prelude::*;
use crate::DEFAULT_BLOCK_SIZE;
use crate::error::QuantResult;
use crate::types::{
Q4_1Block, Q4Block, Q5_1Block, Q5Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor,
};
pub fn quantize_tensor(
tensor: &Tensor<f32>,
quant_type: QuantType,
) -> QuantResult<QuantizedTensor> {
let data = tensor.to_vec();
let shape = tensor.shape().to_vec();
match quant_type {
QuantType::Q8_0 => quantize_q8_0(&data, shape),
QuantType::Q4_0 => quantize_q4_0(&data, shape),
QuantType::Q4_1 => quantize_q4_1(&data, shape),
QuantType::Q5_0 => quantize_q5_0(&data, shape),
QuantType::Q5_1 => quantize_q5_1(&data, shape),
QuantType::F16 => quantize_f16(&data, shape),
QuantType::F32 => quantize_f32(&data, shape),
}
}
pub fn quantize_model(
tensors: &[(&str, &Tensor<f32>)],
quant_type: QuantType,
) -> QuantResult<Vec<(String, QuantizedTensor)>> {
tensors
.par_iter()
.map(|(name, tensor)| {
let quantized = quantize_tensor(tensor, quant_type)?;
Ok((name.to_string(), quantized))
})
.collect()
}
fn quantize_q8_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
let block_size = DEFAULT_BLOCK_SIZE;
let n_blocks = data.len().div_ceil(block_size);
let blocks: Vec<QuantizedBlock> = (0..n_blocks)
.into_par_iter()
.map(|block_idx| {
let start = block_idx * block_size;
let end = (start + block_size).min(data.len());
let block_data = &data[start..end];
let max_abs = block_data
.iter()
.map(|x| x.abs())
.fold(0.0f32, |a, b| a.max(b));
let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };
let mut quantized = [0i8; 32];
for (i, &val) in block_data.iter().enumerate() {
let q = (val / scale).round().clamp(-127.0, 127.0) as i8;
quantized[i] = q;
}
QuantizedBlock::Q8(Q8Block::new(f16::from_f32(scale), quantized))
})
.collect();
Ok(QuantizedTensor::new(shape, QuantType::Q8_0, blocks))
}
fn quantize_q4_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
let block_size = DEFAULT_BLOCK_SIZE;
let n_blocks = data.len().div_ceil(block_size);
let blocks: Vec<QuantizedBlock> = (0..n_blocks)
.into_par_iter()
.map(|block_idx| {
let start = block_idx * block_size;
let end = (start + block_size).min(data.len());
let block_data = &data[start..end];
let max_abs = block_data
.iter()
.map(|x| x.abs())
.fold(0.0f32, |a, b| a.max(b));
let scale = if max_abs > 0.0 { max_abs / 7.0 } else { 1.0 };
let mut quantized = [0i8; 32];
for (i, &val) in block_data.iter().enumerate() {
let q = (val / scale).round().clamp(-8.0, 7.0) as i8;
quantized[i] = q;
}
let packed = Q4Block::pack(&quantized);
QuantizedBlock::Q4(Q4Block::new(f16::from_f32(scale), packed))
})
.collect();
Ok(QuantizedTensor::new(shape, QuantType::Q4_0, blocks))
}
fn quantize_q4_1(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
let block_size = DEFAULT_BLOCK_SIZE;
let n_blocks = data.len().div_ceil(block_size);
let blocks: Vec<QuantizedBlock> = (0..n_blocks)
.into_par_iter()
.map(|block_idx| {
let start = block_idx * block_size;
let end = (start + block_size).min(data.len());
let block_data = &data[start..end];
let min = block_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max = block_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let scale = if max > min { (max - min) / 15.0 } else { 1.0 };
let mut quantized = [0u8; 32];
for (i, &val) in block_data.iter().enumerate() {
let q = ((val - min) / scale).round().clamp(0.0, 15.0) as u8;
quantized[i] = q;
}
let mut packed = [0u8; 16];
for i in 0..16.min(block_data.len() / 2) {
let low = quantized[i * 2] & 0x0F;
let high = quantized.get(i * 2 + 1).copied().unwrap_or(0) & 0x0F;
packed[i] = low | (high << 4);
}
QuantizedBlock::Q4_1(Q4_1Block::new(
f16::from_f32(scale),
f16::from_f32(min),
packed,
))
})
.collect();
Ok(QuantizedTensor::new(shape, QuantType::Q4_1, blocks))
}
fn quantize_q5_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
let block_size = DEFAULT_BLOCK_SIZE;
let n_blocks = data.len().div_ceil(block_size);
let blocks: Vec<QuantizedBlock> = (0..n_blocks)
.into_par_iter()
.map(|block_idx| {
let start = block_idx * block_size;
let end = (start + block_size).min(data.len());
let block_data = &data[start..end];
let max_abs = block_data
.iter()
.map(|x| x.abs())
.fold(0.0f32, |a, b| a.max(b));
let scale = if max_abs > 0.0 { max_abs / 15.0 } else { 1.0 };
let mut quantized = [0i8; 32];
for (i, &val) in block_data.iter().enumerate() {
let q = (val / scale).round().clamp(-16.0, 15.0) as i8;
quantized[i] = q;
}
let packed = Q5Block::pack(&quantized);
QuantizedBlock::Q5(Q5Block::new(f16::from_f32(scale), packed))
})
.collect();
Ok(QuantizedTensor::new(shape, QuantType::Q5_0, blocks))
}
fn quantize_q5_1(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
let block_size = DEFAULT_BLOCK_SIZE;
let n_blocks = data.len().div_ceil(block_size);
let blocks: Vec<QuantizedBlock> = (0..n_blocks)
.into_par_iter()
.map(|block_idx| {
let start = block_idx * block_size;
let end = (start + block_size).min(data.len());
let block_data = &data[start..end];
let min = block_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max = block_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let scale = if max > min { (max - min) / 31.0 } else { 1.0 };
let mut quantized = [0u8; 32];
for (i, &val) in block_data.iter().enumerate() {
let q = ((val - min) / scale).round().clamp(0.0, 31.0) as u8;
quantized[i] = q;
}
let packed = Q5_1Block::pack(&quantized);
QuantizedBlock::Q5_1(Q5_1Block::new(
f16::from_f32(scale),
f16::from_f32(min),
packed,
))
})
.collect();
Ok(QuantizedTensor::new(shape, QuantType::Q5_1, blocks))
}
fn quantize_f16(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
let f16_data: Vec<f16> = data.par_iter().map(|&x| f16::from_f32(x)).collect();
let blocks = vec![QuantizedBlock::F16(f16_data)];
Ok(QuantizedTensor::new(shape, QuantType::F16, blocks))
}
fn quantize_f32(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
let blocks = vec![QuantizedBlock::F32(data.to_vec())];
Ok(QuantizedTensor::new(shape, QuantType::F32, blocks))
}
pub fn compute_quantization_error(original: &[f32], dequantized: &[f32]) -> f32 {
if original.len() != dequantized.len() || original.is_empty() {
return f32::INFINITY;
}
let mse: f32 = original
.iter()
.zip(dequantized.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
/ original.len() as f32;
mse.sqrt()
}
pub struct QuantizationStats {
pub rmse: f32,
pub max_error: f32,
pub mean_error: f32,
pub compression_ratio: f32,
}
pub fn compute_quantization_stats(
original: &[f32],
dequantized: &[f32],
quant_type: QuantType,
) -> QuantizationStats {
let errors: Vec<f32> = original
.iter()
.zip(dequantized.iter())
.map(|(a, b)| (a - b).abs())
.collect();
let mse: f32 = errors.iter().map(|e| e.powi(2)).sum::<f32>() / errors.len() as f32;
let max_error = errors.iter().fold(0.0f32, |a, &b| a.max(b));
let mean_error = errors.iter().sum::<f32>() / errors.len() as f32;
QuantizationStats {
rmse: mse.sqrt(),
max_error,
mean_error,
compression_ratio: quant_type.compression_ratio(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantize_q8_0() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let tensor = Tensor::from_vec(data.clone(), &[8]).unwrap();
let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
assert_eq!(quantized.quant_type, QuantType::Q8_0);
assert_eq!(quantized.shape, vec![8]);
assert_eq!(quantized.num_blocks(), 1);
}
#[test]
fn test_quantize_q4_0() {
let data: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
let tensor = Tensor::from_vec(data.clone(), &[64]).unwrap();
let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
assert_eq!(quantized.quant_type, QuantType::Q4_0);
assert_eq!(quantized.num_blocks(), 2);
}
#[test]
fn test_quantize_f16() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let tensor = Tensor::from_vec(data.clone(), &[4]).unwrap();
let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
assert_eq!(quantized.quant_type, QuantType::F16);
}
#[test]
fn test_compression_ratio() {
let data: Vec<f32> = (0..256).map(|x| x as f32).collect();
let tensor = Tensor::from_vec(data, &[256]).unwrap();
let q8 = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
let q4 = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
assert!(q8.compression_ratio() > 2.0);
assert!(q4.compression_ratio() > q8.compression_ratio());
}
#[test]
fn test_quantization_error() {
let original = vec![1.0, 2.0, 3.0, 4.0];
let dequantized = vec![1.1, 2.0, 2.9, 4.1];
let rmse = compute_quantization_error(&original, &dequantized);
assert!(rmse > 0.0);
assert!(rmse < 0.2);
}
}