entrenar/hf_pipeline/export/
gguf_writer.rs1use aprender::format::gguf::GgmlType;
8
9use crate::quant::{GGUF_BLOCK_SIZE, Q4_0, Q8_0};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum GgufQuantization {
14 None,
16 Q4_0,
18 Q8_0,
20}
21
22pub fn quantize_to_gguf_bytes(data: &[f32], quant: GgufQuantization) -> (Vec<u8>, GgmlType) {
27 if data.is_empty() {
37 let dtype = match quant {
38 GgufQuantization::None => GgmlType::F32,
39 GgufQuantization::Q4_0 => GgmlType::Q4_0,
40 GgufQuantization::Q8_0 => GgmlType::Q8_0,
41 };
42 return (Vec::new(), dtype);
43 }
44 contract_pre_quantize!(data);
45 let result = match quant {
46 GgufQuantization::None => {
47 let bytes: Vec<u8> = bytemuck::cast_slice(data).to_vec();
48 (bytes, GgmlType::F32)
49 }
50 GgufQuantization::Q4_0 => {
51 let quantized = Q4_0::quantize(data);
52 (encode_q4_0_blocks(&quantized), GgmlType::Q4_0)
53 }
54 GgufQuantization::Q8_0 => {
55 let quantized = Q8_0::quantize(data);
56 (encode_q8_0_blocks(&quantized), GgmlType::Q8_0)
57 }
58 };
59 contract_post_quantize_precision_bound!(&result);
60 result
61}
62
63fn encode_q4_0_blocks(q: &Q4_0) -> Vec<u8> {
66 let num_blocks = q.num_blocks();
67 let mut bytes = Vec::with_capacity(num_blocks * 18);
68
69 for block_idx in 0..num_blocks {
70 let scale_f16 = half::f16::from_f32(q.scales[block_idx]);
72 bytes.extend_from_slice(&scale_f16.to_le_bytes());
73
74 let data_start = block_idx * 16;
76 let data_end = (data_start + 16).min(q.data.len());
77 bytes.extend_from_slice(&q.data[data_start..data_end]);
78
79 let pad = 16 - (data_end - data_start);
81 bytes.extend(std::iter::repeat_n(0u8, pad));
82 }
83
84 bytes
85}
86
87fn encode_q8_0_blocks(q: &Q8_0) -> Vec<u8> {
90 let num_blocks = q.num_blocks();
91 let mut bytes = Vec::with_capacity(num_blocks * 34);
92
93 for block_idx in 0..num_blocks {
94 let scale_f16 = half::f16::from_f32(q.scales[block_idx]);
96 bytes.extend_from_slice(&scale_f16.to_le_bytes());
97
98 let data_start = block_idx * GGUF_BLOCK_SIZE;
100 let data_end = (data_start + GGUF_BLOCK_SIZE).min(q.data.len());
101 let block_data = &q.data[data_start..data_end];
102
103 for &val in block_data {
105 bytes.push(val as u8);
106 }
107
108 let pad = GGUF_BLOCK_SIZE - block_data.len();
110 bytes.extend(std::iter::repeat_n(0u8, pad));
111 }
112
113 bytes
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use proptest::prelude::*;
120
121 #[test]
122 fn test_encode_q4_0_block_size() {
123 let values: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
124 let q = Q4_0::quantize(&values);
125 let bytes = encode_q4_0_blocks(&q);
126 assert_eq!(bytes.len(), 18);
128 }
129
130 #[test]
131 fn test_encode_q8_0_block_size() {
132 let values: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
133 let q = Q8_0::quantize(&values);
134 let bytes = encode_q8_0_blocks(&q);
135 assert_eq!(bytes.len(), 34);
137 }
138
139 #[test]
140 fn test_quantize_to_gguf_bytes_none() {
141 let data = [1.0f32, 2.0, 3.0, 4.0];
142 let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::None);
143 assert_eq!(dtype, GgmlType::F32);
144 assert_eq!(bytes.len(), 16); let val = f32::from_le_bytes(bytes[0..4].try_into().expect("conversion should succeed"));
147 assert!((val - 1.0).abs() < f32::EPSILON);
148 }
149
150 #[test]
151 fn test_quantize_to_gguf_bytes_q4_0() {
152 let data: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
153 let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q4_0);
154 assert_eq!(dtype, GgmlType::Q4_0);
155 assert_eq!(bytes.len(), 18); }
157
158 #[test]
159 fn test_quantize_to_gguf_bytes_q8_0() {
160 let data: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
161 let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q8_0);
162 assert_eq!(dtype, GgmlType::Q8_0);
163 assert_eq!(bytes.len(), 34); }
165
166 #[test]
171 fn test_falsify_quantize_empty_data_none() {
172 let (bytes, dtype) = quantize_to_gguf_bytes(&[], GgufQuantization::None);
173 assert_eq!(dtype, GgmlType::F32);
174 assert!(bytes.is_empty());
175 }
176
177 #[test]
178 fn test_falsify_quantize_empty_data_q4_0() {
179 let (bytes, dtype) = quantize_to_gguf_bytes(&[], GgufQuantization::Q4_0);
180 assert_eq!(dtype, GgmlType::Q4_0);
181 assert!(
182 bytes.is_empty(),
183 "empty input must produce empty output, got {} bytes",
184 bytes.len()
185 );
186 }
187
188 #[test]
189 fn test_falsify_quantize_empty_data_q8_0() {
190 let (bytes, dtype) = quantize_to_gguf_bytes(&[], GgufQuantization::Q8_0);
191 assert_eq!(dtype, GgmlType::Q8_0);
192 assert!(
193 bytes.is_empty(),
194 "empty input must produce empty output, got {} bytes",
195 bytes.len()
196 );
197 }
198
199 #[test]
200 fn test_falsify_quantize_single_element_q4_0() {
201 let (bytes, dtype) = quantize_to_gguf_bytes(&[42.0], GgufQuantization::Q4_0);
202 assert_eq!(dtype, GgmlType::Q4_0);
203 assert_eq!(bytes.len(), 18);
205 }
206
207 #[test]
208 fn test_falsify_quantize_single_element_q8_0() {
209 let (bytes, dtype) = quantize_to_gguf_bytes(&[42.0], GgufQuantization::Q8_0);
210 assert_eq!(dtype, GgmlType::Q8_0);
211 assert_eq!(bytes.len(), 34);
213 }
214
215 #[test]
216 fn test_falsify_quantize_33_elements_q4_0() {
217 let data: Vec<f32> = (0..33).map(|i| i as f32 * 0.1).collect();
219 let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q4_0);
220 assert_eq!(dtype, GgmlType::Q4_0);
221 assert_eq!(bytes.len(), 2 * 18); }
223
224 #[test]
225 fn test_falsify_quantize_33_elements_q8_0() {
226 let data: Vec<f32> = (0..33).map(|i| i as f32 * 0.1).collect();
227 let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q8_0);
228 assert_eq!(dtype, GgmlType::Q8_0);
229 assert_eq!(bytes.len(), 2 * 34); }
231
232 #[test]
233 fn test_falsify_quantize_63_elements_q4_0() {
234 let data: Vec<f32> = (0..63).map(|i| i as f32 * 0.01).collect();
236 let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q4_0);
237 assert_eq!(dtype, GgmlType::Q4_0);
238 assert_eq!(bytes.len(), 2 * 18);
239 }
240
241 #[test]
242 fn test_falsify_quantize_all_zeros_q4_0() {
243 let data = [0.0f32; 64];
245 let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q4_0);
246 assert_eq!(dtype, GgmlType::Q4_0);
247 assert_eq!(bytes.len(), 2 * 18);
248 let scale0 = half::f16::from_le_bytes([bytes[0], bytes[1]]);
250 assert_eq!(scale0.to_f32(), 0.0, "scale for zero data must be 0");
251 }
252
253 #[test]
254 fn test_falsify_quantize_all_zeros_q8_0() {
255 let data = [0.0f32; 32];
256 let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q8_0);
257 assert_eq!(dtype, GgmlType::Q8_0);
258 assert_eq!(bytes.len(), 34);
259 let scale0 = half::f16::from_le_bytes([bytes[0], bytes[1]]);
260 assert_eq!(scale0.to_f32(), 0.0, "scale for zero data must be 0");
261 }
262
263 #[test]
264 fn test_falsify_quantize_extreme_range_q4_0() {
265 let mut data = vec![0.0f32; 32];
267 data[0] = 1e30;
268 data[1] = -1e30;
269 let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q4_0);
270 assert_eq!(dtype, GgmlType::Q4_0);
271 assert_eq!(bytes.len(), 18);
272 let scale = half::f16::from_le_bytes([bytes[0], bytes[1]]);
274 assert!(!scale.to_f32().is_nan(), "scale must not be NaN for extreme values");
275 }
276
277 #[test]
278 fn test_falsify_quantize_extreme_range_q8_0() {
279 let mut data = vec![0.0f32; 32];
280 data[0] = 1e30;
281 data[1] = -1e30;
282 let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q8_0);
283 assert_eq!(dtype, GgmlType::Q8_0);
284 assert_eq!(bytes.len(), 34);
285 let scale = half::f16::from_le_bytes([bytes[0], bytes[1]]);
286 assert!(!scale.to_f32().is_nan(), "scale must not be NaN for extreme values");
287 }
288
289 #[test]
290 fn test_falsify_quantize_f32_exact_byte_layout() {
291 let data = [std::f32::consts::PI, std::f32::consts::E];
293 let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::None);
294 assert_eq!(dtype, GgmlType::F32);
295 assert_eq!(bytes.len(), 8);
296 let pi_bytes = std::f32::consts::PI.to_le_bytes();
297 let e_bytes = std::f32::consts::E.to_le_bytes();
298 assert_eq!(&bytes[0..4], &pi_bytes);
299 assert_eq!(&bytes[4..8], &e_bytes);
300 }
301
302 proptest! {
303 #![proptest_config(proptest::test_runner::Config::with_cases(50))]
304
305 #[test]
306 fn prop_q4_0_encode_correct_block_count(
307 n_elements in 1usize..256,
308 ) {
309 let data: Vec<f32> = vec![1.0; n_elements];
310 let q = Q4_0::quantize(&data);
311 let bytes = encode_q4_0_blocks(&q);
312 let expected_blocks = n_elements.div_ceil(GGUF_BLOCK_SIZE);
313 prop_assert_eq!(bytes.len(), expected_blocks * 18);
314 }
315
316 #[test]
317 fn prop_q8_0_encode_correct_block_count(
318 n_elements in 1usize..256,
319 ) {
320 let data: Vec<f32> = vec![1.0; n_elements];
321 let q = Q8_0::quantize(&data);
322 let bytes = encode_q8_0_blocks(&q);
323 let expected_blocks = n_elements.div_ceil(GGUF_BLOCK_SIZE);
324 prop_assert_eq!(bytes.len(), expected_blocks * 34);
325 }
326
327 #[test]
328 fn prop_falsify_quantize_none_preserves_all_bytes(
329 n_elements in 1usize..128,
330 ) {
331 let data: Vec<f32> = (0..n_elements).map(|i| i as f32 * 0.7 - 50.0).collect();
332 let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::None);
333 prop_assert_eq!(dtype, GgmlType::F32);
334 prop_assert_eq!(bytes.len(), n_elements * 4);
335 for (i, &expected) in data.iter().enumerate() {
337 let actual = f32::from_le_bytes(bytes[i*4..(i+1)*4].try_into().expect("conversion should succeed"));
338 prop_assert!(
339 (actual - expected).abs() < f32::EPSILON,
340 "element {i}: expected {expected}, got {actual}"
341 );
342 }
343 }
344
345 #[test]
346 fn prop_falsify_quantize_q4_0_byte_size_invariant(
347 n_elements in 1usize..512,
348 ) {
349 let data: Vec<f32> = vec![0.5; n_elements];
350 let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q4_0);
351 prop_assert_eq!(dtype, GgmlType::Q4_0);
352 let expected_blocks = n_elements.div_ceil(GGUF_BLOCK_SIZE);
353 prop_assert_eq!(bytes.len(), expected_blocks * 18);
354 }
355
356 #[test]
357 fn prop_falsify_quantize_q8_0_byte_size_invariant(
358 n_elements in 1usize..512,
359 ) {
360 let data: Vec<f32> = vec![0.5; n_elements];
361 let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q8_0);
362 prop_assert_eq!(dtype, GgmlType::Q8_0);
363 let expected_blocks = n_elements.div_ceil(GGUF_BLOCK_SIZE);
364 prop_assert_eq!(bytes.len(), expected_blocks * 34);
365 }
366 }
367}