Skip to main content

entrenar/hf_pipeline/export/
gguf_writer.rs

1//! GGUF quantization helpers
2//!
3//! Provides entrenar's quantization enum and byte-encoding functions that bridge
4//! entrenar's Q4_0/Q8_0 quant structs to raw GGUF block bytes. Binary GGUF
5//! serialization is delegated to `aprender::format::gguf::export_tensors_to_gguf`.
6
7use aprender::format::gguf::GgmlType;
8
9use crate::quant::{GGUF_BLOCK_SIZE, Q4_0, Q8_0};
10
11/// GGUF quantization mode for export
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum GgufQuantization {
14    /// No quantization — store as F32
15    None,
16    /// Quantize to Q4_0 (4-bit, 32-element blocks)
17    Q4_0,
18    /// Quantize to Q8_0 (8-bit, 32-element blocks)
19    Q8_0,
20}
21
22/// Quantize f32 data according to `quant` mode and return raw GGUF bytes + dtype.
23///
24/// For `GgufQuantization::None`, returns the f32 data as little-endian bytes with `GgmlType::F32`.
25/// For Q4_0/Q8_0, quantizes via entrenar's quant module and encodes to GGUF block format.
26pub fn quantize_to_gguf_bytes(data: &[f32], quant: GgufQuantization) -> (Vec<u8>, GgmlType) {
27    // Edge case: empty input is a valid no-op for all three quantization
28    // modes (empty tensor → empty bytes, dtype preserved). Return early so
29    // we don't trip `contract_pre_quantize!`'s `input.len() > 0` debug
30    // assertion — the contract's domain is non-empty inputs (where the
31    // precision bound is meaningful), but the function must still handle
32    // empty cleanly because callers may pass zero-length slices for
33    // skipped tensors. See `test_falsify_quantize_empty_data_*` for the
34    // documented invariant: empty input → empty output, dtype matches the
35    // requested quantization mode.
36    if data.is_empty() {
37        let dtype = match quant {
38            GgufQuantization::None => GgmlType::F32,
39            GgufQuantization::Q4_0 => GgmlType::Q4_0,
40            GgufQuantization::Q8_0 => GgmlType::Q8_0,
41        };
42        return (Vec::new(), dtype);
43    }
44    contract_pre_quantize!(data);
45    let result = match quant {
46        GgufQuantization::None => {
47            let bytes: Vec<u8> = bytemuck::cast_slice(data).to_vec();
48            (bytes, GgmlType::F32)
49        }
50        GgufQuantization::Q4_0 => {
51            let quantized = Q4_0::quantize(data);
52            (encode_q4_0_blocks(&quantized), GgmlType::Q4_0)
53        }
54        GgufQuantization::Q8_0 => {
55            let quantized = Q8_0::quantize(data);
56            (encode_q8_0_blocks(&quantized), GgmlType::Q8_0)
57        }
58    };
59    contract_post_quantize_precision_bound!(&result);
60    result
61}
62
63/// Encode Q4_0 quantized data into GGUF binary block format
64/// Each block: f16 scale (2 bytes) + 16 bytes packed 4-bit data = 18 bytes
65fn encode_q4_0_blocks(q: &Q4_0) -> Vec<u8> {
66    let num_blocks = q.num_blocks();
67    let mut bytes = Vec::with_capacity(num_blocks * 18);
68
69    for block_idx in 0..num_blocks {
70        // Scale as f16
71        let scale_f16 = half::f16::from_f32(q.scales[block_idx]);
72        bytes.extend_from_slice(&scale_f16.to_le_bytes());
73
74        // 16 bytes of packed 4-bit data
75        let data_start = block_idx * 16;
76        let data_end = (data_start + 16).min(q.data.len());
77        bytes.extend_from_slice(&q.data[data_start..data_end]);
78
79        // Pad if the last block is short
80        let pad = 16 - (data_end - data_start);
81        bytes.extend(std::iter::repeat_n(0u8, pad));
82    }
83
84    bytes
85}
86
87/// Encode Q8_0 quantized data into GGUF binary block format
88/// Each block: f16 scale (2 bytes) + 32 bytes i8 data = 34 bytes
89fn encode_q8_0_blocks(q: &Q8_0) -> Vec<u8> {
90    let num_blocks = q.num_blocks();
91    let mut bytes = Vec::with_capacity(num_blocks * 34);
92
93    for block_idx in 0..num_blocks {
94        // Scale as f16
95        let scale_f16 = half::f16::from_f32(q.scales[block_idx]);
96        bytes.extend_from_slice(&scale_f16.to_le_bytes());
97
98        // 32 bytes of i8 data
99        let data_start = block_idx * GGUF_BLOCK_SIZE;
100        let data_end = (data_start + GGUF_BLOCK_SIZE).min(q.data.len());
101        let block_data = &q.data[data_start..data_end];
102
103        // Cast i8 slice to u8 slice for extending
104        for &val in block_data {
105            bytes.push(val as u8);
106        }
107
108        // Pad incomplete blocks
109        let pad = GGUF_BLOCK_SIZE - block_data.len();
110        bytes.extend(std::iter::repeat_n(0u8, pad));
111    }
112
113    bytes
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use proptest::prelude::*;
120
121    #[test]
122    fn test_encode_q4_0_block_size() {
123        let values: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
124        let q = Q4_0::quantize(&values);
125        let bytes = encode_q4_0_blocks(&q);
126        // 1 block * 18 bytes
127        assert_eq!(bytes.len(), 18);
128    }
129
130    #[test]
131    fn test_encode_q8_0_block_size() {
132        let values: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
133        let q = Q8_0::quantize(&values);
134        let bytes = encode_q8_0_blocks(&q);
135        // 1 block * 34 bytes
136        assert_eq!(bytes.len(), 34);
137    }
138
139    #[test]
140    fn test_quantize_to_gguf_bytes_none() {
141        let data = [1.0f32, 2.0, 3.0, 4.0];
142        let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::None);
143        assert_eq!(dtype, GgmlType::F32);
144        assert_eq!(bytes.len(), 16); // 4 floats * 4 bytes
145                                     // Verify first float roundtrips
146        let val = f32::from_le_bytes(bytes[0..4].try_into().expect("conversion should succeed"));
147        assert!((val - 1.0).abs() < f32::EPSILON);
148    }
149
150    #[test]
151    fn test_quantize_to_gguf_bytes_q4_0() {
152        let data: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
153        let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q4_0);
154        assert_eq!(dtype, GgmlType::Q4_0);
155        assert_eq!(bytes.len(), 18); // 1 block
156    }
157
158    #[test]
159    fn test_quantize_to_gguf_bytes_q8_0() {
160        let data: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
161        let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q8_0);
162        assert_eq!(dtype, GgmlType::Q8_0);
163        assert_eq!(bytes.len(), 34); // 1 block
164    }
165
166    // =====================================================================
167    // Falsification: edge cases for quantize_to_gguf_bytes
168    // =====================================================================
169
170    #[test]
171    fn test_falsify_quantize_empty_data_none() {
172        let (bytes, dtype) = quantize_to_gguf_bytes(&[], GgufQuantization::None);
173        assert_eq!(dtype, GgmlType::F32);
174        assert!(bytes.is_empty());
175    }
176
177    #[test]
178    fn test_falsify_quantize_empty_data_q4_0() {
179        let (bytes, dtype) = quantize_to_gguf_bytes(&[], GgufQuantization::Q4_0);
180        assert_eq!(dtype, GgmlType::Q4_0);
181        assert!(
182            bytes.is_empty(),
183            "empty input must produce empty output, got {} bytes",
184            bytes.len()
185        );
186    }
187
188    #[test]
189    fn test_falsify_quantize_empty_data_q8_0() {
190        let (bytes, dtype) = quantize_to_gguf_bytes(&[], GgufQuantization::Q8_0);
191        assert_eq!(dtype, GgmlType::Q8_0);
192        assert!(
193            bytes.is_empty(),
194            "empty input must produce empty output, got {} bytes",
195            bytes.len()
196        );
197    }
198
199    #[test]
200    fn test_falsify_quantize_single_element_q4_0() {
201        let (bytes, dtype) = quantize_to_gguf_bytes(&[42.0], GgufQuantization::Q4_0);
202        assert_eq!(dtype, GgmlType::Q4_0);
203        // 1 element → 1 block → 18 bytes (2 scale + 16 packed data)
204        assert_eq!(bytes.len(), 18);
205    }
206
207    #[test]
208    fn test_falsify_quantize_single_element_q8_0() {
209        let (bytes, dtype) = quantize_to_gguf_bytes(&[42.0], GgufQuantization::Q8_0);
210        assert_eq!(dtype, GgmlType::Q8_0);
211        // 1 element → 1 block → 34 bytes (2 scale + 32 i8 data)
212        assert_eq!(bytes.len(), 34);
213    }
214
215    #[test]
216    fn test_falsify_quantize_33_elements_q4_0() {
217        // 33 elements = 2 blocks (32 + 1), second block is mostly padding
218        let data: Vec<f32> = (0..33).map(|i| i as f32 * 0.1).collect();
219        let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q4_0);
220        assert_eq!(dtype, GgmlType::Q4_0);
221        assert_eq!(bytes.len(), 2 * 18); // exactly 2 blocks
222    }
223
224    #[test]
225    fn test_falsify_quantize_33_elements_q8_0() {
226        let data: Vec<f32> = (0..33).map(|i| i as f32 * 0.1).collect();
227        let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q8_0);
228        assert_eq!(dtype, GgmlType::Q8_0);
229        assert_eq!(bytes.len(), 2 * 34); // exactly 2 blocks
230    }
231
232    #[test]
233    fn test_falsify_quantize_63_elements_q4_0() {
234        // 63 elements = 2 blocks (32 + 31)
235        let data: Vec<f32> = (0..63).map(|i| i as f32 * 0.01).collect();
236        let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q4_0);
237        assert_eq!(dtype, GgmlType::Q4_0);
238        assert_eq!(bytes.len(), 2 * 18);
239    }
240
241    #[test]
242    fn test_falsify_quantize_all_zeros_q4_0() {
243        // All zeros — scale should be 0, data all zeros
244        let data = [0.0f32; 64];
245        let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q4_0);
246        assert_eq!(dtype, GgmlType::Q4_0);
247        assert_eq!(bytes.len(), 2 * 18);
248        // Scale bytes (first 2 of each 18-byte block) should encode 0.0
249        let scale0 = half::f16::from_le_bytes([bytes[0], bytes[1]]);
250        assert_eq!(scale0.to_f32(), 0.0, "scale for zero data must be 0");
251    }
252
253    #[test]
254    fn test_falsify_quantize_all_zeros_q8_0() {
255        let data = [0.0f32; 32];
256        let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q8_0);
257        assert_eq!(dtype, GgmlType::Q8_0);
258        assert_eq!(bytes.len(), 34);
259        let scale0 = half::f16::from_le_bytes([bytes[0], bytes[1]]);
260        assert_eq!(scale0.to_f32(), 0.0, "scale for zero data must be 0");
261    }
262
263    #[test]
264    fn test_falsify_quantize_extreme_range_q4_0() {
265        // Values spanning huge range — tests f16 scale saturation
266        let mut data = vec![0.0f32; 32];
267        data[0] = 1e30;
268        data[1] = -1e30;
269        let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q4_0);
270        assert_eq!(dtype, GgmlType::Q4_0);
271        assert_eq!(bytes.len(), 18);
272        // Scale should be finite or infinity but not NaN
273        let scale = half::f16::from_le_bytes([bytes[0], bytes[1]]);
274        assert!(!scale.to_f32().is_nan(), "scale must not be NaN for extreme values");
275    }
276
277    #[test]
278    fn test_falsify_quantize_extreme_range_q8_0() {
279        let mut data = vec![0.0f32; 32];
280        data[0] = 1e30;
281        data[1] = -1e30;
282        let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q8_0);
283        assert_eq!(dtype, GgmlType::Q8_0);
284        assert_eq!(bytes.len(), 34);
285        let scale = half::f16::from_le_bytes([bytes[0], bytes[1]]);
286        assert!(!scale.to_f32().is_nan(), "scale must not be NaN for extreme values");
287    }
288
289    #[test]
290    fn test_falsify_quantize_f32_exact_byte_layout() {
291        // Verify F32 mode produces exact little-endian layout
292        let data = [std::f32::consts::PI, std::f32::consts::E];
293        let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::None);
294        assert_eq!(dtype, GgmlType::F32);
295        assert_eq!(bytes.len(), 8);
296        let pi_bytes = std::f32::consts::PI.to_le_bytes();
297        let e_bytes = std::f32::consts::E.to_le_bytes();
298        assert_eq!(&bytes[0..4], &pi_bytes);
299        assert_eq!(&bytes[4..8], &e_bytes);
300    }
301
302    proptest! {
303        #![proptest_config(proptest::test_runner::Config::with_cases(50))]
304
305        #[test]
306        fn prop_q4_0_encode_correct_block_count(
307            n_elements in 1usize..256,
308        ) {
309            let data: Vec<f32> = vec![1.0; n_elements];
310            let q = Q4_0::quantize(&data);
311            let bytes = encode_q4_0_blocks(&q);
312            let expected_blocks = n_elements.div_ceil(GGUF_BLOCK_SIZE);
313            prop_assert_eq!(bytes.len(), expected_blocks * 18);
314        }
315
316        #[test]
317        fn prop_q8_0_encode_correct_block_count(
318            n_elements in 1usize..256,
319        ) {
320            let data: Vec<f32> = vec![1.0; n_elements];
321            let q = Q8_0::quantize(&data);
322            let bytes = encode_q8_0_blocks(&q);
323            let expected_blocks = n_elements.div_ceil(GGUF_BLOCK_SIZE);
324            prop_assert_eq!(bytes.len(), expected_blocks * 34);
325        }
326
327        #[test]
328        fn prop_falsify_quantize_none_preserves_all_bytes(
329            n_elements in 1usize..128,
330        ) {
331            let data: Vec<f32> = (0..n_elements).map(|i| i as f32 * 0.7 - 50.0).collect();
332            let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::None);
333            prop_assert_eq!(dtype, GgmlType::F32);
334            prop_assert_eq!(bytes.len(), n_elements * 4);
335            // Verify every float
336            for (i, &expected) in data.iter().enumerate() {
337                let actual = f32::from_le_bytes(bytes[i*4..(i+1)*4].try_into().expect("conversion should succeed"));
338                prop_assert!(
339                    (actual - expected).abs() < f32::EPSILON,
340                    "element {i}: expected {expected}, got {actual}"
341                );
342            }
343        }
344
345        #[test]
346        fn prop_falsify_quantize_q4_0_byte_size_invariant(
347            n_elements in 1usize..512,
348        ) {
349            let data: Vec<f32> = vec![0.5; n_elements];
350            let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q4_0);
351            prop_assert_eq!(dtype, GgmlType::Q4_0);
352            let expected_blocks = n_elements.div_ceil(GGUF_BLOCK_SIZE);
353            prop_assert_eq!(bytes.len(), expected_blocks * 18);
354        }
355
356        #[test]
357        fn prop_falsify_quantize_q8_0_byte_size_invariant(
358            n_elements in 1usize..512,
359        ) {
360            let data: Vec<f32> = vec![0.5; n_elements];
361            let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q8_0);
362            prop_assert_eq!(dtype, GgmlType::Q8_0);
363            let expected_blocks = n_elements.div_ceil(GGUF_BLOCK_SIZE);
364            prop_assert_eq!(bytes.len(), expected_blocks * 34);
365        }
366    }
367}