use std::io::SeekFrom;
use std::path::{Component, Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use futures_util::StreamExt;
use reqwest::Client;
use tokio::io::{AsyncSeekExt, AsyncWriteExt};
use tokio::sync::Mutex as AsyncMutex;
use tokio::task::JoinSet;
use serde::Deserialize;
use crate::chunked_state::ChunkedState;
use crate::error::FetchError;
use crate::progress::{self, ProgressEvent};
use crate::retry::{self, RetryPolicy};
const SIDECAR_CHECKPOINT_BYTES: u64 = 16 * 1024 * 1024;
type ProgressCallback = Arc<dyn Fn(&ProgressEvent) + Send + Sync>;
const HF_ENDPOINT: &str = "https://huggingface.co";
const CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
pub(crate) struct RangeInfo {
pub content_length: u64,
pub commit_hash: String,
pub etag: String,
pub cdn_url: String,
pub cdn_expires_at: Option<Instant>,
}
#[must_use]
pub(crate) fn build_download_url(repo_id: &str, revision: &str, filename: &str) -> String {
let url_revision = revision.replace('/', "%2F");
format!("{HF_ENDPOINT}/{repo_id}/resolve/{url_revision}/{filename}")
}
pub(crate) fn build_no_redirect_client(token: Option<&str>) -> Result<Client, FetchError> {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::USER_AGENT,
reqwest::header::HeaderValue::from_static("hf-fetch-model"),
);
if let Some(tok) = token {
let auth_value = format!("Bearer {tok}");
let header_val = reqwest::header::HeaderValue::from_str(auth_value.as_str())
.map_err(|e| FetchError::Http(e.to_string()))?;
headers.insert(reqwest::header::AUTHORIZATION, header_val);
}
Client::builder()
.redirect(reqwest::redirect::Policy::none())
.connect_timeout(CONNECT_TIMEOUT)
.default_headers(headers)
.build()
.map_err(|e| FetchError::Http(e.to_string()))
}
pub fn build_client(token: Option<&str>) -> Result<Client, FetchError> {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::USER_AGENT,
reqwest::header::HeaderValue::from_static("hf-fetch-model"),
);
if let Some(tok) = token {
let auth_value = format!("Bearer {tok}");
let header_val = reqwest::header::HeaderValue::from_str(auth_value.as_str())
.map_err(|e| FetchError::Http(e.to_string()))?;
headers.insert(reqwest::header::AUTHORIZATION, header_val);
}
Client::builder()
.connect_timeout(CONNECT_TIMEOUT)
.default_headers(headers)
.build()
.map_err(|e| FetchError::Http(e.to_string()))
}
pub(crate) async fn probe_range_support(
client: Client,
url: String,
token: Option<String>,
) -> Result<Option<RangeInfo>, FetchError> {
let no_redirect_client = build_no_redirect_client(token.as_deref())?;
let response = no_redirect_client
.get(url.as_str())
.header(reqwest::header::RANGE, "bytes=0-0")
.send()
.await
.map_err(|e| FetchError::Http(e.to_string()))?;
let (hf_headers, redirect_url) = if response.status().is_redirection() {
let headers = response.headers().clone();
let location = headers
.get(reqwest::header::LOCATION)
.and_then(|v| v.to_str().ok())
.map(str::to_owned);
(headers, location)
} else if response.status() == reqwest::StatusCode::PARTIAL_CONTENT {
let headers = response.headers().clone();
(headers, None)
} else {
return Ok(None);
};
let commit_hash = hf_headers
.get("x-repo-commit")
.and_then(|v| v.to_str().ok())
.map(str::to_owned)
.ok_or_else(|| FetchError::Http("missing x-repo-commit header".to_owned()))?;
let etag = hf_headers
.get("x-linked-etag")
.or_else(|| hf_headers.get(reqwest::header::ETAG))
.and_then(|v| v.to_str().ok())
.map(str::to_owned)
.ok_or_else(|| FetchError::Http("missing etag header".to_owned()))?;
let etag = etag.replace('"', "");
let (cdn_url, content_length) = if let Some(ref loc) = redirect_url {
let cdn_response = client
.get(loc.as_str())
.header(reqwest::header::RANGE, "bytes=0-0")
.send()
.await
.map_err(|e| FetchError::Http(e.to_string()))?;
let size = parse_content_length_from_range(&cdn_response)?;
(loc.clone(), size)
} else {
let direct_response = client
.get(url.as_str())
.header(reqwest::header::RANGE, "bytes=0-0")
.send()
.await
.map_err(|e| FetchError::Http(e.to_string()))?;
let size = parse_content_length_from_range(&direct_response)?;
(url, size)
};
let cdn_expires_at = parse_cdn_expiry(&cdn_url);
Ok(Some(RangeInfo {
content_length,
commit_hash,
etag,
cdn_url,
cdn_expires_at,
}))
}
fn parse_content_length_from_range(response: &reqwest::Response) -> Result<u64, FetchError> {
let content_range = response
.headers()
.get(reqwest::header::CONTENT_RANGE)
.and_then(|v| v.to_str().ok())
.ok_or_else(|| FetchError::Http("missing Content-Range header".to_owned()))?;
content_range
.split('/')
.next_back()
.and_then(|s| s.parse::<u64>().ok())
.ok_or_else(|| FetchError::Http(format!("invalid Content-Range header: {content_range}")))
}
fn parse_cdn_expiry(url: &str) -> Option<Instant> {
let query = url.split('?').nth(1)?;
let expires_str = query
.split('&')
.find_map(|param| param.strip_prefix("X-Amz-Expires="))?;
let seconds: u64 = expires_str.parse().ok()?;
Some(Instant::now() + Duration::from_secs(seconds))
}
#[allow(
clippy::too_many_arguments,
clippy::too_many_lines,
clippy::indexing_slicing
)]
pub(crate) async fn download_chunked(
client: Client,
range_info: RangeInfo,
cache_dir: PathBuf,
repo_folder: String,
revision: String,
filename: String,
connections: usize,
retry_policy: RetryPolicy,
on_progress: Option<ProgressCallback>,
files_remaining: usize,
) -> Result<PathBuf, FetchError> {
let total_size = range_info.content_length;
let repo_dir = cache_dir.join(repo_folder.as_str());
let blob_path = crate::cache_layout::blob_path(&repo_dir, range_info.etag.as_str());
let pointer_path = crate::cache_layout::pointer_path(
&repo_dir,
range_info.commit_hash.as_str(),
filename.as_str(),
);
if pointer_path.exists() {
return Ok(pointer_path);
}
let chunk_size = total_size / u64::try_from(connections).unwrap_or(1);
let chunks_layout: Vec<(usize, u64, u64)> = (0..connections)
.map(|i| {
let idx = u64::try_from(i).unwrap_or(0);
let start = idx * chunk_size;
let end = if i == connections - 1 {
total_size - 1
} else {
(idx + 1) * chunk_size - 1
};
(i, start, end)
})
.collect();
let temp_path = crate::cache_layout::temp_blob_path(&repo_dir, range_info.etag.as_str());
let state_path = crate::cache_layout::temp_state_path(&repo_dir, range_info.etag.as_str());
let resume_state = prepare_or_resume_temp_file(
&blob_path,
&pointer_path,
&temp_path,
&state_path,
range_info.etag.as_str(),
total_size,
chunks_layout.as_slice(),
connections,
)
.await?;
let _temp_guard = TempFileGuard::new(temp_path.clone());
let already_done: u64 = resume_state.chunks.iter().map(|c| c.completed).sum();
let bytes_downloaded = Arc::new(AtomicU64::new(already_done));
let resume_offsets: Vec<u64> = resume_state.chunks.iter().map(|c| c.completed).collect();
let shared_state = Arc::new(AsyncMutex::new(resume_state));
let mut join_set = JoinSet::new();
for (chunk_idx, start, end) in chunks_layout {
let task_client = client.clone();
let task_url = range_info.cdn_url.clone();
let task_temp = temp_path.clone();
let task_state_path = state_path.clone();
let task_state = Arc::clone(&shared_state);
let task_policy = retry_policy.clone();
let task_bytes = Arc::clone(&bytes_downloaded);
let task_progress = on_progress.clone();
let task_filename = filename.clone();
let task_initial_offset = resume_offsets[chunk_idx];
join_set.spawn(async move {
download_chunk(
task_client,
task_url,
task_temp,
start,
end,
chunk_idx,
task_initial_offset,
task_state,
task_state_path,
&task_policy,
&task_bytes,
task_progress.as_ref(),
task_filename.as_str(),
total_size,
files_remaining,
)
.await
});
}
let mut failures: Vec<String> = Vec::new();
while let Some(join_result) = join_set.join_next().await {
match join_result {
Ok(Ok(())) => {}
Ok(Err(e)) => failures.push(e.to_string()),
Err(e) => failures.push(format!("chunk task failed: {e}")),
}
}
if !failures.is_empty() {
return Err(FetchError::ChunkedDownload {
filename: filename.clone(),
reason: failures.join("; "),
});
}
finalize_chunked_download(
&temp_path,
&blob_path,
&pointer_path,
&repo_dir,
revision.as_str(),
range_info.commit_hash.as_str(),
)
.await?;
let _ = ChunkedState::remove(&state_path).await;
Ok(pointer_path)
}
struct TempFileGuard {
path: PathBuf,
wipe_on_drop: bool,
}
impl TempFileGuard {
fn new(path: PathBuf) -> Self {
Self {
path,
wipe_on_drop: false,
}
}
#[allow(dead_code)]
fn mark_corrupt(&mut self) {
self.wipe_on_drop = true;
}
}
impl Drop for TempFileGuard {
fn drop(&mut self) {
if self.wipe_on_drop {
let _ = std::fs::remove_file(&self.path);
}
}
}
#[allow(clippy::too_many_arguments)]
async fn prepare_or_resume_temp_file(
blob_path: &Path,
pointer_path: &Path,
temp_path: &Path,
state_path: &Path,
etag: &str,
total_size: u64,
chunks_layout: &[(usize, u64, u64)],
connections: usize,
) -> Result<ChunkedState, FetchError> {
if let Some(parent) = blob_path.parent() {
tokio::fs::create_dir_all(parent)
.await
.map_err(|e| FetchError::Io {
path: parent.to_path_buf(),
source: e,
})?;
}
if let Some(parent) = pointer_path.parent() {
tokio::fs::create_dir_all(parent)
.await
.map_err(|e| FetchError::Io {
path: parent.to_path_buf(),
source: e,
})?;
}
let existing_state = ChunkedState::load(state_path).await?;
let temp_exists = tokio::fs::try_exists(temp_path)
.await
.map_err(|e| FetchError::Io {
path: temp_path.to_path_buf(),
source: e,
})?;
if let Some(state) = existing_state {
if state.is_compatible_with(etag, total_size, connections) && temp_exists {
return Ok(state);
}
}
let _ = tokio::fs::remove_file(temp_path).await;
ChunkedState::remove(state_path).await?;
let file = tokio::fs::File::create(temp_path)
.await
.map_err(|e| FetchError::Io {
path: temp_path.to_path_buf(),
source: e,
})?;
file.set_len(total_size).await.map_err(|e| FetchError::Io {
path: temp_path.to_path_buf(),
source: e,
})?;
drop(file);
let fresh = ChunkedState::new_fresh(etag.to_owned(), total_size, connections, chunks_layout);
fresh.save_atomic(state_path).await?;
Ok(fresh)
}
async fn finalize_chunked_download(
temp_path: &std::path::Path,
blob_path: &std::path::Path,
pointer_path: &std::path::Path,
repo_dir: &std::path::Path,
revision: &str,
commit_hash: &str,
) -> Result<(), FetchError> {
tokio::fs::rename(temp_path, blob_path)
.await
.map_err(|e| FetchError::Io {
path: blob_path.to_path_buf(),
source: e,
})?;
symlink_or_copy(blob_path, pointer_path).map_err(|e| FetchError::Io {
path: pointer_path.to_path_buf(),
source: e,
})?;
let refs_dir = crate::cache_layout::refs_dir(repo_dir);
tokio::fs::create_dir_all(&refs_dir)
.await
.map_err(|e| FetchError::Io {
path: refs_dir.clone(),
source: e,
})?;
let ref_path = crate::cache_layout::ref_path(repo_dir, revision);
tokio::fs::write(&ref_path, commit_hash.as_bytes())
.await
.map_err(|e| FetchError::Io {
path: ref_path,
source: e,
})?;
Ok(())
}
#[allow(
clippy::too_many_arguments,
clippy::too_many_lines,
clippy::indexing_slicing
)]
async fn download_chunk(
client: Client,
url: String,
temp_path: PathBuf,
start: u64,
end: u64,
chunk_idx: usize,
initial_offset: u64,
shared_state: Arc<AsyncMutex<ChunkedState>>,
state_path: PathBuf,
retry_policy: &RetryPolicy,
bytes_downloaded: &AtomicU64,
on_progress: Option<&ProgressCallback>,
filename: &str,
total_size: u64,
files_remaining: usize,
) -> Result<(), FetchError> {
let url_owned = url.clone();
let temp_owned = temp_path.clone();
let filename_owned = filename.to_owned();
retry::retry_async(retry_policy, retry::is_retryable, || {
let task_client = client.clone();
let task_url = url_owned.clone();
let task_temp = temp_owned.clone();
let task_state_path = state_path.clone();
let task_state = Arc::clone(&shared_state);
let task_filename = filename_owned.clone();
async move {
let (resume_completed, already_done) = {
let guard = task_state.lock().await;
let progress = &guard.chunks[chunk_idx];
(progress.completed, progress.is_complete())
};
if already_done {
return Ok(());
}
let resume_byte = start.saturating_add(resume_completed.max(initial_offset));
let effective_resume_completed = resume_byte.saturating_sub(start);
let range_header = format!("bytes={resume_byte}-{end}");
let response = task_client
.get(task_url.as_str())
.header(reqwest::header::RANGE, range_header.as_str())
.send()
.await
.map_err(|e| FetchError::ChunkedDownload {
filename: task_filename.clone(),
reason: format!("chunk {chunk_idx} request failed: {e}"),
})?;
if !response.status().is_success() {
return Err(FetchError::ChunkedDownload {
filename: task_filename.clone(),
reason: format!("chunk {chunk_idx} HTTP {}", response.status()),
});
}
let mut file = tokio::fs::OpenOptions::new()
.write(true)
.open(&task_temp)
.await
.map_err(|e| FetchError::Io {
path: task_temp.clone(),
source: e,
})?;
file.seek(SeekFrom::Start(resume_byte))
.await
.map_err(|e| FetchError::Io {
path: task_temp.clone(),
source: e,
})?;
let mut stream = response.bytes_stream();
let mut current_completed = effective_resume_completed;
let mut last_checkpoint = current_completed;
while let Some(chunk_result) = stream.next().await {
let bytes = chunk_result.map_err(|e| FetchError::ChunkedDownload {
filename: task_filename.clone(),
reason: format!("chunk {chunk_idx} stream error: {e}"),
})?;
file.write_all(&bytes).await.map_err(|e| FetchError::Io {
path: task_temp.clone(),
source: e,
})?;
let added = u64::try_from(bytes.len()).unwrap_or(0);
current_completed = current_completed.saturating_add(added);
let current = bytes_downloaded.fetch_add(added, Ordering::Relaxed) + added;
if let Some(cb) = on_progress {
let event = progress::streaming_event(
task_filename.as_str(),
current,
total_size,
files_remaining,
);
cb(&event);
}
if current_completed.saturating_sub(last_checkpoint) >= SIDECAR_CHECKPOINT_BYTES {
let snapshot = {
let mut guard = task_state.lock().await;
guard.chunks[chunk_idx].completed = current_completed;
guard.clone()
};
let _ = snapshot.save_atomic(task_state_path.as_path()).await;
last_checkpoint = current_completed;
}
}
file.flush().await.map_err(|e| FetchError::Io {
path: task_temp,
source: e,
})?;
let snapshot = {
let mut guard = task_state.lock().await;
guard.chunks[chunk_idx].completed = current_completed;
guard.clone()
};
let _ = snapshot.save_atomic(task_state_path.as_path()).await;
Ok(())
}
})
.await
}
#[derive(Deserialize)]
struct ApiCommitInfo {
sha: String,
}
async fn resolve_commit_hash(
client: &Client,
repo_id: &str,
revision: &str,
repo_dir: &Path,
) -> Result<String, FetchError> {
let ref_path = crate::cache_layout::ref_path(repo_dir, revision);
if let Ok(hash) = tokio::fs::read_to_string(&ref_path).await {
let trimmed = hash.trim().to_owned();
if !trimmed.is_empty() {
return Ok(trimmed);
}
}
let mut url = format!("{HF_ENDPOINT}/api/models/{repo_id}");
if revision != "main" {
url = format!("{url}?revision={revision}");
}
let response = client
.get(url.as_str())
.send()
.await
.map_err(|e| FetchError::Http(format!("resolve commit hash: {e}")))?;
if !response.status().is_success() {
return Err(FetchError::Http(format!(
"resolve commit hash: HTTP {}",
response.status()
)));
}
let info: ApiCommitInfo = response
.json()
.await
.map_err(|e| FetchError::Http(format!("resolve commit hash: {e}")))?;
let refs_dir = crate::cache_layout::refs_dir(repo_dir);
tokio::fs::create_dir_all(&refs_dir)
.await
.map_err(|e| FetchError::Io {
path: refs_dir.clone(),
source: e,
})?;
tokio::fs::write(&ref_path, info.sha.as_bytes())
.await
.map_err(|e| FetchError::Io {
path: ref_path,
source: e,
})?;
Ok(info.sha)
}
pub(crate) async fn download_direct(
client: &Client,
repo_id: &str,
revision: &str,
filename: &str,
cache_dir: &Path,
) -> Result<PathBuf, FetchError> {
let repo_dir = crate::cache_layout::repo_dir(cache_dir, repo_id);
let commit_hash = resolve_commit_hash(client, repo_id, revision, &repo_dir).await?;
let pointer_path = crate::cache_layout::pointer_path(&repo_dir, commit_hash.as_str(), filename);
if pointer_path.exists() {
return Ok(pointer_path);
}
let url = build_download_url(repo_id, revision, filename);
let response = client
.get(url.as_str())
.send()
.await
.map_err(|e| FetchError::Http(format!("direct download of {filename}: {e}")))?;
if !response.status().is_success() {
return Err(FetchError::Http(format!(
"direct download of {filename}: HTTP {}",
response.status()
)));
}
let content = response
.bytes()
.await
.map_err(|e| FetchError::Http(format!("direct download of {filename}: {e}")))?;
if let Some(parent) = pointer_path.parent() {
tokio::fs::create_dir_all(parent)
.await
.map_err(|e| FetchError::Io {
path: parent.to_path_buf(),
source: e,
})?;
}
tokio::fs::write(&pointer_path, &content)
.await
.map_err(|e| FetchError::Io {
path: pointer_path.clone(),
source: e,
})?;
Ok(pointer_path)
}
fn make_relative(src: &Path, dst: &Path) -> PathBuf {
let src_components: Vec<Component<'_>> = src.components().collect();
let dst_parent = dst.parent().unwrap_or(dst);
let dst_components: Vec<Component<'_>> = dst_parent.components().collect();
let common_len = src_components
.iter()
.zip(dst_components.iter())
.take_while(|(a, b)| a == b)
.count();
let mut rel = PathBuf::new();
for _ in common_len..dst_components.len() {
rel.push(Component::ParentDir);
}
for comp in src_components.iter().skip(common_len) {
rel.push(comp);
}
rel
}
fn symlink_or_copy(src: &Path, dst: &Path) -> Result<(), std::io::Error> {
if dst.exists() {
return Ok(());
}
let rel_src = make_relative(src, dst);
#[cfg(target_os = "windows")]
{
if std::os::windows::fs::symlink_file(&rel_src, dst).is_err() {
std::fs::copy(src, dst)?;
}
}
#[cfg(target_family = "unix")]
{
if let Err(e) = std::os::unix::fs::symlink(rel_src, dst) {
if e.kind() != std::io::ErrorKind::AlreadyExists {
return Err(e);
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
#![allow(
clippy::panic,
clippy::unwrap_used,
clippy::expect_used,
clippy::indexing_slicing
)]
use super::*;
#[test]
fn test_repo_folder_name() {
assert_eq!(
crate::cache_layout::repo_folder_name("google/gemma-2-2b"),
"models--google--gemma-2-2b"
);
assert_eq!(
crate::cache_layout::repo_folder_name("RWKV/RWKV7-Goose-World3-1.5B-HF"),
"models--RWKV--RWKV7-Goose-World3-1.5B-HF"
);
}
#[test]
fn test_build_download_url() {
assert_eq!(
build_download_url("google/gemma-2-2b", "main", "config.json"),
"https://huggingface.co/google/gemma-2-2b/resolve/main/config.json"
);
assert_eq!(
build_download_url("org/model", "refs/pr/42", "file.bin"),
"https://huggingface.co/org/model/resolve/refs%2Fpr%2F42/file.bin"
);
}
#[test]
fn test_chunk_boundaries() {
let total: u64 = 1000;
let connections: usize = 4;
let chunk_size = total / u64::try_from(connections).unwrap();
let chunks: Vec<(u64, u64)> = (0..connections)
.map(|i| {
let idx = u64::try_from(i).unwrap();
let start = idx * chunk_size;
let end = if i == connections - 1 {
total - 1
} else {
(idx + 1) * chunk_size - 1
};
(start, end)
})
.collect();
assert_eq!(chunks[0], (0, 249));
assert_eq!(chunks[1], (250, 499));
assert_eq!(chunks[2], (500, 749));
assert_eq!(chunks[3], (750, 999));
}
#[test]
fn temp_file_guard_keeps_file_on_drop_by_default() {
let dir = std::env::temp_dir().join(format!("hf-fm-tempguard-keep-{}", std::process::id()));
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("partial.chunked.part");
std::fs::write(&path, b"some bytes").unwrap();
assert!(path.exists());
{
let _guard = TempFileGuard::new(path.clone());
}
assert!(
path.exists(),
"default-drop should preserve the file at {}",
path.display()
);
std::fs::remove_file(&path).ok();
std::fs::remove_dir(&dir).ok();
}
#[test]
fn temp_file_guard_wipes_file_after_mark_corrupt() {
let dir = std::env::temp_dir().join(format!("hf-fm-tempguard-wipe-{}", std::process::id()));
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("partial.chunked.part");
std::fs::write(&path, b"corrupt bytes").unwrap();
assert!(path.exists());
{
let mut guard = TempFileGuard::new(path.clone());
guard.mark_corrupt();
}
assert!(
!path.exists(),
"mark_corrupt should wipe the file at {}",
path.display()
);
std::fs::remove_dir(&dir).ok();
}
#[test]
fn temp_file_guard_drop_is_safe_when_file_already_gone() {
let dir = std::env::temp_dir().join(format!("hf-fm-tempguard-gone-{}", std::process::id()));
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("never-existed.chunked.part");
{
let mut guard = TempFileGuard::new(path.clone());
guard.mark_corrupt();
}
assert!(!path.exists());
std::fs::remove_dir(&dir).ok();
}
}