use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::Runtime;
use crate::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct Sparse24Tensor<R: Runtime> {
pub(crate) compressed_values: Tensor<R>,
pub(crate) metadata: Tensor<R>,
pub(crate) original_shape: [usize; 2],
pub(crate) dtype: DType,
}
impl<R: Runtime<DType = DType>> Sparse24Tensor<R> {
pub fn new(
compressed_values: Tensor<R>,
metadata: Tensor<R>,
original_shape: [usize; 2],
) -> Result<Self> {
let [m, k] = original_shape;
if k % 4 != 0 {
return Err(Error::InvalidArgument {
arg: "original_shape",
reason: format!("K dimension ({k}) must be divisible by 4 for 2:4 sparsity"),
});
}
let expected_val_shape = [m, k / 2];
if compressed_values.shape() != expected_val_shape {
return Err(Error::ShapeMismatch {
expected: expected_val_shape.to_vec(),
got: compressed_values.shape().to_vec(),
});
}
let num_groups = k / 4;
let meta_cols = (num_groups + 7) / 8; let expected_meta_shape = [m, meta_cols];
if metadata.shape() != expected_meta_shape {
return Err(Error::ShapeMismatch {
expected: expected_meta_shape.to_vec(),
got: metadata.shape().to_vec(),
});
}
if metadata.dtype() != DType::U32 {
return Err(Error::DTypeMismatch {
lhs: DType::U32,
rhs: metadata.dtype(),
});
}
let dtype = compressed_values.dtype();
Ok(Self {
compressed_values,
metadata,
original_shape,
dtype,
})
}
#[inline]
pub fn shape(&self) -> [usize; 2] {
self.original_shape
}
#[inline]
pub fn nrows(&self) -> usize {
self.original_shape[0]
}
#[inline]
pub fn ncols(&self) -> usize {
self.original_shape[1]
}
#[inline]
pub fn dtype(&self) -> DType {
self.dtype
}
#[inline]
pub fn compressed_values(&self) -> &Tensor<R> {
&self.compressed_values
}
#[inline]
pub fn metadata(&self) -> &Tensor<R> {
&self.metadata
}
#[inline]
pub fn nnz(&self) -> usize {
self.original_shape[0] * (self.original_shape[1] / 2)
}
#[inline]
pub fn compression_ratio(&self) -> f64 {
2.0
}
#[inline]
pub fn groups_per_row(&self) -> usize {
self.original_shape[1] / 4
}
#[inline]
pub fn meta_cols(&self) -> usize {
(self.groups_per_row() + 7) / 8
}
pub fn is_valid(&self) -> bool
where
R: Runtime<DType = DType>,
{
let meta_data: Vec<u32> = self.metadata.to_vec();
let num_groups = self.groups_per_row();
for row in 0..self.nrows() {
for g in 0..num_groups {
let word_idx = g / 8;
let nibble_idx = g % 8;
let word = meta_data[row * self.meta_cols() + word_idx];
let nibble = (word >> (nibble_idx * 4)) & 0xF;
if nibble.count_ones() != 2 {
return false;
}
}
}
true
}
}
#[inline]
pub fn meta_cols_for_k(k: usize) -> usize {
let num_groups = k / 4;
(num_groups + 7) / 8
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_meta_cols_for_k() {
assert_eq!(meta_cols_for_k(4), 1); assert_eq!(meta_cols_for_k(8), 1); assert_eq!(meta_cols_for_k(32), 1); assert_eq!(meta_cols_for_k(36), 2); assert_eq!(meta_cols_for_k(64), 2); }
#[test]
fn test_k_must_be_divisible_by_4() {
use crate::runtime::cpu::{CpuDevice, CpuRuntime};
let device = CpuDevice::new();
let vals = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[1, 2], &device);
let meta = Tensor::<CpuRuntime>::from_slice(&[0u32], &[1, 1], &device);
let result = Sparse24Tensor::new(vals, meta, [1, 5]);
assert!(result.is_err());
}
}