Skip to main content

model_runtime/
bundles.rs

1use std::collections::BTreeMap;
2use std::fs;
3use std::io::ErrorKind;
4use std::path::{Component, Path, PathBuf};
5use std::sync::Arc;
6
7use crate::{ModelRuntimeError, Result};
8use jobs_core::{ArtifactKind, ArtifactRef};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12    DownloadedModel, HuggingFaceDownloader, HuggingFaceModelSpec, ModelDownloader,
13    ModelFileRequest, ModelTask,
14};
15
16#[derive(Clone)]
17/// Data type for model bundle store.
18pub struct ModelBundleStore {
19    root: PathBuf,
20    downloader: Arc<dyn ModelDownloader + Send + Sync>,
21    overwrite: bool,
22}
23
24impl std::fmt::Debug for ModelBundleStore {
25    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        formatter
27            .debug_struct("ModelBundleStore")
28            .field("root", &self.root)
29            .field("overwrite", &self.overwrite)
30            .finish_non_exhaustive()
31    }
32}
33
34#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
35/// Data type for model bundle manifest.
36pub struct ModelBundleManifest {
37    /// The schema version value.
38    pub schema_version: u32,
39    /// Human-readable name for this value.
40    pub name: String,
41    /// The repo identifier value.
42    pub repo_id: String,
43    /// The revision value.
44    pub revision: String,
45    /// The task value.
46    pub task: ModelTask,
47    /// The files value.
48    pub files: BTreeMap<String, ModelBundleFile>,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
52/// Data type for model bundle file.
53pub struct ModelBundleFile {
54    /// The remote path value.
55    pub remote_path: String,
56    /// The local path value.
57    pub local_path: String,
58    /// The size bytes value.
59    pub size_bytes: u64,
60}
61
62#[derive(Debug, Clone)]
63/// Data type for model bundle.
64pub struct ModelBundle {
65    /// The root value.
66    pub root: PathBuf,
67    /// The manifest value.
68    pub manifest: ModelBundleManifest,
69}
70
71impl ModelBundleStore {
72    /// Creates a new value.
73    pub fn new(root: impl Into<PathBuf>) -> Self {
74        Self {
75            root: root.into(),
76            downloader: Arc::new(HuggingFaceDownloader::new()),
77            overwrite: false,
78        }
79    }
80
81    /// Returns downloader.
82    pub fn downloader(mut self, downloader: HuggingFaceDownloader) -> Self {
83        self.downloader = Arc::new(downloader);
84        self
85    }
86
87    /// Sets a custom downloader implementation.
88    pub fn model_downloader(
89        mut self,
90        downloader: impl ModelDownloader + Send + Sync + 'static,
91    ) -> Self {
92        self.downloader = Arc::new(downloader);
93        self
94    }
95
96    /// Returns overwrite.
97    pub fn overwrite(mut self, value: bool) -> Self {
98        self.overwrite = value;
99        self
100    }
101
102    /// Returns root.
103    pub fn root(&self) -> &Path {
104        &self.root
105    }
106
107    /// Returns bundle dir.
108    pub fn bundle_dir(&self, spec: &HuggingFaceModelSpec) -> PathBuf {
109        self.root
110            .join(safe_bundle_segment(&spec.name))
111            .join(safe_bundle_segment(&spec.revision))
112    }
113
114    /// Returns download.
115    pub fn download(&self, spec: &HuggingFaceModelSpec) -> Result<ModelBundle> {
116        let downloaded = self.downloader.download_model(spec)?;
117        self.materialize(&downloaded)
118    }
119
120    /// Returns materialize.
121    pub fn materialize(&self, downloaded: &DownloadedModel) -> Result<ModelBundle> {
122        let bundle_root = self.bundle_dir(&downloaded.spec);
123        let manifest_path = bundle_root.join("manifest.json");
124        for remote_path in downloaded.files.keys() {
125            validate_remote_path(remote_path)?;
126        }
127        if manifest_path.exists() && !self.overwrite {
128            return ModelBundle::load(manifest_path);
129        }
130
131        let files_dir = bundle_root.join("files");
132        fs::create_dir_all(&files_dir)?;
133
134        let mut manifest_files = BTreeMap::new();
135        for (remote_path, source_path) in &downloaded.files {
136            let relative_file_path = Path::new("files").join(remote_path);
137            let destination_path = bundle_root.join(&relative_file_path);
138            if let Some(parent) = destination_path.parent() {
139                fs::create_dir_all(parent)?;
140            }
141            if self.overwrite && fs::symlink_metadata(&destination_path).is_ok() {
142                fs::remove_file(&destination_path)?;
143            }
144            let mut should_materialize = match fs::symlink_metadata(&destination_path) {
145                Ok(_) => false,
146                Err(err) if err.kind() == ErrorKind::NotFound => true,
147                Err(err) => return Err(err.into()),
148            };
149            if !should_materialize && fs::metadata(&destination_path).is_err() {
150                // A stale/dangling symlink should be replaced with fresh materialized bytes.
151                fs::remove_file(&destination_path)?;
152                should_materialize = true;
153            }
154            if should_materialize {
155                let source_metadata = fs::symlink_metadata(source_path)?;
156                let linked = !source_metadata.file_type().is_symlink()
157                    && fs::hard_link(source_path, &destination_path).is_ok();
158                if !linked {
159                    let source_for_copy = if source_metadata.file_type().is_symlink() {
160                        fs::canonicalize(source_path)?
161                    } else {
162                        source_path.clone()
163                    };
164                    fs::copy(source_for_copy, &destination_path)?;
165                }
166            }
167
168            let size_bytes = fs::metadata(&destination_path)?.len();
169            manifest_files.insert(
170                remote_path.clone(),
171                ModelBundleFile {
172                    remote_path: remote_path.clone(),
173                    local_path: path_to_manifest_string(&relative_file_path),
174                    size_bytes,
175                },
176            );
177        }
178
179        let manifest = ModelBundleManifest {
180            schema_version: 1,
181            name: downloaded.spec.name.clone(),
182            repo_id: downloaded.spec.repo_id.clone(),
183            revision: downloaded.spec.revision.clone(),
184            task: downloaded.spec.task.clone(),
185            files: manifest_files,
186        };
187        let encoded = serde_json::to_vec_pretty(&manifest).map_err(|err| {
188            ModelRuntimeError::Source(format!("failed to encode model manifest: {err}"))
189        })?;
190        fs::write(&manifest_path, encoded)?;
191
192        Ok(ModelBundle {
193            root: bundle_root,
194            manifest,
195        })
196    }
197
198    /// Returns load.
199    pub fn load(&self, name: impl AsRef<str>, revision: impl AsRef<str>) -> Result<ModelBundle> {
200        ModelBundle::load(
201            self.root
202                .join(safe_bundle_segment(name.as_ref()))
203                .join(safe_bundle_segment(revision.as_ref()))
204                .join("manifest.json"),
205        )
206    }
207}
208
209#[derive(Debug, Clone, PartialEq, Eq)]
210/// Options for resolving a local bundle, downloading it when allowed.
211pub struct ModelBundleResolveOptions {
212    /// Root directory containing model bundles.
213    pub bundle_root: PathBuf,
214    /// Whether missing bundles may be downloaded.
215    pub auto_download: bool,
216    /// Whether downloads should report progress.
217    pub download_progress: bool,
218    /// Optional Hugging Face token.
219    pub hf_token: Option<String>,
220    /// Optional Hugging Face cache directory.
221    pub cache_dir: Option<PathBuf>,
222    /// Maximum download retries.
223    pub max_retries: usize,
224    /// Whether materialization should overwrite existing files.
225    pub overwrite: bool,
226}
227
228impl Default for ModelBundleResolveOptions {
229    fn default() -> Self {
230        Self {
231            bundle_root: PathBuf::from(".model-runtime"),
232            auto_download: true,
233            download_progress: true,
234            hf_token: None,
235            cache_dir: None,
236            max_retries: 1,
237            overwrite: false,
238        }
239    }
240}
241
242impl ModelBundleResolveOptions {
243    /// Builds the configured Hugging Face downloader.
244    pub fn downloader(&self) -> HuggingFaceDownloader {
245        let mut downloader = HuggingFaceDownloader::new()
246            .progress(self.download_progress)
247            .max_retries(self.max_retries);
248        if let Some(cache_dir) = &self.cache_dir {
249            downloader = downloader.cache_dir(cache_dir.clone());
250        }
251        if let Some(token) = &self.hf_token {
252            downloader = downloader.token(token.clone());
253        }
254        downloader
255    }
256}
257
258/// Resolves a bundle from disk, optionally downloading and materializing it first.
259pub fn resolve_or_download_bundle(
260    spec: &HuggingFaceModelSpec,
261    options: &ModelBundleResolveOptions,
262) -> Result<ModelBundle> {
263    resolve_or_download_bundle_with_downloader(spec, options, options.downloader())
264}
265
266/// Resolves a bundle with a caller-provided downloader seam.
267pub fn resolve_or_download_bundle_with_downloader(
268    spec: &HuggingFaceModelSpec,
269    options: &ModelBundleResolveOptions,
270    downloader: impl ModelDownloader + Send + Sync + 'static,
271) -> Result<ModelBundle> {
272    let store = ModelBundleStore::new(options.bundle_root.clone())
273        .model_downloader(downloader)
274        .overwrite(options.overwrite);
275    if let Ok(bundle) = store.load(&spec.name, &spec.revision) {
276        return Ok(bundle);
277    }
278    if !options.auto_download {
279        let expected_path = store.bundle_dir(spec).join("manifest.json");
280        return Err(ModelRuntimeError::InvalidArgument(format!(
281            "missing model bundle `{}` at `{}` and autoDownload is false",
282            spec.name,
283            expected_path.display()
284        )));
285    }
286    store.download(spec)
287}
288
289impl ModelBundle {
290    /// Returns manifest path.
291    pub fn manifest_path(&self) -> PathBuf {
292        self.root.join("manifest.json")
293    }
294
295    /// Returns file path.
296    pub fn file_path(&self, remote_path: &str) -> Option<PathBuf> {
297        self.manifest
298            .files
299            .get(remote_path)
300            .map(|file| self.root.join(&file.local_path))
301    }
302
303    /// Returns generic job artifact references for the files in this model bundle.
304    pub fn artifact_refs(&self) -> Vec<ArtifactRef> {
305        self.manifest
306            .files
307            .iter()
308            .map(|(remote_path, file)| {
309                let local_path = self.root.join(&file.local_path);
310                let mut artifact = ArtifactRef::new(
311                    format!("model:{}", remote_path.replace(['/', '\\'], "_")),
312                    model_file_kind(remote_path),
313                    model_file_media_type(remote_path),
314                    file_uri(&local_path),
315                );
316                artifact.size_bytes = Some(file.size_bytes);
317                artifact
318                    .metadata
319                    .insert("model.repoId".to_string(), self.manifest.repo_id.clone());
320                artifact
321                    .metadata
322                    .insert("model.revision".to_string(), self.manifest.revision.clone());
323                artifact.metadata.insert(
324                    "model.task".to_string(),
325                    self.manifest.task.as_protocol_str().to_string(),
326                );
327                artifact.metadata.insert(
328                    "model.fileRole".to_string(),
329                    model_file_role(remote_path).to_string(),
330                );
331                artifact
332            })
333            .collect()
334    }
335
336    /// Converts this value to downloaded model.
337    pub fn to_downloaded_model(&self) -> DownloadedModel {
338        let files = self
339            .manifest
340            .files
341            .iter()
342            .map(|(remote_path, file)| {
343                (
344                    remote_path.clone(),
345                    absolute_path(self.root.join(&file.local_path)),
346                )
347            })
348            .collect();
349        let mut spec =
350            HuggingFaceModelSpec::new(self.manifest.repo_id.clone(), self.manifest.task.clone())
351                .name(self.manifest.name.clone())
352                .revision(self.manifest.revision.clone());
353        spec.files = self
354            .manifest
355            .files
356            .keys()
357            .map(|remote_path| ModelFileRequest::required(remote_path.clone()))
358            .collect();
359        DownloadedModel { spec, files }
360    }
361
362    /// Returns load.
363    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
364        let path = path.as_ref();
365        let manifest_path = if path.is_dir() {
366            path.join("manifest.json")
367        } else {
368            path.to_path_buf()
369        };
370        let root = manifest_path.parent().ok_or_else(|| {
371            ModelRuntimeError::InvalidArgument(format!(
372                "model bundle manifest `{}` has no parent directory",
373                manifest_path.display()
374            ))
375        })?;
376        let data = fs::read(&manifest_path)?;
377        let manifest = serde_json::from_slice(&data).map_err(|err| {
378            ModelRuntimeError::Source(format!(
379                "failed to decode model bundle manifest `{}`: {err}",
380                manifest_path.display()
381            ))
382        })?;
383        Ok(Self {
384            root: root.to_path_buf(),
385            manifest,
386        })
387    }
388}
389
390fn safe_bundle_segment(value: &str) -> String {
391    let safe = value
392        .chars()
393        .map(|ch| {
394            if ch.is_ascii_alphanumeric() || matches!(ch, '.' | '_' | '-') {
395                ch
396            } else {
397                '_'
398            }
399        })
400        .collect::<String>();
401    if safe.is_empty() {
402        "_".to_string()
403    } else {
404        safe
405    }
406}
407
408fn validate_remote_path(path: &str) -> Result<()> {
409    let remote_path = Path::new(path);
410    if path.is_empty() || remote_path.is_absolute() {
411        return Err(ModelRuntimeError::InvalidArgument(format!(
412            "model file path `{path}` must be relative"
413        )));
414    }
415    for component in remote_path.components() {
416        match component {
417            Component::Normal(_) => {}
418            Component::ParentDir => {
419                return Err(ModelRuntimeError::InvalidArgument(format!(
420                    "model file path `{path}` must not contain `..`"
421                )));
422            }
423            _ => {
424                return Err(ModelRuntimeError::InvalidArgument(format!(
425                    "model file path `{path}` contains an invalid path component"
426                )));
427            }
428        }
429    }
430    Ok(())
431}
432
433fn path_to_manifest_string(path: &Path) -> String {
434    path.components()
435        .map(|component| component.as_os_str().to_string_lossy())
436        .collect::<Vec<_>>()
437        .join("/")
438}
439
440fn absolute_path(path: PathBuf) -> PathBuf {
441    if path.is_absolute() {
442        path
443    } else if let Ok(current_dir) = std::env::current_dir() {
444        current_dir.join(path)
445    } else {
446        path
447    }
448}
449
450fn file_uri(path: &Path) -> String {
451    format!("file://{}", path.to_string_lossy())
452}
453
454fn model_file_kind(remote_path: &str) -> ArtifactKind {
455    match model_file_role(remote_path) {
456        "config" | "tokenizer" => ArtifactKind::Json,
457        "vocabulary" => ArtifactKind::Text,
458        _ => ArtifactKind::Binary,
459    }
460}
461
462fn model_file_media_type(remote_path: &str) -> &'static str {
463    if remote_path.ends_with(".json") {
464        "application/json"
465    } else if remote_path.ends_with(".txt") {
466        "text/plain"
467    } else {
468        "application/octet-stream"
469    }
470}
471
472fn model_file_role(remote_path: &str) -> &'static str {
473    let file_name = remote_path.rsplit('/').next().unwrap_or(remote_path);
474    if file_name == "config.json" {
475        "config"
476    } else if file_name.contains("tokenizer") {
477        "tokenizer"
478    } else if matches!(file_name, "vocab.txt" | "merges.txt") {
479        "vocabulary"
480    } else if file_name.ends_with(".onnx")
481        || file_name.ends_with(".safetensors")
482        || file_name.ends_with(".bin")
483        || file_name.ends_with(".pt")
484    {
485        "weights"
486    } else {
487        "artifact"
488    }
489}