use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "kebab-case")]
pub enum EmbeddingModel {
#[default]
MiniLM,
BgeSmall,
E5Small,
}
impl EmbeddingModel {
pub fn model_id(&self) -> &'static str {
match self {
EmbeddingModel::MiniLM => "sentence-transformers/all-MiniLM-L6-v2",
EmbeddingModel::BgeSmall => "BAAI/bge-small-en-v1.5",
EmbeddingModel::E5Small => "intfloat/e5-small-v2",
}
}
pub fn dimension(&self) -> usize {
match self {
EmbeddingModel::MiniLM => 384,
EmbeddingModel::BgeSmall => 384,
EmbeddingModel::E5Small => 384,
}
}
pub fn max_seq_length(&self) -> usize {
match self {
EmbeddingModel::MiniLM => 256,
EmbeddingModel::BgeSmall => 512,
EmbeddingModel::E5Small => 512,
}
}
pub fn query_prefix(&self) -> Option<&'static str> {
match self {
EmbeddingModel::MiniLM => None,
EmbeddingModel::BgeSmall => None,
EmbeddingModel::E5Small => Some("query: "),
}
}
pub fn document_prefix(&self) -> Option<&'static str> {
match self {
EmbeddingModel::MiniLM => None,
EmbeddingModel::BgeSmall => None,
EmbeddingModel::E5Small => Some("passage: "),
}
}
pub fn use_mean_pooling(&self) -> bool {
match self {
EmbeddingModel::MiniLM => true,
EmbeddingModel::BgeSmall => true,
EmbeddingModel::E5Small => true,
}
}
pub fn normalize_embeddings(&self) -> bool {
true }
pub fn tokens_per_second_cpu(&self) -> usize {
match self {
EmbeddingModel::MiniLM => 5000,
EmbeddingModel::BgeSmall => 3000,
EmbeddingModel::E5Small => 3000,
}
}
pub fn all() -> &'static [EmbeddingModel] {
&[
EmbeddingModel::MiniLM,
EmbeddingModel::BgeSmall,
EmbeddingModel::E5Small,
]
}
pub fn parse(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"minilm" | "all-minilm-l6-v2" | "mini-lm" => Some(EmbeddingModel::MiniLM),
"bge-small" | "bge" | "bge-small-en" => Some(EmbeddingModel::BgeSmall),
"e5-small" | "e5" | "e5-small-v2" => Some(EmbeddingModel::E5Small),
_ => None,
}
}
}
impl std::fmt::Display for EmbeddingModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EmbeddingModel::MiniLM => write!(f, "all-MiniLM-L6-v2"),
EmbeddingModel::BgeSmall => write!(f, "bge-small-en-v1.5"),
EmbeddingModel::E5Small => write!(f, "e5-small-v2"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub model: EmbeddingModel,
pub cache_dir: Option<String>,
pub max_batch_size: usize,
pub use_gpu: bool,
pub num_threads: Option<usize>,
}
impl Default for ModelConfig {
fn default() -> Self {
Self {
model: EmbeddingModel::default(),
cache_dir: None,
max_batch_size: 32,
use_gpu: false,
num_threads: None,
}
}
}
impl ModelConfig {
pub fn new(model: EmbeddingModel) -> Self {
Self {
model,
..Default::default()
}
}
pub fn with_cache_dir(mut self, dir: impl Into<String>) -> Self {
self.cache_dir = Some(dir.into());
self
}
pub fn with_max_batch_size(mut self, size: usize) -> Self {
self.max_batch_size = size;
self
}
pub fn with_gpu(mut self, use_gpu: bool) -> Self {
self.use_gpu = use_gpu;
self
}
pub fn with_num_threads(mut self, threads: usize) -> Self {
self.num_threads = Some(threads);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_ids() {
assert_eq!(
EmbeddingModel::MiniLM.model_id(),
"sentence-transformers/all-MiniLM-L6-v2"
);
assert_eq!(
EmbeddingModel::BgeSmall.model_id(),
"BAAI/bge-small-en-v1.5"
);
assert_eq!(EmbeddingModel::E5Small.model_id(), "intfloat/e5-small-v2");
}
#[test]
fn test_dimensions() {
for model in EmbeddingModel::all() {
assert_eq!(model.dimension(), 384);
}
}
#[test]
fn test_from_str() {
assert_eq!(
EmbeddingModel::parse("minilm"),
Some(EmbeddingModel::MiniLM)
);
assert_eq!(
EmbeddingModel::parse("BGE-SMALL"),
Some(EmbeddingModel::BgeSmall)
);
assert_eq!(EmbeddingModel::parse("e5"), Some(EmbeddingModel::E5Small));
assert_eq!(EmbeddingModel::parse("unknown"), None);
}
#[test]
fn test_e5_prefixes() {
assert_eq!(EmbeddingModel::E5Small.query_prefix(), Some("query: "));
assert_eq!(EmbeddingModel::E5Small.document_prefix(), Some("passage: "));
assert_eq!(EmbeddingModel::MiniLM.query_prefix(), None);
}
}