Skip to main content

engram/embedding/
mod.rs

1//! Embedding generation and async queue management (RML-873)
2//!
3//! Supports multiple embedding backends:
4//! - OpenAI API (text-embedding-3-small) - requires `openai` feature
5//! - TF-IDF fallback (no external dependencies)
6//!
7//! Features:
8//! - LRU embedding cache with zero-copy Arc<[f32]> sharing
9//! - Async queue processing for batch operations
10//!
11//! # Feature Flags
12//!
13//! - `openai`: Enables OpenAI embedding backend (requires API key)
14
15mod cache;
16mod provider;
17mod queue;
18mod tfidf;
19
20#[cfg(feature = "cohere")]
21pub mod cohere;
22#[cfg(feature = "ollama")]
23pub mod ollama;
24#[cfg(feature = "onnx-embed")]
25pub mod onnx;
26#[cfg(feature = "voyage")]
27pub mod voyage;
28
29pub use cache::{EmbeddingCache, EmbeddingCacheStats};
30pub use provider::{EmbeddingProvider, EmbeddingProviderInfo, EmbeddingRegistry};
31pub use queue::{get_embedding, get_embedding_status, EmbeddingQueue, EmbeddingWorker};
32pub use tfidf::TfIdfEmbedder;
33
34use std::sync::Arc;
35
36use crate::error::{EngramError, Result};
37use crate::types::EmbeddingConfig;
38
39/// Trait for embedding generators
40pub trait Embedder: Send + Sync {
41    /// Generate embedding for a single text
42    fn embed(&self, text: &str) -> Result<Vec<f32>>;
43
44    /// Generate embeddings for multiple texts (batch)
45    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
46        texts.iter().map(|t| self.embed(t)).collect()
47    }
48
49    /// Get embedding dimensions
50    fn dimensions(&self) -> usize;
51
52    /// Get model name
53    fn model_name(&self) -> &str;
54}
55
56/// OpenAI embedding client
57///
58/// Requires the `openai` feature to be enabled.
59/// Supports OpenAI, OpenRouter, Azure OpenAI, and other OpenAI-compatible APIs.
60#[cfg(feature = "openai")]
61pub struct OpenAIEmbedder {
62    client: reqwest::Client,
63    api_key: String,
64    base_url: String,
65    model: String,
66    dimensions: usize,
67}
68
69#[cfg(feature = "openai")]
70impl OpenAIEmbedder {
71    /// Create a new OpenAI embedder with default settings
72    pub fn new(api_key: String) -> Self {
73        Self {
74            client: reqwest::Client::new(),
75            api_key,
76            base_url: "https://api.openai.com/v1".to_string(),
77            model: "text-embedding-3-small".to_string(),
78            dimensions: 1536,
79        }
80    }
81
82    /// Create a new OpenAI embedder with custom settings
83    ///
84    /// # Arguments
85    /// * `api_key` - API key for authentication
86    /// * `base_url` - API base URL (e.g., "https://openrouter.ai/api/v1" for OpenRouter)
87    /// * `model` - Model name (e.g., "openai/text-embedding-3-small" for OpenRouter)
88    /// * `dimensions` - Expected embedding dimensions (must match model output)
89    pub fn with_config(
90        api_key: String,
91        base_url: Option<String>,
92        model: Option<String>,
93        dimensions: Option<usize>,
94    ) -> Self {
95        Self {
96            client: reqwest::Client::new(),
97            api_key,
98            base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
99            model: model.unwrap_or_else(|| "text-embedding-3-small".to_string()),
100            dimensions: dimensions.unwrap_or(1536),
101        }
102    }
103
104    /// Legacy constructor for backwards compatibility
105    pub fn with_model(api_key: String, model: String, dimensions: usize) -> Self {
106        Self {
107            client: reqwest::Client::new(),
108            api_key,
109            base_url: "https://api.openai.com/v1".to_string(),
110            model,
111            dimensions,
112        }
113    }
114
115    /// Async embedding call to OpenAI-compatible API
116    pub async fn embed_async(&self, text: &str) -> Result<Vec<f32>> {
117        let url = format!("{}/embeddings", self.base_url);
118
119        let response = self
120            .client
121            .post(&url)
122            .header("Authorization", format!("Bearer {}", self.api_key))
123            // OpenRouter requires HTTP-Referer header
124            .header("HTTP-Referer", "https://github.com/engram")
125            // Optional: helps OpenRouter track usage
126            .header("X-Title", "Engram Memory")
127            .json(&serde_json::json!({
128                "input": text,
129                "model": self.model,
130            }))
131            .send()
132            .await?;
133
134        if !response.status().is_success() {
135            let status = response.status();
136            let text = response.text().await.unwrap_or_default();
137            return Err(EngramError::Embedding(format!(
138                "Embedding API error {}: {}",
139                status, text
140            )));
141        }
142
143        let data: serde_json::Value = response.json().await?;
144        let embedding: Vec<f32> = data["data"][0]["embedding"]
145            .as_array()
146            .ok_or_else(|| EngramError::Embedding("Invalid response format".to_string()))?
147            .iter()
148            .filter_map(|v| v.as_f64().map(|f| f as f32))
149            .collect();
150
151        // Validate dimensions match configuration
152        if embedding.len() != self.dimensions {
153            return Err(EngramError::Embedding(format!(
154                "Embedding dimensions mismatch: expected {}, got {}. Set OPENAI_EMBEDDING_DIMENSIONS={} to match your model.",
155                self.dimensions, embedding.len(), embedding.len()
156            )));
157        }
158
159        Ok(embedding)
160    }
161
162    /// Async batch embedding (up to 2048 inputs per call)
163    pub async fn embed_batch_async(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
164        if texts.is_empty() {
165            return Ok(vec![]);
166        }
167
168        let url = format!("{}/embeddings", self.base_url);
169
170        // OpenAI allows up to 2048 inputs per batch
171        let mut all_embeddings = Vec::with_capacity(texts.len());
172
173        for chunk in texts.chunks(2048) {
174            let response = self
175                .client
176                .post(&url)
177                .header("Authorization", format!("Bearer {}", self.api_key))
178                // OpenRouter requires HTTP-Referer header
179                .header("HTTP-Referer", "https://github.com/engram")
180                .header("X-Title", "Engram Memory")
181                .json(&serde_json::json!({
182                    "input": chunk,
183                    "model": self.model,
184                }))
185                .send()
186                .await?;
187
188            if !response.status().is_success() {
189                let status = response.status();
190                let text = response.text().await.unwrap_or_default();
191                return Err(EngramError::Embedding(format!(
192                    "Embedding API error {}: {}",
193                    status, text
194                )));
195            }
196
197            let data: serde_json::Value = response.json().await?;
198            let embeddings: Vec<Vec<f32>> = data["data"]
199                .as_array()
200                .ok_or_else(|| EngramError::Embedding("Invalid response format".to_string()))?
201                .iter()
202                .map(|item| {
203                    item["embedding"]
204                        .as_array()
205                        .map(|arr| {
206                            arr.iter()
207                                .filter_map(|v| v.as_f64().map(|f| f as f32))
208                                .collect()
209                        })
210                        .unwrap_or_default()
211                })
212                .collect();
213
214            // Validate dimensions on first batch
215            if !embeddings.is_empty() && embeddings[0].len() != self.dimensions {
216                return Err(EngramError::Embedding(format!(
217                    "Embedding dimensions mismatch: expected {}, got {}. Set OPENAI_EMBEDDING_DIMENSIONS={} to match your model.",
218                    self.dimensions, embeddings[0].len(), embeddings[0].len()
219                )));
220            }
221
222            all_embeddings.extend(embeddings);
223        }
224
225        Ok(all_embeddings)
226    }
227}
228
229#[cfg(feature = "openai")]
230impl Embedder for OpenAIEmbedder {
231    fn embed(&self, text: &str) -> Result<Vec<f32>> {
232        // Blocking call for sync interface
233        tokio::task::block_in_place(|| {
234            tokio::runtime::Handle::current().block_on(self.embed_async(text))
235        })
236    }
237
238    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
239        tokio::task::block_in_place(|| {
240            tokio::runtime::Handle::current().block_on(self.embed_batch_async(texts))
241        })
242    }
243
244    fn dimensions(&self) -> usize {
245        self.dimensions
246    }
247
248    fn model_name(&self) -> &str {
249        &self.model
250    }
251}
252
253/// Create an embedder from configuration
254///
255/// Available models depend on enabled features:
256/// - `"tfidf"`: Always available, no external dependencies
257/// - `"openai"`: Requires `openai` feature and API key
258///
259/// For OpenAI-compatible APIs (OpenRouter, Azure, etc.), set:
260/// - `base_url`: API endpoint (e.g., "https://openrouter.ai/api/v1")
261/// - `embedding_model`: Model name (e.g., "openai/text-embedding-3-small")
262/// - `dimensions`: Expected output dimensions
263pub fn create_embedder(config: &EmbeddingConfig) -> Result<Arc<dyn Embedder>> {
264    match config.model.as_str() {
265        #[cfg(feature = "openai")]
266        "openai" => {
267            let api_key = config
268                .api_key
269                .clone()
270                .ok_or_else(|| EngramError::Config(
271                    "OPENAI_API_KEY required when ENGRAM_EMBEDDING_MODEL=openai".to_string()
272                ))?;
273            Ok(Arc::new(OpenAIEmbedder::with_config(
274                api_key,
275                config.base_url.clone(),
276                config.embedding_model.clone(),
277                Some(config.dimensions),
278            )))
279        }
280        #[cfg(not(feature = "openai"))]
281        "openai" => Err(EngramError::Config(
282            "OpenAI embeddings require the 'openai' feature to be enabled. Build with: cargo build --features openai".to_string(),
283        )),
284        "tfidf" => Ok(Arc::new(TfIdfEmbedder::new(config.dimensions))),
285        _ => Err(EngramError::Config(format!(
286            "Unknown embedding model: '{}'. Use 'openai' or 'tfidf'",
287            config.model
288        ))),
289    }
290}
291
292/// Cosine similarity between two vectors
293pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
294    if a.len() != b.len() || a.is_empty() {
295        return 0.0;
296    }
297
298    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
299    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
300    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
301
302    if norm_a == 0.0 || norm_b == 0.0 {
303        return 0.0;
304    }
305
306    dot / (norm_a * norm_b)
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn test_cosine_similarity() {
315        let a = vec![1.0, 0.0, 0.0];
316        let b = vec![1.0, 0.0, 0.0];
317        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
318
319        let c = vec![0.0, 1.0, 0.0];
320        assert!(cosine_similarity(&a, &c).abs() < 0.001);
321
322        let d = vec![-1.0, 0.0, 0.0];
323        assert!((cosine_similarity(&a, &d) + 1.0).abs() < 0.001);
324    }
325
326    #[test]
327    fn test_tfidf_embedder() {
328        let embedder = TfIdfEmbedder::new(384);
329        let embedding = embedder.embed("Hello world").unwrap();
330        assert_eq!(embedding.len(), 384);
331    }
332}