use crate::error::{RealizarError, Result};
use crate::quantize::{BLOCK_SIZE, QK_K};
use super::simd::extract_scale_min;
pub fn dequantize_q4_0(data: &[u8]) -> Result<Vec<f32>> {
const BLOCK_BYTES: usize = 2 + 16;
if !data.len().is_multiple_of(BLOCK_BYTES) {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q4_0 data length {} is not a multiple of block size {}",
data.len(),
BLOCK_BYTES
),
});
}
let num_blocks = data.len() / BLOCK_BYTES;
let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
for block_idx in 0..num_blocks {
let block_start = block_idx * BLOCK_BYTES;
let out_start = block_idx * BLOCK_SIZE;
let scale_bytes = &data[block_start..block_start + 2];
let scale = half::f16::from_le_bytes([scale_bytes[0], scale_bytes[1]]).to_f32();
let quants_start = block_start + 2;
let quants = &data[quants_start..quants_start + 16];
for (j, &byte) in quants.iter().enumerate() {
#[allow(clippy::cast_possible_wrap)]
let low = (byte & 0x0F) as i16 - 8;
result[out_start + j] = scale * (low as f32);
#[allow(clippy::cast_possible_wrap)]
let high = (byte >> 4) as i16 - 8;
result[out_start + j + 16] = scale * (high as f32);
}
}
Ok(result)
}
pub fn dequantize_q8_0(data: &[u8]) -> Result<Vec<f32>> {
const BLOCK_BYTES: usize = 2 + 32;
if !data.len().is_multiple_of(BLOCK_BYTES) {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q8_0 data length {} is not a multiple of block size {}",
data.len(),
BLOCK_BYTES
),
});
}
let num_blocks = data.len() / BLOCK_BYTES;
let mut result = Vec::with_capacity(num_blocks * BLOCK_SIZE);
for block_idx in 0..num_blocks {
let block_start = block_idx * BLOCK_BYTES;
let scale_bits = u16::from_le_bytes([data[block_start], data[block_start + 1]]);
let scale = f16_to_f32(scale_bits);
let quants_start = block_start + 2;
let quants = &data[quants_start..quants_start + 32];
for &byte in quants {
let value = i8::from_le_bytes([byte]);
result.push(scale * f32::from(value));
}
}
Ok(result)
}
#[inline]
pub fn f16_to_f32(h: u16) -> f32 {
trueno::f16_to_f32(h)
}
pub fn dequantize_f16(data: &[u8]) -> Result<Vec<f32>> {
if !data.len().is_multiple_of(2) {
return Err(RealizarError::InvalidShape {
reason: format!(
"F16 data length {} is not a multiple of 2 bytes",
data.len()
),
});
}
let num_values = data.len() / 2;
let mut result = Vec::with_capacity(num_values);
for chunk in data.chunks_exact(2) {
let h = u16::from_le_bytes([chunk[0], chunk[1]]);
result.push(f16_to_f32(h));
}
Ok(result)
}
pub fn dequantize_q4_1(data: &[u8]) -> Result<Vec<f32>> {
const BLOCK_BYTES: usize = 20;
if !data.len().is_multiple_of(BLOCK_BYTES) {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q4_1 data length {} is not a multiple of block size {}",
data.len(),
BLOCK_BYTES
),
});
}
let num_blocks = data.len() / BLOCK_BYTES;
let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
for block_idx in 0..num_blocks {
let block_start = block_idx * BLOCK_BYTES;
let out_start = block_idx * BLOCK_SIZE;
let d_bytes = &data[block_start..block_start + 2];
let d = f16_to_f32(u16::from_le_bytes([d_bytes[0], d_bytes[1]]));
let min_bytes = &data[block_start + 2..block_start + 4];
let min = f16_to_f32(u16::from_le_bytes([min_bytes[0], min_bytes[1]]));
let quants = &data[block_start + 4..block_start + 20];
for (j, &byte) in quants.iter().enumerate() {
let low = byte & 0x0F;
result[out_start + j] = d * f32::from(low) + min;
let high = (byte >> 4) & 0x0F;
result[out_start + j + 16] = d * f32::from(high) + min;
}
}
Ok(result)
}
pub fn dequantize_q5_0(data: &[u8]) -> Result<Vec<f32>> {
const BLOCK_BYTES: usize = 22;
if !data.len().is_multiple_of(BLOCK_BYTES) {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q5_0 data length {} is not a multiple of block size {}",
data.len(),
BLOCK_BYTES
),
});
}
let num_blocks = data.len() / BLOCK_BYTES;
let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
for block_idx in 0..num_blocks {
let block_start = block_idx * BLOCK_BYTES;
let out_start = block_idx * BLOCK_SIZE;
let d_bytes = &data[block_start..block_start + 2];
let d = f16_to_f32(u16::from_le_bytes([d_bytes[0], d_bytes[1]]));
let qh = u32::from_le_bytes([
data[block_start + 2],
data[block_start + 3],
data[block_start + 4],
data[block_start + 5],
]);
let qs = &data[block_start + 6..block_start + 22];
for (i, &byte) in qs.iter().enumerate() {
let low_q = byte & 0x0F;
let high_bit_low = ((qh >> i) & 1) as u8;
let q_low = low_q | (high_bit_low << 4);
#[allow(clippy::cast_possible_wrap)]
let value_low = q_low as i8 - 16;
result[out_start + i] = d * f32::from(value_low);
let high_q = (byte >> 4) & 0x0F;
let high_bit_high = ((qh >> (i + 16)) & 1) as u8;
let q_high = high_q | (high_bit_high << 4);
#[allow(clippy::cast_possible_wrap)]
let value_high = q_high as i8 - 16;
result[out_start + i + 16] = d * f32::from(value_high);
}
}
Ok(result)
}
pub fn dequantize_q5_1(data: &[u8]) -> Result<Vec<f32>> {
const BLOCK_BYTES: usize = 24;
if !data.len().is_multiple_of(BLOCK_BYTES) {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q5_1 data length {} is not a multiple of block size {}",
data.len(),
BLOCK_BYTES
),
});
}
let num_blocks = data.len() / BLOCK_BYTES;
let mut result = vec![0.0f32; num_blocks * BLOCK_SIZE];
for block_idx in 0..num_blocks {
let block_start = block_idx * BLOCK_BYTES;
let out_start = block_idx * BLOCK_SIZE;
let d_bytes = &data[block_start..block_start + 2];
let d = f16_to_f32(u16::from_le_bytes([d_bytes[0], d_bytes[1]]));
let min_bytes = &data[block_start + 2..block_start + 4];
let min = f16_to_f32(u16::from_le_bytes([min_bytes[0], min_bytes[1]]));
let qh = u32::from_le_bytes([
data[block_start + 4],
data[block_start + 5],
data[block_start + 6],
data[block_start + 7],
]);
let qs = &data[block_start + 8..block_start + 24];
for (i, &byte) in qs.iter().enumerate() {
let low_q = byte & 0x0F;
let high_bit_low = ((qh >> i) & 1) as u8;
let q_low = low_q | (high_bit_low << 4);
result[out_start + i] = d * f32::from(q_low) + min;
let high_q = (byte >> 4) & 0x0F;
let high_bit_high = ((qh >> (i + 16)) & 1) as u8;
let q_high = high_q | (high_bit_high << 4);
result[out_start + i + 16] = d * f32::from(q_high) + min;
}
}
Ok(result)
}
include!("dequant_q4k.rs");
include!("dequant_f16.rs");