use aonyx_core::Result;
use async_trait::async_trait;
#[async_trait]
pub trait Embedder: Send + Sync {
fn model_id(&self) -> &str;
fn dim(&self) -> usize;
async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
}
#[cfg(feature = "rag")]
pub use local::LocalEmbedder;
#[cfg(feature = "rag")]
mod local {
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use aonyx_core::{AonyxError, Result};
use async_trait::async_trait;
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use super::Embedder;
pub struct LocalEmbedder {
model: Arc<Mutex<TextEmbedding>>,
model_id: String,
dim: usize,
}
impl LocalEmbedder {
pub fn new(cache_dir: PathBuf) -> Result<Self> {
Self::with_model(EmbeddingModel::BGEM3, "bge-m3", 1024, cache_dir)
}
pub fn with_model(
model: EmbeddingModel,
id: &str,
dim: usize,
cache_dir: PathBuf,
) -> Result<Self> {
let _ = std::fs::create_dir_all(&cache_dir);
let te = TextEmbedding::try_new(
InitOptions::new(model)
.with_cache_dir(cache_dir)
.with_show_download_progress(true),
)
.map_err(|e| AonyxError::Memory(format!("load embedder '{id}': {e}")))?;
Ok(Self {
model: Arc::new(Mutex::new(te)),
model_id: id.to_string(),
dim,
})
}
}
#[async_trait]
impl Embedder for LocalEmbedder {
fn model_id(&self) -> &str {
&self.model_id
}
fn dim(&self) -> usize {
self.dim
}
async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let model = Arc::clone(&self.model);
let texts = texts.to_vec();
tokio::task::spawn_blocking(move || {
let mut model = model
.lock()
.map_err(|_| AonyxError::Memory("embedder mutex poisoned".into()))?;
model
.embed(texts, None)
.map_err(|e| AonyxError::Memory(format!("embed: {e}")))
})
.await
.map_err(|e| AonyxError::Memory(format!("embed task join: {e}")))?
}
}
}