use serde::{Deserialize, Serialize};
use std::time::SystemTime;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelProvenance {
pub model: EmbeddingModel,
pub model_id: String,
pub hash: String,
pub loaded_at: SystemTime,
pub loaded_at_iso: String,
}
impl ModelProvenance {
pub fn new(model: EmbeddingModel, model_id: String) -> Self {
let loaded_at = SystemTime::now();
let loaded_at_iso = {
let dt: chrono::DateTime<chrono::Utc> = loaded_at.into();
dt.to_rfc3339()
};
let hash_input = format!("{model_id}:{loaded_at_iso}:{model:?}");
let hash = blake3::hash(hash_input.as_bytes()).to_hex().to_string();
Self {
model,
model_id,
hash,
loaded_at,
loaded_at_iso,
}
}
pub fn dimensions(&self) -> usize {
self.model.dimensions()
}
pub fn matches_model(&self, expected: EmbeddingModel) -> bool {
self.model == expected
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum EmbeddingModel {
#[default]
#[serde(alias = "BgeSmallEnV15")]
BgeSmallEnV15,
#[serde(alias = "BgeBaseEnV15")]
BgeBaseEnV15,
#[serde(alias = "BgeLargeEnV15")]
BgeLargeEnV15,
#[serde(alias = "MultilingualE5Small")]
MultilingualE5Small,
#[serde(alias = "MultilingualE5Base")]
MultilingualE5Base,
#[serde(alias = "Qwen3Embedding0_6B")]
Qwen3Embedding0_6B,
#[serde(alias = "Qwen3Embedding4B")]
Qwen3Embedding4B,
#[serde(alias = "AllMiniLmL6V2")]
AllMiniLmL6V2,
#[serde(alias = "ParaphraseMultilingualMiniLmL12V2")]
ParaphraseMultilingualMiniLmL12V2,
#[serde(alias = "TextEmbedding3Small")]
TextEmbedding3Small,
}
impl EmbeddingModel {
#[inline]
pub const fn native_dimensions(&self) -> usize {
match self {
EmbeddingModel::BgeSmallEnV15
| EmbeddingModel::MultilingualE5Small
| EmbeddingModel::AllMiniLmL6V2
| EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => 384,
EmbeddingModel::BgeBaseEnV15 | EmbeddingModel::MultilingualE5Base => 768,
EmbeddingModel::BgeLargeEnV15 | EmbeddingModel::Qwen3Embedding0_6B => 1024,
EmbeddingModel::Qwen3Embedding4B => 2560,
EmbeddingModel::TextEmbedding3Small => 1536,
}
}
#[inline]
pub const fn dimensions(&self) -> usize {
self.native_dimensions()
}
#[inline]
pub const fn is_local(&self) -> bool {
matches!(
self,
EmbeddingModel::BgeSmallEnV15
| EmbeddingModel::BgeBaseEnV15
| EmbeddingModel::BgeLargeEnV15
| EmbeddingModel::MultilingualE5Small
| EmbeddingModel::MultilingualE5Base
| EmbeddingModel::AllMiniLmL6V2
| EmbeddingModel::ParaphraseMultilingualMiniLmL12V2
| EmbeddingModel::Qwen3Embedding0_6B
| EmbeddingModel::Qwen3Embedding4B
)
}
#[inline]
pub const fn is_remote(&self) -> bool {
matches!(self, EmbeddingModel::TextEmbedding3Small)
}
#[inline]
pub const fn max_input_tokens(&self) -> usize {
match self {
EmbeddingModel::BgeSmallEnV15 => 512,
EmbeddingModel::BgeBaseEnV15 => 512,
EmbeddingModel::BgeLargeEnV15 => 512,
EmbeddingModel::MultilingualE5Small => 512,
EmbeddingModel::MultilingualE5Base => 512,
EmbeddingModel::AllMiniLmL6V2 => 256,
EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => 128,
EmbeddingModel::Qwen3Embedding0_6B => 8192,
EmbeddingModel::Qwen3Embedding4B => 8192,
EmbeddingModel::TextEmbedding3Small => 8191,
}
}
#[inline]
pub const fn query_instruction(&self) -> Option<&'static str> {
match self {
EmbeddingModel::MultilingualE5Small | EmbeddingModel::MultilingualE5Base => {
Some("query: ")
}
EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B => Some(
"Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: ",
),
_ => None,
}
}
#[inline]
pub const fn document_instruction(&self) -> Option<&'static str> {
None
}
#[inline]
pub const fn model_id(&self) -> &'static str {
match self {
EmbeddingModel::BgeSmallEnV15 => "BAAI/bge-small-en-v1.5",
EmbeddingModel::BgeBaseEnV15 => "BAAI/bge-base-en-v1.5",
EmbeddingModel::BgeLargeEnV15 => "BAAI/bge-large-en-v1.5",
EmbeddingModel::MultilingualE5Small => "intfloat/multilingual-e5-small",
EmbeddingModel::MultilingualE5Base => "intfloat/multilingual-e5-base",
EmbeddingModel::AllMiniLmL6V2 => "sentence-transformers/all-MiniLM-L6-v2",
EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
}
EmbeddingModel::Qwen3Embedding0_6B => "Qwen/Qwen3-Embedding-0.6B",
EmbeddingModel::Qwen3Embedding4B => "Qwen/Qwen3-Embedding-4B",
EmbeddingModel::TextEmbedding3Small => "text-embedding-3-small",
}
}
#[inline]
pub const fn supports_output_dim(&self) -> bool {
matches!(
self,
EmbeddingModel::Qwen3Embedding0_6B | EmbeddingModel::Qwen3Embedding4B
)
}
#[inline]
pub const fn key_version(&self) -> &'static str {
match self {
EmbeddingModel::TextEmbedding3Small
| EmbeddingModel::Qwen3Embedding0_6B
| EmbeddingModel::Qwen3Embedding4B => "v3",
EmbeddingModel::AllMiniLmL6V2 | EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
"v2"
}
_ => "v1.5",
}
}
}
impl std::fmt::Display for EmbeddingModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EmbeddingModel::BgeSmallEnV15 => write!(f, "bge-small-en-v1.5"),
EmbeddingModel::BgeBaseEnV15 => write!(f, "bge-base-en-v1.5"),
EmbeddingModel::BgeLargeEnV15 => write!(f, "bge-large-en-v1.5"),
EmbeddingModel::MultilingualE5Small => write!(f, "multilingual-e5-small"),
EmbeddingModel::MultilingualE5Base => write!(f, "multilingual-e5-base"),
EmbeddingModel::Qwen3Embedding0_6B => write!(f, "qwen3-embedding-0.6b"),
EmbeddingModel::Qwen3Embedding4B => write!(f, "qwen3-embedding-4b"),
EmbeddingModel::AllMiniLmL6V2 => write!(f, "all-minilm-l6-v2"),
EmbeddingModel::ParaphraseMultilingualMiniLmL12V2 => {
write!(f, "paraphrase-multilingual-minilm-l12-v2")
}
EmbeddingModel::TextEmbedding3Small => write!(f, "text-embedding-3-small"),
}
}
}
impl std::str::FromStr for EmbeddingModel {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let lower = s.to_lowercase();
let normalized = lower.trim().replace("_", "-").replace("baai/", "");
match normalized.as_str() {
"bge-small-en-v1.5" | "bge-small-en" | "bge-small" | "small" => {
Ok(EmbeddingModel::BgeSmallEnV15)
}
"bge-base-en-v1.5" | "bge-base-en" | "bge-base" | "base" => {
Ok(EmbeddingModel::BgeBaseEnV15)
}
"bge-large-en-v1.5" | "bge-large-en" | "bge-large" | "large" => {
Ok(EmbeddingModel::BgeLargeEnV15)
}
"multilingual-e5-small" | "e5-small" | "intfloat/multilingual-e5-small" => {
Ok(EmbeddingModel::MultilingualE5Small)
}
"multilingual-e5-base" | "e5-base" | "intfloat/multilingual-e5-base" => {
Ok(EmbeddingModel::MultilingualE5Base)
}
"qwen3-embedding-0.6b" | "qwen3-embedding" | "qwen3" | "qwen/qwen3-embedding-0.6b" => {
Ok(EmbeddingModel::Qwen3Embedding0_6B)
}
"qwen3-embedding-4b" | "qwen3-4b" | "qwen/qwen3-embedding-4b" => {
Ok(EmbeddingModel::Qwen3Embedding4B)
}
"all-minilm-l6-v2"
| "minilm"
| "all-minilm"
| "sentence-transformers/all-minilm-l6-v2" => Ok(EmbeddingModel::AllMiniLmL6V2),
"paraphrase-multilingual-minilm-l12-v2"
| "paraphrase-multilingual"
| "multilingual-minilm"
| "sentence-transformers/paraphrase-multilingual-minilm-l12-v2" => {
Ok(EmbeddingModel::ParaphraseMultilingualMiniLmL12V2)
}
"text-embedding-3-small" | "openai-small" | "openai" => {
Ok(EmbeddingModel::TextEmbedding3Small)
}
_ => Err(format!(
"unknown embedding model: '{s}'. Valid: bge-small-en-v1.5, bge-base-en-v1.5, bge-large-en-v1.5, multilingual-e5-small, multilingual-e5-base, text-embedding-3-small"
)),
}
}
}
pub const MIN_MRL_OUTPUT_DIM: usize = 32;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelConfig {
pub model: EmbeddingModel,
#[serde(default)]
pub output_dim: Option<usize>,
}
impl Default for ModelConfig {
fn default() -> Self {
Self::new(EmbeddingModel::default())
}
}
impl ModelConfig {
pub const fn new(model: EmbeddingModel) -> Self {
Self {
model,
output_dim: None,
}
}
pub fn try_new(
model: EmbeddingModel,
output_dim: Option<usize>,
) -> std::result::Result<Self, crate::error::EmbedError> {
let config = Self { model, output_dim };
config.validate()?;
Ok(config)
}
pub fn validate(&self) -> std::result::Result<(), crate::error::EmbedError> {
let Some(dim) = self.output_dim else {
return Ok(());
};
if !self.model.supports_output_dim() {
return Err(crate::error::EmbedError::InvalidInput(format!(
"{} does not support configurable embedding dimensions",
self.model
)));
}
if dim < MIN_MRL_OUTPUT_DIM {
return Err(crate::error::EmbedError::InvalidInput(format!(
"embedding output dimension {dim} is below minimum {MIN_MRL_OUTPUT_DIM}"
)));
}
let native = self.model.native_dimensions();
if dim > native {
return Err(crate::error::EmbedError::InvalidInput(format!(
"embedding output dimension {dim} exceeds native dimension {native} for {}",
self.model
)));
}
Ok(())
}
pub fn dimensions(&self) -> usize {
self.output_dim
.unwrap_or_else(|| self.model.native_dimensions())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_model() {
let model = EmbeddingModel::default();
assert_eq!(model, EmbeddingModel::BgeSmallEnV15);
}
#[test]
fn test_model_provenance_new() {
let provenance = ModelProvenance::new(
EmbeddingModel::BgeSmallEnV15,
"BAAI/bge-small-en-v1.5".into(),
);
assert_eq!(provenance.model, EmbeddingModel::BgeSmallEnV15);
assert_eq!(provenance.model_id, "BAAI/bge-small-en-v1.5");
assert!(!provenance.hash.is_empty());
assert_eq!(provenance.hash.len(), 64); assert!(!provenance.loaded_at_iso.is_empty());
}
#[test]
fn test_model_provenance_unique_hash() {
let p1 = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "model1".into());
std::thread::sleep(std::time::Duration::from_millis(10)); let p2 = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "model1".into());
assert_ne!(p1.hash, p2.hash);
}
#[test]
fn test_model_provenance_dimensions() {
let p1 = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "small".into());
assert_eq!(p1.dimensions(), 384);
let p2 = ModelProvenance::new(EmbeddingModel::BgeBaseEnV15, "base".into());
assert_eq!(p2.dimensions(), 768);
let p3 = ModelProvenance::new(EmbeddingModel::BgeLargeEnV15, "large".into());
assert_eq!(p3.dimensions(), 1024);
}
#[test]
fn test_model_provenance_matches_model() {
let provenance = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "test".into());
assert!(provenance.matches_model(EmbeddingModel::BgeSmallEnV15));
assert!(!provenance.matches_model(EmbeddingModel::BgeBaseEnV15));
assert!(!provenance.matches_model(EmbeddingModel::BgeLargeEnV15));
}
#[test]
fn test_model_provenance_serialization() {
let provenance = ModelProvenance::new(EmbeddingModel::BgeSmallEnV15, "test-model".into());
let json = serde_json::to_string(&provenance).unwrap();
assert!(json.contains("bge_small_en_v15"), "json={json}");
assert!(json.contains("test-model"));
assert!(json.contains(&provenance.hash));
let parsed: ModelProvenance = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.model, provenance.model);
assert_eq!(parsed.model_id, provenance.model_id);
assert_eq!(parsed.hash, provenance.hash);
}
#[test]
fn test_dimensions() {
assert_eq!(EmbeddingModel::BgeSmallEnV15.dimensions(), 384);
assert_eq!(EmbeddingModel::BgeBaseEnV15.dimensions(), 768);
assert_eq!(EmbeddingModel::BgeLargeEnV15.dimensions(), 1024);
assert_eq!(EmbeddingModel::Qwen3Embedding4B.dimensions(), 2560);
}
#[test]
fn test_model_config_native_dims() {
assert_eq!(
ModelConfig::new(EmbeddingModel::Qwen3Embedding4B).dimensions(),
2560
);
assert_eq!(
ModelConfig::new(EmbeddingModel::Qwen3Embedding0_6B).dimensions(),
1024
);
assert_eq!(
ModelConfig::new(EmbeddingModel::BgeSmallEnV15).dimensions(),
384
);
}
#[test]
fn test_model_config_configured_dim() {
let cfg = ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(1024)).unwrap();
assert_eq!(cfg.dimensions(), 1024);
let cfg = ModelConfig::try_new(EmbeddingModel::Qwen3Embedding0_6B, Some(512)).unwrap();
assert_eq!(cfg.dimensions(), 512);
}
#[test]
fn test_model_config_validation_below_min() {
assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(31)).is_err());
assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(0)).is_err());
}
#[test]
fn test_model_config_validation_above_native() {
assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, Some(2561)).is_err());
assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding0_6B, Some(1025)).is_err());
}
#[test]
fn test_model_config_validation_non_mrl_model() {
assert!(ModelConfig::try_new(EmbeddingModel::BgeSmallEnV15, Some(128)).is_err());
assert!(ModelConfig::try_new(EmbeddingModel::BgeBaseEnV15, Some(512)).is_err());
}
#[test]
fn test_model_config_none_output_dim_ok_for_any_model() {
assert!(ModelConfig::try_new(EmbeddingModel::BgeSmallEnV15, None).is_ok());
assert!(ModelConfig::try_new(EmbeddingModel::Qwen3Embedding4B, None).is_ok());
}
#[test]
fn test_is_local() {
assert!(EmbeddingModel::BgeSmallEnV15.is_local());
assert!(EmbeddingModel::BgeBaseEnV15.is_local());
assert!(EmbeddingModel::BgeLargeEnV15.is_local());
}
#[test]
fn test_display() {
assert_eq!(
EmbeddingModel::BgeSmallEnV15.to_string(),
"bge-small-en-v1.5"
);
assert_eq!(EmbeddingModel::BgeBaseEnV15.to_string(), "bge-base-en-v1.5");
assert_eq!(
EmbeddingModel::BgeLargeEnV15.to_string(),
"bge-large-en-v1.5"
);
}
#[test]
fn test_serialization_roundtrip() {
let model = EmbeddingModel::BgeSmallEnV15;
let json = serde_json::to_string(&model).unwrap();
let parsed: EmbeddingModel = serde_json::from_str(&json).unwrap();
assert_eq!(model, parsed);
}
#[test]
fn test_max_input_tokens() {
assert_eq!(EmbeddingModel::BgeSmallEnV15.max_input_tokens(), 512);
assert_eq!(EmbeddingModel::BgeBaseEnV15.max_input_tokens(), 512);
assert_eq!(EmbeddingModel::BgeLargeEnV15.max_input_tokens(), 512);
}
#[test]
fn test_from_str_display_names() {
assert_eq!(
"bge-small-en-v1.5".parse::<EmbeddingModel>().unwrap(),
EmbeddingModel::BgeSmallEnV15
);
assert_eq!(
"bge-base-en-v1.5".parse::<EmbeddingModel>().unwrap(),
EmbeddingModel::BgeBaseEnV15
);
assert_eq!(
"bge-large-en-v1.5".parse::<EmbeddingModel>().unwrap(),
EmbeddingModel::BgeLargeEnV15
);
}
#[test]
fn test_from_str_short_names() {
assert_eq!(
"small".parse::<EmbeddingModel>().unwrap(),
EmbeddingModel::BgeSmallEnV15
);
assert_eq!(
"bge-base".parse::<EmbeddingModel>().unwrap(),
EmbeddingModel::BgeBaseEnV15
);
assert_eq!(
"LARGE".parse::<EmbeddingModel>().unwrap(), EmbeddingModel::BgeLargeEnV15
);
}
#[test]
fn test_from_str_huggingface_ids() {
assert_eq!(
"BAAI/bge-small-en-v1.5".parse::<EmbeddingModel>().unwrap(),
EmbeddingModel::BgeSmallEnV15
);
}
#[test]
fn test_from_str_invalid() {
let result = "unknown-model".parse::<EmbeddingModel>();
assert!(result.is_err());
assert!(result.unwrap_err().contains("unknown embedding model"));
}
}