1use std::path::Path;
4
5use crate::cli::logging::log;
6use crate::cli::LogLevel;
7use crate::config::PublishArgs;
8
9pub fn run_publish(args: PublishArgs, level: LogLevel) -> Result<(), String> {
10 log(level, LogLevel::Normal, &format!("Publishing to {}", args.repo));
11
12 if !args.model_dir.exists() {
14 return Err(format!("Model directory not found: {}", args.model_dir.display()));
15 }
16
17 let files = collect_model_files(&args.model_dir).map_err(|e| format!("File scan: {e}"))?;
19 if files.is_empty() {
20 return Err(format!("No model files found in {}", args.model_dir.display()));
21 }
22
23 log(level, LogLevel::Normal, &format!(" Found {} file(s) to upload", files.len()));
24 for (path, remote) in &files {
25 log(level, LogLevel::Verbose, &format!(" {} -> {}", path.display(), remote));
26 }
27
28 if args.dry_run {
29 log(level, LogLevel::Normal, "Dry run — skipping upload");
30 return Ok(());
31 }
32
33 do_publish(&args, &files, level)
34}
35
36#[cfg(feature = "hub-publish")]
37fn do_publish(
38 args: &PublishArgs,
39 files: &[(std::path::PathBuf, String)],
40 level: LogLevel,
41) -> Result<(), String> {
42 use crate::hf_pipeline::publish::config::PublishConfig;
43
44 use crate::hf_pipeline::publish::publisher::HfPublisher;
45
46 let config =
47 PublishConfig { repo_id: args.repo.clone(), private: args.private, ..Default::default() };
48
49 let model_card = if args.model_card { Some(build_model_card(args)) } else { None };
50
51 let publisher =
52 HfPublisher::new(config).map_err(|e| format!("Publisher initialization: {e}"))?;
53
54 let file_refs: Vec<(&Path, &str)> =
55 files.iter().map(|(path, remote)| (path.as_path(), remote.as_str())).collect();
56
57 let result = publisher
58 .publish(&file_refs, model_card.as_ref())
59 .map_err(|e| format!("Upload failed: {e}"))?;
60
61 log(level, LogLevel::Normal, &format!("Published: {result}"));
62 Ok(())
63}
64
65#[cfg(not(feature = "hub-publish"))]
66fn do_publish(
67 _args: &PublishArgs,
68 _files: &[(std::path::PathBuf, String)],
69 _level: LogLevel,
70) -> Result<(), String> {
71 Err("Publishing requires the 'hub-publish' feature. Rebuild with: cargo install entrenar --features hub-publish".to_string())
72}
73
74fn collect_model_files(dir: &Path) -> Result<Vec<(std::path::PathBuf, String)>, std::io::Error> {
76 let mut files = Vec::new();
77
78 let extensions = ["safetensors", "gguf", "bin", "json", "yaml", "yml", "txt"];
79
80 for entry in std::fs::read_dir(dir)? {
81 let entry = entry?;
82 let path = entry.path();
83 if !path.is_file() {
84 continue;
85 }
86 let name = match path.file_name().and_then(|n| n.to_str()) {
87 Some(n) => n.to_string(),
88 None => continue,
89 };
90
91 if name.starts_with('.') {
93 continue;
94 }
95
96 let include = extensions.iter().any(|ext| name.ends_with(&format!(".{ext}")));
98
99 if include {
100 files.push((path, name));
101 }
102 }
103
104 files.sort_by(|a, b| a.1.cmp(&b.1));
106
107 Ok(files)
108}
109
110#[cfg(feature = "hub-publish")]
112fn build_model_card(args: &PublishArgs) -> crate::hf_pipeline::publish::model_card::ModelCard {
113 use crate::hf_pipeline::publish::model_card::ModelCard;
114
115 let model_name = args.repo.rsplit('/').next().unwrap_or(&args.repo).to_string();
116
117 let metadata_path = args.model_dir.join("final_model.json");
118 let training_details = read_training_metadata(&metadata_path);
119
120 ModelCard {
121 model_name,
122 description: format!("Fine-tuned model published via entrenar from {}", args.repo),
123 license: Some("apache-2.0".to_string()),
124 language: Vec::new(),
125 tags: vec!["entrenar".to_string(), "fine-tuned".to_string(), "rust".to_string()],
126 metrics: Vec::new(),
127 training_details,
128 base_model: args.base_model.clone(),
129 }
130}
131
132#[cfg(any(feature = "hub-publish", test))]
134fn read_training_metadata(path: &Path) -> Option<String> {
135 let content = std::fs::read_to_string(path).ok()?;
136 let json: serde_json::Value = serde_json::from_str(&content).ok()?;
137
138 let mut details = String::new();
139 if let Some(epochs) = json.get("epochs_completed").and_then(serde_json::Value::as_u64) {
140 details.push_str(&format!("- **Epochs:** {epochs}\n"));
141 }
142 if let Some(loss) = json.get("final_loss").and_then(serde_json::Value::as_f64) {
143 details.push_str(&format!("- **Final loss:** {loss:.6}\n"));
144 }
145 if let Some(mode) = json.get("training_mode").and_then(|v| v.as_str()) {
146 details.push_str(&format!("- **Training mode:** {mode}\n"));
147 }
148
149 if details.is_empty() {
150 None
151 } else {
152 Some(details)
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use std::path::PathBuf;
160
161 #[test]
162 fn test_collect_model_files_empty_dir() {
163 let dir = tempfile::tempdir().expect("temp file creation should succeed");
164 let files = collect_model_files(dir.path()).expect("operation should succeed");
165 assert!(files.is_empty());
166 }
167
168 #[test]
169 fn test_collect_model_files_filters_extensions() {
170 let dir = tempfile::tempdir().expect("temp file creation should succeed");
171 std::fs::write(dir.path().join("model.safetensors"), b"data")
172 .expect("file write should succeed");
173 std::fs::write(dir.path().join("config.json"), b"{}").expect("file write should succeed");
174 std::fs::write(dir.path().join("random.xyz"), b"skip").expect("file write should succeed");
175 std::fs::write(dir.path().join(".hidden"), b"skip").expect("file write should succeed");
176
177 let files = collect_model_files(dir.path()).expect("operation should succeed");
178 assert_eq!(files.len(), 2);
179 let names: Vec<&str> = files.iter().map(|(_, n)| n.as_str()).collect();
180 assert!(names.contains(&"model.safetensors"));
181 assert!(names.contains(&"config.json"));
182 }
183
184 #[test]
185 fn test_collect_model_files_sorted() {
186 let dir = tempfile::tempdir().expect("temp file creation should succeed");
187 std::fs::write(dir.path().join("z_weights.safetensors"), b"w")
188 .expect("file write should succeed");
189 std::fs::write(dir.path().join("a_config.json"), b"c").expect("file write should succeed");
190
191 let files = collect_model_files(dir.path()).expect("operation should succeed");
192 assert_eq!(files[0].1, "a_config.json");
193 assert_eq!(files[1].1, "z_weights.safetensors");
194 }
195
196 #[test]
197 fn test_run_publish_missing_dir() {
198 let args = PublishArgs {
199 model_dir: PathBuf::from("/tmp/definitely-nonexistent-dir-12345"),
200 repo: "user/model".to_string(),
201 private: false,
202 model_card: true,
203 merge_adapters: false,
204 base_model: None,
205 format: "safetensors".to_string(),
206 dry_run: false,
207 };
208 let result = run_publish(args, LogLevel::Quiet);
209 assert!(result.is_err());
210 assert!(result.unwrap_err().contains("not found"));
211 }
212
213 #[test]
214 fn test_run_publish_empty_dir() {
215 let dir = tempfile::tempdir().expect("temp file creation should succeed");
216 let args = PublishArgs {
217 model_dir: dir.path().to_path_buf(),
218 repo: "user/model".to_string(),
219 private: false,
220 model_card: true,
221 merge_adapters: false,
222 base_model: None,
223 format: "safetensors".to_string(),
224 dry_run: false,
225 };
226 let result = run_publish(args, LogLevel::Quiet);
227 assert!(result.is_err());
228 assert!(result.unwrap_err().contains("No model files"));
229 }
230
231 #[test]
232 fn test_run_publish_dry_run() {
233 let dir = tempfile::tempdir().expect("temp file creation should succeed");
234 std::fs::write(dir.path().join("model.safetensors"), b"data")
235 .expect("file write should succeed");
236
237 let args = PublishArgs {
238 model_dir: dir.path().to_path_buf(),
239 repo: "user/model".to_string(),
240 private: false,
241 model_card: true,
242 merge_adapters: false,
243 base_model: None,
244 format: "safetensors".to_string(),
245 dry_run: true,
246 };
247 let result = run_publish(args, LogLevel::Quiet);
248 assert!(result.is_ok());
249 }
250
251 #[test]
252 fn test_run_publish_no_hub_feature() {
253 let dir = tempfile::tempdir().expect("temp file creation should succeed");
254 std::fs::write(dir.path().join("model.safetensors"), b"data")
255 .expect("file write should succeed");
256
257 let args = PublishArgs {
258 model_dir: dir.path().to_path_buf(),
259 repo: "user/model".to_string(),
260 private: false,
261 model_card: true,
262 merge_adapters: false,
263 base_model: None,
264 format: "safetensors".to_string(),
265 dry_run: false,
266 };
267 let result = run_publish(args, LogLevel::Quiet);
270 #[cfg(not(feature = "hub-publish"))]
271 assert!(result.unwrap_err().contains("hub-publish"));
272 #[cfg(feature = "hub-publish")]
273 let _ = result; }
275
276 #[test]
277 fn test_read_training_metadata_missing() {
278 let result = read_training_metadata(Path::new("/tmp/nonexistent.json"));
279 assert!(result.is_none());
280 }
281
282 #[test]
283 fn test_read_training_metadata_invalid_json() {
284 let dir = tempfile::tempdir().expect("temp file creation should succeed");
285 let path = dir.path().join("final_model.json");
286 std::fs::write(&path, "not json").expect("file write should succeed");
287 let result = read_training_metadata(&path);
288 assert!(result.is_none());
289 }
290
291 #[test]
292 fn test_read_training_metadata_valid() {
293 let dir = tempfile::tempdir().expect("temp file creation should succeed");
294 let metadata = serde_json::json!({
295 "epochs_completed": 3,
296 "final_loss": 1.5432,
297 "training_mode": "LoRA"
298 });
299 let path = dir.path().join("final_model.json");
300 std::fs::write(&path, serde_json::to_string(&metadata).expect("file write should succeed"))
301 .expect("file write should succeed");
302
303 let details = read_training_metadata(&path).expect("operation should succeed");
304 assert!(details.contains("Epochs"));
305 assert!(details.contains("1.5432"));
306 assert!(details.contains("LoRA"));
307 }
308
309 #[test]
310 fn test_read_training_metadata_partial() {
311 let dir = tempfile::tempdir().expect("temp file creation should succeed");
312 let metadata = serde_json::json!({
313 "epochs_completed": 5
314 });
315 let path = dir.path().join("final_model.json");
316 std::fs::write(&path, serde_json::to_string(&metadata).expect("file write should succeed"))
317 .expect("file write should succeed");
318
319 let details = read_training_metadata(&path).expect("operation should succeed");
320 assert!(details.contains("Epochs"));
321 assert!(details.contains('5'));
322 }
323
324 #[test]
325 fn test_read_training_metadata_empty_json() {
326 let dir = tempfile::tempdir().expect("temp file creation should succeed");
327 let path = dir.path().join("final_model.json");
328 std::fs::write(&path, "{}").expect("file write should succeed");
329
330 let result = read_training_metadata(&path);
331 assert!(result.is_none());
332 }
333
334 #[test]
335 fn test_collect_model_files_all_extensions() {
336 let dir = tempfile::tempdir().expect("temp file creation should succeed");
337 for ext in &["safetensors", "gguf", "bin", "json", "yaml", "yml", "txt"] {
338 std::fs::write(dir.path().join(format!("file.{ext}")), b"data")
339 .expect("file write should succeed");
340 }
341 let files = collect_model_files(dir.path()).expect("operation should succeed");
342 assert_eq!(files.len(), 7);
343 }
344
345 #[test]
346 fn test_collect_model_files_skips_directories() {
347 let dir = tempfile::tempdir().expect("temp file creation should succeed");
348 std::fs::write(dir.path().join("model.safetensors"), b"data")
349 .expect("file write should succeed");
350 std::fs::create_dir(dir.path().join("subdir")).expect("thread join should succeed");
351
352 let files = collect_model_files(dir.path()).expect("operation should succeed");
353 assert_eq!(files.len(), 1);
354 }
355}