use crate::error::{RealizarError, Result};
use super::f16_to_f32_lut;
pub const BLOCK_SIZE: usize = 32;
pub const QK_K: usize = 256;
#[derive(Debug, Clone)]
pub struct Q4_0Block {
pub scale: f32,
pub quants: [u8; 16],
}
#[derive(Debug, Clone)]
pub struct Q8_0Block {
pub scale: f32,
pub quants: [i8; 32],
}
impl Q8_0Block {
#[must_use]
pub fn quantize(values: &[f32; 32]) -> Self {
let max_abs = values.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
let scale = if max_abs > 1e-10 {
max_abs / 127.0
} else {
1.0 / 127.0 };
let mut quants = [0i8; 32];
for (i, &v) in values.iter().enumerate() {
let q = (v / scale).round();
quants[i] = q.clamp(-128.0, 127.0) as i8;
}
Self { scale, quants }
}
#[must_use]
pub fn dequantize(&self) -> [f32; 32] {
let mut values = [0.0f32; 32];
for (i, &q) in self.quants.iter().enumerate() {
values[i] = q as f32 * self.scale;
}
values
}
#[must_use]
pub fn quantization_error(&self, original: &[f32; 32]) -> f32 {
let dequantized = self.dequantize();
original
.iter()
.zip(dequantized.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max)
}
#[must_use]
pub fn relative_error(&self, original: &[f32; 32]) -> f32 {
let max_val = original.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
if max_val < 1e-10 {
return 0.0;
}
self.quantization_error(original) / max_val
}
}
#[derive(Debug, Clone)]
pub struct Q8KSuperBlock {
pub scale: f32,
pub quants: [i8; 256],
}
impl Q8KSuperBlock {
#[must_use]
pub fn quantize(values: &[f32; 256]) -> Self {
let max_abs = values.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
let scale = if max_abs > 1e-10 {
max_abs / 127.0
} else {
1.0 / 127.0
};
let inv_scale = 1.0 / scale;
let mut quants = [0i8; 256];
for (i, &v) in values.iter().enumerate() {
let q = (v * inv_scale).round();
quants[i] = q.clamp(-128.0, 127.0) as i8;
}
Self { scale, quants }
}
#[inline]
pub fn quantize_into(values: &[f32], scale_out: &mut f32, quants_out: &mut [i8]) {
debug_assert!(values.len() >= 256);
debug_assert!(quants_out.len() >= 256);
let max_abs = values[..256].iter().map(|v| v.abs()).fold(0.0f32, f32::max);
let scale = if max_abs > 1e-10 {
max_abs / 127.0
} else {
1.0 / 127.0
};
*scale_out = scale;
let inv_scale = 1.0 / scale;
for (i, &v) in values[..256].iter().enumerate() {
let q = (v * inv_scale).round();
quants_out[i] = q.clamp(-128.0, 127.0) as i8;
}
}
#[must_use]
pub fn dequantize(&self) -> [f32; 256] {
let mut values = [0.0f32; 256];
for (i, &q) in self.quants.iter().enumerate() {
values[i] = q as f32 * self.scale;
}
values
}
}
#[derive(Debug, Clone)]
#[allow(non_camel_case_types)]
pub struct Q4_KBlock {
pub d: f32,
pub dmin: f32,
pub scales: [u8; 12],
pub qs: [u8; 128],
}
#[derive(Debug, Clone)]
#[allow(non_camel_case_types)]
pub struct Q5_KBlock {
pub d: f32,
pub dmin: f32,
pub scales: [u8; 12],
pub qh: [u8; 32],
pub qs: [u8; 128],
}
#[derive(Debug, Clone)]
#[allow(non_camel_case_types)]
pub struct Q6_KBlock {
pub d: f32,
pub scales: [i8; 16],
pub qh: [u8; 64],
pub qs: [u8; 128],
}
#[derive(Debug, Clone)]
pub struct InterleavedQ4K {
pub d: Vec<f32>,
pub dmin: Vec<f32>,
pub scales: Vec<u8>,
pub qs: Vec<u8>,
pub num_super_blocks: usize,
}
impl InterleavedQ4K {
pub fn from_q4k(q4k_data: &[u8]) -> Result<Self> {
const SUPER_BLOCK_BYTES: usize = 144;
if !q4k_data.len().is_multiple_of(SUPER_BLOCK_BYTES) {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q4_K data length {} is not a multiple of super-block size {}",
q4k_data.len(),
SUPER_BLOCK_BYTES
),
});
}
let num_super_blocks = q4k_data.len() / SUPER_BLOCK_BYTES;
let mut d = Vec::with_capacity(num_super_blocks);
let mut dmin = Vec::with_capacity(num_super_blocks);
let mut scales = Vec::with_capacity(num_super_blocks * 12);
let mut qs = Vec::with_capacity(num_super_blocks * 128);
for sb in 0..num_super_blocks {
let sb_start = sb * SUPER_BLOCK_BYTES;
let d_val = f16_to_f32_lut(u16::from_le_bytes([
q4k_data[sb_start],
q4k_data[sb_start + 1],
]));
let dmin_val = f16_to_f32_lut(u16::from_le_bytes([
q4k_data[sb_start + 2],
q4k_data[sb_start + 3],
]));
d.push(d_val);
dmin.push(dmin_val);
scales.extend_from_slice(&q4k_data[sb_start + 4..sb_start + 16]);
let qs_start = sb_start + 16;
let original_qs = &q4k_data[qs_start..qs_start + 128];
qs.extend_from_slice(original_qs);
}
Ok(Self {
d,
dmin,
scales,
qs,
num_super_blocks,
})
}
#[must_use]
pub fn num_values(&self) -> usize {
self.num_super_blocks * QK_K
}
}
#[derive(Debug, Clone, Default)]
pub struct DequantStats {
pub blocks_processed: u64,
pub bytes_processed: u64,
pub simd_backend: SimdBackend,
}
include!("simd_backend.rs");