Skip to main content

adk_audio/
registry.rs

1//! Local model registry for downloading and caching model weights.
2//!
3//! Uses the `hf_hub` crate (when available) to download models from
4//! HuggingFace Hub on first use, caching them locally for subsequent runs.
5
6use std::path::PathBuf;
7
8use crate::error::{AudioError, AudioResult};
9
10/// Registry for managing local model downloads and caching.
11///
12/// Supports both MLX-format (`.safetensors` + `config.json`) and
13/// ONNX-format (`.onnx`) models from HuggingFace Hub.
14///
15/// On first access the model is downloaded via the HuggingFace Hub API
16/// (requires the `onnx` or `mlx` feature for the `hf-hub` dependency).
17/// Subsequent calls return the cached path immediately.
18///
19/// # Example
20///
21/// ```rust,ignore
22/// let registry = LocalModelRegistry::default();
23/// let path = registry.get_or_download("onnx-community/whisper-base").await?;
24/// ```
25pub struct LocalModelRegistry {
26    cache_dir: PathBuf,
27}
28
29impl Default for LocalModelRegistry {
30    fn default() -> Self {
31        let cache_dir = dirs_cache_dir().join("adk-audio/models");
32        Self { cache_dir }
33    }
34}
35
36impl LocalModelRegistry {
37    /// Create a registry with a custom cache directory.
38    pub fn new(cache_dir: impl Into<PathBuf>) -> Self {
39        Self { cache_dir: cache_dir.into() }
40    }
41
42    /// Get the local path for a model, downloading from HuggingFace Hub if not cached.
43    ///
44    /// The method first checks for a local cache directory at
45    /// `<cache_dir>/<org>--<model>`. If found, it returns immediately.
46    /// Otherwise it uses the `hf-hub` crate to download all repository
47    /// files into the HuggingFace cache and returns that path.
48    ///
49    /// # Errors
50    ///
51    /// Returns [`AudioError::ModelDownload`] if `model_id` is empty or
52    /// the download fails.
53    pub async fn get_or_download(&self, model_id: &str) -> AudioResult<PathBuf> {
54        if model_id.is_empty() {
55            return Err(AudioError::ModelDownload {
56                model_id: model_id.to_string(),
57                message: "model_id cannot be empty".into(),
58            });
59        }
60
61        // Check our own cache directory first
62        let local_path = self.cache_dir.join(model_id.replace('/', "--"));
63        if local_path.exists() {
64            tracing::debug!(model_id, path = %local_path.display(), "model found in local cache");
65            return Ok(local_path);
66        }
67
68        // Download via hf-hub
69        self.download_from_hub(model_id).await
70    }
71
72    /// Get the cache directory path.
73    pub fn cache_dir(&self) -> &PathBuf {
74        &self.cache_dir
75    }
76
77    /// Compute the local path for a model ID (without downloading).
78    pub fn model_path(&self, model_id: &str) -> PathBuf {
79        self.cache_dir.join(model_id.replace('/', "--"))
80    }
81
82    /// Download a model repository from HuggingFace Hub.
83    ///
84    /// Uses the `hf-hub` crate's sync API (wrapped in `spawn_blocking`
85    /// so we don't block the async runtime). The Hub API caches files
86    /// under `~/.cache/huggingface/hub/` by default; we return the
87    /// snapshot directory that contains all downloaded files.
88    #[cfg(any(feature = "onnx", feature = "mlx", feature = "qwen3-tts"))]
89    async fn download_from_hub(&self, model_id: &str) -> AudioResult<PathBuf> {
90        let model_id_owned = model_id.to_string();
91
92        tracing::info!(model_id, "downloading model from HuggingFace Hub (first run)");
93
94        let model_dir = tokio::task::spawn_blocking(move || Self::download_sync(&model_id_owned))
95            .await
96            .map_err(|e| AudioError::ModelDownload {
97                model_id: model_id.to_string(),
98                message: format!("download task panicked: {e}"),
99            })??;
100
101        tracing::info!(
102            model_id,
103            path = %model_dir.display(),
104            "model download complete"
105        );
106
107        Ok(model_dir)
108    }
109
110    /// Synchronous download implementation using hf-hub.
111    #[cfg(any(feature = "onnx", feature = "mlx", feature = "qwen3-tts"))]
112    fn download_sync(model_id: &str) -> AudioResult<PathBuf> {
113        use hf_hub::api::sync::Api;
114
115        let api = Api::new().map_err(|e| AudioError::ModelDownload {
116            model_id: model_id.to_string(),
117            message: format!("failed to create HuggingFace API client: {e}"),
118        })?;
119
120        let repo = api.model(model_id.to_string());
121
122        // Fetch the repo info to discover all files
123        let repo_info = repo.info().map_err(|e| AudioError::ModelDownload {
124            model_id: model_id.to_string(),
125            message: format!("failed to fetch repo info: {e}"),
126        })?;
127
128        let siblings = repo_info.siblings;
129        if siblings.is_empty() {
130            return Err(AudioError::ModelDownload {
131                model_id: model_id.to_string(),
132                message: "repository has no files".into(),
133            });
134        }
135
136        tracing::info!(model_id, file_count = siblings.len(), "downloading model files");
137
138        // Download each file — hf-hub handles caching and deduplication
139        let mut last_path: Option<PathBuf> = None;
140        for sibling in &siblings {
141            let filename = &sibling.rfilename;
142
143            // Skip very large files that aren't needed for inference
144            // (e.g. .git files, READMEs are fine to download)
145            if filename.starts_with(".git") {
146                continue;
147            }
148
149            tracing::debug!(model_id, file = %filename, "downloading");
150            let path = repo.get(filename).map_err(|e| AudioError::ModelDownload {
151                model_id: model_id.to_string(),
152                message: format!("failed to download {filename}: {e}"),
153            })?;
154            last_path = Some(path);
155        }
156
157        // The model directory is the snapshot root.
158        // hf-hub stores files under <cache>/models--<org>--<name>/snapshots/<rev>/
159        // Files may be in subdirectories (e.g. onnx/), so we walk up from the
160        // last downloaded file to find the snapshot root (the directory whose
161        // parent is named "snapshots").
162        let model_dir =
163            last_path.as_ref().and_then(|p| Self::find_snapshot_root(p)).ok_or_else(|| {
164                AudioError::ModelDownload {
165                    model_id: model_id.to_string(),
166                    message: "could not determine model directory from downloaded files".into(),
167                }
168            })?;
169
170        Ok(model_dir)
171    }
172
173    /// Fallback when hf-hub is not available.
174    #[cfg(not(any(feature = "onnx", feature = "mlx", feature = "qwen3-tts")))]
175    async fn download_from_hub(&self, model_id: &str) -> AudioResult<PathBuf> {
176        let local_path = self.cache_dir.join(model_id.replace('/', "--"));
177        Err(AudioError::ModelDownload {
178            model_id: model_id.to_string(),
179            message: format!(
180                "model not cached and hf-hub feature not enabled. \
181                 Either enable the `onnx` or `mlx` feature, or manually place \
182                 model files at: {}",
183                local_path.display()
184            ),
185        })
186    }
187
188    /// Walk up from a file path to find the HuggingFace snapshot root directory.
189    ///
190    /// The snapshot root is the directory whose parent is named `"snapshots"`.
191    /// For example, given a path like:
192    /// `~/.cache/huggingface/hub/models--org--name/snapshots/abc123/onnx/model.onnx`
193    /// this returns `~/.cache/huggingface/hub/models--org--name/snapshots/abc123/`.
194    ///
195    /// Falls back to the immediate parent if no `snapshots` ancestor is found
196    /// (e.g. for locally cached models not from HuggingFace Hub).
197    #[cfg(any(feature = "onnx", feature = "mlx", feature = "qwen3-tts"))]
198    fn find_snapshot_root(file_path: &std::path::Path) -> Option<PathBuf> {
199        let mut current = file_path.parent()?;
200        loop {
201            if let Some(parent) = current.parent() {
202                if parent.file_name().and_then(|n| n.to_str()) == Some("snapshots") {
203                    return Some(current.to_path_buf());
204                }
205                current = parent;
206            } else {
207                // No "snapshots" ancestor found — fall back to immediate parent
208                return file_path.parent().map(|p| p.to_path_buf());
209            }
210        }
211    }
212}
213
214/// Get the user's cache directory, falling back to current directory.
215fn dirs_cache_dir() -> PathBuf {
216    // Use HOME env var on macOS/Linux
217    std::env::var("HOME")
218        .map(|h| PathBuf::from(h).join(".cache"))
219        .unwrap_or_else(|_| PathBuf::from(".cache"))
220}