use std::io::SeekFrom;
use std::path::{Component, Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use futures_util::StreamExt;
use reqwest::Client;
use tokio::io::{AsyncSeekExt, AsyncWriteExt};
use tokio::task::JoinSet;
use serde::Deserialize;
use crate::error::FetchError;
use crate::progress::{self, ProgressEvent};
use crate::retry::{self, RetryPolicy};
type ProgressCallback = Arc<dyn Fn(&ProgressEvent) + Send + Sync>;
const HF_ENDPOINT: &str = "https://huggingface.co";
pub(crate) struct RangeInfo {
pub content_length: u64,
pub commit_hash: String,
pub etag: String,
pub cdn_url: String,
}
#[must_use]
pub(crate) fn build_download_url(repo_id: &str, revision: &str, filename: &str) -> String {
let url_revision = revision.replace('/', "%2F");
format!("{HF_ENDPOINT}/{repo_id}/resolve/{url_revision}/{filename}")
}
#[must_use]
pub(crate) fn repo_folder_name(repo_id: &str) -> String {
format!("models--{}", repo_id.replace('/', "--"))
}
pub(crate) fn build_no_redirect_client(token: Option<&str>) -> Result<Client, FetchError> {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::USER_AGENT,
reqwest::header::HeaderValue::from_static("hf-fetch-model"),
);
if let Some(tok) = token {
let auth_value = format!("Bearer {tok}");
let header_val = reqwest::header::HeaderValue::from_str(auth_value.as_str())
.map_err(|e| FetchError::Http(e.to_string()))?;
headers.insert(reqwest::header::AUTHORIZATION, header_val);
}
Client::builder()
.redirect(reqwest::redirect::Policy::none())
.default_headers(headers)
.build()
.map_err(|e| FetchError::Http(e.to_string()))
}
pub(crate) fn build_client(token: Option<&str>) -> Result<Client, FetchError> {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::USER_AGENT,
reqwest::header::HeaderValue::from_static("hf-fetch-model"),
);
if let Some(tok) = token {
let auth_value = format!("Bearer {tok}");
let header_val = reqwest::header::HeaderValue::from_str(auth_value.as_str())
.map_err(|e| FetchError::Http(e.to_string()))?;
headers.insert(reqwest::header::AUTHORIZATION, header_val);
}
Client::builder()
.default_headers(headers)
.build()
.map_err(|e| FetchError::Http(e.to_string()))
}
pub(crate) async fn probe_range_support(
client: Client,
url: String,
token: Option<String>,
) -> Result<Option<RangeInfo>, FetchError> {
let no_redirect_client = build_no_redirect_client(token.as_deref())?;
let response = no_redirect_client
.get(url.as_str())
.header(reqwest::header::RANGE, "bytes=0-0")
.send()
.await
.map_err(|e| FetchError::Http(e.to_string()))?;
let (hf_headers, redirect_url) = if response.status().is_redirection() {
let headers = response.headers().clone();
let location = headers
.get(reqwest::header::LOCATION)
.and_then(|v| v.to_str().ok())
.map(str::to_owned);
(headers, location)
} else if response.status() == reqwest::StatusCode::PARTIAL_CONTENT {
let headers = response.headers().clone();
(headers, None)
} else {
return Ok(None);
};
let commit_hash = hf_headers
.get("x-repo-commit")
.and_then(|v| v.to_str().ok())
.map(str::to_owned)
.ok_or_else(|| FetchError::Http("missing x-repo-commit header".to_owned()))?;
let etag = hf_headers
.get("x-linked-etag")
.or_else(|| hf_headers.get(reqwest::header::ETAG))
.and_then(|v| v.to_str().ok())
.map(str::to_owned)
.ok_or_else(|| FetchError::Http("missing etag header".to_owned()))?;
let etag = etag.replace('"', "");
let (cdn_url, content_length) = if let Some(ref loc) = redirect_url {
let cdn_response = client
.get(loc.as_str())
.header(reqwest::header::RANGE, "bytes=0-0")
.send()
.await
.map_err(|e| FetchError::Http(e.to_string()))?;
let size = parse_content_length_from_range(&cdn_response)?;
(loc.clone(), size)
} else {
let direct_response = client
.get(url.as_str())
.header(reqwest::header::RANGE, "bytes=0-0")
.send()
.await
.map_err(|e| FetchError::Http(e.to_string()))?;
let size = parse_content_length_from_range(&direct_response)?;
(url, size)
};
Ok(Some(RangeInfo {
content_length,
commit_hash,
etag,
cdn_url,
}))
}
fn parse_content_length_from_range(response: &reqwest::Response) -> Result<u64, FetchError> {
let content_range = response
.headers()
.get(reqwest::header::CONTENT_RANGE)
.and_then(|v| v.to_str().ok())
.ok_or_else(|| FetchError::Http("missing Content-Range header".to_owned()))?;
content_range
.split('/')
.next_back()
.and_then(|s| s.parse::<u64>().ok())
.ok_or_else(|| FetchError::Http(format!("invalid Content-Range header: {content_range}")))
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn download_chunked(
client: Client,
range_info: RangeInfo,
cache_dir: PathBuf,
repo_folder: String,
revision: String,
filename: String,
connections: usize,
retry_policy: RetryPolicy,
on_progress: Option<ProgressCallback>,
files_remaining: usize,
) -> Result<PathBuf, FetchError> {
let total_size = range_info.content_length;
let repo_dir = cache_dir.join(repo_folder.as_str());
let blob_path = repo_dir.join("blobs").join(range_info.etag.as_str());
let snapshot_dir = repo_dir
.join("snapshots")
.join(range_info.commit_hash.as_str());
let pointer_path = snapshot_dir.join(filename.as_str());
if pointer_path.exists() {
return Ok(pointer_path);
}
let temp_path = prepare_temp_file(&blob_path, &pointer_path, total_size).await?;
let chunk_size = total_size / u64::try_from(connections).unwrap_or(1);
let chunks: Vec<(usize, u64, u64)> = (0..connections)
.map(|i| {
let idx = u64::try_from(i).unwrap_or(0);
let start = idx * chunk_size;
let end = if i == connections - 1 {
total_size - 1
} else {
(idx + 1) * chunk_size - 1
};
(i, start, end)
})
.collect();
let bytes_downloaded = Arc::new(AtomicU64::new(0));
let mut join_set = JoinSet::new();
for (chunk_idx, start, end) in chunks {
let task_client = client.clone();
let task_url = range_info.cdn_url.clone();
let task_temp = temp_path.clone();
let task_policy = retry_policy.clone();
let task_bytes = Arc::clone(&bytes_downloaded);
let task_progress = on_progress.clone();
let task_filename = filename.clone();
join_set.spawn(async move {
download_chunk(
task_client,
task_url,
task_temp,
start,
end,
chunk_idx,
&task_policy,
&task_bytes,
task_progress.as_ref(),
task_filename.as_str(),
total_size,
files_remaining,
)
.await
});
}
let mut failures: Vec<String> = Vec::new();
while let Some(join_result) = join_set.join_next().await {
match join_result {
Ok(Ok(())) => {}
Ok(Err(e)) => failures.push(e.to_string()),
Err(e) => failures.push(format!("chunk task failed: {e}")),
}
}
if !failures.is_empty() {
let _ = tokio::fs::remove_file(&temp_path).await;
return Err(FetchError::ChunkedDownload {
filename: filename.clone(),
reason: failures.join("; "),
});
}
finalize_chunked_download(
&temp_path,
&blob_path,
&pointer_path,
&repo_dir,
revision.as_str(),
range_info.commit_hash.as_str(),
)
.await?;
Ok(pointer_path)
}
async fn prepare_temp_file(
blob_path: &std::path::Path,
pointer_path: &std::path::Path,
total_size: u64,
) -> Result<PathBuf, FetchError> {
if let Some(parent) = blob_path.parent() {
tokio::fs::create_dir_all(parent)
.await
.map_err(|e| FetchError::Io {
path: parent.to_path_buf(),
source: e,
})?;
}
if let Some(parent) = pointer_path.parent() {
tokio::fs::create_dir_all(parent)
.await
.map_err(|e| FetchError::Io {
path: parent.to_path_buf(),
source: e,
})?;
}
let temp_path = blob_path.with_extension("chunked.part");
let f = tokio::fs::File::create(&temp_path)
.await
.map_err(|e| FetchError::Io {
path: temp_path.clone(),
source: e,
})?;
f.set_len(total_size).await.map_err(|e| FetchError::Io {
path: temp_path.clone(),
source: e,
})?;
Ok(temp_path)
}
async fn finalize_chunked_download(
temp_path: &std::path::Path,
blob_path: &std::path::Path,
pointer_path: &std::path::Path,
repo_dir: &std::path::Path,
revision: &str,
commit_hash: &str,
) -> Result<(), FetchError> {
tokio::fs::rename(temp_path, blob_path)
.await
.map_err(|e| FetchError::Io {
path: blob_path.to_path_buf(),
source: e,
})?;
symlink_or_rename(blob_path, pointer_path).map_err(|e| FetchError::Io {
path: pointer_path.to_path_buf(),
source: e,
})?;
let refs_dir = repo_dir.join("refs");
tokio::fs::create_dir_all(&refs_dir)
.await
.map_err(|e| FetchError::Io {
path: refs_dir.clone(),
source: e,
})?;
let ref_path = refs_dir.join(revision);
tokio::fs::write(&ref_path, commit_hash.as_bytes())
.await
.map_err(|e| FetchError::Io {
path: ref_path,
source: e,
})?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn download_chunk(
client: Client,
url: String,
temp_path: PathBuf,
start: u64,
end: u64,
chunk_idx: usize,
retry_policy: &RetryPolicy,
bytes_downloaded: &AtomicU64,
on_progress: Option<&ProgressCallback>,
filename: &str,
total_size: u64,
files_remaining: usize,
) -> Result<(), FetchError> {
let url_owned = url.clone();
let temp_owned = temp_path.clone();
let filename_owned = filename.to_owned();
retry::retry_async(retry_policy, retry::is_retryable, || {
let task_client = client.clone();
let task_url = url_owned.clone();
let task_temp = temp_owned.clone();
let task_filename = filename_owned.clone();
async move {
let range_header = format!("bytes={start}-{end}");
let response = task_client
.get(task_url.as_str())
.header(reqwest::header::RANGE, range_header.as_str())
.send()
.await
.map_err(|e| FetchError::ChunkedDownload {
filename: task_filename.clone(),
reason: format!("chunk {chunk_idx} request failed: {e}"),
})?;
if !response.status().is_success() {
return Err(FetchError::ChunkedDownload {
filename: task_filename.clone(),
reason: format!("chunk {chunk_idx} HTTP {}", response.status()),
});
}
let mut file = tokio::fs::OpenOptions::new()
.write(true)
.open(&task_temp)
.await
.map_err(|e| FetchError::Io {
path: task_temp.clone(),
source: e,
})?;
file.seek(SeekFrom::Start(start))
.await
.map_err(|e| FetchError::Io {
path: task_temp.clone(),
source: e,
})?;
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
let bytes = chunk_result.map_err(|e| FetchError::ChunkedDownload {
filename: task_filename.clone(),
reason: format!("chunk {chunk_idx} stream error: {e}"),
})?;
file.write_all(&bytes).await.map_err(|e| FetchError::Io {
path: task_temp.clone(),
source: e,
})?;
let added = u64::try_from(bytes.len()).unwrap_or(0);
let current = bytes_downloaded.fetch_add(added, Ordering::Relaxed) + added;
if let Some(cb) = on_progress {
let event = progress::streaming_event(
task_filename.as_str(),
current,
total_size,
files_remaining,
);
cb(&event);
}
}
file.flush().await.map_err(|e| FetchError::Io {
path: task_temp,
source: e,
})?;
Ok(())
}
})
.await
}
#[derive(Deserialize)]
struct ApiCommitInfo {
sha: String,
}
async fn resolve_commit_hash(
client: &Client,
repo_id: &str,
revision: &str,
repo_dir: &Path,
) -> Result<String, FetchError> {
let ref_path = repo_dir.join("refs").join(revision);
if let Ok(hash) = tokio::fs::read_to_string(&ref_path).await {
let trimmed = hash.trim().to_owned();
if !trimmed.is_empty() {
return Ok(trimmed);
}
}
let mut url = format!("{HF_ENDPOINT}/api/models/{repo_id}");
if revision != "main" {
url = format!("{url}?revision={revision}");
}
let response = client
.get(url.as_str())
.send()
.await
.map_err(|e| FetchError::Http(format!("resolve commit hash: {e}")))?;
if !response.status().is_success() {
return Err(FetchError::Http(format!(
"resolve commit hash: HTTP {}",
response.status()
)));
}
let info: ApiCommitInfo = response
.json()
.await
.map_err(|e| FetchError::Http(format!("resolve commit hash: {e}")))?;
let refs_dir = repo_dir.join("refs");
tokio::fs::create_dir_all(&refs_dir)
.await
.map_err(|e| FetchError::Io {
path: refs_dir.clone(),
source: e,
})?;
tokio::fs::write(&ref_path, info.sha.as_bytes())
.await
.map_err(|e| FetchError::Io {
path: ref_path,
source: e,
})?;
Ok(info.sha)
}
pub(crate) async fn download_direct(
client: &Client,
repo_id: &str,
revision: &str,
filename: &str,
cache_dir: &Path,
) -> Result<PathBuf, FetchError> {
let repo_folder = repo_folder_name(repo_id);
let repo_dir = cache_dir.join(repo_folder.as_str());
let commit_hash = resolve_commit_hash(client, repo_id, revision, &repo_dir).await?;
let pointer_path = repo_dir
.join("snapshots")
.join(commit_hash.as_str())
.join(filename);
if pointer_path.exists() {
return Ok(pointer_path);
}
let url = build_download_url(repo_id, revision, filename);
let response = client
.get(url.as_str())
.send()
.await
.map_err(|e| FetchError::Http(format!("direct download of {filename}: {e}")))?;
if !response.status().is_success() {
return Err(FetchError::Http(format!(
"direct download of {filename}: HTTP {}",
response.status()
)));
}
let content = response
.bytes()
.await
.map_err(|e| FetchError::Http(format!("direct download of {filename}: {e}")))?;
if let Some(parent) = pointer_path.parent() {
tokio::fs::create_dir_all(parent)
.await
.map_err(|e| FetchError::Io {
path: parent.to_path_buf(),
source: e,
})?;
}
tokio::fs::write(&pointer_path, &content)
.await
.map_err(|e| FetchError::Io {
path: pointer_path.clone(),
source: e,
})?;
Ok(pointer_path)
}
fn make_relative(src: &Path, dst: &Path) -> PathBuf {
let src_components: Vec<Component<'_>> = src.components().collect();
let dst_parent = dst.parent().unwrap_or(dst);
let dst_components: Vec<Component<'_>> = dst_parent.components().collect();
let common_len = src_components
.iter()
.zip(dst_components.iter())
.take_while(|(a, b)| a == b)
.count();
let mut rel = PathBuf::new();
for _ in common_len..dst_components.len() {
rel.push(Component::ParentDir);
}
for comp in src_components.iter().skip(common_len) {
rel.push(comp);
}
rel
}
fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> {
if dst.exists() {
return Ok(());
}
let rel_src = make_relative(src, dst);
#[cfg(target_os = "windows")]
{
if std::os::windows::fs::symlink_file(&rel_src, dst).is_err() {
std::fs::rename(src, dst)?;
}
}
#[cfg(target_family = "unix")]
std::os::unix::fs::symlink(rel_src, dst)?;
Ok(())
}
#[cfg(test)]
mod tests {
#![allow(
clippy::panic,
clippy::unwrap_used,
clippy::expect_used,
clippy::indexing_slicing
)]
use super::*;
#[test]
fn test_repo_folder_name() {
assert_eq!(
repo_folder_name("google/gemma-2-2b"),
"models--google--gemma-2-2b"
);
assert_eq!(
repo_folder_name("RWKV/RWKV7-Goose-World3-1.5B-HF"),
"models--RWKV--RWKV7-Goose-World3-1.5B-HF"
);
}
#[test]
fn test_build_download_url() {
assert_eq!(
build_download_url("google/gemma-2-2b", "main", "config.json"),
"https://huggingface.co/google/gemma-2-2b/resolve/main/config.json"
);
assert_eq!(
build_download_url("org/model", "refs/pr/42", "file.bin"),
"https://huggingface.co/org/model/resolve/refs%2Fpr%2F42/file.bin"
);
}
#[test]
fn test_chunk_boundaries() {
let total: u64 = 1000;
let connections: usize = 4;
let chunk_size = total / u64::try_from(connections).unwrap();
let chunks: Vec<(u64, u64)> = (0..connections)
.map(|i| {
let idx = u64::try_from(i).unwrap();
let start = idx * chunk_size;
let end = if i == connections - 1 {
total - 1
} else {
(idx + 1) * chunk_size - 1
};
(start, end)
})
.collect();
assert_eq!(chunks[0], (0, 249));
assert_eq!(chunks[1], (250, 499));
assert_eq!(chunks[2], (500, 749));
assert_eq!(chunks[3], (750, 999));
}
}