Skip to main content

model_artifact/
lib.rs

1pub mod gguf;
2
3use std::path::Path;
4
5use anyhow::{Result, bail};
6use async_trait::async_trait;
7use model_ref::{
8    ModelRef, format_canonical_ref, gguf_matches_quant_selector, normalize_gguf_distribution_id,
9    parse_model_ref, split_gguf_shard_info,
10};
11use serde::{Deserialize, Serialize};
12
13#[async_trait]
14pub trait ModelRepository: Send + Sync {
15    async fn resolve_revision(&self, repo: &str, revision: Option<&str>) -> Result<String>;
16
17    async fn list_files(&self, repo: &str, revision: &str) -> Result<Vec<ModelArtifactFile>>;
18}
19
20#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
21pub struct ResolvedModelArtifact {
22    pub model_id: String,
23    pub source_repo: String,
24    pub source_revision: String,
25    pub selector: Option<String>,
26    pub format: ModelFormat,
27    pub files: Vec<ModelArtifactFile>,
28    pub primary_file: String,
29    pub canonical_ref: String,
30    pub distribution_id: String,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
34pub struct ModelIdentity {
35    pub model_id: String,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub source_repo: Option<String>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub source_revision: Option<String>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub source_file: Option<String>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub canonical_ref: Option<String>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub distribution_id: Option<String>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub selector: Option<String>,
48}
49
50impl ModelIdentity {
51    pub fn from_model_id(model_id: impl Into<String>) -> Self {
52        Self {
53            model_id: model_id.into(),
54            source_repo: None,
55            source_revision: None,
56            source_file: None,
57            canonical_ref: None,
58            distribution_id: None,
59            selector: None,
60        }
61    }
62}
63
64impl From<&ResolvedModelArtifact> for ModelIdentity {
65    fn from(artifact: &ResolvedModelArtifact) -> Self {
66        Self {
67            model_id: artifact.model_id.clone(),
68            source_repo: Some(artifact.source_repo.clone()),
69            source_revision: Some(artifact.source_revision.clone()),
70            source_file: Some(artifact.primary_file.clone()),
71            canonical_ref: Some(artifact.canonical_ref.clone()),
72            distribution_id: Some(artifact.distribution_id.clone()),
73            selector: artifact.selector.clone(),
74        }
75    }
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79#[serde(rename_all = "snake_case")]
80pub enum ModelFormat {
81    Gguf,
82    Safetensors,
83}
84
85#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
86pub struct ModelArtifactFile {
87    pub path: String,
88    pub size_bytes: Option<u64>,
89    pub sha256: Option<String>,
90}
91
92impl ModelArtifactFile {
93    pub fn new(path: impl Into<String>) -> Self {
94        Self {
95            path: path.into(),
96            size_bytes: None,
97            sha256: None,
98        }
99    }
100}
101
102pub async fn resolve_model_artifact_ref(
103    model_ref: &str,
104    repository: &impl ModelRepository,
105) -> Result<ResolvedModelArtifact> {
106    let parsed = parse_model_ref(model_ref)?;
107    resolve_model_artifact(&parsed, repository).await
108}
109
110pub async fn resolve_model_artifact(
111    model_ref: &ModelRef,
112    repository: &impl ModelRepository,
113) -> Result<ResolvedModelArtifact> {
114    let source_revision = repository
115        .resolve_revision(&model_ref.repo, model_ref.revision.as_deref())
116        .await?;
117    let mut repo_files = repository
118        .list_files(&model_ref.repo, &source_revision)
119        .await?;
120    repo_files.sort_by(|left, right| left.path.cmp(&right.path));
121
122    let primary_file = select_primary_file(model_ref.selector.as_deref(), &repo_files)?;
123    let format = format_for_file(&primary_file.path)?;
124    let files = artifact_file_set(&primary_file.path, &repo_files);
125    let distribution_id = distribution_id_for_file(&primary_file.path)?;
126
127    Ok(ResolvedModelArtifact {
128        model_id: model_ref.display_id(),
129        source_repo: model_ref.repo.clone(),
130        source_revision: source_revision.clone(),
131        selector: model_ref.selector.clone(),
132        format,
133        files,
134        primary_file: primary_file.path.clone(),
135        canonical_ref: format_canonical_ref(&model_ref.repo, &source_revision, &primary_file.path),
136        distribution_id,
137    })
138}
139
140pub fn select_primary_artifact_file(
141    selector: Option<&str>,
142    files: &[ModelArtifactFile],
143) -> Result<ModelArtifactFile> {
144    select_primary_file(selector, files)
145}
146
147pub fn artifact_files_for_primary(
148    primary_file: &str,
149    files: &[ModelArtifactFile],
150) -> Vec<ModelArtifactFile> {
151    artifact_file_set(primary_file, files)
152}
153
154fn select_primary_file(
155    selector: Option<&str>,
156    files: &[ModelArtifactFile],
157) -> Result<ModelArtifactFile> {
158    let Some(selector) = selector else {
159        return select_default_file(files);
160    };
161
162    let selector_lower = selector.to_ascii_lowercase();
163    let gguf_exact = format!("{selector}.gguf").to_ascii_lowercase();
164    let gguf_split_prefix = format!("{selector}-00001-of-").to_ascii_lowercase();
165    let safetensors_exact = format!("{selector}.safetensors").to_ascii_lowercase();
166    let safetensors_split_prefix = format!("{selector}-00001-of-").to_ascii_lowercase();
167
168    files
169        .iter()
170        .filter_map(|file| {
171            let lower = file.path.to_ascii_lowercase();
172            let basename = basename_lower(&file.path);
173            let rank = if lower == selector_lower || basename == selector_lower {
174                0
175            } else if gguf_matches_quant_selector(&file.path, selector) {
176                1
177            } else if basename == safetensors_exact {
178                2
179            } else if basename.starts_with(&safetensors_split_prefix)
180                && basename.ends_with(".safetensors")
181            {
182                3
183            } else if basename == gguf_exact {
184                4
185            } else if basename.starts_with(&gguf_split_prefix) && basename.ends_with(".gguf") {
186                5
187            } else {
188                return None;
189            };
190            Some((
191                rank,
192                artifact_preference_score(&file.path),
193                file.path.clone(),
194                file.clone(),
195            ))
196        })
197        .min_by(|left, right| (left.0, left.1, &left.2).cmp(&(right.0, right.1, &right.2)))
198        .map(|(_, _, _, file)| file)
199        .ok_or_else(|| {
200            anyhow::anyhow!("no model artifact matching selector '{selector}' in repository")
201        })
202}
203
204fn select_default_file(files: &[ModelArtifactFile]) -> Result<ModelArtifactFile> {
205    files
206        .iter()
207        .filter_map(|file| {
208            let lower = file.path.to_ascii_lowercase();
209            let basename = basename_lower(&file.path);
210            let rank = if basename == "model.safetensors" {
211                0
212            } else if is_split_safetensors_first_shard(&basename) {
213                1
214            } else if lower.ends_with(".gguf") {
215                if is_known_gguf_sidecar(&basename) {
216                    return None;
217                }
218                if lower.contains("-000") && !lower.contains("-00001-of-") {
219                    return None;
220                }
221                if lower.contains("-00001-of-") { 2 } else { 3 }
222            } else {
223                return None;
224            };
225            Some((
226                rank,
227                artifact_preference_score(&file.path),
228                file.path.clone(),
229                file.clone(),
230            ))
231        })
232        .min_by(|left, right| (left.0, left.1, &left.2).cmp(&(right.0, right.1, &right.2)))
233        .map(|(_, _, _, file)| file)
234        .ok_or_else(|| anyhow::anyhow!("no supported model artifact files found in repository"))
235}
236
237fn artifact_file_set(primary_file: &str, files: &[ModelArtifactFile]) -> Vec<ModelArtifactFile> {
238    if let Some(primary) = split_gguf_shard_info(primary_file) {
239        let mut shards = files
240            .iter()
241            .filter(|file| {
242                split_gguf_shard_info(&file.path)
243                    .map(|candidate| {
244                        candidate.prefix == primary.prefix && candidate.total == primary.total
245                    })
246                    .unwrap_or(false)
247            })
248            .cloned()
249            .collect::<Vec<_>>();
250        shards.sort_by(|left, right| left.path.cmp(&right.path));
251        if !shards.is_empty() {
252            return shards;
253        }
254    }
255
256    vec![
257        files
258            .iter()
259            .find(|file| file.path == primary_file)
260            .cloned()
261            .unwrap_or_else(|| ModelArtifactFile::new(primary_file)),
262    ]
263}
264
265fn format_for_file(file: &str) -> Result<ModelFormat> {
266    if file.ends_with(".gguf") {
267        return Ok(ModelFormat::Gguf);
268    }
269    if file.ends_with(".safetensors") || file.ends_with(".safetensors.index.json") {
270        return Ok(ModelFormat::Safetensors);
271    }
272    bail!("unsupported model artifact file format: {file}")
273}
274
275fn distribution_id_for_file(file: &str) -> Result<String> {
276    if file.ends_with(".gguf") {
277        return normalize_gguf_distribution_id(file)
278            .ok_or_else(|| anyhow::anyhow!("invalid GGUF artifact file name: {file}"));
279    }
280    let basename = Path::new(file)
281        .file_name()
282        .and_then(|value| value.to_str())
283        .unwrap_or(file);
284    let stem = basename.strip_suffix(".safetensors").unwrap_or(basename);
285    Ok(split_safetensors_shard_stem_prefix(stem)
286        .unwrap_or(stem)
287        .to_string())
288}
289
290fn basename_lower(path: &str) -> String {
291    Path::new(path)
292        .file_name()
293        .and_then(|value| value.to_str())
294        .unwrap_or(path)
295        .to_ascii_lowercase()
296}
297
298fn artifact_preference_score(file: &str) -> usize {
299    if file.contains("-00001-of-") {
300        return 0;
301    }
302    const PREFERRED: &[&str] = &[
303        "Q4_K_M", "Q4_K_S", "Q4_1", "Q5_K_M", "Q5_K_S", "Q8_0", "BF16",
304    ];
305    PREFERRED
306        .iter()
307        .position(|needle| file.contains(needle))
308        .map(|pos| pos + 1)
309        .unwrap_or(PREFERRED.len() + 2)
310}
311
312fn is_known_gguf_sidecar(basename_lower: &str) -> bool {
313    basename_lower.starts_with("mmproj")
314}
315
316fn is_split_safetensors_first_shard(basename_lower: &str) -> bool {
317    let Some(stem) = basename_lower.strip_suffix(".safetensors") else {
318        return false;
319    };
320    split_safetensors_shard_info(stem)
321        .map(|(_, part, _)| part == "00001")
322        .unwrap_or(false)
323}
324
325fn split_safetensors_shard_stem_prefix(stem: &str) -> Option<&str> {
326    split_safetensors_shard_info(stem).map(|(prefix, _, _)| prefix)
327}
328
329fn split_safetensors_shard_info(stem: &str) -> Option<(&str, &str, &str)> {
330    let (prefix_and_part, total) = stem.rsplit_once("-of-")?;
331    if total.len() != 5 || !total.bytes().all(|byte| byte.is_ascii_digit()) {
332        return None;
333    }
334    let (prefix, part) = prefix_and_part.rsplit_once('-')?;
335    if part.len() != 5 || !part.bytes().all(|byte| byte.is_ascii_digit()) {
336        return None;
337    }
338    Some((prefix, part, total))
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use std::collections::HashMap;
345
346    struct MemoryRepository {
347        revision: String,
348        files: HashMap<String, Vec<ModelArtifactFile>>,
349    }
350
351    #[async_trait]
352    impl ModelRepository for MemoryRepository {
353        async fn resolve_revision(&self, _repo: &str, revision: Option<&str>) -> Result<String> {
354            Ok(revision.unwrap_or(&self.revision).to_string())
355        }
356
357        async fn list_files(&self, repo: &str, _revision: &str) -> Result<Vec<ModelArtifactFile>> {
358            Ok(self.files.get(repo).cloned().unwrap_or_default())
359        }
360    }
361
362    fn repo(files: Vec<&str>) -> MemoryRepository {
363        MemoryRepository {
364            revision: "abc123".to_string(),
365            files: HashMap::from([(
366                "org/repo".to_string(),
367                files.into_iter().map(ModelArtifactFile::new).collect(),
368            )]),
369        }
370    }
371
372    fn files(paths: &[&str]) -> Vec<ModelArtifactFile> {
373        paths.iter().copied().map(ModelArtifactFile::new).collect()
374    }
375
376    #[tokio::test]
377    async fn resolves_quant_selector_to_gguf_file() {
378        let repository = repo(vec!["Model-Q5_K_M.gguf", "Model-Q4_K_M.gguf", "README.md"]);
379
380        let resolved = resolve_model_artifact_ref("org/repo:Q4_K_M", &repository)
381            .await
382            .unwrap();
383
384        assert_eq!(resolved.model_id, "org/repo:Q4_K_M");
385        assert_eq!(resolved.source_revision, "abc123");
386        assert_eq!(resolved.primary_file, "Model-Q4_K_M.gguf");
387        assert_eq!(resolved.canonical_ref, "org/repo@abc123/Model-Q4_K_M.gguf");
388        assert_eq!(resolved.distribution_id, "Model-Q4_K_M");
389        assert_eq!(resolved.files.len(), 1);
390    }
391
392    #[tokio::test]
393    async fn resolves_split_gguf_selector_to_all_shards() {
394        let repository = repo(vec![
395            "UD-IQ2_M/GLM-5.1-UD-IQ2_M-00002-of-00003.gguf",
396            "UD-IQ2_M/GLM-5.1-UD-IQ2_M-00001-of-00003.gguf",
397            "UD-IQ2_M/GLM-5.1-UD-IQ2_M-00003-of-00003.gguf",
398            "UD-Q4_K_M/GLM-5.1-UD-Q4_K_M-00001-of-00003.gguf",
399        ]);
400
401        let resolved = resolve_model_artifact_ref("org/repo:UD-IQ2_M", &repository)
402            .await
403            .unwrap();
404
405        assert_eq!(
406            resolved.primary_file,
407            "UD-IQ2_M/GLM-5.1-UD-IQ2_M-00001-of-00003.gguf"
408        );
409        assert_eq!(resolved.distribution_id, "GLM-5.1-UD-IQ2_M");
410        assert_eq!(resolved.files.len(), 3);
411        assert_eq!(
412            resolved.files[2].path,
413            "UD-IQ2_M/GLM-5.1-UD-IQ2_M-00003-of-00003.gguf"
414        );
415    }
416
417    #[test]
418    fn public_selector_api_resolves_mesh_split_stem_to_first_part() {
419        let files = files(&[
420            "zai-org.GLM-5.1.Q2_K-00002-of-00018.gguf",
421            "zai-org.GLM-5.1.Q2_K-00001-of-00018.gguf",
422        ]);
423
424        let selected = select_primary_artifact_file(Some("zai-org.GLM-5.1.Q2_K"), &files).unwrap();
425
426        assert_eq!(selected.path, "zai-org.GLM-5.1.Q2_K-00001-of-00018.gguf");
427    }
428
429    #[test]
430    fn public_selector_api_resolves_mesh_quant_aliases() {
431        let files = files(&[
432            "qwen3.5-moe-0.87B-d0.8B.Q2_K.gguf",
433            "gemma-4-31B-it-Q4_0.gguf",
434            "Qwen3-8B-Q4_K_M.gguf",
435        ]);
436
437        assert_eq!(
438            select_primary_artifact_file(Some("Q2_K"), &files)
439                .unwrap()
440                .path,
441            "qwen3.5-moe-0.87B-d0.8B.Q2_K.gguf"
442        );
443        assert_eq!(
444            select_primary_artifact_file(Some("Q4_0"), &files)
445                .unwrap()
446                .path,
447            "gemma-4-31B-it-Q4_0.gguf"
448        );
449    }
450
451    #[test]
452    fn public_selector_api_resolves_mesh_mlx_shorthand() {
453        let files = files(&[
454            "model-00002-of-00048.safetensors",
455            "model-00001-of-00048.safetensors",
456            "model.safetensors.index.json",
457        ]);
458
459        let selected = select_primary_artifact_file(Some("model"), &files).unwrap();
460
461        assert_eq!(selected.path, "model-00001-of-00048.safetensors");
462    }
463
464    #[test]
465    fn public_default_api_preserves_mesh_default_ordering() {
466        let files = files(&[
467            "Qwen3-8B-Q8_0.gguf",
468            "mmproj-BF16.gguf",
469            "Qwen3-8B-Q4_K_M.gguf",
470        ]);
471
472        let selected = select_primary_artifact_file(None, &files).unwrap();
473
474        assert_eq!(selected.path, "Qwen3-8B-Q4_K_M.gguf");
475    }
476
477    #[test]
478    fn public_default_api_prefers_mlx_weights_over_gguf() {
479        let files = files(&[
480            "Qwen3-8B-Q4_K_M.gguf",
481            "model.safetensors",
482            "model.safetensors.index.json",
483        ]);
484
485        let selected = select_primary_artifact_file(None, &files).unwrap();
486
487        assert_eq!(selected.path, "model.safetensors");
488    }
489
490    #[test]
491    fn public_artifact_set_returns_all_split_gguf_shards() {
492        let files = files(&[
493            "UD-IQ2_M/GLM-5.1-UD-IQ2_M-00002-of-00003.gguf",
494            "UD-IQ2_M/GLM-5.1-UD-IQ2_M-00001-of-00003.gguf",
495            "UD-IQ2_M/GLM-5.1-UD-IQ2_M-00003-of-00003.gguf",
496            "UD-Q4_K_M/GLM-5.1-UD-Q4_K_M-00001-of-00003.gguf",
497        ]);
498
499        let shards =
500            artifact_files_for_primary("UD-IQ2_M/GLM-5.1-UD-IQ2_M-00001-of-00003.gguf", &files);
501
502        assert_eq!(
503            shards
504                .iter()
505                .map(|file| file.path.as_str())
506                .collect::<Vec<_>>(),
507            vec![
508                "UD-IQ2_M/GLM-5.1-UD-IQ2_M-00001-of-00003.gguf",
509                "UD-IQ2_M/GLM-5.1-UD-IQ2_M-00002-of-00003.gguf",
510                "UD-IQ2_M/GLM-5.1-UD-IQ2_M-00003-of-00003.gguf",
511            ]
512        );
513    }
514
515    #[tokio::test]
516    async fn accepts_revisioned_selector_refs() {
517        let repository = repo(vec!["Model-Q4_K_M.gguf"]);
518
519        let resolved = resolve_model_artifact_ref("org/repo:Q4_K_M@rev-1", &repository)
520            .await
521            .unwrap();
522
523        assert_eq!(resolved.model_id, "org/repo@rev-1:Q4_K_M");
524        assert_eq!(resolved.source_revision, "rev-1");
525        assert_eq!(resolved.canonical_ref, "org/repo@rev-1/Model-Q4_K_M.gguf");
526    }
527
528    #[tokio::test]
529    async fn default_selection_prefers_primary_weights() {
530        let repository = repo(vec![
531            "README.md",
532            "Qwen3-8B-Q4_K_M.gguf",
533            "Qwen3-8B-Q5_K_M.gguf",
534        ]);
535
536        let resolved = resolve_model_artifact_ref("org/repo", &repository)
537            .await
538            .unwrap();
539
540        assert_eq!(resolved.primary_file, "Qwen3-8B-Q4_K_M.gguf");
541        assert_eq!(resolved.format, ModelFormat::Gguf);
542    }
543
544    #[tokio::test]
545    async fn unknown_selector_returns_error() {
546        let repository = repo(vec!["Model-Q4_K_M.gguf"]);
547
548        let error = resolve_model_artifact_ref("org/repo:Q5_K_M", &repository)
549            .await
550            .unwrap_err();
551
552        assert!(
553            error
554                .to_string()
555                .contains("no model artifact matching selector")
556        );
557    }
558}