Skip to main content

oxillama_quant/
types.rs

1//! Quantization data types and tensor wrapper.
2
3use oxillama_gguf::GgufTensorType;
4
5/// A quantized tensor — raw block data plus shape and type metadata.
6///
7/// The data is stored as raw bytes in the GGUF block format.
8/// Use a [`QuantKernel`](crate::traits::QuantKernel) to dequantize or
9/// perform fused operations.
10#[derive(Debug, Clone)]
11pub struct QuantTensor {
12    /// Raw block data (packed quantized weights).
13    pub data: Vec<u8>,
14    /// Tensor shape (e.g., [out_features, in_features] for a linear layer).
15    pub shape: Vec<usize>,
16    /// The GGUF quantization type.
17    pub tensor_type: GgufTensorType,
18}
19
20impl QuantTensor {
21    /// Create a new quantized tensor.
22    pub fn new(data: Vec<u8>, shape: Vec<usize>, tensor_type: GgufTensorType) -> Self {
23        Self {
24            data,
25            shape,
26            tensor_type,
27        }
28    }
29
30    /// Total number of elements (product of all dimensions).
31    pub fn n_elements(&self) -> usize {
32        if self.shape.is_empty() {
33            return 0;
34        }
35        self.shape.iter().product()
36    }
37
38    /// Number of quantized blocks in this tensor.
39    pub fn n_blocks(&self) -> usize {
40        let block_size = self.tensor_type.block_size();
41        if block_size == 0 {
42            return 0;
43        }
44        self.n_elements().div_ceil(block_size)
45    }
46
47    /// Expected total data size in bytes.
48    pub fn expected_data_size(&self) -> usize {
49        self.n_blocks() * self.tensor_type.block_bytes()
50    }
51}
52
53/// Information about a quantization block format.
54#[derive(Debug, Clone, Copy)]
55pub struct BlockInfo {
56    /// Number of weights per block.
57    pub block_size: usize,
58    /// Number of bytes per block.
59    pub block_bytes: usize,
60    /// Effective bits per weight.
61    pub bits_per_weight: f32,
62}
63
64impl BlockInfo {
65    /// Compute block info for a given GGUF tensor type.
66    pub fn for_type(tensor_type: GgufTensorType) -> Self {
67        let block_size = tensor_type.block_size();
68        let block_bytes = tensor_type.block_bytes();
69        let bits_per_weight = if block_size > 0 {
70            (block_bytes as f32 * 8.0) / block_size as f32
71        } else {
72            0.0
73        };
74        Self {
75            block_size,
76            block_bytes,
77            bits_per_weight,
78        }
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use oxillama_gguf::GgufTensorType;
86
87    #[test]
88    fn test_quant_tensor_n_elements_2d() {
89        let t = QuantTensor::new(vec![0u8; 32], vec![4, 8], GgufTensorType::Q8_0);
90        assert_eq!(t.n_elements(), 32);
91    }
92
93    #[test]
94    fn test_quant_tensor_n_elements_empty_shape() {
95        let t = QuantTensor::new(vec![], vec![], GgufTensorType::F32);
96        assert_eq!(t.n_elements(), 0);
97    }
98
99    #[test]
100    fn test_quant_tensor_n_blocks_q4_0() {
101        // Q4_0: 32 weights per block
102        // 64 elements → 2 blocks
103        let block_bytes = GgufTensorType::Q4_0.block_bytes() * 2;
104        let t = QuantTensor::new(vec![0u8; block_bytes], vec![64], GgufTensorType::Q4_0);
105        assert_eq!(t.n_blocks(), 2);
106    }
107
108    #[test]
109    fn test_quant_tensor_expected_data_size_f32() {
110        // F32: 1 weight per block, 4 bytes per block
111        let t = QuantTensor::new(vec![0u8; 20], vec![5], GgufTensorType::F32);
112        assert_eq!(t.expected_data_size(), 20); // 5 * 4
113    }
114
115    #[test]
116    fn test_block_info_for_q8_0() {
117        let info = BlockInfo::for_type(GgufTensorType::Q8_0);
118        assert_eq!(info.block_size, 32);
119        assert_eq!(info.block_bytes, 34); // 2 (scale) + 32 (quants)
120        assert!(info.bits_per_weight > 0.0);
121    }
122
123    #[test]
124    fn test_block_info_bits_per_weight_q4_0() {
125        let info = BlockInfo::for_type(GgufTensorType::Q4_0);
126        // Q4_0: 18 bytes per 32 weights → (18*8)/32 = 4.5
127        let expected = (18.0f32 * 8.0) / 32.0;
128        assert!(
129            (info.bits_per_weight - expected).abs() < 0.01,
130            "bits_per_weight: {} vs expected {}",
131            info.bits_per_weight,
132            expected
133        );
134    }
135
136    #[test]
137    fn test_quant_tensor_clone() {
138        let t = QuantTensor::new(vec![1u8, 2, 3, 4], vec![2, 2], GgufTensorType::F32);
139        let t2 = t.clone();
140        assert_eq!(t2.data, t.data);
141        assert_eq!(t2.shape, t.shape);
142    }
143}