use candle_core::{Device, IndexOp, Result, Tensor};
const EPSILON: f32 = 1e-6;
#[derive(Debug, Clone)]
pub struct PackedTensor {
pub data: Tensor, pub scale: f32,
pub adaptive_scales: Option<Tensor>,
pub shape: candle_core::Shape, pub num_elem: usize,
pub device: Device,
}
impl PackedTensor {
pub fn new(
data: Vec<u8>,
shape: candle_core::Shape,
scale: f32,
device: &Device,
) -> Result<Self> {
let num_elem = shape.elem_count();
let capacity = num_elem.div_ceil(4);
let tensor = Tensor::from_vec(data, (capacity,), device)?;
Ok(Self {
data: tensor,
scale,
adaptive_scales: None,
shape: shape.clone(),
num_elem,
device: device.clone(),
})
}
pub fn new_adaptive(
data: Vec<u8>,
shape: candle_core::Shape,
scale: f32,
adaptive_scales: Tensor,
device: &Device,
) -> Result<Self> {
let num_elem = shape.elem_count();
let capacity = num_elem.div_ceil(4);
let tensor = Tensor::from_vec(data, (capacity,), device)?;
Ok(Self {
data: tensor,
scale,
adaptive_scales: Some(adaptive_scales),
shape: shape.clone(),
num_elem,
device: device.clone(),
})
}
pub fn pack(tensor: &Tensor) -> Result<Self> {
let device = tensor.device();
let shape = tensor.shape().clone();
let num_elem = shape.elem_count();
let scale_t = tensor.abs()?.max_all()?;
let scale = scale_t.to_scalar::<f32>()? + EPSILON;
let w_scaled = (tensor / scale as f64)?;
let w_quant = w_scaled
.round()?
.clamp(-1.0, 1.0)?
.to_dtype(candle_core::DType::F32)?;
let flat = w_quant.flatten_all()?;
let vec = flat.to_vec1::<f32>()?;
let capacity = num_elem.div_ceil(4);
let mut packed_data = Vec::with_capacity(capacity);
for chunk in vec.chunks(4) {
let mut byte: u8 = 0;
for (i, &val) in chunk.iter().enumerate() {
let code: u8 = if val > 0.5 {
1 } else if val < -0.5 {
2 } else {
0 };
byte |= code << (i * 2);
}
packed_data.push(byte);
}
let data_tensor =
Tensor::from_vec(packed_data, (capacity,), &Device::Cpu)?.to_device(device)?;
Ok(Self {
data: data_tensor,
scale,
adaptive_scales: None, shape,
num_elem,
device: device.clone(),
})
}
pub fn with_adaptive_scales(self, scales: Tensor) -> Result<Self> {
Ok(Self {
adaptive_scales: Some(scales),
..self
})
}
#[inline]
pub fn supports_fused_kernel(&self) -> bool {
self.adaptive_scales.is_some()
}
pub fn from_loaded(
packed_data: Tensor,
scales: Tensor,
out_dim: usize,
in_dim: usize,
device: &Device,
) -> Result<Self> {
let packed_dims = packed_data.dims();
let scales_vec = scales.to_vec1::<f32>()?;
let n_bases = scales_vec.len();
if packed_dims.len() == 3 {
let (d0, d1, d2) = (packed_dims[0], packed_dims[1], packed_dims[2]);
if d0 != out_dim || d1 != in_dim / 4 || d2 != n_bases {
return Err(candle_core::Error::Msg(format!(
"Packed tensor shape mismatch: expected [{}, {}, {}], got [{}, {}, {}]",
out_dim,
in_dim / 4,
n_bases,
d0,
d1,
d2
)));
}
} else if packed_dims.len() == 2 {
let (d0, d1) = (packed_dims[0], packed_dims[1]);
if d0 != out_dim || d1 != in_dim / 4 {
return Err(candle_core::Error::Msg(format!(
"Packed tensor shape mismatch: expected [{}, {}], got [{}, {}]",
out_dim,
in_dim / 4,
d0,
d1
)));
}
}
let shape = candle_core::Shape::from((out_dim, in_dim));
let num_elem = out_dim * in_dim;
if n_bases == 1 {
let flat_packed = if packed_dims.len() == 3 {
packed_data.squeeze(2)?.flatten_all()?
} else {
packed_data.flatten_all()?
};
let data = flat_packed.to_device(device)?;
Ok(Self {
data,
scale: scales_vec[0],
adaptive_scales: None,
shape,
num_elem,
device: device.clone(),
})
} else {
let packed_per_base = out_dim * in_dim / 4;
let total_packed = packed_per_base * n_bases;
let mut flat_data = Vec::with_capacity(total_packed);
tracing::debug!(
"🔍 [from_loaded] packed_data shape: {:?}, is_contiguous: {}",
packed_data.dims(),
packed_data.is_contiguous()
);
for base_idx in 0..n_bases {
let base_slice = packed_data.i((.., .., base_idx))?;
tracing::debug!(
"🔍 [from_loaded] base[{}] slice shape: {:?}, is_contiguous: {}",
base_idx,
base_slice.dims(),
base_slice.is_contiguous()
);
let base_packed = base_slice.contiguous()?;
let base_vec = base_packed.flatten_all()?.to_vec1::<u8>()?;
tracing::debug!(
"🔍 [from_loaded] base[{}] first 8 bytes: {:?}",
base_idx,
&base_vec[..8.min(base_vec.len())]
);
flat_data.extend(base_vec);
}
let flat_packed =
Tensor::from_vec(flat_data, (total_packed,), &Device::Cpu)?.to_device(device)?;
let adaptive_scales = scales.to_device(device)?;
let primary_scale = scales_vec[0];
Ok(Self {
data: flat_packed,
scale: primary_scale,
adaptive_scales: Some(adaptive_scales),
shape,
num_elem,
device: device.clone(),
})
}
}
pub fn unpack(&self, device: &Device) -> Result<Tensor> {
if let Some(ref scales_tensor) = self.adaptive_scales {
let scales_vec = scales_tensor.to_vec1::<f32>()?;
let n_bases = scales_vec.len();
let packed_per_base = self.num_elem.div_ceil(4);
let data_vec = self.data.to_vec1::<u8>()?;
let mut combined = vec![0.0f32; self.num_elem];
const LUT: [f32; 4] = [0.0, 1.0, -1.0, 0.0];
#[allow(clippy::needless_range_loop)]
for base_idx in 0..n_bases {
let scale = scales_vec[base_idx];
let base_start = base_idx * packed_per_base;
let mut weight_idx = 0;
for byte_idx in 0..packed_per_base {
if base_start + byte_idx >= data_vec.len() {
break;
}
let byte = data_vec[base_start + byte_idx];
for bit_pos in 0..4 {
if weight_idx >= self.num_elem {
break;
}
let code = (byte >> (bit_pos * 2)) & 0b11;
let val = LUT[code as usize];
combined[weight_idx] += val * scale;
weight_idx += 1;
}
}
}
let t = Tensor::from_vec(combined, self.shape.clone(), device)?;
Ok(t)
} else {
let data_vec = self.data.to_vec1::<u8>()?;
let mut floats = Vec::with_capacity(self.num_elem);
for &byte in &data_vec {
for i in 0..4 {
if floats.len() >= self.num_elem {
break;
}
let code = (byte >> (i * 2)) & 0b11;
let val: f32 = match code {
1 => 1.0,
2 => -1.0,
_ => 0.0,
};
floats.push(val);
}
}
let t = Tensor::from_vec(floats, self.shape.clone(), device)?;
(t * self.scale as f64)?.to_dtype(candle_core::DType::F32)
}
}
pub fn is_multibase(&self) -> bool {
self.adaptive_scales.is_some()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_tensor_approx_eq(a: &[f32], b: &[f32], tol: f32) {
assert_eq!(a.len(), b.len(), "Tensor lengths mismatch");
for (i, (v1, v2)) in a.iter().zip(b.iter()).enumerate() {
assert!(
(v1 - v2).abs() < tol,
"Mismatch at index {}: {} vs {} (tol {})",
i,
v1,
v2,
tol
);
}
}
#[test]
fn test_packing_cycle_dense() -> Result<()> {
let input_data = vec![1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0];
let tensor = Tensor::new(&input_data[..], &Device::Cpu)?;
let packed = PackedTensor::pack(&tensor)?;
assert!((packed.scale - 1.0).abs() < 1e-3);
let unpacked = packed.unpack(&Device::Cpu)?;
let output_data = unpacked.to_vec1::<f32>()?;
assert_tensor_approx_eq(&input_data, &output_data, 1e-4);
Ok(())
}
#[test]
#[ignore] fn test_packing_cycle_sparse() -> Result<()> {
let input_data: Vec<f32> = vec![1.0, -1.0, 0.0, 0.0, 1.0, -1.0, 0.0, 0.0];
let tensor = Tensor::new(&input_data[..], &Device::Cpu)?;
let packed = PackedTensor::pack(&tensor)?;
assert!((packed.scale - 0.5).abs() < 1e-3);
let unpacked = packed.unpack(&Device::Cpu)?;
let output_data = unpacked.to_vec1::<f32>()?;
let expected_data = vec![0.5, -0.5, 0.0, 0.0, 0.5, -0.5, 0.0, 0.0];
assert_tensor_approx_eq(&output_data, &expected_data, 1e-4);
Ok(())
}
#[test]
fn test_packing_manual() -> Result<()> {
let data = vec![73u8];
let shape = candle_core::Shape::from((4,));
let scale = 1.0;
let device = Device::Cpu;
let packed = PackedTensor::new(data, shape, scale, &device)?;
let unpacked = packed.unpack(&device)?;
let output = unpacked.to_vec1::<f32>()?;
assert_eq!(output[0], 1.0);
assert_eq!(output[1], -1.0);
assert_eq!(output[2], 0.0);
assert_eq!(output[3], 1.0);
Ok(())
}
#[test]
fn test_packing_padding() -> Result<()> {
let input_data = vec![1.0, 1.0, 1.0, 1.0, -1.0];
let tensor = Tensor::new(&input_data[..], &Device::Cpu)?;
let packed = PackedTensor::pack(&tensor)?;
assert_eq!(packed.data.dims1()?, 2);
let data = packed.data.to_vec1::<u8>()?;
assert_eq!(data[0], 85);
assert_eq!(data[1], 2);
let unpacked = packed.unpack(&Device::Cpu)?;
let output_data = unpacked.to_vec1::<f32>()?;
assert_tensor_approx_eq(&input_data, &output_data, 1e-4);
Ok(())
}
}