use std::path::PathBuf;
use crate::error::{AudioError, AudioResult};
pub struct LocalModelRegistry {
cache_dir: PathBuf,
}
impl Default for LocalModelRegistry {
fn default() -> Self {
let cache_dir = dirs_cache_dir().join("adk-audio/models");
Self { cache_dir }
}
}
impl LocalModelRegistry {
pub fn new(cache_dir: impl Into<PathBuf>) -> Self {
Self { cache_dir: cache_dir.into() }
}
pub async fn get_or_download(&self, model_id: &str) -> AudioResult<PathBuf> {
if model_id.is_empty() {
return Err(AudioError::ModelDownload {
model_id: model_id.to_string(),
message: "model_id cannot be empty".into(),
});
}
let local_path = self.cache_dir.join(model_id.replace('/', "--"));
if local_path.exists() {
tracing::debug!(model_id, path = %local_path.display(), "model found in local cache");
return Ok(local_path);
}
self.download_from_hub(model_id).await
}
pub fn cache_dir(&self) -> &PathBuf {
&self.cache_dir
}
pub fn model_path(&self, model_id: &str) -> PathBuf {
self.cache_dir.join(model_id.replace('/', "--"))
}
#[cfg(any(feature = "onnx", feature = "mlx", feature = "qwen3-tts"))]
async fn download_from_hub(&self, model_id: &str) -> AudioResult<PathBuf> {
let model_id_owned = model_id.to_string();
tracing::info!(model_id, "downloading model from HuggingFace Hub (first run)");
let model_dir = tokio::task::spawn_blocking(move || Self::download_sync(&model_id_owned))
.await
.map_err(|e| AudioError::ModelDownload {
model_id: model_id.to_string(),
message: format!("download task panicked: {e}"),
})??;
tracing::info!(
model_id,
path = %model_dir.display(),
"model download complete"
);
Ok(model_dir)
}
#[cfg(any(feature = "onnx", feature = "mlx", feature = "qwen3-tts"))]
fn download_sync(model_id: &str) -> AudioResult<PathBuf> {
use hf_hub::api::sync::Api;
let api = Api::new().map_err(|e| AudioError::ModelDownload {
model_id: model_id.to_string(),
message: format!("failed to create HuggingFace API client: {e}"),
})?;
let repo = api.model(model_id.to_string());
let repo_info = repo.info().map_err(|e| AudioError::ModelDownload {
model_id: model_id.to_string(),
message: format!("failed to fetch repo info: {e}"),
})?;
let siblings = repo_info.siblings;
if siblings.is_empty() {
return Err(AudioError::ModelDownload {
model_id: model_id.to_string(),
message: "repository has no files".into(),
});
}
tracing::info!(model_id, file_count = siblings.len(), "downloading model files");
let mut last_path: Option<PathBuf> = None;
for sibling in &siblings {
let filename = &sibling.rfilename;
if filename.starts_with(".git") {
continue;
}
tracing::debug!(model_id, file = %filename, "downloading");
let path = repo.get(filename).map_err(|e| AudioError::ModelDownload {
model_id: model_id.to_string(),
message: format!("failed to download {filename}: {e}"),
})?;
last_path = Some(path);
}
let model_dir =
last_path.as_ref().and_then(|p| Self::find_snapshot_root(p)).ok_or_else(|| {
AudioError::ModelDownload {
model_id: model_id.to_string(),
message: "could not determine model directory from downloaded files".into(),
}
})?;
Ok(model_dir)
}
#[cfg(not(any(feature = "onnx", feature = "mlx", feature = "qwen3-tts")))]
async fn download_from_hub(&self, model_id: &str) -> AudioResult<PathBuf> {
let local_path = self.cache_dir.join(model_id.replace('/', "--"));
Err(AudioError::ModelDownload {
model_id: model_id.to_string(),
message: format!(
"model not cached and hf-hub feature not enabled. \
Either enable the `onnx` or `mlx` feature, or manually place \
model files at: {}",
local_path.display()
),
})
}
#[cfg(any(feature = "onnx", feature = "mlx", feature = "qwen3-tts"))]
fn find_snapshot_root(file_path: &std::path::Path) -> Option<PathBuf> {
let mut current = file_path.parent()?;
loop {
if let Some(parent) = current.parent() {
if parent.file_name().and_then(|n| n.to_str()) == Some("snapshots") {
return Some(current.to_path_buf());
}
current = parent;
} else {
return file_path.parent().map(|p| p.to_path_buf());
}
}
}
}
fn dirs_cache_dir() -> PathBuf {
std::env::var("HOME")
.map(|h| PathBuf::from(h).join(".cache"))
.unwrap_or_else(|_| PathBuf::from(".cache"))
}