use std::path::{Path, PathBuf};
use hf_hub::api::tokio::{Api, ApiBuilder};
const HF_TOKEN_ENV: &str = "HF_TOKEN";
fn build_api() -> anyhow::Result<Api> {
let mut builder = ApiBuilder::from_env().with_progress(true);
if let Ok(token) = std::env::var(HF_TOKEN_ENV) {
if !token.is_empty() {
builder = builder.with_token(Some(token));
}
}
Ok(builder.build()?)
}
const IGNORED: [&str; 5] = [
".gitattributes",
"LICENSE",
"LICENSE.txt",
"README.md",
"USE_POLICY.md",
];
fn is_weight_file(filename: &str) -> bool {
filename.ends_with(".bin")
|| filename.ends_with(".safetensors")
|| filename.ends_with(".h5")
|| filename.ends_with(".msgpack")
|| filename.ends_with(".ckpt.index")
}
fn is_image(filename: &str) -> bool {
let lower = filename.to_lowercase();
lower.ends_with(".png") || lower.ends_with(".jpg") || lower.ends_with(".jpeg")
}
fn is_tokenizer_file(filename: &str) -> bool {
filename.ends_with("tokenizer.json")
|| filename.ends_with("tokenizer_config.json")
|| filename.ends_with("special_tokens_map.json")
|| filename.ends_with("vocab.json")
|| filename.ends_with("merges.txt")
|| filename.ends_with(".model") || filename.ends_with(".tiktoken")
|| is_chat_template_file(filename) }
fn is_chat_template_file(filename: &str) -> bool {
filename.ends_with(".jinja") || filename == "chat_template.json" }
pub async fn download_tokenizer_from_hf(model_id: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
let model_id = model_id.as_ref();
let api = build_api()?;
let model_name = model_id.display().to_string();
let repo = api.model(model_name.clone());
let info = match repo.info().await {
Ok(info) => info,
Err(e) => {
return Err(anyhow::anyhow!(
"Failed to fetch model '{model_name}' from HuggingFace: {e}. Is this a valid HuggingFace ID?"
));
}
};
if info.siblings.is_empty() {
return Err(anyhow::anyhow!(
"Model '{model_name}' exists but contains no downloadable files."
));
}
let mut cache_dir = None;
let mut tokenizer_files_found = false;
let tokenizer_files: Vec<_> = info
.siblings
.iter()
.filter(|sib| {
!IGNORED.contains(&sib.rfilename.as_str())
&& !is_image(&sib.rfilename)
&& !is_weight_file(&sib.rfilename)
&& is_tokenizer_file(&sib.rfilename)
})
.collect();
if tokenizer_files.is_empty() {
return Err(anyhow::anyhow!(
"No tokenizer files found for model '{model_name}'."
));
}
for sib in tokenizer_files {
match repo.get(&sib.rfilename).await {
Ok(path) => {
if cache_dir.is_none() {
cache_dir = path.parent().map(|p| p.to_path_buf());
}
tokenizer_files_found = true;
}
Err(e) => {
return Err(anyhow::anyhow!(
"Failed to download tokenizer file '{}' from model '{}': {}",
sib.rfilename,
model_name,
e
));
}
}
}
if !tokenizer_files_found {
return Err(anyhow::anyhow!(
"No tokenizer files could be downloaded for model '{model_name}'."
));
}
match cache_dir {
Some(dir) => {
let final_dir = resolve_model_cache_dir(&dir, &model_name);
Ok(final_dir)
}
None => Err(anyhow::anyhow!(
"Invalid HF cache path for model '{model_name}'"
)),
}
}
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
let name = name.as_ref();
let api = build_api()?;
let model_name = name.display().to_string();
let repo = api.model(model_name.clone());
let info = match repo.info().await {
Ok(info) => info,
Err(e) => {
return Err(anyhow::anyhow!(
"Failed to fetch model '{model_name}' from HuggingFace: {e}. Is this a valid HuggingFace ID?"
));
}
};
if info.siblings.is_empty() {
return Err(anyhow::anyhow!(
"Model '{model_name}' exists but contains no downloadable files."
));
}
let mut p = PathBuf::new();
let mut files_downloaded = false;
for sib in info.siblings {
if IGNORED.contains(&sib.rfilename.as_str()) || is_image(&sib.rfilename) {
continue;
}
if ignore_weights && is_weight_file(&sib.rfilename) {
continue;
}
match repo.get(&sib.rfilename).await {
Ok(path) => {
p = path;
files_downloaded = true;
}
Err(e) => {
return Err(anyhow::anyhow!(
"Failed to download file '{}' from model '{}': {}",
sib.rfilename,
model_name,
e
));
}
}
}
if !files_downloaded {
let file_type = if ignore_weights {
"non-weight"
} else {
"valid"
};
return Err(anyhow::anyhow!(
"No {file_type} files found for model '{model_name}'."
));
}
match p.parent() {
Some(p) => {
let final_dir = resolve_model_cache_dir(p, &model_name);
Ok(final_dir)
}
None => Err(anyhow::anyhow!("Invalid HF cache path: {}", p.display())),
}
}
fn resolve_model_cache_dir(path: &Path, model_name: &str) -> PathBuf {
if let Some(parent) = path.parent() {
if let Some(folder_name) = path.file_name() {
if folder_name == "original" {
return parent.to_path_buf();
}
}
}
let model_parts: Vec<&str> = model_name.split('/').collect();
if model_parts.len() >= 2 {
let expected_pattern = format!(
"models--{}--{}",
model_parts[0].replace("-", "--"),
model_parts[1].replace("-", "--")
);
if path.to_string_lossy().contains(&expected_pattern) {
return path.to_path_buf();
}
let mut current = path.to_path_buf();
if current.join("tokenizer.json").exists() || current.join("tokenizer_config.json").exists()
{
return current;
}
while let Some(parent) = current.parent() {
if parent.to_string_lossy().contains(&expected_pattern) {
let snapshots_dir = parent.join("snapshots");
if snapshots_dir.exists() && snapshots_dir.is_dir() {
if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
for entry in entries.flatten() {
let snapshot_path = entry.path();
if snapshot_path.is_dir()
&& (snapshot_path.join("tokenizer.json").exists()
|| snapshot_path.join("tokenizer_config.json").exists())
{
return snapshot_path;
}
}
}
}
return parent.to_path_buf();
}
current = parent.to_path_buf();
}
}
path.to_path_buf()
}
#[cfg(test)]
mod tests {
use super::{is_chat_template_file, is_tokenizer_file, is_weight_file};
#[test]
fn test_is_tokenizer_file() {
assert!(is_tokenizer_file("tokenizer.json"));
assert!(is_tokenizer_file("tokenizer_config.json"));
assert!(is_tokenizer_file("special_tokens_map.json"));
assert!(is_tokenizer_file("vocab.json"));
assert!(is_tokenizer_file("merges.txt"));
assert!(is_tokenizer_file("spiece.model"));
assert!(is_tokenizer_file("chat_template.jinja"));
assert!(is_tokenizer_file("template.jinja"));
assert!(!is_tokenizer_file("model.bin"));
assert!(!is_tokenizer_file("README.md"));
}
#[test]
fn test_is_chat_template_file() {
assert!(is_chat_template_file("chat_template.jinja"));
assert!(is_chat_template_file("template.jinja"));
assert!(is_chat_template_file("any_file.jinja"));
assert!(is_chat_template_file("chat_template.json"));
assert!(!is_chat_template_file("tokenizer.json"));
assert!(!is_chat_template_file("other_file.json"));
assert!(!is_chat_template_file("chat_template"));
assert!(!is_chat_template_file("README.md"));
}
#[test]
fn test_is_weight_file() {
assert!(is_weight_file("model.bin"));
assert!(is_weight_file("model.safetensors"));
assert!(is_weight_file("pytorch_model.bin"));
assert!(!is_weight_file("tokenizer.json"));
assert!(!is_weight_file("config.json"));
}
}