use std::collections::HashMap;
use std::io::{self, Read, Seek, SeekFrom};
use std::path::Path;
use crate::gguf::{self, GGMLType, GGUFFile, GGUFTensorInfo};
#[derive(Debug, Clone)]
pub enum WeightData {
F32(Vec<f32>),
Q8_0Raw(Vec<u8>),
Q4_0Raw(Vec<u8>),
}
#[derive(Debug, Clone)]
pub struct ModelWeightsRaw {
pub tensors: HashMap<String, WeightData>,
}
impl ModelWeightsRaw {
pub fn get_f32(&self, name: &str) -> Option<&[f32]> {
match self.tensors.get(name) {
Some(WeightData::F32(v)) => Some(v.as_slice()),
_ => None,
}
}
pub fn get_q8_raw(&self, name: &str) -> Option<&[u8]> {
match self.tensors.get(name) {
Some(WeightData::Q8_0Raw(v)) => Some(v.as_slice()),
_ => None,
}
}
pub fn get_q4_raw(&self, name: &str) -> Option<&[u8]> {
match self.tensors.get(name) {
Some(WeightData::Q4_0Raw(v)) => Some(v.as_slice()),
_ => None,
}
}
pub fn len(&self) -> usize {
self.tensors.len()
}
pub fn is_empty(&self) -> bool {
self.tensors.is_empty()
}
pub fn memory_bytes(&self) -> usize {
self.tensors
.values()
.map(|v| match v {
WeightData::F32(f) => f.len() * 4,
WeightData::Q8_0Raw(b) => b.len(),
WeightData::Q4_0Raw(b) => b.len(),
})
.sum()
}
pub fn get(&self, name: &str) -> Option<&WeightData> {
self.tensors.get(name)
}
}
pub fn load_from_file_mixed(
path: impl AsRef<std::path::Path>,
) -> Result<(crate::gguf::GGUFFile, ModelWeightsRaw), WeightLoadError> {
let file = std::fs::File::open(path.as_ref())?;
let mmap = unsafe { memmap2::Mmap::map(&file)? };
let mut cursor = std::io::Cursor::new(&mmap[..]);
let gguf = crate::gguf::parse(&mut cursor)
.map_err(|e| WeightLoadError::Io(std::io::Error::other(e)))?;
let mut tensors = HashMap::with_capacity(gguf.tensors.len());
for tensor_info in &gguf.tensors {
let data_offset = gguf.tensor_data_offset + tensor_info.offset;
let data_size = tensor_info.data_size() as usize;
let numel = tensor_info.numel() as usize;
let start = data_offset as usize;
let end = start + data_size;
if end > mmap.len() {
return Err(WeightLoadError::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
format!("tensor {} extends past end of file", tensor_info.name),
)));
}
let raw = &mmap[start..end];
let hf_name = gguf_name_to_hf(&tensor_info.name);
let weight_data = if tensor_info.ggml_type == crate::gguf::GGMLType::Q8_0 {
WeightData::Q8_0Raw(raw.to_vec())
} else if tensor_info.ggml_type == crate::gguf::GGMLType::Q4_0 {
WeightData::Q4_0Raw(raw.to_vec())
} else {
let f32_data = dequantize(raw, tensor_info.ggml_type, numel)?;
WeightData::F32(f32_data)
};
tensors.insert(hf_name, weight_data);
}
let mut weights = ModelWeightsRaw { tensors };
auto_split_fused_tensors_raw(&mut weights, &gguf);
Ok((gguf, weights))
}
#[derive(Debug, Clone)]
pub struct ModelWeights {
pub tensors: HashMap<String, Vec<f32>>,
}
impl ModelWeights {
pub fn get(&self, name: &str) -> Option<&[f32]> {
self.tensors.get(name).map(|v| v.as_slice())
}
pub fn tensor(&self, name: &str) -> &[f32] {
self.tensors
.get(name)
.unwrap_or_else(|| panic!("weight not found: {name}"))
}
pub fn len(&self) -> usize {
self.tensors.len()
}
pub fn is_empty(&self) -> bool {
self.tensors.is_empty()
}
pub fn total_elements(&self) -> usize {
self.tensors.values().map(|v| v.len()).sum()
}
pub fn memory_bytes(&self) -> usize {
self.total_elements() * 4
}
}
pub fn apply_gemma_weight_tweaks(
weights: &mut ModelWeights,
_hidden_size: usize,
num_layers: usize,
) {
for layer in 0..num_layers {
for name in [
format!("model.layers.{layer}.input_layernorm.weight"),
format!("model.layers.{layer}.post_attention_layernorm.weight"),
] {
if let Some(w) = weights.tensors.get_mut(&name) {
for v in w.iter_mut() {
*v += 1.0;
}
}
}
}
if let Some(w) = weights.tensors.get_mut("model.norm.weight") {
for v in w.iter_mut() {
*v += 1.0;
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum WeightLoadError {
#[error("I/O error: {0}")]
Io(#[from] io::Error),
#[error("tensor not found in GGUF: {0}")]
TensorNotFound(String),
#[error("unsupported GGML type for dequantization: {0:?}")]
UnsupportedType(GGMLType),
}
pub fn load_all<R: Read + Seek>(
reader: &mut R,
gguf: &GGUFFile,
) -> Result<ModelWeights, WeightLoadError> {
let mut tensors = HashMap::with_capacity(gguf.tensors.len());
for tensor_info in &gguf.tensors {
let data = load_tensor(reader, gguf, tensor_info)?;
let hf_name = gguf_name_to_hf(&tensor_info.name);
tensors.insert(hf_name, data);
}
let mut weights = ModelWeights { tensors };
auto_split_fused_tensors_f32(&mut weights, gguf);
Ok(weights)
}
pub fn load_from_file(path: impl AsRef<Path>) -> Result<(GGUFFile, ModelWeights), WeightLoadError> {
let file = std::fs::File::open(path.as_ref())?;
let mmap = unsafe { memmap2::Mmap::map(&file)? };
let mut cursor = io::Cursor::new(&mmap[..]);
let gguf = gguf::parse(&mut cursor).map_err(|e| WeightLoadError::Io(io::Error::other(e)))?;
let mut tensors = HashMap::with_capacity(gguf.tensors.len());
for tensor_info in &gguf.tensors {
let data_offset = gguf.tensor_data_offset + tensor_info.offset;
let data_size = tensor_info.data_size() as usize;
let numel = tensor_info.numel() as usize;
let start = data_offset as usize;
let end = start + data_size;
if end > mmap.len() {
return Err(WeightLoadError::Io(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!("tensor {} extends past end of file", tensor_info.name),
)));
}
let raw = &mmap[start..end];
let data = dequantize(raw, tensor_info.ggml_type, numel)?;
let hf_name = gguf_name_to_hf(&tensor_info.name);
tensors.insert(hf_name, data);
}
let mut weights = ModelWeights { tensors };
auto_split_fused_tensors_f32(&mut weights, &gguf);
Ok((gguf, weights))
}
fn gguf_name_to_hf(name: &str) -> String {
match name {
"token_embd.weight" => return "model.embed_tokens.weight".to_string(),
"output_norm.weight" => return "model.norm.weight".to_string(),
"output.weight" => return "lm_head.weight".to_string(),
_ => {}
}
if let Some(rest) = name.strip_prefix("blk.") {
if let Some(dot_pos) = rest.find('.') {
let layer_num = &rest[..dot_pos];
let suffix = &rest[dot_pos + 1..];
let hf_suffix = match suffix {
"attn_norm.weight" => "input_layernorm.weight",
"attn_q.weight" => "self_attn.q_proj.weight",
"attn_k.weight" => "self_attn.k_proj.weight",
"attn_v.weight" => "self_attn.v_proj.weight",
"attn_q.bias" => "self_attn.q_proj.bias",
"attn_k.bias" => "self_attn.k_proj.bias",
"attn_v.bias" => "self_attn.v_proj.bias",
"attn_output.weight" => "self_attn.o_proj.weight",
"attn_qkv.weight" => "self_attn.qkv_proj.weight",
"ffn_norm.weight" => "post_attention_layernorm.weight",
"ffn_gate.weight" => "mlp.gate_proj.weight",
"ffn_up.weight" => "mlp.up_proj.weight",
"ffn_down.weight" => "mlp.down_proj.weight",
other => other, };
return format!("model.layers.{layer_num}.{hf_suffix}");
}
}
name.to_string()
}
fn auto_split_fused_tensors_raw(weights: &mut ModelWeightsRaw, gguf: &GGUFFile) {
let has_fused_qkv = weights
.tensors
.contains_key("model.layers.0.self_attn.qkv_proj.weight");
let has_fused_ffn = {
let gate = "model.layers.0.mlp.gate_proj.weight";
let up = "model.layers.0.mlp.up_proj.weight";
!weights.tensors.contains_key(gate) && weights.tensors.contains_key(up)
};
if !has_fused_qkv && !has_fused_ffn {
return;
}
if let Some(params) = fused_params_from_gguf(gguf) {
split_fused_tensors(
weights,
params.num_layers,
params.hidden_size,
params.num_heads,
params.num_kv_heads,
params.head_dim,
params.intermediate_size,
);
}
}
fn auto_split_fused_tensors_f32(weights: &mut ModelWeights, gguf: &GGUFFile) {
let has_fused_qkv = weights
.tensors
.contains_key("model.layers.0.self_attn.qkv_proj.weight");
let has_fused_ffn = {
let gate = "model.layers.0.mlp.gate_proj.weight";
let up = "model.layers.0.mlp.up_proj.weight";
!weights.tensors.contains_key(gate) && weights.tensors.contains_key(up)
};
if !has_fused_qkv && !has_fused_ffn {
return;
}
if let Some(params) = fused_params_from_gguf(gguf) {
split_fused_tensors_f32(
weights,
params.num_layers,
params.hidden_size,
params.num_heads,
params.num_kv_heads,
params.head_dim,
params.intermediate_size,
);
}
}
struct FusedParams {
num_layers: usize,
hidden_size: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
intermediate_size: usize,
}
fn fused_params_from_gguf(gguf: &GGUFFile) -> Option<FusedParams> {
let arch = gguf.get_str("general.architecture")?.to_string();
let num_layers = gguf.get_u32(&format!("{arch}.block_count"))? as usize;
let hidden_size = gguf.get_u32(&format!("{arch}.embedding_length"))? as usize;
let num_heads = gguf.get_u32(&format!("{arch}.attention.head_count"))? as usize;
let num_kv_heads = gguf
.get_u32(&format!("{arch}.attention.head_count_kv"))
.map(|v| v as usize)
.unwrap_or(num_heads);
let intermediate_size = gguf
.get_u32(&format!("{arch}.feed_forward_length"))
.map(|v| v as usize)
.unwrap_or(hidden_size * 4);
let head_dim = hidden_size / num_heads;
Some(FusedParams {
num_layers,
hidden_size,
num_heads,
num_kv_heads,
head_dim,
intermediate_size,
})
}
pub fn split_fused_tensors(
weights: &mut ModelWeightsRaw,
num_layers: usize,
hidden_size: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
intermediate_size: usize,
) {
for layer in 0..num_layers {
let qkv_name = format!("model.layers.{layer}.self_attn.qkv_proj.weight");
if let Some(fused) = weights.tensors.remove(&qkv_name) {
let q_elems = hidden_size * num_heads * head_dim;
let kv_elems = hidden_size * num_kv_heads * head_dim;
let (q, k, v) = split_weight_three(&fused, q_elems, kv_elems, kv_elems);
weights
.tensors
.insert(format!("model.layers.{layer}.self_attn.q_proj.weight"), q);
weights
.tensors
.insert(format!("model.layers.{layer}.self_attn.k_proj.weight"), k);
weights
.tensors
.insert(format!("model.layers.{layer}.self_attn.v_proj.weight"), v);
}
let gate_name = format!("model.layers.{layer}.mlp.gate_proj.weight");
let up_name = format!("model.layers.{layer}.mlp.up_proj.weight");
if !weights.tensors.contains_key(&gate_name) {
let fused_size = weights.tensors.get(&up_name).map(weight_elem_count);
let expected_single = hidden_size * intermediate_size;
if fused_size == Some(2 * expected_single) {
let fused = weights.tensors.remove(&up_name).unwrap();
let (gate, up) = split_weight_two(&fused, expected_single, expected_single);
weights.tensors.insert(gate_name, gate);
weights.tensors.insert(up_name, up);
}
}
}
}
pub fn split_fused_tensors_f32(
weights: &mut ModelWeights,
num_layers: usize,
hidden_size: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
intermediate_size: usize,
) {
for layer in 0..num_layers {
let qkv_name = format!("model.layers.{layer}.self_attn.qkv_proj.weight");
if let Some(fused) = weights.tensors.remove(&qkv_name) {
let q_elems = hidden_size * num_heads * head_dim;
let kv_elems = hidden_size * num_kv_heads * head_dim;
assert_eq!(fused.len(), q_elems + 2 * kv_elems);
let q = fused[0..q_elems].to_vec();
let k = fused[q_elems..q_elems + kv_elems].to_vec();
let v = fused[q_elems + kv_elems..].to_vec();
weights
.tensors
.insert(format!("model.layers.{layer}.self_attn.q_proj.weight"), q);
weights
.tensors
.insert(format!("model.layers.{layer}.self_attn.k_proj.weight"), k);
weights
.tensors
.insert(format!("model.layers.{layer}.self_attn.v_proj.weight"), v);
}
let gate_name = format!("model.layers.{layer}.mlp.gate_proj.weight");
let up_name = format!("model.layers.{layer}.mlp.up_proj.weight");
if !weights.tensors.contains_key(&gate_name) {
let fused_size = weights.tensors.get(&up_name).map(|v| v.len());
let expected_single = hidden_size * intermediate_size;
if fused_size == Some(2 * expected_single) {
let fused = weights.tensors.remove(&up_name).unwrap();
let gate = fused[0..expected_single].to_vec();
let up = fused[expected_single..].to_vec();
weights.tensors.insert(gate_name, gate);
weights.tensors.insert(up_name, up);
}
}
}
}
fn weight_elem_count(w: &WeightData) -> usize {
match w {
WeightData::F32(v) => v.len(),
WeightData::Q8_0Raw(b) => b.len() / 34 * 32,
WeightData::Q4_0Raw(b) => b.len() / 18 * 32,
}
}
fn split_weight_three(
w: &WeightData,
n1: usize,
n2: usize,
n3: usize,
) -> (WeightData, WeightData, WeightData) {
match w {
WeightData::F32(v) => {
assert_eq!(v.len(), n1 + n2 + n3);
(
WeightData::F32(v[0..n1].to_vec()),
WeightData::F32(v[n1..n1 + n2].to_vec()),
WeightData::F32(v[n1 + n2..].to_vec()),
)
}
WeightData::Q8_0Raw(b) => {
let b1 = n1 / 32 * 34;
let b2 = n2 / 32 * 34;
let b3 = n3 / 32 * 34;
assert_eq!(b.len(), b1 + b2 + b3);
(
WeightData::Q8_0Raw(b[0..b1].to_vec()),
WeightData::Q8_0Raw(b[b1..b1 + b2].to_vec()),
WeightData::Q8_0Raw(b[b1 + b2..].to_vec()),
)
}
WeightData::Q4_0Raw(b) => {
let b1 = n1 / 32 * 18;
let b2 = n2 / 32 * 18;
let b3 = n3 / 32 * 18;
assert_eq!(b.len(), b1 + b2 + b3);
(
WeightData::Q4_0Raw(b[0..b1].to_vec()),
WeightData::Q4_0Raw(b[b1..b1 + b2].to_vec()),
WeightData::Q4_0Raw(b[b1 + b2..].to_vec()),
)
}
}
}
fn split_weight_two(w: &WeightData, n1: usize, n2: usize) -> (WeightData, WeightData) {
match w {
WeightData::F32(v) => {
assert_eq!(v.len(), n1 + n2);
(
WeightData::F32(v[0..n1].to_vec()),
WeightData::F32(v[n1..].to_vec()),
)
}
WeightData::Q8_0Raw(b) => {
let b1 = n1 / 32 * 34;
let b2 = n2 / 32 * 34;
assert_eq!(b.len(), b1 + b2);
(
WeightData::Q8_0Raw(b[0..b1].to_vec()),
WeightData::Q8_0Raw(b[b1..].to_vec()),
)
}
WeightData::Q4_0Raw(b) => {
let b1 = n1 / 32 * 18;
let b2 = n2 / 32 * 18;
assert_eq!(b.len(), b1 + b2);
(
WeightData::Q4_0Raw(b[0..b1].to_vec()),
WeightData::Q4_0Raw(b[b1..].to_vec()),
)
}
}
}
pub fn load_tensor<R: Read + Seek>(
reader: &mut R,
gguf: &GGUFFile,
tensor_info: &GGUFTensorInfo,
) -> Result<Vec<f32>, WeightLoadError> {
let data_offset = gguf.tensor_data_offset + tensor_info.offset;
let data_size = tensor_info.data_size() as usize;
let numel = tensor_info.numel() as usize;
reader.seek(SeekFrom::Start(data_offset))?;
let mut raw = vec![0u8; data_size];
reader.read_exact(&mut raw)?;
dequantize(&raw, tensor_info.ggml_type, numel)
}
pub fn load_tensor_by_name<R: Read + Seek>(
reader: &mut R,
gguf: &GGUFFile,
name: &str,
) -> Result<Vec<f32>, WeightLoadError> {
let tensor_info = gguf
.tensor(name)
.ok_or_else(|| WeightLoadError::TensorNotFound(name.to_string()))?;
load_tensor(reader, gguf, tensor_info)
}
pub fn dequantize_q8_0_to_f32(raw: &[u8], numel: usize) -> Vec<f32> {
dequant_q8_0(raw, numel)
}
fn dequantize(data: &[u8], ggml_type: GGMLType, numel: usize) -> Result<Vec<f32>, WeightLoadError> {
match ggml_type {
GGMLType::F32 => Ok(dequant_f32(data, numel)),
GGMLType::F16 => Ok(dequant_f16(data, numel)),
GGMLType::BF16 => Ok(dequant_bf16(data, numel)),
GGMLType::Q8_0 => Ok(dequant_q8_0(data, numel)),
GGMLType::Q4_0 => Ok(dequant_q4_0(data, numel)),
GGMLType::Q4_1 => Ok(dequant_q4_1(data, numel)),
GGMLType::Q6K => Ok(dequant_q6_k(data, numel)),
GGMLType::Q5K => Ok(dequant_q5_k(data, numel)),
GGMLType::Q4K => Ok(dequant_q4_k(data, numel)),
GGMLType::Q8K => Ok(dequant_q8_k(data, numel)),
GGMLType::Q3K => Ok(dequant_q3_k(data, numel)),
GGMLType::Q2K => Ok(dequant_q2_k(data, numel)),
other => Err(WeightLoadError::UnsupportedType(other)),
}
}
fn dequant_f32(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
for (i, chunk) in data.chunks_exact(4).enumerate().take(numel) {
output[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
}
output
}
fn dequant_f16(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
for (i, chunk) in data.chunks_exact(2).enumerate().take(numel) {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
output[i] = f16_to_f32(bits);
}
output
}
fn dequant_bf16(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
for (i, chunk) in data.chunks_exact(2).enumerate().take(numel) {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
output[i] = f32::from_bits((bits as u32) << 16);
}
output
}
fn dequant_q8_0(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
let block_size = 32;
let type_size = 34; let num_blocks = numel.div_ceil(block_size);
for block_idx in 0..num_blocks {
let block_start = block_idx * type_size;
if block_start + type_size > data.len() {
break;
}
let scale_bits = u16::from_le_bytes([data[block_start], data[block_start + 1]]);
let scale = f16_to_f32(scale_bits);
for j in 0..block_size {
let out_idx = block_idx * block_size + j;
if out_idx >= numel {
break;
}
let quant = data[block_start + 2 + j] as i8;
output[out_idx] = quant as f32 * scale;
}
}
output
}
fn dequant_q4_0(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
let block_size = 32;
let type_size = 18; let num_blocks = numel.div_ceil(block_size);
for block_idx in 0..num_blocks {
let block_start = block_idx * type_size;
if block_start + type_size > data.len() {
break;
}
let scale_bits = u16::from_le_bytes([data[block_start], data[block_start + 1]]);
let scale = f16_to_f32(scale_bits);
for j in 0..16 {
let byte = data[block_start + 2 + j];
let lo = (byte & 0x0F) as i32 - 8; let hi = ((byte >> 4) & 0x0F) as i32 - 8;
let out_idx_lo = block_idx * block_size + j;
let out_idx_hi = block_idx * block_size + j + 16;
if out_idx_lo < numel {
output[out_idx_lo] = lo as f32 * scale;
}
if out_idx_hi < numel {
output[out_idx_hi] = hi as f32 * scale;
}
}
}
output
}
fn dequant_q4_1(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
let block_size = 32;
let type_size = 20; let num_blocks = numel.div_ceil(block_size);
for block_idx in 0..num_blocks {
let block_start = block_idx * type_size;
if block_start + type_size > data.len() {
break;
}
let scale_bits = u16::from_le_bytes([data[block_start], data[block_start + 1]]);
let min_bits = u16::from_le_bytes([data[block_start + 2], data[block_start + 3]]);
let scale = f16_to_f32(scale_bits);
let min = f16_to_f32(min_bits);
for j in 0..16 {
let byte = data[block_start + 4 + j];
let lo = (byte & 0x0F) as f32;
let hi = ((byte >> 4) & 0x0F) as f32;
let out_idx_lo = block_idx * block_size + j;
let out_idx_hi = block_idx * block_size + j + 16;
if out_idx_lo < numel {
output[out_idx_lo] = lo * scale + min;
}
if out_idx_hi < numel {
output[out_idx_hi] = hi * scale + min;
}
}
}
output
}
fn get_scale_min_k4(j: usize, scales: &[u8]) -> (u8, u8) {
if j < 4 {
let sc = scales[j] & 63;
let m = scales[j + 4] & 63;
(sc, m)
} else {
let sc = (scales[j + 4] & 0x0F) | ((scales[j - 4] >> 6) << 4);
let m = (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4);
(sc, m)
}
}
fn dequant_q6_k(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
let block_size = 256;
let type_size = 210;
let num_blocks = numel.div_ceil(block_size);
for block_idx in 0..num_blocks {
let bs = block_idx * type_size;
if bs + type_size > data.len() {
break;
}
let ql = &data[bs..bs + 128];
let qh = &data[bs + 128..bs + 192];
let scales = &data[bs + 192..bs + 208];
let d_bits = u16::from_le_bytes([data[bs + 208], data[bs + 209]]);
let d = f16_to_f32(d_bits);
let out_base = block_idx * block_size;
for n_off in (0..block_size).step_by(128) {
let ql_off = n_off / 2; let qh_off = n_off / 4; let sc_off = n_off / 16;
for l in 0..32 {
let is = l / 16;
let q1 = ((ql[ql_off + l] & 0x0F) | ((qh[qh_off + l] & 3) << 4)) as i32 - 32;
let q2 =
((ql[ql_off + l + 32] & 0x0F) | (((qh[qh_off + l] >> 2) & 3) << 4)) as i32 - 32;
let q3 = ((ql[ql_off + l] >> 4) | (((qh[qh_off + l] >> 4) & 3) << 4)) as i32 - 32;
let q4 =
((ql[ql_off + l + 32] >> 4) | (((qh[qh_off + l] >> 6) & 3) << 4)) as i32 - 32;
let sc1 = scales[sc_off + is] as i8;
let sc2 = scales[sc_off + is + 2] as i8;
let sc3 = scales[sc_off + is + 4] as i8;
let sc4 = scales[sc_off + is + 6] as i8;
let out_idx = out_base + n_off + l;
if out_idx < numel {
output[out_idx] = d * sc1 as f32 * q1 as f32;
}
if out_idx + 32 < numel {
output[out_idx + 32] = d * sc2 as f32 * q2 as f32;
}
if out_idx + 64 < numel {
output[out_idx + 64] = d * sc3 as f32 * q3 as f32;
}
if out_idx + 96 < numel {
output[out_idx + 96] = d * sc4 as f32 * q4 as f32;
}
}
}
}
output
}
fn dequant_q5_k(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
let block_size = 256;
let type_size = 176;
let num_blocks = numel.div_ceil(block_size);
for block_idx in 0..num_blocks {
let bs = block_idx * type_size;
if bs + type_size > data.len() {
break;
}
let d_bits = u16::from_le_bytes([data[bs], data[bs + 1]]);
let dmin_bits = u16::from_le_bytes([data[bs + 2], data[bs + 3]]);
let d = f16_to_f32(d_bits);
let dmin = f16_to_f32(dmin_bits);
let scales = &data[bs + 4..bs + 16];
let qh = &data[bs + 16..bs + 48];
let qs = &data[bs + 48..bs + 176];
let out_base = block_idx * block_size;
let mut ql_off = 0usize;
let mut is = 0usize;
let mut u1: u8 = 1;
let mut u2: u8 = 2;
for chunk in 0..4 {
let (sc1, m1) = get_scale_min_k4(is, scales);
let (sc2, m2) = get_scale_min_k4(is + 1, scales);
let d1 = d * sc1 as f32;
let m1 = dmin * m1 as f32;
let d2 = d * sc2 as f32;
let m2 = dmin * m2 as f32;
for l in 0..32 {
let out_idx = out_base + chunk * 64 + l;
if out_idx >= numel {
break;
}
let q_lo = (qs[ql_off + l] & 0x0F) as u32;
let qh_bit = if qh[l] & u1 != 0 { 16u32 } else { 0u32 };
output[out_idx] = d1 * (q_lo + qh_bit) as f32 - m1;
}
for l in 0..32 {
let out_idx = out_base + chunk * 64 + 32 + l;
if out_idx >= numel {
break;
}
let q_hi = (qs[ql_off + l] >> 4) as u32;
let qh_bit = if qh[l] & u2 != 0 { 16u32 } else { 0u32 };
output[out_idx] = d2 * (q_hi + qh_bit) as f32 - m2;
}
ql_off += 32;
is += 2;
u1 <<= 2;
u2 <<= 2;
}
}
output
}
fn dequant_q4_k(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
let block_size = 256;
let type_size = 144;
let num_blocks = numel.div_ceil(block_size);
for block_idx in 0..num_blocks {
let bs = block_idx * type_size;
if bs + type_size > data.len() {
break;
}
let d_bits = u16::from_le_bytes([data[bs], data[bs + 1]]);
let dmin_bits = u16::from_le_bytes([data[bs + 2], data[bs + 3]]);
let d = f16_to_f32(d_bits);
let dmin = f16_to_f32(dmin_bits);
let scales = &data[bs + 4..bs + 16];
let qs = &data[bs + 16..bs + 144];
let out_base = block_idx * block_size;
let mut q_off = 0usize;
let mut is = 0usize;
for chunk in 0..4 {
let (sc1, m1) = get_scale_min_k4(is, scales);
let (sc2, m2) = get_scale_min_k4(is + 1, scales);
let d1 = d * sc1 as f32;
let m1 = dmin * m1 as f32;
let d2 = d * sc2 as f32;
let m2 = dmin * m2 as f32;
for l in 0..32 {
let out_idx = out_base + chunk * 64 + l;
if out_idx >= numel {
break;
}
output[out_idx] = d1 * (qs[q_off + l] & 0x0F) as f32 - m1;
}
for l in 0..32 {
let out_idx = out_base + chunk * 64 + 32 + l;
if out_idx >= numel {
break;
}
output[out_idx] = d2 * (qs[q_off + l] >> 4) as f32 - m2;
}
q_off += 32;
is += 2;
}
}
output
}
fn dequant_q8_k(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
let block_size = 256;
let type_size = 292;
let num_blocks = numel.div_ceil(block_size);
for block_idx in 0..num_blocks {
let bs = block_idx * type_size;
if bs + type_size > data.len() {
break;
}
let scale = f32::from_le_bytes([data[bs], data[bs + 1], data[bs + 2], data[bs + 3]]);
for j in 0..block_size {
if block_idx * block_size + j >= numel {
break;
}
let q = data[bs + 4 + j] as i8;
output[block_idx * block_size + j] = scale * q as f32;
}
}
output
}
fn dequant_q3_k(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
let block_size = 256;
let type_size = 110;
let num_blocks = numel.div_ceil(block_size);
for block_idx in 0..num_blocks {
let bs = block_idx * type_size;
if bs + type_size > data.len() {
break;
}
let hmask = &data[bs..bs + 32];
let qs = &data[bs + 32..bs + 96];
let scales_raw = &data[bs + 96..bs + 108];
let d_bits = u16::from_le_bytes([data[bs + 108], data[bs + 109]]);
let d = f16_to_f32(d_bits);
for j in 0..block_size {
if block_idx * block_size + j >= numel {
break;
}
let group = j / 16;
let sc = if group < scales_raw.len() {
((scales_raw[group / 2] >> ((group % 2) * 4)) & 0x0F) as i32 - 8
} else {
0
};
let byte_idx = j * 3 / 8;
let bit_offset = (j * 3) % 8;
let q3 = if byte_idx < qs.len() {
((qs[byte_idx] >> bit_offset) & 0x07) as i32 - 4
} else {
0
};
let hbit = if j / 8 < hmask.len() {
((hmask[j / 8] >> (j % 8)) & 1) as i32
} else {
0
};
let q = q3 - hbit * 4;
output[block_idx * block_size + j] = d * sc as f32 * q as f32;
}
}
output
}
fn dequant_q2_k(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
let block_size = 256;
let type_size = 84;
let num_blocks = numel.div_ceil(block_size);
for block_idx in 0..num_blocks {
let bs = block_idx * type_size;
if bs + type_size > data.len() {
break;
}
let scales = &data[bs..bs + 16];
let qs = &data[bs + 16..bs + 80];
let d_bits = u16::from_le_bytes([data[bs + 80], data[bs + 81]]);
let dmin_bits = u16::from_le_bytes([data[bs + 82], data[bs + 83]]);
let d = f16_to_f32(d_bits);
let dmin = f16_to_f32(dmin_bits);
for j in 0..block_size {
if block_idx * block_size + j >= numel {
break;
}
let group = j / 16;
let sc = scales[group] & 0x0F;
let m = (scales[group] >> 4) & 0x0F;
let byte_idx = j / 4;
let q = if byte_idx < qs.len() {
(qs[byte_idx] >> ((j % 4) * 2)) & 0x03
} else {
0
};
output[block_idx * block_size + j] = d * sc as f32 * q as f32 - dmin * m as f32;
}
}
output
}
fn f16_to_f32(bits: u16) -> f32 {
let sign = ((bits >> 15) & 1) as u32;
let exponent = ((bits >> 10) & 0x1F) as u32;
let mantissa = (bits & 0x3FF) as u32;
if exponent == 0 {
if mantissa == 0 {
return f32::from_bits(sign << 31);
}
let mut m = mantissa;
let mut e: i32 = -14; while m & 0x400 == 0 {
m <<= 1;
e -= 1;
}
m &= 0x3FF; let f32_exp = ((e + 127) as u32) & 0xFF;
return f32::from_bits((sign << 31) | (f32_exp << 23) | (m << 13));
}
if exponent == 31 {
let f32_mantissa = mantissa << 13;
return f32::from_bits((sign << 31) | (0xFF << 23) | f32_mantissa);
}
let f32_exp = (exponent as i32 - 15 + 127) as u32;
f32::from_bits((sign << 31) | (f32_exp << 23) | (mantissa << 13))
}
fn f32_to_f16(x: f32) -> u16 {
let bits = x.to_bits();
let sign = (bits >> 31) & 1;
let exp = ((bits >> 23) & 0xFF) as i32;
let mant = bits & 0x7F_FFFF;
if exp == 0xFF {
let h_mant = (mant >> 13) & 0x3FF;
return ((sign << 15) | (0x1F << 10) | h_mant) as u16;
}
let unbiased = exp - 127;
if unbiased > 15 {
return ((sign << 15) | (0x1F << 10)) as u16;
}
if unbiased < -24 {
return (sign << 15) as u16;
}
if unbiased < -14 {
let shift = (-14 - unbiased) as u32;
let h_mant = (mant | 0x80_0000) >> (14 + shift);
return ((sign << 15) | h_mant) as u16;
}
let h_exp = (unbiased + 15) as u32;
let h_mant = mant >> 13;
((sign << 15) | (h_exp << 10) | h_mant) as u16
}
pub fn quantize_f32_to_q4_0(data: &[f32]) -> Vec<u8> {
let num_blocks = data.len().div_ceil(32);
let mut out = vec![0u8; num_blocks * 18];
for blk in 0..num_blocks {
let base = blk * 32;
let end = (base + 32).min(data.len());
let block_data = &data[base..end];
let mut amax = 0.0f32;
for &v in block_data {
let a = v.abs();
if a > amax {
amax = a;
}
}
let scale = amax / 8.0;
let inv_scale = if scale != 0.0 { 1.0 / scale } else { 0.0 };
let ob = blk * 18;
let scale_bits = f32_to_f16(scale);
out[ob] = scale_bits as u8;
out[ob + 1] = (scale_bits >> 8) as u8;
for j in 0..16 {
let lo_idx = j;
let hi_idx = j + 16;
let lo_val = if base + lo_idx < data.len() {
data[base + lo_idx]
} else {
0.0
};
let hi_val = if base + hi_idx < data.len() {
data[base + hi_idx]
} else {
0.0
};
let lo_q = ((lo_val * inv_scale).round() as i32 + 8).clamp(0, 15) as u8;
let hi_q = ((hi_val * inv_scale).round() as i32 + 8).clamp(0, 15) as u8;
out[ob + 2 + j] = lo_q | (hi_q << 4);
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn f16_conversion_basic() {
assert_eq!(f16_to_f32(0x0000), 0.0);
assert!((f16_to_f32(0x3C00) - 1.0).abs() < 1e-6);
assert!((f16_to_f32(0xBC00) - (-1.0)).abs() < 1e-6);
assert!((f16_to_f32(0x3800) - 0.5).abs() < 1e-6);
assert!((f16_to_f32(0x4000) - 2.0).abs() < 1e-6);
}
#[test]
fn f16_special_values() {
assert!(f16_to_f32(0x7C00).is_infinite());
assert!(f16_to_f32(0x7E00).is_nan());
assert_eq!(f16_to_f32(0x8000), -0.0);
}
#[test]
fn bf16_conversion() {
let data = 0x3F80u16.to_le_bytes();
let result = dequant_bf16(&data, 1);
assert!((result[0] - 1.0).abs() < 1e-6);
}
#[test]
fn dequant_f32_identity() {
let values = vec![1.0f32, 2.0, -3.5, 0.0];
let bytes: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
let result = dequant_f32(&bytes, 4);
assert_eq!(result, values);
}
#[test]
fn dequant_f16_roundtrip() {
let f16_one = 0x3C00u16; let f16_half = 0x3800u16; let bytes: Vec<u8> = [f16_one, f16_half]
.iter()
.flat_map(|v| v.to_le_bytes())
.collect();
let result = dequant_f16(&bytes, 2);
assert!((result[0] - 1.0).abs() < 1e-6);
assert!((result[1] - 0.5).abs() < 1e-6);
}
#[test]
fn split_fused_qkv_and_ffn_f32() {
let mut tensors = HashMap::new();
tensors.insert(
"model.layers.0.self_attn.qkv_proj.weight".to_string(),
(1..=48).map(|x| x as f32).collect::<Vec<f32>>(),
);
let gate = (1..=24).map(|x| x as f32).collect::<Vec<f32>>();
let up = (25..=48).map(|x| x as f32).collect::<Vec<f32>>();
let fused_ffn = gate.iter().chain(up.iter()).copied().collect::<Vec<f32>>();
tensors.insert("model.layers.0.mlp.up_proj.weight".to_string(), fused_ffn);
let mut weights = ModelWeights { tensors };
split_fused_tensors_f32(&mut weights, 1, 4, 2, 2, 2, 6);
assert_eq!(
weights
.get("model.layers.0.self_attn.q_proj.weight")
.unwrap()
.len(),
16
);
assert_eq!(
weights
.get("model.layers.0.self_attn.k_proj.weight")
.unwrap()
.len(),
16
);
assert_eq!(
weights
.get("model.layers.0.self_attn.v_proj.weight")
.unwrap()
.len(),
16
);
assert_eq!(
weights
.get("model.layers.0.self_attn.q_proj.weight")
.unwrap(),
&(1..=16).map(|x| x as f32).collect::<Vec<f32>>()[..]
);
assert_eq!(
weights
.get("model.layers.0.self_attn.v_proj.weight")
.unwrap(),
&(33..=48).map(|x| x as f32).collect::<Vec<f32>>()[..]
);
assert!(weights
.get("model.layers.0.self_attn.qkv_proj.weight")
.is_none());
assert_eq!(
weights.get("model.layers.0.mlp.gate_proj.weight").unwrap(),
&gate[..]
);
assert_eq!(
weights.get("model.layers.0.mlp.up_proj.weight").unwrap(),
&up[..]
);
}
#[test]
fn split_fused_noop_on_llama_layout() {
let mut tensors = HashMap::new();
tensors.insert(
"model.layers.0.self_attn.q_proj.weight".to_string(),
vec![1.0f32; 8],
);
tensors.insert(
"model.layers.0.mlp.gate_proj.weight".to_string(),
vec![2.0f32; 24],
);
tensors.insert(
"model.layers.0.mlp.up_proj.weight".to_string(),
vec![3.0f32; 24],
);
let mut weights = ModelWeights {
tensors: tensors.clone(),
};
split_fused_tensors_f32(&mut weights, 1, 4, 2, 2, 2, 6);
assert_eq!(weights.tensors.len(), tensors.len());
assert_eq!(
weights
.get("model.layers.0.mlp.up_proj.weight")
.unwrap()
.len(),
24
);
}
#[test]
fn split_fused_qkv_q8_block_boundaries() {
let block_bytes: [u8; 34] = {
let mut b = [0u8; 34];
b[0] = 0x00;
b[1] = 0x3C; for i in 0..32 {
b[2 + i] = i as u8;
}
b
};
let mut q_block = block_bytes;
let mut k_block = block_bytes;
let mut v_block = block_bytes;
q_block[2] = 1;
k_block[2] = 2;
v_block[2] = 3;
let mut fused = Vec::with_capacity(34 * 3);
fused.extend_from_slice(&q_block);
fused.extend_from_slice(&k_block);
fused.extend_from_slice(&v_block);
let mut tensors = HashMap::new();
tensors.insert(
"model.layers.0.self_attn.qkv_proj.weight".to_string(),
WeightData::Q8_0Raw(fused),
);
let mut weights = ModelWeightsRaw { tensors };
split_fused_tensors(&mut weights, 1, 1, 1, 1, 32, 0);
let q = weights
.get_q8_raw("model.layers.0.self_attn.q_proj.weight")
.unwrap();
let k = weights
.get_q8_raw("model.layers.0.self_attn.k_proj.weight")
.unwrap();
let v = weights
.get_q8_raw("model.layers.0.self_attn.v_proj.weight")
.unwrap();
assert_eq!(q.len(), 34);
assert_eq!(k.len(), 34);
assert_eq!(v.len(), 34);
assert_eq!(q[2], 1);
assert_eq!(k[2], 2);
assert_eq!(v[2], 3);
}
#[test]
fn dequant_q8_0_basic() {
let scale_f16: u16 = 0x3C00; let mut block = Vec::new();
block.extend_from_slice(&scale_f16.to_le_bytes());
for i in 0..32 {
block.push(i as u8); }
let result = dequant_q8_0(&block, 32);
assert_eq!(result.len(), 32);
assert!((result[0] - 0.0).abs() < 1e-6);
assert!((result[1] - 1.0).abs() < 1e-6);
assert!((result[31] - 31.0).abs() < 1e-6);
}
#[test]
fn dequant_q4_0_basic() {
let scale_f16: u16 = 0x3C00; let mut block = Vec::new();
block.extend_from_slice(&scale_f16.to_le_bytes());
block.extend(std::iter::repeat_n(0x88u8, 16));
let result = dequant_q4_0(&block, 32);
assert_eq!(result.len(), 32);
for val in &result {
assert!((val - 0.0).abs() < 1e-6);
}
}
#[test]
fn dequant_q4_1_basic() {
let scale_f16: u16 = 0x4000; let min_f16: u16 = 0x3C00; let mut block = Vec::new();
block.extend_from_slice(&scale_f16.to_le_bytes());
block.extend_from_slice(&min_f16.to_le_bytes());
block.extend(std::iter::repeat_n(0x00u8, 16));
let result = dequant_q4_1(&block, 32);
assert_eq!(result.len(), 32);
for val in &result {
assert!((val - 1.0).abs() < 1e-6, "expected 1.0, got {val}");
}
}
#[test]
fn model_weights_accessors() {
let mut tensors = HashMap::new();
tensors.insert("w1".to_string(), vec![1.0f32; 100]);
tensors.insert("w2".to_string(), vec![2.0f32; 200]);
let weights = ModelWeights { tensors };
assert_eq!(weights.len(), 2);
assert!(!weights.is_empty());
assert_eq!(weights.total_elements(), 300);
assert_eq!(weights.memory_bytes(), 1200);
assert_eq!(weights.get("w1").unwrap().len(), 100);
assert_eq!(weights.tensor("w2").len(), 200);
}
#[test]
fn model_weights_raw_accessors() {
let mut tensors = HashMap::new();
tensors.insert("norm.weight".to_string(), WeightData::F32(vec![1.0f32; 64]));
tensors.insert(
"q_proj.weight".to_string(),
WeightData::Q8_0Raw(vec![0u8; 68]),
); let raw = ModelWeightsRaw { tensors };
assert_eq!(raw.len(), 2);
assert!(!raw.is_empty());
assert!(raw.get_f32("norm.weight").is_some());
assert_eq!(raw.get_f32("norm.weight").unwrap().len(), 64);
assert!(raw.get_q8_raw("q_proj.weight").is_some());
assert_eq!(raw.get_q8_raw("q_proj.weight").unwrap().len(), 68);
assert!(raw.get_f32("q_proj.weight").is_none());
assert!(raw.get_q8_raw("norm.weight").is_none());
assert_eq!(raw.memory_bytes(), 64 * 4 + 68);
}
#[test]
fn model_weights_raw_q4_0_accessor() {
let mut tensors = HashMap::new();
tensors.insert("norm.weight".to_string(), WeightData::F32(vec![1.0f32; 64]));
tensors.insert(
"q_proj.weight".to_string(),
WeightData::Q4_0Raw(vec![0u8; 36]),
);
let raw = ModelWeightsRaw { tensors };
assert_eq!(raw.len(), 2);
assert!(raw.get_q4_raw("q_proj.weight").is_some());
assert_eq!(raw.get_q4_raw("q_proj.weight").unwrap().len(), 36);
assert!(raw.get_f32("q_proj.weight").is_none());
assert!(raw.get_q8_raw("q_proj.weight").is_none());
assert!(raw.get_q4_raw("norm.weight").is_none());
assert_eq!(raw.memory_bytes(), 64 * 4 + 36);
}
#[test]
fn weight_data_q4_0_raw_stored_correctly() {
let scale_f16: u16 = 0x3C00; let mut block = Vec::new();
block.extend_from_slice(&scale_f16.to_le_bytes()); block.extend(std::iter::repeat_n(0x88u8, 16));
let wd = WeightData::Q4_0Raw(block.clone());
match &wd {
WeightData::Q4_0Raw(v) => {
assert_eq!(v.len(), 18);
assert_eq!(&v[..], &block[..]);
}
_ => panic!("expected Q4_0Raw variant"),
}
}
#[test]
fn quantize_q4_0_roundtrip() {
let input: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
let q4_bytes = quantize_f32_to_q4_0(&input);
assert_eq!(q4_bytes.len(), 2 * 18);
let output = dequant_q4_0(&q4_bytes, 64);
assert_eq!(output.len(), 64);
let max_abs = input.iter().copied().fold(0.0f32, |a, v| a.max(v.abs()));
let step = max_abs / 8.0; for (i, (&orig, &decoded)) in input.iter().zip(output.iter()).enumerate() {
let diff = (orig - decoded).abs();
assert!(
diff <= step + 0.01,
"element {i}: orig={orig}, decoded={decoded}, diff={diff}, step={step}"
);
}
}
#[test]
fn quantize_q4_0_zeros() {
let input = vec![0.0f32; 32];
let q4_bytes = quantize_f32_to_q4_0(&input);
assert_eq!(q4_bytes.len(), 18);
let scale_bits = u16::from_le_bytes([q4_bytes[0], q4_bytes[1]]);
assert_eq!(scale_bits, 0, "zero input should produce zero scale");
let output = dequant_q4_0(&q4_bytes, 32);
for &v in &output {
assert_eq!(v, 0.0);
}
}
#[test]
fn get_scale_min_k4_low_groups() {
let mut scales = [0u8; 12];
scales[0] = 10; scales[4] = 20; scales[1] = 30; scales[5] = 40; scales[3] = 63; scales[7] = 63;
let (sc, m) = get_scale_min_k4(0, &scales);
assert_eq!(sc, 10);
assert_eq!(m, 20);
let (sc, m) = get_scale_min_k4(1, &scales);
assert_eq!(sc, 30);
assert_eq!(m, 40);
let (sc, m) = get_scale_min_k4(3, &scales);
assert_eq!(sc, 63);
assert_eq!(m, 63);
}
#[test]
fn get_scale_min_k4_high_groups() {
let mut scales = [0u8; 12];
scales[8] = 0xA5; scales[0] = 0xC0; scales[4] = 0x80;
let (sc, m) = get_scale_min_k4(4, &scales);
assert_eq!(sc, 5 | (3 << 4)); assert_eq!(m, 10 | (2 << 4)); }
#[test]
fn dequant_q4_k_basic() {
let mut block = vec![0u8; 144];
let d_f16: u16 = 0x3C00; let dmin_f16: u16 = 0x0000; block[0..2].copy_from_slice(&d_f16.to_le_bytes());
block[2..4].copy_from_slice(&dmin_f16.to_le_bytes());
block[4] = 1;
block[5] = 1;
block[6] = 1;
block[7] = 1;
block[12] = 0x01;
block[13] = 0x01;
block[14] = 0x01;
block[15] = 0x01;
let result = dequant_q4_k(&block, 256);
assert_eq!(result.len(), 256);
for (i, &val) in result.iter().enumerate() {
assert!(val.abs() < 1e-6, "element {i}: expected 0.0, got {val}");
}
}
#[test]
fn dequant_q4_k_with_values() {
let mut block = vec![0u8; 144];
let d_f16: u16 = 0x4000; let dmin_f16: u16 = 0x3C00; block[0..2].copy_from_slice(&d_f16.to_le_bytes());
block[2..4].copy_from_slice(&dmin_f16.to_le_bytes());
block[4] = 3;
block[8] = 2;
for i in 0..32 {
block[16 + i] = 0x05; }
let result = dequant_q4_k(&block, 256);
for i in 0..32 {
assert!(
(result[i] - 28.0).abs() < 1e-4,
"element {i}: expected 28.0, got {}",
result[i]
);
}
}
#[test]
fn dequant_q5_k_basic_zeros() {
let mut block = vec![0u8; 176];
let d_f16: u16 = 0x3C00; let dmin_f16: u16 = 0x0000; block[0..2].copy_from_slice(&d_f16.to_le_bytes());
block[2..4].copy_from_slice(&dmin_f16.to_le_bytes());
block[4] = 1;
block[5] = 1;
block[6] = 1;
block[7] = 1;
block[12] = 0x01;
block[13] = 0x01;
block[14] = 0x01;
block[15] = 0x01;
let result = dequant_q5_k(&block, 256);
assert_eq!(result.len(), 256);
for (i, &val) in result.iter().enumerate() {
assert!(val.abs() < 1e-6, "element {i}: expected 0.0, got {val}");
}
}
#[test]
fn dequant_q5_k_with_high_bit() {
let mut block = vec![0u8; 176];
let d_f16: u16 = 0x3C00; block[0..2].copy_from_slice(&d_f16.to_le_bytes());
block[2..4].copy_from_slice(&0x0000u16.to_le_bytes());
block[4] = 1;
block[16] = 0x01; block[48] = 0x03;
let result = dequant_q5_k(&block, 256);
assert!(
(result[0] - 19.0).abs() < 1e-4,
"element 0: expected 19.0, got {}",
result[0]
);
}
#[test]
fn dequant_q5_k_high_nibble_with_qh() {
let mut block = vec![0u8; 176];
let d_f16: u16 = 0x3C00; block[0..2].copy_from_slice(&d_f16.to_le_bytes());
block[2..4].copy_from_slice(&0x0000u16.to_le_bytes());
block[5] = 2;
block[16] = 0x02;
block[48] = 0x70;
let result = dequant_q5_k(&block, 256);
assert!(
(result[32] - 46.0).abs() < 1e-4,
"element 32: expected 46.0, got {}",
result[32]
);
}
#[test]
fn dequant_q6_k_basic_zeros() {
let mut block = vec![0u8; 210];
let d_f16: u16 = 0x3C00; block[208..210].copy_from_slice(&d_f16.to_le_bytes());
for i in 192..208 {
block[i] = 1;
}
let result = dequant_q6_k(&block, 256);
assert_eq!(result.len(), 256);
for (i, &val) in result.iter().enumerate() {
assert!(
(val - (-32.0)).abs() < 1e-4,
"element {i}: expected -32.0, got {val}"
);
}
}
#[test]
fn dequant_q6_k_with_values() {
let mut block = vec![0u8; 210];
let d_f16: u16 = 0x3800; block[208..210].copy_from_slice(&d_f16.to_le_bytes());
block[192] = 2;
block[0] = 0x0A; block[128] = 0x02;
let result = dequant_q6_k(&block, 256);
assert!(
(result[0] - 10.0).abs() < 1e-4,
"element 0: expected 10.0, got {}",
result[0]
);
}
#[test]
fn dequant_q6_k_four_interleaved_values() {
let mut block = vec![0u8; 210];
let d_f16: u16 = 0x3C00; block[208..210].copy_from_slice(&d_f16.to_le_bytes());
for i in 192..208 {
block[i] = 1;
}
block[0] = 0x31;
block[32] = 0x52;
block[128] = 0xE4;
let result = dequant_q6_k(&block, 256);
assert!(
(result[0] - (-31.0)).abs() < 1e-4,
"q1 at [0]: expected -31.0, got {}",
result[0]
);
assert!(
(result[32] - (-14.0)).abs() < 1e-4,
"q2 at [32]: expected -14.0, got {}",
result[32]
);
assert!(
(result[64] - 3.0).abs() < 1e-4,
"q3 at [64]: expected 3.0, got {}",
result[64]
);
assert!(
(result[96] - 21.0).abs() < 1e-4,
"q4 at [96]: expected 21.0, got {}",
result[96]
);
}
#[test]
fn quantize_q4_0_handles_large_values() {
let mut input = vec![0.0f32; 32];
input[0] = 1000.0;
input[1] = -1000.0;
input[15] = 500.0;
input[16] = -500.0;
let q4_bytes = quantize_f32_to_q4_0(&input);
assert_eq!(q4_bytes.len(), 18);
let output = dequant_q4_0(&q4_bytes, 32);
for (i, &v) in output.iter().enumerate() {
assert!(
v.is_finite(),
"dequantized value at index {i} is not finite: {v}"
);
}
assert!(
output[0] > 0.0,
"large positive value should remain positive after Q4_0 roundtrip"
);
assert!(
output[1] < 0.0,
"large negative value should remain negative after Q4_0 roundtrip"
);
}
#[test]
fn dequant_q8_0_roundtrip_preserves_sign() {
let scale_f16: u16 = 0x4000; let mut block = Vec::new();
block.extend_from_slice(&scale_f16.to_le_bytes());
for i in 0..32i8 {
if i % 2 == 0 {
block.push(i as u8); } else {
block.push((-i) as u8); }
}
let result = dequant_q8_0(&block, 32);
assert_eq!(result.len(), 32);
for i in (0..32).step_by(2) {
assert!(
result[i] >= 0.0,
"Q8_0 dequant: index {i} should be non-negative, got {}",
result[i]
);
}
for i in (1..32).step_by(2) {
assert!(
result[i] < 0.0,
"Q8_0 dequant: index {i} should be negative, got {}",
result[i]
);
}
assert!((result[0] - 0.0).abs() < 1e-6);
assert!((result[1] - (-2.0)).abs() < 1e-6);
assert!((result[2] - 4.0).abs() < 1e-6);
assert!((result[31] - (-62.0)).abs() < 1e-6);
}
#[test]
fn dequant_q8_0_all_zeros() {
let mut block = vec![0u8; 34]; for i in 2..34 {
block[i] = 127; }
let result = dequant_q8_0(&block, 32);
for (i, &v) in result.iter().enumerate() {
assert_eq!(v, 0.0, "with zero scale, index {i} should be 0.0, got {v}");
}
}
#[test]
fn quantize_q4_0_roundtrip_sign_preservation() {
let input: Vec<f32> = (0..32)
.map(|i| {
if i < 16 {
-(i as f32 + 1.0)
} else {
i as f32 - 15.0
}
})
.collect();
let q4_bytes = quantize_f32_to_q4_0(&input);
let output = dequant_q4_0(&q4_bytes, 32);
for i in 0..16 {
assert!(
output[i] <= 0.0,
"input[{i}]={}, roundtrip output[{i}]={} should be <= 0.0",
input[i],
output[i]
);
}
for i in 16..32 {
assert!(
output[i] >= 0.0,
"input[{i}]={}, roundtrip output[{i}]={} should be >= 0.0",
input[i],
output[i]
);
}
}
}