1use std::collections::BTreeMap;
2use std::fs;
3use std::io::ErrorKind;
4use std::path::{Component, Path, PathBuf};
5use std::sync::Arc;
6
7use crate::{ModelRuntimeError, Result};
8use jobs_core::{ArtifactKind, ArtifactRef};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12 DownloadedModel, HuggingFaceDownloader, HuggingFaceModelSpec, ModelDownloader,
13 ModelFileRequest, ModelTask,
14};
15
16#[derive(Clone)]
17pub struct ModelBundleStore {
19 root: PathBuf,
20 downloader: Arc<dyn ModelDownloader + Send + Sync>,
21 overwrite: bool,
22}
23
24impl std::fmt::Debug for ModelBundleStore {
25 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 formatter
27 .debug_struct("ModelBundleStore")
28 .field("root", &self.root)
29 .field("overwrite", &self.overwrite)
30 .finish_non_exhaustive()
31 }
32}
33
34#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
35pub struct ModelBundleManifest {
37 pub schema_version: u32,
39 pub name: String,
41 pub repo_id: String,
43 pub revision: String,
45 pub task: ModelTask,
47 pub files: BTreeMap<String, ModelBundleFile>,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
52pub struct ModelBundleFile {
54 pub remote_path: String,
56 pub local_path: String,
58 pub size_bytes: u64,
60}
61
62#[derive(Debug, Clone)]
63pub struct ModelBundle {
65 pub root: PathBuf,
67 pub manifest: ModelBundleManifest,
69}
70
71impl ModelBundleStore {
72 pub fn new(root: impl Into<PathBuf>) -> Self {
74 Self {
75 root: root.into(),
76 downloader: Arc::new(HuggingFaceDownloader::new()),
77 overwrite: false,
78 }
79 }
80
81 pub fn downloader(mut self, downloader: HuggingFaceDownloader) -> Self {
83 self.downloader = Arc::new(downloader);
84 self
85 }
86
87 pub fn model_downloader(
89 mut self,
90 downloader: impl ModelDownloader + Send + Sync + 'static,
91 ) -> Self {
92 self.downloader = Arc::new(downloader);
93 self
94 }
95
96 pub fn overwrite(mut self, value: bool) -> Self {
98 self.overwrite = value;
99 self
100 }
101
102 pub fn root(&self) -> &Path {
104 &self.root
105 }
106
107 pub fn bundle_dir(&self, spec: &HuggingFaceModelSpec) -> PathBuf {
109 self.root
110 .join(safe_bundle_segment(&spec.name))
111 .join(safe_bundle_segment(&spec.revision))
112 }
113
114 pub fn download(&self, spec: &HuggingFaceModelSpec) -> Result<ModelBundle> {
116 let downloaded = self.downloader.download_model(spec)?;
117 self.materialize(&downloaded)
118 }
119
120 pub fn materialize(&self, downloaded: &DownloadedModel) -> Result<ModelBundle> {
122 let bundle_root = self.bundle_dir(&downloaded.spec);
123 let manifest_path = bundle_root.join("manifest.json");
124 for remote_path in downloaded.files.keys() {
125 validate_remote_path(remote_path)?;
126 }
127 if manifest_path.exists() && !self.overwrite {
128 return ModelBundle::load(manifest_path);
129 }
130
131 let files_dir = bundle_root.join("files");
132 fs::create_dir_all(&files_dir)?;
133
134 let mut manifest_files = BTreeMap::new();
135 for (remote_path, source_path) in &downloaded.files {
136 let relative_file_path = Path::new("files").join(remote_path);
137 let destination_path = bundle_root.join(&relative_file_path);
138 if let Some(parent) = destination_path.parent() {
139 fs::create_dir_all(parent)?;
140 }
141 if self.overwrite && fs::symlink_metadata(&destination_path).is_ok() {
142 fs::remove_file(&destination_path)?;
143 }
144 let mut should_materialize = match fs::symlink_metadata(&destination_path) {
145 Ok(_) => false,
146 Err(err) if err.kind() == ErrorKind::NotFound => true,
147 Err(err) => return Err(err.into()),
148 };
149 if !should_materialize && fs::metadata(&destination_path).is_err() {
150 fs::remove_file(&destination_path)?;
152 should_materialize = true;
153 }
154 if should_materialize {
155 let source_metadata = fs::symlink_metadata(source_path)?;
156 let linked = !source_metadata.file_type().is_symlink()
157 && fs::hard_link(source_path, &destination_path).is_ok();
158 if !linked {
159 let source_for_copy = if source_metadata.file_type().is_symlink() {
160 fs::canonicalize(source_path)?
161 } else {
162 source_path.clone()
163 };
164 fs::copy(source_for_copy, &destination_path)?;
165 }
166 }
167
168 let size_bytes = fs::metadata(&destination_path)?.len();
169 manifest_files.insert(
170 remote_path.clone(),
171 ModelBundleFile {
172 remote_path: remote_path.clone(),
173 local_path: path_to_manifest_string(&relative_file_path),
174 size_bytes,
175 },
176 );
177 }
178
179 let manifest = ModelBundleManifest {
180 schema_version: 1,
181 name: downloaded.spec.name.clone(),
182 repo_id: downloaded.spec.repo_id.clone(),
183 revision: downloaded.spec.revision.clone(),
184 task: downloaded.spec.task.clone(),
185 files: manifest_files,
186 };
187 let encoded = serde_json::to_vec_pretty(&manifest).map_err(|err| {
188 ModelRuntimeError::Source(format!("failed to encode model manifest: {err}"))
189 })?;
190 fs::write(&manifest_path, encoded)?;
191
192 Ok(ModelBundle {
193 root: bundle_root,
194 manifest,
195 })
196 }
197
198 pub fn load(&self, name: impl AsRef<str>, revision: impl AsRef<str>) -> Result<ModelBundle> {
200 ModelBundle::load(
201 self.root
202 .join(safe_bundle_segment(name.as_ref()))
203 .join(safe_bundle_segment(revision.as_ref()))
204 .join("manifest.json"),
205 )
206 }
207}
208
209#[derive(Debug, Clone, PartialEq, Eq)]
210pub struct ModelBundleResolveOptions {
212 pub bundle_root: PathBuf,
214 pub auto_download: bool,
216 pub download_progress: bool,
218 pub hf_token: Option<String>,
220 pub cache_dir: Option<PathBuf>,
222 pub max_retries: usize,
224 pub overwrite: bool,
226}
227
228impl Default for ModelBundleResolveOptions {
229 fn default() -> Self {
230 Self {
231 bundle_root: PathBuf::from(".model-runtime"),
232 auto_download: true,
233 download_progress: true,
234 hf_token: None,
235 cache_dir: None,
236 max_retries: 1,
237 overwrite: false,
238 }
239 }
240}
241
242impl ModelBundleResolveOptions {
243 pub fn downloader(&self) -> HuggingFaceDownloader {
245 let mut downloader = HuggingFaceDownloader::new()
246 .progress(self.download_progress)
247 .max_retries(self.max_retries);
248 if let Some(cache_dir) = &self.cache_dir {
249 downloader = downloader.cache_dir(cache_dir.clone());
250 }
251 if let Some(token) = &self.hf_token {
252 downloader = downloader.token(token.clone());
253 }
254 downloader
255 }
256}
257
258pub fn resolve_or_download_bundle(
260 spec: &HuggingFaceModelSpec,
261 options: &ModelBundleResolveOptions,
262) -> Result<ModelBundle> {
263 resolve_or_download_bundle_with_downloader(spec, options, options.downloader())
264}
265
266pub fn resolve_or_download_bundle_with_downloader(
268 spec: &HuggingFaceModelSpec,
269 options: &ModelBundleResolveOptions,
270 downloader: impl ModelDownloader + Send + Sync + 'static,
271) -> Result<ModelBundle> {
272 let store = ModelBundleStore::new(options.bundle_root.clone())
273 .model_downloader(downloader)
274 .overwrite(options.overwrite);
275 if let Ok(bundle) = store.load(&spec.name, &spec.revision) {
276 return Ok(bundle);
277 }
278 if !options.auto_download {
279 let expected_path = store.bundle_dir(spec).join("manifest.json");
280 return Err(ModelRuntimeError::InvalidArgument(format!(
281 "missing model bundle `{}` at `{}` and autoDownload is false",
282 spec.name,
283 expected_path.display()
284 )));
285 }
286 store.download(spec)
287}
288
289impl ModelBundle {
290 pub fn manifest_path(&self) -> PathBuf {
292 self.root.join("manifest.json")
293 }
294
295 pub fn file_path(&self, remote_path: &str) -> Option<PathBuf> {
297 self.manifest
298 .files
299 .get(remote_path)
300 .map(|file| self.root.join(&file.local_path))
301 }
302
303 pub fn artifact_refs(&self) -> Vec<ArtifactRef> {
305 self.manifest
306 .files
307 .iter()
308 .map(|(remote_path, file)| {
309 let local_path = self.root.join(&file.local_path);
310 let mut artifact = ArtifactRef::new(
311 format!("model:{}", remote_path.replace(['/', '\\'], "_")),
312 model_file_kind(remote_path),
313 model_file_media_type(remote_path),
314 file_uri(&local_path),
315 );
316 artifact.size_bytes = Some(file.size_bytes);
317 artifact
318 .metadata
319 .insert("model.repoId".to_string(), self.manifest.repo_id.clone());
320 artifact
321 .metadata
322 .insert("model.revision".to_string(), self.manifest.revision.clone());
323 artifact.metadata.insert(
324 "model.task".to_string(),
325 self.manifest.task.as_protocol_str().to_string(),
326 );
327 artifact.metadata.insert(
328 "model.fileRole".to_string(),
329 model_file_role(remote_path).to_string(),
330 );
331 artifact
332 })
333 .collect()
334 }
335
336 pub fn to_downloaded_model(&self) -> DownloadedModel {
338 let files = self
339 .manifest
340 .files
341 .iter()
342 .map(|(remote_path, file)| {
343 (
344 remote_path.clone(),
345 absolute_path(self.root.join(&file.local_path)),
346 )
347 })
348 .collect();
349 let mut spec =
350 HuggingFaceModelSpec::new(self.manifest.repo_id.clone(), self.manifest.task.clone())
351 .name(self.manifest.name.clone())
352 .revision(self.manifest.revision.clone());
353 spec.files = self
354 .manifest
355 .files
356 .keys()
357 .map(|remote_path| ModelFileRequest::required(remote_path.clone()))
358 .collect();
359 DownloadedModel { spec, files }
360 }
361
362 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
364 let path = path.as_ref();
365 let manifest_path = if path.is_dir() {
366 path.join("manifest.json")
367 } else {
368 path.to_path_buf()
369 };
370 let root = manifest_path.parent().ok_or_else(|| {
371 ModelRuntimeError::InvalidArgument(format!(
372 "model bundle manifest `{}` has no parent directory",
373 manifest_path.display()
374 ))
375 })?;
376 let data = fs::read(&manifest_path)?;
377 let manifest = serde_json::from_slice(&data).map_err(|err| {
378 ModelRuntimeError::Source(format!(
379 "failed to decode model bundle manifest `{}`: {err}",
380 manifest_path.display()
381 ))
382 })?;
383 Ok(Self {
384 root: root.to_path_buf(),
385 manifest,
386 })
387 }
388}
389
390fn safe_bundle_segment(value: &str) -> String {
391 let safe = value
392 .chars()
393 .map(|ch| {
394 if ch.is_ascii_alphanumeric() || matches!(ch, '.' | '_' | '-') {
395 ch
396 } else {
397 '_'
398 }
399 })
400 .collect::<String>();
401 if safe.is_empty() {
402 "_".to_string()
403 } else {
404 safe
405 }
406}
407
408fn validate_remote_path(path: &str) -> Result<()> {
409 let remote_path = Path::new(path);
410 if path.is_empty() || remote_path.is_absolute() {
411 return Err(ModelRuntimeError::InvalidArgument(format!(
412 "model file path `{path}` must be relative"
413 )));
414 }
415 for component in remote_path.components() {
416 match component {
417 Component::Normal(_) => {}
418 Component::ParentDir => {
419 return Err(ModelRuntimeError::InvalidArgument(format!(
420 "model file path `{path}` must not contain `..`"
421 )));
422 }
423 _ => {
424 return Err(ModelRuntimeError::InvalidArgument(format!(
425 "model file path `{path}` contains an invalid path component"
426 )));
427 }
428 }
429 }
430 Ok(())
431}
432
433fn path_to_manifest_string(path: &Path) -> String {
434 path.components()
435 .map(|component| component.as_os_str().to_string_lossy())
436 .collect::<Vec<_>>()
437 .join("/")
438}
439
440fn absolute_path(path: PathBuf) -> PathBuf {
441 if path.is_absolute() {
442 path
443 } else if let Ok(current_dir) = std::env::current_dir() {
444 current_dir.join(path)
445 } else {
446 path
447 }
448}
449
450fn file_uri(path: &Path) -> String {
451 format!("file://{}", path.to_string_lossy())
452}
453
454fn model_file_kind(remote_path: &str) -> ArtifactKind {
455 match model_file_role(remote_path) {
456 "config" | "tokenizer" => ArtifactKind::Json,
457 "vocabulary" => ArtifactKind::Text,
458 _ => ArtifactKind::Binary,
459 }
460}
461
462fn model_file_media_type(remote_path: &str) -> &'static str {
463 if remote_path.ends_with(".json") {
464 "application/json"
465 } else if remote_path.ends_with(".txt") {
466 "text/plain"
467 } else {
468 "application/octet-stream"
469 }
470}
471
472fn model_file_role(remote_path: &str) -> &'static str {
473 let file_name = remote_path.rsplit('/').next().unwrap_or(remote_path);
474 if file_name == "config.json" {
475 "config"
476 } else if file_name.contains("tokenizer") {
477 "tokenizer"
478 } else if matches!(file_name, "vocab.txt" | "merges.txt") {
479 "vocabulary"
480 } else if file_name.ends_with(".onnx")
481 || file_name.ends_with(".safetensors")
482 || file_name.ends_with(".bin")
483 || file_name.ends_with(".pt")
484 {
485 "weights"
486 } else {
487 "artifact"
488 }
489}