use crate::{Cache, Repo, RepoType};
use indicatif::{ProgressBar, ProgressStyle};
use rand::{distributions::Alphanumeric, Rng};
use std::collections::HashMap;
use super::RepoInfo;
use std::num::ParseIntError;
use std::path::{Component, Path, PathBuf};
use thiserror::Error;
use ureq::Agent;
const VERSION: &str = env!("CARGO_PKG_VERSION");
const NAME: &str = env!("CARGO_PKG_NAME");
const RANGE: &str = "Range";
const CONTENT_RANGE: &str = "Content-Range";
const LOCATION: &str = "Location";
const USER_AGENT: &str = "User-Agent";
const AUTHORIZATION: &str = "Authorization";
type HeaderMap = HashMap<&'static str, String>;
type HeaderName = &'static str;
#[derive(Clone)]
pub struct HeaderAgent {
agent: Agent,
headers: HeaderMap,
}
impl HeaderAgent {
fn new(agent: Agent, headers: HeaderMap) -> Self {
Self { agent, headers }
}
fn get(&self, url: &str) -> ureq::Request {
let mut request = self.agent.get(url);
for (header, value) in &self.headers {
request = request.set(header, value);
}
request
}
}
#[derive(Debug, Error)]
pub enum ApiError {
#[error("Header {0} is missing")]
MissingHeader(HeaderName),
#[error("Header {0} is invalid")]
InvalidHeader(HeaderName),
#[error("request error: {0}")]
RequestError(#[from] Box<ureq::Error>),
#[error("Cannot parse int")]
ParseIntError(#[from] ParseIntError),
#[error("I/O error {0}")]
IoError(#[from] std::io::Error),
#[error("Too many retries: {0}")]
TooManyRetries(Box<ApiError>),
}
pub struct ApiBuilder {
endpoint: String,
cache: Cache,
url_template: String,
token: Option<String>,
progress: bool,
}
impl Default for ApiBuilder {
fn default() -> Self {
Self::new()
}
}
impl ApiBuilder {
pub fn new() -> Self {
let cache = Cache::default();
let token_filename = cache.token_path();
let token = match std::fs::read_to_string(token_filename) {
Ok(token_content) => {
let token_content = token_content.trim();
if !token_content.is_empty() {
Some(token_content.to_string())
} else {
None
}
}
Err(_) => None,
};
let progress = true;
Self {
endpoint: "https://huggingface.co".to_string(),
url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(),
cache,
token,
progress,
}
}
pub fn with_progress(mut self, progress: bool) -> Self {
self.progress = progress;
self
}
pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self {
self.cache = Cache::new(cache_dir);
self
}
pub fn with_token(mut self, token: Option<String>) -> Self {
self.token = token;
self
}
fn build_headers(&self) -> Result<HeaderMap, ApiError> {
let mut headers = HeaderMap::new();
let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown");
headers.insert(USER_AGENT, user_agent);
if let Some(token) = &self.token {
headers.insert(AUTHORIZATION, format!("Bearer {token}"));
}
Ok(headers)
}
pub fn build(self) -> Result<Api, ApiError> {
let headers = self.build_headers()?;
let client = HeaderAgent::new(ureq::builder().build(), headers.clone());
let no_redirect_client = HeaderAgent::new(ureq::builder().redirects(0).build(), headers);
Ok(Api {
endpoint: self.endpoint,
url_template: self.url_template,
cache: self.cache,
client,
no_redirect_client,
progress: self.progress,
})
}
}
#[derive(Debug)]
struct Metadata {
commit_hash: String,
etag: String,
size: usize,
}
#[derive(Clone)]
pub struct Api {
endpoint: String,
url_template: String,
cache: Cache,
client: HeaderAgent,
no_redirect_client: HeaderAgent,
progress: bool,
}
fn temp_filename() -> PathBuf {
let s: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(7)
.map(char::from)
.collect();
let mut path = std::env::temp_dir();
path.push(s);
path
}
fn make_relative(src: &Path, dst: &Path) -> PathBuf {
let path = src;
let base = dst;
if path.is_absolute() != base.is_absolute() {
panic!("This function is made to look at absolute paths only");
}
let mut ita = path.components();
let mut itb = base.components();
loop {
match (ita.next(), itb.next()) {
(Some(a), Some(b)) if a == b => (),
(some_a, _) => {
let mut new_path = PathBuf::new();
for _ in itb {
new_path.push(Component::ParentDir);
}
if let Some(a) = some_a {
new_path.push(a);
for comp in ita {
new_path.push(comp);
}
}
return new_path;
}
}
}
}
fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> {
if dst.exists() {
return Ok(());
}
let src = make_relative(src, dst);
#[cfg(target_os = "windows")]
std::os::windows::fs::symlink_file(src, dst)?;
#[cfg(target_family = "unix")]
std::os::unix::fs::symlink(src, dst)?;
#[cfg(not(any(target_family = "unix", target_os = "windows")))]
std::fs::rename(src, dst)?;
Ok(())
}
impl Api {
pub fn new() -> Result<Self, ApiError> {
ApiBuilder::new().build()
}
pub fn client(&self) -> &HeaderAgent {
&self.client
}
fn metadata(&self, url: &str) -> Result<Metadata, ApiError> {
let response = self
.no_redirect_client
.get(url)
.set(RANGE, "bytes=0-0")
.call()
.map_err(Box::new)?;
let header_commit = "x-repo-commit";
let header_linked_etag = "x-linked-etag";
let header_etag = "etag";
let etag = match response.header(header_linked_etag) {
Some(etag) => etag,
None => response
.header(header_etag)
.ok_or(ApiError::MissingHeader(header_etag))?,
};
let etag = etag.to_string().replace('"', "");
let commit_hash = response
.header(header_commit)
.ok_or(ApiError::MissingHeader(header_commit))?
.to_string();
let status = response.status();
let is_redirection = (300..400).contains(&status);
let response = if is_redirection {
self.client
.get(response.header(LOCATION).unwrap())
.set(RANGE, "bytes=0-0")
.call()
.map_err(Box::new)?
} else {
response
};
let content_range = response
.header(CONTENT_RANGE)
.ok_or(ApiError::MissingHeader(CONTENT_RANGE))?;
let size = content_range
.split('/')
.last()
.ok_or(ApiError::InvalidHeader(CONTENT_RANGE))?
.parse()?;
Ok(Metadata {
commit_hash,
etag,
size,
})
}
fn download_tempfile(
&self,
url: &str,
progressbar: Option<ProgressBar>,
) -> Result<PathBuf, ApiError> {
let filename = temp_filename();
let mut file = std::fs::File::create(&filename)?;
let response = self.client
.get(url)
.call()
.map_err(Box::new)?;
let mut reader = response.into_reader();
if let Some(p) = &progressbar{
reader = Box::new(p.wrap_read(reader));
}
std::io::copy(&mut reader, &mut file)?;
if let Some(p) = progressbar {
p.finish()
}
Ok(filename)
}
pub fn repo(&self, repo: Repo) -> ApiRepo{
ApiRepo::new(self.clone(), repo)
}
pub fn model(&self, model_id: String) -> ApiRepo{
self.repo(Repo::new(model_id, RepoType::Model))
}
pub fn dataset(&self, model_id: String) -> ApiRepo{
self.repo(Repo::new(model_id, RepoType::Dataset))
}
pub fn space(&self, model_id: String) -> ApiRepo{
self.repo(Repo::new(model_id, RepoType::Space))
}
}
pub struct ApiRepo{
api: Api,
repo: Repo,
}
impl ApiRepo{
fn new(api: Api, repo: Repo) -> Self{
Self{api, repo}
}
}
impl ApiRepo{
pub fn url(&self, filename: &str) -> String {
let endpoint = &self.api.endpoint;
let revision = &self.repo.url_revision();
self.api.url_template
.replace("{endpoint}", endpoint)
.replace("{repo_id}", &self.repo.url())
.replace("{revision}", revision)
.replace("{filename}", filename)
}
pub fn get(&self, filename: &str) -> Result<PathBuf, ApiError> {
if let Some(path) = self.api.cache.get(&self.repo, filename) {
Ok(path)
} else {
self.download(filename)
}
}
pub fn download(&self, filename: &str) -> Result<PathBuf, ApiError> {
let url = self.url(filename);
let metadata = self.api.metadata(&url)?;
let blob_path = self.api.cache.blob_path(&self.repo, &metadata.etag);
std::fs::create_dir_all(blob_path.parent().unwrap())?;
let progressbar = if self.api.progress {
let progress = ProgressBar::new(metadata.size as u64);
progress.set_style(
ProgressStyle::with_template(
"{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})",
)
.unwrap(), );
let maxlength = 30;
let message = if filename.len() > maxlength {
format!("..{}", &filename[filename.len() - maxlength..])
} else {
filename.to_string()
};
progress.set_message(message);
Some(progress)
} else {
None
};
let tmp_filename = self.api.download_tempfile(&url, progressbar)?;
if std::fs::rename(&tmp_filename, &blob_path).is_err() {
std::fs::File::create(&blob_path)?;
std::fs::copy(tmp_filename, &blob_path)?;
}
let mut pointer_path = self.api.cache.pointer_path(&self.repo, &metadata.commit_hash);
pointer_path.push(filename);
std::fs::create_dir_all(pointer_path.parent().unwrap()).ok();
symlink_or_rename(&blob_path, &pointer_path)?;
self.api.cache.create_ref(&self.repo, &metadata.commit_hash)?;
Ok(pointer_path)
}
pub fn info(&self) -> Result<RepoInfo, ApiError> {
let url = format!("{}/api/{}", self.api.endpoint, self.repo.api_url());
let response = self.api.client.get(&url).call().map_err(Box::new)?;
let model_info = response.into_json()?;
Ok(model_info)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api::Siblings;
use crate::RepoType;
use hex_literal::hex;
use rand::{distributions::Alphanumeric, Rng};
use sha2::{Digest, Sha256};
struct TempDir {
path: PathBuf,
}
impl TempDir {
pub fn new() -> Self {
let s: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(7)
.map(char::from)
.collect();
let mut path = std::env::temp_dir();
path.push(s);
std::fs::create_dir(&path).unwrap();
Self { path }
}
}
impl Drop for TempDir {
fn drop(&mut self) {
std::fs::remove_dir_all(&self.path).unwrap()
}
}
#[test]
fn simple() {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let model_id = "julien-c/dummy-unknown".to_string();
let downloaded_path = api.model(model_id.clone()).download("config.json").unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
val[..],
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);
let cache_path = api.cache.get(&Repo::new(model_id, RepoType::Model), "config.json").unwrap();
assert_eq!(cache_path, downloaded_path);
}
#[test]
fn dataset() {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let repo = Repo::with_revision(
"wikitext".to_string(),
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let downloaded_path = api.repo(repo)
.download("wikitext-103-v1/wikitext-test.parquet")
.unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
val[..],
hex!("59ce09415ad8aa45a9e34f88cec2548aeb9de9a73fcda9f6b33a86a065f32b90")
)
}
#[test]
fn info() {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let repo = Repo::with_revision(
"wikitext".to_string(),
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let model_info = api.repo(repo).info().unwrap();
assert_eq!(
model_info,
RepoInfo {
sha: "2dd3f79917d431e9af1c81bfa96a575741774077".to_string(),
siblings: vec![
Siblings {
rfilename: ".gitattributes".to_string()
},
Siblings {
rfilename: "wikitext-103-raw-v1/wikitext-test.parquet".to_string()
},
Siblings {
rfilename: "wikitext-103-raw-v1/wikitext-train-00000-of-00002.parquet"
.to_string()
},
Siblings {
rfilename: "wikitext-103-raw-v1/wikitext-train-00001-of-00002.parquet"
.to_string()
},
Siblings {
rfilename: "wikitext-103-raw-v1/wikitext-validation.parquet".to_string()
},
Siblings {
rfilename: "wikitext-103-v1/test/index.duckdb".to_string()
},
Siblings {
rfilename: "wikitext-103-v1/validation/index.duckdb".to_string()
},
Siblings {
rfilename: "wikitext-103-v1/wikitext-test.parquet".to_string()
},
Siblings {
rfilename: "wikitext-103-v1/wikitext-train-00000-of-00002.parquet"
.to_string()
},
Siblings {
rfilename: "wikitext-103-v1/wikitext-train-00001-of-00002.parquet"
.to_string()
},
Siblings {
rfilename: "wikitext-103-v1/wikitext-validation.parquet".to_string()
},
Siblings {
rfilename: "wikitext-2-raw-v1/test/index.duckdb".to_string()
},
Siblings {
rfilename: "wikitext-2-raw-v1/train/index.duckdb".to_string()
},
Siblings {
rfilename: "wikitext-2-raw-v1/validation/index.duckdb".to_string()
},
Siblings {
rfilename: "wikitext-2-raw-v1/wikitext-test.parquet".to_string()
},
Siblings {
rfilename: "wikitext-2-raw-v1/wikitext-train.parquet".to_string()
},
Siblings {
rfilename: "wikitext-2-raw-v1/wikitext-validation.parquet".to_string()
},
Siblings {
rfilename: "wikitext-2-v1/wikitext-test.parquet".to_string()
},
Siblings {
rfilename: "wikitext-2-v1/wikitext-train.parquet".to_string()
},
Siblings {
rfilename: "wikitext-2-v1/wikitext-validation.parquet".to_string()
}
],
}
)
}
}