use std::path::Path;
use super::config::PublishConfig;
use super::model_card::ModelCard;
use super::result::{PublishError, PublishResult};
use crate::hf_pipeline::HfModelFetcher;
const HF_API_BASE: &str = "https://huggingface.co/api";
pub struct HfPublisher {
config: PublishConfig,
client: reqwest::blocking::Client,
token: String,
}
impl HfPublisher {
pub fn new(config: PublishConfig) -> Result<Self, PublishError> {
let token = config
.token
.clone()
.or_else(HfModelFetcher::resolve_token)
.ok_or(PublishError::AuthRequired)?;
if config.repo_id.is_empty() || !config.repo_id.contains('/') {
return Err(PublishError::InvalidRepoId { repo_id: config.repo_id.clone() });
}
let client =
reqwest::blocking::Client::builder().user_agent("entrenar/0.5").build().map_err(
|e| PublishError::Http { message: format!("Failed to create HTTP client: {e}") },
)?;
Ok(Self { config, client, token })
}
pub fn create_repo(&self) -> Result<String, PublishError> {
let url = format!("{HF_API_BASE}/repos/create");
let mut body = serde_json::json!({
"name": self.repo_name(),
"type": self.config.repo_type.to_string(),
"private": self.config.private,
});
if let Some(org) = self.repo_org() {
body["organization"] = serde_json::Value::String(org);
}
let response =
self.client.post(&url).bearer_auth(&self.token).json(&body).send().map_err(|e| {
PublishError::Http { message: format!("Create repo request failed: {e}") }
})?;
if response.status().is_success() || response.status().as_u16() == 409 {
let repo_url = format!(
"https://huggingface.co/{}/{}",
self.config.repo_type.api_path(),
self.config.repo_id
);
Ok(repo_url)
} else {
let status = response.status();
let body = response.text().unwrap_or_default();
Err(PublishError::RepoCreationFailed {
repo_id: self.config.repo_id.clone(),
message: format!("HTTP {status}: {body}"),
})
}
}
pub fn upload_file(&self, local_path: &Path, path_in_repo: &str) -> Result<(), PublishError> {
let content = std::fs::read(local_path).map_err(PublishError::Io)?;
self.upload_bytes(&content, path_in_repo)
}
pub fn upload_bytes(&self, content: &[u8], path_in_repo: &str) -> Result<(), PublishError> {
let url = format!(
"{HF_API_BASE}/{}/{}/upload/main/{}",
self.config.repo_type.api_path(),
self.config.repo_id,
path_in_repo
);
let response = self
.client
.put(&url)
.bearer_auth(&self.token)
.header("Content-Type", "application/octet-stream")
.body(content.to_vec())
.send()
.map_err(|e| PublishError::UploadFailed {
path: path_in_repo.to_string(),
message: format!("Upload request failed: {e}"),
})?;
if response.status().is_success() {
Ok(())
} else {
let status = response.status();
let body = response.text().unwrap_or_default();
Err(PublishError::UploadFailed {
path: path_in_repo.to_string(),
message: format!("HTTP {status}: {body}"),
})
}
}
pub fn publish(
&self,
files: &[(&Path, &str)],
model_card: Option<&ModelCard>,
) -> Result<PublishResult, PublishError> {
let repo_url = self.create_repo()?;
let mut files_uploaded = 0;
for (local_path, remote_path) in files {
self.upload_file(local_path, remote_path)?;
files_uploaded += 1;
}
let model_card_generated = if let Some(card) = model_card {
let markdown = card.to_markdown();
self.upload_bytes(markdown.as_bytes(), "README.md")?;
true
} else {
false
};
Ok(PublishResult {
repo_url,
repo_id: self.config.repo_id.clone(),
files_uploaded,
model_card_generated,
})
}
fn repo_name(&self) -> &str {
self.config.repo_id.rsplit('/').next().unwrap_or(&self.config.repo_id)
}
fn repo_org(&self) -> Option<String> {
let parts: Vec<&str> = self.config.repo_id.splitn(2, '/').collect();
if parts.len() == 2 {
Some(parts[0].to_string())
} else {
None
}
}
}
impl std::fmt::Debug for HfPublisher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HfPublisher")
.field("repo_id", &self.config.repo_id)
.field("repo_type", &self.config.repo_type)
.field("private", &self.config.private)
.finish_non_exhaustive()
}
}