use std::fs::File;
use std::io::{BufWriter, Read, Seek, SeekFrom, Write as IoWrite};
use std::path::Path;
use crate::error::{Result, RuvLLMError};
use crate::gguf::{GgufQuantType, GGUF_MAGIC, GGUF_VERSION};
pub const ANE_ALIGNMENT: usize = 16;
pub const K_BLOCK_SIZE: usize = 256;
pub const K_SUB_BLOCK_SIZE: usize = 32;
pub const Q8_BLOCK_SIZE: usize = 32;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TargetFormat {
Q4_K_M,
Q5_K_M,
Q8_0,
F16,
PiQ3,
PiQ2,
}
impl TargetFormat {
pub fn to_gguf_type(&self) -> GgufQuantType {
match self {
TargetFormat::Q4_K_M => GgufQuantType::Q4_K,
TargetFormat::Q5_K_M => GgufQuantType::Q5_K,
TargetFormat::Q8_0 => GgufQuantType::Q8_0,
TargetFormat::F16 => GgufQuantType::F16,
TargetFormat::PiQ3 => GgufQuantType::Q4_K, TargetFormat::PiQ2 => GgufQuantType::Q2_K, }
}
pub fn bits_per_weight(&self) -> f32 {
match self {
TargetFormat::Q4_K_M => 4.5,
TargetFormat::Q5_K_M => 5.5,
TargetFormat::Q8_0 => 8.5,
TargetFormat::F16 => 16.0,
TargetFormat::PiQ3 => 3.0625, TargetFormat::PiQ2 => 2.0625, }
}
pub fn block_size(&self) -> usize {
match self {
TargetFormat::Q4_K_M | TargetFormat::Q5_K_M => K_BLOCK_SIZE,
TargetFormat::Q8_0 => Q8_BLOCK_SIZE,
TargetFormat::F16 => 1,
TargetFormat::PiQ3 => 8, TargetFormat::PiQ2 => 4, }
}
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"q4_k_m" | "q4k" | "q4km" | "q4" => Some(TargetFormat::Q4_K_M),
"q5_k_m" | "q5k" | "q5km" | "q5" => Some(TargetFormat::Q5_K_M),
"q8_0" | "q8" | "q80" => Some(TargetFormat::Q8_0),
"f16" | "fp16" | "half" => Some(TargetFormat::F16),
"piq3" | "pi_q3" | "pi3" | "pi-q3" => Some(TargetFormat::PiQ3),
"piq2" | "pi_q2" | "pi2" | "pi-q2" => Some(TargetFormat::PiQ2),
_ => None,
}
}
pub fn name(&self) -> &'static str {
match self {
TargetFormat::Q4_K_M => "Q4_K_M",
TargetFormat::Q5_K_M => "Q5_K_M",
TargetFormat::Q8_0 => "Q8_0",
TargetFormat::F16 => "F16",
TargetFormat::PiQ3 => "PiQ3",
TargetFormat::PiQ2 => "PiQ2",
}
}
pub fn is_pi_quant(&self) -> bool {
matches!(self, TargetFormat::PiQ3 | TargetFormat::PiQ2)
}
pub fn raw_bits(&self) -> u8 {
match self {
TargetFormat::Q4_K_M => 4,
TargetFormat::Q5_K_M => 5,
TargetFormat::Q8_0 => 8,
TargetFormat::F16 => 16,
TargetFormat::PiQ3 => 3,
TargetFormat::PiQ2 => 2,
}
}
}
#[derive(Debug, Clone)]
pub struct QuantConfig {
pub format: TargetFormat,
pub ane_optimize: bool,
pub calibration_samples: usize,
pub keep_embed_fp16: bool,
pub keep_output_fp16: bool,
pub chunk_size: usize,
pub verbose: bool,
}
impl Default for QuantConfig {
fn default() -> Self {
Self {
format: TargetFormat::Q4_K_M,
ane_optimize: true,
calibration_samples: 128,
keep_embed_fp16: true, keep_output_fp16: true, chunk_size: 64 * 1024 * 1024, verbose: false,
}
}
}
impl QuantConfig {
pub fn with_format(mut self, format: TargetFormat) -> Self {
self.format = format;
self
}
pub fn with_ane_optimization(mut self, enable: bool) -> Self {
self.ane_optimize = enable;
self
}
pub fn with_verbose(mut self, verbose: bool) -> Self {
self.verbose = verbose;
self
}
}
#[derive(Debug, Clone)]
pub struct MemoryEstimate {
pub total_bytes: usize,
pub total_mb: f64,
pub breakdown: MemoryBreakdown,
pub compression_ratio: f64,
}
#[derive(Debug, Clone)]
pub struct MemoryBreakdown {
pub embeddings: usize,
pub attention: usize,
pub mlp: usize,
pub norms: usize,
pub output: usize,
}
pub fn estimate_memory_q4(
params_billions: f64,
vocab_size: usize,
hidden_dim: usize,
num_layers: usize,
) -> MemoryEstimate {
estimate_memory_internal(
params_billions,
vocab_size,
hidden_dim,
num_layers,
TargetFormat::Q4_K_M,
)
}
pub fn estimate_memory_q5(
params_billions: f64,
vocab_size: usize,
hidden_dim: usize,
num_layers: usize,
) -> MemoryEstimate {
estimate_memory_internal(
params_billions,
vocab_size,
hidden_dim,
num_layers,
TargetFormat::Q5_K_M,
)
}
pub fn estimate_memory_q8(
params_billions: f64,
vocab_size: usize,
hidden_dim: usize,
num_layers: usize,
) -> MemoryEstimate {
estimate_memory_internal(
params_billions,
vocab_size,
hidden_dim,
num_layers,
TargetFormat::Q8_0,
)
}
fn estimate_memory_internal(
params_billions: f64,
vocab_size: usize,
hidden_dim: usize,
num_layers: usize,
format: TargetFormat,
) -> MemoryEstimate {
let bits_per_weight = format.bits_per_weight();
let embed_params = vocab_size * hidden_dim;
let embeddings = embed_params * 2;
let attention_params = hidden_dim * hidden_dim * 4; let attention_per_layer = (attention_params as f64 * bits_per_weight as f64 / 8.0) as usize;
let attention = attention_per_layer * num_layers;
let intermediate_dim = hidden_dim * 4; let mlp_params = hidden_dim * intermediate_dim * 3;
let mlp_per_layer = (mlp_params as f64 * bits_per_weight as f64 / 8.0) as usize;
let mlp = mlp_per_layer * num_layers;
let norm_params = hidden_dim * 2 * num_layers; let norms = norm_params * 4;
let output_params = hidden_dim * vocab_size;
let output = output_params * 2;
let total_bytes = embeddings + attention + mlp + norms + output;
let total_mb = total_bytes as f64 / (1024.0 * 1024.0);
let fp32_size = params_billions * 1e9 * 4.0;
let compression_ratio = fp32_size / total_bytes as f64;
MemoryEstimate {
total_bytes,
total_mb,
breakdown: MemoryBreakdown {
embeddings,
attention,
mlp,
norms,
output,
},
compression_ratio,
}
}
#[derive(Clone)]
pub struct Q4KMBlock {
pub d: u16,
pub dmin: u16,
pub scales: [u8; 12],
pub qs: [u8; 128],
}
impl Q4KMBlock {
pub const SIZE: usize = 144;
pub const ELEMENTS: usize = 256;
pub fn new() -> Self {
Self {
d: 0,
dmin: 0,
scales: [0u8; 12],
qs: [0u8; 128],
}
}
pub fn to_bytes(&self) -> [u8; Self::SIZE] {
let mut bytes = [0u8; Self::SIZE];
bytes[0..2].copy_from_slice(&self.d.to_le_bytes());
bytes[2..4].copy_from_slice(&self.dmin.to_le_bytes());
bytes[4..16].copy_from_slice(&self.scales);
bytes[16..144].copy_from_slice(&self.qs);
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Self {
let mut block = Self::new();
block.d = u16::from_le_bytes([bytes[0], bytes[1]]);
block.dmin = u16::from_le_bytes([bytes[2], bytes[3]]);
block.scales.copy_from_slice(&bytes[4..16]);
block.qs.copy_from_slice(&bytes[16..144]);
block
}
}
impl Default for Q4KMBlock {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct Q5KMBlock {
pub d: u16,
pub dmin: u16,
pub scales: [u8; 12],
pub qh: [u8; 32],
pub qs: [u8; 128],
}
impl Q5KMBlock {
pub const SIZE: usize = 176;
pub const ELEMENTS: usize = 256;
pub fn new() -> Self {
Self {
d: 0,
dmin: 0,
scales: [0u8; 12],
qh: [0u8; 32],
qs: [0u8; 128],
}
}
pub fn to_bytes(&self) -> [u8; Self::SIZE] {
let mut bytes = [0u8; Self::SIZE];
bytes[0..2].copy_from_slice(&self.d.to_le_bytes());
bytes[2..4].copy_from_slice(&self.dmin.to_le_bytes());
bytes[4..16].copy_from_slice(&self.scales);
bytes[16..48].copy_from_slice(&self.qh);
bytes[48..176].copy_from_slice(&self.qs);
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Self {
let mut block = Self::new();
block.d = u16::from_le_bytes([bytes[0], bytes[1]]);
block.dmin = u16::from_le_bytes([bytes[2], bytes[3]]);
block.scales.copy_from_slice(&bytes[4..16]);
block.qh.copy_from_slice(&bytes[16..48]);
block.qs.copy_from_slice(&bytes[48..176]);
block
}
}
impl Default for Q5KMBlock {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct Q8Block {
pub d: u16,
pub qs: [i8; 32],
}
impl Q8Block {
pub const SIZE: usize = 34;
pub const ELEMENTS: usize = 32;
pub fn new() -> Self {
Self {
d: 0,
qs: [0i8; 32],
}
}
pub fn to_bytes(&self) -> [u8; Self::SIZE] {
let mut bytes = [0u8; Self::SIZE];
bytes[0..2].copy_from_slice(&self.d.to_le_bytes());
for (i, &q) in self.qs.iter().enumerate() {
bytes[2 + i] = q as u8;
}
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Self {
let mut block = Self::new();
block.d = u16::from_le_bytes([bytes[0], bytes[1]]);
for i in 0..32 {
block.qs[i] = bytes[2 + i] as i8;
}
block
}
}
impl Default for Q8Block {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct QuantProgress {
pub current_tensor: String,
pub total_tensors: usize,
pub completed_tensors: usize,
pub bytes_processed: usize,
pub total_bytes: usize,
pub eta_seconds: Option<f64>,
}
#[derive(Debug, Clone, Default)]
pub struct QuantStats {
pub tensors_quantized: usize,
pub elements_processed: usize,
pub input_bytes: usize,
pub output_bytes: usize,
pub quantization_mse: f64,
pub peak_memory: usize,
pub processing_time: f64,
}
pub fn quantize_ruvltra_q4(input: &[f32]) -> Result<Vec<Q4KMBlock>> {
if input.len() % K_BLOCK_SIZE != 0 {
return Err(RuvLLMError::Model(format!(
"Input length {} is not a multiple of block size {}",
input.len(),
K_BLOCK_SIZE
)));
}
let num_blocks = input.len() / K_BLOCK_SIZE;
let mut blocks = Vec::with_capacity(num_blocks);
for block_idx in 0..num_blocks {
let start = block_idx * K_BLOCK_SIZE;
let block_data = &input[start..start + K_BLOCK_SIZE];
blocks.push(quantize_q4_k_block(block_data));
}
Ok(blocks)
}
pub fn quantize_ruvltra_q5(input: &[f32]) -> Result<Vec<Q5KMBlock>> {
if input.len() % K_BLOCK_SIZE != 0 {
return Err(RuvLLMError::Model(format!(
"Input length {} is not a multiple of block size {}",
input.len(),
K_BLOCK_SIZE
)));
}
let num_blocks = input.len() / K_BLOCK_SIZE;
let mut blocks = Vec::with_capacity(num_blocks);
for block_idx in 0..num_blocks {
let start = block_idx * K_BLOCK_SIZE;
let block_data = &input[start..start + K_BLOCK_SIZE];
blocks.push(quantize_q5_k_block(block_data));
}
Ok(blocks)
}
pub fn quantize_ruvltra_q8(input: &[f32]) -> Result<Vec<Q8Block>> {
if input.len() % Q8_BLOCK_SIZE != 0 {
return Err(RuvLLMError::Model(format!(
"Input length {} is not a multiple of block size {}",
input.len(),
Q8_BLOCK_SIZE
)));
}
let num_blocks = input.len() / Q8_BLOCK_SIZE;
let mut blocks = Vec::with_capacity(num_blocks);
for block_idx in 0..num_blocks {
let start = block_idx * Q8_BLOCK_SIZE;
let block_data = &input[start..start + Q8_BLOCK_SIZE];
blocks.push(quantize_q8_block(block_data));
}
Ok(blocks)
}
pub fn dequantize_for_ane(blocks: &[Q4KMBlock], output: &mut [f32]) {
let mut out_idx = 0;
for block in blocks {
dequantize_q4_k_block_to_fp32(block, &mut output[out_idx..out_idx + K_BLOCK_SIZE]);
out_idx += K_BLOCK_SIZE;
}
}
fn quantize_q4_k_block(data: &[f32]) -> Q4KMBlock {
let mut block = Q4KMBlock::new();
let mut min_val = f32::MAX;
let mut max_val = f32::MIN;
for &v in data {
min_val = min_val.min(v);
max_val = max_val.max(v);
}
let range = max_val - min_val;
let d = if range > 0.0 { range / 15.0 } else { 1.0 }; let dmin = min_val;
block.d = f32_to_f16(d);
block.dmin = f32_to_f16(dmin);
for sb in 0..8 {
let sb_start = sb * K_SUB_BLOCK_SIZE;
let sb_end = sb_start + K_SUB_BLOCK_SIZE;
let sb_data = &data[sb_start..sb_end];
let mut sb_min = f32::MAX;
let mut sb_max = f32::MIN;
for &v in sb_data {
sb_min = sb_min.min(v);
sb_max = sb_max.max(v);
}
let sb_range = sb_max - sb_min;
let sb_scale = if d > 0.0 {
(sb_range / d).min(63.0) as u8
} else {
0
};
let scale_byte_idx = (sb * 6) / 8;
let scale_bit_offset = (sb * 6) % 8;
if scale_bit_offset <= 2 {
block.scales[scale_byte_idx] |= sb_scale << scale_bit_offset;
} else {
block.scales[scale_byte_idx] |= sb_scale << scale_bit_offset;
if scale_byte_idx + 1 < 12 {
block.scales[scale_byte_idx + 1] |= sb_scale >> (8 - scale_bit_offset);
}
}
let eff_d = f16_to_f32(block.d);
let eff_min = f16_to_f32(block.dmin);
for i in 0..K_SUB_BLOCK_SIZE {
let val = sb_data[i];
let q = if eff_d > 0.0 {
((val - eff_min) / eff_d).clamp(0.0, 15.0) as u8
} else {
0
};
let elem_idx = sb_start + i;
let byte_idx = elem_idx / 2;
if elem_idx % 2 == 0 {
block.qs[byte_idx] = q;
} else {
block.qs[byte_idx] |= q << 4;
}
}
}
block
}
fn quantize_q5_k_block(data: &[f32]) -> Q5KMBlock {
let mut block = Q5KMBlock::new();
let mut min_val = f32::MAX;
let mut max_val = f32::MIN;
for &v in data {
min_val = min_val.min(v);
max_val = max_val.max(v);
}
let range = max_val - min_val;
let d = if range > 0.0 { range / 31.0 } else { 1.0 }; let dmin = min_val;
block.d = f32_to_f16(d);
block.dmin = f32_to_f16(dmin);
let eff_d = f16_to_f32(block.d);
let eff_min = f16_to_f32(block.dmin);
for i in 0..K_BLOCK_SIZE {
let val = data[i];
let q = if eff_d > 0.0 {
((val - eff_min) / eff_d).clamp(0.0, 31.0) as u8
} else {
0
};
let byte_idx = i / 2;
if i % 2 == 0 {
block.qs[byte_idx] = q & 0x0F;
} else {
block.qs[byte_idx] |= (q & 0x0F) << 4;
}
let qh_byte = i / 8;
let qh_bit = i % 8;
if q & 0x10 != 0 {
block.qh[qh_byte] |= 1 << qh_bit;
}
}
block
}
fn quantize_q8_block(data: &[f32]) -> Q8Block {
let mut block = Q8Block::new();
let mut amax = 0.0f32;
for &v in data {
amax = amax.max(v.abs());
}
let d = if amax > 0.0 { amax / 127.0 } else { 1.0 };
block.d = f32_to_f16(d);
let eff_d = f16_to_f32(block.d);
for i in 0..Q8_BLOCK_SIZE {
let val = data[i];
let q = if eff_d > 0.0 {
(val / eff_d).clamp(-128.0, 127.0).round() as i8
} else {
0
};
block.qs[i] = q;
}
block
}
fn dequantize_q4_k_block_to_fp32(block: &Q4KMBlock, output: &mut [f32]) {
let d = f16_to_f32(block.d);
let dmin = f16_to_f32(block.dmin);
for sb in 0..8 {
let scale_byte_idx = (sb * 6) / 8;
let scale_bit_offset = (sb * 6) % 8;
let mut sc = (block.scales[scale_byte_idx] >> scale_bit_offset) & 0x3F;
if scale_bit_offset > 2 && scale_byte_idx + 1 < 12 {
sc |= (block.scales[scale_byte_idx + 1] << (8 - scale_bit_offset)) & 0x3F;
}
let scale = d * (sc as f32);
let sb_start = sb * K_SUB_BLOCK_SIZE;
for i in 0..K_SUB_BLOCK_SIZE {
let elem_idx = sb_start + i;
let byte_idx = elem_idx / 2;
let q = if elem_idx % 2 == 0 {
block.qs[byte_idx] & 0x0F
} else {
(block.qs[byte_idx] >> 4) & 0x0F
};
output[elem_idx] = (q as f32) * scale + dmin;
}
}
}
#[inline(always)]
fn f32_to_f16(val: f32) -> u16 {
let bits = val.to_bits();
let sign = ((bits >> 16) & 0x8000) as u16;
let exp = ((bits >> 23) & 0xFF) as i32;
let frac = bits & 0x007FFFFF;
if exp == 255 {
return sign | 0x7C00 | ((frac != 0) as u16);
}
if exp == 0 {
return sign;
}
let new_exp = exp - 127 + 15;
if new_exp >= 31 {
return sign | 0x7C00;
}
if new_exp <= 0 {
if new_exp < -10 {
return sign;
}
let new_frac = (frac | 0x00800000) >> (1 - new_exp);
return sign | ((new_frac >> 13) as u16);
}
sign | ((new_exp as u16) << 10) | ((frac >> 13) as u16)
}
#[inline(always)]
fn f16_to_f32(bits: u16) -> f32 {
let sign = ((bits & 0x8000) as u32) << 16;
let exp = ((bits >> 10) & 0x1F) as u32;
let frac = (bits & 0x03FF) as u32;
if exp == 0 {
if frac == 0 {
return f32::from_bits(sign);
}
let mut e = 1u32;
let mut f = frac;
while (f & 0x0400) == 0 {
f <<= 1;
e += 1;
}
f &= 0x03FF;
return f32::from_bits(sign | ((127 - 15 + 1 - e) << 23) | (f << 13));
}
if exp == 31 {
return f32::from_bits(sign | 0x7F80_0000 | (frac << 13));
}
f32::from_bits(sign | ((exp + 127 - 15) << 23) | (frac << 13))
}
pub struct RuvltraQuantizer {
config: QuantConfig,
stats: QuantStats,
}
impl RuvltraQuantizer {
pub fn new(config: QuantConfig) -> Result<Self> {
Ok(Self {
config,
stats: QuantStats::default(),
})
}
pub fn config(&self) -> &QuantConfig {
&self.config
}
pub fn stats(&self) -> &QuantStats {
&self.stats
}
pub fn quantize_tensor(&mut self, data: &[f32], tensor_name: &str) -> Result<Vec<u8>> {
let is_embedding = tensor_name.contains("embed") || tensor_name.contains("token");
let is_output = tensor_name.contains("lm_head") || tensor_name.contains("output");
if (self.config.keep_embed_fp16 && is_embedding)
|| (self.config.keep_output_fp16 && is_output)
{
return self.quantize_to_fp16(data);
}
let block_size = self.config.format.block_size();
let padded_len = ((data.len() + block_size - 1) / block_size) * block_size;
let mut padded_data = data.to_vec();
padded_data.resize(padded_len, 0.0);
match self.config.format {
TargetFormat::Q4_K_M => {
let blocks = quantize_ruvltra_q4(&padded_data)?;
let mut bytes = Vec::with_capacity(blocks.len() * Q4KMBlock::SIZE);
for block in blocks {
bytes.extend_from_slice(&block.to_bytes());
}
self.stats.tensors_quantized += 1;
self.stats.elements_processed += data.len();
Ok(bytes)
}
TargetFormat::Q5_K_M => {
let blocks = quantize_ruvltra_q5(&padded_data)?;
let mut bytes = Vec::with_capacity(blocks.len() * Q5KMBlock::SIZE);
for block in blocks {
bytes.extend_from_slice(&block.to_bytes());
}
self.stats.tensors_quantized += 1;
self.stats.elements_processed += data.len();
Ok(bytes)
}
TargetFormat::Q8_0 => {
let blocks = quantize_ruvltra_q8(&padded_data)?;
let mut bytes = Vec::with_capacity(blocks.len() * Q8Block::SIZE);
for block in blocks {
bytes.extend_from_slice(&block.to_bytes());
}
self.stats.tensors_quantized += 1;
self.stats.elements_processed += data.len();
Ok(bytes)
}
TargetFormat::F16 => self.quantize_to_fp16(data),
TargetFormat::PiQ3 => {
use super::pi_quant_simd::{
pi_quantize_scalar, pi_scale_adaptive, DEFAULT_K, PI3_BYTES_PER_GROUP,
PI3_VALUES_PER_GROUP,
};
let num_groups =
(padded_data.len() + PI3_VALUES_PER_GROUP - 1) / PI3_VALUES_PER_GROUP;
let mut bytes = Vec::with_capacity(num_groups * (PI3_BYTES_PER_GROUP + 2));
for chunk in padded_data.chunks(PI3_VALUES_PER_GROUP) {
let max_abs = chunk.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let alpha = if max_abs > 1e-10 {
max_abs / 4.0
} else {
1e-10
};
let scale = pi_scale_adaptive(alpha, DEFAULT_K);
bytes.extend_from_slice(&f32_to_f16(scale).to_le_bytes());
let mut chunk_padded = [0.0f32; PI3_VALUES_PER_GROUP];
for (i, &v) in chunk.iter().enumerate() {
chunk_padded[i] = v;
}
let mut packed_bytes = [0u8; PI3_BYTES_PER_GROUP];
pi_quantize_scalar(&chunk_padded, scale, &mut packed_bytes);
bytes.extend_from_slice(&packed_bytes);
}
self.stats.tensors_quantized += 1;
self.stats.elements_processed += data.len();
Ok(bytes)
}
TargetFormat::PiQ2 => {
use super::pi_quant_simd::{pi_scale_adaptive, DEFAULT_K};
let num_groups = (padded_data.len() + 3) / 4; let mut bytes = Vec::with_capacity(num_groups * 3);
for chunk in padded_data.chunks(4) {
let max_abs = chunk.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let alpha = if max_abs > 1e-10 {
max_abs / 2.0
} else {
1e-10
};
let scale = pi_scale_adaptive(alpha, DEFAULT_K);
bytes.extend_from_slice(&f32_to_f16(scale).to_le_bytes());
let mut packed_byte = 0u8;
let inv_scale = if scale.abs() > 1e-10 {
1.0 / scale
} else {
0.0
};
for (i, &val) in chunk.iter().take(4).enumerate() {
let quantized = (val * inv_scale).round() as i32;
let clamped = quantized.clamp(-2, 1);
let q_unsigned = (clamped + 2) as u8;
packed_byte |= (q_unsigned & 0x3) << (i * 2);
}
bytes.push(packed_byte);
}
self.stats.tensors_quantized += 1;
self.stats.elements_processed += data.len();
Ok(bytes)
}
}
}
fn quantize_to_fp16(&self, data: &[f32]) -> Result<Vec<u8>> {
let mut bytes = Vec::with_capacity(data.len() * 2);
for &v in data {
bytes.extend_from_slice(&f32_to_f16(v).to_le_bytes());
}
Ok(bytes)
}
pub fn apply_ane_layout(&self, data: &mut [u8], shape: &[usize]) -> Result<()> {
if !self.config.ane_optimize {
return Ok(());
}
if data.as_ptr() as usize % ANE_ALIGNMENT != 0 {
}
let _ = shape;
Ok(())
}
pub fn estimate_output_size(&self, input_bytes: usize) -> usize {
let input_elements = input_bytes / 4; let block_size = self.config.format.block_size();
let num_blocks = (input_elements + block_size - 1) / block_size;
match self.config.format {
TargetFormat::Q4_K_M => num_blocks * Q4KMBlock::SIZE,
TargetFormat::Q5_K_M => num_blocks * Q5KMBlock::SIZE,
TargetFormat::Q8_0 => num_blocks * Q8Block::SIZE,
TargetFormat::F16 => input_elements * 2,
TargetFormat::PiQ3 => {
use super::pi_quant_simd::{PI3_BYTES_PER_GROUP, PI3_VALUES_PER_GROUP};
let num_groups = (input_elements + PI3_VALUES_PER_GROUP - 1) / PI3_VALUES_PER_GROUP;
num_groups * (PI3_BYTES_PER_GROUP + 2) }
TargetFormat::PiQ2 => {
let num_groups = (input_elements + 3) / 4; num_groups * 3 }
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_target_format_parsing() {
assert_eq!(TargetFormat::from_str("q4_k_m"), Some(TargetFormat::Q4_K_M));
assert_eq!(TargetFormat::from_str("Q4K"), Some(TargetFormat::Q4_K_M));
assert_eq!(TargetFormat::from_str("q8"), Some(TargetFormat::Q8_0));
assert_eq!(TargetFormat::from_str("f16"), Some(TargetFormat::F16));
assert_eq!(TargetFormat::from_str("invalid"), None);
}
#[test]
fn test_memory_estimation() {
let estimate = estimate_memory_q4(0.5, 151936, 896, 24);
assert!(
estimate.total_mb > 100.0 && estimate.total_mb < 1000.0,
"Estimate should be reasonable, got {:.1}MB",
estimate.total_mb
);
let estimate_q8 = estimate_memory_q8(0.5, 151936, 896, 24);
assert!(
estimate_q8.total_mb > estimate.total_mb,
"Q8 ({:.1}MB) should be larger than Q4 ({:.1}MB)",
estimate_q8.total_mb,
estimate.total_mb
);
assert!(
estimate.compression_ratio > 1.0,
"Compression ratio should be > 1, got {:.2}",
estimate.compression_ratio
);
}
#[test]
fn test_q4_k_quantization() {
let data: Vec<f32> = (0..256).map(|i| i as f32 / 256.0).collect();
let blocks = quantize_ruvltra_q4(&data).unwrap();
assert_eq!(blocks.len(), 1);
let mut output = vec![0.0f32; 256];
dequantize_for_ane(&blocks, &mut output);
let mse: f64 = data
.iter()
.zip(output.iter())
.map(|(a, b)| ((a - b) as f64).powi(2))
.sum::<f64>()
/ 256.0;
assert!(mse < 0.01, "Quantization MSE too high: {}", mse);
}
#[test]
fn test_q8_quantization() {
let data: Vec<f32> = (0..32).map(|i| (i as f32 - 16.0) / 16.0).collect();
let blocks = quantize_ruvltra_q8(&data).unwrap();
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].qs.len(), 32);
}
#[test]
fn test_f16_conversion_roundtrip() {
let values = [0.0f32, 1.0, -1.0, 0.5, 100.0, 0.001];
for &val in &values {
let f16 = f32_to_f16(val);
let back = f16_to_f32(f16);
let error = (val - back).abs() / val.abs().max(1.0);
assert!(
error < 0.01,
"F16 roundtrip error too high for {}: got {}",
val,
back
);
}
}
#[test]
fn test_quantizer_config() {
let config = QuantConfig::default()
.with_format(TargetFormat::Q5_K_M)
.with_ane_optimization(true)
.with_verbose(true);
assert_eq!(config.format, TargetFormat::Q5_K_M);
assert!(config.ane_optimize);
assert!(config.verbose);
}
#[test]
fn test_block_serialization() {
let mut block = Q4KMBlock::new();
block.d = 0x3C00; block.dmin = 0x0000;
block.scales[0] = 0x3F; block.qs[0] = 0x12;
let bytes = block.to_bytes();
let restored = Q4KMBlock::from_bytes(&bytes);
assert_eq!(restored.d, block.d);
assert_eq!(restored.dmin, block.dmin);
assert_eq!(restored.scales[0], block.scales[0]);
assert_eq!(restored.qs[0], block.qs[0]);
}
}