dakera-inference 0.11.61

Embedded inference engine for Dakera - generates embeddings locally via ONNX Runtime
Documentation
//! Integration tests for `crates/inference` model configuration and types.
//!
//! These tests guard the public API contracts of `EmbeddingModel` and `ModelConfig`
//! without requiring ONNX Runtime or network access.
//!
//! Covers:
//! - Dimension correctness (regression guard: wrong dim silently breaks HNSW index)
//! - Prefix contracts for E5/BGE/MiniLM (wrong prefix = silent quality drop)
//! - ModelConfig builder invariants (pool size clamping, thread setting)
//! - INT8 quantization consistency: all models normalize + mean-pool (so the
//!   same post-processing pipeline applies — no model-specific branches)

use inference::{EmbeddingModel, ModelConfig};

// ─────────────────────────────────────────────────────────────
// EmbeddingModel — dimensions
// ─────────────────────────────────────────────────────────────

#[test]
fn model_bge_large_dimension_is_1024() {
    assert_eq!(
        EmbeddingModel::BgeLarge.dimension(),
        1024,
        "BgeLarge must be 1024-dim for HNSW index compatibility"
    );
}

#[test]
fn model_minilm_dimension_is_384() {
    assert_eq!(EmbeddingModel::MiniLM.dimension(), 384);
}

#[test]
fn model_bge_small_dimension_is_384() {
    assert_eq!(EmbeddingModel::BgeSmall.dimension(), 384);
}

#[test]
fn model_e5_small_dimension_is_384() {
    assert_eq!(EmbeddingModel::E5Small.dimension(), 384);
}

#[test]
fn all_models_have_positive_dimension() {
    for model in EmbeddingModel::all() {
        assert!(
            model.dimension() > 0,
            "{:?} reported dimension 0 — HNSW index would fail",
            model
        );
    }
}

// ─────────────────────────────────────────────────────────────
// EmbeddingModel — sequence lengths
// ─────────────────────────────────────────────────────────────

#[test]
fn model_minilm_max_seq_is_256() {
    // MiniLM is trained with 256-token context. Truncation beyond this silently
    // degrades quality — guard that the constant matches the published spec.
    assert_eq!(EmbeddingModel::MiniLM.max_seq_length(), 256);
}

#[test]
fn model_bge_and_e5_max_seq_is_512() {
    assert_eq!(EmbeddingModel::BgeLarge.max_seq_length(), 512);
    assert_eq!(EmbeddingModel::BgeSmall.max_seq_length(), 512);
    assert_eq!(EmbeddingModel::E5Small.max_seq_length(), 512);
}

// ─────────────────────────────────────────────────────────────
// EmbeddingModel — query/document prefixes
// ─────────────────────────────────────────────────────────────

#[test]
fn model_e5_small_requires_query_prefix() {
    assert_eq!(
        EmbeddingModel::E5Small.query_prefix(),
        Some("query: "),
        "E5-small requires 'query: ' prefix — missing prefix drops retrieval quality"
    );
}

#[test]
fn model_e5_small_requires_passage_prefix() {
    assert_eq!(
        EmbeddingModel::E5Small.document_prefix(),
        Some("passage: "),
        "E5-small requires 'passage: ' prefix for documents"
    );
}

#[test]
fn model_bge_large_has_no_prefix() {
    assert_eq!(EmbeddingModel::BgeLarge.query_prefix(), None);
    assert_eq!(EmbeddingModel::BgeLarge.document_prefix(), None);
}

#[test]
fn model_minilm_has_no_prefix() {
    assert_eq!(EmbeddingModel::MiniLM.query_prefix(), None);
    assert_eq!(EmbeddingModel::MiniLM.document_prefix(), None);
}

#[test]
fn model_bge_small_has_no_prefix() {
    assert_eq!(EmbeddingModel::BgeSmall.query_prefix(), None);
    assert_eq!(EmbeddingModel::BgeSmall.document_prefix(), None);
}

// ─────────────────────────────────────────────────────────────
// INT8 quantization consistency — all models must normalize + mean-pool
// ─────────────────────────────────────────────────────────────

#[test]
fn all_models_use_mean_pooling() {
    // All Xenova INT8 ONNX models output last-hidden-state, not CLS token.
    // A false `use_mean_pooling = false` would silently use wrong pooling.
    for model in EmbeddingModel::all() {
        assert!(
            model.use_mean_pooling(),
            "{:?} must use mean pooling — CLS pooling is incorrect for this family",
            model
        );
    }
}

#[test]
fn all_models_normalize_embeddings() {
    // Normalization is required for cosine similarity via dot product.
    // A model that skips normalization would produce incorrect similarity scores.
    for model in EmbeddingModel::all() {
        assert!(
            model.normalize_embeddings(),
            "{:?} must normalize embeddings for cosine similarity correctness",
            model
        );
    }
}

// ─────────────────────────────────────────────────────────────
// EmbeddingModel — parse
// ─────────────────────────────────────────────────────────────

