use fastembed::{
EmbeddingModel, InitOptions, RerankInitOptions, RerankerModel, TextEmbedding, TextRerank,
};
use std::sync::{Arc, Mutex};
use tokio::sync::OnceCell;
use tokio::task;
use tracing::info;
type SharedModel = Arc<Mutex<TextEmbedding>>;
type SharedReranker = Arc<Mutex<TextRerank>>;
const RERANKER_MODEL: RerankerModel = RerankerModel::JINARerankerV2BaseMultiligual;
#[derive(Clone)]
pub struct EmbeddingService {
model: Arc<OnceCell<SharedModel>>,
reranker: Arc<OnceCell<SharedReranker>>,
}
impl EmbeddingService {
pub fn new() -> anyhow::Result<Self> {
Ok(Self {
model: Arc::new(OnceCell::new()),
reranker: Arc::new(OnceCell::new()),
})
}
async fn get_model(&self) -> anyhow::Result<SharedModel> {
let model = self
.model
.get_or_try_init(|| async {
task::spawn_blocking(|| {
let mut options = InitOptions::default();
options.model_name = EmbeddingModel::AllMiniLML6V2;
options.show_download_progress = true;
let model = TextEmbedding::try_new(options)?;
info!("Embedding model loaded (AllMiniLML6V2)");
Ok::<_, anyhow::Error>(Arc::new(Mutex::new(model)))
})
.await?
})
.await?;
Ok(model.clone())
}
pub async fn embed(&self, text: String) -> anyhow::Result<Vec<f32>> {
let model = self.get_model().await?;
task::spawn_blocking(move || {
let guard = model
.lock()
.map_err(|e| anyhow::anyhow!("embedding model mutex poisoned: {e}"))?;
let embeddings = guard.embed(vec![text], None)?;
Ok(embeddings[0].clone())
})
.await?
}
#[allow(dead_code)]
pub async fn embed_batch(&self, texts: Vec<String>) -> anyhow::Result<Vec<Vec<f32>>> {
let model = self.get_model().await?;
task::spawn_blocking(move || {
let guard = model
.lock()
.map_err(|e| anyhow::anyhow!("embedding model mutex poisoned: {e}"))?;
guard.embed(texts, None)
})
.await?
}
async fn get_reranker(&self) -> anyhow::Result<SharedReranker> {
let reranker = self
.reranker
.get_or_try_init(|| async {
task::spawn_blocking(|| {
let options =
RerankInitOptions::new(RERANKER_MODEL).with_show_download_progress(true);
let model = TextRerank::try_new(options)?;
info!("Reranker model loaded ({:?})", RERANKER_MODEL);
Ok::<_, anyhow::Error>(Arc::new(Mutex::new(model)))
})
.await?
})
.await?;
Ok(reranker.clone())
}
pub async fn rerank(
&self,
query: String,
documents: Vec<String>,
) -> anyhow::Result<Vec<(usize, f32)>> {
if documents.is_empty() {
return Ok(vec![]);
}
if cfg!(test) {
anyhow::bail!("reranker disabled in test builds");
}
let reranker = self.get_reranker().await?;
task::spawn_blocking(move || {
let guard = reranker
.lock()
.map_err(|e| anyhow::anyhow!("reranker model mutex poisoned: {e}"))?;
let results = guard.rerank(query, documents, false, None)?;
let mut ranked: Vec<(usize, f32)> =
results.into_iter().map(|r| (r.index, r.score)).collect();
ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(ranked)
})
.await?
}
}