use crate::backends::q4k::matmul_q4k_f32_dispatch;
use crate::blis::attention::fused_attention_decode;
use crate::blis::norms::rms_norm;
use crate::error::TruenoError;
use crate::inference::gguf::{GgmlType, GgufFile};
#[derive(Debug, Clone)]
pub struct ModelConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_layers: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub vocab_size: usize,
pub rms_norm_eps: f32,
pub rope_theta: f32,
pub max_seq_len: usize,
pub arch: String,
}
impl ModelConfig {
pub fn from_gguf(gguf: &GgufFile) -> Result<Self, TruenoError> {
let arch = gguf.meta_str("general.architecture").unwrap_or("llama").to_string();
let prefix = &arch;
let hidden_size = gguf
.meta_u32(&format!("{prefix}.embedding_length"))
.ok_or_else(|| TruenoError::InvalidInput("Missing embedding_length in GGUF".into()))?
as usize;
let num_heads = gguf
.meta_u32(&format!("{prefix}.attention.head_count"))
.ok_or_else(|| TruenoError::InvalidInput("Missing head_count in GGUF".into()))?
as usize;
let num_kv_heads = gguf
.meta_u32(&format!("{prefix}.attention.head_count_kv"))
.unwrap_or(num_heads as u32) as usize;
let num_layers = gguf
.meta_u32(&format!("{prefix}.block_count"))
.ok_or_else(|| TruenoError::InvalidInput("Missing block_count in GGUF".into()))?
as usize;
let intermediate_size =
gguf.meta_u32(&format!("{prefix}.feed_forward_length")).ok_or_else(|| {
TruenoError::InvalidInput("Missing feed_forward_length in GGUF".into())
})? as usize;
let head_dim = hidden_size / num_heads;
let vocab_size = gguf
.meta_u32("tokenizer.ggml.vocab_size")
.or_else(|| {
gguf.metadata.get("tokenizer.ggml.tokens").and_then(|v| {
if let crate::inference::gguf::MetadataValue::Array(arr) = v {
Some(arr.len() as u32)
} else {
None
}
})
})
.unwrap_or(32000) as usize;
let rms_norm_eps =
gguf.meta_f32(&format!("{prefix}.attention.layer_norm_rms_epsilon")).unwrap_or(1e-5);
let rope_theta = gguf.meta_f32(&format!("{prefix}.rope.freq_base")).unwrap_or(10000.0);
let max_seq_len =
gguf.meta_u32(&format!("{prefix}.context_length")).unwrap_or(2048) as usize;
Ok(Self {
hidden_size,
intermediate_size,
num_layers,
num_heads,
num_kv_heads,
head_dim,
vocab_size,
rms_norm_eps,
rope_theta,
max_seq_len,
arch,
})
}
}
pub enum WeightMatrix {
Q4K { data: Vec<u8>, rows: usize },
F32 { data: Vec<f32>, rows: usize },
}
impl WeightMatrix {
pub fn rows(&self) -> usize {
match self {
WeightMatrix::Q4K { rows, .. } => *rows,
WeightMatrix::F32 { rows, .. } => *rows,
}
}
}
pub struct LayerWeights {
pub attn_norm: Vec<f32>,
pub q_weight: WeightMatrix,
pub k_weight: WeightMatrix,
pub v_weight: WeightMatrix,
pub o_weight: WeightMatrix,
pub q_bias: Option<Vec<f32>>,
pub k_bias: Option<Vec<f32>>,
pub v_bias: Option<Vec<f32>>,
pub ffn_norm: Vec<f32>,
pub gate_weight: WeightMatrix,
pub up_weight: WeightMatrix,
pub down_weight: WeightMatrix,
}
pub struct ModelWeights {
pub token_embd: Vec<f32>, pub output_norm: Vec<f32>, pub output_weight: WeightMatrix,
pub layers: Vec<LayerWeights>,
}
pub struct KvCache {
pub k: Vec<Vec<f32>>,
pub v: Vec<Vec<f32>>,
pub seq_len: usize,
}
impl KvCache {
pub fn new(config: &ModelConfig) -> Self {
let kv_dim = config.num_kv_heads * config.head_dim;
let layer_size = config.max_seq_len * kv_dim;
Self {
k: (0..config.num_layers).map(|_| vec![0.0f32; layer_size]).collect(),
v: (0..config.num_layers).map(|_| vec![0.0f32; layer_size]).collect(),
seq_len: 0,
}
}
}
pub struct LlamaModel {
pub config: ModelConfig,
pub weights: ModelWeights,
}
impl LlamaModel {
pub fn from_gguf(gguf: &GgufFile) -> Result<Self, TruenoError> {
let config = ModelConfig::from_gguf(gguf)?;
eprintln!(
"Loading {} model: {}L × {}H ({}h {}kv) × {}I, vocab={}",
config.arch,
config.num_layers,
config.hidden_size,
config.num_heads,
config.num_kv_heads,
config.intermediate_size,
config.vocab_size,
);
let weights = load_weights(gguf, &config)?;
Ok(Self { config, weights })
}
pub fn forward(
&self,
token_id: u32,
pos: usize,
kv_cache: &mut KvCache,
) -> Result<Vec<f32>, TruenoError> {
let cfg = &self.config;
let w = &self.weights;
let embd_start = token_id as usize * cfg.hidden_size;
let embd_end = embd_start + cfg.hidden_size;
if embd_end > w.token_embd.len() {
return Err(TruenoError::InvalidInput(format!(
"Token ID {token_id} out of range (vocab={})",
cfg.vocab_size
)));
}
let mut hidden = w.token_embd[embd_start..embd_end].to_vec();
for (layer_idx, lw) in w.layers.iter().enumerate() {
hidden = self.forward_layer(layer_idx, lw, &hidden, pos, kv_cache)?;
}
let mut normed = vec![0.0f32; cfg.hidden_size];
rms_norm(&hidden, &w.output_norm, cfg.rms_norm_eps, &mut normed)?;
let logits = matmul_weight(&w.output_weight, &normed, cfg.hidden_size);
Ok(logits)
}
fn forward_layer(
&self,
layer_idx: usize,
lw: &LayerWeights,
hidden: &[f32],
pos: usize,
kv_cache: &mut KvCache,
) -> Result<Vec<f32>, TruenoError> {
let cfg = &self.config;
let kv_dim = cfg.num_kv_heads * cfg.head_dim;
let mut attn_input = vec![0.0f32; cfg.hidden_size];
rms_norm(hidden, &lw.attn_norm, cfg.rms_norm_eps, &mut attn_input)?;
let mut q = matmul_weight(&lw.q_weight, &attn_input, cfg.hidden_size);
let mut k_proj = matmul_weight(&lw.k_weight, &attn_input, cfg.hidden_size);
let mut v_proj = matmul_weight(&lw.v_weight, &attn_input, cfg.hidden_size);
if let Some(bias) = &lw.q_bias {
for (v, b) in q.iter_mut().zip(bias.iter()) {
*v += b;
}
}
if let Some(bias) = &lw.k_bias {
for (v, b) in k_proj.iter_mut().zip(bias.iter()) {
*v += b;
}
}
if let Some(bias) = &lw.v_bias {
for (v, b) in v_proj.iter_mut().zip(bias.iter()) {
*v += b;
}
}
let mut q_rope = q;
let mut k_rope = k_proj;
apply_rope(&mut q_rope, cfg.num_heads, cfg.head_dim, pos, cfg.rope_theta);
apply_rope(&mut k_rope, cfg.num_kv_heads, cfg.head_dim, pos, cfg.rope_theta);
let kv_off = pos * kv_dim;
kv_cache.k[layer_idx][kv_off..kv_off + kv_dim].copy_from_slice(&k_rope);
kv_cache.v[layer_idx][kv_off..kv_off + kv_dim].copy_from_slice(&v_proj);
let seq_len = pos + 1;
let mut attn_out = vec![0.0f32; cfg.hidden_size];
let heads_per_kv = cfg.num_heads / cfg.num_kv_heads;
for h in 0..cfg.num_heads {
let kv_h = h / heads_per_kv; let q_head = &q_rope[h * cfg.head_dim..(h + 1) * cfg.head_dim];
let mut k_cache_head = vec![0.0f32; seq_len * cfg.head_dim];
let mut v_cache_head = vec![0.0f32; seq_len * cfg.head_dim];
for s in 0..seq_len {
let src_off = s * kv_dim + kv_h * cfg.head_dim;
let dst_off = s * cfg.head_dim;
k_cache_head[dst_off..dst_off + cfg.head_dim]
.copy_from_slice(&kv_cache.k[layer_idx][src_off..src_off + cfg.head_dim]);
v_cache_head[dst_off..dst_off + cfg.head_dim]
.copy_from_slice(&kv_cache.v[layer_idx][src_off..src_off + cfg.head_dim]);
}
let out_head = &mut attn_out[h * cfg.head_dim..(h + 1) * cfg.head_dim];
fused_attention_decode(
q_head,
&k_cache_head,
&v_cache_head,
cfg.head_dim,
seq_len,
out_head,
);
}
let attn_projected = matmul_weight(&lw.o_weight, &attn_out, cfg.hidden_size);
let mut residual: Vec<f32> =
hidden.iter().zip(attn_projected.iter()).map(|(a, b)| a + b).collect();
let mut ffn_input = vec![0.0f32; cfg.hidden_size];
rms_norm(&residual, &lw.ffn_norm, cfg.rms_norm_eps, &mut ffn_input)?;
let gate = matmul_weight(&lw.gate_weight, &ffn_input, cfg.hidden_size);
let up = matmul_weight(&lw.up_weight, &ffn_input, cfg.hidden_size);
let swiglu: Vec<f32> = gate
.iter()
.zip(up.iter())
.map(|(&g, &u)| {
let silu_g = g / (1.0 + (-g).exp());
silu_g * u
})
.collect();
let ffn_out = matmul_weight(&lw.down_weight, &swiglu, cfg.intermediate_size);
for (r, f) in residual.iter_mut().zip(ffn_out.iter()) {
*r += f;
}
Ok(residual)
}
}
fn apply_rope(x: &mut [f32], num_heads: usize, head_dim: usize, pos: usize, theta: f32) {
for h in 0..num_heads {
let head = &mut x[h * head_dim..(h + 1) * head_dim];
for i in (0..head_dim).step_by(2) {
let freq = 1.0 / theta.powf(i as f32 / head_dim as f32);
let angle = pos as f32 * freq;
let (sin_a, cos_a) = angle.sin_cos();
let x0 = head[i];
let x1 = head[i + 1];
head[i] = x0 * cos_a - x1 * sin_a;
head[i + 1] = x0 * sin_a + x1 * cos_a;
}
}
}
fn load_weights(gguf: &GgufFile, config: &ModelConfig) -> Result<ModelWeights, TruenoError> {
let token_embd = load_f32_or_dequant_tensor(
gguf,
"token_embd.weight",
config.vocab_size * config.hidden_size,
)?;
let output_norm = load_f32_tensor(gguf, "output_norm.weight", config.hidden_size)?;
let output_weight = if gguf.tensor_info("output.weight").is_some() {
load_weight_matrix(gguf, "output.weight", config.hidden_size)?
} else {
WeightMatrix::F32 { data: token_embd.clone(), rows: config.vocab_size }
};
let mut layers = Vec::with_capacity(config.num_layers);
for i in 0..config.num_layers {
let prefix = format!("blk.{i}");
let attn_norm =
load_f32_tensor(gguf, &format!("{prefix}.attn_norm.weight"), config.hidden_size)?;
let ffn_norm =
load_f32_tensor(gguf, &format!("{prefix}.ffn_norm.weight"), config.hidden_size)?;
let q_weight =
load_weight_matrix(gguf, &format!("{prefix}.attn_q.weight"), config.hidden_size)?;
let k_weight =
load_weight_matrix(gguf, &format!("{prefix}.attn_k.weight"), config.hidden_size)?;
let v_weight =
load_weight_matrix(gguf, &format!("{prefix}.attn_v.weight"), config.hidden_size)?;
let o_weight =
load_weight_matrix(gguf, &format!("{prefix}.attn_output.weight"), config.hidden_size)?;
let kv_dim = config.num_kv_heads * config.head_dim;
let q_bias = load_optional_f32(gguf, &format!("{prefix}.attn_q.bias"), config.hidden_size);
let k_bias = load_optional_f32(gguf, &format!("{prefix}.attn_k.bias"), kv_dim);
let v_bias = load_optional_f32(gguf, &format!("{prefix}.attn_v.bias"), kv_dim);
let gate_weight =
load_weight_matrix(gguf, &format!("{prefix}.ffn_gate.weight"), config.hidden_size)?;
let up_weight =
load_weight_matrix(gguf, &format!("{prefix}.ffn_up.weight"), config.hidden_size)?;
let down_weight = load_weight_matrix(
gguf,
&format!("{prefix}.ffn_down.weight"),
config.intermediate_size,
)?;
if i == 0 {
eprintln!(
" Layer 0: Q[{}×{}] K[{}×{}] V[{}×{}] Gate[{}×{}]",
q_weight.rows(),
config.hidden_size,
k_weight.rows(),
config.hidden_size,
v_weight.rows(),
config.hidden_size,
gate_weight.rows(),
config.hidden_size,
);
}
layers.push(LayerWeights {
attn_norm,
q_weight,
k_weight,
v_weight,
o_weight,
q_bias,
k_bias,
v_bias,
ffn_norm,
gate_weight,
up_weight,
down_weight,
});
}
eprintln!(" Loaded {} layers", layers.len());
Ok(ModelWeights { token_embd, output_norm, output_weight, layers })
}
fn load_f32_or_dequant_tensor(
gguf: &GgufFile,
name: &str,
expected_elements: usize,
) -> Result<Vec<f32>, TruenoError> {
let info = gguf
.tensor_info(name)
.ok_or_else(|| TruenoError::InvalidInput(format!("Missing tensor: {name}")))?;
let data = gguf
.tensor_data(name)
.ok_or_else(|| TruenoError::InvalidInput(format!("Missing tensor data: {name}")))?;
match info.dtype {
GgmlType::F32 | GgmlType::F16 | GgmlType::Bf16 => {
Ok(to_f32_from_any(data, info.dtype, expected_elements))
}
GgmlType::Q4K => {
let n_elements = info.n_elements() as usize;
Ok(crate::backends::q4k::dequantize_q4k_to_f32(data, n_elements))
}
GgmlType::Q6K => Ok(dequantize_q6k_to_f32(data, info.n_elements() as usize)),
GgmlType::Q5K => Ok(dequantize_q5k_to_f32(data, info.n_elements() as usize)),
GgmlType::Q8_0 => Ok(dequantize_q8_0_to_f32(data, info.n_elements() as usize)),
GgmlType::Q4_0 => Ok(dequantize_q4_0_to_f32(data, info.n_elements() as usize)),
GgmlType::Q4_1 => Ok(dequantize_q4_1_to_f32(data, info.n_elements() as usize)),
_ => {
eprintln!(
" WARNING: tensor '{name}' has unsupported dtype {:?}, using zeros",
info.dtype
);
Ok(vec![0.0f32; expected_elements])
}
}
}
fn load_optional_f32(gguf: &GgufFile, name: &str, expected_elements: usize) -> Option<Vec<f32>> {
let info = gguf.tensor_info(name)?;
let data = gguf.tensor_data(name)?;
Some(to_f32_from_any(data, info.dtype, expected_elements))
}
fn load_f32_tensor(
gguf: &GgufFile,
name: &str,
expected_elements: usize,
) -> Result<Vec<f32>, TruenoError> {
let info = gguf
.tensor_info(name)
.ok_or_else(|| TruenoError::InvalidInput(format!("Missing tensor: {name}")))?;
let data = gguf
.tensor_data(name)
.ok_or_else(|| TruenoError::InvalidInput(format!("Missing tensor data: {name}")))?;
Ok(to_f32_from_any(data, info.dtype, expected_elements))
}
fn load_weight_matrix(
gguf: &GgufFile,
name: &str,
in_dim: usize,
) -> Result<WeightMatrix, TruenoError> {
let info = gguf
.tensor_info(name)
.ok_or_else(|| TruenoError::InvalidInput(format!("Missing tensor: {name}")))?;
let data = gguf
.tensor_data(name)
.ok_or_else(|| TruenoError::InvalidInput(format!("Missing tensor data: {name}")))?;
let n_elements = info.n_elements() as usize;
let out_dim = n_elements / in_dim;
match info.dtype {
GgmlType::Q4K => Ok(WeightMatrix::Q4K { data: data.to_vec(), rows: out_dim }),
GgmlType::F32 | GgmlType::F16 | GgmlType::Bf16 => {
let f32_data = to_f32_from_any(data, info.dtype, n_elements);
Ok(WeightMatrix::F32 { data: f32_data, rows: out_dim })
}
GgmlType::Q6K => {
let f32_data = dequantize_q6k_to_f32(data, n_elements);
Ok(WeightMatrix::F32 { data: f32_data, rows: out_dim })
}
GgmlType::Q5K => {
let f32_data = dequantize_q5k_to_f32(data, n_elements);
Ok(WeightMatrix::F32 { data: f32_data, rows: out_dim })
}
GgmlType::Q8_0 => {
let f32_data = dequantize_q8_0_to_f32(data, n_elements);
Ok(WeightMatrix::F32 { data: f32_data, rows: out_dim })
}
GgmlType::Q4_0 => {
let f32_data = dequantize_q4_0_to_f32(data, n_elements);
Ok(WeightMatrix::F32 { data: f32_data, rows: out_dim })
}
GgmlType::Q4_1 => {
let f32_data = dequantize_q4_1_to_f32(data, n_elements);
Ok(WeightMatrix::F32 { data: f32_data, rows: out_dim })
}
_ => {
eprintln!(" WARNING: tensor '{name}' dtype {:?} unsupported, using zeros", info.dtype);
Ok(WeightMatrix::F32 { data: vec![0.0f32; n_elements], rows: out_dim })
}
}
}
fn matmul_weight(weight: &WeightMatrix, input: &[f32], in_dim: usize) -> Vec<f32> {
match weight {
WeightMatrix::Q4K { data, rows } => matmul_q4k_f32_dispatch(data, input, *rows, in_dim),
WeightMatrix::F32 { data, rows } => {
let mut out = vec![0.0f32; *rows];
for i in 0..*rows {
let row = &data[i * in_dim..(i + 1) * in_dim];
out[i] = row.iter().zip(input.iter()).map(|(a, b)| a * b).sum();
}
out
}
}
}
fn f16_to_f32(bits: u16) -> f32 {
let sign = ((bits >> 15) as u32) << 31;
let exp = ((bits >> 10) & 0x1F) as u32;
let mant = (bits & 0x3FF) as u32;
if exp == 0 {
if mant == 0 {
return f32::from_bits(sign); }
let mut m = mant;
let mut e: i32 = -14;
while m & 0x400 == 0 {
m <<= 1;
e -= 1;
}
m &= 0x3FF;
let f32_exp = ((e + 127) as u32) << 23;
return f32::from_bits(sign | f32_exp | (m << 13));
}
if exp == 31 {
return f32::from_bits(sign | 0x7F80_0000 | (mant << 13));
}
let f32_exp = (exp + 112) << 23; f32::from_bits(sign | f32_exp | (mant << 13))
}
fn to_f32_from_any(data: &[u8], dtype: GgmlType, n_elements: usize) -> Vec<f32> {
match dtype {
GgmlType::F32 => {
let count = n_elements.min(data.len() / 4);
(0..count)
.map(|i| {
let off = i * 4;
f32::from_le_bytes([data[off], data[off + 1], data[off + 2], data[off + 3]])
})
.collect()
}
GgmlType::F16 => {
let count = n_elements.min(data.len() / 2);
(0..count)
.map(|i| {
let off = i * 2;
let bits = u16::from_le_bytes([data[off], data[off + 1]]);
f16_to_f32(bits)
})
.collect()
}
GgmlType::Bf16 => {
let count = n_elements.min(data.len() / 2);
(0..count)
.map(|i| {
let off = i * 2;
let bits = u16::from_le_bytes([data[off], data[off + 1]]);
f32::from_bits((bits as u32) << 16)
})
.collect()
}
_ => {
vec![0.0f32; n_elements]
}
}
}
fn dequantize_q6k_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
const BLOCK_SIZE: usize = 256;
const BLOCK_BYTES: usize = 210;
let num_blocks = (num_elements + BLOCK_SIZE - 1) / BLOCK_SIZE;
let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
for sb in 0..num_blocks {
let sb_start = sb * BLOCK_BYTES;
if sb_start + BLOCK_BYTES > data.len() {
break;
}
let block = &data[sb_start..sb_start + BLOCK_BYTES];
let ql = &block[0..128];
let qh = &block[128..192];
let scales = &block[192..208];
let d = f16_to_f32(u16::from_le_bytes([block[208], block[209]]));
let out_base = sb * BLOCK_SIZE;
for group in 0..16usize {
let scale = (scales[group] as i8) as f32;
let group_off = group * 16;
for j in 0..16usize {
let idx = group_off + j;
let ql_byte = ql[idx / 2];
let low4 = if idx % 2 == 0 { ql_byte & 0x0F } else { ql_byte >> 4 };
let qh_byte = qh[idx / 4];
let high2 = (qh_byte >> ((idx % 4) * 2)) & 0x03;
let q6 = ((low4 | (high2 << 4)) as i8).wrapping_sub(32) as f32;
result[out_base + idx] = d * scale * q6;
}
}
}
result.truncate(num_elements);
result
}
fn dequantize_q5k_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
const BLOCK_SIZE: usize = 256;
const BLOCK_BYTES: usize = 176;
let num_blocks = (num_elements + BLOCK_SIZE - 1) / BLOCK_SIZE;
let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
for sb in 0..num_blocks {
let sb_start = sb * BLOCK_BYTES;
if sb_start + BLOCK_BYTES > data.len() {
break;
}
let block = &data[sb_start..sb_start + BLOCK_BYTES];
let d = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
let dmin = f16_to_f32(u16::from_le_bytes([block[2], block[3]]));
let sc = &block[4..16];
let mut scales = [0u8; 8];
let mut mins = [0u8; 8];
for i in 0..4 {
scales[i] = sc[i] & 0x3F;
mins[i] = sc[i + 4] & 0x3F;
scales[i + 4] = (sc[i + 8] & 0x0F) | ((sc[i] >> 6) << 4);
mins[i + 4] = (sc[i + 8] >> 4) | ((sc[i + 4] >> 6) << 4);
}
let qh = &block[16..48];
let qs = &block[48..176];
let out_base = sb * BLOCK_SIZE;
for sub in 0..8usize {
let scale = d * scales[sub] as f32;
let min = dmin * mins[sub] as f32;
let sub_off = sub * 32;
for j in 0..32usize {
let idx = sub_off + j;
let low4 = (qs[idx / 2] >> ((idx % 2) * 4)) & 0x0F;
let high1 = (qh[idx / 8] >> (idx % 8)) & 0x01;
let q5 = (low4 | (high1 << 4)) as f32;
result[out_base + idx] = scale * q5 - min;
}
}
}
result.truncate(num_elements);
result
}
fn dequantize_q8_0_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
const BLOCK_SIZE: usize = 32;
const BLOCK_BYTES: usize = 34;
let num_blocks = (num_elements + BLOCK_SIZE - 1) / BLOCK_SIZE;
let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
for b in 0..num_blocks {
let b_start = b * BLOCK_BYTES;
if b_start + BLOCK_BYTES > data.len() {
break;
}
let block = &data[b_start..b_start + BLOCK_BYTES];
let d = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
let out_base = b * BLOCK_SIZE;
for j in 0..BLOCK_SIZE {
result[out_base + j] = d * (block[2 + j] as i8) as f32;
}
}
result.truncate(num_elements);
result
}
fn dequantize_q4_0_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
const BLOCK_SIZE: usize = 32;
const BLOCK_BYTES: usize = 18;
let num_blocks = (num_elements + BLOCK_SIZE - 1) / BLOCK_SIZE;
let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
for b in 0..num_blocks {
let b_start = b * BLOCK_BYTES;
if b_start + BLOCK_BYTES > data.len() {
break;
}
let block = &data[b_start..b_start + BLOCK_BYTES];
let d = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
let out_base = b * BLOCK_SIZE;
for j in 0..16 {
let byte = block[2 + j];
let lo = (byte & 0x0F) as i32 - 8;
let hi = ((byte >> 4) & 0x0F) as i32 - 8;
result[out_base + j * 2] = d * lo as f32;
result[out_base + j * 2 + 1] = d * hi as f32;
}
}
result.truncate(num_elements);
result
}
fn dequantize_q4_1_to_f32(data: &[u8], num_elements: usize) -> Vec<f32> {
const BLOCK_SIZE: usize = 32;
const BLOCK_BYTES: usize = 20;
let num_blocks = (num_elements + BLOCK_SIZE - 1) / BLOCK_SIZE;
let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
for b in 0..num_blocks {
let b_start = b * BLOCK_BYTES;
if b_start + BLOCK_BYTES > data.len() {
break;
}
let block = &data[b_start..b_start + BLOCK_BYTES];
let d = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
let m = f16_to_f32(u16::from_le_bytes([block[2], block[3]]));
let out_base = b * BLOCK_SIZE;
for j in 0..16 {
let byte = block[4 + j];
let lo = (byte & 0x0F) as f32;
let hi = ((byte >> 4) & 0x0F) as f32;
result[out_base + j * 2] = d * lo + m;
result[out_base + j * 2 + 1] = d * hi + m;
}
}
result.truncate(num_elements);
result
}