#[test]
fn model_parse_bge_large_variants() {
    assert_eq!(
        EmbeddingModel::parse("bge-large"),
        Some(EmbeddingModel::BgeLarge)
    );
    assert_eq!(
        EmbeddingModel::parse("BGE-LARGE"),
        Some(EmbeddingModel::BgeLarge)
    );
    assert_eq!(
        EmbeddingModel::parse("bge-large-en-v1.5"),
        Some(EmbeddingModel::BgeLarge)
    );
}

#[test]
fn model_parse_minilm_variants() {
    assert_eq!(
        EmbeddingModel::parse("minilm"),
        Some(EmbeddingModel::MiniLM)
    );
    assert_eq!(
        EmbeddingModel::parse("MiniLM"),
        Some(EmbeddingModel::MiniLM)
    );
}

#[test]
fn model_parse_unknown_returns_none() {
    assert_eq!(EmbeddingModel::parse("gpt4"), None);
    assert_eq!(EmbeddingModel::parse(""), None);
}

#[test]
fn model_all_returns_four_variants() {
    assert_eq!(EmbeddingModel::all().len(), 4);
    assert!(EmbeddingModel::all().contains(&EmbeddingModel::BgeLarge));
    assert!(EmbeddingModel::all().contains(&EmbeddingModel::MiniLM));
    assert!(EmbeddingModel::all().contains(&EmbeddingModel::BgeSmall));
    assert!(EmbeddingModel::all().contains(&EmbeddingModel::E5Small));
}

// ─────────────────────────────────────────────────────────────
// EmbeddingModel — Display
// ─────────────────────────────────────────────────────────────

#[test]
fn model_display_known_format_strings() {
    assert_eq!(EmbeddingModel::BgeLarge.to_string(), "bge-large-en-v1.5");
    assert_eq!(EmbeddingModel::MiniLM.to_string(), "all-MiniLM-L6-v2");
    assert_eq!(EmbeddingModel::BgeSmall.to_string(), "bge-small-en-v1.5");
    assert_eq!(EmbeddingModel::E5Small.to_string(), "e5-small-v2");
}

#[test]
fn model_onnx_repo_ids_are_xenova_hosted() {
    // All ONNX INT8 models come from the Xenova namespace on HuggingFace.
    // A wrong repo ID means the model download silently fetches the wrong file.
    for model in EmbeddingModel::all() {
        assert!(
            model.onnx_repo_id().starts_with("Xenova/"),
            "{:?} onnx_repo_id should be 'Xenova/...', got: {}",
            model,
            model.onnx_repo_id()
        );
    }
}

#[test]
fn model_onnx_filename_is_quantized() {
    // All models use the quantized INT8 ONNX file — not the fp32 version.
    for model in EmbeddingModel::all() {
        assert!(
            model.onnx_filename().contains("quantized"),
            "{:?} should use quantized ONNX file, got: {}",
            model,
            model.onnx_filename()
        );
    }
}

// ─────────────────────────────────────────────────────────────
// ModelConfig — construction and builder API
// ─────────────────────────────────────────────────────────────

#[test]
fn model_config_new_sets_model() {
    let cfg = ModelConfig::new(EmbeddingModel::MiniLM);
    assert_eq!(cfg.model, EmbeddingModel::MiniLM);
}

#[test]
fn model_config_default_uses_bge_large() {
    // lib.rs test already checks this but we guard it from the integration level too.
    let cfg = ModelConfig::default();
    assert_eq!(cfg.model, EmbeddingModel::BgeLarge);
    assert!(!cfg.use_gpu, "GPU should be disabled by default");
}

#[test]
fn model_config_builder_max_batch_size() {
    let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(16);
    assert_eq!(cfg.max_batch_size, 16);
}

#[test]
fn model_config_builder_with_gpu() {
    let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
    assert!(cfg.use_gpu);
}

#[test]
fn model_config_builder_num_threads() {
    let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
    assert_eq!(cfg.num_threads, Some(8));
}

#[test]
fn model_config_session_pool_size_clamps_zero_to_one() {
    // Passing 0 would create an empty pool — ORT would have no session to use.
    // with_session_pool_size must clamp to ≥1.
    let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(0);
    assert_eq!(
        cfg.session_pool_size, 1,
        "session_pool_size=0 must be clamped to 1"
    );
}

#[test]
fn model_config_session_pool_size_accepts_valid_value() {
    let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(4);
    assert_eq!(cfg.session_pool_size, 4);
}

#[test]
fn model_config_builder_chain_all_fields() {
    let cfg = ModelConfig::new(EmbeddingModel::E5Small)
        .with_max_batch_size(32)
        .with_gpu(false)
        .with_num_threads(4)
        .with_session_pool_size(2)
        .with_cache_dir("/tmp/models");
    assert_eq!(cfg.model, EmbeddingModel::E5Small);
    assert_eq!(cfg.max_batch_size, 32);
    assert!(!cfg.use_gpu);
    assert_eq!(cfg.num_threads, Some(4));
    assert_eq!(cfg.session_pool_size, 2);
    assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
}