use crate::error::AnamnesisError;
use crate::parse::safetensors::Dtype;
use crate::remember::fp8::f32_bits_to_bf16_bits;
use crate::remember::quant_utils::{read_scale_f32, read_u32_le};
fn unpack_zeros_for_group(
buf: &mut [f32],
qzeros_data: &[u8],
g: usize,
out_features: usize,
bits: u8,
) -> crate::Result<()> {
#[allow(clippy::as_conversions)]
let bits_u32 = u32::from(bits);
let mask = (1u32 << bits_u32) - 1;
#[allow(clippy::as_conversions)]
let pack_factor = 32 / bits as usize;
let packed_cols =
out_features
.checked_div(pack_factor)
.ok_or_else(|| AnamnesisError::Parse {
reason: "pack_factor is zero".into(),
})?;
for (j, buf_val) in buf.iter_mut().enumerate() {
let packed_col = j / pack_factor;
let pos = j % pack_factor;
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
let shift = bits_u32 * (pos as u32);
let byte_offset = (g * packed_cols + packed_col)
.checked_mul(4)
.ok_or_else(|| AnamnesisError::Parse {
reason: "qzeros byte offset overflow".into(),
})?;
let packed = read_u32_le(qzeros_data, byte_offset)?;
let qz = (packed >> shift) & mask;
#[allow(clippy::as_conversions, clippy::cast_precision_loss)]
{
*buf_val = qz as f32;
}
}
Ok(())
}
fn unpack_scales_for_group(
buf: &mut [f32],
scales_data: &[u8],
g: usize,
out_features: usize,
scale_dtype: Dtype,
) -> crate::Result<()> {
let bps = scale_dtype.byte_size();
let row_start = g
.checked_mul(out_features)
.ok_or_else(|| AnamnesisError::Parse {
reason: "scales group row offset overflow".into(),
})?;
for (j, buf_val) in buf.iter_mut().enumerate() {
let byte_offset = row_start
.checked_add(j)
.and_then(|idx| idx.checked_mul(bps))
.ok_or_else(|| AnamnesisError::Parse {
reason: "scale byte offset overflow".into(),
})?;
*buf_val = read_scale_f32(scales_data, byte_offset, scale_dtype)?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dequantize_awq_to_bf16(
qweight_data: &[u8],
scales_data: &[u8],
qzeros_data: &[u8],
in_features: usize,
out_features: usize,
group_size: usize,
bits: u8,
scale_dtype: Dtype,
) -> crate::Result<Vec<u8>> {
if bits != 4 && bits != 8 {
return Err(AnamnesisError::Unsupported {
format: "AWQ".into(),
detail: format!("{bits}-bit quantization not supported (expected 4 or 8)"),
});
}
#[allow(clippy::as_conversions)]
let pack_factor = 32 / bits as usize;
if in_features == 0 || out_features == 0 || group_size == 0 {
return Err(AnamnesisError::Parse {
reason: format!(
"zero dimension: in_features={in_features}, out_features={out_features}, \
group_size={group_size}"
),
});
}
if !out_features.is_multiple_of(pack_factor) {
return Err(AnamnesisError::Parse {
reason: format!(
"out_features {out_features} is not a multiple of pack_factor {pack_factor}"
),
});
}
if !in_features.is_multiple_of(group_size) {
return Err(AnamnesisError::Parse {
reason: format!(
"in_features {in_features} is not a multiple of group_size {group_size}"
),
});
}
let packed_cols = out_features / pack_factor;
let num_groups = in_features / group_size;
let expected_qw_len = in_features
.checked_mul(packed_cols)
.and_then(|n| n.checked_mul(4))
.ok_or_else(|| AnamnesisError::Parse {
reason: "qweight byte length overflow".into(),
})?;
if qweight_data.len() != expected_qw_len {
return Err(AnamnesisError::Parse {
reason: format!(
"qweight data length {} != expected {expected_qw_len}",
qweight_data.len()
),
});
}
let expected_scales_len = num_groups
.checked_mul(out_features)
.and_then(|n| n.checked_mul(scale_dtype.byte_size()))
.ok_or_else(|| AnamnesisError::Parse {
reason: "scales byte length overflow".into(),
})?;
if scales_data.len() != expected_scales_len {
return Err(AnamnesisError::Parse {
reason: format!(
"scales data length {} != expected {expected_scales_len}",
scales_data.len()
),
});
}
let expected_qzeros_len = num_groups
.checked_mul(packed_cols)
.and_then(|n| n.checked_mul(4))
.ok_or_else(|| AnamnesisError::Parse {
reason: "qzeros byte length overflow".into(),
})?;
if qzeros_data.len() != expected_qzeros_len {
return Err(AnamnesisError::Parse {
reason: format!(
"qzeros data length {} != expected {expected_qzeros_len}",
qzeros_data.len()
),
});
}
let out_byte_len = in_features
.checked_mul(out_features)
.and_then(|n| n.checked_mul(2))
.ok_or_else(|| AnamnesisError::Parse {
reason: "output size overflow".into(),
})?;
let mut output = vec![0u8; out_byte_len];
let mut unpacked_buf = vec![0.0_f32; out_features];
let mut zeros_buf = vec![0.0_f32; out_features];
let mut scales_buf = vec![0.0_f32; out_features];
let mut cached_group: Option<usize> = None;
#[allow(clippy::as_conversions)]
let bits_u32 = u32::from(bits);
let mask = (1u32 << bits_u32) - 1;
for i in 0..in_features {
let g = i / group_size;
let qw_row_start = i
.checked_mul(packed_cols)
.and_then(|n| n.checked_mul(4))
.ok_or_else(|| AnamnesisError::Parse {
reason: "qweight row byte offset overflow".into(),
})?;
let qw_row_end = qw_row_start + packed_cols * 4;
let qw_row =
qweight_data
.get(qw_row_start..qw_row_end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("qweight row {i} out of bounds"),
})?;
if cached_group != Some(g) {
unpack_zeros_for_group(&mut zeros_buf, qzeros_data, g, out_features, bits)?;
unpack_scales_for_group(&mut scales_buf, scales_data, g, out_features, scale_dtype)?;
cached_group = Some(g);
}
let zeros_row = &zeros_buf[..];
let scales_row = &scales_buf[..];
let out_row_start = i
.checked_mul(out_features)
.and_then(|n| n.checked_mul(2))
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("output row {i} offset overflow"),
})?;
let out_row_end = out_features
.checked_mul(2)
.and_then(|row_bytes| out_row_start.checked_add(row_bytes))
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("output row {i} end overflow"),
})?;
let out_row =
output
.get_mut(out_row_start..out_row_end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("output row {i} out of bounds"),
})?;
let unpacked_row =
unpacked_buf
.get_mut(..out_features)
.ok_or_else(|| AnamnesisError::Parse {
reason: "unpacked buffer too short".into(),
})?;
#[allow(clippy::indexing_slicing)]
for (packed_col, qw_chunk) in qw_row.chunks_exact(4).enumerate() {
let packed = u32::from_le_bytes([qw_chunk[0], qw_chunk[1], qw_chunk[2], qw_chunk[3]]);
for pos in 0..pack_factor {
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
let shift = bits_u32 * (pos as u32);
let qw = (packed >> shift) & mask;
#[allow(clippy::as_conversions, clippy::cast_precision_loss)]
let qw_f32 = qw as f32;
unpacked_row[packed_col * pack_factor + pos] = qw_f32;
}
}
for (((out_pair, &qw), &zero), &scale) in out_row
.chunks_exact_mut(2)
.zip(unpacked_row.iter())
.zip(zeros_row.iter())
.zip(scales_row.iter())
{
let val = (qw - zero) * scale;
let bf16 = f32_bits_to_bf16_bits(val.to_bits());
out_pair.copy_from_slice(&bf16.to_le_bytes());
}
}
Ok(output)
}
#[cfg(test)]
#[allow(
clippy::panic,
clippy::indexing_slicing,
clippy::unwrap_used,
clippy::as_conversions,
clippy::cast_possible_truncation,
clippy::float_cmp
)]
mod tests {
use super::*;
#[test]
fn dequant_4bit_uniform() {
let in_features = 8;
let out_features = 8;
let group_size = 8;
let bits: u8 = 4;
let qweight_i32 = 0xAAAA_AAAAu32;
let mut qweight_data = Vec::new();
for _i in 0..in_features {
qweight_data.extend_from_slice(&qweight_i32.to_le_bytes());
}
let scale_f16 = half::f16::from_f32(2.0).to_le_bytes();
let mut scales_data = Vec::new();
for _ in 0..out_features {
scales_data.extend_from_slice(&scale_f16);
}
let qzeros_i32 = 0x3333_3333u32;
let qzeros_data = qzeros_i32.to_le_bytes().to_vec();
let output = dequantize_awq_to_bf16(
&qweight_data,
&scales_data,
&qzeros_data,
in_features,
out_features,
group_size,
bits,
Dtype::F16,
)
.unwrap();
assert_eq!(output.len(), in_features * out_features * 2);
let bf16_14 = f32_bits_to_bf16_bits(14.0_f32.to_bits());
for chunk in output.chunks_exact(2) {
let actual = u16::from_le_bytes([chunk[0], chunk[1]]);
assert_eq!(actual, bf16_14, "expected BF16 14.0");
}
}
#[test]
fn dequant_4bit_no_plus_one_offset() {
let in_features = 8;
let out_features = 8;
let group_size = 8;
let bits: u8 = 4;
let qweight_i32 = 0x5555_5555u32;
let mut qweight_data = Vec::new();
for _i in 0..in_features {
qweight_data.extend_from_slice(&qweight_i32.to_le_bytes());
}
let scale_f16 = half::f16::from_f32(1.0).to_le_bytes();
let mut scales_data = Vec::new();
for _ in 0..out_features {
scales_data.extend_from_slice(&scale_f16);
}
let qzeros_i32 = 0x3333_3333u32;
let qzeros_data = qzeros_i32.to_le_bytes().to_vec();
let output = dequantize_awq_to_bf16(
&qweight_data,
&scales_data,
&qzeros_data,
in_features,
out_features,
group_size,
bits,
Dtype::F16,
)
.unwrap();
let bf16_2 = f32_bits_to_bf16_bits(2.0_f32.to_bits());
for chunk in output.chunks_exact(2) {
let actual = u16::from_le_bytes([chunk[0], chunk[1]]);
assert_eq!(actual, bf16_2, "expected BF16 2.0 (AWQ: no +1 offset)");
}
}
#[test]
fn dequant_8bit_uniform() {
let in_features = 4;
let out_features = 4;
let group_size = 4;
let bits: u8 = 8;
let qweight_i32 = 0x6464_6464u32; let mut qweight_data = Vec::new();
for _i in 0..in_features {
qweight_data.extend_from_slice(&qweight_i32.to_le_bytes());
}
let scale_half = half::f16::from_f32(0.5).to_le_bytes();
let mut scales_data = Vec::new();
for _ in 0..out_features {
scales_data.extend_from_slice(&scale_half);
}
let qzeros_i32 = 0x3232_3232u32; let qzeros_data = qzeros_i32.to_le_bytes().to_vec();
let output = dequantize_awq_to_bf16(
&qweight_data,
&scales_data,
&qzeros_data,
in_features,
out_features,
group_size,
bits,
Dtype::F16,
)
.unwrap();
let bf16_25 = f32_bits_to_bf16_bits(25.0_f32.to_bits());
for chunk in output.chunks_exact(2) {
let actual = u16::from_le_bytes([chunk[0], chunk[1]]);
assert_eq!(actual, bf16_25, "expected BF16 25.0");
}
}
#[test]
fn dequant_4bit_two_groups() {
let in_features = 8;
let out_features = 8;
let group_size = 4;
let bits: u8 = 4;
let qweight_i32 = 0x8888_8888u32;
let mut qweight_data = Vec::new();
for _i in 0..in_features {
qweight_data.extend_from_slice(&qweight_i32.to_le_bytes());
}
let mut scales_data = Vec::new();
for _ in 0..out_features {
scales_data.extend_from_slice(&half::f16::from_f32(1.0).to_le_bytes());
}
for _ in 0..out_features {
scales_data.extend_from_slice(&half::f16::from_f32(3.0).to_le_bytes());
}
let qz_g0 = 0x6666_6666u32;
let qz_g1 = 0x2222_2222u32;
let mut qzeros_data = Vec::new();
qzeros_data.extend_from_slice(&qz_g0.to_le_bytes());
qzeros_data.extend_from_slice(&qz_g1.to_le_bytes());
let output = dequantize_awq_to_bf16(
&qweight_data,
&scales_data,
&qzeros_data,
in_features,
out_features,
group_size,
bits,
Dtype::F16,
)
.unwrap();
let bf16_2 = f32_bits_to_bf16_bits(2.0_f32.to_bits());
for i in 0..4 {
for j in 0..out_features {
let offset = (i * out_features + j) * 2;
let actual = u16::from_le_bytes([output[offset], output[offset + 1]]);
assert_eq!(actual, bf16_2, "element [{i},{j}]: expected BF16 2.0");
}
}
let bf16_18 = f32_bits_to_bf16_bits(18.0_f32.to_bits());
for i in 4..8 {
for j in 0..out_features {
let offset = (i * out_features + j) * 2;
let actual = u16::from_le_bytes([output[offset], output[offset + 1]]);
assert_eq!(actual, bf16_18, "element [{i},{j}]: expected BF16 18.0");
}
}
}
#[test]
fn validation_unsupported_bits() {
let result = dequantize_awq_to_bf16(&[], &[], &[], 8, 8, 8, 3, Dtype::F16);
assert!(result.is_err());
}
#[test]
fn validation_zero_dimensions() {
let result = dequantize_awq_to_bf16(&[], &[], &[], 0, 8, 8, 4, Dtype::F16);
assert!(result.is_err());
}
#[test]
fn validation_qweight_length_mismatch() {
let result = dequantize_awq_to_bf16(
&[0u8; 16], &[0u8; 16], &[0u8; 4], 8,
8,
8,
4,
Dtype::F16,
);
assert!(result.is_err());
}
}