Skip to main content

entrenar/hf_pipeline/export/
pipeline.rs

1//! Quantize-Export pipeline
2//!
3//! Quantizes model weights and exports them as GGUF files in a single operation.
4
5use crate::hf_pipeline::error::{FetchError, Result};
6use crate::hf_pipeline::export::gguf_writer::GgufQuantization;
7use std::path::Path;
8
9use super::exporter::Exporter;
10use super::format::ExportFormat;
11use super::result::ExportResult;
12use super::weights::ModelWeights;
13
14/// Result of the quantize-export pipeline
15#[derive(Debug, Clone)]
16pub struct QuantExportResult {
17    /// Export result with file path and size
18    pub export: ExportResult,
19    /// Quantization mode used
20    pub quantization: GgufQuantization,
21    /// Generated README content (if any)
22    pub readme: Option<String>,
23}
24
25/// Quantize model weights and export to GGUF format
26///
27/// Performs the full quantize→export pipeline:
28/// 1. Quantize all tensors according to the config
29/// 2. Export as GGUF with tensor data
30/// 3. Generate a README with quantization metadata
31pub fn quantize_and_export(
32    weights: &ModelWeights,
33    quantization: GgufQuantization,
34    output_dir: impl AsRef<Path>,
35    filename: impl AsRef<Path>,
36) -> Result<QuantExportResult> {
37    let output_dir = output_dir.as_ref();
38
39    // Determine output filename
40    let filename = filename.as_ref();
41
42    // Build exporter with quantization config
43    let exporter = Exporter::new().output_dir(output_dir).gguf_quantization(quantization);
44
45    let export = exporter.export(weights, ExportFormat::GGUF, filename)?;
46
47    // Generate README with quantization metadata
48    let readme = generate_quant_readme(weights, quantization, &export);
49
50    // Write README alongside the model
51    let readme_path = output_dir.join("README.md");
52    std::fs::write(&readme_path, &readme).map_err(|e| FetchError::GgufWriteError {
53        message: format!("Failed to write README: {e}"),
54    })?;
55
56    Ok(QuantExportResult { export, quantization, readme: Some(readme) })
57}
58
59/// Generate a README with quantization metadata
60fn generate_quant_readme(
61    weights: &ModelWeights,
62    quantization: GgufQuantization,
63    export: &ExportResult,
64) -> String {
65    let quant_name = match quantization {
66        GgufQuantization::None => "F32 (unquantized)",
67        GgufQuantization::Q4_0 => "Q4_0 (4-bit)",
68        GgufQuantization::Q8_0 => "Q8_0 (8-bit)",
69    };
70
71    let model_name = weights.metadata.model_name.as_deref().unwrap_or("Unknown Model");
72
73    let arch = weights.metadata.architecture.as_deref().unwrap_or("unknown");
74
75    format!(
76        "---\ntags:\n- entrenar\n- gguf\n- quantized\n---\n\n\
77         # {model_name} ({quant_name})\n\n\
78         Quantized with [Entrenar](https://github.com/paiml/entrenar).\n\n\
79         ## Model Details\n\n\
80         | Property | Value |\n\
81         |----------|-------|\n\
82         | Architecture | {arch} |\n\
83         | Parameters | {} |\n\
84         | Quantization | {quant_name} |\n\
85         | File Size | {} |\n\
86         | Tensors | {} |\n\
87         | Format | GGUF v3 |\n",
88        weights.metadata.num_params,
89        export.size_human(),
90        export.num_tensors,
91    )
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97    use crate::hf_pipeline::export::weights::{ModelMetadata, ModelWeights};
98    use tempfile::TempDir;
99
100    fn make_test_weights() -> ModelWeights {
101        let mut weights = ModelWeights::new();
102        weights.add_tensor("layer.0.weight", vec![1.0; 256], vec![16, 16]);
103        weights.add_tensor("layer.0.bias", vec![0.1; 16], vec![16]);
104        weights.metadata = ModelMetadata {
105            model_name: Some("test-model".to_string()),
106            architecture: Some("llama".to_string()),
107            num_params: 272,
108            ..Default::default()
109        };
110        weights
111    }
112
113    #[test]
114    fn test_quantize_export_f32() {
115        let weights = make_test_weights();
116        let tmp = TempDir::new().expect("temp file creation should succeed");
117
118        let result =
119            quantize_and_export(&weights, GgufQuantization::None, tmp.path(), "model.gguf")
120                .expect("operation should succeed");
121
122        assert_eq!(result.quantization, GgufQuantization::None);
123        assert!(result.export.size_bytes > 0);
124        assert!(result.readme.is_some());
125        assert!(tmp.path().join("model.gguf").exists());
126        assert!(tmp.path().join("README.md").exists());
127    }
128
129    #[test]
130    fn test_quantize_export_q4_0() {
131        let weights = make_test_weights();
132        let tmp = TempDir::new().expect("temp file creation should succeed");
133
134        let result =
135            quantize_and_export(&weights, GgufQuantization::Q4_0, tmp.path(), "model-q4.gguf")
136                .expect("operation should succeed");
137
138        assert_eq!(result.quantization, GgufQuantization::Q4_0);
139        assert!(result.export.size_bytes > 0);
140    }
141
142    #[test]
143    fn test_quantize_export_q8_0() {
144        let weights = make_test_weights();
145        let tmp = TempDir::new().expect("temp file creation should succeed");
146
147        let result =
148            quantize_and_export(&weights, GgufQuantization::Q8_0, tmp.path(), "model-q8.gguf")
149                .expect("operation should succeed");
150
151        assert_eq!(result.quantization, GgufQuantization::Q8_0);
152    }
153
154    #[test]
155    fn test_quantize_export_readme_content() {
156        let weights = make_test_weights();
157        let tmp = TempDir::new().expect("temp file creation should succeed");
158
159        let result =
160            quantize_and_export(&weights, GgufQuantization::Q4_0, tmp.path(), "model.gguf")
161                .expect("operation should succeed");
162
163        let readme = result.readme.expect("operation should succeed");
164        assert!(readme.contains("test-model"));
165        assert!(readme.contains("Q4_0"));
166        assert!(readme.contains("llama"));
167        assert!(readme.contains("entrenar"));
168    }
169
170    #[test]
171    fn test_quantize_export_q4_smaller_than_f32() {
172        let weights = make_test_weights();
173        let tmp_f32 = TempDir::new().expect("temp file creation should succeed");
174        let tmp_q4 = TempDir::new().expect("temp file creation should succeed");
175
176        let f32_result =
177            quantize_and_export(&weights, GgufQuantization::None, tmp_f32.path(), "model.gguf")
178                .expect("operation should succeed");
179        let q4_result =
180            quantize_and_export(&weights, GgufQuantization::Q4_0, tmp_q4.path(), "model.gguf")
181                .expect("operation should succeed");
182
183        assert!(
184            q4_result.export.size_bytes < f32_result.export.size_bytes,
185            "Q4_0 ({}) should be smaller than F32 ({})",
186            q4_result.export.size_bytes,
187            f32_result.export.size_bytes
188        );
189    }
190
191    // =====================================================================
192    // Falsification: pipeline roundtrip via verify_gguf
193    // =====================================================================
194
195    #[test]
196    fn test_falsify_pipeline_f32_gguf_is_valid() {
197        use crate::hf_pipeline::export::gguf_verify::verify_gguf;
198
199        let weights = make_test_weights();
200        let tmp = TempDir::new().expect("temp file creation should succeed");
201        quantize_and_export(&weights, GgufQuantization::None, tmp.path(), "f32.gguf")
202            .expect("operation should succeed");
203
204        let file_data =
205            std::fs::read(tmp.path().join("f32.gguf")).expect("file read should succeed");
206        let summary = verify_gguf(&file_data).expect("operation should succeed");
207
208        assert_eq!(summary.version, 3);
209        assert_eq!(summary.tensor_count, 2);
210        // Metadata: architecture + name + parameter_count = 3
211        assert_eq!(summary.metadata_count, 3);
212        // Tensors sorted alphabetically
213        assert_eq!(summary.tensors[0].name, "layer.0.bias");
214        assert_eq!(summary.tensors[1].name, "layer.0.weight");
215        // Both F32
216        assert_eq!(summary.tensors[0].dtype, 0);
217        assert_eq!(summary.tensors[1].dtype, 0);
218        // Shapes
219        assert_eq!(summary.tensors[0].shape, vec![16]);
220        assert_eq!(summary.tensors[1].shape, vec![16, 16]);
221    }
222
223    #[test]
224    fn test_falsify_pipeline_q4_0_gguf_is_valid() {
225        use crate::hf_pipeline::export::gguf_verify::verify_gguf;
226
227        let weights = make_test_weights();
228        let tmp = TempDir::new().expect("temp file creation should succeed");
229        quantize_and_export(&weights, GgufQuantization::Q4_0, tmp.path(), "q4.gguf")
230            .expect("operation should succeed");
231
232        let file_data =
233            std::fs::read(tmp.path().join("q4.gguf")).expect("file read should succeed");
234        let summary = verify_gguf(&file_data).expect("operation should succeed");
235
236        assert_eq!(summary.tensor_count, 2);
237        // Both Q4_0
238        assert_eq!(summary.tensors[0].dtype, 2);
239        assert_eq!(summary.tensors[1].dtype, 2);
240    }
241
242    #[test]
243    fn test_falsify_pipeline_q8_0_gguf_is_valid() {
244        use crate::hf_pipeline::export::gguf_verify::verify_gguf;
245
246        let weights = make_test_weights();
247        let tmp = TempDir::new().expect("temp file creation should succeed");
248        quantize_and_export(&weights, GgufQuantization::Q8_0, tmp.path(), "q8.gguf")
249            .expect("operation should succeed");
250
251        let file_data =
252            std::fs::read(tmp.path().join("q8.gguf")).expect("file read should succeed");
253        let summary = verify_gguf(&file_data).expect("operation should succeed");
254
255        assert_eq!(summary.tensor_count, 2);
256        // Both Q8_0
257        assert_eq!(summary.tensors[0].dtype, 8);
258        assert_eq!(summary.tensors[1].dtype, 8);
259    }
260
261    #[test]
262    fn test_falsify_pipeline_q8_smaller_than_f32() {
263        let weights = make_test_weights();
264        let tmp_f32 = TempDir::new().expect("temp file creation should succeed");
265        let tmp_q8 = TempDir::new().expect("temp file creation should succeed");
266
267        let f32_result =
268            quantize_and_export(&weights, GgufQuantization::None, tmp_f32.path(), "model.gguf")
269                .expect("operation should succeed");
270        let q8_result =
271            quantize_and_export(&weights, GgufQuantization::Q8_0, tmp_q8.path(), "model.gguf")
272                .expect("operation should succeed");
273
274        assert!(
275            q8_result.export.size_bytes < f32_result.export.size_bytes,
276            "Q8_0 ({}) should be smaller than F32 ({})",
277            q8_result.export.size_bytes,
278            f32_result.export.size_bytes
279        );
280    }
281
282    #[test]
283    fn test_falsify_pipeline_q4_smaller_than_q8() {
284        let weights = make_test_weights();
285        let tmp_q4 = TempDir::new().expect("temp file creation should succeed");
286        let tmp_q8 = TempDir::new().expect("temp file creation should succeed");
287
288        let q4_result =
289            quantize_and_export(&weights, GgufQuantization::Q4_0, tmp_q4.path(), "model.gguf")
290                .expect("operation should succeed");
291        let q8_result =
292            quantize_and_export(&weights, GgufQuantization::Q8_0, tmp_q8.path(), "model.gguf")
293                .expect("operation should succeed");
294
295        assert!(
296            q4_result.export.size_bytes < q8_result.export.size_bytes,
297            "Q4_0 ({}) should be smaller than Q8_0 ({})",
298            q4_result.export.size_bytes,
299            q8_result.export.size_bytes
300        );
301    }
302
303    #[test]
304    fn test_falsify_pipeline_readme_contains_quantization_mode() {
305        let weights = make_test_weights();
306
307        for (quant, expected_str) in [
308            (GgufQuantization::None, "F32 (unquantized)"),
309            (GgufQuantization::Q4_0, "Q4_0 (4-bit)"),
310            (GgufQuantization::Q8_0, "Q8_0 (8-bit)"),
311        ] {
312            let tmp = TempDir::new().expect("temp file creation should succeed");
313            let result = quantize_and_export(&weights, quant, tmp.path(), "model.gguf")
314                .expect("operation should succeed");
315            let readme = result.readme.expect("operation should succeed");
316            assert!(
317                readme.contains(expected_str),
318                "README for {quant:?} should contain '{expected_str}', got:\n{readme}"
319            );
320        }
321    }
322
323    #[test]
324    fn test_falsify_pipeline_f32_data_integrity_through_pipeline() {
325        // Verify actual tensor bytes survive the full pipeline
326        use crate::hf_pipeline::export::gguf_verify::verify_gguf;
327
328        let mut weights = ModelWeights::new();
329        let original: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
330        weights.add_tensor("test_data", original.clone(), vec![8, 8]);
331        weights.metadata.num_params = 64;
332
333        let tmp = TempDir::new().expect("temp file creation should succeed");
334        quantize_and_export(&weights, GgufQuantization::None, tmp.path(), "data.gguf")
335            .expect("operation should succeed");
336
337        let file_data =
338            std::fs::read(tmp.path().join("data.gguf")).expect("file read should succeed");
339        let summary = verify_gguf(&file_data).expect("operation should succeed");
340        assert_eq!(summary.tensors[0].name, "test_data");
341        assert_eq!(summary.tensors[0].shape, vec![8, 8]);
342        assert_eq!(summary.tensors[0].dtype, 0); // F32
343
344        // Extract and verify actual data
345        // Find data section: skip header (24) + metadata + tensor info
346        let mut pos = 24;
347        // Skip metadata
348        for _ in 0..summary.metadata_count {
349            // Skip key string
350            let key_len = u64::from_le_bytes(
351                file_data[pos..pos + 8].try_into().expect("conversion should succeed"),
352            ) as usize;
353            pos += 8 + key_len;
354            let value_type = u32::from_le_bytes(
355                file_data[pos..pos + 4].try_into().expect("conversion should succeed"),
356            );
357            pos += 4;
358            match value_type {
359                4..=6 => pos += 4, // U32/I32/F32
360                8 => {
361                    let len = u64::from_le_bytes(
362                        file_data[pos..pos + 8].try_into().expect("conversion should succeed"),
363                    ) as usize;
364                    pos += 8 + len;
365                }
366                10..=12 => pos += 8, // U64/I64/F64
367                _ => {}
368            }
369        }
370        // Skip tensor info
371        let name_len = u64::from_le_bytes(
372            file_data[pos..pos + 8].try_into().expect("conversion should succeed"),
373        ) as usize;
374        pos += 8 + name_len;
375        let n_dims = u32::from_le_bytes(
376            file_data[pos..pos + 4].try_into().expect("conversion should succeed"),
377        ) as usize;
378        pos += 4 + n_dims * 8 + 4 + 8; // dims + dtype + offset
379
380        // Skip alignment padding before tensor data — `aprender::format::gguf::
381        // export_tensors_to_gguf` writes `padding_for_alignment(header_size,
382        // GGUF_DEFAULT_ALIGNMENT)` zero bytes after the tensor-info section so
383        // tensor data begins at a 32-byte-aligned offset (types.rs:445). Without
384        // this skip we'd read padding zeros instead of the actual f32 bytes.
385        use aprender::format::gguf::{padding_for_alignment, GGUF_DEFAULT_ALIGNMENT};
386        pos += padding_for_alignment(pos, GGUF_DEFAULT_ALIGNMENT);
387
388        // Now pos is at the start of tensor data
389        let data_start = pos;
390        let recovered: Vec<f32> = (0..64)
391            .map(|i| {
392                let off = data_start + i * 4;
393                f32::from_le_bytes(
394                    file_data[off..off + 4].try_into().expect("conversion should succeed"),
395                )
396            })
397            .collect();
398        assert_eq!(original, recovered, "f32 data must survive pipeline exactly");
399    }
400
401    // =====================================================================
402    // TIER 3: File size monotonicity across varying tensor counts
403    // =====================================================================
404
405    #[test]
406    fn test_falsify_pipeline_size_monotonic_with_tensor_count() {
407        // For same quant mode, more tensors → larger file
408        let mut prev_size = 0u64;
409        for n in [1, 2, 4, 8] {
410            let mut weights = ModelWeights::new();
411            for i in 0..n {
412                weights.add_tensor(format!("layer.{i}.weight"), vec![1.0; 64], vec![8, 8]);
413            }
414            weights.metadata.num_params = n as u64 * 64;
415
416            let tmp = TempDir::new().expect("temp file creation should succeed");
417            let result =
418                quantize_and_export(&weights, GgufQuantization::None, tmp.path(), "model.gguf")
419                    .expect("operation should succeed");
420
421            assert!(
422                result.export.size_bytes > prev_size,
423                "F32 {n} tensors ({}) must be > prev ({prev_size})",
424                result.export.size_bytes
425            );
426            prev_size = result.export.size_bytes;
427        }
428    }
429
430    #[test]
431    fn test_falsify_pipeline_q4_size_monotonic_with_tensor_count() {
432        let mut prev_size = 0u64;
433        for n in [1, 2, 4, 8] {
434            let mut weights = ModelWeights::new();
435            for i in 0..n {
436                weights.add_tensor(format!("layer.{i}.weight"), vec![1.0; 64], vec![8, 8]);
437            }
438            weights.metadata.num_params = n as u64 * 64;
439
440            let tmp = TempDir::new().expect("temp file creation should succeed");
441            let result =
442                quantize_and_export(&weights, GgufQuantization::Q4_0, tmp.path(), "model.gguf")
443                    .expect("operation should succeed");
444
445            assert!(
446                result.export.size_bytes > prev_size,
447                "Q4_0 {n} tensors ({}) must be > prev ({prev_size})",
448                result.export.size_bytes
449            );
450            prev_size = result.export.size_bytes;
451        }
452    }
453
454    #[test]
455    fn test_falsify_pipeline_size_ordering_at_multiple_scales() {
456        // For various tensor element counts, always: Q4 < Q8 < F32
457        for n_elements in [32, 128, 512, 1024] {
458            let mut weights = ModelWeights::new();
459            weights.add_tensor("w", vec![0.5; n_elements], vec![n_elements]);
460            weights.metadata.num_params = n_elements as u64;
461
462            let sizes: Vec<(GgufQuantization, u64)> =
463                [GgufQuantization::None, GgufQuantization::Q8_0, GgufQuantization::Q4_0]
464                    .iter()
465                    .map(|&quant| {
466                        let tmp = TempDir::new().expect("temp file creation should succeed");
467                        let result = quantize_and_export(&weights, quant, tmp.path(), "m.gguf")
468                            .expect("operation should succeed");
469                        (quant, result.export.size_bytes)
470                    })
471                    .collect();
472
473            let (_, f32_size) = sizes[0];
474            let (_, q8_size) = sizes[1];
475            let (_, q4_size) = sizes[2];
476
477            assert!(
478                q4_size < q8_size,
479                "at {n_elements} elements: Q4={q4_size} must be < Q8={q8_size}"
480            );
481            assert!(
482                q8_size < f32_size,
483                "at {n_elements} elements: Q8={q8_size} must be < F32={f32_size}"
484            );
485        }
486    }
487
488    #[test]
489    fn test_falsify_pipeline_magic_bytes_all_quant_modes() {
490        let weights = make_test_weights();
491        for quant in [GgufQuantization::None, GgufQuantization::Q4_0, GgufQuantization::Q8_0] {
492            let tmp = TempDir::new().expect("temp file creation should succeed");
493            quantize_and_export(&weights, quant, tmp.path(), "model.gguf")
494                .expect("operation should succeed");
495            let file_data =
496                std::fs::read(tmp.path().join("model.gguf")).expect("file read should succeed");
497            assert_eq!(&file_data[0..4], b"GGUF", "magic bytes wrong for pipeline {quant:?}");
498        }
499    }
500
501    #[test]
502    fn test_falsify_pipeline_readme_file_size_field() {
503        let weights = make_test_weights();
504        let tmp = TempDir::new().expect("temp file creation should succeed");
505        let result =
506            quantize_and_export(&weights, GgufQuantization::None, tmp.path(), "model.gguf")
507                .expect("operation should succeed");
508        let readme = result.readme.expect("operation should succeed");
509        // README should contain the human-readable size string
510        let size_str = result.export.size_human();
511        assert!(
512            readme.contains(&size_str),
513            "README should contain size '{size_str}', got:\n{readme}"
514        );
515    }
516
517    #[test]
518    fn test_falsify_pipeline_readme_tensor_count() {
519        let weights = make_test_weights();
520        let tmp = TempDir::new().expect("temp file creation should succeed");
521        let result =
522            quantize_and_export(&weights, GgufQuantization::None, tmp.path(), "model.gguf")
523                .expect("operation should succeed");
524        let readme = result.readme.expect("operation should succeed");
525        assert!(
526            readme.contains(&format!("{}", result.export.num_tensors)),
527            "README should contain tensor count {}",
528            result.export.num_tensors
529        );
530    }
531
532    #[test]
533    fn test_falsify_pipeline_readme_has_yaml_frontmatter() {
534        let weights = make_test_weights();
535        let tmp = TempDir::new().expect("temp file creation should succeed");
536        let result =
537            quantize_and_export(&weights, GgufQuantization::None, tmp.path(), "model.gguf")
538                .expect("operation should succeed");
539        let readme = result.readme.expect("operation should succeed");
540        assert!(readme.starts_with("---\n"), "README must start with YAML frontmatter");
541        assert!(readme.contains("tags:"), "README must have tags in frontmatter");
542        assert!(readme.contains("- gguf"), "README frontmatter must tag 'gguf'");
543        assert!(readme.contains("- entrenar"), "README frontmatter must tag 'entrenar'");
544    }
545}