use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use std::sync::{Arc, Mutex};
use tokio::sync::OnceCell;
use tokio::task;
use tracing::info;
type SharedModel = Arc<Mutex<TextEmbedding>>;
#[derive(Clone)]
pub struct EmbeddingService {
model: Arc<OnceCell<SharedModel>>,
}
impl EmbeddingService {
pub fn new() -> anyhow::Result<Self> {
Ok(Self {
model: Arc::new(OnceCell::new()),
})
}
async fn get_model(&self) -> anyhow::Result<SharedModel> {
let model = self
.model
.get_or_try_init(|| async {
task::spawn_blocking(|| {
let mut options = InitOptions::default();
options.model_name = EmbeddingModel::AllMiniLML6V2;
options.show_download_progress = true;
let model = TextEmbedding::try_new(options)?;
info!("Embedding model loaded (AllMiniLML6V2)");
Ok::<_, anyhow::Error>(Arc::new(Mutex::new(model)))
})
.await?
})
.await?;
Ok(model.clone())
}
pub async fn embed(&self, text: String) -> anyhow::Result<Vec<f32>> {
let model = self.get_model().await?;
task::spawn_blocking(move || {
let guard = model
.lock()
.map_err(|e| anyhow::anyhow!("embedding model mutex poisoned: {e}"))?;
let embeddings = guard.embed(vec![text], None)?;
Ok(embeddings[0].clone())
})
.await?
}
#[allow(dead_code)]
pub async fn embed_batch(&self, texts: Vec<String>) -> anyhow::Result<Vec<Vec<f32>>> {
let model = self.get_model().await?;
task::spawn_blocking(move || {
let guard = model
.lock()
.map_err(|e| anyhow::anyhow!("embedding model mutex poisoned: {e}"))?;
guard.embed(texts, None)
})
.await?
}
}