fn extract_shard_filename(kv_pair: &str) -> Option<String> {
let colon_pos = kv_pair.rfind(':')?;
let value = kv_pair[colon_pos + 1..].trim();
let filename = value.trim_matches(|c: char| c == '"' || c.is_whitespace());
if filename.ends_with(".safetensors")
&& !filename.is_empty()
&& !filename.contains('/')
&& !filename.contains('\\')
&& !filename.contains("..")
{
Some(filename.to_string())
} else {
None
}
}
fn extract_shard_files_from_index(json: &str) -> Vec<String> {
let Some(weight_map_start) = json.find("\"weight_map\"") else {
return Vec::new();
};
let Some(entries) = find_brace_content(&json[weight_map_start..]) else {
return Vec::new();
};
let mut sorted: Vec<String> = entries
.split(',')
.filter_map(extract_shard_filename)
.collect::<HashSet<_>>()
.into_iter()
.collect();
sorted.sort();
sorted
}
fn home_dir() -> Option<std::path::PathBuf> {
std::env::var_os("HOME")
.or_else(|| std::env::var_os("USERPROFILE"))
.map(std::path::PathBuf::from)
}
fn resolve_hf_token() -> Option<String> {
if let Ok(token) = std::env::var("HF_TOKEN") {
if !token.is_empty() {
return Some(token);
}
}
if let Some(home) = home_dir() {
for path in [
home.join(".huggingface/token"),
home.join(".cache/huggingface/token"),
] {
if let Ok(token) = std::fs::read_to_string(&path) {
let token = token.trim().to_string();
if !token.is_empty() {
return Some(token);
}
}
}
}
None
}
fn hf_get(url: &str) -> ureq::Request {
let req = ureq::get(url);
if let Some(token) = resolve_hf_token() {
req.set("Authorization", &format!("Bearer {token}"))
} else {
req
}
}
fn format_gated_model_error(url: &str) -> String {
format_gated_model_error_inner(url, resolve_hf_token().is_some())
}
fn format_gated_model_error_inner(url: &str, has_token: bool) -> String {
if has_token {
format!(
"Access denied (HTTP 401) for {url}\n\
Your HF_TOKEN was sent but lacks access to this gated model.\n\
Request access at the model page on huggingface.co, then retry."
)
} else {
format!(
"Access denied (HTTP 401) for {url}\n\
This is a gated model that requires authentication.\n\
Set HF_TOKEN=hf_... or run: huggingface-cli login"
)
}
}
fn download_file(url: &str, path: &Path) -> Result<()> {
let response = hf_get(url).call().map_err(|e| match &e {
ureq::Error::Status(404, _) => {
CliError::HttpNotFound(format!("HTTP 404: {url}"))
}
ureq::Error::Status(401, _) => {
CliError::NetworkError(format_gated_model_error(url))
}
_ => CliError::NetworkError(format!("Download failed: {e}")),
})?;
let mut file = std::fs::File::create(path)?;
let mut reader = response.into_reader();
std::io::copy(&mut reader, &mut file)?;
Ok(())
}
fn download_file_with_progress(url: &str, path: &Path) -> Result<FileChecksum> {
let response = hf_get(url).call().map_err(|e| match &e {
ureq::Error::Status(401, _) => {
CliError::NetworkError(format_gated_model_error(url))
}
_ => CliError::NetworkError(format!("Download failed: {e}")),
})?;
let total = response
.header("Content-Length")
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(0);
let mut file = std::fs::File::create(path)?;
let mut reader = response.into_reader();
let mut hasher = blake3::Hasher::new();
let mut downloaded: u64 = 0;
let mut buf = vec![0u8; 64 * 1024];
let mut last_pct: u64 = 0;
loop {
let n = reader
.read(&mut buf)
.map_err(|e| CliError::NetworkError(format!("Read failed: {e}")))?;
if n == 0 {
break;
}
let chunk = &buf[..n];
io::Write::write_all(&mut file, chunk)?;
hasher.update(chunk);
downloaded += n as u64;
if total > 0 {
let pct = downloaded * 100 / total;
if pct / 10 > last_pct / 10 {
print!(" {}%", pct);
io::stdout().flush().ok();
last_pct = pct;
}
}
}
if total > 0 && downloaded != total {
let _ = std::fs::remove_file(path);
return Err(CliError::NetworkError(format!(
"Download incomplete for '{}': expected {} bytes, got {} bytes",
path.display(),
total,
downloaded
)));
}
Ok(FileChecksum {
size: downloaded,
blake3: hasher.finalize().to_hex().to_string(),
})
}
#[allow(dead_code)]
pub fn resolve_hf_uri(uri: &str) -> Result<String> {
match resolve_hf_model(uri)? {
ResolvedModel::SingleFile(s) => Ok(s),
ResolvedModel::Sharded { org, repo, .. } => {
Ok(format!(
"hf://{}/{}/model.safetensors.index.json",
org, repo
))
}
}
}