Skip to main content

engram/embedding/
mod.rs

1//! Embedding generation and async queue management (RML-873)
2//!
3//! Supports multiple embedding backends:
4//! - OpenAI API (text-embedding-3-small) - requires `openai` feature
5//! - TF-IDF fallback (no external dependencies)
6//!
7//! Features:
8//! - LRU embedding cache with zero-copy Arc<[f32]> sharing
9//! - Async queue processing for batch operations
10//!
11//! # Feature Flags
12//!
13//! - `openai`: Enables OpenAI embedding backend (requires API key)
14
15mod cache;
16mod provider;
17mod queue;
18mod tfidf;
19
20#[cfg(feature = "multimodal")]
21pub mod clip;
22#[cfg(feature = "cohere")]
23pub mod cohere;
24#[cfg(feature = "ollama")]
25pub mod ollama;
26#[cfg(feature = "onnx-embed")]
27pub mod onnx;
28#[cfg(feature = "onnx-embed")]
29pub mod onnx_registry;
30#[cfg(feature = "voyage")]
31pub mod voyage;
32
33pub use cache::{EmbeddingCache, EmbeddingCacheStats};
34#[cfg(feature = "multimodal")]
35pub use clip::{ClipEmbedder, MultimodalEmbedder, CLIP_PROVIDER_NAME};
36pub use provider::{EmbeddingProvider, EmbeddingProviderInfo, EmbeddingRegistry};
37pub use queue::{
38    drain_pending_embeddings, get_embedding, get_embedding_queue_health, get_embedding_status,
39    requeue_stale_processing_embeddings, run_embedding_queue_hygiene, EmbeddingQueue,
40    EmbeddingQueueHealth, EmbeddingQueueHygieneConfig, EmbeddingQueueHygieneReport,
41    EmbeddingWorker, DEFAULT_COMPLETE_RETENTION, DEFAULT_MAX_EMBEDDING_RETRIES,
42    DEFAULT_STALE_PROCESSING_AFTER,
43};
44pub use tfidf::TfIdfEmbedder;
45
46use std::sync::Arc;
47
48use crate::error::{EngramError, Result};
49use crate::types::EmbeddingConfig;
50
51/// Trait for embedding generators
52pub trait Embedder: Send + Sync {
53    /// Generate embedding for a single text
54    fn embed(&self, text: &str) -> Result<Vec<f32>>;
55
56    /// Generate embeddings for multiple texts (batch)
57    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
58        texts.iter().map(|t| self.embed(t)).collect()
59    }
60
61    /// Get embedding dimensions
62    fn dimensions(&self) -> usize;
63
64    /// Get model name
65    fn model_name(&self) -> &str;
66}
67
68/// OpenAI embedding client
69///
70/// Requires the `openai` feature to be enabled.
71/// Supports OpenAI, OpenRouter, Azure OpenAI, and other OpenAI-compatible APIs.
72#[cfg(feature = "openai")]
73pub struct OpenAIEmbedder {
74    client: reqwest::Client,
75    api_key: String,
76    base_url: String,
77    model: String,
78    dimensions: usize,
79    /// Owned single-threaded runtime used by the sync `embed()` / `embed_batch()`
80    /// path when no tokio runtime is in scope (e.g. MCP stdio transport, which
81    /// runs the JSON-RPC loop synchronously). When the embedder is called from
82    /// inside an existing runtime (HTTP transport), we detect that with
83    /// `Handle::try_current()` and use `block_in_place` instead of this owned
84    /// runtime to avoid the cost of a nested runtime. Fixes #9.
85    runtime: tokio::runtime::Runtime,
86}
87
88#[cfg(feature = "openai")]
89impl OpenAIEmbedder {
90    /// Build the owned fallback runtime. Single-threaded with all drivers
91    /// enabled is enough for the rare case where there's no ambient runtime.
92    fn build_fallback_runtime() -> tokio::runtime::Runtime {
93        tokio::runtime::Builder::new_current_thread()
94            .enable_all()
95            .build()
96            .expect("OpenAIEmbedder: failed to build fallback tokio runtime")
97    }
98
99    /// Create a new OpenAI embedder with default settings
100    pub fn new(api_key: String) -> Self {
101        Self {
102            client: reqwest::Client::new(),
103            api_key,
104            base_url: "https://api.openai.com/v1".to_string(),
105            model: "text-embedding-3-small".to_string(),
106            dimensions: 1536,
107            runtime: Self::build_fallback_runtime(),
108        }
109    }
110
111    /// Create a new OpenAI embedder with custom settings
112    ///
113    /// # Arguments
114    /// * `api_key` - API key for authentication
115    /// * `base_url` - API base URL (e.g., `<https://openrouter.ai/api/v1>` for OpenRouter)
116    /// * `model` - Model name (e.g., "openai/text-embedding-3-small" for OpenRouter)
117    /// * `dimensions` - Expected embedding dimensions (must match model output)
118    pub fn with_config(
119        api_key: String,
120        base_url: Option<String>,
121        model: Option<String>,
122        dimensions: Option<usize>,
123    ) -> Self {
124        Self {
125            client: reqwest::Client::new(),
126            api_key,
127            base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
128            model: model.unwrap_or_else(|| "text-embedding-3-small".to_string()),
129            dimensions: dimensions.unwrap_or(1536),
130            runtime: Self::build_fallback_runtime(),
131        }
132    }
133
134    /// Legacy constructor for backwards compatibility
135    pub fn with_model(api_key: String, model: String, dimensions: usize) -> Self {
136        Self {
137            client: reqwest::Client::new(),
138            api_key,
139            base_url: "https://api.openai.com/v1".to_string(),
140            model,
141            dimensions,
142            runtime: Self::build_fallback_runtime(),
143        }
144    }
145
146    /// Async embedding call to OpenAI-compatible API
147    pub async fn embed_async(&self, text: &str) -> Result<Vec<f32>> {
148        let url = format!("{}/embeddings", self.base_url);
149
150        let response = self
151            .client
152            .post(&url)
153            .header("Authorization", format!("Bearer {}", self.api_key))
154            // OpenRouter requires HTTP-Referer header
155            .header("HTTP-Referer", "https://github.com/engram")
156            // Optional: helps OpenRouter track usage
157            .header("X-Title", "Engram Memory")
158            .json(&serde_json::json!({
159                "input": text,
160                "model": self.model,
161            }))
162            .send()
163            .await?;
164
165        if !response.status().is_success() {
166            let status = response.status();
167            let text = response.text().await.unwrap_or_default();
168            return Err(EngramError::Embedding(format!(
169                "Embedding API error {}: {}",
170                status, text
171            )));
172        }
173
174        let data: serde_json::Value = response.json().await?;
175        let embedding: Vec<f32> = data["data"][0]["embedding"]
176            .as_array()
177            .ok_or_else(|| EngramError::Embedding("Invalid response format".to_string()))?
178            .iter()
179            .filter_map(|v| v.as_f64().map(|f| f as f32))
180            .collect();
181
182        // Validate dimensions match configuration
183        if embedding.len() != self.dimensions {
184            return Err(EngramError::Embedding(format!(
185                "Embedding dimensions mismatch: expected {}, got {}. Set OPENAI_EMBEDDING_DIMENSIONS={} to match your model.",
186                self.dimensions, embedding.len(), embedding.len()
187            )));
188        }
189
190        Ok(embedding)
191    }
192
193    /// Async batch embedding (up to 2048 inputs per call)
194    pub async fn embed_batch_async(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
195        if texts.is_empty() {
196            return Ok(vec![]);
197        }
198
199        let url = format!("{}/embeddings", self.base_url);
200
201        // OpenAI allows up to 2048 inputs per batch
202        let mut all_embeddings = Vec::with_capacity(texts.len());
203
204        for chunk in texts.chunks(2048) {
205            let response = self
206                .client
207                .post(&url)
208                .header("Authorization", format!("Bearer {}", self.api_key))
209                // OpenRouter requires HTTP-Referer header
210                .header("HTTP-Referer", "https://github.com/engram")
211                .header("X-Title", "Engram Memory")
212                .json(&serde_json::json!({
213                    "input": chunk,
214                    "model": self.model,
215                }))
216                .send()
217                .await?;
218
219            if !response.status().is_success() {
220                let status = response.status();
221                let text = response.text().await.unwrap_or_default();
222                return Err(EngramError::Embedding(format!(
223                    "Embedding API error {}: {}",
224                    status, text
225                )));
226            }
227
228            let data: serde_json::Value = response.json().await?;
229            let embeddings: Vec<Vec<f32>> = data["data"]
230                .as_array()
231                .ok_or_else(|| EngramError::Embedding("Invalid response format".to_string()))?
232                .iter()
233                .map(|item| {
234                    item["embedding"]
235                        .as_array()
236                        .map(|arr| {
237                            arr.iter()
238                                .filter_map(|v| v.as_f64().map(|f| f as f32))
239                                .collect()
240                        })
241                        .unwrap_or_default()
242                })
243                .collect();
244
245            // Validate dimensions on first batch
246            if !embeddings.is_empty() && embeddings[0].len() != self.dimensions {
247                return Err(EngramError::Embedding(format!(
248                    "Embedding dimensions mismatch: expected {}, got {}. Set OPENAI_EMBEDDING_DIMENSIONS={} to match your model.",
249                    self.dimensions, embeddings[0].len(), embeddings[0].len()
250                )));
251            }
252
253            all_embeddings.extend(embeddings);
254        }
255
256        Ok(all_embeddings)
257    }
258}
259
260#[cfg(feature = "openai")]
261impl Embedder for OpenAIEmbedder {
262    fn embed(&self, text: &str) -> Result<Vec<f32>> {
263        // If we are inside an existing tokio runtime (e.g. HTTP transport),
264        // use block_in_place so we don't create a nested runtime. Otherwise
265        // (e.g. MCP stdio), drive the async call from our owned fallback
266        // runtime — without this, Handle::current() panics with "there is
267        // no reactor running". Fixes #9.
268        match tokio::runtime::Handle::try_current() {
269            Ok(handle) => tokio::task::block_in_place(|| handle.block_on(self.embed_async(text))),
270            Err(_) => self.runtime.block_on(self.embed_async(text)),
271        }
272    }
273
274    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
275        match tokio::runtime::Handle::try_current() {
276            Ok(handle) => {
277                tokio::task::block_in_place(|| handle.block_on(self.embed_batch_async(texts)))
278            }
279            Err(_) => self.runtime.block_on(self.embed_batch_async(texts)),
280        }
281    }
282
283    fn dimensions(&self) -> usize {
284        self.dimensions
285    }
286
287    fn model_name(&self) -> &str {
288        &self.model
289    }
290}
291
292/// Create an embedder from configuration
293///
294/// Available models depend on enabled features:
295/// - `"tfidf"`: Always available, no external dependencies
296/// - `"openai"`: Requires `openai` feature and API key
297/// - `"local"` / `"onnx"`: Requires `local-embeddings` feature and a downloaded ONNX model
298///
299/// For OpenAI-compatible APIs (OpenRouter, Azure, etc.), set:
300/// - `base_url`: API endpoint (e.g., `<https://openrouter.ai/api/v1>`)
301/// - `embedding_model`: Model name (e.g., "openai/text-embedding-3-small")
302/// - `dimensions`: Expected output dimensions
303pub fn create_embedder(config: &EmbeddingConfig) -> Result<Arc<dyn Embedder>> {
304    match config.model.as_str() {
305        #[cfg(feature = "multimodal")]
306        "clip" => {
307            clip::create_clip_embedder()
308                .map(|e| e as Arc<dyn Embedder>)
309        }
310        #[cfg(feature = "openai")]
311        "openai" => {
312            let api_key = config
313                .api_key
314                .clone()
315                .ok_or_else(|| EngramError::Config(
316                    "OPENAI_API_KEY required when ENGRAM_EMBEDDING_MODEL=openai".to_string()
317                ))?;
318            Ok(Arc::new(OpenAIEmbedder::with_config(
319                api_key,
320                config.base_url.clone(),
321                config.embedding_model.clone(),
322                Some(config.dimensions),
323            )))
324        }
325        #[cfg(not(feature = "openai"))]
326        "openai" => Err(EngramError::Config(
327            "OpenAI embeddings require the 'openai' feature to be enabled. Build with: cargo build --features openai".to_string(),
328        )),
329        #[cfg(feature = "onnx-embed")]
330        "local" | "onnx" => {
331            let model_dir = onnx::resolve_model_dir(config.model_path.as_deref());
332            Ok(Arc::new(onnx::OnnxEmbedder::from_dir(&model_dir)?))
333        }
334        #[cfg(not(feature = "onnx-embed"))]
335        "local" | "onnx" => Err(EngramError::Config(
336            "Local sentence-transformer embeddings require the 'local-embeddings' feature. Build with: cargo build --features local-embeddings, then run: engram-cli model download minilm-l6-v2 and set ENGRAM_EMBEDDING_MODEL=local".to_string(),
337        )),
338        "tfidf" => Ok(Arc::new(TfIdfEmbedder::new(config.dimensions))),
339        _ => Err(EngramError::Config(format!(
340            "Unknown embedding model: '{}'. Use 'tfidf', 'local', or 'openai'",
341            config.model
342        ))),
343    }
344}
345
346/// Cosine similarity between two vectors
347pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
348    if a.len() != b.len() || a.is_empty() {
349        return 0.0;
350    }
351
352    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
353    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
354    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
355
356    if norm_a == 0.0 || norm_b == 0.0 {
357        return 0.0;
358    }
359
360    dot / (norm_a * norm_b)
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn test_cosine_similarity() {
369        let a = vec![1.0, 0.0, 0.0];
370        let b = vec![1.0, 0.0, 0.0];
371        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
372
373        let c = vec![0.0, 1.0, 0.0];
374        assert!(cosine_similarity(&a, &c).abs() < 0.001);
375
376        let d = vec![-1.0, 0.0, 0.0];
377        assert!((cosine_similarity(&a, &d) + 1.0).abs() < 0.001);
378    }
379
380    #[test]
381    fn test_tfidf_embedder() {
382        let embedder = TfIdfEmbedder::new(384);
383        let embedding = embedder.embed("Hello world").unwrap();
384        assert_eq!(embedding.len(), 384);
385    }
386
387    #[cfg(not(feature = "local-embeddings"))]
388    #[test]
389    fn test_local_embedder_requires_feature_when_disabled() {
390        let config = EmbeddingConfig {
391            model: "local".to_string(),
392            ..EmbeddingConfig::default()
393        };
394
395        let err = match create_embedder(&config) {
396            Ok(_) => panic!("local backend should require opt-in feature"),
397            Err(err) => err,
398        };
399        let msg = err.to_string();
400
401        assert!(msg.contains("local-embeddings"), "{msg}");
402        assert!(msg.contains("ENGRAM_EMBEDDING_MODEL=local"), "{msg}");
403    }
404
405    #[cfg(not(feature = "local-embeddings"))]
406    #[test]
407    fn test_onnx_alias_requires_feature_when_disabled() {
408        let config = EmbeddingConfig {
409            model: "onnx".to_string(),
410            ..EmbeddingConfig::default()
411        };
412
413        let err = match create_embedder(&config) {
414            Ok(_) => panic!("onnx alias should require opt-in feature"),
415            Err(err) => err,
416        };
417        let msg = err.to_string();
418
419        assert!(msg.contains("local-embeddings"), "{msg}");
420        assert!(msg.contains("ENGRAM_EMBEDDING_MODEL=local"), "{msg}");
421    }
422}