use super::super::{
base64_encode, HfHubClient, HfHubError, ModelCard, PushOptions, Result, UploadProgress,
UploadResult,
};
use std::path::PathBuf;
impl HfHubClient {
pub fn new() -> Result<Self> {
let token = std::env::var("HF_TOKEN").ok();
let cache_dir = Self::default_cache_dir();
Ok(Self {
token,
cache_dir,
api_base: "https://huggingface.co".to_string(),
})
}
#[must_use]
pub fn with_token(token: impl Into<String>) -> Self {
Self {
token: Some(token.into()),
cache_dir: Self::default_cache_dir(),
api_base: "https://huggingface.co".to_string(),
}
}
#[must_use]
pub fn with_cache_dir(mut self, path: impl Into<PathBuf>) -> Self {
self.cache_dir = path.into();
self
}
pub(crate) fn default_cache_dir() -> PathBuf {
dirs::cache_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("huggingface")
.join("hub")
}
#[must_use]
pub fn is_authenticated(&self) -> bool {
self.token.is_some()
}
pub(crate) fn parse_repo_id(repo_id: &str) -> Result<(&str, &str)> {
let parts: Vec<&str> = repo_id.split('/').collect();
if parts.len() != 2 {
return Err(HfHubError::InvalidRepoId(repo_id.to_string()));
}
Ok((parts[0], parts[1]))
}
#[cfg(feature = "hf-hub-integration")]
pub fn pull_from_hub(&self, repo_id: &str) -> Result<PathBuf> {
use hf_hub::api::sync::ApiBuilder;
let (org, name) = Self::parse_repo_id(repo_id)?;
let mut builder = ApiBuilder::new();
if let Some(token) = &self.token {
builder = builder.with_token(Some(token.clone()));
}
let api = builder
.build()
.map_err(|e| HfHubError::NetworkError(e.to_string()))?;
let repo = api.model(format!("{org}/{name}"));
let model_path = repo
.get("model.apr")
.map_err(|e| HfHubError::FileNotFound(format!("model.apr: {e}")))?;
Ok(model_path)
}
#[cfg(not(feature = "hf-hub-integration"))]
pub fn pull_from_hub(&self, _repo_id: &str) -> Result<PathBuf> {
Err(HfHubError::NetworkError(
"hf-hub-integration feature not enabled".to_string(),
))
}
#[cfg(feature = "hf-hub-integration")]
#[allow(clippy::needless_pass_by_value)] pub fn push_to_hub(
&self,
repo_id: &str,
model_data: &[u8],
options: PushOptions,
) -> Result<UploadResult> {
let token = self.token.as_ref().ok_or(HfHubError::MissingToken)?;
let (_org, _name) = Self::parse_repo_id(repo_id)?;
let model_card = options.model_card.clone().unwrap_or_else(|| {
ModelCard::new(repo_id, "1.0.0").with_description("Model uploaded via aprender")
});
let readme_content = model_card.to_huggingface();
let commit_msg = options
.commit_message
.clone()
.unwrap_or_else(|| "Upload model via aprender".to_string());
if options.create_repo {
self.create_repo_if_not_exists(repo_id, token, options.private)?;
}
let total_bytes = model_data.len() as u64 + readme_content.len() as u64;
let mut bytes_transferred = 0u64;
let mut files_uploaded = Vec::new();
if let Some(ref cb) = options.progress_callback {
cb(UploadProgress {
bytes_sent: 0,
total_bytes,
current_file: options.filename.clone(),
files_completed: 0,
total_files: 2,
});
}
self.upload_file_with_retry(
repo_id,
&options.filename,
model_data,
&commit_msg,
token,
&options,
&mut bytes_transferred,
total_bytes,
0,
2,
)?;
files_uploaded.push(options.filename.clone());
self.upload_file_with_retry(
repo_id,
"README.md",
readme_content.as_bytes(),
&commit_msg,
token,
&options,
&mut bytes_transferred,
total_bytes,
1,
2,
)?;
files_uploaded.push("README.md".to_string());
if let Some(ref cb) = options.progress_callback {
cb(UploadProgress {
bytes_sent: bytes_transferred,
total_bytes,
current_file: "Complete".to_string(),
files_completed: 2,
total_files: 2,
});
}
Ok(UploadResult {
repo_url: format!("{}/{}", self.api_base, repo_id),
commit_sha: "uploaded".to_string(), files_uploaded,
bytes_transferred,
})
}
#[cfg(not(feature = "hf-hub-integration"))]
pub fn push_to_hub(
&self,
_repo_id: &str,
_model_data: &[u8],
_options: PushOptions,
) -> Result<UploadResult> {
Err(HfHubError::NetworkError(
"hf-hub-integration feature not enabled".to_string(),
))
}
#[cfg(feature = "hf-hub-integration")]
#[allow(clippy::disallowed_methods)] fn create_repo_if_not_exists(&self, repo_id: &str, token: &str, private: bool) -> Result<()> {
let (org, name) = Self::parse_repo_id(repo_id)?;
let url = format!("{}/api/repos/create", self.api_base);
let body = serde_json::json!({
"type": "model",
"name": name,
"organization": org,
"private": private
});
let response = ureq::post(&url)
.set("Authorization", &format!("Bearer {token}"))
.set("Content-Type", "application/json")
.send_json(&body);
match response {
Ok(_) => Ok(()),
Err(ureq::Error::Status(409, _)) => {
Ok(())
}
Err(ureq::Error::Status(400, _)) => {
Ok(())
}
Err(ureq::Error::Status(code, resp)) => {
let body = resp.into_string().unwrap_or_default();
Err(HfHubError::NetworkError(format!(
"Failed to create repo (HTTP {code}): {body}"
)))
}
Err(e) => Err(HfHubError::NetworkError(format!(
"Network error creating repo: {e}"
))),
}
}
#[cfg(feature = "hf-hub-integration")]
fn upload_file_with_retry(
&self,
repo_id: &str,
filename: &str,
data: &[u8],
commit_msg: &str,
token: &str,
options: &PushOptions,
bytes_transferred: &mut u64,
total_bytes: u64,
files_completed: usize,
total_files: usize,
) -> Result<()> {
let mut last_error = None;
let mut backoff_ms = options.initial_backoff_ms;
for attempt in 0..=options.max_retries {
if attempt > 0 {
std::thread::sleep(std::time::Duration::from_millis(backoff_ms));
backoff_ms = (backoff_ms * 2).min(30000); }
if let Some(ref cb) = options.progress_callback {
cb(UploadProgress {
bytes_sent: *bytes_transferred,
total_bytes,
current_file: filename.to_string(),
files_completed,
total_files,
});
}
match self.upload_file_once(repo_id, filename, data, commit_msg, token) {
Ok(()) => {
*bytes_transferred += data.len() as u64;
return Ok(());
}
Err(e) => {
last_error = Some(e);
if attempt == options.max_retries {
break;
}
}
}
}
Err(last_error
.unwrap_or_else(|| HfHubError::NetworkError("Upload failed after retries".to_string())))
}
const LFS_THRESHOLD: usize = 10 * 1024 * 1024;
#[cfg(feature = "hf-hub-integration")]
fn upload_file_once(
&self,
repo_id: &str,
filename: &str,
data: &[u8],
commit_msg: &str,
token: &str,
) -> Result<()> {
if data.len() >= Self::LFS_THRESHOLD {
self.upload_via_lfs(repo_id, filename, data, commit_msg, token)
} else {
self.upload_direct(repo_id, filename, data, commit_msg, token)
}
}
#[cfg(feature = "hf-hub-integration")]
#[allow(clippy::disallowed_methods)] fn upload_direct(
&self,
repo_id: &str,
filename: &str,
data: &[u8],
commit_msg: &str,
token: &str,
) -> Result<()> {
let url = format!("{}/api/models/{}/commit/main", self.api_base, repo_id);
let operations = serde_json::json!([{
"op": "addOrUpdate",
"path": filename,
"content": base64_encode(data)
}]);
let body = serde_json::json!({
"summary": commit_msg,
"operations": operations
});
let response = ureq::post(&url)
.set("Authorization", &format!("Bearer {token}"))
.set("Content-Type", "application/json")
.timeout(std::time::Duration::from_secs(120))
.send_json(&body);
match response {
Ok(resp) if resp.status() >= 200 && resp.status() < 300 => Ok(()),
Ok(resp) => {
let body = resp.into_string().unwrap_or_default();
Err(HfHubError::NetworkError(format!("Upload failed: {body}")))
}
Err(ureq::Error::Status(code, resp)) => {
let body = resp.into_string().unwrap_or_default();
Err(HfHubError::NetworkError(format!(
"Upload failed (HTTP {code}): {body}"
)))
}
Err(e) => Err(HfHubError::NetworkError(format!("Network error: {e}"))),
}
}
}