use crate::error::{Error, Result};
use crate::quant::QuantFormat;
use numr::runtime::Runtime;
use numr::tensor::Storage;
pub struct QuantTensor<R: Runtime> {
storage: Storage<R>,
format: QuantFormat,
shape: Vec<usize>,
device: R::Device,
}
impl<R: Runtime<DType = numr::dtype::DType>> QuantTensor<R> {
pub fn from_bytes(
data: &[u8],
format: QuantFormat,
shape: &[usize],
device: &R::Device,
) -> Result<Self> {
if shape.is_empty() {
return Err(Error::QuantError {
reason: "QuantTensor shape must be non-empty".into(),
});
}
let last_dim = shape[shape.len() - 1];
if last_dim % format.block_size() != 0 {
return Err(Error::QuantError {
reason: format!(
"last dimension {} is not a multiple of {}'s block_size {}",
last_dim,
format.name(),
format.block_size(),
),
});
}
let numel: usize = shape.iter().product();
let expected_bytes = format.storage_bytes(numel)?;
if data.len() != expected_bytes {
return Err(Error::QuantError {
reason: format!(
"expected {} bytes for {} with {} elements, got {} bytes",
expected_bytes,
format.name(),
numel,
data.len(),
),
});
}
let storage =
Storage::<R>::from_bytes(data, numr::dtype::DType::U8, device).map_err(Error::Numr)?;
Ok(Self {
storage,
format,
shape: shape.to_vec(),
device: device.clone(),
})
}
pub fn format(&self) -> QuantFormat {
self.format
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
pub fn num_blocks(&self) -> usize {
self.numel() / self.format.block_size()
}
pub fn storage_bytes(&self) -> usize {
self.num_blocks() * self.format.block_bytes()
}
pub fn device(&self) -> &R::Device {
&self.device
}
pub fn storage(&self) -> &Storage<R> {
&self.storage
}
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::{CpuDevice, CpuRuntime};
fn cpu_device() -> CpuDevice {
CpuDevice::new()
}
#[test]
fn test_create_q4_0() {
let device = cpu_device();
let data = vec![0u8; 18]; let qt = QuantTensor::<CpuRuntime>::from_bytes(&data, QuantFormat::Q4_0, &[32], &device)
.unwrap();
assert_eq!(qt.format(), QuantFormat::Q4_0);
assert_eq!(qt.shape(), &[32]);
assert_eq!(qt.numel(), 32);
assert_eq!(qt.num_blocks(), 1);
assert_eq!(qt.storage_bytes(), 18);
}
#[test]
fn test_create_q4k_matrix() {
let device = cpu_device();
let numel = 4096 * 4096;
let num_blocks = numel / 256;
let data = vec![0u8; num_blocks * 144];
let qt =
QuantTensor::<CpuRuntime>::from_bytes(&data, QuantFormat::Q4K, &[4096, 4096], &device)
.unwrap();
assert_eq!(qt.shape(), &[4096, 4096]);
assert_eq!(qt.numel(), numel);
assert_eq!(qt.num_blocks(), num_blocks);
}
#[test]
fn test_alignment_error() {
let device = cpu_device();
let data = vec![0u8; 18];
let result =
QuantTensor::<CpuRuntime>::from_bytes(&data, QuantFormat::Q4_0, &[33], &device);
assert!(result.is_err());
}
#[test]
fn test_size_mismatch_error() {
let device = cpu_device();
let data = vec![0u8; 10]; let result =
QuantTensor::<CpuRuntime>::from_bytes(&data, QuantFormat::Q4_0, &[32], &device);
assert!(result.is_err());
}
#[test]
fn test_empty_shape_error() {
let device = cpu_device();
let data = vec![0u8; 18];
let result = QuantTensor::<CpuRuntime>::from_bytes(&data, QuantFormat::Q4_0, &[], &device);
assert!(result.is_err());
}
#[test]
fn test_multi_block() {
let device = cpu_device();
let data = vec![0u8; 4 * 34];
let qt = QuantTensor::<CpuRuntime>::from_bytes(&data, QuantFormat::Q8_0, &[128], &device)
.unwrap();
assert_eq!(qt.num_blocks(), 4);
assert_eq!(qt.storage_bytes(), 136);
}
}