use crate::hf_pipeline::error::{FetchError, Result};
use std::path::PathBuf;
use super::options::FetchOptions;
use super::types::{ModelArtifact, WeightFormat};
pub struct HfModelFetcher {
pub(crate) token: Option<String>,
pub(crate) cache_dir: PathBuf,
#[allow(dead_code)]
pub(crate) api_base: String,
}
impl HfModelFetcher {
pub fn new() -> Result<Self> {
let token = Self::resolve_token();
let cache_dir = Self::default_cache_dir();
Ok(Self { token, cache_dir, api_base: "https://huggingface.co".into() })
}
#[must_use]
pub fn with_token(token: impl Into<String>) -> Self {
Self {
token: Some(token.into()),
cache_dir: Self::default_cache_dir(),
api_base: "https://huggingface.co".into(),
}
}
#[must_use]
pub fn cache_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.cache_dir = dir.into();
self
}
#[must_use]
pub fn resolve_token() -> Option<String> {
if let Ok(token) = std::env::var("HF_TOKEN") {
if !token.is_empty() {
return Some(token);
}
}
if let Some(home) = dirs::home_dir() {
let token_path = home.join(".huggingface").join("token");
if let Ok(token) = std::fs::read_to_string(token_path) {
let token = token.trim().to_string();
if !token.is_empty() {
return Some(token);
}
}
}
None
}
pub(crate) fn default_cache_dir() -> PathBuf {
dirs::cache_dir().unwrap_or_else(|| PathBuf::from(".cache")).join("huggingface").join("hub")
}
#[must_use]
pub fn is_authenticated(&self) -> bool {
self.token.is_some()
}
pub(crate) fn parse_repo_id(repo_id: &str) -> Result<(&str, &str)> {
let parts: Vec<&str> = repo_id.split('/').collect();
if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() {
return Err(FetchError::InvalidRepoId { repo_id: repo_id.to_string() });
}
Ok((parts[0], parts[1]))
}
fn resolve_files(options: &FetchOptions) -> Vec<String> {
if options.files.is_empty() {
vec!["model.safetensors".to_string(), "config.json".to_string()]
} else {
options.files.clone()
}
}
fn check_security(files: &[String], allow_pickle: bool) -> Result<()> {
for file in files {
if let Some(format) = WeightFormat::from_filename(file) {
if !format.is_safe() && !allow_pickle {
return Err(FetchError::PickleSecurityRisk);
}
}
}
Ok(())
}
fn build_api(&self, cache_path: &std::path::Path) -> Result<hf_hub::api::sync::Api> {
let mut api_builder =
hf_hub::api::sync::ApiBuilder::new().with_cache_dir(cache_path.to_path_buf());
if let Some(token) = &self.token {
api_builder = api_builder.with_token(Some(token.clone()));
}
api_builder.build().map_err(|e| FetchError::ConfigParseError {
message: format!("Failed to initialize HF API: {e}"),
})
}
fn download_file(
repo: &hf_hub::api::sync::ApiRepo,
api: &hf_hub::api::sync::Api,
repo_id: &str,
revision: &str,
file: &str,
cache_path: &std::path::Path,
) -> Result<()> {
let download_result = if revision == "main" {
repo.get(file)
} else {
let revision_repo = api.repo(hf_hub::Repo::with_revision(
repo_id.to_string(),
hf_hub::RepoType::Model,
revision.to_string(),
));
revision_repo.get(file)
};
match download_result {
Ok(path) => {
let dest = cache_path.join(file);
if path != dest {
if let Some(parent) = dest.parent() {
std::fs::create_dir_all(parent)?;
}
if path.exists() && !dest.exists() {
std::fs::copy(&path, &dest)?;
}
}
Ok(())
}
Err(hf_hub::api::sync::ApiError::RequestError(e)) => {
if e.to_string().contains("404") {
Err(FetchError::FileNotFound {
repo: repo_id.to_string(),
file: file.to_string(),
})
} else {
Err(FetchError::ConfigParseError { message: format!("Download failed: {e}") })
}
}
Err(e) => {
Err(FetchError::ConfigParseError { message: format!("Download failed: {e}") })
}
}
}
pub fn download_model(&self, repo_id: &str, options: FetchOptions) -> Result<ModelArtifact> {
Self::parse_repo_id(repo_id)?;
let files = Self::resolve_files(&options);
Self::check_security(&files, options.allow_pytorch_pickle)?;
let cache_path = options
.cache_dir
.clone()
.unwrap_or_else(|| self.cache_dir.clone())
.join(repo_id.replace('/', "--"))
.join(&options.revision);
std::fs::create_dir_all(&cache_path)?;
let format = files
.iter()
.find_map(|f| WeightFormat::from_filename(f))
.unwrap_or(WeightFormat::SafeTensors);
let api = self.build_api(&cache_path)?;
let repo = api.model(repo_id.to_string());
for file in &files {
Self::download_file(&repo, &api, repo_id, &options.revision, file, &cache_path)?;
}
Ok(ModelArtifact {
path: cache_path,
format,
architecture: None,
sha256: options.verify_sha256,
})
}
#[must_use]
pub fn estimate_memory(param_count: u64, dtype_bytes: u8) -> u64 {
param_count * u64::from(dtype_bytes)
}
}
impl Default for HfModelFetcher {
fn default() -> Self {
Self::new().expect("Failed to create HfModelFetcher")
}
}