use candle_core::{IndexOp, Result, Tensor};
const EPSILON: f32 = 1e-5;
pub fn pack_4bit_symmetric(tensor: &Tensor, group_size: usize) -> Result<(Tensor, Tensor)> {
let device = tensor.device();
let dims = tensor.dims();
if dims.len() != 2 {
return Err(candle_core::Error::Msg(format!(
"Expected 2D tensor [out_dim, in_dim], got shape {:?}",
dims
)));
}
let (out_dim, in_dim) = (dims[0], dims[1]);
let pad_size = if in_dim % group_size == 0 {
0
} else {
group_size - (in_dim % group_size)
};
let padded_in_dim = in_dim + pad_size;
let n_groups = padded_in_dim / group_size;
let tensor_f32 = tensor.to_dtype(candle_core::DType::F32)?;
let tensor_2d = if pad_size > 0 {
let zeros_pad = Tensor::zeros((out_dim, pad_size), candle_core::DType::F32, device)?;
Tensor::cat(&[&tensor_f32, &zeros_pad], 1)?
} else {
tensor_f32.clone()
};
let mut all_packed_data = Vec::new();
let mut all_scales = Vec::new();
for row_idx in 0..out_dim {
let row = tensor_2d.i((row_idx, ..))?; let row_data = row.to_vec1::<f32>()?;
let mut row_scales = Vec::new();
let mut row_packed = Vec::new();
for group_idx in 0..n_groups {
let group_start = group_idx * group_size;
let group_end = (group_start + group_size).min(padded_in_dim);
let group = &row_data[group_start..group_end];
let max_abs = group.iter().map(|&x: &f32| x.abs()).fold(0.0f32, f32::max);
let scale = (max_abs / 7.0).max(EPSILON);
row_scales.push(scale);
let mut quantized_group = Vec::new();
for &weight in group {
let q = (weight / scale).round().clamp(-8.0, 7.0) as i8;
quantized_group.push(q);
}
for i in (0..group_size).step_by(2) {
let low = if i < quantized_group.len() {
quantized_group[i]
} else {
0i8
};
let high = if i + 1 < quantized_group.len() {
quantized_group[i + 1]
} else {
0i8
};
let low_unsigned = (low + 8) as u8;
let high_unsigned = (high + 8) as u8;
let packed_byte = low_unsigned | (high_unsigned << 4);
row_packed.push(packed_byte);
}
}
all_scales.extend(row_scales);
all_packed_data.extend(row_packed);
}
let bytes_per_group = group_size.div_ceil(2);
let packed_shape = (out_dim, n_groups * bytes_per_group);
let scales_shape = (out_dim, n_groups);
let packed_tensor = Tensor::from_vec(all_packed_data, packed_shape, device)?;
let scales_tensor = Tensor::from_vec(all_scales, scales_shape, device)?;
Ok((packed_tensor, scales_tensor))
}
pub fn unpack_4bit_symmetric(
packed: &Tensor,
scales: &Tensor,
original_shape: (usize, usize),
group_size: usize,
) -> Result<Tensor> {
let device = packed.device();
let (out_dim, in_dim) = original_shape;
let pad_size = if in_dim % group_size == 0 {
0
} else {
group_size - (in_dim % group_size)
};
let padded_in_dim = in_dim + pad_size;
let n_groups = padded_in_dim / group_size;
let bytes_per_group = group_size.div_ceil(2);
let packed_dims = packed.dims();
let scales_dims = scales.dims();
if packed_dims != [out_dim, n_groups * bytes_per_group] {
return Err(candle_core::Error::Msg(format!(
"Packed tensor shape mismatch: expected [{}, {}], got {:?}",
out_dim,
n_groups * bytes_per_group,
packed_dims
)));
}
if scales_dims != [out_dim, n_groups] {
return Err(candle_core::Error::Msg(format!(
"Scales tensor shape mismatch: expected [{}, {}], got {:?}",
out_dim, n_groups, scales_dims
)));
}
let packed_data = packed.flatten_all()?.to_vec1::<u8>()?;
let scales_f32 = scales.to_dtype(candle_core::DType::F32)?;
let scales_data = scales_f32.flatten_all()?.to_vec1::<f32>()?;
let mut result_data: Vec<f32> = Vec::new();
for row_idx in 0..out_dim {
let mut row_data: Vec<f32> = Vec::new();
for group_idx in 0..n_groups {
let scale = scales_data[row_idx * n_groups + group_idx];
let group_start_packed =
(row_idx * n_groups * bytes_per_group) + (group_idx * bytes_per_group);
let mut group_weights = Vec::new();
for byte_idx in 0..bytes_per_group {
let packed_byte = packed_data[group_start_packed + byte_idx];
let low_unsigned = packed_byte & 0x0F;
let high_unsigned = (packed_byte >> 4) & 0x0F;
let low_signed = (low_unsigned as i8) - 8;
let high_signed = (high_unsigned as i8) - 8;
group_weights.push(low_signed as f32 * scale);
if group_weights.len() < group_size {
group_weights.push(high_signed as f32 * scale);
}
}
let group_size_actual = group_size.min(padded_in_dim - group_idx * group_size);
row_data.extend(&group_weights[..group_size_actual]);
}
if row_data.len() > in_dim {
row_data.truncate(in_dim);
}
result_data.extend(row_data);
}
let result_tensor = Tensor::from_vec(result_data, original_shape, device)?;
Ok(result_tensor)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
fn assert_tensor_approx_eq(a: &Tensor, b: &Tensor, tol: f32) -> Result<()> {
let a_f32 = a.to_dtype(candle_core::DType::F32)?;
let b_f32 = b.to_dtype(candle_core::DType::F32)?;
let a_vec = a_f32.flatten_all()?.to_vec1::<f32>()?;
let b_vec = b_f32.flatten_all()?.to_vec1::<f32>()?;
assert_eq!(a_vec.len(), b_vec.len(), "Tensor lengths mismatch");
for (i, (v1, v2)) in a_vec.iter().zip(b_vec.iter()).enumerate() {
assert!(
(v1 - v2).abs() < tol,
"Mismatch at index {}: {} vs {} (tol {})",
i,
v1,
v2,
tol
);
}
Ok(())
}
#[test]
fn test_4bit_packing_simple() -> Result<()> {
let device = Device::Cpu;
let input_data = [7.0, -5.6, 0.0, 3.5];
let tensor = Tensor::new(&input_data[..], &device)?.reshape((1, 4))?;
let (packed, scales) = pack_4bit_symmetric(&tensor, 4)?;
assert_eq!(packed.dims(), &[1, 2]); assert_eq!(scales.dims(), &[1, 1]);
let unpacked = unpack_4bit_symmetric(&packed, &scales, (1, 4), 4)?;
assert_tensor_approx_eq(&tensor, &unpacked, 1.0)?;
Ok(())
}
#[test]
fn test_4bit_packing_multi_group() -> Result<()> {
let device = Device::Cpu;
let input_data = vec![
1.0, -2.0, 3.0, 4.0, -1.0, 0.0, -7.0, 2.0, 1.0, 0.5, -0.2, 3.0, ];
let tensor = Tensor::new(&input_data[..], &device)?.reshape((2, 6))?;
let (packed, scales) = pack_4bit_symmetric(&tensor, 3)?;
assert_eq!(packed.dims(), &[2, 4]); assert_eq!(scales.dims(), &[2, 2]);
let unpacked = unpack_4bit_symmetric(&packed, &scales, (2, 6), 3)?;
assert_tensor_approx_eq(&tensor, &unpacked, 0.5)?;
Ok(())
}
#[test]
fn test_4bit_packing_padding() -> Result<()> {
let device = Device::Cpu;
let input_data = [1.0, -2.0, 3.0, -4.0, 2.0];
let tensor = Tensor::new(&input_data[..], &device)?.reshape((1, 5))?;
let (packed, scales) = pack_4bit_symmetric(&tensor, 4)?;
assert_eq!(packed.dims(), &[1, 4]); assert_eq!(scales.dims(), &[1, 2]);
let unpacked = unpack_4bit_symmetric(&packed, &scales, (1, 5), 4)?;
let original_slice = tensor.narrow(1, 0, 5)?;
let unpacked_slice = unpacked.narrow(1, 0, 5)?;
assert_tensor_approx_eq(&original_slice, &unpacked_slice, 0.5)?;
Ok(())
}
#[test]
fn test_4bit_quantization_range() -> Result<()> {
let device = Device::Cpu;
let input_data = [10.0f32, -15.0, 0.0, 7.0]; let tensor = Tensor::new(&input_data[..], &device)?.reshape((1, 4))?;
let (packed, scales) = pack_4bit_symmetric(&tensor, 4)?;
let unpacked = unpack_4bit_symmetric(&packed, &scales, (1, 4), 4)?;
let unpacked_vec = unpacked.flatten_all()?.to_vec1::<f32>()?;
let scale = scales.flatten_all()?.to_vec1::<f32>()?[0];
let max_val = 7.0 * scale; let min_val = -8.0 * scale;
for &val in &unpacked_vec {
assert!(
val >= min_val - 1e-5 && val <= max_val + 1e-5,
"Unpacked value {} outside expected range [{}, {}]",
val,
min_val,
max_val
);
}
Ok(())
}
#[test]
fn test_4bit_round_trip_error() -> Result<()> {
let device = Device::Cpu;
let input_data: Vec<f32> = (0..32).map(|i| (i as f32 - 16.0) * 0.1).collect(); let tensor = Tensor::new(&input_data[..], &device)?.reshape((4, 8))?;
let (packed, scales) = pack_4bit_symmetric(&tensor, 8)?;
let unpacked = unpack_4bit_symmetric(&packed, &scales, (4, 8), 8)?;
let tensor_f32 = tensor.to_dtype(candle_core::DType::F32)?;
let diff = (&tensor_f32 - &unpacked)?;
let mae = diff.abs()?.mean_all()?.to_scalar::<f32>()?;
println!("4-bit Round-trip MAE: {:.6}", mae);
assert!(mae < 0.05, "Round-trip error too high: MAE = {}", mae);
Ok(())
}
#[test]
fn test_4bit_standalone_functionality() -> Result<()> {
println!("🧪 Testing 4-bit Packing Functions");
let device = Device::Cpu;
println!("Test 1: Simple 4-bit packing...");
let input_data: Vec<f32> = vec![7.0, -5.6, 0.0, 3.5];
let tensor = Tensor::new(&input_data[..], &device)?.reshape((1, 4))?;
let (packed, scales) = pack_4bit_symmetric(&tensor, 4)?;
assert_eq!(packed.dims(), &[1, 2]); assert_eq!(scales.dims(), &[1, 1]);
let unpacked = unpack_4bit_symmetric(&packed, &scales, (1, 4), 4)?;
let unpacked_data = unpacked.flatten_all()?.to_vec1::<f32>()?;
let mut total_error = 0.0;
for i in 0..input_data.len() {
total_error += (input_data[i] - unpacked_data[i]).abs();
}
let mae = total_error / input_data.len() as f32;
println!(" MAE: {:.6}", mae);
assert!(mae < 0.5, "MAE too high: {}", mae);
println!("Test 2: Multi-group 4-bit packing...");
let input_data: Vec<f32> = vec![
1.0, -2.0, 3.0, 4.0, -1.0, 0.0, -7.0, 2.0, 1.0, 0.5, -0.2, 3.0, ];
let tensor = Tensor::new(&input_data[..], &device)?.reshape((2, 6))?;
let (packed, scales) = pack_4bit_symmetric(&tensor, 3)?;
assert_eq!(packed.dims(), &[2, 4]); assert_eq!(scales.dims(), &[2, 2]);
let unpacked = unpack_4bit_symmetric(&packed, &scales, (2, 6), 3)?;
let diff = (&tensor - &unpacked)?;
let mae = diff.abs()?.mean_all()?.to_scalar::<f32>()?;
println!(" MAE: {:.6}", mae);
assert!(mae < 0.5, "MAE too high: {}", mae);
println!("Test 3: Padding test...");
let input_data: Vec<f32> = vec![1.0, -2.0, 3.0, -4.0, 2.0];
let tensor = Tensor::new(&input_data[..], &device)?.reshape((1, 5))?;
let (packed, scales) = pack_4bit_symmetric(&tensor, 4)?;
assert_eq!(packed.dims(), &[1, 4]); assert_eq!(scales.dims(), &[1, 2]);
let unpacked = unpack_4bit_symmetric(&packed, &scales, (1, 5), 4)?;
let original_slice = tensor.narrow(1, 0, 5)?;
let unpacked_slice = unpacked.narrow(1, 0, 5)?;
let diff = (&original_slice - &unpacked_slice)?;
let mae = diff.abs()?.mean_all()?.to_scalar::<f32>()?;
println!(" MAE: {:.6}", mae);
assert!(mae < 0.5, "MAE too high: {}", mae);
println!("✅ All 4-bit packing tests passed!");
Ok(())
}
}