1use oxillama_gguf::GgufTensorType;
4
5#[derive(Debug, Clone)]
11pub struct QuantTensor {
12 pub data: Vec<u8>,
14 pub shape: Vec<usize>,
16 pub tensor_type: GgufTensorType,
18}
19
20impl QuantTensor {
21 pub fn new(data: Vec<u8>, shape: Vec<usize>, tensor_type: GgufTensorType) -> Self {
23 Self {
24 data,
25 shape,
26 tensor_type,
27 }
28 }
29
30 pub fn n_elements(&self) -> usize {
32 if self.shape.is_empty() {
33 return 0;
34 }
35 self.shape.iter().product()
36 }
37
38 pub fn n_blocks(&self) -> usize {
40 let block_size = self.tensor_type.block_size();
41 if block_size == 0 {
42 return 0;
43 }
44 self.n_elements().div_ceil(block_size)
45 }
46
47 pub fn expected_data_size(&self) -> usize {
49 self.n_blocks() * self.tensor_type.block_bytes()
50 }
51}
52
53#[derive(Debug, Clone, Copy)]
55pub struct BlockInfo {
56 pub block_size: usize,
58 pub block_bytes: usize,
60 pub bits_per_weight: f32,
62}
63
64impl BlockInfo {
65 pub fn for_type(tensor_type: GgufTensorType) -> Self {
67 let block_size = tensor_type.block_size();
68 let block_bytes = tensor_type.block_bytes();
69 let bits_per_weight = if block_size > 0 {
70 (block_bytes as f32 * 8.0) / block_size as f32
71 } else {
72 0.0
73 };
74 Self {
75 block_size,
76 block_bytes,
77 bits_per_weight,
78 }
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use oxillama_gguf::GgufTensorType;
86
87 #[test]
88 fn test_quant_tensor_n_elements_2d() {
89 let t = QuantTensor::new(vec![0u8; 32], vec![4, 8], GgufTensorType::Q8_0);
90 assert_eq!(t.n_elements(), 32);
91 }
92
93 #[test]
94 fn test_quant_tensor_n_elements_empty_shape() {
95 let t = QuantTensor::new(vec![], vec![], GgufTensorType::F32);
96 assert_eq!(t.n_elements(), 0);
97 }
98
99 #[test]
100 fn test_quant_tensor_n_blocks_q4_0() {
101 let block_bytes = GgufTensorType::Q4_0.block_bytes() * 2;
104 let t = QuantTensor::new(vec![0u8; block_bytes], vec![64], GgufTensorType::Q4_0);
105 assert_eq!(t.n_blocks(), 2);
106 }
107
108 #[test]
109 fn test_quant_tensor_expected_data_size_f32() {
110 let t = QuantTensor::new(vec![0u8; 20], vec![5], GgufTensorType::F32);
112 assert_eq!(t.expected_data_size(), 20); }
114
115 #[test]
116 fn test_block_info_for_q8_0() {
117 let info = BlockInfo::for_type(GgufTensorType::Q8_0);
118 assert_eq!(info.block_size, 32);
119 assert_eq!(info.block_bytes, 34); assert!(info.bits_per_weight > 0.0);
121 }
122
123 #[test]
124 fn test_block_info_bits_per_weight_q4_0() {
125 let info = BlockInfo::for_type(GgufTensorType::Q4_0);
126 let expected = (18.0f32 * 8.0) / 32.0;
128 assert!(
129 (info.bits_per_weight - expected).abs() < 0.01,
130 "bits_per_weight: {} vs expected {}",
131 info.bits_per_weight,
132 expected
133 );
134 }
135
136 #[test]
137 fn test_quant_tensor_clone() {
138 let t = QuantTensor::new(vec![1u8, 2, 3, 4], vec![2, 2], GgufTensorType::F32);
139 let t2 = t.clone();
140 assert_eq!(t2.data, t.data);
141 assert_eq!(t2.shape, t.shape);
142 }
143}