Skip to main content

autoagents_speech/
model_source.rs

1use std::path::{Path, PathBuf};
2
3const HF_ENDPOINT_ENV: &str = "HF_ENDPOINT";
4const HUGGINGFACE_HUB_TOKEN_ENV: &str = "HUGGINGFACE_HUB_TOKEN";
5const HF_TOKEN_ENV: &str = "HF_TOKEN";
6const HUGGINGFACE_TOKEN_ENV: &str = "HUGGINGFACE_TOKEN";
7
8/// Source for loading a model from disk or HuggingFace.
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct ModelSource {
11    kind: ModelSourceKind,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15enum ModelSourceKind {
16    File {
17        path: PathBuf,
18    },
19    HuggingFace {
20        repo_id: String,
21        filename: String,
22        revision: Option<String>,
23    },
24    HuggingFaceDir {
25        repo_id: String,
26        directory: String,
27        revision: Option<String>,
28    },
29}
30
31impl ModelSource {
32    /// Create a source backed by a local model file.
33    pub fn from_file(path: impl Into<PathBuf>) -> Self {
34        Self {
35            kind: ModelSourceKind::File { path: path.into() },
36        }
37    }
38
39    /// Create a source backed by a HuggingFace repo + filename.
40    pub fn from_hf(repo_id: impl Into<String>, filename: impl Into<String>) -> Self {
41        Self {
42            kind: ModelSourceKind::HuggingFace {
43                repo_id: repo_id.into(),
44                filename: filename.into(),
45                revision: None,
46            },
47        }
48    }
49
50    /// Create a source backed by a HuggingFace repo + directory prefix.
51    pub fn from_hf_dir(repo_id: impl Into<String>, directory: impl Into<String>) -> Self {
52        Self {
53            kind: ModelSourceKind::HuggingFaceDir {
54                repo_id: repo_id.into(),
55                directory: directory.into(),
56                revision: None,
57            },
58        }
59    }
60
61    /// Set the HuggingFace revision (branch, tag, or commit SHA).
62    pub fn with_revision(mut self, revision: impl Into<String>) -> Self {
63        match &mut self.kind {
64            ModelSourceKind::HuggingFace { revision: slot, .. }
65            | ModelSourceKind::HuggingFaceDir { revision: slot, .. } => {
66                *slot = Some(revision.into());
67            }
68            _ => {}
69        }
70        self
71    }
72
73    /// Resolve the model path, downloading if necessary.
74    pub fn resolve(&self) -> Result<PathBuf, ModelSourceError> {
75        match &self.kind {
76            ModelSourceKind::File { path } => {
77                if path.is_file() {
78                    Ok(path.clone())
79                } else {
80                    Err(ModelSourceError::MissingLocalFile(path.clone()))
81                }
82            }
83            ModelSourceKind::HuggingFace {
84                repo_id,
85                filename,
86                revision,
87            } => resolve_hf(repo_id, filename, revision.as_deref()),
88            ModelSourceKind::HuggingFaceDir {
89                repo_id,
90                directory,
91                revision,
92            } => resolve_hf_dir(repo_id, directory, revision.as_deref()),
93        }
94    }
95
96    /// Return the local path when the source is a file.
97    pub fn local_path(&self) -> Option<&Path> {
98        match &self.kind {
99            ModelSourceKind::File { path } => Some(path.as_path()),
100            _ => None,
101        }
102    }
103
104    /// Return the HuggingFace repo ID if applicable.
105    pub fn repo_id(&self) -> Option<&str> {
106        match &self.kind {
107            ModelSourceKind::HuggingFace { repo_id, .. }
108            | ModelSourceKind::HuggingFaceDir { repo_id, .. } => Some(repo_id.as_str()),
109            _ => None,
110        }
111    }
112
113    /// Return the HuggingFace filename if applicable.
114    pub fn filename(&self) -> Option<&str> {
115        match &self.kind {
116            ModelSourceKind::HuggingFace { filename, .. } => Some(filename.as_str()),
117            _ => None,
118        }
119    }
120
121    /// Return the HuggingFace directory prefix if applicable.
122    pub fn directory(&self) -> Option<&str> {
123        match &self.kind {
124            ModelSourceKind::HuggingFaceDir { directory, .. } => Some(directory.as_str()),
125            _ => None,
126        }
127    }
128}
129
130#[derive(Debug, thiserror::Error)]
131pub enum ModelSourceError {
132    #[error("Model file not found: {0}")]
133    MissingLocalFile(PathBuf),
134    #[error("HuggingFace support is not enabled; enable the `model-hf` feature")]
135    HuggingFaceDisabled,
136    #[error("HuggingFace download failed: {0}")]
137    HuggingFaceDownload(String),
138    #[error("HuggingFace repo id is required")]
139    MissingRepoId,
140    #[error("HuggingFace filename is required")]
141    MissingFilename,
142    #[error("HuggingFace directory is required")]
143    MissingDirectory,
144}
145
146#[cfg(feature = "model-hf")]
147fn resolve_hf(
148    repo_id: &str,
149    filename: &str,
150    revision: Option<&str>,
151) -> Result<PathBuf, ModelSourceError> {
152    use hf_hub::api::sync::ApiBuilder;
153    use hf_hub::{Cache, Repo, RepoType};
154
155    if repo_id.is_empty() {
156        return Err(ModelSourceError::MissingRepoId);
157    }
158    if filename.is_empty() {
159        return Err(ModelSourceError::MissingFilename);
160    }
161
162    let cache = Cache::from_env();
163    let mut api_builder = ApiBuilder::from_cache(cache);
164    if let Ok(endpoint) = std::env::var(HF_ENDPOINT_ENV) {
165        api_builder = api_builder.with_endpoint(endpoint);
166    }
167    if let Some(token) = hf_token() {
168        api_builder = api_builder.with_token(Some(token));
169    }
170    let api = api_builder
171        .build()
172        .map_err(|err| ModelSourceError::HuggingFaceDownload(err.to_string()))?;
173    let revision = revision.unwrap_or("main");
174    let repo = Repo::with_revision(repo_id.to_string(), RepoType::Model, revision.to_string());
175    let api_repo = api.repo(repo);
176    let path = api_repo
177        .get(filename)
178        .map_err(|err| ModelSourceError::HuggingFaceDownload(err.to_string()))?;
179    Ok(path)
180}
181
182#[cfg(feature = "model-hf")]
183fn resolve_hf_dir(
184    repo_id: &str,
185    directory: &str,
186    revision: Option<&str>,
187) -> Result<PathBuf, ModelSourceError> {
188    use hf_hub::api::sync::ApiBuilder;
189    use hf_hub::{Cache, Repo, RepoType};
190
191    if repo_id.is_empty() {
192        return Err(ModelSourceError::MissingRepoId);
193    }
194    if directory.is_empty() {
195        return Err(ModelSourceError::MissingDirectory);
196    }
197
198    let cache = Cache::from_env();
199    let mut api_builder = ApiBuilder::from_cache(cache);
200    if let Ok(endpoint) = std::env::var(HF_ENDPOINT_ENV) {
201        api_builder = api_builder.with_endpoint(endpoint);
202    }
203    if let Some(token) = hf_token() {
204        api_builder = api_builder.with_token(Some(token));
205    }
206    let api = api_builder
207        .build()
208        .map_err(|err| ModelSourceError::HuggingFaceDownload(err.to_string()))?;
209    let revision = revision.unwrap_or("main");
210    let repo = Repo::with_revision(repo_id.to_string(), RepoType::Model, revision.to_string());
211    let api_repo = api.repo(repo);
212    let info = api_repo
213        .info()
214        .map_err(|err| ModelSourceError::HuggingFaceDownload(err.to_string()))?;
215
216    let prefix = if directory.ends_with('/') {
217        directory.to_string()
218    } else {
219        format!("{directory}/")
220    };
221
222    let mut local_dir: Option<PathBuf> = None;
223    let mut found = false;
224
225    for sibling in info.siblings {
226        let filename = sibling.rfilename;
227        if !filename.starts_with(&prefix) {
228            continue;
229        }
230        found = true;
231        let path = api_repo
232            .get(&filename)
233            .map_err(|err| ModelSourceError::HuggingFaceDownload(err.to_string()))?;
234
235        if local_dir.is_none() {
236            let local = derive_directory(&path, &prefix, &filename);
237            local_dir = Some(local);
238        }
239    }
240
241    if !found {
242        return Err(ModelSourceError::MissingDirectory);
243    }
244
245    local_dir.ok_or(ModelSourceError::MissingDirectory)
246}
247
248#[cfg(not(feature = "model-hf"))]
249fn resolve_hf_dir(
250    _repo_id: &str,
251    _directory: &str,
252    _revision: Option<&str>,
253) -> Result<PathBuf, ModelSourceError> {
254    Err(ModelSourceError::HuggingFaceDisabled)
255}
256
257#[cfg(feature = "model-hf")]
258fn derive_directory(path: &Path, directory: &str, rfilename: &str) -> PathBuf {
259    let prefix_path = Path::new(directory);
260    let prefix_count = prefix_path.components().count();
261    let file_components = Path::new(rfilename).components().count();
262    let pops = file_components.saturating_sub(prefix_count);
263
264    let mut local = path.to_path_buf();
265    for _ in 0..pops {
266        local.pop();
267    }
268    local
269}
270
271#[cfg(not(feature = "model-hf"))]
272fn resolve_hf(
273    _repo_id: &str,
274    _filename: &str,
275    _revision: Option<&str>,
276) -> Result<PathBuf, ModelSourceError> {
277    Err(ModelSourceError::HuggingFaceDisabled)
278}
279
280#[cfg(feature = "model-hf")]
281fn hf_token() -> Option<String> {
282    std::env::var(HUGGINGFACE_HUB_TOKEN_ENV)
283        .ok()
284        .or_else(|| std::env::var(HF_TOKEN_ENV).ok())
285        .or_else(|| std::env::var(HUGGINGFACE_TOKEN_ENV).ok())
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use std::io::Write;
292
293    #[test]
294    fn from_file_tracks_path() {
295        let source = ModelSource::from_file("model.onnx");
296        assert_eq!(source.local_path(), Some(Path::new("model.onnx")));
297        assert!(source.repo_id().is_none());
298    }
299
300    #[test]
301    fn resolve_missing_file_returns_error() {
302        let source = ModelSource::from_file("missing.onnx");
303        let err = source.resolve().unwrap_err();
304        match err {
305            ModelSourceError::MissingLocalFile(path) => {
306                assert_eq!(path, PathBuf::from("missing.onnx"));
307            }
308            other => panic!("unexpected error: {other:?}"),
309        }
310    }
311
312    #[test]
313    fn resolve_existing_file() {
314        let mut file = tempfile::NamedTempFile::new().unwrap();
315        writeln!(file, "test").unwrap();
316        let path = file.path().to_path_buf();
317
318        let source = ModelSource::from_file(&path);
319        let resolved = source.resolve().unwrap();
320        assert_eq!(resolved, path);
321    }
322
323    #[test]
324    fn from_hf_tracks_repo_and_filename() {
325        let source = ModelSource::from_hf("org/model", "model.onnx");
326        assert_eq!(source.repo_id(), Some("org/model"));
327        assert_eq!(source.filename(), Some("model.onnx"));
328    }
329
330    #[test]
331    fn from_hf_dir_tracks_repo_and_directory() {
332        let source = ModelSource::from_hf_dir("org/model", "weights");
333        assert_eq!(source.repo_id(), Some("org/model"));
334        assert_eq!(source.directory(), Some("weights"));
335        assert!(source.filename().is_none());
336    }
337
338    #[test]
339    #[cfg(not(feature = "model-hf"))]
340    fn resolve_hf_requires_feature() {
341        let source = ModelSource::from_hf("org/model", "model.onnx");
342        let err = source.resolve().unwrap_err();
343        match err {
344            ModelSourceError::HuggingFaceDisabled => {}
345            other => panic!("unexpected error: {other:?}"),
346        }
347    }
348
349    #[test]
350    #[cfg(not(feature = "model-hf"))]
351    fn resolve_hf_dir_requires_feature() {
352        let source = ModelSource::from_hf_dir("org/model", "weights");
353        let err = source.resolve().unwrap_err();
354        match err {
355            ModelSourceError::HuggingFaceDisabled => {}
356            other => panic!("unexpected error: {other:?}"),
357        }
358    }
359}