1use std::path::PathBuf;
7
8use crate::error::{AudioError, AudioResult};
9
10pub struct LocalModelRegistry {
26 cache_dir: PathBuf,
27}
28
29impl Default for LocalModelRegistry {
30 fn default() -> Self {
31 let cache_dir = dirs_cache_dir().join("adk-audio/models");
32 Self { cache_dir }
33 }
34}
35
36impl LocalModelRegistry {
37 pub fn new(cache_dir: impl Into<PathBuf>) -> Self {
39 Self { cache_dir: cache_dir.into() }
40 }
41
42 pub async fn get_or_download(&self, model_id: &str) -> AudioResult<PathBuf> {
54 if model_id.is_empty() {
55 return Err(AudioError::ModelDownload {
56 model_id: model_id.to_string(),
57 message: "model_id cannot be empty".into(),
58 });
59 }
60
61 let local_path = self.cache_dir.join(model_id.replace('/', "--"));
63 if local_path.exists() {
64 tracing::debug!(model_id, path = %local_path.display(), "model found in local cache");
65 return Ok(local_path);
66 }
67
68 self.download_from_hub(model_id).await
70 }
71
72 pub fn cache_dir(&self) -> &PathBuf {
74 &self.cache_dir
75 }
76
77 pub fn model_path(&self, model_id: &str) -> PathBuf {
79 self.cache_dir.join(model_id.replace('/', "--"))
80 }
81
82 #[cfg(any(feature = "onnx", feature = "mlx", feature = "qwen3-tts"))]
89 async fn download_from_hub(&self, model_id: &str) -> AudioResult<PathBuf> {
90 let model_id_owned = model_id.to_string();
91
92 tracing::info!(model_id, "downloading model from HuggingFace Hub (first run)");
93
94 let model_dir = tokio::task::spawn_blocking(move || Self::download_sync(&model_id_owned))
95 .await
96 .map_err(|e| AudioError::ModelDownload {
97 model_id: model_id.to_string(),
98 message: format!("download task panicked: {e}"),
99 })??;
100
101 tracing::info!(
102 model_id,
103 path = %model_dir.display(),
104 "model download complete"
105 );
106
107 Ok(model_dir)
108 }
109
110 #[cfg(any(feature = "onnx", feature = "mlx", feature = "qwen3-tts"))]
112 fn download_sync(model_id: &str) -> AudioResult<PathBuf> {
113 use hf_hub::api::sync::Api;
114
115 let api = Api::new().map_err(|e| AudioError::ModelDownload {
116 model_id: model_id.to_string(),
117 message: format!("failed to create HuggingFace API client: {e}"),
118 })?;
119
120 let repo = api.model(model_id.to_string());
121
122 let repo_info = repo.info().map_err(|e| AudioError::ModelDownload {
124 model_id: model_id.to_string(),
125 message: format!("failed to fetch repo info: {e}"),
126 })?;
127
128 let siblings = repo_info.siblings;
129 if siblings.is_empty() {
130 return Err(AudioError::ModelDownload {
131 model_id: model_id.to_string(),
132 message: "repository has no files".into(),
133 });
134 }
135
136 tracing::info!(model_id, file_count = siblings.len(), "downloading model files");
137
138 let mut last_path: Option<PathBuf> = None;
140 for sibling in &siblings {
141 let filename = &sibling.rfilename;
142
143 if filename.starts_with(".git") {
146 continue;
147 }
148
149 tracing::debug!(model_id, file = %filename, "downloading");
150 let path = repo.get(filename).map_err(|e| AudioError::ModelDownload {
151 model_id: model_id.to_string(),
152 message: format!("failed to download {filename}: {e}"),
153 })?;
154 last_path = Some(path);
155 }
156
157 let model_dir =
163 last_path.as_ref().and_then(|p| Self::find_snapshot_root(p)).ok_or_else(|| {
164 AudioError::ModelDownload {
165 model_id: model_id.to_string(),
166 message: "could not determine model directory from downloaded files".into(),
167 }
168 })?;
169
170 Ok(model_dir)
171 }
172
173 #[cfg(not(any(feature = "onnx", feature = "mlx", feature = "qwen3-tts")))]
175 async fn download_from_hub(&self, model_id: &str) -> AudioResult<PathBuf> {
176 let local_path = self.cache_dir.join(model_id.replace('/', "--"));
177 Err(AudioError::ModelDownload {
178 model_id: model_id.to_string(),
179 message: format!(
180 "model not cached and hf-hub feature not enabled. \
181 Either enable the `onnx` or `mlx` feature, or manually place \
182 model files at: {}",
183 local_path.display()
184 ),
185 })
186 }
187
188 #[cfg(any(feature = "onnx", feature = "mlx", feature = "qwen3-tts"))]
198 fn find_snapshot_root(file_path: &std::path::Path) -> Option<PathBuf> {
199 let mut current = file_path.parent()?;
200 loop {
201 if let Some(parent) = current.parent() {
202 if parent.file_name().and_then(|n| n.to_str()) == Some("snapshots") {
203 return Some(current.to_path_buf());
204 }
205 current = parent;
206 } else {
207 return file_path.parent().map(|p| p.to_path_buf());
209 }
210 }
211 }
212}
213
214fn dirs_cache_dir() -> PathBuf {
216 std::env::var("HOME")
218 .map(|h| PathBuf::from(h).join(".cache"))
219 .unwrap_or_else(|_| PathBuf::from(".cache"))
220}