use crate::pooling::Pooling;
use super::model_info::ModelInfo;
use super::quantization::QuantizationMode;
use std::{collections::HashMap, fmt::Display, sync::OnceLock};
static MODEL_MAP: OnceLock<HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>>> = OnceLock::new();
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum EmbeddingModel {
AllMiniLML6V2,
AllMiniLML6V2Q,
AllMiniLML12V2,
AllMiniLML12V2Q,
BGEBaseENV15,
BGEBaseENV15Q,
BGELargeENV15,
BGELargeENV15Q,
BGESmallENV15,
BGESmallENV15Q,
NomicEmbedTextV1,
NomicEmbedTextV15,
NomicEmbedTextV15Q,
ParaphraseMLMiniLML12V2,
ParaphraseMLMiniLML12V2Q,
ParaphraseMLMpnetBaseV2,
BGESmallZHV15,
MultilingualE5Small,
MultilingualE5Base,
MultilingualE5Large,
MxbaiEmbedLargeV1,
MxbaiEmbedLargeV1Q,
GTEBaseENV15,
GTEBaseENV15Q,
GTELargeENV15,
GTELargeENV15Q,
ClipVitB32,
JinaEmbeddingsV2BaseCode,
}
fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
let models_list = vec![
ModelInfo {
model: EmbeddingModel::AllMiniLML6V2,
dim: 384,
description: String::from("Sentence Transformer model, MiniLM-L6-v2"),
model_code: String::from("Qdrant/all-MiniLM-L6-v2-onnx"),
model_file: String::from("model.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::AllMiniLML6V2Q,
dim: 384,
description: String::from("Quantized Sentence Transformer model, MiniLM-L6-v2"),
model_code: String::from("Xenova/all-MiniLM-L6-v2"),
model_file: String::from("onnx/model_quantized.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::AllMiniLML12V2,
dim: 384,
description: String::from("Sentence Transformer model, MiniLM-L12-v2"),
model_code: String::from("Xenova/all-MiniLM-L12-v2"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::AllMiniLML12V2Q,
dim: 384,
description: String::from("Quantized Sentence Transformer model, MiniLM-L12-v2"),
model_code: String::from("Xenova/all-MiniLM-L12-v2"),
model_file: String::from("onnx/model_quantized.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::BGEBaseENV15,
dim: 768,
description: String::from("v1.5 release of the base English model"),
model_code: String::from("Xenova/bge-base-en-v1.5"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::BGEBaseENV15Q,
dim: 768,
description: String::from("Quantized v1.5 release of the large English model"),
model_code: String::from("Qdrant/bge-base-en-v1.5-onnx-Q"),
model_file: String::from("model_optimized.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::BGELargeENV15,
dim: 1024,
description: String::from("v1.5 release of the large English model"),
model_code: String::from("Xenova/bge-large-en-v1.5"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::BGELargeENV15Q,
dim: 1024,
description: String::from("Quantized v1.5 release of the large English model"),
model_code: String::from("Qdrant/bge-large-en-v1.5-onnx-Q"),
model_file: String::from("model_optimized.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::BGESmallENV15,
dim: 384,
description: String::from("v1.5 release of the fast and default English model"),
model_code: String::from("Xenova/bge-small-en-v1.5"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::BGESmallENV15Q,
dim: 384,
description: String::from(
"Quantized v1.5 release of the fast and default English model",
),
model_code: String::from("Qdrant/bge-small-en-v1.5-onnx-Q"),
model_file: String::from("model_optimized.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::NomicEmbedTextV1,
dim: 768,
description: String::from("8192 context length english model"),
model_code: String::from("nomic-ai/nomic-embed-text-v1"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::NomicEmbedTextV15,
dim: 768,
description: String::from("v1.5 release of the 8192 context length english model"),
model_code: String::from("nomic-ai/nomic-embed-text-v1.5"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::NomicEmbedTextV15Q,
dim: 768,
description: String::from(
"Quantized v1.5 release of the 8192 context length english model",
),
model_code: String::from("nomic-ai/nomic-embed-text-v1.5"),
model_file: String::from("onnx/model_quantized.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::ParaphraseMLMiniLML12V2Q,
dim: 384,
description: String::from("Quantized Multi-lingual model"),
model_code: String::from("Qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q"),
model_file: String::from("model_optimized.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::ParaphraseMLMiniLML12V2,
dim: 384,
description: String::from("Multi-lingual model"),
model_code: String::from("Xenova/paraphrase-multilingual-MiniLM-L12-v2"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::ParaphraseMLMpnetBaseV2,
dim: 768,
description: String::from(
"Sentence-transformers model for tasks like clustering or semantic search",
),
model_code: String::from("Xenova/paraphrase-multilingual-mpnet-base-v2"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::BGESmallZHV15,
dim: 512,
description: String::from("v1.5 release of the small Chinese model"),
model_code: String::from("Xenova/bge-small-zh-v1.5"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::MultilingualE5Small,
dim: 384,
description: String::from("Small model of multilingual E5 Text Embeddings"),
model_code: String::from("intfloat/multilingual-e5-small"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::MultilingualE5Base,
dim: 768,
description: String::from("Base model of multilingual E5 Text Embeddings"),
model_code: String::from("intfloat/multilingual-e5-base"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::MultilingualE5Large,
dim: 1024,
description: String::from("Large model of multilingual E5 Text Embeddings"),
model_code: String::from("Qdrant/multilingual-e5-large-onnx"),
model_file: String::from("model.onnx"),
additional_files: vec!["model.onnx_data".to_string()],
},
ModelInfo {
model: EmbeddingModel::MxbaiEmbedLargeV1,
dim: 1024,
description: String::from("Large English embedding model from MixedBreed.ai"),
model_code: String::from("mixedbread-ai/mxbai-embed-large-v1"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::MxbaiEmbedLargeV1Q,
dim: 1024,
description: String::from("Quantized Large English embedding model from MixedBreed.ai"),
model_code: String::from("mixedbread-ai/mxbai-embed-large-v1"),
model_file: String::from("onnx/model_quantized.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::GTEBaseENV15,
dim: 768,
description: String::from("Large multilingual embedding model from Alibaba"),
model_code: String::from("Alibaba-NLP/gte-base-en-v1.5"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::GTEBaseENV15Q,
dim: 768,
description: String::from("Quantized Large multilingual embedding model from Alibaba"),
model_code: String::from("Alibaba-NLP/gte-base-en-v1.5"),
model_file: String::from("onnx/model_quantized.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::GTELargeENV15,
dim: 1024,
description: String::from("Large multilingual embedding model from Alibaba"),
model_code: String::from("Alibaba-NLP/gte-large-en-v1.5"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::GTELargeENV15Q,
dim: 1024,
description: String::from("Quantized Large multilingual embedding model from Alibaba"),
model_code: String::from("Alibaba-NLP/gte-large-en-v1.5"),
model_file: String::from("onnx/model_quantized.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::ClipVitB32,
dim: 512,
description: String::from("CLIP text encoder based on ViT-B/32"),
model_code: String::from("Qdrant/clip-ViT-B-32-text"),
model_file: String::from("model.onnx"),
additional_files: Vec::new(),
},
ModelInfo {
model: EmbeddingModel::JinaEmbeddingsV2BaseCode,
dim: 768,
description: String::from("Jina embeddings v2 base code"),
model_code: String::from("jinaai/jina-embeddings-v2-base-code"),
model_file: String::from("onnx/model.onnx"),
additional_files: Vec::new(),
},
];
models_list
.into_iter()
.fold(HashMap::new(), |mut map, model| {
map.insert(model.model.clone(), model);
map
})
}
pub fn models_map() -> &'static HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
MODEL_MAP.get_or_init(init_models_map)
}
pub fn get_model_info(model: &EmbeddingModel) -> Option<&ModelInfo<EmbeddingModel>> {
models_map().get(model)
}
pub fn models_list() -> Vec<ModelInfo<EmbeddingModel>> {
models_map().values().cloned().collect()
}
impl EmbeddingModel {
pub fn get_default_pooling_method(&self) -> Option<Pooling> {
match self {
EmbeddingModel::AllMiniLML6V2 => Some(Pooling::Mean),
EmbeddingModel::AllMiniLML6V2Q => Some(Pooling::Mean),
EmbeddingModel::AllMiniLML12V2 => Some(Pooling::Mean),
EmbeddingModel::AllMiniLML12V2Q => Some(Pooling::Mean),
EmbeddingModel::BGEBaseENV15 => Some(Pooling::Cls),
EmbeddingModel::BGEBaseENV15Q => Some(Pooling::Cls),
EmbeddingModel::BGELargeENV15 => Some(Pooling::Cls),
EmbeddingModel::BGELargeENV15Q => Some(Pooling::Cls),
EmbeddingModel::BGESmallENV15 => Some(Pooling::Cls),
EmbeddingModel::BGESmallENV15Q => Some(Pooling::Cls),
EmbeddingModel::BGESmallZHV15 => Some(Pooling::Cls),
EmbeddingModel::NomicEmbedTextV1 => Some(Pooling::Mean),
EmbeddingModel::NomicEmbedTextV15 => Some(Pooling::Mean),
EmbeddingModel::NomicEmbedTextV15Q => Some(Pooling::Mean),
EmbeddingModel::ParaphraseMLMiniLML12V2 => Some(Pooling::Mean),
EmbeddingModel::ParaphraseMLMiniLML12V2Q => Some(Pooling::Mean),
EmbeddingModel::ParaphraseMLMpnetBaseV2 => Some(Pooling::Mean),
EmbeddingModel::MultilingualE5Base => Some(Pooling::Mean),
EmbeddingModel::MultilingualE5Small => Some(Pooling::Mean),
EmbeddingModel::MultilingualE5Large => Some(Pooling::Mean),
EmbeddingModel::MxbaiEmbedLargeV1 => Some(Pooling::Cls),
EmbeddingModel::MxbaiEmbedLargeV1Q => Some(Pooling::Cls),
EmbeddingModel::GTEBaseENV15 => Some(Pooling::Cls),
EmbeddingModel::GTEBaseENV15Q => Some(Pooling::Cls),
EmbeddingModel::GTELargeENV15 => Some(Pooling::Cls),
EmbeddingModel::GTELargeENV15Q => Some(Pooling::Cls),
EmbeddingModel::ClipVitB32 => Some(Pooling::Mean),
EmbeddingModel::JinaEmbeddingsV2BaseCode => Some(Pooling::Mean),
}
}
pub fn get_quantization_mode(&self) -> QuantizationMode {
match self {
EmbeddingModel::AllMiniLML6V2Q => QuantizationMode::Dynamic,
EmbeddingModel::AllMiniLML12V2Q => QuantizationMode::Dynamic,
EmbeddingModel::BGEBaseENV15Q => QuantizationMode::Static,
EmbeddingModel::BGELargeENV15Q => QuantizationMode::Static,
EmbeddingModel::BGESmallENV15Q => QuantizationMode::Static,
EmbeddingModel::NomicEmbedTextV15Q => QuantizationMode::Dynamic,
EmbeddingModel::ParaphraseMLMiniLML12V2Q => QuantizationMode::Static,
EmbeddingModel::MxbaiEmbedLargeV1Q => QuantizationMode::Dynamic,
EmbeddingModel::GTEBaseENV15Q => QuantizationMode::Dynamic,
EmbeddingModel::GTELargeENV15Q => QuantizationMode::Dynamic,
_ => QuantizationMode::None,
}
}
}
impl Display for EmbeddingModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let model_info = get_model_info(self).expect("Model not found.");
write!(f, "{}", model_info.model_code)
}
}