Skip to main content

aonyx_memory/
embed.rs

1//! Text embeddings for semantic retrieval (RG1).
2//!
3//! [`Embedder`] is provider-agnostic. [`LocalEmbedder`] (feature `rag`) runs
4//! fastembed-rs (ONNX, offline after a one-time model download) — see ADR-005
5//! / ADR-009. A provider-backed embedder (OpenAI / Ollama) lands in Phase 1.
6
7use aonyx_core::Result;
8use async_trait::async_trait;
9
10/// Produces dense vectors for text. Must be deterministic for a given
11/// `(model, input)` so persisted vectors stay comparable across runs.
12#[async_trait]
13pub trait Embedder: Send + Sync {
14    /// Stable model id (e.g. `"bge-m3"`), persisted beside each vector so a
15    /// model change can be detected and the corpus re-indexed.
16    fn model_id(&self) -> &str;
17
18    /// Embedding dimensionality.
19    fn dim(&self) -> usize;
20
21    /// Embed a batch of texts → one vector per input, in order.
22    async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
23}
24
25#[cfg(feature = "rag")]
26pub use local::LocalEmbedder;
27
28#[cfg(feature = "rag")]
29mod local {
30    use std::path::PathBuf;
31    use std::sync::{Arc, Mutex};
32
33    use aonyx_core::{AonyxError, Result};
34    use async_trait::async_trait;
35    use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
36
37    use super::Embedder;
38
39    /// Local, offline embedder via fastembed-rs (ONNX). Quality-first default:
40    /// **BAAI/bge-m3** (multilingual, dim 1024, no query/passage prefix
41    /// needed). The model downloads once into `cache_dir`, then runs offline.
42    pub struct LocalEmbedder {
43        model: Arc<Mutex<TextEmbedding>>,
44        model_id: String,
45        dim: usize,
46    }
47
48    impl LocalEmbedder {
49        /// Load the default quality model (bge-m3), downloading on first use
50        /// into `cache_dir` (e.g. `~/.aonyx/models`).
51        pub fn new(cache_dir: PathBuf) -> Result<Self> {
52            Self::with_model(EmbeddingModel::BGEM3, "bge-m3", 1024, cache_dir)
53        }
54
55        /// Load a specific fastembed model. `dim` must match the model's output.
56        pub fn with_model(
57            model: EmbeddingModel,
58            id: &str,
59            dim: usize,
60            cache_dir: PathBuf,
61        ) -> Result<Self> {
62            let _ = std::fs::create_dir_all(&cache_dir);
63            let te = TextEmbedding::try_new(
64                InitOptions::new(model)
65                    .with_cache_dir(cache_dir)
66                    .with_show_download_progress(true),
67            )
68            .map_err(|e| AonyxError::Memory(format!("load embedder '{id}': {e}")))?;
69            Ok(Self {
70                model: Arc::new(Mutex::new(te)),
71                model_id: id.to_string(),
72                dim,
73            })
74        }
75    }
76
77    #[async_trait]
78    impl Embedder for LocalEmbedder {
79        fn model_id(&self) -> &str {
80            &self.model_id
81        }
82
83        fn dim(&self) -> usize {
84            self.dim
85        }
86
87        async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
88            if texts.is_empty() {
89                return Ok(Vec::new());
90            }
91            let model = Arc::clone(&self.model);
92            let texts = texts.to_vec();
93            // fastembed's embed() is blocking (ONNX inference) — keep it off
94            // the async runtime.
95            tokio::task::spawn_blocking(move || {
96                let mut model = model
97                    .lock()
98                    .map_err(|_| AonyxError::Memory("embedder mutex poisoned".into()))?;
99                model
100                    .embed(texts, None)
101                    .map_err(|e| AonyxError::Memory(format!("embed: {e}")))
102            })
103            .await
104            .map_err(|e| AonyxError::Memory(format!("embed task join: {e}")))?
105        }
106    }
107}