Skip to main content

engram/embedding/
mod.rs

1//! Embedding generation and async queue management (RML-873)
2//!
3//! Supports multiple embedding backends:
4//! - OpenAI API (text-embedding-3-small) - requires `openai` feature
5//! - TF-IDF fallback (no external dependencies)
6//!
7//! Features:
8//! - LRU embedding cache with zero-copy Arc<[f32]> sharing
9//! - Async queue processing for batch operations
10//!
11//! # Feature Flags
12//!
13//! - `openai`: Enables OpenAI embedding backend (requires API key)
14
15mod cache;
16mod provider;
17mod queue;
18mod tfidf;
19
20#[cfg(feature = "cohere")]
21pub mod cohere;
22#[cfg(feature = "multimodal")]
23pub mod clip;
24#[cfg(feature = "ollama")]
25pub mod ollama;
26#[cfg(feature = "onnx-embed")]
27pub mod onnx;
28#[cfg(feature = "voyage")]
29pub mod voyage;
30
31pub use cache::{EmbeddingCache, EmbeddingCacheStats};
32#[cfg(feature = "multimodal")]
33pub use clip::{ClipEmbedder, MultimodalEmbedder, CLIP_PROVIDER_NAME};
34pub use provider::{EmbeddingProvider, EmbeddingProviderInfo, EmbeddingRegistry};
35pub use queue::{get_embedding, get_embedding_status, EmbeddingQueue, EmbeddingWorker};
36pub use tfidf::TfIdfEmbedder;
37
38use std::sync::Arc;
39
40use crate::error::{EngramError, Result};
41use crate::types::EmbeddingConfig;
42
43/// Trait for embedding generators
44pub trait Embedder: Send + Sync {
45    /// Generate embedding for a single text
46    fn embed(&self, text: &str) -> Result<Vec<f32>>;
47
48    /// Generate embeddings for multiple texts (batch)
49    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
50        texts.iter().map(|t| self.embed(t)).collect()
51    }
52
53    /// Get embedding dimensions
54    fn dimensions(&self) -> usize;
55
56    /// Get model name
57    fn model_name(&self) -> &str;
58}
59
60/// OpenAI embedding client
61///
62/// Requires the `openai` feature to be enabled.
63/// Supports OpenAI, OpenRouter, Azure OpenAI, and other OpenAI-compatible APIs.
64#[cfg(feature = "openai")]
65pub struct OpenAIEmbedder {
66    client: reqwest::Client,
67    api_key: String,
68    base_url: String,
69    model: String,
70    dimensions: usize,
71}
72
73#[cfg(feature = "openai")]
74impl OpenAIEmbedder {
75    /// Create a new OpenAI embedder with default settings
76    pub fn new(api_key: String) -> Self {
77        Self {
78            client: reqwest::Client::new(),
79            api_key,
80            base_url: "https://api.openai.com/v1".to_string(),
81            model: "text-embedding-3-small".to_string(),
82            dimensions: 1536,
83        }
84    }
85
86    /// Create a new OpenAI embedder with custom settings
87    ///
88    /// # Arguments
89    /// * `api_key` - API key for authentication
90    /// * `base_url` - API base URL (e.g., "https://openrouter.ai/api/v1" for OpenRouter)
91    /// * `model` - Model name (e.g., "openai/text-embedding-3-small" for OpenRouter)
92    /// * `dimensions` - Expected embedding dimensions (must match model output)
93    pub fn with_config(
94        api_key: String,
95        base_url: Option<String>,
96        model: Option<String>,
97        dimensions: Option<usize>,
98    ) -> Self {
99        Self {
100            client: reqwest::Client::new(),
101            api_key,
102            base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
103            model: model.unwrap_or_else(|| "text-embedding-3-small".to_string()),
104            dimensions: dimensions.unwrap_or(1536),
105        }
106    }
107
108    /// Legacy constructor for backwards compatibility
109    pub fn with_model(api_key: String, model: String, dimensions: usize) -> Self {
110        Self {
111            client: reqwest::Client::new(),
112            api_key,
113            base_url: "https://api.openai.com/v1".to_string(),
114            model,
115            dimensions,
116        }
117    }
118
119    /// Async embedding call to OpenAI-compatible API
120    pub async fn embed_async(&self, text: &str) -> Result<Vec<f32>> {
121        let url = format!("{}/embeddings", self.base_url);
122
123        let response = self
124            .client
125            .post(&url)
126            .header("Authorization", format!("Bearer {}", self.api_key))
127            // OpenRouter requires HTTP-Referer header
128            .header("HTTP-Referer", "https://github.com/engram")
129            // Optional: helps OpenRouter track usage
130            .header("X-Title", "Engram Memory")
131            .json(&serde_json::json!({
132                "input": text,
133                "model": self.model,
134            }))
135            .send()
136            .await?;
137
138        if !response.status().is_success() {
139            let status = response.status();
140            let text = response.text().await.unwrap_or_default();
141            return Err(EngramError::Embedding(format!(
142                "Embedding API error {}: {}",
143                status, text
144            )));
145        }
146
147        let data: serde_json::Value = response.json().await?;
148        let embedding: Vec<f32> = data["data"][0]["embedding"]
149            .as_array()
150            .ok_or_else(|| EngramError::Embedding("Invalid response format".to_string()))?
151            .iter()
152            .filter_map(|v| v.as_f64().map(|f| f as f32))
153            .collect();
154
155        // Validate dimensions match configuration
156        if embedding.len() != self.dimensions {
157            return Err(EngramError::Embedding(format!(
158                "Embedding dimensions mismatch: expected {}, got {}. Set OPENAI_EMBEDDING_DIMENSIONS={} to match your model.",
159                self.dimensions, embedding.len(), embedding.len()
160            )));
161        }
162
163        Ok(embedding)
164    }
165
166    /// Async batch embedding (up to 2048 inputs per call)
167    pub async fn embed_batch_async(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
168        if texts.is_empty() {
169            return Ok(vec![]);
170        }
171
172        let url = format!("{}/embeddings", self.base_url);
173
174        // OpenAI allows up to 2048 inputs per batch
175        let mut all_embeddings = Vec::with_capacity(texts.len());
176
177        for chunk in texts.chunks(2048) {
178            let response = self
179                .client
180                .post(&url)
181                .header("Authorization", format!("Bearer {}", self.api_key))
182                // OpenRouter requires HTTP-Referer header
183                .header("HTTP-Referer", "https://github.com/engram")
184                .header("X-Title", "Engram Memory")
185                .json(&serde_json::json!({
186                    "input": chunk,
187                    "model": self.model,
188                }))
189                .send()
190                .await?;
191
192            if !response.status().is_success() {
193                let status = response.status();
194                let text = response.text().await.unwrap_or_default();
195                return Err(EngramError::Embedding(format!(
196                    "Embedding API error {}: {}",
197                    status, text
198                )));
199            }
200
201            let data: serde_json::Value = response.json().await?;
202            let embeddings: Vec<Vec<f32>> = data["data"]
203                .as_array()
204                .ok_or_else(|| EngramError::Embedding("Invalid response format".to_string()))?
205                .iter()
206                .map(|item| {
207                    item["embedding"]
208                        .as_array()
209                        .map(|arr| {
210                            arr.iter()
211                                .filter_map(|v| v.as_f64().map(|f| f as f32))
212                                .collect()
213                        })
214                        .unwrap_or_default()
215                })
216                .collect();
217
218            // Validate dimensions on first batch
219            if !embeddings.is_empty() && embeddings[0].len() != self.dimensions {
220                return Err(EngramError::Embedding(format!(
221                    "Embedding dimensions mismatch: expected {}, got {}. Set OPENAI_EMBEDDING_DIMENSIONS={} to match your model.",
222                    self.dimensions, embeddings[0].len(), embeddings[0].len()
223                )));
224            }
225
226            all_embeddings.extend(embeddings);
227        }
228
229        Ok(all_embeddings)
230    }
231}
232
233#[cfg(feature = "openai")]
234impl Embedder for OpenAIEmbedder {
235    fn embed(&self, text: &str) -> Result<Vec<f32>> {
236        // Blocking call for sync interface
237        tokio::task::block_in_place(|| {
238            tokio::runtime::Handle::current().block_on(self.embed_async(text))
239        })
240    }
241
242    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
243        tokio::task::block_in_place(|| {
244            tokio::runtime::Handle::current().block_on(self.embed_batch_async(texts))
245        })
246    }
247
248    fn dimensions(&self) -> usize {
249        self.dimensions
250    }
251
252    fn model_name(&self) -> &str {
253        &self.model
254    }
255}
256
257/// Create an embedder from configuration
258///
259/// Available models depend on enabled features:
260/// - `"tfidf"`: Always available, no external dependencies
261/// - `"openai"`: Requires `openai` feature and API key
262///
263/// For OpenAI-compatible APIs (OpenRouter, Azure, etc.), set:
264/// - `base_url`: API endpoint (e.g., "https://openrouter.ai/api/v1")
265/// - `embedding_model`: Model name (e.g., "openai/text-embedding-3-small")
266/// - `dimensions`: Expected output dimensions
267pub fn create_embedder(config: &EmbeddingConfig) -> Result<Arc<dyn Embedder>> {
268    match config.model.as_str() {
269        #[cfg(feature = "multimodal")]
270        "clip" => {
271            clip::create_clip_embedder()
272                .map(|e| e as Arc<dyn Embedder>)
273        }
274        #[cfg(feature = "openai")]
275        "openai" => {
276            let api_key = config
277                .api_key
278                .clone()
279                .ok_or_else(|| EngramError::Config(
280                    "OPENAI_API_KEY required when ENGRAM_EMBEDDING_MODEL=openai".to_string()
281                ))?;
282            Ok(Arc::new(OpenAIEmbedder::with_config(
283                api_key,
284                config.base_url.clone(),
285                config.embedding_model.clone(),
286                Some(config.dimensions),
287            )))
288        }
289        #[cfg(not(feature = "openai"))]
290        "openai" => Err(EngramError::Config(
291            "OpenAI embeddings require the 'openai' feature to be enabled. Build with: cargo build --features openai".to_string(),
292        )),
293        "tfidf" => Ok(Arc::new(TfIdfEmbedder::new(config.dimensions))),
294        _ => Err(EngramError::Config(format!(
295            "Unknown embedding model: '{}'. Use 'openai' or 'tfidf'",
296            config.model
297        ))),
298    }
299}
300
301/// Cosine similarity between two vectors
302pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
303    if a.len() != b.len() || a.is_empty() {
304        return 0.0;
305    }
306
307    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
308    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
309    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
310
311    if norm_a == 0.0 || norm_b == 0.0 {
312        return 0.0;
313    }
314
315    dot / (norm_a * norm_b)
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn test_cosine_similarity() {
324        let a = vec![1.0, 0.0, 0.0];
325        let b = vec![1.0, 0.0, 0.0];
326        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
327
328        let c = vec![0.0, 1.0, 0.0];
329        assert!(cosine_similarity(&a, &c).abs() < 0.001);
330
331        let d = vec![-1.0, 0.0, 0.0];
332        assert!((cosine_similarity(&a, &d) + 1.0).abs() < 0.001);
333    }
334
335    #[test]
336    fn test_tfidf_embedder() {
337        let embedder = TfIdfEmbedder::new(384);
338        let embedding = embedder.embed("Hello world").unwrap();
339        assert_eq!(embedding.len(), 384);
340    }
341}