entrenar/quant/
quant4bit.rs1use serde::{Deserialize, Serialize};
14
15pub const BLOCK_SIZE: usize = 64;
17
18#[derive(Clone, Debug, Serialize, Deserialize)]
24pub struct Quantized4Bit {
25 pub scales: Vec<f32>,
27 pub data: Vec<u8>,
29 pub len: usize,
31}
32
33impl Quantized4Bit {
34 pub fn memory_bytes(&self) -> usize {
36 self.scales.len() * 4 + self.data.len()
37 }
38
39 #[provable_contracts_macros::contract("quantization-v1", equation = "compression_ratio")]
41 pub fn compression_ratio(&self) -> f32 {
42 let original_bytes = self.len * 4; let compressed_bytes = self.memory_bytes();
44 original_bytes as f32 / compressed_bytes as f32
45 }
46}
47
48pub fn quantize_4bit(values: &[f32]) -> Quantized4Bit {
56 let len = values.len();
57 let num_blocks = len.div_ceil(BLOCK_SIZE);
58
59 let mut scales = Vec::with_capacity(num_blocks);
60 let mut data = Vec::with_capacity(len.div_ceil(2)); for block_idx in 0..num_blocks {
63 let start = block_idx * BLOCK_SIZE;
64 let end = (start + BLOCK_SIZE).min(len);
65 let block = &values[start..end];
66
67 let max_abs = block.iter().map(|v| v.abs()).max_by(f32::total_cmp).unwrap_or(1e-8);
70
71 let scale = max_abs / 7.0;
72 scales.push(scale);
73
74 for (i, &val) in block.iter().enumerate() {
76 let quantized = quantize_value(val, scale);
77
78 if i.is_multiple_of(2) {
80 data.push(((quantized as u8) & 0x0F) << 4);
82 } else {
83 let last_idx = data.len() - 1;
85 data[last_idx] |= (quantized as u8) & 0x0F;
86 }
87 }
88
89 }
92
93 Quantized4Bit { scales, data, len }
94}
95
96pub fn dequantize_4bit(quantized: &Quantized4Bit) -> Vec<f32> {
104 let mut result = Vec::with_capacity(quantized.len);
105
106 let num_blocks = quantized.scales.len();
107
108 for block_idx in 0..num_blocks {
109 let scale = quantized.scales[block_idx];
110 let start = block_idx * BLOCK_SIZE;
111 let end = (start + BLOCK_SIZE).min(quantized.len);
112 let block_len = end - start;
113
114 for i in 0..block_len {
115 let byte_idx = usize::midpoint(start, i);
116 let byte = quantized.data[byte_idx];
117
118 let q_val = if (start + i).is_multiple_of(2) {
120 let nibble = (byte >> 4) & 0x0F;
122 if nibble & 0x08 != 0 {
124 (nibble | 0xF0) as i8
125 } else {
126 nibble as i8
127 }
128 } else {
129 let nibble = byte & 0x0F;
131 if nibble & 0x08 != 0 {
133 (nibble | 0xF0) as i8
134 } else {
135 nibble as i8
136 }
137 };
138
139 let deq_val = f32::from(q_val) * scale;
141 result.push(deq_val);
142 }
143 }
144
145 result
146}
147
148fn quantize_value(val: f32, scale: f32) -> i8 {
152 let normalized = val / scale;
153 let clamped = normalized.clamp(-7.0, 7.0);
154 clamped.round() as i8
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use approx::assert_abs_diff_eq;
161
162 #[test]
163 fn test_quantize_dequantize_round_trip() {
164 let values = vec![1.0, -2.0, 3.5, -4.2, 0.5, -0.8, 2.1, -1.5];
165 let quantized = quantize_4bit(&values);
166 let dequantized = dequantize_4bit(&quantized);
167
168 assert_eq!(dequantized.len(), values.len());
169
170 for (original, deq) in values.iter().zip(dequantized.iter()) {
173 let error = (original - deq).abs();
174 let relative_error = error / original.abs().max(1e-6);
175 assert!(
176 relative_error < 0.3,
177 "Relative error too large: {original} vs {deq} (error: {error}, rel_error: {relative_error})"
178 );
179 }
180 }
181
182 #[test]
183 fn test_quantize_zeros() {
184 let values = vec![0.0; 64];
185 let quantized = quantize_4bit(&values);
186 let dequantized = dequantize_4bit(&quantized);
187
188 for val in dequantized {
189 assert_abs_diff_eq!(val, 0.0, epsilon = 1e-6);
190 }
191 }
192
193 #[test]
194 fn test_quantize_uniform() {
195 let values = vec![1.0; 64];
196 let quantized = quantize_4bit(&values);
197 let dequantized = dequantize_4bit(&quantized);
198
199 for val in dequantized {
200 assert_abs_diff_eq!(val, 1.0, epsilon = 0.2);
201 }
202 }
203
204 #[test]
205 fn test_quantize_range() {
206 let values: Vec<f32> = (-7..=7).map(|x| x as f32).collect();
208 let quantized = quantize_4bit(&values);
209 let dequantized = dequantize_4bit(&quantized);
210
211 for (original, deq) in values.iter().zip(dequantized.iter()) {
212 assert_abs_diff_eq!(original, deq, epsilon = 0.5);
213 }
214 }
215
216 #[test]
217 fn test_quantize_multiple_blocks() {
218 let values: Vec<f32> = (0..200).map(|i| (i as f32 * 0.1).sin()).collect();
220 let quantized = quantize_4bit(&values);
221 let dequantized = dequantize_4bit(&quantized);
222
223 assert_eq!(dequantized.len(), values.len());
224
225 let expected_blocks = 200_usize.div_ceil(BLOCK_SIZE);
227 assert_eq!(quantized.scales.len(), expected_blocks);
228 }
229
230 #[test]
231 fn test_memory_savings() {
232 let values = vec![1.0; 1024];
233 let quantized = quantize_4bit(&values);
234
235 let original_bytes = values.len() * 4; let compressed_bytes = quantized.memory_bytes();
237
238 let compression = original_bytes as f32 / compressed_bytes as f32;
240 assert!(compression > 6.0, "Compression ratio {compression} should be > 6.0");
241 }
242
243 #[test]
244 fn test_compression_ratio() {
245 let values = vec![1.5; 1024];
246 let quantized = quantize_4bit(&values);
247
248 let ratio = quantized.compression_ratio();
249 assert!(ratio > 6.0, "Compression ratio {ratio} should be > 6.0");
250 }
251
252 #[test]
253 fn test_quantize_small_values() {
254 let values = vec![0.001, 0.002, 0.003, 0.004];
255 let quantized = quantize_4bit(&values);
256 let dequantized = dequantize_4bit(&quantized);
257
258 for (original, deq) in values.iter().zip(dequantized.iter()) {
260 let error = (original - deq).abs();
261 assert!(error < 0.001, "Error {error} too large for small value");
262 }
263 }
264
265 #[test]
266 fn test_quantize_mixed_magnitudes() {
267 let values = vec![10.0, 1.0, -5.0, 0.5, 7.5, -2.0];
270 let quantized = quantize_4bit(&values);
271 let dequantized = dequantize_4bit(&quantized);
272
273 assert_eq!(dequantized.len(), values.len());
274
275 for (original, deq) in values.iter().zip(dequantized.iter()) {
279 let error = (original - deq).abs();
280 if original.abs() < 1.0 {
281 assert!(
283 error < 1.5,
284 "Absolute error {error} too large for small value {original} vs {deq}"
285 );
286 } else {
287 let relative_error = error / original.abs();
289 assert!(
290 relative_error < 0.5,
291 "Relative error {relative_error} too large for {original} vs {deq} (error: {error})"
292 );
293 }
294 }
295 }
296
297 #[test]
298 fn test_quantize_odd_length() {
299 let values: Vec<f32> = (0..77).map(|i| i as f32 * 0.5).collect();
301 let quantized = quantize_4bit(&values);
302 let dequantized = dequantize_4bit(&quantized);
303
304 assert_eq!(dequantized.len(), 77);
305 }
306
307 #[test]
308 fn test_quantized_data_size() {
309 let values = vec![1.0; 128];
310 let quantized = quantize_4bit(&values);
311
312 assert_eq!(quantized.data.len(), 64);
314
315 assert_eq!(quantized.scales.len(), 2);
317 }
318}