moenarch-model-runtime 0.1.0

Generic model specs, bundles, downloads, and job helpers for multimodal runtimes.
Documentation
use std::collections::BTreeMap;

use tempfile::tempdir;

use crate::{
    resolve_or_download_bundle_with_downloader, DownloadedModel, HuggingFaceModelSpec, ModelBundle,
    ModelBundleResolveOptions, ModelBundleStore, ModelDownloader, ModelFileRequest, ModelPreset,
    ModelRuntimeBackend, ModelSource, ModelSpec, ModelTask, RawPrediction,
};

#[derive(Debug, Clone)]
struct FakeDownloader {
    root: std::path::PathBuf,
}

impl ModelDownloader for FakeDownloader {
    fn download_model(&self, spec: &HuggingFaceModelSpec) -> crate::Result<DownloadedModel> {
        std::fs::create_dir_all(&self.root)?;
        let mut files = BTreeMap::new();
        for request in &spec.files {
            let remote_path = match request {
                ModelFileRequest::Required(path) | ModelFileRequest::Optional(path) => path.clone(),
                ModelFileRequest::FirstAvailable(paths) => paths[0].clone(),
            };
            let local = self.root.join(remote_path.replace('/', "_"));
            std::fs::write(&local, b"fake")?;
            files.insert(remote_path, local);
        }
        Ok(DownloadedModel {
            spec: spec.clone(),
            files,
        })
    }
}

#[test]
fn model_spec_records_generic_source_and_compat_fields() {
    let spec = ModelSpec::new("owner/model", ModelTask::TextEmbedding)
        .revision("v1")
        .file("config.json");

    assert_eq!(spec.repo_id_value(), Some("owner/model"));
    assert_eq!(spec.revision_value(), Some("v1"));
    assert_eq!(spec.source.kind(), "hugging_face");
    assert_eq!(spec.files, vec![ModelFileRequest::required("config.json")]);

    let local = ModelSpec::from_source(
        "local-detector",
        ModelTask::ObjectDetection,
        ModelSource::Custom("fixture".to_string()),
    );
    assert_eq!(local.repo_id_value(), None);
    assert_eq!(local.source.kind(), "fixture");
}

