use crate::{
Backend, Body, DirEntry, ExistsRequest, GetRequest, GetResponse, PutRequest, PutResponse,
StatRequest, StatResponse, DEFAULT_USER_AGENT, KEEP_ALIVE_INTERVAL, POOL_MAX_IDLE_PER_HOST,
};
use async_trait::async_trait;
use dragonfly_api::common::v2::Range;
use dragonfly_client_config::dfdaemon::Config;
use dragonfly_client_core::{
error::{BackendError, ErrorType, OrErr},
Error, Result,
};
use dragonfly_client_util::tls::NoVerifier;
use futures::TryStreamExt;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_LENGTH, RANGE, USER_AGENT};
use reqwest::Client;
use serde::Deserialize;
use std::error::Error as _;
use std::io::Error as IOError;
use std::sync::Arc;
use tokio_util::io::StreamReader;
use tracing::{debug, error};
use url::Url;
pub const SCHEME: &str = "hf";
const HUGGING_FACE_BASE_URL: &str = "https://huggingface.co";
#[derive(Default, Debug, Deserialize)]
#[serde(default, rename_all = "camelCase")]
#[allow(dead_code)]
struct Repository {
#[serde(rename = "_id")]
id: String,
model_id: Option<String>,
private: bool,
siblings: Option<Vec<Sibling>>,
}
#[derive(Default, Debug, Deserialize)]
#[serde(default, rename_all = "camelCase")]
struct Sibling {
rfilename: String,
size: Option<u64>,
lfs: Option<Lfs>,
}
#[derive(Default, Debug, Deserialize)]
#[serde(default, rename_all = "camelCase")]
#[allow(dead_code)]
struct Lfs {
size: u64,
sha256: Option<String>,
pointer_size: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct ParsedURL {
pub url: Url,
pub repository_id: String,
pub repository_type: RepositoryType,
pub file_path: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum RepositoryType {
Model,
Dataset,
Space,
}
impl RepositoryType {
#[allow(dead_code)]
pub fn as_str(&self) -> &'static str {
match self {
RepositoryType::Model => "models",
RepositoryType::Dataset => "datasets",
RepositoryType::Space => "spaces",
}
}
}
impl TryFrom<Url> for ParsedURL {
type Error = Error;
fn try_from(url: Url) -> std::result::Result<Self, Self::Error> {
let host = url
.host_str()
.ok_or_else(|| Error::InvalidURI(url.to_string()))?;
let raw_path = format!("{}{}", host, url.path().trim_end_matches('/'));
let segments: Vec<&str> = raw_path.trim_matches('/').split('/').collect();
let (repository_type, offset) = match segments.first() {
Some(&"datasets") => (RepositoryType::Dataset, 1),
Some(&"spaces") => (RepositoryType::Space, 1),
Some(&"models") => (RepositoryType::Model, 1),
_ => (RepositoryType::Model, 0),
};
let remaining = &segments[offset..];
if remaining.len() < 2 {
return Err(Error::InvalidParameter);
}
let repository_id = format!("{}/{}", remaining[0], remaining[1]);
let file_path = if remaining.len() > 2 {
Some(remaining[2..].join("/"))
} else {
None
};
Ok(ParsedURL {
url,
repository_type,
repository_id,
file_path,
})
}
}
impl TryFrom<&str> for ParsedURL {
type Error = Error;
fn try_from(url: &str) -> std::result::Result<Self, Self::Error> {
let parsed_url = Url::parse(url).or_err(ErrorType::ParseError)?;
ParsedURL::try_from(parsed_url)
}
}
pub struct HuggingFace {
scheme: String,
client: Client,
}
impl HuggingFace {
pub fn new(config: Arc<Config>) -> Result<Self> {
let client_config_builder = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(NoVerifier::new())
.with_no_client_auth();
let client = reqwest::Client::builder()
.no_gzip()
.no_brotli()
.no_zstd()
.no_deflate()
.hickory_dns(config.backend.enable_hickory_dns)
.use_preconfigured_tls(client_config_builder)
.pool_max_idle_per_host(POOL_MAX_IDLE_PER_HOST)
.tcp_keepalive(KEEP_ALIVE_INTERVAL)
.tcp_nodelay(true)
.build()?;
Ok(Self {
scheme: SCHEME.to_string(),
client,
})
}
fn resolve_base_urls(base_url: Option<&str>) -> Result<(Url, Url)> {
let base_url = Url::parse(base_url.unwrap_or(HUGGING_FACE_BASE_URL))?;
let api_base_url = base_url.join("/api/")?;
Ok((base_url, api_base_url))
}
fn build_download_url(
parsed_url: &ParsedURL,
file_path: &str,
revision: &str,
base_url: &Url,
) -> Result<Url> {
let path = match parsed_url.repository_type {
RepositoryType::Model => {
format!(
"{}/resolve/{}/{}",
parsed_url.repository_id, revision, file_path
)
}
RepositoryType::Dataset => {
format!(
"datasets/{}/resolve/{}/{}",
parsed_url.repository_id, revision, file_path
)
}
RepositoryType::Space => {
format!(
"spaces/{}/resolve/{}/{}",
parsed_url.repository_id, revision, file_path
)
}
};
Ok(base_url.join(&path)?)
}
fn build_repository_url(parsed_url: &ParsedURL, api_base_url: &Url) -> Result<Url> {
let path = format!(
"{}/{}",
parsed_url.repository_type.as_str(),
parsed_url.repository_id
);
Ok(api_base_url.join(&path)?)
}
fn build_repository_revision_url(
parsed_url: &ParsedURL,
revision: &str,
api_base_url: &Url,
) -> Result<Url> {
let path = format!(
"{}/{}?revision={}",
parsed_url.repository_type.as_str(),
parsed_url.repository_id,
revision
);
Ok(api_base_url.join(&path)?)
}
fn build_hf_url(parsed_url: &ParsedURL, filename: &str) -> Result<Url> {
let url = match parsed_url.repository_type {
RepositoryType::Model => {
format!("{}://{}/{}", SCHEME, parsed_url.repository_id, filename)
}
RepositoryType::Dataset => {
format!(
"{}://datasets/{}/{}",
SCHEME, parsed_url.repository_id, filename
)
}
RepositoryType::Space => {
format!(
"{}://spaces/{}/{}",
SCHEME, parsed_url.repository_id, filename
)
}
};
Ok(Url::parse(&url)?)
}
fn build_request_headers(token: Option<String>, range: Option<Range>) -> Result<HeaderMap> {
let mut request_header = HeaderMap::new();
if let Some(range) = &range {
request_header.insert(
RANGE,
format!("bytes={}-{}", range.start, range.start + range.length - 1).parse()?,
);
};
request_header
.entry(USER_AGENT)
.or_insert(HeaderValue::from_static(DEFAULT_USER_AGENT));
if let Some(token) = token {
request_header.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
);
}
Ok(request_header)
}
}
#[async_trait]
impl Backend for HuggingFace {
fn scheme(&self) -> String {
self.scheme.clone()
}
async fn stat(&self, request: StatRequest) -> Result<StatResponse> {
debug!(
"stat request {} {}: {:?}",
request.task_id, request.url, request.http_header
);
let request_header = Self::build_request_headers(
request
.hugging_face
.as_ref()
.and_then(|hf| hf.token.clone()),
None,
)?;
let hugging_face = request.hugging_face.as_ref().ok_or_else(|| {
error!(
"stat request {} {}: missing Hugging Face information",
request.task_id, request.url
);
Error::InvalidParameter
})?;
let parsed_url = ParsedURL::try_from(request.url.as_str())?;
let (base_url, api_base_url) = Self::resolve_base_urls(hugging_face.base_url.as_deref())?;
match &parsed_url.file_path {
Some(file_path) => {
let download_url = Self::build_download_url(
&parsed_url,
file_path,
&hugging_face.revision,
&base_url,
)?;
let response = match self
.client
.head(download_url.as_str())
.headers(request_header)
.timeout(request.timeout)
.send()
.await
{
Ok(response) => response,
Err(err) => {
error!(
"stat request failed {} {}: {}",
request.task_id, download_url, err
);
return Ok(StatResponse {
success: false,
content_length: None,
http_header: None,
http_status_code: None,
entries: Vec::new(),
error_message: Some(err.to_string()),
});
}
};
let response_status_code = response.status();
let response_header = response.headers().clone();
let content_length = match response_header.get(CONTENT_LENGTH) {
Some(content_length) => content_length.to_str()?.parse::<u64>().ok(),
None => response.content_length(),
};
debug!(
"stat response {} {}: {:?} {:?} {:?}",
request.task_id,
download_url,
response_status_code,
content_length,
response_header
);
Ok(StatResponse {
success: response_status_code.is_success(),
content_length,
http_header: Some(response_header),
http_status_code: Some(response_status_code),
error_message: Some(response_status_code.to_string()),
entries: Vec::new(),
})
}
None => {
let repository_revision_url = Self::build_repository_revision_url(
&parsed_url,
&hugging_face.revision,
&api_base_url,
)?;
let response = match self
.client
.get(repository_revision_url.as_str())
.headers(request_header)
.timeout(request.timeout)
.send()
.await
{
Ok(response) => response,
Err(err) => {
error!(
"stat request failed {} {}: {}",
request.task_id, repository_revision_url, err
);
return Ok(StatResponse {
success: false,
content_length: None,
http_header: None,
http_status_code: None,
entries: Vec::new(),
error_message: Some(err.to_string()),
});
}
};
let response_status_code = response.status();
let response_header = response.headers().clone();
let content_length = match response_header.get(CONTENT_LENGTH) {
Some(content_length) => content_length.to_str()?.parse::<u64>().ok(),
None => response.content_length(),
};
if !response.status().is_success() {
return Ok(StatResponse {
success: false,
content_length: None,
http_header: Some(response_header),
http_status_code: response_status_code.into(),
error_message: Some(response_status_code.to_string()),
entries: Vec::new(),
});
}
let text = response.text().await.map_err(|err| {
error!(
"stat request failed {} {}: {}",
request.task_id, repository_revision_url, err
);
Error::BackendError(Box::new(BackendError {
message: err.to_string(),
status_code: None,
header: None,
}))
})?;
let repository: Repository = serde_json::from_str(&text).map_err(|err| {
error!(
"stat request failed {} {}: {}",
request.task_id, repository_revision_url, err
);
Error::BackendError(Box::new(BackendError {
message: err.to_string(),
status_code: None,
header: None,
}))
})?;
let entries: Vec<DirEntry> = repository
.siblings
.unwrap_or_default()
.into_iter()
.map(|sibling: Sibling| -> Result<DirEntry> {
let hf_url: Url = Self::build_hf_url(&parsed_url, &sibling.rfilename)?;
let content_length: u64 = sibling
.lfs
.as_ref()
.map(|lfs: &Lfs| lfs.size)
.or(sibling.size)
.unwrap_or(0);
Ok(DirEntry {
url: hf_url.to_string(),
content_length: content_length as usize,
is_dir: false,
})
})
.collect::<Result<Vec<_>>>()?;
debug!(
"stat response {} {}: {:?} {:?} {:?}",
request.task_id,
repository_revision_url,
response_status_code,
content_length,
response_header
);
Ok(StatResponse {
success: response_status_code.is_success(),
content_length,
http_header: Some(response_header),
http_status_code: Some(response_status_code),
error_message: Some(response_status_code.to_string()),
entries,
})
}
}
}
async fn get(&self, request: GetRequest) -> Result<GetResponse<Body>> {
debug!(
"get request {} {} {}: {:?}",
request.task_id, request.piece_id, request.url, request.http_header
);
let request_header = Self::build_request_headers(
request
.hugging_face
.as_ref()
.and_then(|hf| hf.token.clone()),
request.range,
)?;
let hugging_face = request.hugging_face.as_ref().ok_or_else(|| {
error!(
"get request {} {}: missing Hugging Face information",
request.task_id, request.url
);
Error::InvalidParameter
})?;
let parsed_url = ParsedURL::try_from(request.url.as_str())?;
let Some(file_path) = &parsed_url.file_path else {
error!(
"get request {} {}: URL must specify a file path",
request.task_id, request.url
);
return Err(Error::InvalidParameter);
};
let (base_url, _) = Self::resolve_base_urls(hugging_face.base_url.as_deref())?;
let download_url =
Self::build_download_url(&parsed_url, file_path, &hugging_face.revision, &base_url)?;
let response = match self
.client
.get(download_url.as_str())
.headers(request_header)
.timeout(request.timeout)
.send()
.await
{
Ok(response) => response,
Err(err) => {
error!(
"get request failed {} {} {}: {}",
request.task_id, request.piece_id, download_url, err
);
return Ok(GetResponse {
success: false,
http_header: None,
http_status_code: None,
reader: Box::new(tokio::io::empty()),
error_message: Some(err.to_string()),
});
}
};
let response_header = response.headers().clone();
let response_status_code = response.status();
let response_reader = Box::new(StreamReader::new(response.bytes_stream().map_err(
move |err| {
let mut chain = err.to_string();
let mut source = err.source();
while let Some(err) = source {
chain.push_str(": ");
chain.push_str(&err.to_string());
source = err.source();
}
IOError::other(chain)
},
)));
debug!(
"get response {} {}: {:?} {:?}",
request.task_id, request.piece_id, response_status_code, response_header,
);
Ok(GetResponse {
success: response_status_code.is_success(),
http_header: Some(response_header),
http_status_code: Some(response_status_code),
reader: response_reader,
error_message: Some(response_status_code.to_string()),
})
}
async fn put(&self, _request: PutRequest) -> Result<PutResponse> {
unimplemented!()
}
async fn exists(&self, request: ExistsRequest) -> Result<bool> {
debug!(
"exists request {} {}: {:?}",
request.task_id, request.url, request.http_header
);
let request_header = Self::build_request_headers(
request
.hugging_face
.as_ref()
.and_then(|hf| hf.token.clone()),
None,
)?;
let hugging_face = request.hugging_face.as_ref().ok_or_else(|| {
error!(
"exists request {} {}: missing Hugging Face information",
request.task_id, request.url
);
Error::InvalidParameter
})?;
let parsed_url = ParsedURL::try_from(request.url.as_str())?;
let (base_url, api_base_url) = Self::resolve_base_urls(hugging_face.base_url.as_deref())?;
match &parsed_url.file_path {
Some(file_path) => {
let download_url = Self::build_download_url(
&parsed_url,
file_path,
&hugging_face.revision,
&base_url,
)?;
let response = self
.client
.head(download_url.as_str())
.headers(request_header)
.timeout(request.timeout)
.send()
.await
.inspect_err(|err| {
error!(
"exists request failed {} {}: {}",
request.task_id, request.url, err
);
})?;
let response_status_code = response.status();
debug!(
"exists response {} {}: {:?} {:?}",
request.task_id,
request.url,
response_status_code,
response.headers()
);
Ok(response_status_code.is_success())
}
None => {
let repository_url = Self::build_repository_url(&parsed_url, &api_base_url)?;
let response = self
.client
.head(repository_url.as_str())
.headers(request_header)
.timeout(request.timeout)
.send()
.await
.inspect_err(|err| {
error!(
"exists request failed {} {}: {}",
request.task_id, request.url, err
);
})?;
let response_status_code = response.status();
debug!(
"exists response {} {}: {:?} {:?}",
request.task_id,
request.url,
response_status_code,
response.headers()
);
Ok(response_status_code.is_success())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::DEFAULT_USER_AGENT;
#[test]
fn test_parse_url_simple() {
let parsed_url = ParsedURL::try_from("hf://deepseek-ai/DeepSeek-OCR").unwrap();
assert_eq!(parsed_url.repository_id, "deepseek-ai/DeepSeek-OCR");
assert_eq!(parsed_url.repository_type, RepositoryType::Model);
assert!(parsed_url.file_path.is_none());
}
#[test]
fn test_parse_url_with_file() {
let parsed_url =
ParsedURL::try_from("hf://deepseek-ai/DeepSeek-OCR/model.safetensors").unwrap();
assert_eq!(parsed_url.repository_id, "deepseek-ai/DeepSeek-OCR");
assert_eq!(parsed_url.repository_type, RepositoryType::Model);
assert_eq!(parsed_url.file_path, Some("model.safetensors".to_string()));
}
#[test]
fn test_parse_url_with_revision() {
let parsed_url = ParsedURL::try_from("hf://deepseek-ai/DeepSeek-OCR").unwrap();
assert_eq!(parsed_url.repository_id, "deepseek-ai/DeepSeek-OCR");
assert!(parsed_url.file_path.is_none());
}
#[test]
fn test_parse_url_with_nested_path() {
let parsed_url =
ParsedURL::try_from("hf://deepseek-ai/DeepSeek-OCR/models/v1/model.bin").unwrap();
assert_eq!(parsed_url.repository_id, "deepseek-ai/DeepSeek-OCR");
assert_eq!(parsed_url.repository_type, RepositoryType::Model);
assert_eq!(
parsed_url.file_path,
Some("models/v1/model.bin".to_string())
);
}
#[test]
fn test_parse_url_dataset() {
let parsed_url = ParsedURL::try_from("hf://datasets/huggingface/squad").unwrap();
assert_eq!(parsed_url.repository_id, "huggingface/squad");
assert_eq!(parsed_url.repository_type, RepositoryType::Dataset);
assert!(parsed_url.file_path.is_none());
}
#[test]
fn test_parse_url_dataset_with_path() {
let parsed_url = ParsedURL::try_from("hf://datasets/huggingface/squad/train.json").unwrap();
assert_eq!(parsed_url.repository_id, "huggingface/squad");
assert_eq!(parsed_url.repository_type, RepositoryType::Dataset);
assert_eq!(parsed_url.file_path, Some("train.json".to_string()));
}
#[test]
fn test_parse_url_space() {
let parsed_url = ParsedURL::try_from("hf://spaces/huggingface/transformers-demo").unwrap();
assert_eq!(parsed_url.repository_id, "huggingface/transformers-demo");
assert_eq!(parsed_url.repository_type, RepositoryType::Space);
assert!(parsed_url.file_path.is_none());
}
#[test]
fn test_parse_url_explicit_model_type() {
let parsed_url =
ParsedURL::try_from("hf://models/deepseek-ai/DeepSeek-OCR/model.safetensors").unwrap();
assert_eq!(parsed_url.repository_id, "deepseek-ai/DeepSeek-OCR");
assert_eq!(parsed_url.repository_type, RepositoryType::Model);
assert_eq!(parsed_url.file_path, Some("model.safetensors".to_string()));
}
#[test]
fn test_parse_url_missing_repo() {
let result = ParsedURL::try_from("hf://deepseek-ai");
assert!(result.is_err());
}
#[test]
fn test_build_download_url_model() {
let parsed_url =
ParsedURL::try_from("hf://deepseek-ai/DeepSeek-OCR/model.safetensors").unwrap();
let url = HuggingFace::build_download_url(
&parsed_url,
"model.safetensors",
"main",
&Url::parse(HUGGING_FACE_BASE_URL).unwrap(),
)
.unwrap();
assert_eq!(
url.as_str(),
"https://huggingface.co/deepseek-ai/DeepSeek-OCR/resolve/main/model.safetensors"
);
}
#[test]
fn test_build_download_url_dataset() {
let parsed_url = ParsedURL::try_from("hf://datasets/huggingface/squad/train.json").unwrap();
let url = HuggingFace::build_download_url(
&parsed_url,
"train.json",
"main",
&Url::parse(HUGGING_FACE_BASE_URL).unwrap(),
)
.unwrap();
assert_eq!(
url.as_str(),
"https://huggingface.co/datasets/huggingface/squad/resolve/main/train.json"
);
}
#[test]
fn test_build_api_url_model() {
let parsed_url = ParsedURL::try_from("hf://deepseek-ai/DeepSeek-OCR").unwrap();
let url = HuggingFace::build_repository_url(
&parsed_url,
&Url::parse("https://huggingface.co/api/").unwrap(),
)
.unwrap();
assert_eq!(
url.as_str(),
"https://huggingface.co/api/models/deepseek-ai/DeepSeek-OCR"
);
}
#[test]
fn test_build_api_url_dataset() {
let parsed_url = ParsedURL::try_from("hf://datasets/huggingface/squad").unwrap();
let url = HuggingFace::build_repository_url(
&parsed_url,
&Url::parse("https://huggingface.co/api/").unwrap(),
)
.unwrap();
assert_eq!(
url.as_str(),
"https://huggingface.co/api/datasets/huggingface/squad"
);
}
#[test]
fn test_build_hf_url_model() {
let parsed_url = ParsedURL::try_from("hf://deepseek-ai/DeepSeek-OCR").unwrap();
let url = HuggingFace::build_hf_url(&parsed_url, "model.safetensors").unwrap();
assert_eq!(
url.as_str(),
"hf://deepseek-ai/DeepSeek-OCR/model.safetensors"
);
}
#[test]
fn test_build_hf_url_dataset() {
let parsed_url = ParsedURL::try_from("hf://datasets/huggingface/squad").unwrap();
let url = HuggingFace::build_hf_url(&parsed_url, "train.json").unwrap();
assert_eq!(url.as_str(), "hf://datasets/huggingface/squad/train.json");
}
#[test]
fn test_resolve_base_urls() {
let (base_url, api_base_url) =
HuggingFace::resolve_base_urls(Some("https://hf-mirror.com/")).unwrap();
assert_eq!(base_url.as_str(), "https://hf-mirror.com/");
assert_eq!(api_base_url.as_str(), "https://hf-mirror.com/api/");
}
#[test]
fn test_build_headers_default_user_agent() {
let request_header = HuggingFace::build_request_headers(None, None).unwrap();
assert_eq!(
request_header.get(USER_AGENT).unwrap(),
HeaderValue::from_static(DEFAULT_USER_AGENT)
);
}
#[test]
fn test_build_headers_preserves_request_headers() {
let request_headers =
HuggingFace::build_request_headers(Some("test-token".to_string()), None).unwrap();
assert_eq!(
request_headers.get(reqwest::header::AUTHORIZATION).unwrap(),
"Bearer test-token"
);
assert_eq!(
request_headers.get(USER_AGENT).unwrap(),
HeaderValue::from_static(DEFAULT_USER_AGENT)
);
}
#[test]
fn test_build_headers_with_range() {
let request_headers = HuggingFace::build_request_headers(
None,
Some(Range {
start: 0,
length: 1024,
}),
)
.unwrap();
assert_eq!(
request_headers.get(RANGE).unwrap(),
HeaderValue::from_static("bytes=0-1023")
);
}
}