use super::hadamard::HadamardTransform;
use super::incoherence::IncoherenceTransform;
use crate::error::{Result, RuvLLMError};
use serde::{Deserialize, Serialize};
use std::f32::consts::PI;
pub const Q2_NUM_LEVELS: usize = 4;
pub const Q2_BLOCK_SIZE: usize = 4;
pub const Q2_SUPER_BLOCK_SIZE: usize = 256;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuipConfig {
pub enable_incoherence: bool,
pub hadamard_seed: u64,
pub use_lattice_codebook: bool,
pub enable_ldlq: bool,
pub codebook: QuipCodebook,
pub per_channel: bool,
}
impl Default for QuipConfig {
fn default() -> Self {
Self {
enable_incoherence: true,
hadamard_seed: 42,
use_lattice_codebook: true,
enable_ldlq: false, codebook: QuipCodebook::E8P,
per_channel: true,
}
}
}
impl QuipConfig {
pub fn minimal() -> Self {
Self {
enable_incoherence: false,
use_lattice_codebook: false,
enable_ldlq: false,
codebook: QuipCodebook::Uniform,
..Default::default()
}
}
pub fn full() -> Self {
Self {
enable_incoherence: true,
use_lattice_codebook: true,
enable_ldlq: true,
codebook: QuipCodebook::E8P,
..Default::default()
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuipCodebook {
Uniform,
E8P,
D4,
HalfInt,
}
impl QuipCodebook {
pub fn values(&self) -> [f32; 4] {
match self {
QuipCodebook::Uniform => [-1.5, -0.5, 0.5, 1.5],
QuipCodebook::E8P => [-1.41, -0.47, 0.47, 1.41], QuipCodebook::D4 => [-1.22, -0.41, 0.41, 1.22], QuipCodebook::HalfInt => [-1.5, -0.5, 0.5, 1.5], }
}
pub fn name(&self) -> &'static str {
match self {
QuipCodebook::Uniform => "Uniform",
QuipCodebook::E8P => "E8P",
QuipCodebook::D4 => "D4",
QuipCodebook::HalfInt => "HalfInt",
}
}
}
#[derive(Clone, Debug)]
pub struct Q2QuipBlock {
pub scale: u16,
pub zero: u16,
pub packed: u8,
}
impl Q2QuipBlock {
pub const SIZE: usize = 5;
pub const ELEMENTS: usize = 4;
pub fn new() -> Self {
Self {
scale: 0,
zero: 0,
packed: 0,
}
}
#[inline]
pub fn pack(values: &[u8; 4]) -> u8 {
debug_assert!(values.iter().all(|&v| v < 4));
(values[0] & 0x03)
| ((values[1] & 0x03) << 2)
| ((values[2] & 0x03) << 4)
| ((values[3] & 0x03) << 6)
}
#[inline]
pub fn unpack(packed: u8) -> [u8; 4] {
[
packed & 0x03,
(packed >> 2) & 0x03,
(packed >> 4) & 0x03,
(packed >> 6) & 0x03,
]
}
pub fn to_bytes(&self) -> [u8; Self::SIZE] {
let mut bytes = [0u8; Self::SIZE];
bytes[0..2].copy_from_slice(&self.scale.to_le_bytes());
bytes[2..4].copy_from_slice(&self.zero.to_le_bytes());
bytes[4] = self.packed;
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Self {
Self {
scale: u16::from_le_bytes([bytes[0], bytes[1]]),
zero: u16::from_le_bytes([bytes[2], bytes[3]]),
packed: bytes[4],
}
}
}
impl Default for Q2QuipBlock {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug)]
pub struct Q2QuipSuperBlock {
pub d: u16,
pub dmin: u16,
pub sub_scales: [u8; 32],
pub packed: [u8; 64],
}
impl Q2QuipSuperBlock {
pub const SIZE: usize = 100; pub const ELEMENTS: usize = 256;
pub const NUM_SUB_BLOCKS: usize = 64;
pub const SUB_BLOCK_SIZE: usize = 4;
pub fn new() -> Self {
Self {
d: 0,
dmin: 0,
sub_scales: [0u8; 32],
packed: [0u8; 64],
}
}
pub fn to_bytes(&self) -> [u8; Self::SIZE] {
let mut bytes = [0u8; Self::SIZE];
bytes[0..2].copy_from_slice(&self.d.to_le_bytes());
bytes[2..4].copy_from_slice(&self.dmin.to_le_bytes());
bytes[4..36].copy_from_slice(&self.sub_scales);
bytes[36..100].copy_from_slice(&self.packed);
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Self {
let mut block = Self::new();
block.d = u16::from_le_bytes([bytes[0], bytes[1]]);
block.dmin = u16::from_le_bytes([bytes[2], bytes[3]]);
block.sub_scales.copy_from_slice(&bytes[4..36]);
block.packed.copy_from_slice(&bytes[36..100]);
block
}
}
impl Default for Q2QuipSuperBlock {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct QuipMetadata {
pub shape: Vec<usize>,
pub padded_shape: Vec<usize>,
pub hadamard_seed: u64,
pub incoherence_applied: bool,
pub codebook: QuipCodebook,
pub channel_scales: Option<Vec<f32>>,
}
pub struct QuipQuantizer {
config: QuipConfig,
incoherence: Option<IncoherenceTransform>,
stats: QuipStats,
}
#[derive(Debug, Clone, Default)]
pub struct QuipStats {
pub tensors_quantized: usize,
pub elements_processed: usize,
pub mse: f64,
pub cosine_similarity: f64,
pub hadamard_time_us: u64,
pub quant_time_us: u64,
}
impl QuipQuantizer {
pub fn new(config: QuipConfig) -> Self {
Self {
config,
incoherence: None,
stats: QuipStats::default(),
}
}
pub fn config(&self) -> &QuipConfig {
&self.config
}
pub fn stats(&self) -> &QuipStats {
&self.stats
}
pub fn quantize(
&mut self,
data: &[f32],
shape: &[usize],
) -> Result<(Vec<Q2QuipSuperBlock>, QuipMetadata)> {
use std::time::Instant;
let total_elements: usize = shape.iter().product();
if total_elements != data.len() {
return Err(RuvLLMError::Model(format!(
"Shape {:?} implies {} elements, got {}",
shape,
total_elements,
data.len()
)));
}
let padded_len = data.len().next_power_of_two();
let mut working = data.to_vec();
working.resize(padded_len, 0.0);
let hadamard_start = Instant::now();
let incoherence_applied = if self.config.enable_incoherence {
if self.incoherence.is_none() {
let config = super::incoherence::IncoherenceConfig {
randomized: true,
seed: Some(self.config.hadamard_seed),
..Default::default()
};
self.incoherence = Some(IncoherenceTransform::new(config)?);
}
if let Some(ref mut transform) = self.incoherence {
transform.apply_before_quantization(&mut working)?;
true
} else {
false
}
} else {
false
};
self.stats.hadamard_time_us += hadamard_start.elapsed().as_micros() as u64;
let quant_start = Instant::now();
let super_block_size = Q2QuipSuperBlock::ELEMENTS;
let num_super_blocks = (padded_len + super_block_size - 1) / super_block_size;
let mut blocks = Vec::with_capacity(num_super_blocks);
for sb_idx in 0..num_super_blocks {
let start = sb_idx * super_block_size;
let end = (start + super_block_size).min(working.len());
let sb_data = &working[start..end];
blocks.push(self.quantize_super_block(sb_data)?);
}
self.stats.quant_time_us += quant_start.elapsed().as_micros() as u64;
self.stats.tensors_quantized += 1;
self.stats.elements_processed += data.len();
let metadata = QuipMetadata {
shape: shape.to_vec(),
padded_shape: vec![padded_len],
hadamard_seed: self.config.hadamard_seed,
incoherence_applied,
codebook: self.config.codebook,
channel_scales: None,
};
Ok((blocks, metadata))
}
fn quantize_super_block(&self, data: &[f32]) -> Result<Q2QuipSuperBlock> {
let mut block = Q2QuipSuperBlock::new();
let mut padded = [0.0f32; Q2QuipSuperBlock::ELEMENTS];
let copy_len = data.len().min(Q2QuipSuperBlock::ELEMENTS);
padded[..copy_len].copy_from_slice(&data[..copy_len]);
let mut min_val = f32::MAX;
let mut max_val = f32::MIN;
for &v in &padded {
min_val = min_val.min(v);
max_val = max_val.max(v);
}
let range = max_val - min_val;
let codebook = self.config.codebook.values();
let codebook_range = codebook[3] - codebook[0];
let d = if range > 1e-10 {
range / codebook_range
} else {
1.0
};
block.d = f32_to_f16(d);
block.dmin = f32_to_f16(min_val);
let eff_d = f16_to_f32(block.d);
let eff_min = f16_to_f32(block.dmin);
for sb in 0..64 {
let sb_start = sb * 4;
let sb_data = &padded[sb_start..sb_start + 4];
let mut sb_min = f32::MAX;
let mut sb_max = f32::MIN;
for &v in sb_data {
sb_min = sb_min.min(v);
sb_max = sb_max.max(v);
}
let sb_scale = if eff_d > 1e-10 {
((sb_max - sb_min) / eff_d).min(15.0) as u8
} else {
0
};
let scale_byte = sb / 2;
let scale_shift = (sb % 2) * 4;
block.sub_scales[scale_byte] |= sb_scale << scale_shift;
let mut quantized = [0u8; 4];
for i in 0..4 {
let val = sb_data[i];
let normalized = if eff_d > 1e-10 {
(val - eff_min) / eff_d
} else {
0.0
};
let q = self.nearest_codebook_entry(normalized);
quantized[i] = q;
}
block.packed[sb] = Q2QuipBlock::pack(&quantized);
}
Ok(block)
}
#[inline]
fn nearest_codebook_entry(&self, val: f32) -> u8 {
let codebook = self.config.codebook.values();
let mut min_dist = f32::MAX;
let mut best_idx = 0u8;
for (i, &cb_val) in codebook.iter().enumerate() {
let dist = (val - cb_val).abs();
if dist < min_dist {
min_dist = dist;
best_idx = i as u8;
}
}
best_idx
}
pub fn dequantize(
&mut self,
blocks: &[Q2QuipSuperBlock],
metadata: &QuipMetadata,
) -> Result<Vec<f32>> {
let padded_len = metadata.padded_shape.iter().product();
let mut output = vec![0.0f32; padded_len];
for (sb_idx, block) in blocks.iter().enumerate() {
let start = sb_idx * Q2QuipSuperBlock::ELEMENTS;
let end = (start + Q2QuipSuperBlock::ELEMENTS).min(padded_len);
self.dequantize_super_block(block, &mut output[start..end]);
}
if metadata.incoherence_applied {
let config = super::incoherence::IncoherenceConfig {
randomized: true,
seed: Some(metadata.hadamard_seed),
..Default::default()
};
let mut transform = IncoherenceTransform::new(config)?;
let original_len: usize = metadata.shape.iter().product();
transform.restore_after_dequantization(&mut output, Some(original_len))?;
}
let original_len: usize = metadata.shape.iter().product();
output.truncate(original_len);
Ok(output)
}
fn dequantize_super_block(&self, block: &Q2QuipSuperBlock, output: &mut [f32]) {
let d = f16_to_f32(block.d);
let dmin = f16_to_f32(block.dmin);
let codebook = self.config.codebook.values();
for sb in 0..64 {
let sb_start = sb * 4;
if sb_start >= output.len() {
break;
}
let scale_byte = sb / 2;
let scale_shift = (sb % 2) * 4;
let _sb_scale = (block.sub_scales[scale_byte] >> scale_shift) & 0x0F;
let quantized = Q2QuipBlock::unpack(block.packed[sb]);
for i in 0..4 {
let elem_idx = sb_start + i;
if elem_idx < output.len() {
let cb_val = codebook[quantized[i] as usize];
output[elem_idx] = dmin + cb_val * d;
}
}
}
}
pub fn compute_metrics(&mut self, original: &[f32], restored: &[f32]) {
if original.len() != restored.len() || original.is_empty() {
return;
}
let n = original.len() as f64;
let mse: f64 = original
.iter()
.zip(restored.iter())
.map(|(a, b)| ((a - b) as f64).powi(2))
.sum::<f64>()
/ n;
self.stats.mse = mse;
let dot: f64 = original
.iter()
.zip(restored.iter())
.map(|(a, b)| (*a as f64) * (*b as f64))
.sum();
let norm_a: f64 = original
.iter()
.map(|a| (*a as f64).powi(2))
.sum::<f64>()
.sqrt();
let norm_b: f64 = restored
.iter()
.map(|b| (*b as f64).powi(2))
.sum::<f64>()
.sqrt();
self.stats.cosine_similarity = if norm_a > 0.0 && norm_b > 0.0 {
dot / (norm_a * norm_b)
} else {
0.0
};
}
}
#[inline(always)]
fn f32_to_f16(val: f32) -> u16 {
let bits = val.to_bits();
let sign = ((bits >> 16) & 0x8000) as u16;
let exp = ((bits >> 23) & 0xFF) as i32;
let frac = bits & 0x007FFFFF;
if exp == 255 {
return sign | 0x7C00 | ((frac != 0) as u16);
}
if exp == 0 {
return sign;
}
let new_exp = exp - 127 + 15;
if new_exp >= 31 {
return sign | 0x7C00;
}
if new_exp <= 0 {
if new_exp < -10 {
return sign;
}
let new_frac = (frac | 0x00800000) >> (1 - new_exp);
return sign | ((new_frac >> 13) as u16);
}
sign | ((new_exp as u16) << 10) | ((frac >> 13) as u16)
}
#[inline(always)]
fn f16_to_f32(bits: u16) -> f32 {
let sign = ((bits & 0x8000) as u32) << 16;
let exp = ((bits >> 10) & 0x1F) as u32;
let frac = (bits & 0x03FF) as u32;
if exp == 0 {
if frac == 0 {
return f32::from_bits(sign);
}
let mut e = 1u32;
let mut f = frac;
while (f & 0x0400) == 0 {
f <<= 1;
e += 1;
}
f &= 0x03FF;
return f32::from_bits(sign | ((127 - 15 + 1 - e) << 23) | (f << 13));
}
if exp == 31 {
return f32::from_bits(sign | 0x7F80_0000 | (frac << 13));
}
f32::from_bits(sign | ((exp + 127 - 15) << 23) | (frac << 13))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quip_config() {
let config = QuipConfig::default();
assert!(config.enable_incoherence);
assert!(config.use_lattice_codebook);
assert_eq!(config.codebook, QuipCodebook::E8P);
let minimal = QuipConfig::minimal();
assert!(!minimal.enable_incoherence);
assert!(!minimal.use_lattice_codebook);
}
#[test]
fn test_codebook_values() {
let uniform = QuipCodebook::Uniform.values();
assert_eq!(uniform.len(), 4);
assert!(uniform[0] < uniform[1]);
assert!(uniform[1] < uniform[2]);
assert!(uniform[2] < uniform[3]);
let e8p = QuipCodebook::E8P.values();
assert_eq!(e8p.len(), 4);
}
#[test]
fn test_q2_pack_unpack() {
let values = [0, 1, 2, 3];
let packed = Q2QuipBlock::pack(&values);
let unpacked = Q2QuipBlock::unpack(packed);
assert_eq!(values, unpacked);
for v0 in 0..4 {
for v1 in 0..4 {
for v2 in 0..4 {
for v3 in 0..4 {
let vals = [v0, v1, v2, v3];
let packed = Q2QuipBlock::pack(&vals);
let unpacked = Q2QuipBlock::unpack(packed);
assert_eq!(vals, unpacked);
}
}
}
}
}
#[test]
fn test_q2_block_serialization() {
let mut block = Q2QuipBlock::new();
block.scale = 0x3C00;
block.zero = 0x0000;
block.packed = 0xE4;
let bytes = block.to_bytes();
let restored = Q2QuipBlock::from_bytes(&bytes);
assert_eq!(restored.scale, block.scale);
assert_eq!(restored.zero, block.zero);
assert_eq!(restored.packed, block.packed);
}
#[test]
fn test_quip_quantize_dequantize() {
let config = QuipConfig::minimal();
let mut quantizer = QuipQuantizer::new(config);
let data: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) / 128.0).collect();
let shape = vec![256];
let (blocks, metadata) = quantizer.quantize(&data, &shape).unwrap();
let restored = quantizer.dequantize(&blocks, &metadata).unwrap();
assert_eq!(restored.len(), data.len());
let mse: f64 = data
.iter()
.zip(restored.iter())
.map(|(a, b)| ((a - b) as f64).powi(2))
.sum::<f64>()
/ data.len() as f64;
assert!(mse < 1.0, "MSE too high: {}", mse);
}
#[test]
fn test_quip_with_incoherence() {
let config = QuipConfig::default();
let mut quantizer = QuipQuantizer::new(config);
let data: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) / 128.0).collect();
let shape = vec![256];
let result = quantizer.quantize(&data, &shape);
assert!(result.is_ok());
let (blocks, metadata) = result.unwrap();
assert!(metadata.incoherence_applied);
let restored = quantizer.dequantize(&blocks, &metadata).unwrap();
assert_eq!(restored.len(), data.len());
}
#[test]
fn test_super_block_serialization() {
let mut block = Q2QuipSuperBlock::new();
block.d = 0x3C00;
block.dmin = 0x0000;
block.sub_scales[0] = 0xAB;
block.packed[0] = 0xCD;
let bytes = block.to_bytes();
let restored = Q2QuipSuperBlock::from_bytes(&bytes);
assert_eq!(restored.d, block.d);
assert_eq!(restored.dmin, block.dmin);
assert_eq!(restored.sub_scales[0], block.sub_scales[0]);
assert_eq!(restored.packed[0], block.packed[0]);
}
}