#[test]
fn model_bundle_store_materializes_generic_manifest() {
    let temp = tempdir().unwrap();
    let source = temp.path().join("config.json");
    std::fs::write(&source, br#"{"model_type":"fixture"}"#).unwrap();

    let downloaded = DownloadedModel {
        spec: ModelSpec::new("owner/model", ModelTask::TextClassification)
            .name("fixture-model")
            .file("config.json"),
        files: BTreeMap::from([("config.json".to_string(), source)]),
    };

    let bundle = ModelBundleStore::new(temp.path().join("bundles"))
        .materialize(&downloaded)
        .unwrap();

    assert_eq!(bundle.manifest.name, "fixture-model");
    assert_eq!(bundle.manifest.task, ModelTask::TextClassification);
    assert!(bundle.manifest_path().exists());
    assert!(bundle.file_path("config.json").unwrap().exists());

    let loaded = ModelBundle::load(bundle.manifest_path()).unwrap();
    assert_eq!(loaded.manifest, bundle.manifest);
}

#[test]
fn model_presets_include_local_onnx_defaults() {
    let ids = ModelPreset::ALL
        .iter()
        .map(|preset| preset.as_str())
        .collect::<Vec<_>>();

    assert!(ids.contains(&"roberta-base-squad2-onnx"));
    assert!(ids.contains(&"vit-base-patch16-224-onnx"));
    assert!(ids.contains(&"vit-gpt2-image-captioning-onnx"));
    assert!(ids.contains(&"xenova-yolov8n-pose-onnx"));
}

#[test]
fn tts_presets_are_explicit_speaker_conditioned_and_carry_license_metadata() {
    let f5_v1 = ModelPreset::F5TtsV1Base.spec();
    assert_eq!(f5_v1.repo_id, "SWivid/F5-TTS");
    assert_eq!(f5_v1.task, ModelTask::SpeakerConditionedTts);
    assert!(f5_v1.files.contains(&ModelFileRequest::required(
        "F5TTS_v1_Base/model_1250000.safetensors"
    )));
    assert!(f5_v1
        .files
        .contains(&ModelFileRequest::required("F5TTS_v1_Base/vocab.txt")));
    assert_eq!(f5_v1.metadata["license"], "cc-by-nc-4.0");
    assert_eq!(f5_v1.metadata["licenseScope"], "model");
    assert_eq!(f5_v1.metadata["explicitOptIn"], "true");

    let f5_base = ModelPreset::F5TtsBase.spec();
    assert_eq!(f5_base.repo_id, "SWivid/F5-TTS");
    assert_eq!(f5_base.task, ModelTask::SpeakerConditionedTts);
    assert!(f5_base.files.contains(&ModelFileRequest::required(
        "F5TTS_Base/model_1200000.safetensors"
    )));
    assert!(f5_base
        .files
        .contains(&ModelFileRequest::required("F5TTS_Base/vocab.txt")));
    assert_eq!(f5_base.metadata["license"], "cc-by-nc-4.0");

    let e2 = ModelPreset::E2TtsBase.spec();
    assert_eq!(e2.repo_id, "SWivid/E2-TTS");
    assert_eq!(e2.task, ModelTask::SpeakerConditionedTts);
    assert!(e2.files.contains(&ModelFileRequest::required(
        "E2TTS_Base/model_1200000.safetensors"
    )));
    assert_eq!(e2.metadata["license"], "cc-by-nc-4.0");

    let vocos = ModelPreset::VocosMel24Khz.spec();
    assert_eq!(vocos.repo_id, "charactr/vocos-mel-24khz");
    assert_eq!(vocos.task, ModelTask::AudioGeneration);
    assert!(vocos
        .files
        .contains(&ModelFileRequest::required("config.yaml")));
    assert!(vocos
        .files
        .contains(&ModelFileRequest::required("pytorch_model.bin")));
    assert_eq!(vocos.metadata["license"], "mit");
}

#[test]
fn yolov8n_pose_onnx_preset_uses_pose_bundle_layout() {
    let preset = ModelPreset::XenovaYolov8nPoseOnnx;
    let spec = preset.spec();

    assert_eq!(preset.as_str(), "xenova-yolov8n-pose-onnx");
    assert_eq!(spec.repo_id, "Xenova/yolov8n-pose");
    assert_eq!(spec.task, ModelTask::PoseEstimation2d);
    assert!(spec
        .files
        .contains(&ModelFileRequest::required("config.json")));
    assert!(spec
        .files
        .contains(&ModelFileRequest::required("preprocessor_config.json")));
    assert!(spec.files.contains(&ModelFileRequest::first_available([
        "onnx/model_quantized.onnx",
        "onnx/model_int8.onnx",
        "onnx/model.onnx",
    ])));
}

#[test]
fn wav2vec2_base_preset_accepts_vocab_json_tokenizer_layout() {
    let spec = ModelPreset::Wav2Vec2Base960h.spec();

    assert_eq!(spec.repo_id, "facebook/wav2vec2-base-960h");
    assert!(spec
        .files
        .contains(&ModelFileRequest::required("config.json")));
    assert!(spec
        .files
        .contains(&ModelFileRequest::required("preprocessor_config.json")));
    assert!(spec.files.contains(&ModelFileRequest::first_available([
        "tokenizer.json",
        "vocab.json"
    ])));
    assert!(spec
        .files
        .contains(&ModelFileRequest::required("model.safetensors")));
}

#[test]
fn distilbert_sst2_preset_uses_wordpiece_vocab_layout() {
    let spec = ModelPreset::DistilbertSst2.spec();

    assert_eq!(
        spec.repo_id,
        "distilbert-base-uncased-finetuned-sst-2-english"
    );
    assert_eq!(spec.name, "distilbert-sst2");
    assert!(spec
        .files
        .contains(&ModelFileRequest::required("config.json")));
    assert!(spec
        .files
        .contains(&ModelFileRequest::required("tokenizer_config.json")));
    assert!(spec
        .files
        .contains(&ModelFileRequest::required("vocab.txt")));
    assert!(!spec
        .files
        .contains(&ModelFileRequest::required("tokenizer.json")));
    assert!(spec.files.contains(&ModelFileRequest::first_available([
        "model.safetensors",
        "pytorch_model.bin"
    ])));
}

#[test]
fn bart_mnli_onnx_preset_prefers_quantized_smoke_artifact() {
    let spec = ModelPreset::XenovaBartLargeMnliOnnx.spec();

    assert_eq!(spec.repo_id, "Xenova/bart-large-mnli");
    assert!(spec.files.contains(&ModelFileRequest::first_available([
        "onnx/model_quantized.onnx",
        "onnx/encoder_model.onnx",
        "onnx/model.onnx"
    ])));
}

#[test]
fn resolver_loads_existing_bundle_before_download() {
    let temp = tempdir().unwrap();
    let spec = ModelPreset::OnnxCommunityRobertaBaseSquad2.spec();
    let store =
        ModelBundleStore::new(temp.path().join("bundles")).model_downloader(FakeDownloader {
            root: temp.path().join("cache"),
        });
    let existing = store.download(&spec).unwrap();

    let bundle = resolve_or_download_bundle_with_downloader(
        &spec,
        &ModelBundleResolveOptions {
            bundle_root: temp.path().join("bundles"),
            auto_download: true,
            ..ModelBundleResolveOptions::default()
        },
        FakeDownloader {
            root: temp.path().join("unused-cache"),
        },
    )
    .unwrap();

    assert_eq!(bundle.manifest_path(), existing.manifest_path());
}

#[test]
fn resolver_reports_missing_bundle_when_download_disabled() {
    let temp = tempdir().unwrap();
    let spec = ModelPreset::XenovaVitBasePatch16_224Onnx.spec();

    let error = resolve_or_download_bundle_with_downloader(
        &spec,
        &ModelBundleResolveOptions {
            bundle_root: temp.path().join("bundles"),
            auto_download: false,
            ..ModelBundleResolveOptions::default()
        },
        FakeDownloader {
            root: temp.path().join("cache"),
        },
    )
    .expect_err("missing bundle");

    assert!(error.to_string().contains("autoDownload is false"));
    assert!(error.to_string().contains(spec.name.as_str()));
}

#[test]
fn resolver_fake_downloader_materializes_expected_manifest() {
    let temp = tempdir().unwrap();
    let spec = ModelPreset::XenovaVitGpt2ImageCaptioningOnnx.spec();

    let bundle = resolve_or_download_bundle_with_downloader(
        &spec,
        &ModelBundleResolveOptions {
            bundle_root: temp.path().join("bundles"),
            auto_download: true,
            ..ModelBundleResolveOptions::default()
        },
        FakeDownloader {
            root: temp.path().join("cache"),
        },
    )
    .unwrap();

    assert_eq!(bundle.manifest.repo_id, "Xenova/vit-gpt2-image-captioning");
    assert!(bundle.file_path("config.json").unwrap().exists());
    assert!(bundle
        .file_path("onnx/encoder_model_quantized.onnx")
        .unwrap()
        .exists());
    assert!(bundle
        .file_path("onnx/decoder_model_quantized.onnx")
        .unwrap()
        .exists());
}

#[cfg(feature = "jobs")]
#[test]
fn model_bundle_exports_generic_artifact_metadata() {
    let temp = tempdir().unwrap();
    let source = temp.path().join("tokenizer.json");
    std::fs::write(&source, br#"{"tokenizer":"fixture"}"#).unwrap();

    let downloaded = DownloadedModel {
        spec: ModelSpec::new("owner/model", ModelTask::TextEmbedding)
            .name("fixture-model")
            .revision("v1")
            .file("tokenizer.json"),
        files: BTreeMap::from([("tokenizer.json".to_string(), source)]),
    };

    let bundle = ModelBundleStore::new(temp.path().join("bundles"))
        .materialize(&downloaded)
        .unwrap();
    let artifacts = bundle.artifact_refs();

    assert_eq!(artifacts.len(), 1);
    assert_eq!(artifacts[0].metadata["model.repoId"], "owner/model");
    assert_eq!(artifacts[0].metadata["model.revision"], "v1");
    assert_eq!(artifacts[0].metadata["model.task"], "text_embedding");
    assert_eq!(artifacts[0].metadata["model.fileRole"], "tokenizer");
}

#[test]
fn blue_green_prediction_check_remains_generic() {
    let green = vec![RawPrediction::label("positive", 0.9)];
    let blue = vec![RawPrediction::label("positive", 0.90001)];

    let report = crate::compare_blue_green_predictions(
        &green,
        &blue,
        crate::BlueGreenPredictionTestOptions {
            max_score_delta: 0.001,
            compare_regions: true,
        },
    )
    .unwrap();

    assert_eq!(report.compared_predictions, 1);
    assert!(ModelRuntimeBackend::Onnx.as_str().contains("onnx"));
}

#[cfg(feature = "jobs")]
#[test]
fn model_job_spec_records_standard_metadata() {
    let spec = ModelSpec::new("owner/model", ModelTask::TextClassification).revision("v1");
    let job = crate::jobs::model_job_spec(
        "job-1",
        crate::jobs::ModelJobKind::Download,
        &spec,
        ModelRuntimeBackend::Onnx,
    )
    .unwrap();

    assert_eq!(job.kind.as_deref(), Some("model-download"));
    assert_eq!(job.metadata["model.name"], "owner/model");
    assert_eq!(job.metadata["model.task"], "text_classification");
    assert_eq!(job.metadata["model.runtime"], "onnx");
    assert_eq!(job.metadata["model.revision"], "v1");
    assert_eq!(job.metadata["model.repoId"], "owner/model");
}

#[cfg(feature = "jobs")]
#[test]
fn model_access_job_request_records_standard_metadata() {
    use crate::jobs::{run_model_job_inline_for_tests, ModelAccessJobRequest, ModelJobKind};

    let request = ModelAccessJobRequest {
        id: Some("inline-model-job".to_string()),
        kind: ModelJobKind::Inference,
        spec: ModelSpec::new("owner/model", ModelTask::TextClassification).revision("v1"),
        backend: ModelRuntimeBackend::Heuristic,
        inputs: vec![crate::jobs::ModelJobInput::Json(
            serde_json::json!({"text": "hello"}),
        )],
        output_artifact_prefix: Some("prediction".to_string()),
        metadata: BTreeMap::from([("caller".to_string(), "test".to_string())]),
    };

    let result = run_model_job_inline_for_tests(request).unwrap();
    assert_eq!(result.job_id.as_str(), "inline-model-job");
    assert_eq!(result.kind, ModelJobKind::Inference);
    assert_eq!(result.backend, ModelRuntimeBackend::Heuristic);
    assert_eq!(result.output.unwrap()["inline"], true);
}