use crate::huggingface::client::HfClient;
use anyhow::{bail, Context, Result};
use std::path::{Path, PathBuf};
pub struct UploadOptions {
pub repo_id: String,
pub revision: String,
pub commit_message: String,
}
pub struct UploadResult {
pub repo_id: String,
pub files_uploaded: usize,
pub commit_url: String,
}
pub async fn upload_model(
client: &HfClient,
path: &Path,
opts: &UploadOptions,
) -> Result<UploadResult> {
let mut files: Vec<(String, PathBuf)> = Vec::new();
collect_files(path, path, &mut files)
.with_context(|| format!("Failed to collect files from '{}'", path.display()))?;
if files.is_empty() {
bail!("No files found to upload at '{}'", path.display());
}
let mut files_uploaded: usize = 0;
for (relative_path, full_path) in &files {
let data = std::fs::read(full_path)
.with_context(|| format!("Failed to read file '{}'", full_path.display()))?;
client
.upload_file(
&opts.repo_id,
&opts.revision,
relative_path,
data,
&opts.commit_message,
)
.await
.with_context(|| {
format!("Failed to upload '{}' to '{}'", relative_path, opts.repo_id)
})?;
files_uploaded += 1;
}
let commit_url = format!(
"https://huggingface.co/{}/tree/{}",
opts.repo_id, opts.revision
);
Ok(UploadResult {
repo_id: opts.repo_id.clone(),
files_uploaded,
commit_url,
})
}
fn collect_files(base: &Path, current: &Path, files: &mut Vec<(String, PathBuf)>) -> Result<()> {
if current.is_file() {
let relative = current
.strip_prefix(base)
.unwrap_or(current.file_name().map(Path::new).unwrap_or(current));
let relative_str = relative
.to_str()
.context("File path contains non-UTF-8 characters")?
.to_string();
files.push((relative_str, current.to_path_buf()));
return Ok(());
}
if current.is_dir() {
let entries = std::fs::read_dir(current)
.with_context(|| format!("Failed to read directory '{}'", current.display()))?;
for entry in entries {
let entry = entry
.with_context(|| format!("Failed to read entry in '{}'", current.display()))?;
let name = entry.file_name();
let name_str = name.to_string_lossy();
if name_str.starts_with('.') {
continue;
}
collect_files(base, &entry.path(), files)?;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[test]
fn collect_files_single_file() {
let dir = TempDir::new().unwrap();
let file = dir.path().join("model.bin");
fs::write(&file, b"data").unwrap();
let mut files = Vec::new();
collect_files(dir.path(), dir.path(), &mut files).unwrap();
assert_eq!(files.len(), 1);
assert_eq!(files[0].0, "model.bin");
}
#[test]
fn collect_files_nested_directory() {
let dir = TempDir::new().unwrap();
fs::create_dir_all(dir.path().join("subdir")).unwrap();
fs::write(dir.path().join("config.json"), b"{}").unwrap();
fs::write(dir.path().join("subdir").join("weights.bin"), b"weights").unwrap();
let mut files = Vec::new();
collect_files(dir.path(), dir.path(), &mut files).unwrap();
assert_eq!(files.len(), 2);
let names: Vec<&str> = files.iter().map(|(name, _)| name.as_str()).collect();
assert!(names.contains(&"config.json"));
assert!(names.contains(&"subdir/weights.bin"));
}
#[test]
fn collect_files_skips_hidden() {
let dir = TempDir::new().unwrap();
fs::write(dir.path().join("visible.txt"), b"ok").unwrap();
fs::write(dir.path().join(".hidden"), b"secret").unwrap();
fs::create_dir_all(dir.path().join(".git")).unwrap();
fs::write(dir.path().join(".git").join("HEAD"), b"ref").unwrap();
let mut files = Vec::new();
collect_files(dir.path(), dir.path(), &mut files).unwrap();
assert_eq!(files.len(), 1);
assert_eq!(files[0].0, "visible.txt");
}
#[test]
fn collect_files_empty_directory() {
let dir = TempDir::new().unwrap();
let mut files = Vec::new();
collect_files(dir.path(), dir.path(), &mut files).unwrap();
assert!(files.is_empty());
}
}