#[cfg(feature = "daemon")]
use super::types::*;
#[cfg(feature = "daemon")]
use super::urlencoding;
#[cfg(feature = "daemon")]
use super::{HF_API_BASE, HF_MODELS_BASE};
use crate::error::MullamaError;
#[cfg(feature = "daemon")]
use crate::Model;
use std::fs;
#[cfg(feature = "daemon")]
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
pub struct HFClient {
pub download_dir: PathBuf,
token: Option<String>,
#[allow(dead_code)]
user_agent: String,
}
impl HFClient {
pub fn new() -> Self {
let download_dir = dirs::cache_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("mullama")
.join("models");
Self {
download_dir,
token: None,
user_agent: format!("mullama/{}", env!("CARGO_PKG_VERSION")),
}
}
pub fn with_download_dir<P: AsRef<Path>>(download_dir: P) -> Self {
Self {
download_dir: download_dir.as_ref().to_path_buf(),
token: None,
user_agent: format!("mullama/{}", env!("CARGO_PKG_VERSION")),
}
}
pub fn with_token(mut self, token: &str) -> Self {
self.token = Some(token.to_string());
self
}
pub fn with_token_from_env(mut self) -> Self {
self.token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
self
}
#[cfg(feature = "daemon")]
pub fn search_models(
&self,
filters: &ModelSearchFilters,
) -> Result<Vec<HFModelInfo>, MullamaError> {
let mut url = format!("{}/models", HF_API_BASE);
let mut params = Vec::new();
if let Some(ref query) = filters.query {
params.push(format!("search={}", urlencoding::encode(query)));
}
if let Some(ref author) = filters.author {
params.push(format!("author={}", urlencoding::encode(author)));
}
for tag in &filters.tags {
params.push(format!("tags={}", urlencoding::encode(tag)));
}
if let Some(ref sort) = filters.sort {
params.push(format!("sort={}", sort));
params.push("direction=-1".to_string()); }
if let Some(limit) = filters.limit {
params.push(format!("limit={}", limit));
}
if filters.gguf_only {
params.push("filter=gguf".to_string());
}
if !params.is_empty() {
url = format!("{}?{}", url, params.join("&"));
}
let response = self.http_get(&url)?;
let models: Vec<serde_json::Value> = serde_json::from_str(&response).map_err(|e| {
MullamaError::HuggingFaceError(format!("Failed to parse response: {}", e))
})?;
let mut results = Vec::new();
for model_json in models {
if let Some(model_info) = self.parse_model_info(&model_json) {
if let Some(min_downloads) = filters.min_downloads {
if model_info.downloads < min_downloads {
continue;
}
}
results.push(model_info);
}
}
Ok(results)
}
#[cfg(feature = "daemon")]
pub fn get_model_info(&self, model_id: &str) -> Result<HFModelInfo, MullamaError> {
let url = format!("{}/models/{}", HF_API_BASE, model_id);
let response = self.http_get(&url)?;
let model_json: serde_json::Value = serde_json::from_str(&response).map_err(|e| {
MullamaError::HuggingFaceError(format!("Failed to parse model info: {}", e))
})?;
self.parse_model_info(&model_json).ok_or_else(|| {
MullamaError::HuggingFaceError(format!("Invalid model data for {}", model_id))
})
}
#[cfg(feature = "daemon")]
pub fn list_gguf_files(&self, model_id: &str) -> Result<Vec<GGUFFile>, MullamaError> {
let url = format!("{}/models/{}/tree/main", HF_API_BASE, model_id);
let response = self.http_get(&url)?;
let files: Vec<serde_json::Value> = serde_json::from_str(&response).map_err(|e| {
MullamaError::HuggingFaceError(format!("Failed to parse file list: {}", e))
})?;
let mut gguf_files = Vec::new();
for file in files {
if let Some(filename) = file.get("path").and_then(|p| p.as_str()) {
if filename.to_lowercase().ends_with(".gguf") {
let size = file.get("size").and_then(|s| s.as_u64()).unwrap_or(0);
let sha256 = file
.get("oid")
.and_then(|o| o.as_str())
.map(|s| s.to_string());
gguf_files.push(GGUFFile {
filename: filename.to_string(),
size,
quantization: QuantizationType::from_filename(filename),
download_url: format!(
"{}/{}/resolve/main/{}",
HF_MODELS_BASE, model_id, filename
),
sha256,
});
}
}
}
gguf_files.sort_by_key(|f| f.size);
Ok(gguf_files)
}
#[cfg(feature = "daemon")]
pub fn download_gguf(
&self,
model_id: &str,
gguf_file: &GGUFFile,
progress_callback: Option<ProgressCallback>,
) -> Result<PathBuf, MullamaError> {
let model_dir = self.download_dir.join(model_id.replace('/', "_"));
fs::create_dir_all(&model_dir).map_err(MullamaError::IoError)?;
let dest_path = model_dir.join(&gguf_file.filename);
if dest_path.exists() {
if let Ok(metadata) = fs::metadata(&dest_path) {
if metadata.len() == gguf_file.size {
return Ok(dest_path);
}
}
}
self.download_file(
&gguf_file.download_url,
&dest_path,
gguf_file.size,
&gguf_file.filename,
progress_callback,
)?;
Ok(dest_path)
}
#[cfg(feature = "daemon")]
pub fn download_lora(
&self,
model_id: &str,
filename: Option<&str>,
progress_callback: Option<ProgressCallback>,
) -> Result<PathBuf, MullamaError> {
let gguf_files = self.list_gguf_files(model_id)?;
if gguf_files.is_empty() {
return Err(MullamaError::HuggingFaceError(format!(
"No GGUF files found in repository: {}",
model_id
)));
}
let target_file = if let Some(fname) = filename {
gguf_files
.iter()
.find(|f| f.filename == fname || f.filename.to_lowercase() == fname.to_lowercase())
.ok_or_else(|| {
MullamaError::HuggingFaceError(format!(
"LoRA file '{}' not found in {}",
fname, model_id
))
})?
} else {
let lora_files: Vec<_> = gguf_files
.iter()
.filter(|f| {
let lower = f.filename.to_lowercase();
lower.contains("lora") || lower.contains("adapter")
})
.collect();
if !lora_files.is_empty() {
lora_files[0] } else {
&gguf_files[0]
}
};
self.download_gguf(model_id, target_file, progress_callback)
}
#[cfg(feature = "daemon")]
fn download_file(
&self,
url: &str,
dest: &Path,
expected_size: u64,
filename: &str,
progress_callback: Option<ProgressCallback>,
) -> Result<(), MullamaError> {
let temp_path = dest.with_extension("download");
let client = reqwest::blocking::Client::builder()
.user_agent(&self.user_agent)
.build()
.map_err(|e| {
MullamaError::HuggingFaceError(format!("Failed to create HTTP client: {}", e))
})?;
let mut request = client.get(url);
if let Some(ref token) = self.token {
request = request.header("Authorization", format!("Bearer {}", token));
}
let mut response = request.send().map_err(|e| {
MullamaError::HuggingFaceError(format!("Download request failed: {}", e))
})?;
if !response.status().is_success() {
let _ = fs::remove_file(&temp_path);
return Err(MullamaError::HuggingFaceError(format!(
"Download failed: HTTP {}",
response.status()
)));
}
let mut file = fs::File::create(&temp_path).map_err(MullamaError::IoError)?;
let start_time = std::time::Instant::now();
let mut downloaded: u64 = 0;
let mut buf = [0u8; 8192];
loop {
let n = response
.read(&mut buf)
.map_err(|e| MullamaError::HuggingFaceError(format!("Download error: {}", e)))?;
if n == 0 {
break;
}
file.write_all(&buf[..n])
.map_err(|e| MullamaError::HuggingFaceError(format!("Write error: {}", e)))?;
downloaded += n as u64;
if let Some(ref callback) = progress_callback {
let elapsed = start_time.elapsed().as_secs_f64();
let speed = if elapsed > 0.0 {
(downloaded as f64 / elapsed) as u64
} else {
0
};
let remaining = expected_size.saturating_sub(downloaded);
let eta = if speed > 0 { remaining / speed } else { 0 };
callback(DownloadProgress {
downloaded,
total: expected_size,
speed_bps: speed,
eta_seconds: eta,
filename: filename.to_string(),
});
}
}
drop(file);
fs::rename(&temp_path, dest).map_err(MullamaError::IoError)?;
if let Some(callback) = progress_callback {
callback(DownloadProgress {
downloaded: expected_size,
total: expected_size,
speed_bps: 0,
eta_seconds: 0,
filename: filename.to_string(),
});
}
Ok(())
}
#[cfg(feature = "daemon")]
pub fn test_model(&self, model_path: &Path) -> Result<ModelTestResult, MullamaError> {
use std::sync::Arc;
use std::time::Instant;
let mut result = ModelTestResult {
load_success: false,
load_time_ms: 0,
tokenization_works: false,
generation_works: false,
sample_output: None,
n_params: 0,
n_ctx: 0,
n_embd: 0,
n_layers: 0,
vocab_size: 0,
error: None,
};
let load_start = Instant::now();
let model = match Model::load(model_path) {
Ok(m) => Arc::new(m),
Err(e) => {
result.error = Some(format!("Failed to load model: {}", e));
return Ok(result);
}
};
result.load_time_ms = load_start.elapsed().as_millis() as u64;
result.load_success = true;
result.n_params = model.n_params();
result.n_ctx = model.n_ctx_train() as u32;
result.n_embd = model.n_embd() as u32;
result.n_layers = model.n_layer() as u32;
result.vocab_size = model.vocab_size() as u32;
match model.tokenize("Hello, world!", true, false) {
Ok(tokens) => {
if !tokens.is_empty() {
result.tokenization_works = true;
}
}
Err(e) => {
result.error = Some(format!("Tokenization failed: {}", e));
return Ok(result);
}
}
result.generation_works = true;
result.sample_output = Some("(generation test skipped)".to_string());
Ok(result)
}
#[cfg(feature = "daemon")]
pub fn get_popular_gguf_models(&self, limit: usize) -> Result<Vec<HFModelInfo>, MullamaError> {
let filters = ModelSearchFilters::new()
.gguf_only()
.sort_by_downloads()
.with_limit(limit);
self.search_models(&filters)
}
#[cfg(feature = "daemon")]
pub fn find_gguf_versions(&self, model_name: &str) -> Result<Vec<HFModelInfo>, MullamaError> {
let filters = ModelSearchFilters::new()
.with_query(&format!("{} GGUF", model_name))
.gguf_only()
.sort_by_downloads()
.with_limit(20);
self.search_models(&filters)
}
#[cfg(feature = "daemon")]
fn http_get(&self, url: &str) -> Result<String, MullamaError> {
let client = reqwest::blocking::Client::builder()
.user_agent(&self.user_agent)
.build()
.map_err(|e| {
MullamaError::HuggingFaceError(format!("Failed to create HTTP client: {}", e))
})?;
let mut request = client.get(url);
if let Some(ref token) = self.token {
request = request.header("Authorization", format!("Bearer {}", token));
}
let response = request
.send()
.map_err(|e| MullamaError::HuggingFaceError(format!("HTTP request failed: {}", e)))?;
if !response.status().is_success() {
return Err(MullamaError::HuggingFaceError(format!(
"HTTP request failed: {} - {}",
response.status(),
response.text().unwrap_or_else(|_| "(no body)".to_string())
)));
}
response
.text()
.map_err(|e| MullamaError::HuggingFaceError(format!("Failed to read response: {}", e)))
}
#[cfg(feature = "daemon")]
fn parse_model_info(&self, json: &serde_json::Value) -> Option<HFModelInfo> {
let model_id = json
.get("modelId")
.or_else(|| json.get("id"))
.and_then(|v| v.as_str())?
.to_string();
let parts: Vec<&str> = model_id.split('/').collect();
let (author, name) = if parts.len() >= 2 {
(parts[0].to_string(), parts[1..].join("/"))
} else {
("".to_string(), model_id.clone())
};
let tags: Vec<String> = json
.get("tags")
.and_then(|t| t.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
Some(HFModelInfo {
model_id,
author,
name,
description: json
.get("description")
.and_then(|v| v.as_str())
.map(String::from),
downloads: json.get("downloads").and_then(|v| v.as_u64()).unwrap_or(0),
likes: json.get("likes").and_then(|v| v.as_u64()).unwrap_or(0),
tags,
last_modified: json
.get("lastModified")
.and_then(|v| v.as_str())
.map(String::from),
gguf_files: Vec::new(), pipeline_tag: json
.get("pipeline_tag")
.and_then(|v| v.as_str())
.map(String::from),
license: json
.get("license")
.and_then(|v| v.as_str())
.map(String::from),
})
}
pub fn download_dir(&self) -> &Path {
&self.download_dir
}
pub fn list_local_models(&self) -> Result<Vec<PathBuf>, MullamaError> {
let mut models = Vec::new();
if !self.download_dir.exists() {
return Ok(models);
}
for entry in fs::read_dir(&self.download_dir).map_err(MullamaError::IoError)? {
let entry = entry.map_err(MullamaError::IoError)?;
let path = entry.path();
if path.is_dir() {
for file_entry in fs::read_dir(&path).map_err(MullamaError::IoError)? {
let file_entry = file_entry.map_err(MullamaError::IoError)?;
let file_path = file_entry.path();
if file_path.extension().map(|e| e == "gguf").unwrap_or(false) {
models.push(file_path);
}
}
} else if path.extension().map(|e| e == "gguf").unwrap_or(false) {
models.push(path);
}
}
Ok(models)
}
pub fn delete_local_model(&self, model_path: &Path) -> Result<(), MullamaError> {
if model_path.exists() {
fs::remove_file(model_path).map_err(MullamaError::IoError)?;
}
Ok(())
}
}
impl Default for HFClient {
fn default() -> Self {
Self::new()
}
}