entrenar/hf_pipeline/export/
exporter.rs1use crate::hf_pipeline::error::{FetchError, Result};
4use serde::Serialize;
5use std::collections::HashMap;
6use std::path::{Path, PathBuf};
7
8use super::format::ExportFormat;
9use super::gguf_writer::{quantize_to_gguf_bytes, GgufQuantization};
10use super::result::ExportResult;
11use super::weights::{ModelMetadata, ModelWeights};
12
13pub struct Exporter {
15 pub(super) output_dir: PathBuf,
17 pub(super) default_format: ExportFormat,
19 pub(super) include_metadata: bool,
21 pub(super) gguf_quantization: GgufQuantization,
23}
24
25impl Default for Exporter {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl Exporter {
32 #[must_use]
34 pub fn new() -> Self {
35 Self {
36 output_dir: PathBuf::from("."),
37 default_format: ExportFormat::SafeTensors,
38 include_metadata: true,
39 gguf_quantization: GgufQuantization::None,
40 }
41 }
42
43 #[must_use]
45 pub fn output_dir(mut self, dir: impl Into<PathBuf>) -> Self {
46 self.output_dir = dir.into();
47 self
48 }
49
50 #[must_use]
52 pub fn default_format(mut self, format: ExportFormat) -> Self {
53 self.default_format = format;
54 self
55 }
56
57 #[must_use]
59 pub fn include_metadata(mut self, include: bool) -> Self {
60 self.include_metadata = include;
61 self
62 }
63
64 #[must_use]
66 pub fn gguf_quantization(mut self, quant: GgufQuantization) -> Self {
67 self.gguf_quantization = quant;
68 self
69 }
70
71 pub fn export(
73 &self,
74 weights: &ModelWeights,
75 format: ExportFormat,
76 filename: impl AsRef<Path>,
77 ) -> Result<ExportResult> {
78 let path = self.output_dir.join(filename);
79
80 if let Some(parent) = path.parent() {
82 std::fs::create_dir_all(parent).map_err(|e| FetchError::ConfigParseError {
83 message: format!("Failed to create output directory: {e}"),
84 })?;
85 }
86
87 match format {
88 ExportFormat::SafeTensors => self.export_safetensors(weights, &path),
89 ExportFormat::APR => self.export_apr(weights, &path),
90 ExportFormat::GGUF => self.export_gguf(weights, &path),
91 ExportFormat::PyTorch => Err(FetchError::PickleSecurityRisk),
92 }
93 }
94
95 fn export_safetensors(&self, weights: &ModelWeights, path: &Path) -> Result<ExportResult> {
97 let mut output = Vec::new();
99
100 let header = serde_json::json!({
102 "__metadata__": {
103 "format": "safetensors",
104 "version": "0.1.0",
105 "num_tensors": weights.tensors.len(),
106 "num_params": weights.param_count(),
107 }
108 });
109 let header_bytes = serde_json::to_vec(&header).map_err(|e| {
110 FetchError::ConfigParseError { message: format!("Failed to serialize header: {e}") }
111 })?;
112
113 output.extend_from_slice(&(header_bytes.len() as u64).to_le_bytes());
115 output.extend_from_slice(&header_bytes);
116
117 let data_size: usize = weights.tensors.values().map(|t| t.len() * 4).sum();
119 output.extend(vec![0u8; data_size.min(1024)]); std::fs::write(path, &output).map_err(|e| FetchError::ConfigParseError {
122 message: format!("Failed to write file: {e}"),
123 })?;
124
125 Ok(ExportResult {
126 path: path.to_path_buf(),
127 format: ExportFormat::SafeTensors,
128 size_bytes: output.len() as u64,
129 num_tensors: weights.tensors.len(),
130 })
131 }
132
133 fn export_apr(&self, weights: &ModelWeights, path: &Path) -> Result<ExportResult> {
135 #[derive(Serialize)]
136 struct AprFormat {
137 version: String,
138 metadata: ModelMetadata,
139 tensors: HashMap<String, AprTensor>,
140 }
141
142 #[derive(Serialize)]
143 struct AprTensor {
144 shape: Vec<usize>,
145 dtype: String,
146 data: Vec<f32>,
147 }
148
149 let apr = AprFormat {
150 version: "1.0".to_string(),
151 metadata: weights.metadata.clone(),
152 tensors: weights
153 .tensors
154 .iter()
155 .map(|(name, data)| {
156 let shape = weights.shapes.get(name).cloned().unwrap_or_default();
157 (
158 name.clone(),
159 AprTensor { shape, dtype: "f32".to_string(), data: data.clone() },
160 )
161 })
162 .collect(),
163 };
164
165 let json = serde_json::to_string_pretty(&apr).map_err(|e| {
166 FetchError::ConfigParseError { message: format!("Failed to serialize APR: {e}") }
167 })?;
168
169 std::fs::write(path, &json).map_err(|e| FetchError::ConfigParseError {
170 message: format!("Failed to write file: {e}"),
171 })?;
172
173 Ok(ExportResult {
174 path: path.to_path_buf(),
175 format: ExportFormat::APR,
176 size_bytes: json.len() as u64,
177 num_tensors: weights.tensors.len(),
178 })
179 }
180
181 fn export_gguf(&self, weights: &ModelWeights, path: &Path) -> Result<ExportResult> {
183 use aprender::format::gguf::{export_tensors_to_gguf, GgufTensor, GgufValue};
184
185 let mut metadata: Vec<(String, GgufValue)> = Vec::new();
187 if self.include_metadata {
188 if let Some(arch) = &weights.metadata.architecture {
189 metadata.push(("general.architecture".into(), GgufValue::String(arch.clone())));
190 }
191 if let Some(name) = &weights.metadata.model_name {
192 metadata.push(("general.name".into(), GgufValue::String(name.clone())));
193 }
194 metadata.push((
195 "general.parameter_count".into(),
196 GgufValue::Uint64(weights.metadata.num_params),
197 ));
198 if let Some(hidden) = weights.metadata.hidden_size {
199 metadata.push(("general.hidden_size".into(), GgufValue::Uint32(hidden as u32)));
200 }
201 if let Some(layers) = weights.metadata.num_layers {
202 metadata.push(("general.num_layers".into(), GgufValue::Uint32(layers as u32)));
203 }
204 }
205
206 let mut tensor_names: Vec<&String> = weights.tensors.keys().collect();
208 tensor_names.sort();
209
210 let mut tensors: Vec<GgufTensor> = Vec::new();
211 for name in &tensor_names {
212 let data = &weights.tensors[*name];
213 let shape = weights.shapes.get(*name).cloned().unwrap_or_else(|| vec![data.len()]);
214 let (bytes, dtype) = quantize_to_gguf_bytes(data, self.gguf_quantization);
215 tensors.push(GgufTensor {
216 name: (*name).clone(),
217 shape: shape.iter().map(|&d| d as u64).collect(),
218 dtype,
219 data: bytes,
220 });
221 }
222
223 let mut file = std::fs::File::create(path).map_err(|e| FetchError::GgufWriteError {
225 message: format!("Failed to create GGUF file: {e}"),
226 })?;
227 export_tensors_to_gguf(&mut file, &tensors, &metadata).map_err(|e| {
228 FetchError::GgufWriteError { message: format!("Failed to write GGUF data: {e}") }
229 })?;
230
231 let size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
232
233 Ok(ExportResult {
234 path: path.to_path_buf(),
235 format: ExportFormat::GGUF,
236 size_bytes: size,
237 num_tensors: tensor_names.len(),
238 })
239 }
240
241 pub fn export_auto(
243 &self,
244 weights: &ModelWeights,
245 filename: impl AsRef<Path>,
246 ) -> Result<ExportResult> {
247 let path = filename.as_ref();
248 let format = ExportFormat::from_path(path).unwrap_or(self.default_format);
249 self.export(weights, format, path)
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use crate::hf_pipeline::export::weights::ModelMetadata;
257
258 fn make_test_weights() -> ModelWeights {
259 let mut weights = ModelWeights::new();
260 weights.add_tensor("layer.0.weight", vec![1.0; 64], vec![8, 8]);
261 weights.metadata = ModelMetadata {
262 model_name: Some("test-model".to_string()),
263 architecture: Some("llama".to_string()),
264 num_params: 64,
265 ..Default::default()
266 };
267 weights
268 }
269
270 #[test]
275 fn test_falsify_exporter_default_values() {
276 let exp = Exporter::new();
277 assert_eq!(exp.output_dir, PathBuf::from("."));
278 assert_eq!(exp.default_format, ExportFormat::SafeTensors);
279 assert!(exp.include_metadata);
280 assert_eq!(exp.gguf_quantization, GgufQuantization::None);
281 }
282
283 #[test]
284 fn test_falsify_exporter_default_eq_new() {
285 let a = Exporter::new();
286 let b = Exporter::default();
287 assert_eq!(a.output_dir, b.output_dir);
288 assert_eq!(a.default_format, b.default_format);
289 assert_eq!(a.include_metadata, b.include_metadata);
290 assert_eq!(a.gguf_quantization, b.gguf_quantization);
291 }
292
293 #[test]
294 fn test_falsify_builder_order_independence() {
295 let weights = make_test_weights();
296 let dir = tempfile::tempdir().expect("temp file creation should succeed");
297
298 let result1 = Exporter::new()
299 .output_dir(dir.path())
300 .gguf_quantization(GgufQuantization::Q4_0)
301 .include_metadata(false)
302 .export(&weights, ExportFormat::GGUF, "a.gguf")
303 .expect("operation should succeed");
304
305 let result2 = Exporter::new()
306 .include_metadata(false)
307 .gguf_quantization(GgufQuantization::Q4_0)
308 .output_dir(dir.path())
309 .export(&weights, ExportFormat::GGUF, "b.gguf")
310 .expect("operation should succeed");
311
312 assert_eq!(result1.size_bytes, result2.size_bytes);
313 assert_eq!(result1.num_tensors, result2.num_tensors);
314 }
315
316 #[test]
317 fn test_falsify_builder_setter_override() {
318 let weights = make_test_weights();
319 let dir = tempfile::tempdir().expect("temp file creation should succeed");
320
321 let _result = Exporter::new()
323 .output_dir(dir.path())
324 .gguf_quantization(GgufQuantization::Q8_0)
325 .gguf_quantization(GgufQuantization::Q4_0)
326 .include_metadata(false)
327 .export(&weights, ExportFormat::GGUF, "override.gguf")
328 .expect("operation should succeed");
329
330 let file_data =
331 std::fs::read(dir.path().join("override.gguf")).expect("file read should succeed");
332 let summary = crate::hf_pipeline::export::gguf_verify::verify_gguf(&file_data)
333 .expect("operation should succeed");
334 assert_eq!(summary.tensors[0].dtype, 2, "override should use Q4_0");
336 }
337
338 #[test]
343 fn test_falsify_pytorch_format_rejected() {
344 let weights = make_test_weights();
345 let dir = tempfile::tempdir().expect("temp file creation should succeed");
346 let exporter = Exporter::new().output_dir(dir.path());
347 let result = exporter.export(&weights, ExportFormat::PyTorch, "model.pt");
348 assert!(result.is_err(), "PyTorch export must be rejected");
349 let err = result.unwrap_err();
350 assert!(
351 matches!(err, FetchError::PickleSecurityRisk),
352 "error must be PickleSecurityRisk, got {err:?}"
353 );
354 }
355
356 #[test]
357 fn test_falsify_safetensors_export_works() {
358 let weights = make_test_weights();
359 let dir = tempfile::tempdir().expect("temp file creation should succeed");
360 let exporter = Exporter::new().output_dir(dir.path());
361 let result = exporter
362 .export(&weights, ExportFormat::SafeTensors, "model.safetensors")
363 .expect("deserialization should succeed");
364 assert_eq!(result.format, ExportFormat::SafeTensors);
365 assert!(result.size_bytes > 0);
366 assert!(dir.path().join("model.safetensors").exists());
367 }
368
369 #[test]
370 fn test_falsify_apr_export_works() {
371 let weights = make_test_weights();
372 let dir = tempfile::tempdir().expect("temp file creation should succeed");
373 let exporter = Exporter::new().output_dir(dir.path());
374 let result = exporter
375 .export(&weights, ExportFormat::APR, "model.apr.json")
376 .expect("operation should succeed");
377 assert_eq!(result.format, ExportFormat::APR);
378 assert!(result.size_bytes > 0);
379 assert!(dir.path().join("model.apr.json").exists());
380 }
381
382 #[test]
383 fn test_falsify_safetensors_ignores_quantization_setting() {
384 let weights = make_test_weights();
385 let dir = tempfile::tempdir().expect("temp file creation should succeed");
386 let exporter =
388 Exporter::new().output_dir(dir.path()).gguf_quantization(GgufQuantization::Q4_0);
389 let result = exporter
390 .export(&weights, ExportFormat::SafeTensors, "model.safetensors")
391 .expect("deserialization should succeed");
392 assert_eq!(result.format, ExportFormat::SafeTensors);
393 assert!(result.size_bytes > 0);
394 }
395
396 #[test]
401 fn test_falsify_export_auto_detects_gguf() {
402 let weights = make_test_weights();
403 let dir = tempfile::tempdir().expect("temp file creation should succeed");
404 let exporter = Exporter::new().output_dir(dir.path()).default_format(ExportFormat::APR);
405 let result =
406 exporter.export_auto(&weights, "model.gguf").expect("operation should succeed");
407 assert_eq!(result.format, ExportFormat::GGUF);
408 }
409
410 #[test]
411 fn test_falsify_export_auto_detects_safetensors() {
412 let weights = make_test_weights();
413 let dir = tempfile::tempdir().expect("temp file creation should succeed");
414 let exporter = Exporter::new().output_dir(dir.path()).default_format(ExportFormat::GGUF);
415 let result =
416 exporter.export_auto(&weights, "model.safetensors").expect("operation should succeed");
417 assert_eq!(result.format, ExportFormat::SafeTensors);
418 }
419
420 #[test]
421 fn test_falsify_export_auto_detects_apr() {
422 let weights = make_test_weights();
423 let dir = tempfile::tempdir().expect("temp file creation should succeed");
424 let exporter = Exporter::new().output_dir(dir.path()).default_format(ExportFormat::GGUF);
425 let result =
426 exporter.export_auto(&weights, "model.apr.json").expect("operation should succeed");
427 assert_eq!(result.format, ExportFormat::APR);
428 }
429
430 #[test]
431 fn test_falsify_export_auto_unknown_extension_uses_default() {
432 let weights = make_test_weights();
433 let dir = tempfile::tempdir().expect("temp file creation should succeed");
434 let exporter = Exporter::new().output_dir(dir.path()).default_format(ExportFormat::GGUF);
435 let result =
436 exporter.export_auto(&weights, "model.unknown").expect("operation should succeed");
437 assert_eq!(result.format, ExportFormat::GGUF);
438 }
439
440 #[test]
445 fn test_falsify_num_tensors_matches_input() {
446 for n in [0, 1, 3, 10] {
447 let mut weights = ModelWeights::new();
448 for i in 0..n {
449 weights.add_tensor(format!("t.{i}"), vec![1.0], vec![1]);
450 }
451
452 let dir = tempfile::tempdir().expect("temp file creation should succeed");
453 let exporter = Exporter::new().output_dir(dir.path()).include_metadata(false);
454 let result = exporter
455 .export(&weights, ExportFormat::GGUF, "count.gguf")
456 .expect("operation should succeed");
457 assert_eq!(result.num_tensors, n, "num_tensors mismatch for {n} input tensors");
458
459 let file_data =
460 std::fs::read(dir.path().join("count.gguf")).expect("file read should succeed");
461 let summary = crate::hf_pipeline::export::gguf_verify::verify_gguf(&file_data)
462 .expect("operation should succeed");
463 assert_eq!(summary.tensor_count, n as u64, "GGUF header tensor_count mismatch for {n}");
464 }
465 }
466}