use serde::{Deserialize, Serialize};
use std::path::PathBuf;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
pub model_type: EmbeddingModelType,
pub max_batch_size: usize,
pub adaptive_batching: bool,
pub memory_pool_size: usize,
pub max_concurrent_ops: usize,
pub enable_performance_monitoring: bool,
pub cache_dir: PathBuf,
pub enable_caching: bool,
pub operation_timeout_secs: u64,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
model_type: EmbeddingModelType::default(),
max_batch_size: 32,
adaptive_batching: true,
memory_pool_size: 1000,
max_concurrent_ops: num_cpus::get() * 2,
enable_performance_monitoring: true,
cache_dir: PathBuf::from("./models_cache"),
enable_caching: true,
operation_timeout_secs: 30,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Default)]
pub enum EmbeddingModelType {
StaticSimilarityMRL,
MiniLM,
MultilingualMiniLM,
TinyBERT,
BGESmall,
#[default]
PotionMultilingual,
PotionCode,
}
impl EmbeddingModelType {
pub fn embedding_dimension(&self) -> usize {
match self {
Self::StaticSimilarityMRL => 1024,
Self::MiniLM | Self::MultilingualMiniLM | Self::BGESmall => 384,
Self::TinyBERT => 312,
Self::PotionMultilingual => 256,
Self::PotionCode => 512,
}
}
pub fn model_id(&self) -> &'static str {
match self {
Self::StaticSimilarityMRL | Self::MiniLM => "sentence-transformers/all-MiniLM-L6-v2",
Self::MultilingualMiniLM => {
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
}
Self::TinyBERT => "huawei-noah/TinyBERT_General_6L_312D",
Self::BGESmall => "BAAI/bge-small-en-v1.5",
Self::PotionMultilingual => "minishlab/potion-multilingual-128M",
Self::PotionCode => "minishlab/potion-code-16M",
}
}
pub fn is_bert_based(&self) -> bool {
matches!(
self,
Self::MiniLM | Self::MultilingualMiniLM | Self::TinyBERT | Self::BGESmall
)
}
pub fn is_model2vec(&self) -> bool {
matches!(self, Self::PotionMultilingual | Self::PotionCode)
}
}