1use std::path::{Path, PathBuf};
2
3const HF_ENDPOINT_ENV: &str = "HF_ENDPOINT";
4const HUGGINGFACE_HUB_TOKEN_ENV: &str = "HUGGINGFACE_HUB_TOKEN";
5const HF_TOKEN_ENV: &str = "HF_TOKEN";
6const HUGGINGFACE_TOKEN_ENV: &str = "HUGGINGFACE_TOKEN";
7
8#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct ModelSource {
11 kind: ModelSourceKind,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15enum ModelSourceKind {
16 File {
17 path: PathBuf,
18 },
19 HuggingFace {
20 repo_id: String,
21 filename: String,
22 revision: Option<String>,
23 },
24 HuggingFaceDir {
25 repo_id: String,
26 directory: String,
27 revision: Option<String>,
28 },
29}
30
31impl ModelSource {
32 pub fn from_file(path: impl Into<PathBuf>) -> Self {
34 Self {
35 kind: ModelSourceKind::File { path: path.into() },
36 }
37 }
38
39 pub fn from_hf(repo_id: impl Into<String>, filename: impl Into<String>) -> Self {
41 Self {
42 kind: ModelSourceKind::HuggingFace {
43 repo_id: repo_id.into(),
44 filename: filename.into(),
45 revision: None,
46 },
47 }
48 }
49
50 pub fn from_hf_dir(repo_id: impl Into<String>, directory: impl Into<String>) -> Self {
52 Self {
53 kind: ModelSourceKind::HuggingFaceDir {
54 repo_id: repo_id.into(),
55 directory: directory.into(),
56 revision: None,
57 },
58 }
59 }
60
61 pub fn with_revision(mut self, revision: impl Into<String>) -> Self {
63 match &mut self.kind {
64 ModelSourceKind::HuggingFace { revision: slot, .. }
65 | ModelSourceKind::HuggingFaceDir { revision: slot, .. } => {
66 *slot = Some(revision.into());
67 }
68 _ => {}
69 }
70 self
71 }
72
73 pub fn resolve(&self) -> Result<PathBuf, ModelSourceError> {
75 match &self.kind {
76 ModelSourceKind::File { path } => {
77 if path.is_file() {
78 Ok(path.clone())
79 } else {
80 Err(ModelSourceError::MissingLocalFile(path.clone()))
81 }
82 }
83 ModelSourceKind::HuggingFace {
84 repo_id,
85 filename,
86 revision,
87 } => resolve_hf(repo_id, filename, revision.as_deref()),
88 ModelSourceKind::HuggingFaceDir {
89 repo_id,
90 directory,
91 revision,
92 } => resolve_hf_dir(repo_id, directory, revision.as_deref()),
93 }
94 }
95
96 pub fn local_path(&self) -> Option<&Path> {
98 match &self.kind {
99 ModelSourceKind::File { path } => Some(path.as_path()),
100 _ => None,
101 }
102 }
103
104 pub fn repo_id(&self) -> Option<&str> {
106 match &self.kind {
107 ModelSourceKind::HuggingFace { repo_id, .. }
108 | ModelSourceKind::HuggingFaceDir { repo_id, .. } => Some(repo_id.as_str()),
109 _ => None,
110 }
111 }
112
113 pub fn filename(&self) -> Option<&str> {
115 match &self.kind {
116 ModelSourceKind::HuggingFace { filename, .. } => Some(filename.as_str()),
117 _ => None,
118 }
119 }
120
121 pub fn directory(&self) -> Option<&str> {
123 match &self.kind {
124 ModelSourceKind::HuggingFaceDir { directory, .. } => Some(directory.as_str()),
125 _ => None,
126 }
127 }
128}
129
130#[derive(Debug, thiserror::Error)]
131pub enum ModelSourceError {
132 #[error("Model file not found: {0}")]
133 MissingLocalFile(PathBuf),
134 #[error("HuggingFace support is not enabled; enable the `model-hf` feature")]
135 HuggingFaceDisabled,
136 #[error("HuggingFace download failed: {0}")]
137 HuggingFaceDownload(String),
138 #[error("HuggingFace repo id is required")]
139 MissingRepoId,
140 #[error("HuggingFace filename is required")]
141 MissingFilename,
142 #[error("HuggingFace directory is required")]
143 MissingDirectory,
144}
145
146#[cfg(feature = "model-hf")]
147fn resolve_hf(
148 repo_id: &str,
149 filename: &str,
150 revision: Option<&str>,
151) -> Result<PathBuf, ModelSourceError> {
152 use hf_hub::api::sync::ApiBuilder;
153 use hf_hub::{Cache, Repo, RepoType};
154
155 if repo_id.is_empty() {
156 return Err(ModelSourceError::MissingRepoId);
157 }
158 if filename.is_empty() {
159 return Err(ModelSourceError::MissingFilename);
160 }
161
162 let cache = Cache::from_env();
163 let mut api_builder = ApiBuilder::from_cache(cache);
164 if let Ok(endpoint) = std::env::var(HF_ENDPOINT_ENV) {
165 api_builder = api_builder.with_endpoint(endpoint);
166 }
167 if let Some(token) = hf_token() {
168 api_builder = api_builder.with_token(Some(token));
169 }
170 let api = api_builder
171 .build()
172 .map_err(|err| ModelSourceError::HuggingFaceDownload(err.to_string()))?;
173 let revision = revision.unwrap_or("main");
174 let repo = Repo::with_revision(repo_id.to_string(), RepoType::Model, revision.to_string());
175 let api_repo = api.repo(repo);
176 let path = api_repo
177 .get(filename)
178 .map_err(|err| ModelSourceError::HuggingFaceDownload(err.to_string()))?;
179 Ok(path)
180}
181
182#[cfg(feature = "model-hf")]
183fn resolve_hf_dir(
184 repo_id: &str,
185 directory: &str,
186 revision: Option<&str>,
187) -> Result<PathBuf, ModelSourceError> {
188 use hf_hub::api::sync::ApiBuilder;
189 use hf_hub::{Cache, Repo, RepoType};
190
191 if repo_id.is_empty() {
192 return Err(ModelSourceError::MissingRepoId);
193 }
194 if directory.is_empty() {
195 return Err(ModelSourceError::MissingDirectory);
196 }
197
198 let cache = Cache::from_env();
199 let mut api_builder = ApiBuilder::from_cache(cache);
200 if let Ok(endpoint) = std::env::var(HF_ENDPOINT_ENV) {
201 api_builder = api_builder.with_endpoint(endpoint);
202 }
203 if let Some(token) = hf_token() {
204 api_builder = api_builder.with_token(Some(token));
205 }
206 let api = api_builder
207 .build()
208 .map_err(|err| ModelSourceError::HuggingFaceDownload(err.to_string()))?;
209 let revision = revision.unwrap_or("main");
210 let repo = Repo::with_revision(repo_id.to_string(), RepoType::Model, revision.to_string());
211 let api_repo = api.repo(repo);
212 let info = api_repo
213 .info()
214 .map_err(|err| ModelSourceError::HuggingFaceDownload(err.to_string()))?;
215
216 let prefix = if directory.ends_with('/') {
217 directory.to_string()
218 } else {
219 format!("{directory}/")
220 };
221
222 let mut local_dir: Option<PathBuf> = None;
223 let mut found = false;
224
225 for sibling in info.siblings {
226 let filename = sibling.rfilename;
227 if !filename.starts_with(&prefix) {
228 continue;
229 }
230 found = true;
231 let path = api_repo
232 .get(&filename)
233 .map_err(|err| ModelSourceError::HuggingFaceDownload(err.to_string()))?;
234
235 if local_dir.is_none() {
236 let local = derive_directory(&path, &prefix, &filename);
237 local_dir = Some(local);
238 }
239 }
240
241 if !found {
242 return Err(ModelSourceError::MissingDirectory);
243 }
244
245 local_dir.ok_or(ModelSourceError::MissingDirectory)
246}
247
248#[cfg(not(feature = "model-hf"))]
249fn resolve_hf_dir(
250 _repo_id: &str,
251 _directory: &str,
252 _revision: Option<&str>,
253) -> Result<PathBuf, ModelSourceError> {
254 Err(ModelSourceError::HuggingFaceDisabled)
255}
256
257#[cfg(feature = "model-hf")]
258fn derive_directory(path: &Path, directory: &str, rfilename: &str) -> PathBuf {
259 let prefix_path = Path::new(directory);
260 let prefix_count = prefix_path.components().count();
261 let file_components = Path::new(rfilename).components().count();
262 let pops = file_components.saturating_sub(prefix_count);
263
264 let mut local = path.to_path_buf();
265 for _ in 0..pops {
266 local.pop();
267 }
268 local
269}
270
271#[cfg(not(feature = "model-hf"))]
272fn resolve_hf(
273 _repo_id: &str,
274 _filename: &str,
275 _revision: Option<&str>,
276) -> Result<PathBuf, ModelSourceError> {
277 Err(ModelSourceError::HuggingFaceDisabled)
278}
279
280#[cfg(feature = "model-hf")]
281fn hf_token() -> Option<String> {
282 std::env::var(HUGGINGFACE_HUB_TOKEN_ENV)
283 .ok()
284 .or_else(|| std::env::var(HF_TOKEN_ENV).ok())
285 .or_else(|| std::env::var(HUGGINGFACE_TOKEN_ENV).ok())
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use std::io::Write;
292
293 #[test]
294 fn from_file_tracks_path() {
295 let source = ModelSource::from_file("model.onnx");
296 assert_eq!(source.local_path(), Some(Path::new("model.onnx")));
297 assert!(source.repo_id().is_none());
298 }
299
300 #[test]
301 fn resolve_missing_file_returns_error() {
302 let source = ModelSource::from_file("missing.onnx");
303 let err = source.resolve().unwrap_err();
304 match err {
305 ModelSourceError::MissingLocalFile(path) => {
306 assert_eq!(path, PathBuf::from("missing.onnx"));
307 }
308 other => panic!("unexpected error: {other:?}"),
309 }
310 }
311
312 #[test]
313 fn resolve_existing_file() {
314 let mut file = tempfile::NamedTempFile::new().unwrap();
315 writeln!(file, "test").unwrap();
316 let path = file.path().to_path_buf();
317
318 let source = ModelSource::from_file(&path);
319 let resolved = source.resolve().unwrap();
320 assert_eq!(resolved, path);
321 }
322
323 #[test]
324 fn from_hf_tracks_repo_and_filename() {
325 let source = ModelSource::from_hf("org/model", "model.onnx");
326 assert_eq!(source.repo_id(), Some("org/model"));
327 assert_eq!(source.filename(), Some("model.onnx"));
328 }
329
330 #[test]
331 fn from_hf_dir_tracks_repo_and_directory() {
332 let source = ModelSource::from_hf_dir("org/model", "weights");
333 assert_eq!(source.repo_id(), Some("org/model"));
334 assert_eq!(source.directory(), Some("weights"));
335 assert!(source.filename().is_none());
336 }
337
338 #[test]
339 #[cfg(not(feature = "model-hf"))]
340 fn resolve_hf_requires_feature() {
341 let source = ModelSource::from_hf("org/model", "model.onnx");
342 let err = source.resolve().unwrap_err();
343 match err {
344 ModelSourceError::HuggingFaceDisabled => {}
345 other => panic!("unexpected error: {other:?}"),
346 }
347 }
348
349 #[test]
350 #[cfg(not(feature = "model-hf"))]
351 fn resolve_hf_dir_requires_feature() {
352 let source = ModelSource::from_hf_dir("org/model", "weights");
353 let err = source.resolve().unwrap_err();
354 match err {
355 ModelSourceError::HuggingFaceDisabled => {}
356 other => panic!("unexpected error: {other:?}"),
357 }
358 }
359}