use crate::error::{ModelError, ModelResult};
use scirs2_core::ndarray::Array2;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::fs;
use std::io::Write;
use std::path::PathBuf;
const HUGGINGFACE_HUB_URL: &str = "https://huggingface.co";
fn default_cache_dir() -> PathBuf {
#[cfg(target_os = "linux")]
{
if let Ok(xdg_cache) = std::env::var("XDG_CACHE_HOME") {
PathBuf::from(xdg_cache).join("huggingface/hub")
} else if let Ok(home) = std::env::var("HOME") {
PathBuf::from(home).join(".cache/huggingface/hub")
} else {
PathBuf::from("/tmp/huggingface/hub")
}
}
#[cfg(target_os = "macos")]
{
if let Ok(home) = std::env::var("HOME") {
PathBuf::from(home).join("Library/Caches/huggingface/hub")
} else {
PathBuf::from("/tmp/huggingface/hub")
}
}
#[cfg(target_os = "windows")]
{
if let Ok(appdata) = std::env::var("LOCALAPPDATA") {
PathBuf::from(appdata).join("huggingface\\hub")
} else {
PathBuf::from("C:\\Temp\\huggingface\\hub")
}
}
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
{
PathBuf::from("/tmp/huggingface/hub")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
#[serde(rename = "architectures")]
pub architecture: Option<Vec<String>>,
#[serde(rename = "hidden_size")]
pub hidden_dim: Option<usize>,
#[serde(rename = "num_hidden_layers")]
pub num_layers: Option<usize>,
pub vocab_size: Option<usize>,
pub max_position_embeddings: Option<usize>,
#[serde(rename = "state_size")]
pub state_dim: Option<usize>,
pub num_attention_heads: Option<usize>,
pub model_type: Option<String>,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct HuggingFaceHub {
pub token: Option<String>,
pub cache_dir: PathBuf,
client: reqwest::Client,
pub verify_integrity: bool,
}
impl HuggingFaceHub {
pub fn new() -> ModelResult<Self> {
let token = std::env::var("HF_TOKEN")
.ok()
.or_else(|| std::env::var("HUGGING_FACE_HUB_TOKEN").ok());
let cache_dir = default_cache_dir();
let client = reqwest::Client::builder()
.user_agent("kizzasi/0.1.0")
.timeout(std::time::Duration::from_secs(300))
.build()
.map_err(|e| {
ModelError::simple_load_error(format!("Failed to create HTTP client: {}", e))
})?;
Ok(Self {
token,
cache_dir,
client,
verify_integrity: true,
})
}
pub fn with_cache_dir(cache_dir: impl Into<PathBuf>) -> ModelResult<Self> {
let mut hub = Self::new()?;
hub.cache_dir = cache_dir.into();
Ok(hub)
}
pub fn with_token(mut self, token: impl Into<String>) -> Self {
self.token = Some(token.into());
self
}
pub fn with_verification(mut self, verify: bool) -> Self {
self.verify_integrity = verify;
self
}
pub async fn download_file(
&self,
repo_id: &str,
filename: &str,
revision: Option<&str>,
) -> ModelResult<PathBuf> {
let revision = revision.unwrap_or("main");
let repo_path = repo_id.replace('/', "--");
let cache_path = self
.cache_dir
.join(format!("models--{}", repo_path))
.join("snapshots")
.join(revision)
.join(filename);
if cache_path.exists() {
tracing::debug!("Using cached file: {}", cache_path.display());
return Ok(cache_path);
}
let url = format!(
"{}/{}/resolve/{}/{}",
HUGGINGFACE_HUB_URL, repo_id, revision, filename
);
tracing::info!("Downloading {} from HuggingFace Hub", filename);
tracing::debug!("URL: {}", url);
let mut request = self.client.get(&url);
if let Some(token) = &self.token {
request = request.header("Authorization", format!("Bearer {}", token));
}
let response = request
.send()
.await
.map_err(|e| ModelError::simple_load_error(format!("Download failed: {}", e)))?;
if !response.status().is_success() {
return Err(ModelError::simple_load_error(format!(
"Download failed with status {}: {}",
response.status(),
response.text().await.unwrap_or_default()
)));
}
let total_size = response.content_length().unwrap_or(0);
tracing::info!("Downloading {} bytes", total_size);
let bytes = response
.bytes()
.await
.map_err(|e| ModelError::simple_load_error(format!("Download failed: {}", e)))?;
if self.verify_integrity {
let hash = Sha256::digest(&bytes);
let hash_hex: String = hash.iter().map(|b| format!("{b:02x}")).collect();
tracing::debug!("SHA256: {}", hash_hex);
}
if let Some(parent) = cache_path.parent() {
fs::create_dir_all(parent).map_err(|e| {
ModelError::simple_load_error(format!("Failed to create cache directory: {}", e))
})?;
}
let mut file = fs::File::create(&cache_path).map_err(|e| {
ModelError::simple_load_error(format!("Failed to create cache file: {}", e))
})?;
file.write_all(&bytes).map_err(|e| {
ModelError::simple_load_error(format!("Failed to write cache file: {}", e))
})?;
tracing::info!("Downloaded to: {}", cache_path.display());
Ok(cache_path)
}
pub async fn load_config(
&self,
repo_id: &str,
revision: Option<&str>,
) -> ModelResult<ModelConfig> {
let config_path = self.download_file(repo_id, "config.json", revision).await?;
let config_data = fs::read_to_string(&config_path).map_err(|e| {
ModelError::simple_load_error(format!("Failed to read config.json: {}", e))
})?;
let config: ModelConfig = serde_json::from_str(&config_data).map_err(|e| {
ModelError::simple_load_error(format!("Failed to parse config.json: {}", e))
})?;
Ok(config)
}
pub async fn load_model(
&self,
repo_id: &str,
revision: Option<&str>,
) -> ModelResult<HashMap<String, Array2<f32>>> {
let safetensors_files = ["model.safetensors", "pytorch_model.safetensors"];
let mut safetensors_path = None;
for filename in &safetensors_files {
match self.download_file(repo_id, filename, revision).await {
Ok(path) => {
safetensors_path = Some(path);
break;
}
Err(_) => continue,
}
}
let path = safetensors_path.ok_or_else(|| {
ModelError::simple_load_error(format!(
"No SafeTensors file found in repository: {}. Tried: {:?}",
repo_id, safetensors_files
))
})?;
tracing::info!("Loading weights from: {}", path.display());
use crate::loader::ModelLoader;
let loader = ModelLoader::new(&path)?;
let tensor_names = loader.list_tensors();
let mut weights = HashMap::new();
for name in tensor_names {
if let Ok(tensor) = loader.load_array2(&name) {
weights.insert(name, tensor);
}
}
tracing::info!("Loaded {} tensors from HuggingFace model", weights.len());
Ok(weights)
}
pub async fn load_model_loader(
&self,
repo_id: &str,
revision: Option<&str>,
) -> ModelResult<crate::loader::ModelLoader> {
let safetensors_files = ["model.safetensors", "pytorch_model.safetensors"];
let mut safetensors_path = None;
for filename in &safetensors_files {
match self.download_file(repo_id, filename, revision).await {
Ok(path) => {
safetensors_path = Some(path);
break;
}
Err(_) => continue,
}
}
let path = safetensors_path.ok_or_else(|| {
ModelError::simple_load_error(format!(
"No SafeTensors file found in repository: {}. Tried: {:?}",
repo_id, safetensors_files
))
})?;
tracing::info!("Creating ModelLoader for: {}", path.display());
crate::loader::ModelLoader::new(&path)
}
pub async fn list_files(
&self,
repo_id: &str,
revision: Option<&str>,
) -> ModelResult<Vec<String>> {
let revision = revision.unwrap_or("main");
let url = format!(
"{}/api/models/{}/tree/{}",
HUGGINGFACE_HUB_URL, repo_id, revision
);
let mut request = self.client.get(&url);
if let Some(token) = &self.token {
request = request.header("Authorization", format!("Bearer {}", token));
}
let response = request
.send()
.await
.map_err(|e| ModelError::simple_load_error(format!("API request failed: {}", e)))?;
if !response.status().is_success() {
return Err(ModelError::simple_load_error(format!(
"API request failed with status {}",
response.status()
)));
}
let files: Vec<serde_json::Value> = response.json().await.map_err(|e| {
ModelError::simple_load_error(format!("Failed to parse API response: {}", e))
})?;
let file_names: Vec<String> = files
.iter()
.filter_map(|f| f.get("path")?.as_str().map(|s| s.to_string()))
.collect();
Ok(file_names)
}
pub fn clear_cache(&self, repo_id: &str) -> ModelResult<()> {
let repo_path = repo_id.replace('/', "--");
let cache_path = self.cache_dir.join(format!("models--{}", repo_path));
if cache_path.exists() {
fs::remove_dir_all(&cache_path).map_err(|e| {
ModelError::simple_load_error(format!("Failed to clear cache: {}", e))
})?;
tracing::info!("Cleared cache for repository: {}", repo_id);
}
Ok(())
}
pub fn cache_size(&self) -> ModelResult<u64> {
let mut total_size = 0u64;
if !self.cache_dir.exists() {
return Ok(0);
}
for entry in fs::read_dir(&self.cache_dir).map_err(|e| {
ModelError::simple_load_error(format!("Failed to read cache directory: {}", e))
})? {
let entry = entry.map_err(|e| {
ModelError::simple_load_error(format!("Failed to read directory entry: {}", e))
})?;
let metadata = entry.metadata().map_err(|e| {
ModelError::simple_load_error(format!("Failed to get file metadata: {}", e))
})?;
if metadata.is_file() {
total_size += metadata.len();
}
}
Ok(total_size)
}
}
impl Default for HuggingFaceHub {
fn default() -> Self {
Self::new().expect("Failed to create default HuggingFaceHub")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_cache_dir() {
let cache_dir = default_cache_dir();
assert!(cache_dir.to_string_lossy().contains("huggingface"));
}
#[test]
fn test_hub_creation() {
let hub = HuggingFaceHub::new();
assert!(hub.is_ok());
let hub = hub.unwrap();
assert!(hub.cache_dir.to_string_lossy().contains("huggingface"));
}
#[test]
fn test_custom_cache_dir() {
let custom_dir = PathBuf::from("/tmp/test_cache");
let hub = HuggingFaceHub::with_cache_dir(&custom_dir);
assert!(hub.is_ok());
let hub = hub.unwrap();
assert_eq!(hub.cache_dir, custom_dir);
}
#[test]
fn test_with_token() {
let hub = HuggingFaceHub::new()
.unwrap()
.with_token("test_token_12345");
assert_eq!(hub.token.as_deref(), Some("test_token_12345"));
}
#[test]
fn test_with_verification() {
let hub = HuggingFaceHub::new().unwrap().with_verification(false);
assert!(!hub.verify_integrity);
let hub2 = HuggingFaceHub::new().unwrap().with_verification(true);
assert!(hub2.verify_integrity);
}
#[test]
fn test_model_config_deserialization() {
let config_json = r#"{
"architectures": ["MambaForCausalLM"],
"hidden_size": 768,
"num_hidden_layers": 24,
"vocab_size": 50280,
"model_type": "mamba"
}"#;
let config: ModelConfig = serde_json::from_str(config_json).unwrap();
assert_eq!(config.hidden_dim, Some(768));
assert_eq!(config.num_layers, Some(24));
assert_eq!(config.vocab_size, Some(50280));
assert_eq!(config.model_type.as_deref(), Some("mamba"));
}
#[test]
fn test_repo_path_conversion() {
let repo_id = "state-spaces/mamba-130m";
let repo_path = repo_id.replace('/', "--");
assert_eq!(repo_path, "state-spaces--mamba-130m");
}
#[tokio::test]
async fn test_cache_size_empty() {
let temp_dir = std::env::temp_dir().join("kizzasi_test_cache_empty");
let hub = HuggingFaceHub::with_cache_dir(&temp_dir).unwrap();
let size = hub.cache_size().unwrap();
assert_eq!(size, 0);
}
}