Skip to main content

entrenar/hf_pipeline/export/
exporter.rs

1//! Model exporter implementation.
2
3use crate::hf_pipeline::error::{FetchError, Result};
4use serde::Serialize;
5use std::collections::HashMap;
6use std::path::{Path, PathBuf};
7
8use super::format::ExportFormat;
9use super::gguf_writer::{quantize_to_gguf_bytes, GgufQuantization};
10use super::result::ExportResult;
11use super::weights::{ModelMetadata, ModelWeights};
12
13/// Model exporter
14pub struct Exporter {
15    /// Output directory
16    pub(super) output_dir: PathBuf,
17    /// Default format
18    pub(super) default_format: ExportFormat,
19    /// Include metadata
20    pub(super) include_metadata: bool,
21    /// GGUF quantization mode
22    pub(super) gguf_quantization: GgufQuantization,
23}
24
25impl Default for Exporter {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl Exporter {
32    /// Create new exporter
33    #[must_use]
34    pub fn new() -> Self {
35        Self {
36            output_dir: PathBuf::from("."),
37            default_format: ExportFormat::SafeTensors,
38            include_metadata: true,
39            gguf_quantization: GgufQuantization::None,
40        }
41    }
42
43    /// Set output directory
44    #[must_use]
45    pub fn output_dir(mut self, dir: impl Into<PathBuf>) -> Self {
46        self.output_dir = dir.into();
47        self
48    }
49
50    /// Set default format
51    #[must_use]
52    pub fn default_format(mut self, format: ExportFormat) -> Self {
53        self.default_format = format;
54        self
55    }
56
57    /// Set whether to include metadata
58    #[must_use]
59    pub fn include_metadata(mut self, include: bool) -> Self {
60        self.include_metadata = include;
61        self
62    }
63
64    /// Set GGUF quantization mode
65    #[must_use]
66    pub fn gguf_quantization(mut self, quant: GgufQuantization) -> Self {
67        self.gguf_quantization = quant;
68        self
69    }
70
71    /// Export weights to file
72    pub fn export(
73        &self,
74        weights: &ModelWeights,
75        format: ExportFormat,
76        filename: impl AsRef<Path>,
77    ) -> Result<ExportResult> {
78        let path = self.output_dir.join(filename);
79
80        // Ensure parent directory exists
81        if let Some(parent) = path.parent() {
82            std::fs::create_dir_all(parent).map_err(|e| FetchError::ConfigParseError {
83                message: format!("Failed to create output directory: {e}"),
84            })?;
85        }
86
87        match format {
88            ExportFormat::SafeTensors => self.export_safetensors(weights, &path),
89            ExportFormat::APR => self.export_apr(weights, &path),
90            ExportFormat::GGUF => self.export_gguf(weights, &path),
91            ExportFormat::PyTorch => Err(FetchError::PickleSecurityRisk),
92        }
93    }
94
95    /// Export to SafeTensors format
96    fn export_safetensors(&self, weights: &ModelWeights, path: &Path) -> Result<ExportResult> {
97        // Mock implementation - actual safetensors serialization would use the safetensors crate
98        let mut output = Vec::new();
99
100        // Header
101        let header = serde_json::json!({
102            "__metadata__": {
103                "format": "safetensors",
104                "version": "0.1.0",
105                "num_tensors": weights.tensors.len(),
106                "num_params": weights.param_count(),
107            }
108        });
109        let header_bytes = serde_json::to_vec(&header).map_err(|e| {
110            FetchError::ConfigParseError { message: format!("Failed to serialize header: {e}") }
111        })?;
112
113        // Write header length (8 bytes, little-endian)
114        output.extend_from_slice(&(header_bytes.len() as u64).to_le_bytes());
115        output.extend_from_slice(&header_bytes);
116
117        // Write tensor data (mock - just count bytes)
118        let data_size: usize = weights.tensors.values().map(|t| t.len() * 4).sum();
119        output.extend(vec![0u8; data_size.min(1024)]); // Truncate for mock
120
121        std::fs::write(path, &output).map_err(|e| FetchError::ConfigParseError {
122            message: format!("Failed to write file: {e}"),
123        })?;
124
125        Ok(ExportResult {
126            path: path.to_path_buf(),
127            format: ExportFormat::SafeTensors,
128            size_bytes: output.len() as u64,
129            num_tensors: weights.tensors.len(),
130        })
131    }
132
133    /// Export to APR format (JSON-based)
134    fn export_apr(&self, weights: &ModelWeights, path: &Path) -> Result<ExportResult> {
135        #[derive(Serialize)]
136        struct AprFormat {
137            version: String,
138            metadata: ModelMetadata,
139            tensors: HashMap<String, AprTensor>,
140        }
141
142        #[derive(Serialize)]
143        struct AprTensor {
144            shape: Vec<usize>,
145            dtype: String,
146            data: Vec<f32>,
147        }
148
149        let apr = AprFormat {
150            version: "1.0".to_string(),
151            metadata: weights.metadata.clone(),
152            tensors: weights
153                .tensors
154                .iter()
155                .map(|(name, data)| {
156                    let shape = weights.shapes.get(name).cloned().unwrap_or_default();
157                    (
158                        name.clone(),
159                        AprTensor { shape, dtype: "f32".to_string(), data: data.clone() },
160                    )
161                })
162                .collect(),
163        };
164
165        let json = serde_json::to_string_pretty(&apr).map_err(|e| {
166            FetchError::ConfigParseError { message: format!("Failed to serialize APR: {e}") }
167        })?;
168
169        std::fs::write(path, &json).map_err(|e| FetchError::ConfigParseError {
170            message: format!("Failed to write file: {e}"),
171        })?;
172
173        Ok(ExportResult {
174            path: path.to_path_buf(),
175            format: ExportFormat::APR,
176            size_bytes: json.len() as u64,
177            num_tensors: weights.tensors.len(),
178        })
179    }
180
181    /// Export to GGUF format with real tensor data (delegates to aprender)
182    fn export_gguf(&self, weights: &ModelWeights, path: &Path) -> Result<ExportResult> {
183        use aprender::format::gguf::{export_tensors_to_gguf, GgufTensor, GgufValue};
184
185        // Build metadata
186        let mut metadata: Vec<(String, GgufValue)> = Vec::new();
187        if self.include_metadata {
188            if let Some(arch) = &weights.metadata.architecture {
189                metadata.push(("general.architecture".into(), GgufValue::String(arch.clone())));
190            }
191            if let Some(name) = &weights.metadata.model_name {
192                metadata.push(("general.name".into(), GgufValue::String(name.clone())));
193            }
194            metadata.push((
195                "general.parameter_count".into(),
196                GgufValue::Uint64(weights.metadata.num_params),
197            ));
198            if let Some(hidden) = weights.metadata.hidden_size {
199                metadata.push(("general.hidden_size".into(), GgufValue::Uint32(hidden as u32)));
200            }
201            if let Some(layers) = weights.metadata.num_layers {
202                metadata.push(("general.num_layers".into(), GgufValue::Uint32(layers as u32)));
203            }
204        }
205
206        // Build tensors — sort names for deterministic output
207        let mut tensor_names: Vec<&String> = weights.tensors.keys().collect();
208        tensor_names.sort();
209
210        let mut tensors: Vec<GgufTensor> = Vec::new();
211        for name in &tensor_names {
212            let data = &weights.tensors[*name];
213            let shape = weights.shapes.get(*name).cloned().unwrap_or_else(|| vec![data.len()]);
214            let (bytes, dtype) = quantize_to_gguf_bytes(data, self.gguf_quantization);
215            tensors.push(GgufTensor {
216                name: (*name).clone(),
217                shape: shape.iter().map(|&d| d as u64).collect(),
218                dtype,
219                data: bytes,
220            });
221        }
222
223        // Write via aprender
224        let mut file = std::fs::File::create(path).map_err(|e| FetchError::GgufWriteError {
225            message: format!("Failed to create GGUF file: {e}"),
226        })?;
227        export_tensors_to_gguf(&mut file, &tensors, &metadata).map_err(|e| {
228            FetchError::GgufWriteError { message: format!("Failed to write GGUF data: {e}") }
229        })?;
230
231        let size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
232
233        Ok(ExportResult {
234            path: path.to_path_buf(),
235            format: ExportFormat::GGUF,
236            size_bytes: size,
237            num_tensors: tensor_names.len(),
238        })
239    }
240
241    /// Export with automatic format detection from filename
242    pub fn export_auto(
243        &self,
244        weights: &ModelWeights,
245        filename: impl AsRef<Path>,
246    ) -> Result<ExportResult> {
247        let path = filename.as_ref();
248        let format = ExportFormat::from_path(path).unwrap_or(self.default_format);
249        self.export(weights, format, path)
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use crate::hf_pipeline::export::weights::ModelMetadata;
257
258    fn make_test_weights() -> ModelWeights {
259        let mut weights = ModelWeights::new();
260        weights.add_tensor("layer.0.weight", vec![1.0; 64], vec![8, 8]);
261        weights.metadata = ModelMetadata {
262            model_name: Some("test-model".to_string()),
263            architecture: Some("llama".to_string()),
264            num_params: 64,
265            ..Default::default()
266        };
267        weights
268    }
269
270    // =================================================================
271    // TIER 4: Builder pattern & defaults
272    // =================================================================
273
274    #[test]
275    fn test_falsify_exporter_default_values() {
276        let exp = Exporter::new();
277        assert_eq!(exp.output_dir, PathBuf::from("."));
278        assert_eq!(exp.default_format, ExportFormat::SafeTensors);
279        assert!(exp.include_metadata);
280        assert_eq!(exp.gguf_quantization, GgufQuantization::None);
281    }
282
283    #[test]
284    fn test_falsify_exporter_default_eq_new() {
285        let a = Exporter::new();
286        let b = Exporter::default();
287        assert_eq!(a.output_dir, b.output_dir);
288        assert_eq!(a.default_format, b.default_format);
289        assert_eq!(a.include_metadata, b.include_metadata);
290        assert_eq!(a.gguf_quantization, b.gguf_quantization);
291    }
292
293    #[test]
294    fn test_falsify_builder_order_independence() {
295        let weights = make_test_weights();
296        let dir = tempfile::tempdir().expect("temp file creation should succeed");
297
298        let result1 = Exporter::new()
299            .output_dir(dir.path())
300            .gguf_quantization(GgufQuantization::Q4_0)
301            .include_metadata(false)
302            .export(&weights, ExportFormat::GGUF, "a.gguf")
303            .expect("operation should succeed");
304
305        let result2 = Exporter::new()
306            .include_metadata(false)
307            .gguf_quantization(GgufQuantization::Q4_0)
308            .output_dir(dir.path())
309            .export(&weights, ExportFormat::GGUF, "b.gguf")
310            .expect("operation should succeed");
311
312        assert_eq!(result1.size_bytes, result2.size_bytes);
313        assert_eq!(result1.num_tensors, result2.num_tensors);
314    }
315
316    #[test]
317    fn test_falsify_builder_setter_override() {
318        let weights = make_test_weights();
319        let dir = tempfile::tempdir().expect("temp file creation should succeed");
320
321        // Set Q8_0 then override to Q4_0
322        let _result = Exporter::new()
323            .output_dir(dir.path())
324            .gguf_quantization(GgufQuantization::Q8_0)
325            .gguf_quantization(GgufQuantization::Q4_0)
326            .include_metadata(false)
327            .export(&weights, ExportFormat::GGUF, "override.gguf")
328            .expect("operation should succeed");
329
330        let file_data =
331            std::fs::read(dir.path().join("override.gguf")).expect("file read should succeed");
332        let summary = crate::hf_pipeline::export::gguf_verify::verify_gguf(&file_data)
333            .expect("operation should succeed");
334        // Should be Q4_0 (dtype=2), not Q8_0 (dtype=8)
335        assert_eq!(summary.tensors[0].dtype, 2, "override should use Q4_0");
336    }
337
338    // =================================================================
339    // TIER 4: Format rejection & regression
340    // =================================================================
341
342    #[test]
343    fn test_falsify_pytorch_format_rejected() {
344        let weights = make_test_weights();
345        let dir = tempfile::tempdir().expect("temp file creation should succeed");
346        let exporter = Exporter::new().output_dir(dir.path());
347        let result = exporter.export(&weights, ExportFormat::PyTorch, "model.pt");
348        assert!(result.is_err(), "PyTorch export must be rejected");
349        let err = result.unwrap_err();
350        assert!(
351            matches!(err, FetchError::PickleSecurityRisk),
352            "error must be PickleSecurityRisk, got {err:?}"
353        );
354    }
355
356    #[test]
357    fn test_falsify_safetensors_export_works() {
358        let weights = make_test_weights();
359        let dir = tempfile::tempdir().expect("temp file creation should succeed");
360        let exporter = Exporter::new().output_dir(dir.path());
361        let result = exporter
362            .export(&weights, ExportFormat::SafeTensors, "model.safetensors")
363            .expect("deserialization should succeed");
364        assert_eq!(result.format, ExportFormat::SafeTensors);
365        assert!(result.size_bytes > 0);
366        assert!(dir.path().join("model.safetensors").exists());
367    }
368
369    #[test]
370    fn test_falsify_apr_export_works() {
371        let weights = make_test_weights();
372        let dir = tempfile::tempdir().expect("temp file creation should succeed");
373        let exporter = Exporter::new().output_dir(dir.path());
374        let result = exporter
375            .export(&weights, ExportFormat::APR, "model.apr.json")
376            .expect("operation should succeed");
377        assert_eq!(result.format, ExportFormat::APR);
378        assert!(result.size_bytes > 0);
379        assert!(dir.path().join("model.apr.json").exists());
380    }
381
382    #[test]
383    fn test_falsify_safetensors_ignores_quantization_setting() {
384        let weights = make_test_weights();
385        let dir = tempfile::tempdir().expect("temp file creation should succeed");
386        // Set Q4_0 quant — should be silently ignored for SafeTensors
387        let exporter =
388            Exporter::new().output_dir(dir.path()).gguf_quantization(GgufQuantization::Q4_0);
389        let result = exporter
390            .export(&weights, ExportFormat::SafeTensors, "model.safetensors")
391            .expect("deserialization should succeed");
392        assert_eq!(result.format, ExportFormat::SafeTensors);
393        assert!(result.size_bytes > 0);
394    }
395
396    // =================================================================
397    // TIER 4: export_auto() format detection
398    // =================================================================
399
400    #[test]
401    fn test_falsify_export_auto_detects_gguf() {
402        let weights = make_test_weights();
403        let dir = tempfile::tempdir().expect("temp file creation should succeed");
404        let exporter = Exporter::new().output_dir(dir.path()).default_format(ExportFormat::APR);
405        let result =
406            exporter.export_auto(&weights, "model.gguf").expect("operation should succeed");
407        assert_eq!(result.format, ExportFormat::GGUF);
408    }
409
410    #[test]
411    fn test_falsify_export_auto_detects_safetensors() {
412        let weights = make_test_weights();
413        let dir = tempfile::tempdir().expect("temp file creation should succeed");
414        let exporter = Exporter::new().output_dir(dir.path()).default_format(ExportFormat::GGUF);
415        let result =
416            exporter.export_auto(&weights, "model.safetensors").expect("operation should succeed");
417        assert_eq!(result.format, ExportFormat::SafeTensors);
418    }
419
420    #[test]
421    fn test_falsify_export_auto_detects_apr() {
422        let weights = make_test_weights();
423        let dir = tempfile::tempdir().expect("temp file creation should succeed");
424        let exporter = Exporter::new().output_dir(dir.path()).default_format(ExportFormat::GGUF);
425        let result =
426            exporter.export_auto(&weights, "model.apr.json").expect("operation should succeed");
427        assert_eq!(result.format, ExportFormat::APR);
428    }
429
430    #[test]
431    fn test_falsify_export_auto_unknown_extension_uses_default() {
432        let weights = make_test_weights();
433        let dir = tempfile::tempdir().expect("temp file creation should succeed");
434        let exporter = Exporter::new().output_dir(dir.path()).default_format(ExportFormat::GGUF);
435        let result =
436            exporter.export_auto(&weights, "model.unknown").expect("operation should succeed");
437        assert_eq!(result.format, ExportFormat::GGUF);
438    }
439
440    // =================================================================
441    // TIER 4: num_tensors invariant
442    // =================================================================
443
444    #[test]
445    fn test_falsify_num_tensors_matches_input() {
446        for n in [0, 1, 3, 10] {
447            let mut weights = ModelWeights::new();
448            for i in 0..n {
449                weights.add_tensor(format!("t.{i}"), vec![1.0], vec![1]);
450            }
451
452            let dir = tempfile::tempdir().expect("temp file creation should succeed");
453            let exporter = Exporter::new().output_dir(dir.path()).include_metadata(false);
454            let result = exporter
455                .export(&weights, ExportFormat::GGUF, "count.gguf")
456                .expect("operation should succeed");
457            assert_eq!(result.num_tensors, n, "num_tensors mismatch for {n} input tensors");
458
459            let file_data =
460                std::fs::read(dir.path().join("count.gguf")).expect("file read should succeed");
461            let summary = crate::hf_pipeline::export::gguf_verify::verify_gguf(&file_data)
462                .expect("operation should succeed");
463            assert_eq!(summary.tensor_count, n as u64, "GGUF header tensor_count mismatch for {n}");
464        }
465    }
466}