use super::models::{HfModelInfo, HfTreeEntry};
use anyhow::{bail, Context, Result};
use reqwest::header::HeaderMap;
pub struct HfClient {
client: reqwest::Client,
download_client: reqwest::Client,
api_base: String,
token: Option<String>,
}
impl HfClient {
pub fn new(token: Option<String>) -> Self {
Self::new_with_base(token, "https://huggingface.co".to_string())
}
pub fn new_with_base(token: Option<String>, api_base: String) -> Self {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(60))
.user_agent("securegit")
.build()
.expect("Failed to create HTTP client");
let download_client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(1800))
.user_agent("securegit")
.build()
.expect("Failed to create download HTTP client");
Self {
client,
download_client,
api_base,
token,
}
}
pub fn from_env() -> Self {
let token = std::env::var("HF_TOKEN").ok().or_else(|| {
dirs::cache_dir().and_then(|cache| {
let path = cache.join("huggingface").join("token");
std::fs::read_to_string(path)
.ok()
.map(|s| s.trim().to_string())
})
});
Self::new(token)
}
fn auth_headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
if let Some(ref token) = self.token {
if let Ok(val) = reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token)) {
headers.insert(reqwest::header::AUTHORIZATION, val);
}
}
headers
}
pub async fn search_models(
&self,
query: &str,
task: Option<&str>,
library: Option<&str>,
limit: usize,
) -> Result<Vec<HfModelInfo>> {
let encoded_query = urlencoding::encode(query);
let mut url = format!(
"{}/api/models?search={}&limit={}&sort=downloads&direction=-1",
self.api_base, encoded_query, limit
);
if let Some(task) = task {
url.push_str(&format!("&pipeline_tag={}", urlencoding::encode(task)));
}
if let Some(library) = library {
url.push_str(&format!("&library={}", urlencoding::encode(library)));
}
let resp = self
.client
.get(&url)
.headers(self.auth_headers())
.send()
.await
.context("Failed to search HuggingFace models")?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
bail!("HuggingFace API error ({}): {}", status, text);
}
let models: Vec<HfModelInfo> = resp
.json()
.await
.context("Failed to parse model search results")?;
Ok(models)
}
pub async fn model_info(&self, model_id: &str) -> Result<HfModelInfo> {
let url = format!("{}/api/models/{}", self.api_base, model_id);
let resp = self
.client
.get(&url)
.headers(self.auth_headers())
.send()
.await
.context("Failed to fetch model info")?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
bail!("HuggingFace API error ({}): {}", status, text);
}
let info: HfModelInfo = resp.json().await.context("Failed to parse model info")?;
Ok(info)
}
pub async fn list_files(&self, model_id: &str, revision: &str) -> Result<Vec<HfTreeEntry>> {
let url = format!(
"{}/api/models/{}/tree/{}",
self.api_base, model_id, revision
);
let resp = self
.client
.get(&url)
.headers(self.auth_headers())
.send()
.await
.context("Failed to list model files")?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
bail!("HuggingFace API error ({}): {}", status, text);
}
let entries: Vec<HfTreeEntry> =
resp.json().await.context("Failed to parse file listing")?;
Ok(entries)
}
pub fn file_download_url(&self, model_id: &str, revision: &str, filename: &str) -> String {
format!(
"{}/{}/resolve/{}/{}",
self.api_base, model_id, revision, filename
)
}
pub async fn resolve_revision(&self, model_id: &str, revision: &str) -> Result<String> {
let info = self.model_info(model_id).await?;
info.sha.context(format!(
"Model '{}' at revision '{}' has no SHA",
model_id, revision
))
}
pub async fn download_file(&self, url: &str) -> Result<Vec<u8>> {
let resp = self
.download_client
.get(url)
.headers(self.auth_headers())
.send()
.await
.context("Failed to download file")?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
bail!("HuggingFace download error ({}): {}", status, text);
}
let bytes = resp.bytes().await.context("Failed to read download body")?;
Ok(bytes.to_vec())
}
pub async fn upload_file(
&self,
repo_id: &str,
revision: &str,
path: &str,
data: Vec<u8>,
commit_message: &str,
) -> Result<()> {
use base64::Engine;
let url = format!(
"{}/api/models/{}/commit/{}",
self.api_base, repo_id, revision
);
let encoded = base64::engine::general_purpose::STANDARD.encode(&data);
let header_line = serde_json::json!({
"key": "header",
"value": {
"summary": commit_message,
"description": ""
}
});
let file_line = serde_json::json!({
"key": "file",
"value": {
"content": encoded,
"path": path,
"encoding": "base64"
}
});
let body = format!("{}\n{}\n", header_line, file_line);
let resp = self
.client
.post(&url)
.headers(self.auth_headers())
.header("Content-Type", "application/x-ndjson")
.body(body)
.send()
.await
.context("Failed to upload file")?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
bail!("HuggingFace upload error ({}): {}", status, text);
}
Ok(())
}
pub async fn create_repo(&self, repo_id: &str, private: bool) -> Result<()> {
let url = format!("{}/api/repos/create", self.api_base);
let (org, name) = if let Some((o, n)) = repo_id.split_once('/') {
(Some(o), n)
} else {
(None, repo_id)
};
let mut body = serde_json::json!({
"type": "model",
"name": name,
"private": private,
});
if let Some(org) = org {
body["organization"] = serde_json::json!(org);
}
let resp = self
.client
.post(&url)
.headers(self.auth_headers())
.json(&body)
.send()
.await
.context("Failed to create repository")?;
let status = resp.status();
if status == reqwest::StatusCode::CONFLICT {
return Ok(());
}
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
bail!("HuggingFace API error ({}): {}", status, text);
}
Ok(())
}
pub async fn fetch_openapi_spec(&self) -> Result<String> {
let url = format!("{}/.well-known/openapi.json", self.api_base);
let resp = self
.client
.get(&url)
.headers(self.auth_headers())
.send()
.await
.context("Failed to fetch OpenAPI spec")?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
bail!("HuggingFace API error ({}): {}", status, text);
}
let body = resp
.text()
.await
.context("Failed to read OpenAPI spec body")?;
Ok(body)
}
}