Skip to main content

entrenar/cli/commands/
publish.rs

1//! Publish command implementation — upload trained models to HuggingFace Hub
2
3use std::path::Path;
4
5use crate::cli::logging::log;
6use crate::cli::LogLevel;
7use crate::config::PublishArgs;
8
9pub fn run_publish(args: PublishArgs, level: LogLevel) -> Result<(), String> {
10    log(level, LogLevel::Normal, &format!("Publishing to {}", args.repo));
11
12    // Validate model directory
13    if !args.model_dir.exists() {
14        return Err(format!("Model directory not found: {}", args.model_dir.display()));
15    }
16
17    // Find model files to upload
18    let files = collect_model_files(&args.model_dir).map_err(|e| format!("File scan: {e}"))?;
19    if files.is_empty() {
20        return Err(format!("No model files found in {}", args.model_dir.display()));
21    }
22
23    log(level, LogLevel::Normal, &format!("  Found {} file(s) to upload", files.len()));
24    for (path, remote) in &files {
25        log(level, LogLevel::Verbose, &format!("    {} -> {}", path.display(), remote));
26    }
27
28    if args.dry_run {
29        log(level, LogLevel::Normal, "Dry run — skipping upload");
30        return Ok(());
31    }
32
33    do_publish(&args, &files, level)
34}
35
36#[cfg(feature = "hub-publish")]
37fn do_publish(
38    args: &PublishArgs,
39    files: &[(std::path::PathBuf, String)],
40    level: LogLevel,
41) -> Result<(), String> {
42    use crate::hf_pipeline::publish::config::PublishConfig;
43
44    use crate::hf_pipeline::publish::publisher::HfPublisher;
45
46    let config =
47        PublishConfig { repo_id: args.repo.clone(), private: args.private, ..Default::default() };
48
49    let model_card = if args.model_card { Some(build_model_card(args)) } else { None };
50
51    let publisher =
52        HfPublisher::new(config).map_err(|e| format!("Publisher initialization: {e}"))?;
53
54    let file_refs: Vec<(&Path, &str)> =
55        files.iter().map(|(path, remote)| (path.as_path(), remote.as_str())).collect();
56
57    let result = publisher
58        .publish(&file_refs, model_card.as_ref())
59        .map_err(|e| format!("Upload failed: {e}"))?;
60
61    log(level, LogLevel::Normal, &format!("Published: {result}"));
62    Ok(())
63}
64
65#[cfg(not(feature = "hub-publish"))]
66fn do_publish(
67    _args: &PublishArgs,
68    _files: &[(std::path::PathBuf, String)],
69    _level: LogLevel,
70) -> Result<(), String> {
71    Err("Publishing requires the 'hub-publish' feature. Rebuild with: cargo install entrenar --features hub-publish".to_string())
72}
73
74/// Collect model files from the output directory for upload.
75fn collect_model_files(dir: &Path) -> Result<Vec<(std::path::PathBuf, String)>, std::io::Error> {
76    let mut files = Vec::new();
77
78    let extensions = ["safetensors", "gguf", "bin", "json", "yaml", "yml", "txt"];
79
80    for entry in std::fs::read_dir(dir)? {
81        let entry = entry?;
82        let path = entry.path();
83        if !path.is_file() {
84            continue;
85        }
86        let name = match path.file_name().and_then(|n| n.to_str()) {
87            Some(n) => n.to_string(),
88            None => continue,
89        };
90
91        // Skip hidden files
92        if name.starts_with('.') {
93            continue;
94        }
95
96        // Include files with known extensions
97        let include = extensions.iter().any(|ext| name.ends_with(&format!(".{ext}")));
98
99        if include {
100            files.push((path, name));
101        }
102    }
103
104    // Sort for deterministic upload order
105    files.sort_by(|a, b| a.1.cmp(&b.1));
106
107    Ok(files)
108}
109
110/// Build a model card from publish args and training metadata.
111#[cfg(feature = "hub-publish")]
112fn build_model_card(args: &PublishArgs) -> crate::hf_pipeline::publish::model_card::ModelCard {
113    use crate::hf_pipeline::publish::model_card::ModelCard;
114
115    let model_name = args.repo.rsplit('/').next().unwrap_or(&args.repo).to_string();
116
117    let metadata_path = args.model_dir.join("final_model.json");
118    let training_details = read_training_metadata(&metadata_path);
119
120    ModelCard {
121        model_name,
122        description: format!("Fine-tuned model published via entrenar from {}", args.repo),
123        license: Some("apache-2.0".to_string()),
124        language: Vec::new(),
125        tags: vec!["entrenar".to_string(), "fine-tuned".to_string(), "rust".to_string()],
126        metrics: Vec::new(),
127        training_details,
128        base_model: args.base_model.clone(),
129    }
130}
131
132/// Read training metadata from final_model.json if it exists.
133#[cfg(any(feature = "hub-publish", test))]
134fn read_training_metadata(path: &Path) -> Option<String> {
135    let content = std::fs::read_to_string(path).ok()?;
136    let json: serde_json::Value = serde_json::from_str(&content).ok()?;
137
138    let mut details = String::new();
139    if let Some(epochs) = json.get("epochs_completed").and_then(serde_json::Value::as_u64) {
140        details.push_str(&format!("- **Epochs:** {epochs}\n"));
141    }
142    if let Some(loss) = json.get("final_loss").and_then(serde_json::Value::as_f64) {
143        details.push_str(&format!("- **Final loss:** {loss:.6}\n"));
144    }
145    if let Some(mode) = json.get("training_mode").and_then(|v| v.as_str()) {
146        details.push_str(&format!("- **Training mode:** {mode}\n"));
147    }
148
149    if details.is_empty() {
150        None
151    } else {
152        Some(details)
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use std::path::PathBuf;
160
161    #[test]
162    fn test_collect_model_files_empty_dir() {
163        let dir = tempfile::tempdir().expect("temp file creation should succeed");
164        let files = collect_model_files(dir.path()).expect("operation should succeed");
165        assert!(files.is_empty());
166    }
167
168    #[test]
169    fn test_collect_model_files_filters_extensions() {
170        let dir = tempfile::tempdir().expect("temp file creation should succeed");
171        std::fs::write(dir.path().join("model.safetensors"), b"data")
172            .expect("file write should succeed");
173        std::fs::write(dir.path().join("config.json"), b"{}").expect("file write should succeed");
174        std::fs::write(dir.path().join("random.xyz"), b"skip").expect("file write should succeed");
175        std::fs::write(dir.path().join(".hidden"), b"skip").expect("file write should succeed");
176
177        let files = collect_model_files(dir.path()).expect("operation should succeed");
178        assert_eq!(files.len(), 2);
179        let names: Vec<&str> = files.iter().map(|(_, n)| n.as_str()).collect();
180        assert!(names.contains(&"model.safetensors"));
181        assert!(names.contains(&"config.json"));
182    }
183
184    #[test]
185    fn test_collect_model_files_sorted() {
186        let dir = tempfile::tempdir().expect("temp file creation should succeed");
187        std::fs::write(dir.path().join("z_weights.safetensors"), b"w")
188            .expect("file write should succeed");
189        std::fs::write(dir.path().join("a_config.json"), b"c").expect("file write should succeed");
190
191        let files = collect_model_files(dir.path()).expect("operation should succeed");
192        assert_eq!(files[0].1, "a_config.json");
193        assert_eq!(files[1].1, "z_weights.safetensors");
194    }
195
196    #[test]
197    fn test_run_publish_missing_dir() {
198        let args = PublishArgs {
199            model_dir: PathBuf::from("/tmp/definitely-nonexistent-dir-12345"),
200            repo: "user/model".to_string(),
201            private: false,
202            model_card: true,
203            merge_adapters: false,
204            base_model: None,
205            format: "safetensors".to_string(),
206            dry_run: false,
207        };
208        let result = run_publish(args, LogLevel::Quiet);
209        assert!(result.is_err());
210        assert!(result.unwrap_err().contains("not found"));
211    }
212
213    #[test]
214    fn test_run_publish_empty_dir() {
215        let dir = tempfile::tempdir().expect("temp file creation should succeed");
216        let args = PublishArgs {
217            model_dir: dir.path().to_path_buf(),
218            repo: "user/model".to_string(),
219            private: false,
220            model_card: true,
221            merge_adapters: false,
222            base_model: None,
223            format: "safetensors".to_string(),
224            dry_run: false,
225        };
226        let result = run_publish(args, LogLevel::Quiet);
227        assert!(result.is_err());
228        assert!(result.unwrap_err().contains("No model files"));
229    }
230
231    #[test]
232    fn test_run_publish_dry_run() {
233        let dir = tempfile::tempdir().expect("temp file creation should succeed");
234        std::fs::write(dir.path().join("model.safetensors"), b"data")
235            .expect("file write should succeed");
236
237        let args = PublishArgs {
238            model_dir: dir.path().to_path_buf(),
239            repo: "user/model".to_string(),
240            private: false,
241            model_card: true,
242            merge_adapters: false,
243            base_model: None,
244            format: "safetensors".to_string(),
245            dry_run: true,
246        };
247        let result = run_publish(args, LogLevel::Quiet);
248        assert!(result.is_ok());
249    }
250
251    #[test]
252    fn test_run_publish_no_hub_feature() {
253        let dir = tempfile::tempdir().expect("temp file creation should succeed");
254        std::fs::write(dir.path().join("model.safetensors"), b"data")
255            .expect("file write should succeed");
256
257        let args = PublishArgs {
258            model_dir: dir.path().to_path_buf(),
259            repo: "user/model".to_string(),
260            private: false,
261            model_card: true,
262            merge_adapters: false,
263            base_model: None,
264            format: "safetensors".to_string(),
265            dry_run: false,
266        };
267        // Without hub-publish feature, this returns an error
268        // With hub-publish feature, this would attempt actual upload
269        let result = run_publish(args, LogLevel::Quiet);
270        #[cfg(not(feature = "hub-publish"))]
271        assert!(result.unwrap_err().contains("hub-publish"));
272        #[cfg(feature = "hub-publish")]
273        let _ = result; // May succeed or fail depending on HF_TOKEN
274    }
275
276    #[test]
277    fn test_read_training_metadata_missing() {
278        let result = read_training_metadata(Path::new("/tmp/nonexistent.json"));
279        assert!(result.is_none());
280    }
281
282    #[test]
283    fn test_read_training_metadata_invalid_json() {
284        let dir = tempfile::tempdir().expect("temp file creation should succeed");
285        let path = dir.path().join("final_model.json");
286        std::fs::write(&path, "not json").expect("file write should succeed");
287        let result = read_training_metadata(&path);
288        assert!(result.is_none());
289    }
290
291    #[test]
292    fn test_read_training_metadata_valid() {
293        let dir = tempfile::tempdir().expect("temp file creation should succeed");
294        let metadata = serde_json::json!({
295            "epochs_completed": 3,
296            "final_loss": 1.5432,
297            "training_mode": "LoRA"
298        });
299        let path = dir.path().join("final_model.json");
300        std::fs::write(&path, serde_json::to_string(&metadata).expect("file write should succeed"))
301            .expect("file write should succeed");
302
303        let details = read_training_metadata(&path).expect("operation should succeed");
304        assert!(details.contains("Epochs"));
305        assert!(details.contains("1.5432"));
306        assert!(details.contains("LoRA"));
307    }
308
309    #[test]
310    fn test_read_training_metadata_partial() {
311        let dir = tempfile::tempdir().expect("temp file creation should succeed");
312        let metadata = serde_json::json!({
313            "epochs_completed": 5
314        });
315        let path = dir.path().join("final_model.json");
316        std::fs::write(&path, serde_json::to_string(&metadata).expect("file write should succeed"))
317            .expect("file write should succeed");
318
319        let details = read_training_metadata(&path).expect("operation should succeed");
320        assert!(details.contains("Epochs"));
321        assert!(details.contains('5'));
322    }
323
324    #[test]
325    fn test_read_training_metadata_empty_json() {
326        let dir = tempfile::tempdir().expect("temp file creation should succeed");
327        let path = dir.path().join("final_model.json");
328        std::fs::write(&path, "{}").expect("file write should succeed");
329
330        let result = read_training_metadata(&path);
331        assert!(result.is_none());
332    }
333
334    #[test]
335    fn test_collect_model_files_all_extensions() {
336        let dir = tempfile::tempdir().expect("temp file creation should succeed");
337        for ext in &["safetensors", "gguf", "bin", "json", "yaml", "yml", "txt"] {
338            std::fs::write(dir.path().join(format!("file.{ext}")), b"data")
339                .expect("file write should succeed");
340        }
341        let files = collect_model_files(dir.path()).expect("operation should succeed");
342        assert_eq!(files.len(), 7);
343    }
344
345    #[test]
346    fn test_collect_model_files_skips_directories() {
347        let dir = tempfile::tempdir().expect("temp file creation should succeed");
348        std::fs::write(dir.path().join("model.safetensors"), b"data")
349            .expect("file write should succeed");
350        std::fs::create_dir(dir.path().join("subdir")).expect("thread join should succeed");
351
352        let files = collect_model_files(dir.path()).expect("operation should succeed");
353        assert_eq!(files.len(), 1);
354    }
355}