use crate::types::{AppError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Display;
use std::str::FromStr;
use std::sync::{Arc, Mutex, OnceLock};
use tokio::task::spawn_blocking;
pub use fastembed::{EmbeddingModel as FastEmbedModel, InitOptions, SparseModel, TextEmbedding};
static MODEL_INIT_LOCKS: OnceLock<Mutex<HashMap<String, Arc<Mutex<()>>>>> = OnceLock::new();
fn get_model_lock(model_name: &str) -> Arc<Mutex<()>> {
let locks = MODEL_INIT_LOCKS.get_or_init(|| Mutex::new(HashMap::new()));
let mut map = locks.lock().unwrap();
map.entry(model_name.to_string())
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "kebab-case")]
pub enum EmbeddingModelType {
#[default]
BgeSmallEnV15,
BgeSmallEnV15Q,
AllMiniLmL6V2,
AllMiniLmL6V2Q,
AllMiniLmL12V2,
AllMiniLmL12V2Q,
AllMpnetBaseV2,
BgeBaseEnV15,
BgeBaseEnV15Q,
BgeLargeEnV15,
BgeLargeEnV15Q,
MultilingualE5Small,
MultilingualE5Base,
MultilingualE5Large,
ParaphraseMiniLmL12V2,
ParaphraseMiniLmL12V2Q,
ParaphraseMultilingualMpnetBaseV2,
BgeSmallZhV15,
BgeLargeZhV15,
NomicEmbedTextV1,
NomicEmbedTextV15,
NomicEmbedTextV15Q,
MxbaiEmbedLargeV1,
MxbaiEmbedLargeV1Q,
GteBaseEnV15,
GteBaseEnV15Q,
GteLargeEnV15,
GteLargeEnV15Q,
ClipVitB32,
JinaEmbeddingsV2BaseCode,
EmbeddingGemma300M,
ModernBertEmbedLarge,
SnowflakeArcticEmbedXs,
SnowflakeArcticEmbedXsQ,
SnowflakeArcticEmbedS,
SnowflakeArcticEmbedSQ,
SnowflakeArcticEmbedM,
SnowflakeArcticEmbedMQ,
SnowflakeArcticEmbedMLong,
SnowflakeArcticEmbedMLongQ,
SnowflakeArcticEmbedL,
SnowflakeArcticEmbedLQ,
}
impl EmbeddingModelType {
pub fn to_fastembed_model(&self) -> FastEmbedModel {
match self {
Self::BgeSmallEnV15 => FastEmbedModel::BGESmallENV15,
Self::BgeSmallEnV15Q => FastEmbedModel::BGESmallENV15Q,
Self::AllMiniLmL6V2 => FastEmbedModel::AllMiniLML6V2,
Self::AllMiniLmL6V2Q => FastEmbedModel::AllMiniLML6V2Q,
Self::AllMiniLmL12V2 => FastEmbedModel::AllMiniLML12V2,
Self::AllMiniLmL12V2Q => FastEmbedModel::AllMiniLML12V2Q,
Self::AllMpnetBaseV2 => FastEmbedModel::AllMpnetBaseV2,
Self::BgeBaseEnV15 => FastEmbedModel::BGEBaseENV15,
Self::BgeBaseEnV15Q => FastEmbedModel::BGEBaseENV15Q,
Self::BgeLargeEnV15 => FastEmbedModel::BGELargeENV15,
Self::BgeLargeEnV15Q => FastEmbedModel::BGELargeENV15Q,
Self::MultilingualE5Small => FastEmbedModel::MultilingualE5Small,
Self::MultilingualE5Base => FastEmbedModel::MultilingualE5Base,
Self::MultilingualE5Large => FastEmbedModel::MultilingualE5Large,
Self::ParaphraseMiniLmL12V2 => FastEmbedModel::ParaphraseMLMiniLML12V2,
Self::ParaphraseMiniLmL12V2Q => FastEmbedModel::ParaphraseMLMiniLML12V2Q,
Self::ParaphraseMultilingualMpnetBaseV2 => FastEmbedModel::ParaphraseMLMpnetBaseV2,
Self::BgeSmallZhV15 => FastEmbedModel::BGESmallZHV15,
Self::BgeLargeZhV15 => FastEmbedModel::BGELargeZHV15,
Self::NomicEmbedTextV1 => FastEmbedModel::NomicEmbedTextV1,
Self::NomicEmbedTextV15 => FastEmbedModel::NomicEmbedTextV15,
Self::NomicEmbedTextV15Q => FastEmbedModel::NomicEmbedTextV15Q,
Self::MxbaiEmbedLargeV1 => FastEmbedModel::MxbaiEmbedLargeV1,
Self::MxbaiEmbedLargeV1Q => FastEmbedModel::MxbaiEmbedLargeV1Q,
Self::GteBaseEnV15 => FastEmbedModel::GTEBaseENV15,
Self::GteBaseEnV15Q => FastEmbedModel::GTEBaseENV15Q,
Self::GteLargeEnV15 => FastEmbedModel::GTELargeENV15,
Self::GteLargeEnV15Q => FastEmbedModel::GTELargeENV15Q,
Self::ClipVitB32 => FastEmbedModel::ClipVitB32,
Self::JinaEmbeddingsV2BaseCode => FastEmbedModel::JinaEmbeddingsV2BaseCode,
Self::EmbeddingGemma300M => FastEmbedModel::EmbeddingGemma300M,
Self::ModernBertEmbedLarge => FastEmbedModel::ModernBertEmbedLarge,
Self::SnowflakeArcticEmbedXs => FastEmbedModel::SnowflakeArcticEmbedXS,
Self::SnowflakeArcticEmbedXsQ => FastEmbedModel::SnowflakeArcticEmbedXSQ,
Self::SnowflakeArcticEmbedS => FastEmbedModel::SnowflakeArcticEmbedS,
Self::SnowflakeArcticEmbedSQ => FastEmbedModel::SnowflakeArcticEmbedSQ,
Self::SnowflakeArcticEmbedM => FastEmbedModel::SnowflakeArcticEmbedM,
Self::SnowflakeArcticEmbedMQ => FastEmbedModel::SnowflakeArcticEmbedMQ,
Self::SnowflakeArcticEmbedMLong => FastEmbedModel::SnowflakeArcticEmbedMLong,
Self::SnowflakeArcticEmbedMLongQ => FastEmbedModel::SnowflakeArcticEmbedMLongQ,
Self::SnowflakeArcticEmbedL => FastEmbedModel::SnowflakeArcticEmbedL,
Self::SnowflakeArcticEmbedLQ => FastEmbedModel::SnowflakeArcticEmbedLQ,
}
}
pub fn dimensions(&self) -> usize {
match self {
Self::BgeSmallEnV15
| Self::BgeSmallEnV15Q
| Self::AllMiniLmL6V2
| Self::AllMiniLmL6V2Q
| Self::AllMiniLmL12V2
| Self::AllMiniLmL12V2Q
| Self::MultilingualE5Small
| Self::SnowflakeArcticEmbedXs
| Self::SnowflakeArcticEmbedXsQ
| Self::SnowflakeArcticEmbedS
| Self::SnowflakeArcticEmbedSQ => 384,
Self::BgeSmallZhV15 | Self::ClipVitB32 => 512,
Self::AllMpnetBaseV2
| Self::BgeBaseEnV15
| Self::BgeBaseEnV15Q
| Self::MultilingualE5Base
| Self::ParaphraseMiniLmL12V2
| Self::ParaphraseMiniLmL12V2Q
| Self::ParaphraseMultilingualMpnetBaseV2
| Self::NomicEmbedTextV1
| Self::NomicEmbedTextV15
| Self::NomicEmbedTextV15Q
| Self::GteBaseEnV15
| Self::GteBaseEnV15Q
| Self::JinaEmbeddingsV2BaseCode
| Self::EmbeddingGemma300M
| Self::SnowflakeArcticEmbedM
| Self::SnowflakeArcticEmbedMQ
| Self::SnowflakeArcticEmbedMLong
| Self::SnowflakeArcticEmbedMLongQ => 768,
Self::BgeLargeEnV15
| Self::BgeLargeEnV15Q
| Self::BgeLargeZhV15
| Self::MultilingualE5Large
| Self::MxbaiEmbedLargeV1
| Self::MxbaiEmbedLargeV1Q
| Self::GteLargeEnV15
| Self::GteLargeEnV15Q
| Self::ModernBertEmbedLarge
| Self::SnowflakeArcticEmbedL
| Self::SnowflakeArcticEmbedLQ => 1024,
}
}
pub fn is_quantized(&self) -> bool {
matches!(
self,
Self::BgeSmallEnV15Q
| Self::AllMiniLmL6V2Q
| Self::AllMiniLmL12V2Q
| Self::BgeBaseEnV15Q
| Self::BgeLargeEnV15Q
| Self::ParaphraseMiniLmL12V2Q
| Self::NomicEmbedTextV15Q
| Self::MxbaiEmbedLargeV1Q
| Self::GteBaseEnV15Q
| Self::GteLargeEnV15Q
| Self::SnowflakeArcticEmbedXsQ
| Self::SnowflakeArcticEmbedSQ
| Self::SnowflakeArcticEmbedMQ
| Self::SnowflakeArcticEmbedMLongQ
| Self::SnowflakeArcticEmbedLQ
)
}
pub fn is_multilingual(&self) -> bool {
matches!(
self,
Self::MultilingualE5Small
| Self::MultilingualE5Base
| Self::MultilingualE5Large
| Self::ParaphraseMultilingualMpnetBaseV2
| Self::BgeSmallZhV15
| Self::BgeLargeZhV15
)
}
pub fn max_context_length(&self) -> usize {
match self {
Self::NomicEmbedTextV1 | Self::NomicEmbedTextV15 | Self::NomicEmbedTextV15Q => 8192,
Self::SnowflakeArcticEmbedMLong | Self::SnowflakeArcticEmbedMLongQ => 2048,
_ => 512,
}
}
pub fn all() -> Vec<Self> {
vec![
Self::BgeSmallEnV15,
Self::BgeSmallEnV15Q,
Self::AllMiniLmL6V2,
Self::AllMiniLmL6V2Q,
Self::AllMiniLmL12V2,
Self::AllMiniLmL12V2Q,
Self::AllMpnetBaseV2,
Self::BgeBaseEnV15,
Self::BgeBaseEnV15Q,
Self::BgeLargeEnV15,
Self::BgeLargeEnV15Q,
Self::MultilingualE5Small,
Self::MultilingualE5Base,
Self::MultilingualE5Large,
Self::ParaphraseMiniLmL12V2,
Self::ParaphraseMiniLmL12V2Q,
Self::ParaphraseMultilingualMpnetBaseV2,
Self::BgeSmallZhV15,
Self::BgeLargeZhV15,
Self::NomicEmbedTextV1,
Self::NomicEmbedTextV15,
Self::NomicEmbedTextV15Q,
Self::MxbaiEmbedLargeV1,
Self::MxbaiEmbedLargeV1Q,
Self::GteBaseEnV15,
Self::GteBaseEnV15Q,
Self::GteLargeEnV15,
Self::GteLargeEnV15Q,
Self::ClipVitB32,
Self::JinaEmbeddingsV2BaseCode,
Self::EmbeddingGemma300M,
Self::ModernBertEmbedLarge,
Self::SnowflakeArcticEmbedXs,
Self::SnowflakeArcticEmbedXsQ,
Self::SnowflakeArcticEmbedS,
Self::SnowflakeArcticEmbedSQ,
Self::SnowflakeArcticEmbedM,
Self::SnowflakeArcticEmbedMQ,
Self::SnowflakeArcticEmbedMLong,
Self::SnowflakeArcticEmbedMLongQ,
Self::SnowflakeArcticEmbedL,
Self::SnowflakeArcticEmbedLQ,
]
}
}
impl Display for EmbeddingModelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = match self {
Self::BgeSmallEnV15 => "bge-small-en-v1.5",
Self::BgeSmallEnV15Q => "bge-small-en-v1.5-q",
Self::AllMiniLmL6V2 => "all-minilm-l6-v2",
Self::AllMiniLmL6V2Q => "all-minilm-l6-v2-q",
Self::AllMiniLmL12V2 => "all-minilm-l12-v2",
Self::AllMiniLmL12V2Q => "all-minilm-l12-v2-q",
Self::AllMpnetBaseV2 => "all-mpnet-base-v2",
Self::BgeBaseEnV15 => "bge-base-en-v1.5",
Self::BgeBaseEnV15Q => "bge-base-en-v1.5-q",
Self::BgeLargeEnV15 => "bge-large-en-v1.5",
Self::BgeLargeEnV15Q => "bge-large-en-v1.5-q",
Self::MultilingualE5Small => "multilingual-e5-small",
Self::MultilingualE5Base => "multilingual-e5-base",
Self::MultilingualE5Large => "multilingual-e5-large",
Self::ParaphraseMiniLmL12V2 => "paraphrase-minilm-l12-v2",
Self::ParaphraseMiniLmL12V2Q => "paraphrase-minilm-l12-v2-q",
Self::ParaphraseMultilingualMpnetBaseV2 => "paraphrase-multilingual-mpnet-base-v2",
Self::BgeSmallZhV15 => "bge-small-zh-v1.5",
Self::BgeLargeZhV15 => "bge-large-zh-v1.5",
Self::NomicEmbedTextV1 => "nomic-embed-text-v1",
Self::NomicEmbedTextV15 => "nomic-embed-text-v1.5",
Self::NomicEmbedTextV15Q => "nomic-embed-text-v1.5-q",
Self::MxbaiEmbedLargeV1 => "mxbai-embed-large-v1",
Self::MxbaiEmbedLargeV1Q => "mxbai-embed-large-v1-q",
Self::GteBaseEnV15 => "gte-base-en-v1.5",
Self::GteBaseEnV15Q => "gte-base-en-v1.5-q",
Self::GteLargeEnV15 => "gte-large-en-v1.5",
Self::GteLargeEnV15Q => "gte-large-en-v1.5-q",
Self::ClipVitB32 => "clip-vit-b-32",
Self::JinaEmbeddingsV2BaseCode => "jina-embeddings-v2-base-code",
Self::EmbeddingGemma300M => "embedding-gemma-300m",
Self::ModernBertEmbedLarge => "modernbert-embed-large",
Self::SnowflakeArcticEmbedXs => "snowflake-arctic-embed-xs",
Self::SnowflakeArcticEmbedXsQ => "snowflake-arctic-embed-xs-q",
Self::SnowflakeArcticEmbedS => "snowflake-arctic-embed-s",
Self::SnowflakeArcticEmbedSQ => "snowflake-arctic-embed-s-q",
Self::SnowflakeArcticEmbedM => "snowflake-arctic-embed-m",
Self::SnowflakeArcticEmbedMQ => "snowflake-arctic-embed-m-q",
Self::SnowflakeArcticEmbedMLong => "snowflake-arctic-embed-m-long",
Self::SnowflakeArcticEmbedMLongQ => "snowflake-arctic-embed-m-long-q",
Self::SnowflakeArcticEmbedL => "snowflake-arctic-embed-l",
Self::SnowflakeArcticEmbedLQ => "snowflake-arctic-embed-l-q",
};
write!(f, "{}", name)
}
}
impl FromStr for EmbeddingModelType {
type Err = AppError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"bge-small-en-v1.5" | "bge-small-en" | "bge-small" => Ok(Self::BgeSmallEnV15),
"bge-small-en-v1.5-q" => Ok(Self::BgeSmallEnV15Q),
"all-minilm-l6-v2" | "minilm-l6" => Ok(Self::AllMiniLmL6V2),
"all-minilm-l6-v2-q" => Ok(Self::AllMiniLmL6V2Q),
"all-minilm-l12-v2" | "minilm-l12" => Ok(Self::AllMiniLmL12V2),
"all-minilm-l12-v2-q" => Ok(Self::AllMiniLmL12V2Q),
"all-mpnet-base-v2" | "mpnet" => Ok(Self::AllMpnetBaseV2),
"bge-base-en-v1.5" | "bge-base-en" | "bge-base" => Ok(Self::BgeBaseEnV15),
"bge-base-en-v1.5-q" => Ok(Self::BgeBaseEnV15Q),
"bge-large-en-v1.5" | "bge-large-en" | "bge-large" => Ok(Self::BgeLargeEnV15),
"bge-large-en-v1.5-q" => Ok(Self::BgeLargeEnV15Q),
"multilingual-e5-small" | "e5-small" => Ok(Self::MultilingualE5Small),
"multilingual-e5-base" | "e5-base" => Ok(Self::MultilingualE5Base),
"multilingual-e5-large" | "e5-large" => Ok(Self::MultilingualE5Large),
"paraphrase-minilm-l12-v2" => Ok(Self::ParaphraseMiniLmL12V2),
"paraphrase-minilm-l12-v2-q" => Ok(Self::ParaphraseMiniLmL12V2Q),
"paraphrase-multilingual-mpnet-base-v2" => Ok(Self::ParaphraseMultilingualMpnetBaseV2),
"bge-small-zh-v1.5" | "bge-small-zh" => Ok(Self::BgeSmallZhV15),
"bge-large-zh-v1.5" | "bge-large-zh" => Ok(Self::BgeLargeZhV15),
"nomic-embed-text-v1" | "nomic-v1" => Ok(Self::NomicEmbedTextV1),
"nomic-embed-text-v1.5" | "nomic-v1.5" | "nomic" => Ok(Self::NomicEmbedTextV15),
"nomic-embed-text-v1.5-q" => Ok(Self::NomicEmbedTextV15Q),
"mxbai-embed-large-v1" | "mxbai" => Ok(Self::MxbaiEmbedLargeV1),
"mxbai-embed-large-v1-q" => Ok(Self::MxbaiEmbedLargeV1Q),
"gte-base-en-v1.5" | "gte-base" => Ok(Self::GteBaseEnV15),
"gte-base-en-v1.5-q" => Ok(Self::GteBaseEnV15Q),
"gte-large-en-v1.5" | "gte-large" => Ok(Self::GteLargeEnV15),
"gte-large-en-v1.5-q" => Ok(Self::GteLargeEnV15Q),
"clip-vit-b-32" | "clip" => Ok(Self::ClipVitB32),
"jina-embeddings-v2-base-code" | "jina-code" => Ok(Self::JinaEmbeddingsV2BaseCode),
"embedding-gemma-300m" | "gemma-300m" | "gemma" => Ok(Self::EmbeddingGemma300M),
"modernbert-embed-large" | "modernbert" => Ok(Self::ModernBertEmbedLarge),
"snowflake-arctic-embed-xs" => Ok(Self::SnowflakeArcticEmbedXs),
"snowflake-arctic-embed-xs-q" => Ok(Self::SnowflakeArcticEmbedXsQ),
"snowflake-arctic-embed-s" => Ok(Self::SnowflakeArcticEmbedS),
"snowflake-arctic-embed-s-q" => Ok(Self::SnowflakeArcticEmbedSQ),
"snowflake-arctic-embed-m" => Ok(Self::SnowflakeArcticEmbedM),
"snowflake-arctic-embed-m-q" => Ok(Self::SnowflakeArcticEmbedMQ),
"snowflake-arctic-embed-m-long" => Ok(Self::SnowflakeArcticEmbedMLong),
"snowflake-arctic-embed-m-long-q" => Ok(Self::SnowflakeArcticEmbedMLongQ),
"snowflake-arctic-embed-l" | "snowflake-l" => Ok(Self::SnowflakeArcticEmbedL),
"snowflake-arctic-embed-l-q" => Ok(Self::SnowflakeArcticEmbedLQ),
_ => Err(AppError::Internal(format!(
"Unknown embedding model: {}. Use one of: {}",
s,
EmbeddingModelType::all()
.iter()
.map(|m| m.to_string())
.collect::<Vec<_>>()
.join(", ")
))),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "kebab-case")]
pub enum SparseModelType {
#[default]
SpladePpV1,
}
impl SparseModelType {
pub fn to_fastembed_model(&self) -> SparseModel {
match self {
Self::SpladePpV1 => SparseModel::SPLADEPPV1,
}
}
}
impl Display for SparseModelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = match self {
Self::SpladePpV1 => "splade-pp-v1",
};
write!(f, "{}", name)
}
}
impl FromStr for SparseModelType {
type Err = AppError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"splade-pp-v1" | "splade" => Ok(Self::SpladePpV1),
_ => Err(AppError::Internal(format!(
"Unknown sparse model: {}. Use: splade-pp-v1",
s
))),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
#[serde(default)]
pub model: EmbeddingModelType,
#[serde(default = "default_batch_size")]
pub batch_size: usize,
#[serde(default = "default_show_progress")]
pub show_download_progress: bool,
#[serde(default)]
pub sparse_enabled: bool,
#[serde(default)]
pub sparse_model: SparseModelType,
}
fn default_batch_size() -> usize {
32
}
fn default_show_progress() -> bool {
true
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
model: EmbeddingModelType::default(),
batch_size: default_batch_size(),
show_download_progress: default_show_progress(),
sparse_enabled: false,
sparse_model: SparseModelType::default(),
}
}
}
pub struct EmbeddingService {
model: Arc<Mutex<TextEmbedding>>,
sparse_model: Option<Arc<Mutex<fastembed::SparseTextEmbedding>>>,
config: EmbeddingConfig,
}
impl EmbeddingService {
pub fn new(config: EmbeddingConfig) -> Result<Self> {
let model_name = format!("{:?}", config.model.to_fastembed_model());
let model_lock = get_model_lock(&model_name);
let _guard = model_lock.lock().map_err(|e| {
AppError::Internal(format!(
"Failed to acquire model initialization lock: {}",
e
))
})?;
let model = TextEmbedding::try_new(
InitOptions::new(config.model.to_fastembed_model())
.with_show_download_progress(config.show_download_progress),
)
.map_err(|e| AppError::Internal(format!("Failed to initialize embedding model: {}", e)))?;
let sparse_model = if config.sparse_enabled {
let sparse_model_name = format!("{:?}", config.sparse_model.to_fastembed_model());
let sparse_lock = get_model_lock(&sparse_model_name);
let _sparse_guard = sparse_lock.lock().map_err(|e| {
AppError::Internal(format!("Failed to acquire sparse model lock: {}", e))
})?;
Some(
fastembed::SparseTextEmbedding::try_new(
fastembed::SparseInitOptions::new(config.sparse_model.to_fastembed_model())
.with_show_download_progress(config.show_download_progress),
)
.map_err(|e| {
AppError::Internal(format!(
"Failed to initialize sparse embedding model: {}",
e
))
})?,
)
} else {
None
};
Ok(Self {
model: Arc::new(Mutex::new(model)),
sparse_model: sparse_model.map(|m| Arc::new(Mutex::new(m))),
config,
})
}
pub fn with_default_model() -> Result<Self> {
Self::new(EmbeddingConfig::default())
}
pub fn with_model(model: EmbeddingModelType) -> Result<Self> {
Self::new(EmbeddingConfig {
model,
..Default::default()
})
}
pub fn model_type(&self) -> EmbeddingModelType {
self.config.model
}
pub fn dimensions(&self) -> usize {
self.config.model.dimensions()
}
pub fn config(&self) -> &EmbeddingConfig {
&self.config
}
pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
let embeddings = self.embed_texts(&[text.to_string()]).await?;
embeddings
.into_iter()
.next()
.ok_or_else(|| AppError::Internal("No embedding generated".to_string()))
}
pub async fn embed_texts<S: AsRef<str> + Send + Sync + 'static>(
&self,
texts: &[S],
) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let texts_owned: Vec<String> = texts.iter().map(|s| s.as_ref().to_string()).collect();
let batch_size = self.config.batch_size;
let model = Arc::clone(&self.model);
spawn_blocking(move || {
let mut model_guard = model
.lock()
.map_err(|e| AppError::Internal(format!("Failed to acquire model lock: {}", e)))?;
let refs: Vec<&str> = texts_owned.iter().map(|s| s.as_str()).collect();
model_guard
.embed(refs, Some(batch_size))
.map_err(|e| AppError::Internal(format!("Embedding failed: {}", e)))
})
.await
.map_err(|e| AppError::Internal(format!("Blocking task failed: {}", e)))?
}
pub async fn embed_sparse<S: AsRef<str> + Send + Sync + 'static>(
&self,
texts: &[S],
) -> Result<Vec<fastembed::SparseEmbedding>> {
let sparse_model = self.sparse_model.as_ref().ok_or_else(|| {
AppError::Internal(
"Sparse embeddings not enabled. Set sparse_enabled: true in config.".to_string(),
)
})?;
let texts_owned: Vec<String> = texts.iter().map(|s| s.as_ref().to_string()).collect();
let batch_size = self.config.batch_size;
let model = Arc::clone(sparse_model);
spawn_blocking(move || {
let mut model_guard = model.lock().map_err(|e| {
AppError::Internal(format!("Failed to acquire sparse model lock: {}", e))
})?;
let refs: Vec<&str> = texts_owned.iter().map(|s| s.as_str()).collect();
model_guard
.embed(refs, Some(batch_size))
.map_err(|e| AppError::Internal(format!("Sparse embedding failed: {}", e)))
})
.await
.map_err(|e| AppError::Internal(format!("Blocking task failed: {}", e)))?
}
}
use crate::rag::cache::{CacheConfig, CacheStats, EmbeddingCache, LruEmbeddingCache, NoOpCache};
pub struct CachedEmbeddingService {
inner: EmbeddingService,
cache: Box<dyn EmbeddingCache>,
}
impl CachedEmbeddingService {
pub fn new(embedding_config: EmbeddingConfig, cache_config: CacheConfig) -> Result<Self> {
let inner = EmbeddingService::new(embedding_config)?;
let cache: Box<dyn EmbeddingCache> = if cache_config.enabled {
Box::new(LruEmbeddingCache::new(cache_config))
} else {
Box::new(NoOpCache::new())
};
Ok(Self { inner, cache })
}
pub fn with_defaults() -> Result<Self> {
Self::new(EmbeddingConfig::default(), CacheConfig::default())
}
pub fn with_model(model: EmbeddingModelType) -> Result<Self> {
Self::new(
EmbeddingConfig {
model,
..Default::default()
},
CacheConfig::default(),
)
}
pub fn without_cache(embedding_config: EmbeddingConfig) -> Result<Self> {
Self::new(
embedding_config,
CacheConfig {
enabled: false,
..Default::default()
},
)
}
fn model_name(&self) -> String {
self.inner.model_type().to_string()
}
pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
let cache_key = self.cache.compute_key(text, &self.model_name());
if let Some(cached) = self.cache.get(&cache_key) {
return Ok(cached);
}
let embedding = self.inner.embed_text(text).await?;
self.cache.set(&cache_key, embedding.clone(), None)?;
Ok(embedding)
}
pub async fn embed_texts<S: AsRef<str> + Send + Sync + 'static>(
&self,
texts: &[S],
) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let model_name = self.model_name();
let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
let mut uncached_indices: Vec<usize> = Vec::new();
let mut uncached_texts: Vec<String> = Vec::new();
for (i, text) in texts.iter().enumerate() {
let text_str = text.as_ref();
let cache_key = self.cache.compute_key(text_str, &model_name);
if let Some(cached) = self.cache.get(&cache_key) {
results[i] = Some(cached);
} else {
uncached_indices.push(i);
uncached_texts.push(text_str.to_string());
}
}
if !uncached_texts.is_empty() {
let new_embeddings = self.inner.embed_texts(&uncached_texts).await?;
for (j, embedding) in new_embeddings.into_iter().enumerate() {
let idx = uncached_indices[j];
let cache_key = self.cache.compute_key(&uncached_texts[j], &model_name);
self.cache.set(&cache_key, embedding.clone(), None)?;
results[idx] = Some(embedding);
}
}
Ok(results.into_iter().flatten().collect())
}
pub fn model_type(&self) -> EmbeddingModelType {
self.inner.model_type()
}
pub fn dimensions(&self) -> usize {
self.inner.dimensions()
}
pub fn config(&self) -> &EmbeddingConfig {
self.inner.config()
}
pub fn cache_stats(&self) -> CacheStats {
self.cache.stats()
}
pub fn clear_cache(&self) -> Result<()> {
self.cache.clear()
}
pub fn invalidate(&self, text: &str) -> Result<()> {
let cache_key = self.cache.compute_key(text, &self.model_name());
self.cache.invalidate(&cache_key)
}
pub fn is_cache_enabled(&self) -> bool {
self.cache.is_enabled()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[allow(dead_code)]
#[derive(Default)]
pub enum AccelerationBackend {
#[default]
Cpu,
Cuda {
device_id: usize,
},
Metal,
Vulkan,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_dimensions() {
assert_eq!(EmbeddingModelType::BgeSmallEnV15.dimensions(), 384);
assert_eq!(EmbeddingModelType::BgeBaseEnV15.dimensions(), 768);
assert_eq!(EmbeddingModelType::BgeLargeEnV15.dimensions(), 1024);
assert_eq!(EmbeddingModelType::MultilingualE5Large.dimensions(), 1024);
}
#[test]
fn test_model_from_str() {
assert_eq!(
"bge-small-en-v1.5".parse::<EmbeddingModelType>().unwrap(),
EmbeddingModelType::BgeSmallEnV15
);
assert_eq!(
"multilingual-e5-large"
.parse::<EmbeddingModelType>()
.unwrap(),
EmbeddingModelType::MultilingualE5Large
);
assert_eq!(
"minilm-l6".parse::<EmbeddingModelType>().unwrap(),
EmbeddingModelType::AllMiniLmL6V2
);
}
#[test]
fn test_model_is_multilingual() {
assert!(EmbeddingModelType::MultilingualE5Small.is_multilingual());
assert!(EmbeddingModelType::MultilingualE5Large.is_multilingual());
assert!(!EmbeddingModelType::BgeSmallEnV15.is_multilingual());
}
#[test]
fn test_model_max_context() {
assert_eq!(
EmbeddingModelType::NomicEmbedTextV15.max_context_length(),
8192
);
assert_eq!(
EmbeddingModelType::NomicEmbedTextV1.max_context_length(),
8192
);
assert_eq!(EmbeddingModelType::BgeSmallEnV15.max_context_length(), 512);
}
#[test]
fn test_default_config() {
let config = EmbeddingConfig::default();
assert_eq!(config.model, EmbeddingModelType::BgeSmallEnV15);
assert_eq!(config.batch_size, 32);
assert!(config.show_download_progress);
assert!(!config.sparse_enabled);
}
#[test]
fn test_all_models_listed() {
let all = EmbeddingModelType::all();
assert!(all.len() >= 38); assert!(all.contains(&EmbeddingModelType::BgeSmallEnV15));
assert!(all.contains(&EmbeddingModelType::MultilingualE5Large));
}
}