use crate::bitnet::dequantize_bitnet_t158;
use crate::error::{Result, RuvLLMError};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u32)]
pub enum GgufQuantType {
F32 = 0,
F16 = 1,
Q4_0 = 2,
Q4_1 = 3,
Q4_2 = 4,
Q4_3 = 5,
Q5_0 = 6,
Q5_1 = 7,
Q8_0 = 8,
Q8_1 = 9,
Q2_K = 10,
Q3_K = 11,
Q4_K = 12,
Q5_K = 13,
Q6_K = 14,
Q8_K = 15,
IQ2_XXS = 16,
IQ2_XS = 17,
IQ3_XXS = 18,
IQ1_S = 19,
IQ4_NL = 20,
IQ3_S = 21,
IQ2_S = 22,
IQ4_XS = 23,
I8 = 24,
I16 = 25,
I32 = 26,
I64 = 27,
F64 = 28,
Bf16 = 29,
BitnetT158 = 30,
}
impl TryFrom<u32> for GgufQuantType {
type Error = RuvLLMError;
fn try_from(value: u32) -> Result<Self> {
match value {
0 => Ok(Self::F32),
1 => Ok(Self::F16),
2 => Ok(Self::Q4_0),
3 => Ok(Self::Q4_1),
4 => Ok(Self::Q4_2),
5 => Ok(Self::Q4_3),
6 => Ok(Self::Q5_0),
7 => Ok(Self::Q5_1),
8 => Ok(Self::Q8_0),
9 => Ok(Self::Q8_1),
10 => Ok(Self::Q2_K),
11 => Ok(Self::Q3_K),
12 => Ok(Self::Q4_K),
13 => Ok(Self::Q5_K),
14 => Ok(Self::Q6_K),
15 => Ok(Self::Q8_K),
16 => Ok(Self::IQ2_XXS),
17 => Ok(Self::IQ2_XS),
18 => Ok(Self::IQ3_XXS),
19 => Ok(Self::IQ1_S),
20 => Ok(Self::IQ4_NL),
21 => Ok(Self::IQ3_S),
22 => Ok(Self::IQ2_S),
23 => Ok(Self::IQ4_XS),
24 => Ok(Self::I8),
25 => Ok(Self::I16),
26 => Ok(Self::I32),
27 => Ok(Self::I64),
28 => Ok(Self::F64),
29 => Ok(Self::Bf16),
30 => Ok(Self::BitnetT158),
_ => Err(RuvLLMError::Model(format!(
"Unknown GGUF quantization type: {}",
value
))),
}
}
}
impl GgufQuantType {
pub fn block_size(&self) -> usize {
match self {
Self::F32 | Self::F16 | Self::Bf16 | Self::F64 => 1,
Self::I8 | Self::I16 | Self::I32 | Self::I64 => 1,
Self::Q4_0 | Self::Q4_1 | Self::Q4_2 | Self::Q4_3 => 32,
Self::Q5_0 | Self::Q5_1 => 32,
Self::Q8_0 | Self::Q8_1 => 32,
Self::Q2_K | Self::Q3_K | Self::Q4_K | Self::Q5_K | Self::Q6_K | Self::Q8_K => 256,
Self::IQ2_XXS | Self::IQ2_XS | Self::IQ2_S => 256,
Self::IQ3_XXS | Self::IQ3_S => 256,
Self::IQ1_S => 256,
Self::IQ4_NL => 32,
Self::IQ4_XS => 256,
Self::BitnetT158 => 256,
}
}
pub fn type_size(&self) -> usize {
match self {
Self::F32 => 4,
Self::F16 => 2,
Self::Bf16 => 2,
Self::F64 => 8,
Self::I8 => 1,
Self::I16 => 2,
Self::I32 => 4,
Self::I64 => 8,
Self::Q4_0 => 18,
Self::Q4_1 => 20,
Self::Q4_2 => 18, Self::Q4_3 => 20, Self::Q5_0 => 22,
Self::Q5_1 => 24,
Self::Q8_0 => 34,
Self::Q8_1 => 36,
Self::Q2_K => 84,
Self::Q3_K => 110,
Self::Q4_K => 144,
Self::Q5_K => 176,
Self::Q6_K => 210,
Self::Q8_K => 292,
Self::IQ2_XXS => 66,
Self::IQ2_XS => 74,
Self::IQ2_S => 82,
Self::IQ3_XXS => 98,
Self::IQ3_S => 110,
Self::IQ1_S => 50,
Self::IQ4_NL => 18,
Self::IQ4_XS => 136,
Self::BitnetT158 => 66,
}
}
pub fn tensor_size(&self, num_elements: usize) -> usize {
let block_size = self.block_size();
let type_size = self.type_size();
let num_blocks = (num_elements + block_size - 1) / block_size;
num_blocks * type_size
}
pub fn is_quantized(&self) -> bool {
!matches!(
self,
Self::F32
| Self::F16
| Self::Bf16
| Self::F64
| Self::I8
| Self::I16
| Self::I32
| Self::I64
)
}
pub fn bits_per_weight(&self) -> f32 {
let type_size = self.type_size() as f32;
let block_size = self.block_size() as f32;
(type_size * 8.0) / block_size
}
pub fn name(&self) -> &'static str {
match self {
Self::F32 => "F32",
Self::F16 => "F16",
Self::Bf16 => "BF16",
Self::F64 => "F64",
Self::I8 => "I8",
Self::I16 => "I16",
Self::I32 => "I32",
Self::I64 => "I64",
Self::Q4_0 => "Q4_0",
Self::Q4_1 => "Q4_1",
Self::Q4_2 => "Q4_2",
Self::Q4_3 => "Q4_3",
Self::Q5_0 => "Q5_0",
Self::Q5_1 => "Q5_1",
Self::Q8_0 => "Q8_0",
Self::Q8_1 => "Q8_1",
Self::Q2_K => "Q2_K",
Self::Q3_K => "Q3_K",
Self::Q4_K => "Q4_K",
Self::Q5_K => "Q5_K",
Self::Q6_K => "Q6_K",
Self::Q8_K => "Q8_K",
Self::IQ2_XXS => "IQ2_XXS",
Self::IQ2_XS => "IQ2_XS",
Self::IQ2_S => "IQ2_S",
Self::IQ3_XXS => "IQ3_XXS",
Self::IQ3_S => "IQ3_S",
Self::IQ1_S => "IQ1_S",
Self::IQ4_NL => "IQ4_NL",
Self::IQ4_XS => "IQ4_XS",
Self::BitnetT158 => "BITNET_T158",
}
}
}
#[derive(Debug, Clone)]
pub struct QuantizedTensor {
pub data: Vec<u8>,
pub dtype: GgufQuantType,
pub shape: Vec<usize>,
pub num_elements: usize,
}
impl QuantizedTensor {
pub fn dequantize(&self) -> Result<Vec<f32>> {
dequantize_tensor(&self.data, self.dtype, self.num_elements)
}
pub fn block_count(&self) -> usize {
let block_size = self.dtype.block_size();
(self.num_elements + block_size - 1) / block_size
}
}
pub fn dequantize_tensor(
data: &[u8],
dtype: GgufQuantType,
num_elements: usize,
) -> Result<Vec<f32>> {
let mut output = vec![0.0f32; num_elements];
match dtype {
GgufQuantType::F32 => dequantize_f32(data, &mut output),
GgufQuantType::F16 => dequantize_f16(data, &mut output),
GgufQuantType::Bf16 => dequantize_bf16(data, &mut output),
GgufQuantType::Q4_0 => dequantize_q4_0(data, &mut output),
GgufQuantType::Q4_1 => dequantize_q4_1(data, &mut output),
GgufQuantType::Q5_0 => dequantize_q5_0(data, &mut output),
GgufQuantType::Q5_1 => dequantize_q5_1(data, &mut output),
GgufQuantType::Q8_0 => dequantize_q8_0(data, &mut output),
GgufQuantType::Q8_1 => dequantize_q8_1(data, &mut output),
GgufQuantType::Q2_K => dequantize_q2_k(data, &mut output),
GgufQuantType::Q3_K => dequantize_q3_k(data, &mut output),
GgufQuantType::Q4_K => dequantize_q4_k(data, &mut output),
GgufQuantType::Q5_K => dequantize_q5_k(data, &mut output),
GgufQuantType::Q6_K => dequantize_q6_k(data, &mut output),
GgufQuantType::IQ4_NL => dequantize_iq4_nl(data, &mut output),
GgufQuantType::BitnetT158 => dequantize_bitnet_t158_wrapper(data, &mut output),
GgufQuantType::IQ1_S => {
return Err(RuvLLMError::Model(
"IQ1_S dequantization requires codebook lookup tables (not yet implemented). \
For BitNet ternary quantization, use BITNET_T158 type instead."
.to_string(),
));
}
_ => {
return Err(RuvLLMError::Model(format!(
"Dequantization not implemented for {:?}",
dtype
)));
}
}
Ok(output)
}
pub fn dequantize_block(data: &[u8], dtype: GgufQuantType, output: &mut [f32]) {
match dtype {
GgufQuantType::Q4_0 => dequantize_q4_0_block(data, output),
GgufQuantType::Q4_1 => dequantize_q4_1_block(data, output),
GgufQuantType::Q8_0 => dequantize_q8_0_block(data, output),
GgufQuantType::Q4_K => dequantize_q4_k_block(data, output),
GgufQuantType::BitnetT158 => dequantize_bitnet_t158_block_wrapper(data, output),
_ => {
output.fill(0.0);
}
}
}
fn dequantize_bitnet_t158_block_wrapper(data: &[u8], output: &mut [f32]) {
if data.len() < BITNET_T158_TYPE_SIZE {
output.fill(0.0);
return;
}
let packed = &data[..64];
let scale = f16_to_f32(u16::from_le_bytes([data[64], data[65]]));
let min_output_len = output.len().min(BITNET_T158_BLOCK_SIZE);
let dequantized = dequantize_bitnet_t158(packed, &[scale], min_output_len);
output[..dequantized.len()].copy_from_slice(&dequantized);
}
fn dequantize_f32(data: &[u8], output: &mut [f32]) {
for (i, chunk) in data.chunks_exact(4).enumerate() {
if i >= output.len() {
break;
}
output[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
}
}
fn dequantize_f16(data: &[u8], output: &mut [f32]) {
for (i, chunk) in data.chunks_exact(2).enumerate() {
if i >= output.len() {
break;
}
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
output[i] = f16_to_f32(bits);
}
}
fn dequantize_bf16(data: &[u8], output: &mut [f32]) {
for (i, chunk) in data.chunks_exact(2).enumerate() {
if i >= output.len() {
break;
}
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
output[i] = f32::from_bits((bits as u32) << 16);
}
}
const Q4_0_BLOCK_SIZE: usize = 32;
const Q4_0_TYPE_SIZE: usize = 18;
fn dequantize_q4_0(data: &[u8], output: &mut [f32]) {
let num_blocks = output.len() / Q4_0_BLOCK_SIZE;
for block_idx in 0..num_blocks {
let block_start = block_idx * Q4_0_TYPE_SIZE;
let out_start = block_idx * Q4_0_BLOCK_SIZE;
if block_start + Q4_0_TYPE_SIZE > data.len() {
break;
}
let block = &data[block_start..block_start + Q4_0_TYPE_SIZE];
let out = &mut output[out_start..out_start + Q4_0_BLOCK_SIZE];
dequantize_q4_0_block(block, out);
}
}
fn dequantize_q4_0_block(block: &[u8], output: &mut [f32]) {
let scale = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
for i in 0..16 {
let byte = block[2 + i];
let q0 = (byte & 0x0F) as i8 - 8; let q1 = ((byte >> 4) & 0x0F) as i8 - 8;
output[i * 2] = (q0 as f32) * scale;
output[i * 2 + 1] = (q1 as f32) * scale;
}
}
const Q4_1_BLOCK_SIZE: usize = 32;
const Q4_1_TYPE_SIZE: usize = 20;
fn dequantize_q4_1(data: &[u8], output: &mut [f32]) {
let num_blocks = output.len() / Q4_1_BLOCK_SIZE;
for block_idx in 0..num_blocks {
let block_start = block_idx * Q4_1_TYPE_SIZE;
let out_start = block_idx * Q4_1_BLOCK_SIZE;
if block_start + Q4_1_TYPE_SIZE > data.len() {
break;
}
let block = &data[block_start..block_start + Q4_1_TYPE_SIZE];
let out = &mut output[out_start..out_start + Q4_1_BLOCK_SIZE];
dequantize_q4_1_block(block, out);
}
}
fn dequantize_q4_1_block(block: &[u8], output: &mut [f32]) {
let scale = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
let min = f16_to_f32(u16::from_le_bytes([block[2], block[3]]));
for i in 0..16 {
let byte = block[4 + i];
let q0 = (byte & 0x0F) as f32;
let q1 = ((byte >> 4) & 0x0F) as f32;
output[i * 2] = q0 * scale + min;
output[i * 2 + 1] = q1 * scale + min;
}
}
const Q5_0_BLOCK_SIZE: usize = 32;
const Q5_0_TYPE_SIZE: usize = 22;
fn dequantize_q5_0(data: &[u8], output: &mut [f32]) {
let num_blocks = output.len() / Q5_0_BLOCK_SIZE;
for block_idx in 0..num_blocks {
let block_start = block_idx * Q5_0_TYPE_SIZE;
let out_start = block_idx * Q5_0_BLOCK_SIZE;
if block_start + Q5_0_TYPE_SIZE > data.len() {
break;
}
let scale = f16_to_f32(u16::from_le_bytes([
data[block_start],
data[block_start + 1],
]));
let qh = u32::from_le_bytes([
data[block_start + 2],
data[block_start + 3],
data[block_start + 4],
data[block_start + 5],
]);
for i in 0..16 {
let byte = data[block_start + 6 + i];
let h0 = ((qh >> (i * 2)) & 1) as i8;
let h1 = ((qh >> (i * 2 + 1)) & 1) as i8;
let q0 = ((byte & 0x0F) as i8 | (h0 << 4)) - 16;
let q1 = (((byte >> 4) & 0x0F) as i8 | (h1 << 4)) - 16;
output[out_start + i * 2] = (q0 as f32) * scale;
output[out_start + i * 2 + 1] = (q1 as f32) * scale;
}
}
}
const Q5_1_BLOCK_SIZE: usize = 32;
const Q5_1_TYPE_SIZE: usize = 24;
fn dequantize_q5_1(data: &[u8], output: &mut [f32]) {
let num_blocks = output.len() / Q5_1_BLOCK_SIZE;
for block_idx in 0..num_blocks {
let block_start = block_idx * Q5_1_TYPE_SIZE;
let out_start = block_idx * Q5_1_BLOCK_SIZE;
if block_start + Q5_1_TYPE_SIZE > data.len() {
break;
}
let scale = f16_to_f32(u16::from_le_bytes([
data[block_start],
data[block_start + 1],
]));
let min = f16_to_f32(u16::from_le_bytes([
data[block_start + 2],
data[block_start + 3],
]));
let qh = u32::from_le_bytes([
data[block_start + 4],
data[block_start + 5],
data[block_start + 6],
data[block_start + 7],
]);
for i in 0..16 {
let byte = data[block_start + 8 + i];
let h0 = ((qh >> (i * 2)) & 1) as u8;
let h1 = ((qh >> (i * 2 + 1)) & 1) as u8;
let q0 = ((byte & 0x0F) | (h0 << 4)) as f32;
let q1 = (((byte >> 4) & 0x0F) | (h1 << 4)) as f32;
output[out_start + i * 2] = q0 * scale + min;
output[out_start + i * 2 + 1] = q1 * scale + min;
}
}
}
const Q8_0_BLOCK_SIZE: usize = 32;
const Q8_0_TYPE_SIZE: usize = 34;
fn dequantize_q8_0(data: &[u8], output: &mut [f32]) {
let num_blocks = output.len() / Q8_0_BLOCK_SIZE;
for block_idx in 0..num_blocks {
let block_start = block_idx * Q8_0_TYPE_SIZE;
let out_start = block_idx * Q8_0_BLOCK_SIZE;
if block_start + Q8_0_TYPE_SIZE > data.len() {
break;
}
let block = &data[block_start..block_start + Q8_0_TYPE_SIZE];
let out = &mut output[out_start..out_start + Q8_0_BLOCK_SIZE];
dequantize_q8_0_block(block, out);
}
}
fn dequantize_q8_0_block(block: &[u8], output: &mut [f32]) {
let scale = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
for i in 0..32 {
let q = block[2 + i] as i8;
output[i] = (q as f32) * scale;
}
}
const Q8_1_BLOCK_SIZE: usize = 32;
const Q8_1_TYPE_SIZE: usize = 36;
fn dequantize_q8_1(data: &[u8], output: &mut [f32]) {
let num_blocks = output.len() / Q8_1_BLOCK_SIZE;
for block_idx in 0..num_blocks {
let block_start = block_idx * Q8_1_TYPE_SIZE;
let out_start = block_idx * Q8_1_BLOCK_SIZE;
if block_start + Q8_1_TYPE_SIZE > data.len() {
break;
}
let scale = f16_to_f32(u16::from_le_bytes([
data[block_start],
data[block_start + 1],
]));
let offset = f16_to_f32(u16::from_le_bytes([
data[block_start + 2],
data[block_start + 3],
]));
for i in 0..32 {
let q = data[block_start + 4 + i] as i8;
output[out_start + i] = (q as f32) * scale + offset;
}
}
}
const Q2_K_BLOCK_SIZE: usize = 256;
const Q2_K_TYPE_SIZE: usize = 84;
fn dequantize_q2_k(data: &[u8], output: &mut [f32]) {
let num_blocks = output.len() / Q2_K_BLOCK_SIZE;
for block_idx in 0..num_blocks {
let block_start = block_idx * Q2_K_TYPE_SIZE;
let out_start = block_idx * Q2_K_BLOCK_SIZE;
if block_start + Q2_K_TYPE_SIZE > data.len() {
break;
}
let block = &data[block_start..];
let d = f16_to_f32(u16::from_le_bytes([block[16], block[17]]));
let dmin = f16_to_f32(u16::from_le_bytes([block[18], block[19]]));
for j in 0..16 {
let sc = (block[j / 2] >> ((j % 2) * 4)) & 0x0F;
let scale = d * (sc as f32);
let min = dmin * (sc as f32);
for k in 0..16 {
let idx = j * 16 + k;
let byte_idx = 20 + idx / 4;
let bit_idx = (idx % 4) * 2;
let q = (block[byte_idx] >> bit_idx) & 0x03;
output[out_start + idx] = (q as f32) * scale - min;
}
}
}
}
const Q3_K_BLOCK_SIZE: usize = 256;
const Q3_K_TYPE_SIZE: usize = 110;
fn dequantize_q3_k(data: &[u8], output: &mut [f32]) {
let num_blocks = output.len() / Q3_K_BLOCK_SIZE;
for block_idx in 0..num_blocks {
let block_start = block_idx * Q3_K_TYPE_SIZE;
let out_start = block_idx * Q3_K_BLOCK_SIZE;
if block_start + Q3_K_TYPE_SIZE > data.len() {
break;
}
let block = &data[block_start..];
let d = f16_to_f32(u16::from_le_bytes([block[104], block[105]]));
for i in 0..256 {
let byte_idx = i * 3 / 8;
let bit_offset = (i * 3) % 8;
if byte_idx < 96 {
let q = ((block[byte_idx] >> bit_offset) & 0x07) as i8 - 4;
output[out_start + i] = (q as f32) * d;
}
}
}
}
const Q4_K_BLOCK_SIZE: usize = 256;
const Q4_K_TYPE_SIZE: usize = 144;
fn dequantize_q4_k(data: &[u8], output: &mut [f32]) {
let num_blocks = output.len() / Q4_K_BLOCK_SIZE;
for block_idx in 0..num_blocks {
let block_start = block_idx * Q4_K_TYPE_SIZE;
let out_start = block_idx * Q4_K_BLOCK_SIZE;
if block_start + Q4_K_TYPE_SIZE > data.len() {
break;
}
let block = &data[block_start..block_start + Q4_K_TYPE_SIZE];
let out = &mut output[out_start..out_start + Q4_K_BLOCK_SIZE];
dequantize_q4_k_block(block, out);
}
}
fn dequantize_q4_k_block(block: &[u8], output: &mut [f32]) {
let d = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
let dmin = f16_to_f32(u16::from_le_bytes([block[2], block[3]]));
for sb in 0..8 {
let scale_idx = sb * 6 / 8;
let scale_shift = (sb * 6) % 8;
let mut sc = (block[4 + scale_idx] >> scale_shift) & 0x3F;
if scale_shift > 2 && scale_idx + 1 < 12 {
sc |= (block[4 + scale_idx + 1] << (8 - scale_shift)) & 0x3F;
}
let scale = d * (sc as f32);
let qs_start = 16 + sb * 16; for i in 0..16 {
let byte = block[qs_start + i];
let q0 = (byte & 0x0F) as f32;
let q1 = ((byte >> 4) & 0x0F) as f32;
output[sb * 32 + i * 2] = q0 * scale + dmin;
output[sb * 32 + i * 2 + 1] = q1 * scale + dmin;
}
}
}
const Q5_K_BLOCK_SIZE: usize = 256;
const Q5_K_TYPE_SIZE: usize = 176;
fn dequantize_q5_k(data: &[u8], output: &mut [f32]) {
let num_blocks = output.len() / Q5_K_BLOCK_SIZE;
for block_idx in 0..num_blocks {
let block_start = block_idx * Q5_K_TYPE_SIZE;
let out_start = block_idx * Q5_K_BLOCK_SIZE;
if block_start + Q5_K_TYPE_SIZE > data.len() {
break;
}
let block = &data[block_start..];
let d = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
let dmin = f16_to_f32(u16::from_le_bytes([block[2], block[3]]));
for i in 0..256 {
let byte_idx = 16 + (i * 5) / 8;
let bit_offset = (i * 5) % 8;
if byte_idx < Q5_K_TYPE_SIZE {
let mut q = (block[byte_idx] >> bit_offset) & 0x1F;
if bit_offset > 3 && byte_idx + 1 < Q5_K_TYPE_SIZE {
q |= (block[byte_idx + 1] << (8 - bit_offset)) & 0x1F;
}
output[out_start + i] = (q as f32) * d + dmin;
}
}
}
}
const Q6_K_BLOCK_SIZE: usize = 256;
const Q6_K_TYPE_SIZE: usize = 210;
fn dequantize_q6_k(data: &[u8], output: &mut [f32]) {
let num_blocks = output.len() / Q6_K_BLOCK_SIZE;
for block_idx in 0..num_blocks {
let block_start = block_idx * Q6_K_TYPE_SIZE;
let out_start = block_idx * Q6_K_BLOCK_SIZE;
if block_start + Q6_K_TYPE_SIZE > data.len() {
break;
}
let block = &data[block_start..];
let d = f16_to_f32(u16::from_le_bytes([block[208], block[209]]));
for i in 0..256 {
let ql_idx = i / 2;
let is_high = i % 2 == 1;
if ql_idx < 128 {
let ql = if is_high {
(block[ql_idx] >> 4) & 0x0F
} else {
block[ql_idx] & 0x0F
};
let qh_idx = 128 + i / 4;
let qh_shift = (i % 4) * 2;
let qh = if qh_idx < 192 {
(block[qh_idx] >> qh_shift) & 0x03
} else {
0
};
let q = ((qh << 4) | ql) as i8 - 32;
let scale_idx = i / 16;
let sc = if scale_idx < 16 {
(block[192 + scale_idx / 2] >> ((scale_idx % 2) * 4)) & 0x0F
} else {
1
};
output[out_start + i] = (q as f32) * d * (sc as f32);
}
}
}
}
const IQ4_NL_BLOCK_SIZE: usize = 32;
const IQ4_NL_TYPE_SIZE: usize = 18;
const IQ4_NL_LUT: [f32; 16] = [
-1.0, -0.75, -0.5, -0.375, -0.25, -0.125, 0.0, 0.125, 0.25, 0.375, 0.5, 0.75, 1.0, 1.5, 2.0,
3.0,
];
fn dequantize_iq4_nl(data: &[u8], output: &mut [f32]) {
let num_blocks = output.len() / IQ4_NL_BLOCK_SIZE;
for block_idx in 0..num_blocks {
let block_start = block_idx * IQ4_NL_TYPE_SIZE;
let out_start = block_idx * IQ4_NL_BLOCK_SIZE;
if block_start + IQ4_NL_TYPE_SIZE > data.len() {
break;
}
let scale = f16_to_f32(u16::from_le_bytes([
data[block_start],
data[block_start + 1],
]));
for i in 0..16 {
let byte = data[block_start + 2 + i];
let q0 = (byte & 0x0F) as usize;
let q1 = ((byte >> 4) & 0x0F) as usize;
output[out_start + i * 2] = IQ4_NL_LUT[q0] * scale;
output[out_start + i * 2 + 1] = IQ4_NL_LUT[q1] * scale;
}
}
}
const BITNET_T158_BLOCK_SIZE: usize = 256;
const BITNET_T158_TYPE_SIZE: usize = 66;
fn dequantize_bitnet_t158_wrapper(data: &[u8], output: &mut [f32]) {
let num_blocks = output.len() / BITNET_T158_BLOCK_SIZE;
let mut scales = Vec::with_capacity(num_blocks);
let mut packed_data = Vec::with_capacity(num_blocks * 64);
for block_idx in 0..num_blocks {
let block_start = block_idx * BITNET_T158_TYPE_SIZE;
if block_start + BITNET_T158_TYPE_SIZE > data.len() {
break;
}
packed_data.extend_from_slice(&data[block_start..block_start + 64]);
let scale_f16 = f16_to_f32(u16::from_le_bytes([
data[block_start + 64],
data[block_start + 65],
]));
scales.push(scale_f16);
}
let dequantized = dequantize_bitnet_t158(&packed_data, &scales, output.len());
output[..dequantized.len()].copy_from_slice(&dequantized);
}
#[inline(always)]
fn f16_to_f32(bits: u16) -> f32 {
let sign = ((bits & 0x8000) as u32) << 16;
let exp = ((bits >> 10) & 0x1F) as u32;
let frac = (bits & 0x03FF) as u32;
if exp == 0 {
if frac == 0 {
return f32::from_bits(sign);
}
let mut e = 1u32;
let mut f = frac;
while (f & 0x0400) == 0 {
f <<= 1;
e += 1;
}
f &= 0x03FF;
return f32::from_bits(sign | ((127 - 15 + 1 - e) << 23) | (f << 13));
}
if exp == 31 {
return f32::from_bits(sign | 0x7F80_0000 | (frac << 13));
}
f32::from_bits(sign | ((exp + 127 - 15) << 23) | (frac << 13))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quant_type_sizes() {
assert_eq!(GgufQuantType::F32.block_size(), 1);
assert_eq!(GgufQuantType::F32.type_size(), 4);
assert_eq!(GgufQuantType::Q4_0.block_size(), 32);
assert_eq!(GgufQuantType::Q4_0.type_size(), 18);
assert_eq!(GgufQuantType::Q4_K.block_size(), 256);
assert_eq!(GgufQuantType::Q4_K.type_size(), 144);
}
#[test]
fn test_quant_type_bits() {
assert!((GgufQuantType::F32.bits_per_weight() - 32.0).abs() < 0.1);
assert!((GgufQuantType::Q4_0.bits_per_weight() - 4.5).abs() < 0.1);
assert!((GgufQuantType::Q8_0.bits_per_weight() - 8.5).abs() < 0.1);
}
#[test]
fn test_f16_conversion() {
assert_eq!(f16_to_f32(0x0000), 0.0);
assert_eq!(f16_to_f32(0x3C00), 1.0);
assert_eq!(f16_to_f32(0xBC00), -1.0);
let half = f16_to_f32(0x3800); assert!((half - 0.5).abs() < 0.001);
}
#[test]
fn test_q4_0_dequantize() {
let mut block = vec![0u8; 18];
block[0] = 0x00;
block[1] = 0x3C;
for i in 0..16 {
block[2 + i] = 0x88; }
let mut output = vec![0.0f32; 32];
dequantize_q4_0_block(&block, &mut output);
for val in &output {
assert!(val.abs() < 0.001);
}
}
#[test]
fn test_q8_0_dequantize() {
let mut block = vec![0u8; 34];
block[0] = 0x00;
block[1] = 0x3C;
for i in 0..32 {
block[2 + i] = (i + 1) as u8;
}
let mut output = vec![0.0f32; 32];
dequantize_q8_0_block(&block, &mut output);
for i in 0..32 {
assert!((output[i] - (i + 1) as f32).abs() < 0.001);
}
}
#[test]
fn test_quant_type_try_from() {
assert_eq!(GgufQuantType::try_from(0).unwrap(), GgufQuantType::F32);
assert_eq!(GgufQuantType::try_from(12).unwrap(), GgufQuantType::Q4_K);
assert!(GgufQuantType::try_from(100).is_err());
}
#[test]
fn test_quantized_tensor() {
let tensor = QuantizedTensor {
data: vec![0u8; 144],
dtype: GgufQuantType::Q4_K,
shape: vec![256],
num_elements: 256,
};
assert_eq!(tensor.block_count(), 1);
assert!(tensor.dtype.is_quantized());
}
}