use inference::{EmbeddingModel, ModelConfig};
#[test]
fn model_bge_large_dimension_is_1024() {
assert_eq!(
EmbeddingModel::BgeLarge.dimension(),
1024,
"BgeLarge must be 1024-dim for HNSW index compatibility"
);
}
#[test]
fn model_minilm_dimension_is_384() {
assert_eq!(EmbeddingModel::MiniLM.dimension(), 384);
}
#[test]
fn model_bge_small_dimension_is_384() {
assert_eq!(EmbeddingModel::BgeSmall.dimension(), 384);
}
#[test]
fn model_e5_small_dimension_is_384() {
assert_eq!(EmbeddingModel::E5Small.dimension(), 384);
}
#[test]
fn all_models_have_positive_dimension() {
for model in EmbeddingModel::all() {
assert!(
model.dimension() > 0,
"{:?} reported dimension 0 — HNSW index would fail",
model
);
}
}
#[test]
fn model_minilm_max_seq_is_256() {
assert_eq!(EmbeddingModel::MiniLM.max_seq_length(), 256);
}
#[test]
fn model_bge_and_e5_max_seq_is_512() {
assert_eq!(EmbeddingModel::BgeLarge.max_seq_length(), 512);
assert_eq!(EmbeddingModel::BgeSmall.max_seq_length(), 512);
assert_eq!(EmbeddingModel::E5Small.max_seq_length(), 512);
}
#[test]
fn model_e5_small_requires_query_prefix() {
assert_eq!(
EmbeddingModel::E5Small.query_prefix(),
Some("query: "),
"E5-small requires 'query: ' prefix — missing prefix drops retrieval quality"
);
}
#[test]
fn model_e5_small_requires_passage_prefix() {
assert_eq!(
EmbeddingModel::E5Small.document_prefix(),
Some("passage: "),
"E5-small requires 'passage: ' prefix for documents"
);
}
#[test]
fn model_bge_large_has_no_prefix() {
assert_eq!(EmbeddingModel::BgeLarge.query_prefix(), None);
assert_eq!(EmbeddingModel::BgeLarge.document_prefix(), None);
}
#[test]
fn model_minilm_has_no_prefix() {
assert_eq!(EmbeddingModel::MiniLM.query_prefix(), None);
assert_eq!(EmbeddingModel::MiniLM.document_prefix(), None);
}
#[test]
fn model_bge_small_has_no_prefix() {
assert_eq!(EmbeddingModel::BgeSmall.query_prefix(), None);
assert_eq!(EmbeddingModel::BgeSmall.document_prefix(), None);
}
#[test]
fn all_models_use_mean_pooling() {
for model in EmbeddingModel::all() {
assert!(
model.use_mean_pooling(),
"{:?} must use mean pooling — CLS pooling is incorrect for this family",
model
);
}
}
#[test]
fn all_models_normalize_embeddings() {
for model in EmbeddingModel::all() {
assert!(
model.normalize_embeddings(),
"{:?} must normalize embeddings for cosine similarity correctness",
model
);
}
}
#[test]
fn model_parse_bge_large_variants() {
assert_eq!(
EmbeddingModel::parse("bge-large"),
Some(EmbeddingModel::BgeLarge)
);
assert_eq!(
EmbeddingModel::parse("BGE-LARGE"),
Some(EmbeddingModel::BgeLarge)
);
assert_eq!(
EmbeddingModel::parse("bge-large-en-v1.5"),
Some(EmbeddingModel::BgeLarge)
);
}
#[test]
fn model_parse_minilm_variants() {
assert_eq!(
EmbeddingModel::parse("minilm"),
Some(EmbeddingModel::MiniLM)
);
assert_eq!(
EmbeddingModel::parse("MiniLM"),
Some(EmbeddingModel::MiniLM)
);
}
#[test]
fn model_parse_unknown_returns_none() {
assert_eq!(EmbeddingModel::parse("gpt4"), None);
assert_eq!(EmbeddingModel::parse(""), None);
}
#[test]
fn model_all_returns_four_variants() {
assert_eq!(EmbeddingModel::all().len(), 5);
assert!(EmbeddingModel::all().contains(&EmbeddingModel::BgeLarge));
assert!(EmbeddingModel::all().contains(&EmbeddingModel::MiniLM));
assert!(EmbeddingModel::all().contains(&EmbeddingModel::BgeSmall));
assert!(EmbeddingModel::all().contains(&EmbeddingModel::E5Small));
assert!(EmbeddingModel::all().contains(&EmbeddingModel::ModernBertEmbedBase));
}
#[test]
fn model_display_known_format_strings() {
assert_eq!(EmbeddingModel::BgeLarge.to_string(), "bge-large-en-v1.5");
assert_eq!(EmbeddingModel::MiniLM.to_string(), "all-MiniLM-L6-v2");
assert_eq!(EmbeddingModel::BgeSmall.to_string(), "bge-small-en-v1.5");
assert_eq!(EmbeddingModel::E5Small.to_string(), "e5-small-v2");
}
#[test]
fn model_onnx_repo_ids_are_xenova_hosted() {
for model in EmbeddingModel::all() {
assert!(
model.onnx_repo_id().starts_with("Xenova/"),
"{:?} onnx_repo_id should be 'Xenova/...', got: {}",
model,
model.onnx_repo_id()
);
}
}
#[test]
fn model_onnx_filename_is_quantized() {
for model in EmbeddingModel::all() {
assert!(
model.onnx_filename().contains("quantized"),
"{:?} should use quantized ONNX file, got: {}",
model,
model.onnx_filename()
);
}
}
#[test]
fn model_config_new_sets_model() {
let cfg = ModelConfig::new(EmbeddingModel::MiniLM);
assert_eq!(cfg.model, EmbeddingModel::MiniLM);
}
#[test]
fn model_config_default_uses_bge_large() {
let cfg = ModelConfig::default();
assert_eq!(cfg.model, EmbeddingModel::BgeLarge);
assert!(!cfg.use_gpu, "GPU should be disabled by default");
}
#[test]
fn model_config_builder_max_batch_size() {
let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_max_batch_size(16);
assert_eq!(cfg.max_batch_size, 16);
}
#[test]
fn model_config_builder_with_gpu() {
let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_gpu(true);
assert!(cfg.use_gpu);
}
#[test]
fn model_config_builder_num_threads() {
let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_num_threads(8);
assert_eq!(cfg.num_threads, Some(8));
}
#[test]
fn model_config_session_pool_size_clamps_zero_to_one() {
let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(0);
assert_eq!(
cfg.session_pool_size, 1,
"session_pool_size=0 must be clamped to 1"
);
}
#[test]
fn model_config_session_pool_size_accepts_valid_value() {
let cfg = ModelConfig::new(EmbeddingModel::MiniLM).with_session_pool_size(4);
assert_eq!(cfg.session_pool_size, 4);
}
#[test]
fn model_config_builder_chain_all_fields() {
let cfg = ModelConfig::new(EmbeddingModel::E5Small)
.with_max_batch_size(32)
.with_gpu(false)
.with_num_threads(4)
.with_session_pool_size(2)
.with_cache_dir("/tmp/models");
assert_eq!(cfg.model, EmbeddingModel::E5Small);
assert_eq!(cfg.max_batch_size, 32);
assert!(!cfg.use_gpu);
assert_eq!(cfg.num_threads, Some(4));
assert_eq!(cfg.session_pool_size, 2);
assert_eq!(cfg.cache_dir, Some("/tmp/models".to_string()));
}