use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum EmbeddingModel {
AllMiniLmL6V2,
BgeSmallEnV15,
BgeBaseEnV15,
AllMpnetBaseV2,
NomicEmbedTextV15,
BgeLargeEnV15,
}
impl EmbeddingModel {
pub fn from_name(name: &str) -> Option<Self> {
match name.to_lowercase().as_str() {
"all-minilm-l6-v2" | "minilm" | "default" => Some(Self::AllMiniLmL6V2),
"bge-small-en-v1.5" | "bge-small" | "baai/bge-small-en-v1.5" => {
Some(Self::BgeSmallEnV15)
}
"bge-base-en-v1.5" | "bge-base" | "baai/bge-base-en-v1.5" => Some(Self::BgeBaseEnV15),
"bge-large-en-v1.5" | "bge-large" | "baai/bge-large-en-v1.5" => {
Some(Self::BgeLargeEnV15)
}
"all-mpnet-base-v2" | "mpnet" | "sentence-transformers/all-mpnet-base-v2" => {
Some(Self::AllMpnetBaseV2)
}
"nomic-embed-text-v1.5" | "nomic" | "nomic-ai/nomic-embed-text-v1.5" => {
Some(Self::NomicEmbedTextV15)
}
_ => None,
}
}
pub fn name(&self) -> &'static str {
match self {
Self::AllMiniLmL6V2 => "all-MiniLM-L6-v2",
Self::BgeSmallEnV15 => "BAAI/bge-small-en-v1.5",
Self::BgeBaseEnV15 => "BAAI/bge-base-en-v1.5",
Self::BgeLargeEnV15 => "BAAI/bge-large-en-v1.5",
Self::AllMpnetBaseV2 => "sentence-transformers/all-mpnet-base-v2",
Self::NomicEmbedTextV15 => "nomic-ai/nomic-embed-text-v1.5",
}
}
pub fn dimensions(&self) -> usize {
match self {
Self::AllMiniLmL6V2 => 384,
Self::BgeSmallEnV15 => 384,
Self::BgeBaseEnV15 => 768,
Self::BgeLargeEnV15 => 1024,
Self::AllMpnetBaseV2 => 768,
Self::NomicEmbedTextV15 => 768,
}
}
pub fn description(&self) -> &'static str {
match self {
Self::AllMiniLmL6V2 => "Fast general-purpose model, good for most use cases",
Self::BgeSmallEnV15 => "High quality small model from BAAI, great for semantic search",
Self::BgeBaseEnV15 => "Higher quality base model from BAAI, better accuracy",
Self::BgeLargeEnV15 => "Highest quality large model from BAAI, best accuracy",
Self::AllMpnetBaseV2 => "High quality model from sentence-transformers",
Self::NomicEmbedTextV15 => "Modern model with good quality from Nomic AI",
}
}
pub fn speed_rating(&self) -> u8 {
match self {
Self::AllMiniLmL6V2 => 5,
Self::BgeSmallEnV15 => 5,
Self::BgeBaseEnV15 => 3,
Self::BgeLargeEnV15 => 1,
Self::AllMpnetBaseV2 => 3,
Self::NomicEmbedTextV15 => 3,
}
}
pub fn quality_rating(&self) -> u8 {
match self {
Self::AllMiniLmL6V2 => 3,
Self::BgeSmallEnV15 => 4,
Self::BgeBaseEnV15 => 4,
Self::BgeLargeEnV15 => 5,
Self::AllMpnetBaseV2 => 4,
Self::NomicEmbedTextV15 => 4,
}
}
pub fn memory_mb(&self) -> usize {
match self {
Self::AllMiniLmL6V2 => 90,
Self::BgeSmallEnV15 => 130,
Self::BgeBaseEnV15 => 440,
Self::BgeLargeEnV15 => 1340,
Self::AllMpnetBaseV2 => 440,
Self::NomicEmbedTextV15 => 550,
}
}
pub fn all() -> &'static [EmbeddingModel] {
&[
Self::AllMiniLmL6V2,
Self::BgeSmallEnV15,
Self::BgeBaseEnV15,
Self::BgeLargeEnV15,
Self::AllMpnetBaseV2,
Self::NomicEmbedTextV15,
]
}
pub fn default_model() -> Self {
Self::AllMiniLmL6V2
}
}
impl Default for EmbeddingModel {
fn default() -> Self {
Self::default_model()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub name: String,
pub dimensions: i32,
pub description: String,
pub speed_rating: i32,
pub quality_rating: i32,
pub memory_mb: i32,
pub loaded: bool,
}
impl From<EmbeddingModel> for ModelInfo {
fn from(model: EmbeddingModel) -> Self {
Self {
name: model.name().to_string(),
dimensions: model.dimensions() as i32,
description: model.description().to_string(),
speed_rating: model.speed_rating() as i32,
quality_rating: model.quality_rating() as i32,
memory_mb: model.memory_mb() as i32,
loaded: false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_parsing() {
assert_eq!(
EmbeddingModel::from_name("all-minilm-l6-v2"),
Some(EmbeddingModel::AllMiniLmL6V2)
);
assert_eq!(
EmbeddingModel::from_name("minilm"),
Some(EmbeddingModel::AllMiniLmL6V2)
);
assert_eq!(
EmbeddingModel::from_name("default"),
Some(EmbeddingModel::AllMiniLmL6V2)
);
assert_eq!(
EmbeddingModel::from_name("bge-small"),
Some(EmbeddingModel::BgeSmallEnV15)
);
assert_eq!(EmbeddingModel::from_name("unknown"), None);
}
#[test]
fn test_model_dimensions() {
assert_eq!(EmbeddingModel::AllMiniLmL6V2.dimensions(), 384);
assert_eq!(EmbeddingModel::BgeBaseEnV15.dimensions(), 768);
assert_eq!(EmbeddingModel::BgeLargeEnV15.dimensions(), 1024);
}
#[test]
fn test_all_models() {
let models = EmbeddingModel::all();
assert!(models.len() >= 4);
for model in models {
assert!(!model.name().is_empty());
assert!(model.dimensions() > 0);
}
}
}