use oxillama_gguf::GgufTensorType;
#[derive(Debug, Clone)]
pub struct QuantTensor {
pub data: Vec<u8>,
pub shape: Vec<usize>,
pub tensor_type: GgufTensorType,
}
impl QuantTensor {
pub fn new(data: Vec<u8>, shape: Vec<usize>, tensor_type: GgufTensorType) -> Self {
Self {
data,
shape,
tensor_type,
}
}
pub fn n_elements(&self) -> usize {
if self.shape.is_empty() {
return 0;
}
self.shape.iter().product()
}
pub fn n_blocks(&self) -> usize {
let block_size = self.tensor_type.block_size();
if block_size == 0 {
return 0;
}
self.n_elements().div_ceil(block_size)
}
pub fn expected_data_size(&self) -> usize {
self.n_blocks() * self.tensor_type.block_bytes()
}
}
#[derive(Debug, Clone, Copy)]
pub struct BlockInfo {
pub block_size: usize,
pub block_bytes: usize,
pub bits_per_weight: f32,
}
impl BlockInfo {
pub fn for_type(tensor_type: GgufTensorType) -> Self {
let block_size = tensor_type.block_size();
let block_bytes = tensor_type.block_bytes();
let bits_per_weight = if block_size > 0 {
(block_bytes as f32 * 8.0) / block_size as f32
} else {
0.0
};
Self {
block_size,
block_bytes,
bits_per_weight,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use oxillama_gguf::GgufTensorType;
#[test]
fn test_quant_tensor_n_elements_2d() {
let t = QuantTensor::new(vec![0u8; 32], vec![4, 8], GgufTensorType::Q8_0);
assert_eq!(t.n_elements(), 32);
}
#[test]
fn test_quant_tensor_n_elements_empty_shape() {
let t = QuantTensor::new(vec![], vec![], GgufTensorType::F32);
assert_eq!(t.n_elements(), 0);
}
#[test]
fn test_quant_tensor_n_blocks_q4_0() {
let block_bytes = GgufTensorType::Q4_0.block_bytes() * 2;
let t = QuantTensor::new(vec![0u8; block_bytes], vec![64], GgufTensorType::Q4_0);
assert_eq!(t.n_blocks(), 2);
}
#[test]
fn test_quant_tensor_expected_data_size_f32() {
let t = QuantTensor::new(vec![0u8; 20], vec![5], GgufTensorType::F32);
assert_eq!(t.expected_data_size(), 20); }
#[test]
fn test_block_info_for_q8_0() {
let info = BlockInfo::for_type(GgufTensorType::Q8_0);
assert_eq!(info.block_size, 32);
assert_eq!(info.block_bytes, 34); assert!(info.bits_per_weight > 0.0);
}
#[test]
fn test_block_info_bits_per_weight_q4_0() {
let info = BlockInfo::for_type(GgufTensorType::Q4_0);
let expected = (18.0f32 * 8.0) / 32.0;
assert!(
(info.bits_per_weight - expected).abs() < 0.01,
"bits_per_weight: {} vs expected {}",
info.bits_per_weight,
expected
);
}
#[test]
fn test_quant_tensor_clone() {
let t = QuantTensor::new(vec![1u8, 2, 3, 4], vec![2, 2], GgufTensorType::F32);
let t2 = t.clone();
assert_eq!(t2.data, t.data);
assert_eq!(t2.shape, t.shape);
}
}