use crate::error::AnamnesisError;
use crate::parse::safetensors::Dtype;
use crate::remember::fp8::f32_bits_to_bf16_bits;
use crate::remember::quant_utils::{read_scale_f32, read_u32_le};
#[must_use]
fn unpack_gptq(packed: u32, shift: u32, mask: u32) -> u32 {
(packed >> shift) & mask
}
fn unpack_zeros_for_group(
buf: &mut [f32],
qzeros_data: &[u8],
g: usize,
out_features: usize,
bits: u8,
) -> crate::Result<()> {
#[allow(clippy::as_conversions)]
let bits_u32 = u32::from(bits);
let mask = (1u32 << bits_u32) - 1;
#[allow(clippy::as_conversions)]
let pack_factor = 32 / bits as usize;
let packed_cols =
out_features
.checked_div(pack_factor)
.ok_or_else(|| AnamnesisError::Parse {
reason: "pack_factor is zero".into(),
})?;
for (j, buf_val) in buf.iter_mut().enumerate() {
let packed_col = j / pack_factor;
let pos = j % pack_factor;
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
let shift = bits_u32 * (pos as u32);
let byte_offset = (g * packed_cols + packed_col)
.checked_mul(4)
.ok_or_else(|| AnamnesisError::Parse {
reason: "qzeros byte offset overflow".into(),
})?;
let packed = read_u32_le(qzeros_data, byte_offset)?;
let qz = unpack_gptq(packed, shift, mask);
#[allow(clippy::as_conversions, clippy::cast_precision_loss)]
{
*buf_val = (qz + 1) as f32;
}
}
Ok(())
}
fn unpack_scales_for_group(
buf: &mut [f32],
scales_data: &[u8],
g: usize,
out_features: usize,
scale_dtype: Dtype,
) -> crate::Result<()> {
let bps = scale_dtype.byte_size();
let row_start = g
.checked_mul(out_features)
.ok_or_else(|| AnamnesisError::Parse {
reason: "scales group row offset overflow".into(),
})?;
for (j, buf_val) in buf.iter_mut().enumerate() {
let byte_offset = row_start
.checked_add(j)
.and_then(|idx| idx.checked_mul(bps))
.ok_or_else(|| AnamnesisError::Parse {
reason: "scale byte offset overflow".into(),
})?;
*buf_val = read_scale_f32(scales_data, byte_offset, scale_dtype)?;
}
Ok(())
}
fn parse_g_idx(g_idx_data: &[u8], in_features: usize) -> crate::Result<Vec<usize>> {
let expected_len = in_features
.checked_mul(4)
.ok_or_else(|| AnamnesisError::Parse {
reason: "g_idx byte length overflow".into(),
})?;
if g_idx_data.len() != expected_len {
return Err(AnamnesisError::Parse {
reason: format!(
"g_idx data length {} != expected {expected_len} (in_features={in_features} × 4)",
g_idx_data.len()
),
});
}
let mut indices = Vec::with_capacity(in_features);
for i in 0..in_features {
let byte_offset = i.checked_mul(4).ok_or_else(|| AnamnesisError::Parse {
reason: "g_idx byte offset overflow".into(),
})?;
let val = read_u32_le(g_idx_data, byte_offset)?;
#[allow(clippy::as_conversions)]
let idx = val as usize;
indices.push(idx);
}
Ok(indices)
}
#[allow(clippy::too_many_arguments)]
pub fn dequantize_gptq_to_bf16(
qweight_data: &[u8],
scales_data: &[u8],
qzeros_data: &[u8],
g_idx_data: Option<&[u8]>,
in_features: usize,
out_features: usize,
group_size: usize,
bits: u8,
scale_dtype: Dtype,
) -> crate::Result<Vec<u8>> {
if bits != 4 && bits != 8 {
return Err(AnamnesisError::Unsupported {
format: "GPTQ".into(),
detail: format!("{bits}-bit quantization not supported (expected 4 or 8)"),
});
}
#[allow(clippy::as_conversions)]
let pack_factor = 32 / bits as usize;
if in_features == 0 || out_features == 0 || group_size == 0 {
return Err(AnamnesisError::Parse {
reason: format!(
"zero dimension: in_features={in_features}, out_features={out_features}, \
group_size={group_size}"
),
});
}
if !in_features.is_multiple_of(pack_factor) {
return Err(AnamnesisError::Parse {
reason: format!(
"in_features {in_features} is not a multiple of pack_factor {pack_factor}"
),
});
}
if !out_features.is_multiple_of(pack_factor) {
return Err(AnamnesisError::Parse {
reason: format!(
"out_features {out_features} is not a multiple of pack_factor {pack_factor}"
),
});
}
if !in_features.is_multiple_of(group_size) {
return Err(AnamnesisError::Parse {
reason: format!(
"in_features {in_features} is not a multiple of group_size {group_size}"
),
});
}
let packed_rows = in_features / pack_factor;
let packed_cols = out_features / pack_factor;
let num_groups = in_features / group_size;
let expected_qw_len = packed_rows
.checked_mul(out_features)
.and_then(|n| n.checked_mul(4))
.ok_or_else(|| AnamnesisError::Parse {
reason: "qweight byte length overflow".into(),
})?;
if qweight_data.len() != expected_qw_len {
return Err(AnamnesisError::Parse {
reason: format!(
"qweight data length {} != expected {expected_qw_len}",
qweight_data.len()
),
});
}
let expected_scales_len = num_groups
.checked_mul(out_features)
.and_then(|n| n.checked_mul(scale_dtype.byte_size()))
.ok_or_else(|| AnamnesisError::Parse {
reason: "scales byte length overflow".into(),
})?;
if scales_data.len() != expected_scales_len {
return Err(AnamnesisError::Parse {
reason: format!(
"scales data length {} != expected {expected_scales_len}",
scales_data.len()
),
});
}
let expected_qzeros_len = num_groups
.checked_mul(packed_cols)
.and_then(|n| n.checked_mul(4))
.ok_or_else(|| AnamnesisError::Parse {
reason: "qzeros byte length overflow".into(),
})?;
if qzeros_data.len() != expected_qzeros_len {
return Err(AnamnesisError::Parse {
reason: format!(
"qzeros data length {} != expected {expected_qzeros_len}",
qzeros_data.len()
),
});
}
let g_idx = g_idx_data
.map(|data| parse_g_idx(data, in_features))
.transpose()?;
if let Some(ref idx) = g_idx {
for (i, &g) in idx.iter().enumerate() {
if g >= num_groups {
return Err(AnamnesisError::Parse {
reason: format!("g_idx[{i}] = {g} >= num_groups {num_groups}"),
});
}
}
}
let out_byte_len = in_features
.checked_mul(out_features)
.and_then(|n| n.checked_mul(2))
.ok_or_else(|| AnamnesisError::Parse {
reason: "output size overflow".into(),
})?;
let mut output = vec![0u8; out_byte_len];
#[allow(clippy::as_conversions)]
let bits_u32 = u32::from(bits);
let mask = (1u32 << bits_u32) - 1;
let mut unpacked_buf = vec![0.0_f32; out_features];
let mut zeros_buf = vec![0.0_f32; out_features];
let mut scales_buf = vec![0.0_f32; out_features];
let mut cached_group: Option<usize> = None;
for i in 0..in_features {
let g = if let Some(ref idx) = g_idx {
idx.get(i).copied().ok_or_else(|| AnamnesisError::Parse {
reason: format!("g_idx index {i} out of bounds"),
})?
} else {
i / group_size
};
let packed_row = i / pack_factor;
let pos = i % pack_factor;
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
let shift = bits_u32 * (pos as u32);
let qw_row_start = packed_row
.checked_mul(out_features)
.and_then(|n| n.checked_mul(4))
.ok_or_else(|| AnamnesisError::Parse {
reason: "qweight row byte offset overflow".into(),
})?;
let qw_row_end = qw_row_start + out_features * 4;
let qw_row =
qweight_data
.get(qw_row_start..qw_row_end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("qweight row {packed_row} out of bounds"),
})?;
if cached_group != Some(g) {
unpack_zeros_for_group(&mut zeros_buf, qzeros_data, g, out_features, bits)?;
unpack_scales_for_group(&mut scales_buf, scales_data, g, out_features, scale_dtype)?;
cached_group = Some(g);
}
let zeros_row = &zeros_buf[..];
let scales_row = &scales_buf[..];
let out_row_start = i
.checked_mul(out_features)
.and_then(|n| n.checked_mul(2))
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("output row {i} offset overflow"),
})?;
let out_row_end = out_features
.checked_mul(2)
.and_then(|row_bytes| out_row_start.checked_add(row_bytes))
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("output row {i} end overflow"),
})?;
let out_row =
output
.get_mut(out_row_start..out_row_end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("output row {i} out of bounds"),
})?;
let unpacked_row =
unpacked_buf
.get_mut(..out_features)
.ok_or_else(|| AnamnesisError::Parse {
reason: "unpacked buffer too short".into(),
})?;
#[allow(clippy::indexing_slicing)]
for (j, qw_chunk) in qw_row.chunks_exact(4).enumerate() {
let packed = u32::from_le_bytes([qw_chunk[0], qw_chunk[1], qw_chunk[2], qw_chunk[3]]);
#[allow(clippy::as_conversions, clippy::cast_precision_loss)]
let qw = unpack_gptq(packed, shift, mask) as f32;
unpacked_row[j] = qw;
}
for (((out_pair, &qw), &zero), &scale) in out_row
.chunks_exact_mut(2)
.zip(unpacked_row.iter())
.zip(zeros_row.iter())
.zip(scales_row.iter())
{
let val = (qw - zero) * scale;
let bf16 = f32_bits_to_bf16_bits(val.to_bits());
out_pair.copy_from_slice(&bf16.to_le_bytes());
}
}
Ok(output)
}
#[cfg(test)]
#[allow(
clippy::panic,
clippy::indexing_slicing,
clippy::unwrap_used,
clippy::as_conversions,
clippy::cast_possible_truncation,
clippy::float_cmp
)]
mod tests {
use super::*;
#[test]
fn unpack_4bit_all_positions() {
let packed: u32 = 0x7654_3210;
let mask = 0xF;
for pos in 0..8u32 {
let shift = 4 * pos;
assert_eq!(unpack_gptq(packed, shift, mask), pos);
}
}
#[test]
fn unpack_4bit_max_value() {
let packed: u32 = 0xFFFF_FFFF;
let mask = 0xF;
for pos in 0..8u32 {
let shift = 4 * pos;
assert_eq!(unpack_gptq(packed, shift, mask), 15);
}
}
#[test]
fn unpack_8bit_all_positions() {
let packed: u32 = 0x4030_2010;
let mask = 0xFF;
assert_eq!(unpack_gptq(packed, 0, mask), 0x10);
assert_eq!(unpack_gptq(packed, 8, mask), 0x20);
assert_eq!(unpack_gptq(packed, 16, mask), 0x30);
assert_eq!(unpack_gptq(packed, 24, mask), 0x40);
}
#[test]
fn unpack_8bit_max_value() {
let packed: u32 = 0xFFFF_FFFF;
let mask = 0xFF;
for pos in 0..4u32 {
let shift = 8 * pos;
assert_eq!(unpack_gptq(packed, shift, mask), 255);
}
}
#[test]
fn dequant_4bit_uniform() {
let in_features = 8;
let out_features = 8;
let group_size = 8;
let bits: u8 = 4;
let qweight_i32 = 0x5555_5555u32;
let mut qweight_data = Vec::new();
for _j in 0..out_features {
qweight_data.extend_from_slice(&qweight_i32.to_le_bytes());
}
let scale_f16 = half::f16::from_f32(2.0).to_le_bytes();
let mut scales_data = Vec::new();
for _ in 0..out_features {
scales_data.extend_from_slice(&scale_f16);
}
let qzeros_i32 = 0x3333_3333u32;
let qzeros_data = qzeros_i32.to_le_bytes().to_vec();
let output = dequantize_gptq_to_bf16(
&qweight_data,
&scales_data,
&qzeros_data,
None,
in_features,
out_features,
group_size,
bits,
Dtype::F16,
)
.unwrap();
assert_eq!(output.len(), in_features * out_features * 2);
for chunk in output.chunks_exact(2) {
assert_eq!(chunk, &[0x00, 0x40], "expected BF16 2.0");
}
}
#[test]
fn dequant_4bit_with_g_idx() {
let in_features = 8;
let out_features = 8;
let group_size = 4;
let bits: u8 = 4;
let g_idx_values: Vec<u32> = vec![1, 1, 1, 1, 0, 0, 0, 0];
let g_idx_data: Vec<u8> = g_idx_values.iter().flat_map(|v| v.to_le_bytes()).collect();
let qweight_i32 = 0xAAAA_AAAAu32; let mut qweight_data = Vec::new();
for _j in 0..out_features {
qweight_data.extend_from_slice(&qweight_i32.to_le_bytes());
}
let scale_1 = half::f16::from_f32(1.0).to_le_bytes();
let scale_3 = half::f16::from_f32(3.0).to_le_bytes();
let mut scales_data = Vec::new();
for _ in 0..out_features {
scales_data.extend_from_slice(&scale_1); }
for _ in 0..out_features {
scales_data.extend_from_slice(&scale_3); }
let qz_group0 = 0x7777_7777u32;
let qz_group1 = 0x4444_4444u32;
let mut qzeros_data = Vec::new();
qzeros_data.extend_from_slice(&qz_group0.to_le_bytes());
qzeros_data.extend_from_slice(&qz_group1.to_le_bytes());
let output = dequantize_gptq_to_bf16(
&qweight_data,
&scales_data,
&qzeros_data,
Some(&g_idx_data),
in_features,
out_features,
group_size,
bits,
Dtype::F16,
)
.unwrap();
let bf16_15 = f32_bits_to_bf16_bits(15.0_f32.to_bits());
for i in 0..4 {
for j in 0..out_features {
let offset = (i * out_features + j) * 2;
let actual = u16::from_le_bytes([output[offset], output[offset + 1]]);
assert_eq!(actual, bf16_15, "element [{i},{j}]: expected BF16 15.0");
}
}
let bf16_2 = f32_bits_to_bf16_bits(2.0_f32.to_bits());
for i in 4..8 {
for j in 0..out_features {
let offset = (i * out_features + j) * 2;
let actual = u16::from_le_bytes([output[offset], output[offset + 1]]);
assert_eq!(actual, bf16_2, "element [{i},{j}]: expected BF16 2.0");
}
}
}
#[test]
fn dequant_8bit_uniform() {
let in_features = 4;
let out_features = 4;
let group_size = 4;
let bits: u8 = 8;
let qweight_i32 = 0x6464_6464u32; let mut qweight_data = Vec::new();
for _j in 0..out_features {
qweight_data.extend_from_slice(&qweight_i32.to_le_bytes());
}
let scale_half = half::f16::from_f32(0.5).to_le_bytes();
let mut scales_data = Vec::new();
for _ in 0..out_features {
scales_data.extend_from_slice(&scale_half);
}
let qzeros_i32 = 0x3131_3131u32; let qzeros_data = qzeros_i32.to_le_bytes().to_vec();
let output = dequantize_gptq_to_bf16(
&qweight_data,
&scales_data,
&qzeros_data,
None,
in_features,
out_features,
group_size,
bits,
Dtype::F16,
)
.unwrap();
let bf16_25 = f32_bits_to_bf16_bits(25.0_f32.to_bits());
for chunk in output.chunks_exact(2) {
let actual = u16::from_le_bytes([chunk[0], chunk[1]]);
assert_eq!(actual, bf16_25, "expected BF16 25.0");
}
}
#[test]
fn validation_unsupported_bits() {
let result = dequantize_gptq_to_bf16(&[], &[], &[], None, 8, 8, 8, 3, Dtype::F16);
assert!(result.is_err());
}
#[test]
fn validation_zero_dimensions() {
let result = dequantize_gptq_to_bf16(&[], &[], &[], None, 0, 8, 8, 4, Dtype::F16);
assert!(result.is_err());
}
#[test]
fn validation_in_features_not_multiple_of_pack_factor() {
let result = dequantize_gptq_to_bf16(&[], &[], &[], None, 5, 8, 5, 4, Dtype::F16);
assert!(result.is_err());
}
#[test]
fn validation_g_idx_out_of_range() {
let in_features = 8;
let out_features = 8;
let group_size = 4;
let bits: u8 = 4;
let qweight_i32 = 0x5555_5555u32;
let mut qweight_data = Vec::new();
for _ in 0..out_features {
qweight_data.extend_from_slice(&qweight_i32.to_le_bytes());
}
let scale_f16 = half::f16::from_f32(1.0).to_le_bytes();
let mut scales_data = Vec::new();
for _ in 0..2 * out_features {
scales_data.extend_from_slice(&scale_f16);
}
let qz = 0x3333_3333u32;
let mut qzeros_data = Vec::new();
qzeros_data.extend_from_slice(&qz.to_le_bytes());
qzeros_data.extend_from_slice(&qz.to_le_bytes());
let g_idx_values: Vec<u32> = vec![0, 0, 0, 0, 1, 99, 1, 1];
let g_idx_data: Vec<u8> = g_idx_values.iter().flat_map(|v| v.to_le_bytes()).collect();
let result = dequantize_gptq_to_bf16(
&qweight_data,
&scales_data,
&qzeros_data,
Some(&g_idx_data),
in_features,
out_features,
group_size,
bits,
Dtype::F16,
);
let err = result.unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("g_idx[5]") && msg.contains("99"),
"expected fail-fast g_idx error, got: {msg}"
);
}
#[test]
fn validation_qweight_length_mismatch() {
let result = dequantize_gptq_to_bf16(
&[0u8; 16], &[0u8; 16], &[0u8; 4], None,
8,
8,
8,
4,
Dtype::F16,
);
assert!(result.is_err());
}
}