use anyhow::bail;
use hf_hub::{Repo, RepoType};
use httpdate::parse_http_date;
use reqwest::header::{HeaderValue, CONTENT_LENGTH, LAST_MODIFIED, RANGE};
use reqwest::{Response, StatusCode, Url};
use std::path::PathBuf;
use std::str::FromStr;
use tokio::fs::{File, OpenOptions};
use tokio::io::AsyncWriteExt;
use crate::FileSource;
#[derive(Debug)]
#[derive(Clone)]
pub enum ModelLoadingProgress {
Downloading {
source: String,
start_time: std::time::Instant,
progress: f32,
},
Loading {
progress: f32,
},
}
impl ModelLoadingProgress {
pub fn downloading(source: String, progress: f32, start_time: std::time::Instant) -> Self {
Self::Downloading {
source,
progress,
start_time,
}
}
pub fn downloading_progress(source: String) -> impl FnMut(f32) -> Self + Send + Sync {
let start = std::time::Instant::now();
move |progress| ModelLoadingProgress::downloading(source.clone(), progress, start)
}
pub fn loading(progress: f32) -> Self {
Self::Loading { progress }
}
pub fn estimate_time_remaining(&self) -> Option<std::time::Duration> {
match self {
Self::Downloading {
start_time,
progress,
..
} => {
let elapsed = start_time.elapsed();
let remaining = (1. - progress) * elapsed.as_secs_f32();
Some(std::time::Duration::from_secs_f32(remaining))
}
_ => None,
}
}
pub fn multi_bar_loading_indicator() -> impl FnMut(ModelLoadingProgress) + Send + Sync + 'static
{
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use std::collections::HashMap;
let m = MultiProgress::new();
let sty = ProgressStyle::with_template(
"{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] ({pos}/{len}, ETA {eta})",
)
.unwrap();
let mut progress_bars = HashMap::new();
move |progress| match progress {
ModelLoadingProgress::Downloading {
source, progress, ..
} => {
let n = 100;
let progress = progress * n as f32;
let progress_bar = progress_bars.entry(source.clone()).or_insert_with(|| {
let pb = m.add(ProgressBar::new(n));
pb.set_message(format!("Downloading {source}"));
pb.set_style(sty.clone());
pb
});
progress_bar.set_position(progress as u64);
}
ModelLoadingProgress::Loading { progress } => {
for pb in progress_bars.values_mut() {
pb.finish();
}
let progress = progress * 100.;
m.println(format!("Loading {progress:.2}%")).unwrap();
}
}
}
}
#[derive(Debug, Clone)]
pub struct Cache {
location: PathBuf,
huggingface_token: Option<String>,
}
impl Cache {
pub fn new(location: PathBuf) -> Self {
Self {
location,
huggingface_token: None,
}
}
pub fn with_huggingface_token(mut self, token: Option<String>) -> Self {
self.huggingface_token = token;
self
}
pub fn exists(&self, source: &FileSource) -> bool {
match source {
FileSource::HuggingFace {
model_id,
revision,
file,
..
} => {
let path = self.location.join(model_id).join(revision);
let complete_download = path.join(file);
complete_download.exists()
}
FileSource::Local(path) => path.exists(),
}
}
pub async fn get(
&self,
source: &FileSource,
progress: impl FnMut(f32),
) -> anyhow::Result<PathBuf> {
match source {
FileSource::HuggingFace {
model_id,
revision,
file,
} => {
let token = self.huggingface_token.clone().or_else(huggingface_token);
let path = self.location.join(model_id).join(revision);
let complete_download = path.join(file);
let api = hf_hub::api::sync::Api::new()?;
let repo = Repo::with_revision(
model_id.to_string(),
RepoType::Model,
revision.to_string(),
);
let api = api.repo(repo);
let url = api.url(file);
let url = Url::from_str(&url)?;
let client = reqwest::Client::new();
let response = client
.head(url.clone())
.with_authorization_header(token.clone())
.send()
.await;
if complete_download.exists() {
let metadata = tokio::fs::metadata(&complete_download).await?;
let file_last_modified = metadata.modified()?;
if let Some(last_updated) = response
.as_ref()
.ok()
.and_then(|response| response.headers().get(LAST_MODIFIED))
.and_then(|last_updated| last_updated.to_str().ok())
.and_then(|s| parse_http_date(s).ok())
{
if last_updated <= file_last_modified {
return Ok(complete_download);
}
} else {
return Ok(complete_download);
}
}
let incomplete_download = path.join(format!("{}.partial", file));
tracing::trace!("Downloading into {:?}", incomplete_download);
download_into(
url,
&incomplete_download,
response?,
client,
token,
progress,
)
.await?;
tokio::fs::rename(&incomplete_download, &complete_download).await?;
Ok(complete_download)
}
FileSource::Local(path) => Ok(path.clone()),
}
}
}
impl Default for Cache {
fn default() -> Self {
Self {
location: dirs::data_dir().unwrap().join("kalosm").join("cache"),
huggingface_token: None,
}
}
}
impl FileSource {
pub async fn download(&self, progress: impl FnMut(f32)) -> anyhow::Result<PathBuf> {
let cache = Cache::default();
cache.get(self, progress).await
}
}
async fn download_into(
url: Url,
file: &PathBuf,
head: Response,
client: reqwest::Client,
token: Option<String>,
mut progress: impl FnMut(f32),
) -> anyhow::Result<()> {
let length = head
.headers()
.get(CONTENT_LENGTH)
.ok_or("response doesn't include the content length")
.unwrap();
let length = length.to_str().ok().and_then(|s| u64::from_str(s).ok());
let (start, mut output_file) = if let Ok(metadata) = tokio::fs::metadata(file).await {
let start = metadata.len();
let output_file = OpenOptions::new().append(true).open(file).await.unwrap();
(start, output_file)
} else {
tokio::fs::create_dir_all(file.parent().unwrap()).await?;
(0, File::create(file).await.unwrap())
};
if let Some(length) = length {
progress(start as f32 / length as f32);
}
if Some(start) == length {
tracing::trace!("File {} already downloaded", file.display());
progress(1.0);
return Ok(());
}
let range = length
.and_then(|length| HeaderValue::from_str(&format!("bytes={}-{}", start, length - 1)).ok());
tracing::trace!("Fetching range {:?}", range);
let mut request = client.get(url).with_authorization_header(token);
if let Some(range) = range {
request = request.header(RANGE, range);
}
let mut response = request.send().await?;
let status = response.status();
if !(status == StatusCode::OK || status == StatusCode::PARTIAL_CONTENT) {
bail!("Unexpected status code: {:?}", status);
}
let mut current_progress = start;
while let Some(chunk) = response.chunk().await? {
output_file.write_all(&chunk).await?;
tracing::trace!("wrote chunk of size {}", chunk.len());
current_progress += chunk.len() as u64;
if let Some(length) = length {
progress(current_progress as f32 / length as f32);
}
}
tracing::trace!("Download of {} complete", file.display());
progress(1.0);
Ok(())
}
trait RequestBuilderExt {
fn with_authorization_header(self, token: Option<String>) -> Self;
}
impl RequestBuilderExt for reqwest::RequestBuilder {
fn with_authorization_header(self, token: Option<String>) -> Self {
if let Some(token) = token {
self.header(reqwest::header::AUTHORIZATION, format!("Bearer {token}"))
} else {
self
}
}
}
#[tokio::test]
async fn downloads_work() {
let url = "https://httpbin.org/range/102400?duration=2";
let file = PathBuf::from("download.bin");
let progress = |p| {
println!("Progress: {}", p);
};
let client = reqwest::Client::new();
let response = client.head(url).send().await.unwrap();
download_into(
Url::from_str(url).unwrap(),
&file,
response,
client,
None,
progress,
)
.await
.unwrap();
assert!(file.exists());
tokio::fs::remove_file(file).await.unwrap();
}
fn huggingface_token() -> Option<String> {
let cache = hf_hub::Cache::default();
cache.token().or_else(|| std::env::var("HF_TOKEN").ok())
}