use crate::types::{AppError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Display;
use std::str::FromStr;
use std::sync::{Arc, Mutex, OnceLock};
use tokio::task::spawn_blocking;
pub use fastembed::{EmbeddingModel as FastEmbedModel, InitOptions, SparseModel, TextEmbedding};
static MODEL_INIT_LOCKS: OnceLock<Mutex<HashMap<String, Arc<Mutex<()>>>>> = OnceLock::new();
fn get_model_lock(model_name: &str) -> Arc<Mutex<()>> {
let locks = MODEL_INIT_LOCKS.get_or_init(|| Mutex::new(HashMap::new()));
let mut map = locks.lock().unwrap();
map.entry(model_name.to_string())
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone()
}
pub(crate) fn pre_download_model(
repo_id: &str,
files: &[&str],
cache_dir: &std::path::Path,
) -> Result<()> {
let folder_name = format!("models--{}", repo_id.replace('/', "--"));
let snapshot_hash = "lancor-prefetch"; let snapshot_dir = cache_dir.join(&folder_name).join("snapshots").join(snapshot_hash);
let refs_dir = cache_dir.join(&folder_name).join("refs");
std::fs::create_dir_all(&snapshot_dir).ok();
std::fs::create_dir_all(&refs_dir).ok();
let ref_path = refs_dir.join("main");
if !ref_path.exists() {
std::fs::write(&ref_path, snapshot_hash).ok();
}
let hub = lancor::hub::HubClient::with_cache_dir(cache_dir.to_path_buf())
.map_err(|e| AppError::Internal(format!("Failed to create hub client: {}", e)))?;
let rt = tokio::runtime::Handle::current();
for filename in files {
let target = snapshot_dir.join(filename);
if target.exists() && std::fs::metadata(&target).map(|m| m.len() > 0).unwrap_or(false) {
tracing::debug!("Already cached: {}/{}", repo_id, filename);
continue;
}
if let Some(parent) = target.parent() {
std::fs::create_dir_all(parent).ok();
}
match tokio::task::block_in_place(|| {
rt.block_on(hub.download(repo_id, filename, None))
}) {
Ok(downloaded_path) => {
if downloaded_path != target {
std::fs::copy(&downloaded_path, &target).ok();
}
tracing::info!("Pre-downloaded {}/{} ({} bytes)", repo_id, filename,
std::fs::metadata(&target).map(|m| m.len()).unwrap_or(0));
}
Err(e) => {
tracing::warn!("Could not pre-download {}/{}: {}", repo_id, filename, e);
}
}
}
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "kebab-case")]
pub enum EmbeddingModelType {
#[default]
BgeSmallEnV15,
BgeSmallEnV15Q,
AllMiniLmL6V2,
AllMiniLmL6V2Q,
AllMiniLmL12V2,
AllMiniLmL12V2Q,
AllMpnetBaseV2,
BgeBaseEnV15,
BgeBaseEnV15Q,
BgeLargeEnV15,
BgeLargeEnV15Q,
MultilingualE5Small,
MultilingualE5Base,
MultilingualE5Large,
ParaphraseMiniLmL12V2,
ParaphraseMiniLmL12V2Q,
ParaphraseMultilingualMpnetBaseV2,
BgeSmallZhV15,
BgeLargeZhV15,
NomicEmbedTextV1,
NomicEmbedTextV15,
NomicEmbedTextV15Q,
MxbaiEmbedLargeV1,
MxbaiEmbedLargeV1Q,
GteBaseEnV15,
GteBaseEnV15Q,
GteLargeEnV15,
GteLargeEnV15Q,
ClipVitB32,
JinaEmbeddingsV2BaseCode,
EmbeddingGemma300M,
ModernBertEmbedLarge,
SnowflakeArcticEmbedXs,
SnowflakeArcticEmbedXsQ,
SnowflakeArcticEmbedS,
SnowflakeArcticEmbedSQ,
SnowflakeArcticEmbedM,
SnowflakeArcticEmbedMQ,
SnowflakeArcticEmbedMLong,
SnowflakeArcticEmbedMLongQ,
SnowflakeArcticEmbedL,
SnowflakeArcticEmbedLQ,
}
impl EmbeddingModelType {
pub fn to_fastembed_model(&self) -> FastEmbedModel {
match self {
Self::BgeSmallEnV15 => FastEmbedModel::BGESmallENV15,
Self::BgeSmallEnV15Q => FastEmbedModel::BGESmallENV15Q,
Self::AllMiniLmL6V2 => FastEmbedModel::AllMiniLML6V2,
Self::AllMiniLmL6V2Q => FastEmbedModel::AllMiniLML6V2Q,
Self::AllMiniLmL12V2 => FastEmbedModel::AllMiniLML12V2,
Self::AllMiniLmL12V2Q => FastEmbedModel::AllMiniLML12V2Q,
Self::AllMpnetBaseV2 => FastEmbedModel::AllMpnetBaseV2,
Self::BgeBaseEnV15 => FastEmbedModel::BGEBaseENV15,
Self::BgeBaseEnV15Q => FastEmbedModel::BGEBaseENV15Q,
Self::BgeLargeEnV15 => FastEmbedModel::BGELargeENV15,
Self::BgeLargeEnV15Q => FastEmbedModel::BGELargeENV15Q,
Self::MultilingualE5Small => FastEmbedModel::MultilingualE5Small,
Self::MultilingualE5Base => FastEmbedModel::MultilingualE5Base,
Self::MultilingualE5Large => FastEmbedModel::MultilingualE5Large,
Self::ParaphraseMiniLmL12V2 => FastEmbedModel::ParaphraseMLMiniLML12V2,
Self::ParaphraseMiniLmL12V2Q => FastEmbedModel::ParaphraseMLMiniLML12V2Q,
Self::ParaphraseMultilingualMpnetBaseV2 => FastEmbedModel::ParaphraseMLMpnetBaseV2,
Self::BgeSmallZhV15 => FastEmbedModel::BGESmallZHV15,
Self::BgeLargeZhV15 => FastEmbedModel::BGELargeZHV15,
Self::NomicEmbedTextV1 => FastEmbedModel::NomicEmbedTextV1,
Self::NomicEmbedTextV15 => FastEmbedModel::NomicEmbedTextV15,
Self::NomicEmbedTextV15Q => FastEmbedModel::NomicEmbedTextV15Q,
Self::MxbaiEmbedLargeV1 => FastEmbedModel::MxbaiEmbedLargeV1,
Self::MxbaiEmbedLargeV1Q => FastEmbedModel::MxbaiEmbedLargeV1Q,
Self::GteBaseEnV15 => FastEmbedModel::GTEBaseENV15,
Self::GteBaseEnV15Q => FastEmbedModel::GTEBaseENV15Q,
Self::GteLargeEnV15 => FastEmbedModel::GTELargeENV15,
Self::GteLargeEnV15Q => FastEmbedModel::GTELargeENV15Q,
Self::ClipVitB32 => FastEmbedModel::ClipVitB32,
Self::JinaEmbeddingsV2BaseCode => FastEmbedModel::JinaEmbeddingsV2BaseCode,
Self::EmbeddingGemma300M => FastEmbedModel::EmbeddingGemma300M,
Self::ModernBertEmbedLarge => FastEmbedModel::ModernBertEmbedLarge,
Self::SnowflakeArcticEmbedXs => FastEmbedModel::SnowflakeArcticEmbedXS,
Self::SnowflakeArcticEmbedXsQ => FastEmbedModel::SnowflakeArcticEmbedXSQ,
Self::SnowflakeArcticEmbedS => FastEmbedModel::SnowflakeArcticEmbedS,
Self::SnowflakeArcticEmbedSQ => FastEmbedModel::SnowflakeArcticEmbedSQ,
Self::SnowflakeArcticEmbedM => FastEmbedModel::SnowflakeArcticEmbedM,
Self::SnowflakeArcticEmbedMQ => FastEmbedModel::SnowflakeArcticEmbedMQ,
Self::SnowflakeArcticEmbedMLong => FastEmbedModel::SnowflakeArcticEmbedMLong,
Self::SnowflakeArcticEmbedMLongQ => FastEmbedModel::SnowflakeArcticEmbedMLongQ,
Self::SnowflakeArcticEmbedL => FastEmbedModel::SnowflakeArcticEmbedL,
Self::SnowflakeArcticEmbedLQ => FastEmbedModel::SnowflakeArcticEmbedLQ,
}
}
pub fn hf_repo_id(&self) -> &'static str {
match self {
Self::BgeSmallEnV15 | Self::BgeSmallEnV15Q => "Xenova/bge-small-en-v1.5",
Self::AllMiniLmL6V2 | Self::AllMiniLmL6V2Q => "sentence-transformers/all-MiniLM-L6-v2",
Self::AllMiniLmL12V2 | Self::AllMiniLmL12V2Q => "sentence-transformers/all-MiniLM-L12-v2",
_ => "Xenova/bge-small-en-v1.5", }
}
pub fn dimensions(&self) -> usize {
match self {
Self::BgeSmallEnV15
| Self::BgeSmallEnV15Q
| Self::AllMiniLmL6V2
| Self::AllMiniLmL6V2Q
| Self::AllMiniLmL12V2
| Self::AllMiniLmL12V2Q
| Self::MultilingualE5Small
| Self::SnowflakeArcticEmbedXs
| Self::SnowflakeArcticEmbedXsQ
| Self::SnowflakeArcticEmbedS
| Self::SnowflakeArcticEmbedSQ => 384,
Self::BgeSmallZhV15 | Self::ClipVitB32 => 512,
Self::AllMpnetBaseV2
| Self::BgeBaseEnV15
| Self::BgeBaseEnV15Q
| Self::MultilingualE5Base
| Self::ParaphraseMiniLmL12V2
| Self::ParaphraseMiniLmL12V2Q
| Self::ParaphraseMultilingualMpnetBaseV2
| Self::NomicEmbedTextV1
| Self::NomicEmbedTextV15
| Self::NomicEmbedTextV15Q
| Self::GteBaseEnV15
| Self::GteBaseEnV15Q
| Self::JinaEmbeddingsV2BaseCode
| Self::EmbeddingGemma300M
| Self::SnowflakeArcticEmbedM
| Self::SnowflakeArcticEmbedMQ
| Self::SnowflakeArcticEmbedMLong
| Self::SnowflakeArcticEmbedMLongQ => 768,
Self::BgeLargeEnV15
| Self::BgeLargeEnV15Q
| Self::BgeLargeZhV15
| Self::MultilingualE5Large
| Self::MxbaiEmbedLargeV1
| Self::MxbaiEmbedLargeV1Q
| Self::GteLargeEnV15
| Self::GteLargeEnV15Q
| Self::ModernBertEmbedLarge
| Self::SnowflakeArcticEmbedL
| Self::SnowflakeArcticEmbedLQ => 1024,
}
}
pub fn is_quantized(&self) -> bool {
matches!(
self,
Self::BgeSmallEnV15Q
| Self::AllMiniLmL6V2Q
| Self::AllMiniLmL12V2Q
| Self::BgeBaseEnV15Q
| Self::BgeLargeEnV15Q
| Self::ParaphraseMiniLmL12V2Q
| Self::NomicEmbedTextV15Q
| Self::MxbaiEmbedLargeV1Q
| Self::GteBaseEnV15Q
| Self::GteLargeEnV15Q
| Self::SnowflakeArcticEmbedXsQ
| Self::SnowflakeArcticEmbedSQ
| Self::SnowflakeArcticEmbedMQ
| Self::SnowflakeArcticEmbedMLongQ
| Self::SnowflakeArcticEmbedLQ
)
}
pub fn is_multilingual(&self) -> bool {
matches!(
self,
Self::MultilingualE5Small
| Self::MultilingualE5Base
| Self::MultilingualE5Large
| Self::ParaphraseMultilingualMpnetBaseV2
| Self::BgeSmallZhV15
| Self::BgeLargeZhV15
)
}
pub fn max_context_length(&self) -> usize {
match self {
Self::NomicEmbedTextV1 | Self::NomicEmbedTextV15 | Self::NomicEmbedTextV15Q => 8192,
Self::SnowflakeArcticEmbedMLong | Self::SnowflakeArcticEmbedMLongQ => 2048,
_ => 512,
}
}
pub fn all() -> Vec<Self> {
vec![
Self::BgeSmallEnV15,
Self::BgeSmallEnV15Q,
Self::AllMiniLmL6V2,
Self::AllMiniLmL6V2Q,
Self::AllMiniLmL12V2,
Self::AllMiniLmL12V2Q,
Self::AllMpnetBaseV2,
Self::BgeBaseEnV15,
Self::BgeBaseEnV15Q,
Self::BgeLargeEnV15,
Self::BgeLargeEnV15Q,
Self::MultilingualE5Small,
Self::MultilingualE5Base,
Self::MultilingualE5Large,
Self::ParaphraseMiniLmL12V2,
Self::ParaphraseMiniLmL12V2Q,
Self::ParaphraseMultilingualMpnetBaseV2,
Self::BgeSmallZhV15,
Self::BgeLargeZhV15,
Self::NomicEmbedTextV1,
Self::NomicEmbedTextV15,
Self::NomicEmbedTextV15Q,
Self::MxbaiEmbedLargeV1,
Self::MxbaiEmbedLargeV1Q,
Self::GteBaseEnV15,
Self::GteBaseEnV15Q,
Self::GteLargeEnV15,
Self::GteLargeEnV15Q,
Self::ClipVitB32,
Self::JinaEmbeddingsV2BaseCode,
Self::EmbeddingGemma300M,
Self::ModernBertEmbedLarge,
Self::SnowflakeArcticEmbedXs,
Self::SnowflakeArcticEmbedXsQ,
Self::SnowflakeArcticEmbedS,
Self::SnowflakeArcticEmbedSQ,
Self::SnowflakeArcticEmbedM,
Self::SnowflakeArcticEmbedMQ,
Self::SnowflakeArcticEmbedMLong,
Self::SnowflakeArcticEmbedMLongQ,
Self::SnowflakeArcticEmbedL,
Self::SnowflakeArcticEmbedLQ,
]
}
}
impl Display for EmbeddingModelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = match self {
Self::BgeSmallEnV15 => "bge-small-en-v1.5",
Self::BgeSmallEnV15Q => "bge-small-en-v1.5-q",
Self::AllMiniLmL6V2 => "all-minilm-l6-v2",
Self::AllMiniLmL6V2Q => "all-minilm-l6-v2-q",
Self::AllMiniLmL12V2 => "all-minilm-l12-v2",
Self::AllMiniLmL12V2Q => "all-minilm-l12-v2-q",
Self::AllMpnetBaseV2 => "all-mpnet-base-v2",
Self::BgeBaseEnV15 => "bge-base-en-v1.5",
Self::BgeBaseEnV15Q => "bge-base-en-v1.5-q",
Self::BgeLargeEnV15 => "bge-large-en-v1.5",
Self::BgeLargeEnV15Q => "bge-large-en-v1.5-q",
Self::MultilingualE5Small => "multilingual-e5-small",
Self::MultilingualE5Base => "multilingual-e5-base",
Self::MultilingualE5Large => "multilingual-e5-large",
Self::ParaphraseMiniLmL12V2 => "paraphrase-minilm-l12-v2",
Self::ParaphraseMiniLmL12V2Q => "paraphrase-minilm-l12-v2-q",
Self::ParaphraseMultilingualMpnetBaseV2 => "paraphrase-multilingual-mpnet-base-v2",
Self::BgeSmallZhV15 => "bge-small-zh-v1.5",
Self::BgeLargeZhV15 => "bge-large-zh-v1.5",
Self::NomicEmbedTextV1 => "nomic-embed-text-v1",
Self::NomicEmbedTextV15 => "nomic-embed-text-v1.5",
Self::NomicEmbedTextV15Q => "nomic-embed-text-v1.5-q",
Self::MxbaiEmbedLargeV1 => "mxbai-embed-large-v1",
Self::MxbaiEmbedLargeV1Q => "mxbai-embed-large-v1-q",
Self::GteBaseEnV15 => "gte-base-en-v1.5",
Self::GteBaseEnV15Q => "gte-base-en-v1.5-q",
Self::GteLargeEnV15 => "gte-large-en-v1.5",
Self::GteLargeEnV15Q => "gte-large-en-v1.5-q",
Self::ClipVitB32 => "clip-vit-b-32",
Self::JinaEmbeddingsV2BaseCode => "jina-embeddings-v2-base-code",
Self::EmbeddingGemma300M => "embedding-gemma-300m",
Self::ModernBertEmbedLarge => "modernbert-embed-large",
Self::SnowflakeArcticEmbedXs => "snowflake-arctic-embed-xs",
Self::SnowflakeArcticEmbedXsQ => "snowflake-arctic-embed-xs-q",
Self::SnowflakeArcticEmbedS => "snowflake-arctic-embed-s",
Self::SnowflakeArcticEmbedSQ => "snowflake-arctic-embed-s-q",
Self::SnowflakeArcticEmbedM => "snowflake-arctic-embed-m",
Self::SnowflakeArcticEmbedMQ => "snowflake-arctic-embed-m-q",
Self::SnowflakeArcticEmbedMLong => "snowflake-arctic-embed-m-long",
Self::SnowflakeArcticEmbedMLongQ => "snowflake-arctic-embed-m-long-q",
Self::SnowflakeArcticEmbedL => "snowflake-arctic-embed-l",
Self::SnowflakeArcticEmbedLQ => "snowflake-arctic-embed-l-q",
};
write!(f, "{}", name)
}
}
impl FromStr for EmbeddingModelType {
type Err = AppError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"bge-small-en-v1.5" | "bge-small-en" | "bge-small" => Ok(Self::BgeSmallEnV15),
"bge-small-en-v1.5-q" => Ok(Self::BgeSmallEnV15Q),
"all-minilm-l6-v2" | "minilm-l6" => Ok(Self::AllMiniLmL6V2),
"all-minilm-l6-v2-q" => Ok(Self::AllMiniLmL6V2Q),
"all-minilm-l12-v2" | "minilm-l12" => Ok(Self::AllMiniLmL12V2),
"all-minilm-l12-v2-q" => Ok(Self::AllMiniLmL12V2Q),
"all-mpnet-base-v2" | "mpnet" => Ok(Self::AllMpnetBaseV2),
"bge-base-en-v1.5" | "bge-base-en" | "bge-base" => Ok(Self::BgeBaseEnV15),
"bge-base-en-v1.5-q" => Ok(Self::BgeBaseEnV15Q),
"bge-large-en-v1.5" | "bge-large-en" | "bge-large" => Ok(Self::BgeLargeEnV15),
"bge-large-en-v1.5-q" => Ok(Self::BgeLargeEnV15Q),
"multilingual-e5-small" | "e5-small" => Ok(Self::MultilingualE5Small),
"multilingual-e5-base" | "e5-base" => Ok(Self::MultilingualE5Base),
"multilingual-e5-large" | "e5-large" => Ok(Self::MultilingualE5Large),
"paraphrase-minilm-l12-v2" => Ok(Self::ParaphraseMiniLmL12V2),
"paraphrase-minilm-l12-v2-q" => Ok(Self::ParaphraseMiniLmL12V2Q),
"paraphrase-multilingual-mpnet-base-v2" => Ok(Self::ParaphraseMultilingualMpnetBaseV2),
"bge-small-zh-v1.5" | "bge-small-zh" => Ok(Self::BgeSmallZhV15),
"bge-large-zh-v1.5" | "bge-large-zh" => Ok(Self::BgeLargeZhV15),
"nomic-embed-text-v1" | "nomic-v1" => Ok(Self::NomicEmbedTextV1),
"nomic-embed-text-v1.5" | "nomic-v1.5" | "nomic" => Ok(Self::NomicEmbedTextV15),
"nomic-embed-text-v1.5-q" => Ok(Self::NomicEmbedTextV15Q),
"mxbai-embed-large-v1" | "mxbai" => Ok(Self::MxbaiEmbedLargeV1),
"mxbai-embed-large-v1-q" => Ok(Self::MxbaiEmbedLargeV1Q),
"gte-base-en-v1.5" | "gte-base" => Ok(Self::GteBaseEnV15),
"gte-base-en-v1.5-q" => Ok(Self::GteBaseEnV15Q),
"gte-large-en-v1.5" | "gte-large" => Ok(Self::GteLargeEnV15),
"gte-large-en-v1.5-q" => Ok(Self::GteLargeEnV15Q),
"clip-vit-b-32" | "clip" => Ok(Self::ClipVitB32),
"jina-embeddings-v2-base-code" | "jina-code" => Ok(Self::JinaEmbeddingsV2BaseCode),
"embedding-gemma-300m" | "gemma-300m" | "gemma" => Ok(Self::EmbeddingGemma300M),
"modernbert-embed-large" | "modernbert" => Ok(Self::ModernBertEmbedLarge),
"snowflake-arctic-embed-xs" => Ok(Self::SnowflakeArcticEmbedXs),
"snowflake-arctic-embed-xs-q" => Ok(Self::SnowflakeArcticEmbedXsQ),
"snowflake-arctic-embed-s" => Ok(Self::SnowflakeArcticEmbedS),
"snowflake-arctic-embed-s-q" => Ok(Self::SnowflakeArcticEmbedSQ),
"snowflake-arctic-embed-m" => Ok(Self::SnowflakeArcticEmbedM),
"snowflake-arctic-embed-m-q" => Ok(Self::SnowflakeArcticEmbedMQ),
"snowflake-arctic-embed-m-long" => Ok(Self::SnowflakeArcticEmbedMLong),
"snowflake-arctic-embed-m-long-q" => Ok(Self::SnowflakeArcticEmbedMLongQ),
"snowflake-arctic-embed-l" | "snowflake-l" => Ok(Self::SnowflakeArcticEmbedL),
"snowflake-arctic-embed-l-q" => Ok(Self::SnowflakeArcticEmbedLQ),
_ => Err(AppError::Internal(format!(
"Unknown embedding model: {}. Use one of: {}",
s,
EmbeddingModelType::all()
.iter()
.map(|m| m.to_string())
.collect::<Vec<_>>()
.join(", ")
))),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "kebab-case")]
pub enum SparseModelType {
#[default]
SpladePpV1,
}
impl SparseModelType {
pub fn to_fastembed_model(&self) -> SparseModel {
match self {
Self::SpladePpV1 => SparseModel::SPLADEPPV1,
}
}
}
impl Display for SparseModelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = match self {
Self::SpladePpV1 => "splade-pp-v1",
};
write!(f, "{}", name)
}
}
impl FromStr for SparseModelType {
type Err = AppError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"splade-pp-v1" | "splade" => Ok(Self::SpladePpV1),
_ => Err(AppError::Internal(format!(
"Unknown sparse model: {}. Use: splade-pp-v1",
s
))),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
#[serde(default)]
pub model: EmbeddingModelType,
#[serde(default = "default_batch_size")]
pub batch_size: usize,
#[serde(default = "default_show_progress")]
pub show_download_progress: bool,
#[serde(default)]
pub sparse_enabled: bool,
#[serde(default)]
pub sparse_model: SparseModelType,
}
fn default_batch_size() -> usize {
32
}
fn default_show_progress() -> bool {
true
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
model: EmbeddingModelType::default(),
batch_size: default_batch_size(),
show_download_progress: default_show_progress(),
sparse_enabled: false,
sparse_model: SparseModelType::default(),
}
}
}
pub struct EmbeddingService {
model: Arc<Mutex<TextEmbedding>>,
sparse_model: Option<Arc<Mutex<fastembed::SparseTextEmbedding>>>,
config: EmbeddingConfig,
}
impl EmbeddingService {
pub fn new(config: EmbeddingConfig) -> Result<Self> {
let model_name = format!("{:?}", config.model.to_fastembed_model());
let model_lock = get_model_lock(&model_name);
let _guard = model_lock.lock().map_err(|e| {
AppError::Internal(format!(
"Failed to acquire model initialization lock: {}",
e
))
})?;
let cache_dir = std::env::var("FASTEMBED_CACHE_DIR")
.map(std::path::PathBuf::from)
.unwrap_or_else(|_| std::path::PathBuf::from(".fastembed_cache"));
std::fs::create_dir_all(&cache_dir).ok();
let model_repo = config.model.hf_repo_id();
pre_download_model(model_repo, &["onnx/model.onnx", "tokenizer.json", "config.json"], &cache_dir)?;
let folder_name = format!("models--{}", model_repo.replace('/', "--"));
let model_base = cache_dir.join(&folder_name).join("snapshots");
let snapshot_dir = if model_base.exists() {
std::fs::read_dir(&model_base).ok().and_then(|entries| {
entries.filter_map(|e| e.ok()).find(|e| {
let p = e.path();
p.join("onnx").join("model.onnx").exists()
&& p.join("tokenizer.json").exists()
&& p.join("config.json").exists()
&& p.join("special_tokens_map.json").exists()
}).map(|e| e.path())
})
} else {
let native = cache_dir.join(model_repo.replace('/', "--"));
if native.join("onnx").join("model.onnx").exists() { Some(native) } else { None }
};
let model = if let Some(ref snap) = snapshot_dir {
tracing::info!("Loading embedding model from local cache: {}", snap.display());
let onnx_bytes = std::fs::read(snap.join("onnx").join("model.onnx"))
.map_err(|e| AppError::Internal(format!("Failed to read ONNX: {}", e)))?;
let tokenizer_file = std::fs::read(snap.join("tokenizer.json"))
.map_err(|e| AppError::Internal(format!("Failed to read tokenizer.json: {}", e)))?;
let config_file = std::fs::read(snap.join("config.json"))
.map_err(|e| AppError::Internal(format!("Failed to read config.json: {}", e)))?;
let special_tokens_map_file = std::fs::read(snap.join("special_tokens_map.json"))
.map_err(|e| AppError::Internal(format!("Failed to read special_tokens_map.json: {}", e)))?;
let tokenizer_config_file = std::fs::read(snap.join("tokenizer_config.json"))
.map_err(|e| AppError::Internal(format!("Failed to read tokenizer_config.json: {}", e)))?;
let tokenizer_files = fastembed::TokenizerFiles {
tokenizer_file,
config_file,
special_tokens_map_file,
tokenizer_config_file,
};
let user_model = fastembed::UserDefinedEmbeddingModel::new(onnx_bytes, tokenizer_files);
TextEmbedding::try_new_from_user_defined(user_model, fastembed::InitOptionsUserDefined::new())
.map_err(|e| AppError::Internal(format!("Failed to load local model: {}", e)))?
} else {
tracing::warn!("No local ONNX cache, attempting HF download (may fail on xethub)");
TextEmbedding::try_new(
InitOptions::new(config.model.to_fastembed_model())
.with_cache_dir(cache_dir.clone())
.with_show_download_progress(true),
)
.map_err(|e| AppError::Internal(format!("Failed to init embedding model: {}", e)))?
};
let sparse_model = if config.sparse_enabled {
let sparse_model_name = format!("{:?}", config.sparse_model.to_fastembed_model());
let sparse_lock = get_model_lock(&sparse_model_name);
let _sparse_guard = sparse_lock.lock().map_err(|e| {
AppError::Internal(format!("Failed to acquire sparse model lock: {}", e))
})?;
Some(
fastembed::SparseTextEmbedding::try_new(
fastembed::SparseInitOptions::new(config.sparse_model.to_fastembed_model())
.with_show_download_progress(config.show_download_progress),
)
.map_err(|e| {
AppError::Internal(format!(
"Failed to initialize sparse embedding model: {}",
e
))
})?,
)
} else {
None
};
Ok(Self {
model: Arc::new(Mutex::new(model)),
sparse_model: sparse_model.map(|m| Arc::new(Mutex::new(m))),
config,
})
}
pub fn with_default_model() -> Result<Self> {
Self::new(EmbeddingConfig::default())
}
pub fn with_model(model: EmbeddingModelType) -> Result<Self> {
Self::new(EmbeddingConfig {
model,
..Default::default()
})
}
pub fn model_type(&self) -> EmbeddingModelType {
self.config.model
}
pub fn dimensions(&self) -> usize {
self.config.model.dimensions()
}
pub fn config(&self) -> &EmbeddingConfig {
&self.config
}
pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
let embeddings = self.embed_texts(&[text.to_string()]).await?;
embeddings
.into_iter()
.next()
.ok_or_else(|| AppError::Internal("No embedding generated".to_string()))
}
pub async fn embed_texts<S: AsRef<str> + Send + Sync + 'static>(
&self,
texts: &[S],
) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let texts_owned: Vec<String> = texts.iter().map(|s| s.as_ref().to_string()).collect();
let batch_size = self.config.batch_size;
let model = Arc::clone(&self.model);
spawn_blocking(move || {
let mut model_guard = model
.lock()
.map_err(|e| AppError::Internal(format!("Failed to acquire model lock: {}", e)))?;
let refs: Vec<&str> = texts_owned.iter().map(|s| s.as_str()).collect();
model_guard
.embed(refs, Some(batch_size))
.map_err(|e| AppError::Internal(format!("Embedding failed: {}", e)))
})
.await
.map_err(|e| AppError::Internal(format!("Blocking task failed: {}", e)))?
}
pub async fn embed_sparse<S: AsRef<str> + Send + Sync + 'static>(
&self,
texts: &[S],
) -> Result<Vec<fastembed::SparseEmbedding>> {
let sparse_model = self.sparse_model.as_ref().ok_or_else(|| {
AppError::Internal(
"Sparse embeddings not enabled. Set sparse_enabled: true in config.".to_string(),
)
})?;
let texts_owned: Vec<String> = texts.iter().map(|s| s.as_ref().to_string()).collect();
let batch_size = self.config.batch_size;
let model = Arc::clone(sparse_model);
spawn_blocking(move || {
let mut model_guard = model.lock().map_err(|e| {
AppError::Internal(format!("Failed to acquire sparse model lock: {}", e))
})?;
let refs: Vec<&str> = texts_owned.iter().map(|s| s.as_str()).collect();
model_guard
.embed(refs, Some(batch_size))
.map_err(|e| AppError::Internal(format!("Sparse embedding failed: {}", e)))
})
.await
.map_err(|e| AppError::Internal(format!("Blocking task failed: {}", e)))?
}
}
use crate::rag::cache::{CacheConfig, CacheStats, EmbeddingCache, LruEmbeddingCache, NoOpCache};
pub struct CachedEmbeddingService {
inner: EmbeddingService,
cache: Box<dyn EmbeddingCache>,
}
impl CachedEmbeddingService {
pub fn new(embedding_config: EmbeddingConfig, cache_config: CacheConfig) -> Result<Self> {
let inner = EmbeddingService::new(embedding_config)?;
let cache: Box<dyn EmbeddingCache> = if cache_config.enabled {
Box::new(LruEmbeddingCache::new(cache_config))
} else {
Box::new(NoOpCache::new())
};
Ok(Self { inner, cache })
}
pub fn with_defaults() -> Result<Self> {
Self::new(EmbeddingConfig::default(), CacheConfig::default())
}
pub fn with_model(model: EmbeddingModelType) -> Result<Self> {
Self::new(
EmbeddingConfig {
model,
..Default::default()
},
CacheConfig::default(),
)
}
pub fn without_cache(embedding_config: EmbeddingConfig) -> Result<Self> {
Self::new(
embedding_config,
CacheConfig {
enabled: false,
..Default::default()
},
)
}
fn model_name(&self) -> String {
self.inner.model_type().to_string()
}
pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
let cache_key = self.cache.compute_key(text, &self.model_name());
if let Some(cached) = self.cache.get(&cache_key) {
return Ok(cached);
}
let embedding = self.inner.embed_text(text).await?;
self.cache.set(&cache_key, embedding.clone(), None)?;
Ok(embedding)
}
pub async fn embed_texts<S: AsRef<str> + Send + Sync + 'static>(
&self,
texts: &[S],
) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let model_name = self.model_name();
let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
let mut uncached_indices: Vec<usize> = Vec::new();
let mut uncached_texts: Vec<String> = Vec::new();
for (i, text) in texts.iter().enumerate() {
let text_str = text.as_ref();
let cache_key = self.cache.compute_key(text_str, &model_name);
if let Some(cached) = self.cache.get(&cache_key) {
results[i] = Some(cached);
} else {
uncached_indices.push(i);
uncached_texts.push(text_str.to_string());
}
}
if !uncached_texts.is_empty() {
let new_embeddings = self.inner.embed_texts(&uncached_texts).await?;
for (j, embedding) in new_embeddings.into_iter().enumerate() {
let idx = uncached_indices[j];
let cache_key = self.cache.compute_key(&uncached_texts[j], &model_name);
self.cache.set(&cache_key, embedding.clone(), None)?;
results[idx] = Some(embedding);
}
}
Ok(results.into_iter().flatten().collect())
}
pub fn model_type(&self) -> EmbeddingModelType {
self.inner.model_type()
}
pub fn dimensions(&self) -> usize {
self.inner.dimensions()
}
pub fn config(&self) -> &EmbeddingConfig {
self.inner.config()
}
pub fn cache_stats(&self) -> CacheStats {
self.cache.stats()
}
pub fn clear_cache(&self) -> Result<()> {
self.cache.clear()
}
pub fn invalidate(&self, text: &str) -> Result<()> {
let cache_key = self.cache.compute_key(text, &self.model_name());
self.cache.invalidate(&cache_key)
}
pub fn is_cache_enabled(&self) -> bool {
self.cache.is_enabled()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[allow(dead_code)]
#[derive(Default)]
pub enum AccelerationBackend {
#[default]
Cpu,
Cuda {
device_id: usize,
},
Metal,
Vulkan,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_dimensions() {
assert_eq!(EmbeddingModelType::BgeSmallEnV15.dimensions(), 384);
assert_eq!(EmbeddingModelType::BgeBaseEnV15.dimensions(), 768);
assert_eq!(EmbeddingModelType::BgeLargeEnV15.dimensions(), 1024);
assert_eq!(EmbeddingModelType::MultilingualE5Large.dimensions(), 1024);
}
#[test]
fn test_model_from_str() {
assert_eq!(
"bge-small-en-v1.5".parse::<EmbeddingModelType>().unwrap(),
EmbeddingModelType::BgeSmallEnV15
);
assert_eq!(
"multilingual-e5-large"
.parse::<EmbeddingModelType>()
.unwrap(),
EmbeddingModelType::MultilingualE5Large
);
assert_eq!(
"minilm-l6".parse::<EmbeddingModelType>().unwrap(),
EmbeddingModelType::AllMiniLmL6V2
);
}
#[test]
fn test_model_is_multilingual() {
assert!(EmbeddingModelType::MultilingualE5Small.is_multilingual());
assert!(EmbeddingModelType::MultilingualE5Large.is_multilingual());
assert!(!EmbeddingModelType::BgeSmallEnV15.is_multilingual());
}
#[test]
fn test_model_max_context() {
assert_eq!(
EmbeddingModelType::NomicEmbedTextV15.max_context_length(),
8192
);
assert_eq!(
EmbeddingModelType::NomicEmbedTextV1.max_context_length(),
8192
);
assert_eq!(EmbeddingModelType::BgeSmallEnV15.max_context_length(), 512);
}
#[test]
fn test_default_config() {
let config = EmbeddingConfig::default();
assert_eq!(config.model, EmbeddingModelType::BgeSmallEnV15);
assert_eq!(config.batch_size, 32);
assert!(config.show_download_progress);
assert!(!config.sparse_enabled);
}
#[test]
fn test_all_models_listed() {
let all = EmbeddingModelType::all();
assert!(all.len() >= 38); assert!(all.contains(&EmbeddingModelType::BgeSmallEnV15));
assert!(all.contains(&EmbeddingModelType::MultilingualE5Large));
}
#[test]
fn test_display_roundtrip_all_models() {
for model in EmbeddingModelType::all() {
let display = model.to_string();
let parsed: EmbeddingModelType = display.parse().unwrap_or_else(|_| {
panic!("Display→FromStr roundtrip failed for {:?} ('{}')", model, display)
});
assert_eq!(parsed, model, "Roundtrip mismatch for {}", display);
}
}
#[test]
fn test_from_str_aliases() {
let aliases = vec![
("bge-small", EmbeddingModelType::BgeSmallEnV15),
("bge-small-en", EmbeddingModelType::BgeSmallEnV15),
("bge-base", EmbeddingModelType::BgeBaseEnV15),
("bge-large", EmbeddingModelType::BgeLargeEnV15),
("e5-small", EmbeddingModelType::MultilingualE5Small),
("e5-large", EmbeddingModelType::MultilingualE5Large),
("mpnet", EmbeddingModelType::AllMpnetBaseV2),
("nomic", EmbeddingModelType::NomicEmbedTextV15),
("mxbai", EmbeddingModelType::MxbaiEmbedLargeV1),
("gte-base", EmbeddingModelType::GteBaseEnV15),
("gte-large", EmbeddingModelType::GteLargeEnV15),
("clip", EmbeddingModelType::ClipVitB32),
("jina-code", EmbeddingModelType::JinaEmbeddingsV2BaseCode),
("gemma", EmbeddingModelType::EmbeddingGemma300M),
("modernbert", EmbeddingModelType::ModernBertEmbedLarge),
("snowflake-l", EmbeddingModelType::SnowflakeArcticEmbedL),
];
for (alias, expected) in aliases {
let parsed: EmbeddingModelType = alias.parse().unwrap_or_else(|_| {
panic!("Alias '{}' should parse", alias)
});
assert_eq!(parsed, expected, "Alias '{}' mismatch", alias);
}
}
#[test]
fn test_from_str_case_insensitive() {
let upper: EmbeddingModelType = "BGE-SMALL-EN-V1.5".parse().unwrap();
assert_eq!(upper, EmbeddingModelType::BgeSmallEnV15);
let mixed: EmbeddingModelType = "Nomic-Embed-Text-V1.5".parse().unwrap();
assert_eq!(mixed, EmbeddingModelType::NomicEmbedTextV15);
}
#[test]
fn test_from_str_invalid_model() {
let result = "totally-fake-model".parse::<EmbeddingModelType>();
assert!(result.is_err());
let err = result.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("Unknown embedding model"), "Error should mention 'Unknown': {}", msg);
}
#[test]
fn test_hf_repo_id_known_models() {
assert_eq!(EmbeddingModelType::BgeSmallEnV15.hf_repo_id(), "Xenova/bge-small-en-v1.5");
assert_eq!(EmbeddingModelType::AllMiniLmL6V2.hf_repo_id(), "sentence-transformers/all-MiniLM-L6-v2");
assert_eq!(EmbeddingModelType::AllMiniLmL12V2.hf_repo_id(), "sentence-transformers/all-MiniLM-L12-v2");
}
#[test]
fn test_hf_repo_id_quantized_same_as_base() {
assert_eq!(
EmbeddingModelType::BgeSmallEnV15.hf_repo_id(),
EmbeddingModelType::BgeSmallEnV15Q.hf_repo_id()
);
assert_eq!(
EmbeddingModelType::AllMiniLmL6V2.hf_repo_id(),
EmbeddingModelType::AllMiniLmL6V2Q.hf_repo_id()
);
}
#[test]
fn test_dimensions_categories() {
for model in EmbeddingModelType::all() {
let dim = model.dimensions();
assert!(
dim == 384 || dim == 512 || dim == 768 || dim == 1024,
"{:?} has unexpected dimension {}",
model,
dim
);
}
}
#[test]
fn test_sparse_model_display_roundtrip() {
let model = SparseModelType::SpladePpV1;
let display = model.to_string();
assert_eq!(display, "splade-pp-v1");
let parsed: SparseModelType = display.parse().unwrap();
assert_eq!(parsed, model);
}
#[test]
fn test_sparse_model_alias() {
let parsed: SparseModelType = "splade".parse().unwrap();
assert_eq!(parsed, SparseModelType::SpladePpV1);
}
#[test]
fn test_sparse_model_invalid() {
let result = "nonexistent-sparse".parse::<SparseModelType>();
assert!(result.is_err());
}
#[test]
fn test_embedding_config_serialization_roundtrip() {
let config = EmbeddingConfig {
model: EmbeddingModelType::NomicEmbedTextV15,
batch_size: 64,
show_download_progress: false,
sparse_enabled: true,
sparse_model: SparseModelType::SpladePpV1,
};
let json = serde_json::to_string(&config).unwrap();
let parsed: EmbeddingConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.model, EmbeddingModelType::NomicEmbedTextV15);
assert_eq!(parsed.batch_size, 64);
assert!(!parsed.show_download_progress);
assert!(parsed.sparse_enabled);
}
#[test]
fn test_to_fastembed_model_all_variants() {
for model in EmbeddingModelType::all() {
let _ = model.to_fastembed_model(); }
}
#[tokio::test(flavor = "multi_thread")]
async fn test_pre_download_creates_cache_structure() {
let tmp = tempfile::TempDir::new().unwrap();
let cache_dir = tmp.path().to_path_buf();
let _ = pre_download_model(
"fake-org/fake-model",
&["onnx/model.onnx"],
&cache_dir,
);
let folder = cache_dir.join("models--fake-org--fake-model");
assert!(folder.join("snapshots").join("lancor-prefetch").exists(),
"snapshot dir should be created");
assert!(folder.join("refs").exists(),
"refs dir should be created");
let ref_main = folder.join("refs").join("main");
if ref_main.exists() {
let content = std::fs::read_to_string(&ref_main).unwrap();
assert_eq!(content, "lancor-prefetch");
}
}
#[test]
fn test_model_lock_creation() {
let lock1 = get_model_lock("test-model");
let lock2 = get_model_lock("test-model");
assert!(Arc::ptr_eq(&lock1, &lock2), "Same model should return same lock");
let lock3 = get_model_lock("other-model");
assert!(!Arc::ptr_eq(&lock1, &lock3), "Different models should have different locks");
}
}