entrenar/hf_pipeline/export/
format.rs1use serde::{Deserialize, Serialize};
4use std::path::Path;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
8pub enum ExportFormat {
9 SafeTensors,
11 APR,
13 GGUF,
15 PyTorch,
17}
18
19impl ExportFormat {
20 #[must_use]
22 pub fn extension(&self) -> &'static str {
23 match self {
24 Self::SafeTensors => "safetensors",
25 Self::APR => "apr.json",
26 Self::GGUF => "gguf",
27 Self::PyTorch => "pt",
28 }
29 }
30
31 #[must_use]
33 pub fn is_safe(&self) -> bool {
34 matches!(self, Self::SafeTensors | Self::APR | Self::GGUF)
35 }
36
37 #[must_use]
39 pub fn from_path(path: &Path) -> Option<Self> {
40 let name = path.file_name()?.to_str()?;
41 if name.ends_with(".safetensors") {
42 Some(Self::SafeTensors)
43 } else if name.ends_with(".apr.json") || name.ends_with(".apr") {
44 Some(Self::APR)
45 } else if name.ends_with(".gguf") {
46 Some(Self::GGUF)
47 } else if name.ends_with(".pt") || name.ends_with(".bin") {
48 Some(Self::PyTorch)
49 } else {
50 None
51 }
52 }
53}
54
55impl std::fmt::Display for ExportFormat {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 Self::SafeTensors => write!(f, "SafeTensors"),
59 Self::APR => write!(f, "APR"),
60 Self::GGUF => write!(f, "GGUF"),
61 Self::PyTorch => write!(f, "PyTorch"),
62 }
63 }
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69
70 #[test]
75 fn test_falsify_from_path_gguf() {
76 assert_eq!(ExportFormat::from_path(Path::new("model.gguf")), Some(ExportFormat::GGUF));
77 }
78
79 #[test]
80 fn test_falsify_from_path_safetensors() {
81 assert_eq!(
82 ExportFormat::from_path(Path::new("model.safetensors")),
83 Some(ExportFormat::SafeTensors)
84 );
85 }
86
87 #[test]
88 fn test_falsify_from_path_apr_json() {
89 assert_eq!(ExportFormat::from_path(Path::new("model.apr.json")), Some(ExportFormat::APR));
90 }
91
92 #[test]
93 fn test_falsify_from_path_apr() {
94 assert_eq!(ExportFormat::from_path(Path::new("model.apr")), Some(ExportFormat::APR));
95 }
96
97 #[test]
98 fn test_falsify_from_path_pytorch_pt() {
99 assert_eq!(ExportFormat::from_path(Path::new("model.pt")), Some(ExportFormat::PyTorch));
100 }
101
102 #[test]
103 fn test_falsify_from_path_pytorch_bin() {
104 assert_eq!(ExportFormat::from_path(Path::new("model.bin")), Some(ExportFormat::PyTorch));
105 }
106
107 #[test]
108 fn test_falsify_from_path_unknown() {
109 assert_eq!(ExportFormat::from_path(Path::new("model.xyz")), None);
110 }
111
112 #[test]
113 fn test_falsify_from_path_no_extension() {
114 assert_eq!(ExportFormat::from_path(Path::new("model")), None);
115 }
116
117 #[test]
118 fn test_falsify_from_path_nested_path() {
119 assert_eq!(
120 ExportFormat::from_path(Path::new("/deep/nested/dir/model.gguf")),
121 Some(ExportFormat::GGUF)
122 );
123 }
124
125 #[test]
130 fn test_falsify_extension_roundtrip() {
131 for fmt in [
133 ExportFormat::SafeTensors,
134 ExportFormat::APR,
135 ExportFormat::GGUF,
136 ExportFormat::PyTorch,
137 ] {
138 let filename = format!("model.{}", fmt.extension());
139 let detected = ExportFormat::from_path(Path::new(&filename));
140 assert_eq!(
141 detected,
142 Some(fmt),
143 "extension '{}' should roundtrip for {fmt:?}",
144 fmt.extension()
145 );
146 }
147 }
148
149 #[test]
150 fn test_falsify_is_safe() {
151 assert!(ExportFormat::SafeTensors.is_safe());
152 assert!(ExportFormat::APR.is_safe());
153 assert!(ExportFormat::GGUF.is_safe());
154 assert!(!ExportFormat::PyTorch.is_safe());
155 }
156
157 #[test]
158 fn test_falsify_display() {
159 assert_eq!(ExportFormat::SafeTensors.to_string(), "SafeTensors");
160 assert_eq!(ExportFormat::APR.to_string(), "APR");
161 assert_eq!(ExportFormat::GGUF.to_string(), "GGUF");
162 assert_eq!(ExportFormat::PyTorch.to_string(), "PyTorch");
163 }
164}