Skip to main content

entrenar/hf_pipeline/export/
format.rs

1//! Export format selection and detection.
2
3use serde::{Deserialize, Serialize};
4use std::path::Path;
5
6/// Export format selection
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
8pub enum ExportFormat {
9    /// SafeTensors format (recommended)
10    SafeTensors,
11    /// Aprender format (JSON-based)
12    APR,
13    /// GGUF quantized format
14    GGUF,
15    /// PyTorch state dict (for compatibility)
16    PyTorch,
17}
18
19impl ExportFormat {
20    /// Get file extension for format
21    #[must_use]
22    pub fn extension(&self) -> &'static str {
23        match self {
24            Self::SafeTensors => "safetensors",
25            Self::APR => "apr.json",
26            Self::GGUF => "gguf",
27            Self::PyTorch => "pt",
28        }
29    }
30
31    /// Check if format is safe (no pickle/arbitrary code)
32    #[must_use]
33    pub fn is_safe(&self) -> bool {
34        matches!(self, Self::SafeTensors | Self::APR | Self::GGUF)
35    }
36
37    /// Detect format from file path
38    #[must_use]
39    pub fn from_path(path: &Path) -> Option<Self> {
40        let name = path.file_name()?.to_str()?;
41        if name.ends_with(".safetensors") {
42            Some(Self::SafeTensors)
43        } else if name.ends_with(".apr.json") || name.ends_with(".apr") {
44            Some(Self::APR)
45        } else if name.ends_with(".gguf") {
46            Some(Self::GGUF)
47        } else if name.ends_with(".pt") || name.ends_with(".bin") {
48            Some(Self::PyTorch)
49        } else {
50            None
51        }
52    }
53}
54
55impl std::fmt::Display for ExportFormat {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        match self {
58            Self::SafeTensors => write!(f, "SafeTensors"),
59            Self::APR => write!(f, "APR"),
60            Self::GGUF => write!(f, "GGUF"),
61            Self::PyTorch => write!(f, "PyTorch"),
62        }
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69
70    // =================================================================
71    // TIER 4: from_path() exhaustive coverage
72    // =================================================================
73
74    #[test]
75    fn test_falsify_from_path_gguf() {
76        assert_eq!(ExportFormat::from_path(Path::new("model.gguf")), Some(ExportFormat::GGUF));
77    }
78
79    #[test]
80    fn test_falsify_from_path_safetensors() {
81        assert_eq!(
82            ExportFormat::from_path(Path::new("model.safetensors")),
83            Some(ExportFormat::SafeTensors)
84        );
85    }
86
87    #[test]
88    fn test_falsify_from_path_apr_json() {
89        assert_eq!(ExportFormat::from_path(Path::new("model.apr.json")), Some(ExportFormat::APR));
90    }
91
92    #[test]
93    fn test_falsify_from_path_apr() {
94        assert_eq!(ExportFormat::from_path(Path::new("model.apr")), Some(ExportFormat::APR));
95    }
96
97    #[test]
98    fn test_falsify_from_path_pytorch_pt() {
99        assert_eq!(ExportFormat::from_path(Path::new("model.pt")), Some(ExportFormat::PyTorch));
100    }
101
102    #[test]
103    fn test_falsify_from_path_pytorch_bin() {
104        assert_eq!(ExportFormat::from_path(Path::new("model.bin")), Some(ExportFormat::PyTorch));
105    }
106
107    #[test]
108    fn test_falsify_from_path_unknown() {
109        assert_eq!(ExportFormat::from_path(Path::new("model.xyz")), None);
110    }
111
112    #[test]
113    fn test_falsify_from_path_no_extension() {
114        assert_eq!(ExportFormat::from_path(Path::new("model")), None);
115    }
116
117    #[test]
118    fn test_falsify_from_path_nested_path() {
119        assert_eq!(
120            ExportFormat::from_path(Path::new("/deep/nested/dir/model.gguf")),
121            Some(ExportFormat::GGUF)
122        );
123    }
124
125    // =================================================================
126    // Extension & safety coverage
127    // =================================================================
128
129    #[test]
130    fn test_falsify_extension_roundtrip() {
131        // extension() output should be recognized by from_path()
132        for fmt in [
133            ExportFormat::SafeTensors,
134            ExportFormat::APR,
135            ExportFormat::GGUF,
136            ExportFormat::PyTorch,
137        ] {
138            let filename = format!("model.{}", fmt.extension());
139            let detected = ExportFormat::from_path(Path::new(&filename));
140            assert_eq!(
141                detected,
142                Some(fmt),
143                "extension '{}' should roundtrip for {fmt:?}",
144                fmt.extension()
145            );
146        }
147    }
148
149    #[test]
150    fn test_falsify_is_safe() {
151        assert!(ExportFormat::SafeTensors.is_safe());
152        assert!(ExportFormat::APR.is_safe());
153        assert!(ExportFormat::GGUF.is_safe());
154        assert!(!ExportFormat::PyTorch.is_safe());
155    }
156
157    #[test]
158    fn test_falsify_display() {
159        assert_eq!(ExportFormat::SafeTensors.to_string(), "SafeTensors");
160        assert_eq!(ExportFormat::APR.to_string(), "APR");
161        assert_eq!(ExportFormat::GGUF.to_string(), "GGUF");
162        assert_eq!(ExportFormat::PyTorch.to_string(), "PyTorch");
163    }
164}