use std::collections::HashMap;
use std::io::{self, Read, Seek, SeekFrom};
use std::path::Path;
use crate::gguf::{self, GGMLType, GGUFFile, GGUFTensorInfo};
#[derive(Debug, Clone)]
pub struct ModelWeights {
pub tensors: HashMap<String, Vec<f32>>,
}
impl ModelWeights {
pub fn get(&self, name: &str) -> Option<&[f32]> {
self.tensors.get(name).map(|v| v.as_slice())
}
pub fn tensor(&self, name: &str) -> &[f32] {
self.tensors
.get(name)
.unwrap_or_else(|| panic!("weight not found: {name}"))
}
pub fn len(&self) -> usize {
self.tensors.len()
}
pub fn is_empty(&self) -> bool {
self.tensors.is_empty()
}
pub fn total_elements(&self) -> usize {
self.tensors.values().map(|v| v.len()).sum()
}
pub fn memory_bytes(&self) -> usize {
self.total_elements() * 4
}
}
#[derive(Debug, thiserror::Error)]
pub enum WeightLoadError {
#[error("I/O error: {0}")]
Io(#[from] io::Error),
#[error("tensor not found in GGUF: {0}")]
TensorNotFound(String),
#[error("unsupported GGML type for dequantization: {0:?}")]
UnsupportedType(GGMLType),
}
pub fn load_all<R: Read + Seek>(
reader: &mut R,
gguf: &GGUFFile,
) -> Result<ModelWeights, WeightLoadError> {
let mut tensors = HashMap::with_capacity(gguf.tensors.len());
for tensor_info in &gguf.tensors {
let data = load_tensor(reader, gguf, tensor_info)?;
let hf_name = gguf_name_to_hf(&tensor_info.name);
tensors.insert(hf_name, data);
}
Ok(ModelWeights { tensors })
}
pub fn load_from_file(path: impl AsRef<Path>) -> Result<(GGUFFile, ModelWeights), WeightLoadError> {
let file = std::fs::File::open(path.as_ref())?;
let mmap = unsafe { memmap2::Mmap::map(&file)? };
let mut cursor = io::Cursor::new(&mmap[..]);
let gguf = gguf::parse(&mut cursor).map_err(|e| WeightLoadError::Io(io::Error::other(e)))?;
let mut tensors = HashMap::with_capacity(gguf.tensors.len());
for tensor_info in &gguf.tensors {
let data_offset = gguf.tensor_data_offset + tensor_info.offset;
let data_size = tensor_info.data_size() as usize;
let numel = tensor_info.numel() as usize;
let start = data_offset as usize;
let end = start + data_size;
if end > mmap.len() {
return Err(WeightLoadError::Io(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!("tensor {} extends past end of file", tensor_info.name),
)));
}
let raw = &mmap[start..end];
let data = dequantize(raw, tensor_info.ggml_type, numel)?;
let hf_name = gguf_name_to_hf(&tensor_info.name);
tensors.insert(hf_name, data);
}
Ok((gguf, ModelWeights { tensors }))
}
fn gguf_name_to_hf(name: &str) -> String {
match name {
"token_embd.weight" => return "model.embed_tokens.weight".to_string(),
"output_norm.weight" => return "model.norm.weight".to_string(),
"output.weight" => return "lm_head.weight".to_string(),
_ => {}
}
if let Some(rest) = name.strip_prefix("blk.") {
if let Some(dot_pos) = rest.find('.') {
let layer_num = &rest[..dot_pos];
let suffix = &rest[dot_pos + 1..];
let hf_suffix = match suffix {
"attn_norm.weight" => "input_layernorm.weight",
"attn_q.weight" => "self_attn.q_proj.weight",
"attn_k.weight" => "self_attn.k_proj.weight",
"attn_v.weight" => "self_attn.v_proj.weight",
"attn_q.bias" => "self_attn.q_proj.bias",
"attn_k.bias" => "self_attn.k_proj.bias",
"attn_v.bias" => "self_attn.v_proj.bias",
"attn_output.weight" => "self_attn.o_proj.weight",
"ffn_norm.weight" => "post_attention_layernorm.weight",
"ffn_gate.weight" => "mlp.gate_proj.weight",
"ffn_up.weight" => "mlp.up_proj.weight",
"ffn_down.weight" => "mlp.down_proj.weight",
other => other, };
return format!("model.layers.{layer_num}.{hf_suffix}");
}
}
name.to_string()
}
pub fn load_tensor<R: Read + Seek>(
reader: &mut R,
gguf: &GGUFFile,
tensor_info: &GGUFTensorInfo,
) -> Result<Vec<f32>, WeightLoadError> {
let data_offset = gguf.tensor_data_offset + tensor_info.offset;
let data_size = tensor_info.data_size() as usize;
let numel = tensor_info.numel() as usize;
reader.seek(SeekFrom::Start(data_offset))?;
let mut raw = vec![0u8; data_size];
reader.read_exact(&mut raw)?;
dequantize(&raw, tensor_info.ggml_type, numel)
}
pub fn load_tensor_by_name<R: Read + Seek>(
reader: &mut R,
gguf: &GGUFFile,
name: &str,
) -> Result<Vec<f32>, WeightLoadError> {
let tensor_info = gguf
.tensor(name)
.ok_or_else(|| WeightLoadError::TensorNotFound(name.to_string()))?;
load_tensor(reader, gguf, tensor_info)
}
fn dequantize(data: &[u8], ggml_type: GGMLType, numel: usize) -> Result<Vec<f32>, WeightLoadError> {
match ggml_type {
GGMLType::F32 => Ok(dequant_f32(data, numel)),
GGMLType::F16 => Ok(dequant_f16(data, numel)),
GGMLType::BF16 => Ok(dequant_bf16(data, numel)),
GGMLType::Q8_0 => Ok(dequant_q8_0(data, numel)),
GGMLType::Q4_0 => Ok(dequant_q4_0(data, numel)),
GGMLType::Q4_1 => Ok(dequant_q4_1(data, numel)),
GGMLType::Q6K => Ok(dequant_q6_k(data, numel)),
GGMLType::Q5K => Ok(dequant_q5_k(data, numel)),
GGMLType::Q4K => Ok(dequant_q4_k(data, numel)),
GGMLType::Q8K => Ok(dequant_q8_k(data, numel)),
GGMLType::Q3K => Ok(dequant_q3_k(data, numel)),
GGMLType::Q2K => Ok(dequant_q2_k(data, numel)),
other => Err(WeightLoadError::UnsupportedType(other)),
}
}
fn dequant_f32(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
for (i, chunk) in data.chunks_exact(4).enumerate().take(numel) {
output[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
}
output
}
fn dequant_f16(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
for (i, chunk) in data.chunks_exact(2).enumerate().take(numel) {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
output[i] = f16_to_f32(bits);
}
output
}
fn dequant_bf16(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
for (i, chunk) in data.chunks_exact(2).enumerate().take(numel) {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
output[i] = f32::from_bits((bits as u32) << 16);
}
output
}
fn dequant_q8_0(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
let block_size = 32;
let type_size = 34; let num_blocks = numel.div_ceil(block_size);
for block_idx in 0..num_blocks {
let block_start = block_idx * type_size;
if block_start + type_size > data.len() {
break;
}
let scale_bits = u16::from_le_bytes([data[block_start], data[block_start + 1]]);
let scale = f16_to_f32(scale_bits);
for j in 0..block_size {
let out_idx = block_idx * block_size + j;
if out_idx >= numel {
break;
}
let quant = data[block_start + 2 + j] as i8;
output[out_idx] = quant as f32 * scale;
}
}
output
}
fn dequant_q4_0(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
let block_size = 32;
let type_size = 18; let num_blocks = numel.div_ceil(block_size);
for block_idx in 0..num_blocks {
let block_start = block_idx * type_size;
if block_start + type_size > data.len() {
break;
}
let scale_bits = u16::from_le_bytes([data[block_start], data[block_start + 1]]);
let scale = f16_to_f32(scale_bits);
for j in 0..16 {
let byte = data[block_start + 2 + j];
let lo = (byte & 0x0F) as i32 - 8; let hi = ((byte >> 4) & 0x0F) as i32 - 8;
let out_idx_lo = block_idx * block_size + j;
let out_idx_hi = block_idx * block_size + j + 16;
if out_idx_lo < numel {
output[out_idx_lo] = lo as f32 * scale;
}
if out_idx_hi < numel {
output[out_idx_hi] = hi as f32 * scale;
}
}
}
output
}
fn dequant_q4_1(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
let block_size = 32;
let type_size = 20; let num_blocks = numel.div_ceil(block_size);
for block_idx in 0..num_blocks {
let block_start = block_idx * type_size;
if block_start + type_size > data.len() {
break;
}
let scale_bits = u16::from_le_bytes([data[block_start], data[block_start + 1]]);
let min_bits = u16::from_le_bytes([data[block_start + 2], data[block_start + 3]]);
let scale = f16_to_f32(scale_bits);
let min = f16_to_f32(min_bits);
for j in 0..16 {
let byte = data[block_start + 4 + j];
let lo = (byte & 0x0F) as f32;
let hi = ((byte >> 4) & 0x0F) as f32;
let out_idx_lo = block_idx * block_size + j;
let out_idx_hi = block_idx * block_size + j + 16;
if out_idx_lo < numel {
output[out_idx_lo] = lo * scale + min;
}
if out_idx_hi < numel {
output[out_idx_hi] = hi * scale + min;
}
}
}
output
}
fn dequant_q6_k(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
let block_size = 256;
let type_size = 210;
let num_blocks = numel.div_ceil(block_size);
for block_idx in 0..num_blocks {
let bs = block_idx * type_size;
if bs + type_size > data.len() {
break;
}
let ql = &data[bs..bs + 128]; let qh = &data[bs + 128..bs + 192]; let scales = &data[bs + 192..bs + 208]; let d_bits = u16::from_le_bytes([data[bs + 208], data[bs + 209]]);
let d = f16_to_f32(d_bits);
for (group, &sc_byte) in scales.iter().enumerate().take(16) {
let sc = sc_byte as i8;
for j in 0..16 {
let idx = group * 16 + j;
if block_idx * block_size + idx >= numel {
break;
}
let ql_val = (ql[idx / 2] >> ((idx % 2) * 4)) & 0x0F;
let qh_val = (qh[idx / 4] >> ((idx % 4) * 2)) & 0x03;
let q = ((qh_val as i32) << 4 | ql_val as i32) - 32;
output[block_idx * block_size + idx] = d * sc as f32 * q as f32;
}
}
}
output
}
fn dequant_q5_k(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
let block_size = 256;
let type_size = 176;
let num_blocks = numel.div_ceil(block_size);
for block_idx in 0..num_blocks {
let bs = block_idx * type_size;
if bs + type_size > data.len() {
break;
}
let d_bits = u16::from_le_bytes([data[bs], data[bs + 1]]);
let dmin_bits = u16::from_le_bytes([data[bs + 2], data[bs + 3]]);
let d = f16_to_f32(d_bits);
let dmin = f16_to_f32(dmin_bits);
let scales = &data[bs + 4..bs + 16];
let qs = &data[bs + 16..bs + 16 + 128];
let qh = &data[bs + 144..bs + 176];
for j in 0..block_size {
if block_idx * block_size + j >= numel {
break;
}
let group = j / 32;
let sc_idx = group;
let sc = if sc_idx < scales.len() {
(scales[sc_idx / 2] >> ((sc_idx % 2) * 4)) & 0x0F
} else {
0
};
let m = if sc_idx < scales.len() {
(scales[(sc_idx + 8) / 2] >> (((sc_idx + 8) % 2) * 4)) & 0x0F
} else {
0
};
let byte_idx = j / 2;
let q4 = if byte_idx < qs.len() {
(qs[byte_idx] >> ((j % 2) * 4)) & 0x0F
} else {
0
};
let qh_bit = if j / 8 < qh.len() {
(qh[j / 8] >> (j % 8)) & 1
} else {
0
};
let q = (q4 as u32 | ((qh_bit as u32) << 4)) as f32;
output[block_idx * block_size + j] = d * sc as f32 * q - dmin * m as f32;
}
}
output
}
fn dequant_q4_k(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
let block_size = 256;
let type_size = 144;
let num_blocks = numel.div_ceil(block_size);
for block_idx in 0..num_blocks {
let bs = block_idx * type_size;
if bs + type_size > data.len() {
break;
}
let d_bits = u16::from_le_bytes([data[bs], data[bs + 1]]);
let dmin_bits = u16::from_le_bytes([data[bs + 2], data[bs + 3]]);
let d = f16_to_f32(d_bits);
let dmin = f16_to_f32(dmin_bits);
let scales = &data[bs + 4..bs + 16];
let qs = &data[bs + 16..bs + 144];
for j in 0..block_size {
if block_idx * block_size + j >= numel {
break;
}
let group = j / 32;
let sc = (scales[group / 2] >> ((group % 2) * 4)) & 0x0F;
let m = (scales[(group + 8) / 2] >> (((group + 8) % 2) * 4)) & 0x0F;
let byte_idx = j / 2;
let q = if byte_idx < qs.len() {
(qs[byte_idx] >> ((j % 2) * 4)) & 0x0F
} else {
0
};
output[block_idx * block_size + j] = d * sc as f32 * q as f32 - dmin * m as f32;
}
}
output
}
fn dequant_q8_k(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
let block_size = 256;
let type_size = 292;
let num_blocks = numel.div_ceil(block_size);
for block_idx in 0..num_blocks {
let bs = block_idx * type_size;
if bs + type_size > data.len() {
break;
}
let scale = f32::from_le_bytes([data[bs], data[bs + 1], data[bs + 2], data[bs + 3]]);
for j in 0..block_size {
if block_idx * block_size + j >= numel {
break;
}
let q = data[bs + 4 + j] as i8;
output[block_idx * block_size + j] = scale * q as f32;
}
}
output
}
fn dequant_q3_k(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
let block_size = 256;
let type_size = 110;
let num_blocks = numel.div_ceil(block_size);
for block_idx in 0..num_blocks {
let bs = block_idx * type_size;
if bs + type_size > data.len() {
break;
}
let hmask = &data[bs..bs + 32];
let qs = &data[bs + 32..bs + 96];
let scales_raw = &data[bs + 96..bs + 108];
let d_bits = u16::from_le_bytes([data[bs + 108], data[bs + 109]]);
let d = f16_to_f32(d_bits);
for j in 0..block_size {
if block_idx * block_size + j >= numel {
break;
}
let group = j / 16;
let sc = if group < scales_raw.len() {
((scales_raw[group / 2] >> ((group % 2) * 4)) & 0x0F) as i32 - 8
} else {
0
};
let byte_idx = j * 3 / 8;
let bit_offset = (j * 3) % 8;
let q3 = if byte_idx < qs.len() {
((qs[byte_idx] >> bit_offset) & 0x07) as i32 - 4
} else {
0
};
let hbit = if j / 8 < hmask.len() {
((hmask[j / 8] >> (j % 8)) & 1) as i32
} else {
0
};
let q = q3 - hbit * 4;
output[block_idx * block_size + j] = d * sc as f32 * q as f32;
}
}
output
}
fn dequant_q2_k(data: &[u8], numel: usize) -> Vec<f32> {
let mut output = vec![0.0f32; numel];
let block_size = 256;
let type_size = 84;
let num_blocks = numel.div_ceil(block_size);
for block_idx in 0..num_blocks {
let bs = block_idx * type_size;
if bs + type_size > data.len() {
break;
}
let scales = &data[bs..bs + 16];
let qs = &data[bs + 16..bs + 80];
let d_bits = u16::from_le_bytes([data[bs + 80], data[bs + 81]]);
let dmin_bits = u16::from_le_bytes([data[bs + 82], data[bs + 83]]);
let d = f16_to_f32(d_bits);
let dmin = f16_to_f32(dmin_bits);
for j in 0..block_size {
if block_idx * block_size + j >= numel {
break;
}
let group = j / 16;
let sc = scales[group] & 0x0F;
let m = (scales[group] >> 4) & 0x0F;
let byte_idx = j / 4;
let q = if byte_idx < qs.len() {
(qs[byte_idx] >> ((j % 4) * 2)) & 0x03
} else {
0
};
output[block_idx * block_size + j] = d * sc as f32 * q as f32 - dmin * m as f32;
}
}
output
}
fn f16_to_f32(bits: u16) -> f32 {
let sign = ((bits >> 15) & 1) as u32;
let exponent = ((bits >> 10) & 0x1F) as u32;
let mantissa = (bits & 0x3FF) as u32;
if exponent == 0 {
if mantissa == 0 {
return f32::from_bits(sign << 31);
}
let mut m = mantissa;
let mut e: i32 = -14; while m & 0x400 == 0 {
m <<= 1;
e -= 1;
}
m &= 0x3FF; let f32_exp = ((e + 127) as u32) & 0xFF;
return f32::from_bits((sign << 31) | (f32_exp << 23) | (m << 13));
}
if exponent == 31 {
let f32_mantissa = mantissa << 13;
return f32::from_bits((sign << 31) | (0xFF << 23) | f32_mantissa);
}
let f32_exp = (exponent as i32 - 15 + 127) as u32;
f32::from_bits((sign << 31) | (f32_exp << 23) | (mantissa << 13))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn f16_conversion_basic() {
assert_eq!(f16_to_f32(0x0000), 0.0);
assert!((f16_to_f32(0x3C00) - 1.0).abs() < 1e-6);
assert!((f16_to_f32(0xBC00) - (-1.0)).abs() < 1e-6);
assert!((f16_to_f32(0x3800) - 0.5).abs() < 1e-6);
assert!((f16_to_f32(0x4000) - 2.0).abs() < 1e-6);
}
#[test]
fn f16_special_values() {
assert!(f16_to_f32(0x7C00).is_infinite());
assert!(f16_to_f32(0x7E00).is_nan());
assert_eq!(f16_to_f32(0x8000), -0.0);
}
#[test]
fn bf16_conversion() {
let data = 0x3F80u16.to_le_bytes();
let result = dequant_bf16(&data, 1);
assert!((result[0] - 1.0).abs() < 1e-6);
}
#[test]
fn dequant_f32_identity() {
let values = vec![1.0f32, 2.0, -3.5, 0.0];
let bytes: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
let result = dequant_f32(&bytes, 4);
assert_eq!(result, values);
}
#[test]
fn dequant_f16_roundtrip() {
let f16_one = 0x3C00u16; let f16_half = 0x3800u16; let bytes: Vec<u8> = [f16_one, f16_half]
.iter()
.flat_map(|v| v.to_le_bytes())
.collect();
let result = dequant_f16(&bytes, 2);
assert!((result[0] - 1.0).abs() < 1e-6);
assert!((result[1] - 0.5).abs() < 1e-6);
}
#[test]
fn dequant_q8_0_basic() {
let scale_f16: u16 = 0x3C00; let mut block = Vec::new();
block.extend_from_slice(&scale_f16.to_le_bytes());
for i in 0..32 {
block.push(i as u8); }
let result = dequant_q8_0(&block, 32);
assert_eq!(result.len(), 32);
assert!((result[0] - 0.0).abs() < 1e-6);
assert!((result[1] - 1.0).abs() < 1e-6);
assert!((result[31] - 31.0).abs() < 1e-6);
}
#[test]
fn dequant_q4_0_basic() {
let scale_f16: u16 = 0x3C00; let mut block = Vec::new();
block.extend_from_slice(&scale_f16.to_le_bytes());
block.extend(std::iter::repeat_n(0x88u8, 16));
let result = dequant_q4_0(&block, 32);
assert_eq!(result.len(), 32);
for val in &result {
assert!((val - 0.0).abs() < 1e-6);
}
}
#[test]
fn dequant_q4_1_basic() {
let scale_f16: u16 = 0x4000; let min_f16: u16 = 0x3C00; let mut block = Vec::new();
block.extend_from_slice(&scale_f16.to_le_bytes());
block.extend_from_slice(&min_f16.to_le_bytes());
block.extend(std::iter::repeat_n(0x00u8, 16));
let result = dequant_q4_1(&block, 32);
assert_eq!(result.len(), 32);
for val in &result {
assert!((val - 1.0).abs() < 1e-6, "expected 1.0, got {val}");
}
}
#[test]
fn model_weights_accessors() {
let mut tensors = HashMap::new();
tensors.insert("w1".to_string(), vec![1.0f32; 100]);
tensors.insert("w2".to_string(), vec![2.0f32; 200]);
let weights = ModelWeights { tensors };
assert_eq!(weights.len(), 2);
assert!(!weights.is_empty());
assert_eq!(weights.total_elements(), 300);
assert_eq!(weights.memory_bytes(), 1200);
assert_eq!(weights.get("w1").unwrap().len(), 100);
assert_eq!(weights.tensor("w2").len(), 200);
}
}