use super::pooling::Pooling;
use std::collections::HashMap;
use std::sync::OnceLock;
#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
pub enum EmbeddingModel {
#[default]
AllMiniLML6V2,
AllMiniLML12V2,
AllMpnetBaseV2,
BGESmallENV15,
BGEBaseENV15,
BGELargeENV15,
BGESmallZHV15,
BGELargeZHV15,
MultilingualE5Small,
MultilingualE5Base,
MultilingualE5Large,
ParaphraseMLMiniLML12V2,
ParaphraseMLMpnetBaseV2,
SnowflakeArcticEmbedXS,
SnowflakeArcticEmbedS,
SnowflakeArcticEmbedM,
SnowflakeArcticEmbedL,
MxbaiEmbedLargeV1,
NomicEmbedTextV15,
}
#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
pub enum ImageEmbeddingModel {
#[default]
NomicEmbedVisionV15,
}
#[derive(Debug, Clone)]
pub struct ImageModelInfo {
pub model: ImageEmbeddingModel,
pub dim: usize,
pub description: &'static str,
pub hf_repo: &'static str,
pub model_file: &'static str,
pub img_size: usize,
}
static IMAGE_MODEL_MAP: OnceLock<HashMap<ImageEmbeddingModel, ImageModelInfo>> = OnceLock::new();
fn init_image_models_map() -> HashMap<ImageEmbeddingModel, ImageModelInfo> {
vec![ImageModelInfo {
model: ImageEmbeddingModel::NomicEmbedVisionV15,
dim: 768,
description: "Nomic embed vision v1.5, 12 layers, 224px",
hf_repo: "nomic-ai/nomic-embed-vision-v1.5",
model_file: "model.safetensors",
img_size: 224,
}]
.into_iter()
.map(|info| (info.model.clone(), info))
.collect()
}
impl ImageEmbeddingModel {
pub fn get_info(&self) -> Option<&'static ImageModelInfo> {
IMAGE_MODEL_MAP.get_or_init(init_image_models_map).get(self)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ModelArch {
Bert,
NomicBert,
}
#[derive(Debug, Clone)]
pub struct ModelInfo {
pub model: EmbeddingModel,
pub dim: usize,
pub description: &'static str,
pub hf_repo: &'static str,
pub model_file: &'static str,
pub pooling: Pooling,
pub max_length: usize,
pub arch: ModelArch,
}
static MODEL_MAP: OnceLock<HashMap<EmbeddingModel, ModelInfo>> = OnceLock::new();
fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo> {
vec![
ModelInfo {
model: EmbeddingModel::AllMiniLML6V2,
dim: 384,
description: "MiniLM-L6-v2, 6 layers, fast and lightweight",
hf_repo: "sentence-transformers/all-MiniLM-L6-v2",
model_file: "model.safetensors",
pooling: Pooling::Mean,
max_length: 256,
arch: ModelArch::Bert,
},
ModelInfo {
model: EmbeddingModel::AllMiniLML12V2,
dim: 384,
description: "MiniLM-L12-v2, 12 layers, higher quality",
hf_repo: "sentence-transformers/all-MiniLM-L12-v2",
model_file: "model.safetensors",
pooling: Pooling::Mean,
max_length: 256,
arch: ModelArch::Bert,
},
ModelInfo {
model: EmbeddingModel::BGESmallENV15,
dim: 384,
description: "BGE small English v1.5, compact and fast",
hf_repo: "BAAI/bge-small-en-v1.5",
model_file: "model.safetensors",
pooling: Pooling::Cls,
max_length: 512,
arch: ModelArch::Bert,
},
ModelInfo {
model: EmbeddingModel::BGEBaseENV15,
dim: 768,
description: "BGE base English v1.5, balanced quality and speed",
hf_repo: "BAAI/bge-base-en-v1.5",
model_file: "model.safetensors",
pooling: Pooling::Cls,
max_length: 512,
arch: ModelArch::Bert,
},
ModelInfo {
model: EmbeddingModel::BGELargeENV15,
dim: 1024,
description: "BGE large English v1.5, highest quality",
hf_repo: "BAAI/bge-large-en-v1.5",
model_file: "model.safetensors",
pooling: Pooling::Cls,
max_length: 512,
arch: ModelArch::Bert,
},
ModelInfo {
model: EmbeddingModel::BGESmallZHV15,
dim: 512,
description: "BGE small Chinese v1.5, CLS pooling",
hf_repo: "BAAI/bge-small-zh-v1.5",
model_file: "model.safetensors",
pooling: Pooling::Cls,
max_length: 512,
arch: ModelArch::Bert,
},
ModelInfo {
model: EmbeddingModel::ParaphraseMLMiniLML12V2,
dim: 384,
description: "Paraphrase multilingual MiniLM L12 v2, mean pooling",
hf_repo: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
model_file: "model.safetensors",
pooling: Pooling::Mean,
max_length: 128,
arch: ModelArch::Bert,
},
ModelInfo {
model: EmbeddingModel::AllMpnetBaseV2,
dim: 768,
description: "mpnet-base-v2, strong general-purpose embeddings",
hf_repo: "sentence-transformers/all-mpnet-base-v2",
model_file: "model.safetensors",
pooling: Pooling::Mean,
max_length: 384,
arch: ModelArch::Bert,
},
ModelInfo {
model: EmbeddingModel::BGELargeZHV15,
dim: 1024,
description: "BGE large Chinese v1.5",
hf_repo: "BAAI/bge-large-zh-v1.5",
model_file: "model.safetensors",
pooling: Pooling::Cls,
max_length: 512,
arch: ModelArch::Bert,
},
ModelInfo {
model: EmbeddingModel::MultilingualE5Small,
dim: 384,
description: "Multilingual E5 small, 100+ languages",
hf_repo: "intfloat/multilingual-e5-small",
model_file: "model.safetensors",
pooling: Pooling::Mean,
max_length: 512,
arch: ModelArch::Bert,
},
ModelInfo {
model: EmbeddingModel::MultilingualE5Base,
dim: 768,
description: "Multilingual E5 base, 100+ languages",
hf_repo: "intfloat/multilingual-e5-base",
model_file: "model.safetensors",
pooling: Pooling::Mean,
max_length: 512,
arch: ModelArch::Bert,
},
ModelInfo {
model: EmbeddingModel::MultilingualE5Large,
dim: 1024,
description: "Multilingual E5 large, 100+ languages",
hf_repo: "intfloat/multilingual-e5-large",
model_file: "model.safetensors",
pooling: Pooling::Mean,
max_length: 512,
arch: ModelArch::Bert,
},
ModelInfo {
model: EmbeddingModel::ParaphraseMLMpnetBaseV2,
dim: 768,
description: "Paraphrase multilingual mpnet base v2",
hf_repo: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
model_file: "model.safetensors",
pooling: Pooling::Mean,
max_length: 384,
arch: ModelArch::Bert,
},
ModelInfo {
model: EmbeddingModel::SnowflakeArcticEmbedXS,
dim: 384,
description: "Snowflake Arctic Embed XS",
hf_repo: "snowflake/snowflake-arctic-embed-xs",
model_file: "model.safetensors",
pooling: Pooling::Cls,
max_length: 512,
arch: ModelArch::Bert,
},
ModelInfo {
model: EmbeddingModel::SnowflakeArcticEmbedS,
dim: 384,
description: "Snowflake Arctic Embed S",
hf_repo: "snowflake/snowflake-arctic-embed-s",
model_file: "model.safetensors",
pooling: Pooling::Cls,
max_length: 512,
arch: ModelArch::Bert,
},
ModelInfo {
model: EmbeddingModel::SnowflakeArcticEmbedM,
dim: 768,
description: "Snowflake Arctic Embed M",
hf_repo: "Snowflake/snowflake-arctic-embed-m",
model_file: "model.safetensors",
pooling: Pooling::Cls,
max_length: 512,
arch: ModelArch::Bert,
},
ModelInfo {
model: EmbeddingModel::SnowflakeArcticEmbedL,
dim: 1024,
description: "Snowflake Arctic Embed L",
hf_repo: "snowflake/snowflake-arctic-embed-l",
model_file: "model.safetensors",
pooling: Pooling::Cls,
max_length: 512,
arch: ModelArch::Bert,
},
ModelInfo {
model: EmbeddingModel::MxbaiEmbedLargeV1,
dim: 1024,
description: "MxBai embed large v1",
hf_repo: "mixedbread-ai/mxbai-embed-large-v1",
model_file: "model.safetensors",
pooling: Pooling::Cls,
max_length: 512,
arch: ModelArch::Bert,
},
ModelInfo {
model: EmbeddingModel::NomicEmbedTextV15,
dim: 768,
description: "Nomic embed text v1.5, RoPE, SwiGLU, 8192 context",
hf_repo: "nomic-ai/nomic-embed-text-v1.5",
model_file: "model.safetensors",
pooling: Pooling::Mean,
max_length: 8192,
arch: ModelArch::NomicBert,
},
]
.into_iter()
.map(|info| (info.model.clone(), info))
.collect()
}
pub fn models_map() -> &'static HashMap<EmbeddingModel, ModelInfo> {
MODEL_MAP.get_or_init(init_models_map)
}
impl EmbeddingModel {
pub fn get_info(&self) -> Option<&'static ModelInfo> {
models_map().get(self)
}
pub fn list_supported() -> Vec<&'static ModelInfo> {
models_map().values().collect()
}
}