use half::f16;
pub fn dequant_iq3_xxs(blocks: &[u8], output: &mut [f32]) {
const BLOCK_SIZE: usize = 256;
const BLOCK_BYTES: usize = 98;
let num_blocks = blocks.len() / BLOCK_BYTES;
debug_assert_eq!(output.len(), num_blocks * BLOCK_SIZE);
for b in 0..num_blocks {
let block = &blocks[b * BLOCK_BYTES..];
let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
let qs = &block[2..98];
let out = &mut output[b * BLOCK_SIZE..][..BLOCK_SIZE];
for group in 0..8 {
let gdata = &qs[group * 12..(group + 1) * 12];
let signs = u32::from_le_bytes([gdata[8], gdata[9], gdata[10], gdata[11]]);
let sub_scale_bits = (signs >> 28) & 0x0F;
let sub_scale = d * (1.0 + sub_scale_bits as f32);
let group_out = &mut out[group * 32..(group + 1) * 32];
for sub in 0..4 {
let grid_lo = gdata[sub * 2] as u16;
let grid_hi = gdata[sub * 2 + 1] as u16;
let grid_idx = (grid_lo | (grid_hi << 8)) as usize;
for k in 0..8 {
let val = ((grid_idx >> (k * 2)) & 0x03) as f32 + 1.0;
let sign_bit = (signs >> (sub * 8 + k)) & 1;
let sign = if sign_bit != 0 { -1.0f32 } else { 1.0f32 };
group_out[sub * 8 + k] = sub_scale * val * sign;
}
}
}
}
}
#[allow(clippy::needless_range_loop)]
pub fn dequant_iq3_s(blocks: &[u8], output: &mut [f32]) {
const BLOCK_SIZE: usize = 256;
const BLOCK_BYTES: usize = 110;
let num_blocks = blocks.len() / BLOCK_BYTES;
debug_assert_eq!(output.len(), num_blocks * BLOCK_SIZE);
for b in 0..num_blocks {
let block = &blocks[b * BLOCK_BYTES..];
let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
let qs = &block[2..34]; let qh = &block[34..38]; let signs = &block[38..70]; let scales = &block[70..78];
let out = &mut output[b * BLOCK_SIZE..][..BLOCK_SIZE];
for sb in 0..8 {
let scale_byte = scales[sb];
let sub_scale = d * (1.0 + (scale_byte & 0x0F) as f32);
let sub_out = &mut out[sb * 32..(sb + 1) * 32];
for k in 0..32 {
let byte_idx = sb * 4 + k / 8;
let bit_pos = k % 8;
let q3 = if byte_idx < 32 {
((qs[byte_idx] >> (bit_pos % 4 * 2)) & 0x03) as f32
} else {
0.0
};
let qh_byte_idx = (sb * 32 + k) / 8;
let qh_bit = if qh_byte_idx < 4 {
((qh[qh_byte_idx] >> ((sb * 32 + k) % 8)) & 1) as f32
} else {
0.0
};
let val = q3 + qh_bit * 4.0 + 1.0;
let sign_byte_idx = sb * 4 + k / 8;
let sign_byte = if sign_byte_idx < 32 {
signs[sign_byte_idx]
} else {
0
};
let sign = if (sign_byte >> (k % 8)) & 1 != 0 {
-1.0f32
} else {
1.0f32
};
sub_out[k] = sub_scale * val * sign;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dequant_iq3_xxs_zeros() {
let block = [0u8; 98];
let mut output = [0.0f32; 256];
dequant_iq3_xxs(&block, &mut output);
for &v in &output {
assert!(v.abs() < 1e-5, "expected 0, got {}", v);
}
}
#[test]
fn test_dequant_iq3_s_zeros() {
let block = [0u8; 110];
let mut output = [0.0f32; 256];
dequant_iq3_s(&block, &mut output);
for &v in &output {
assert!(v.abs() < 1e-5, "expected 0, got {}", v);
}
}
}