Skip to main content

engram/embedding/
mod.rs

1//! Embedding generation and async queue management (RML-873)
2//!
3//! Supports multiple embedding backends:
4//! - OpenAI API (text-embedding-3-small) - requires `openai` feature
5//! - TF-IDF fallback (no external dependencies)
6//!
7//! Features:
8//! - LRU embedding cache with zero-copy Arc<[f32]> sharing
9//! - Async queue processing for batch operations
10//!
11//! # Feature Flags
12//!
13//! - `openai`: Enables OpenAI embedding backend (requires API key)
14
15mod cache;
16mod queue;
17mod tfidf;
18
19pub use cache::{EmbeddingCache, EmbeddingCacheStats};
20pub use queue::{get_embedding, get_embedding_status, EmbeddingQueue, EmbeddingWorker};
21pub use tfidf::TfIdfEmbedder;
22
23use std::sync::Arc;
24
25use crate::error::{EngramError, Result};
26use crate::types::EmbeddingConfig;
27
28/// Trait for embedding generators
29pub trait Embedder: Send + Sync {
30    /// Generate embedding for a single text
31    fn embed(&self, text: &str) -> Result<Vec<f32>>;
32
33    /// Generate embeddings for multiple texts (batch)
34    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
35        texts.iter().map(|t| self.embed(t)).collect()
36    }
37
38    /// Get embedding dimensions
39    fn dimensions(&self) -> usize;
40
41    /// Get model name
42    fn model_name(&self) -> &str;
43}
44
45/// OpenAI embedding client
46///
47/// Requires the `openai` feature to be enabled.
48/// Supports OpenAI, OpenRouter, Azure OpenAI, and other OpenAI-compatible APIs.
49#[cfg(feature = "openai")]
50pub struct OpenAIEmbedder {
51    client: reqwest::Client,
52    api_key: String,
53    base_url: String,
54    model: String,
55    dimensions: usize,
56}
57
58#[cfg(feature = "openai")]
59impl OpenAIEmbedder {
60    /// Create a new OpenAI embedder with default settings
61    pub fn new(api_key: String) -> Self {
62        Self {
63            client: reqwest::Client::new(),
64            api_key,
65            base_url: "https://api.openai.com/v1".to_string(),
66            model: "text-embedding-3-small".to_string(),
67            dimensions: 1536,
68        }
69    }
70
71    /// Create a new OpenAI embedder with custom settings
72    ///
73    /// # Arguments
74    /// * `api_key` - API key for authentication
75    /// * `base_url` - API base URL (e.g., "https://openrouter.ai/api/v1" for OpenRouter)
76    /// * `model` - Model name (e.g., "openai/text-embedding-3-small" for OpenRouter)
77    /// * `dimensions` - Expected embedding dimensions (must match model output)
78    pub fn with_config(
79        api_key: String,
80        base_url: Option<String>,
81        model: Option<String>,
82        dimensions: Option<usize>,
83    ) -> Self {
84        Self {
85            client: reqwest::Client::new(),
86            api_key,
87            base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
88            model: model.unwrap_or_else(|| "text-embedding-3-small".to_string()),
89            dimensions: dimensions.unwrap_or(1536),
90        }
91    }
92
93    /// Legacy constructor for backwards compatibility
94    pub fn with_model(api_key: String, model: String, dimensions: usize) -> Self {
95        Self {
96            client: reqwest::Client::new(),
97            api_key,
98            base_url: "https://api.openai.com/v1".to_string(),
99            model,
100            dimensions,
101        }
102    }
103
104    /// Async embedding call to OpenAI-compatible API
105    pub async fn embed_async(&self, text: &str) -> Result<Vec<f32>> {
106        let url = format!("{}/embeddings", self.base_url);
107
108        let response = self
109            .client
110            .post(&url)
111            .header("Authorization", format!("Bearer {}", self.api_key))
112            // OpenRouter requires HTTP-Referer header
113            .header("HTTP-Referer", "https://github.com/engram")
114            // Optional: helps OpenRouter track usage
115            .header("X-Title", "Engram Memory")
116            .json(&serde_json::json!({
117                "input": text,
118                "model": self.model,
119            }))
120            .send()
121            .await?;
122
123        if !response.status().is_success() {
124            let status = response.status();
125            let text = response.text().await.unwrap_or_default();
126            return Err(EngramError::Embedding(format!(
127                "Embedding API error {}: {}",
128                status, text
129            )));
130        }
131
132        let data: serde_json::Value = response.json().await?;
133        let embedding: Vec<f32> = data["data"][0]["embedding"]
134            .as_array()
135            .ok_or_else(|| EngramError::Embedding("Invalid response format".to_string()))?
136            .iter()
137            .filter_map(|v| v.as_f64().map(|f| f as f32))
138            .collect();
139
140        // Validate dimensions match configuration
141        if embedding.len() != self.dimensions {
142            return Err(EngramError::Embedding(format!(
143                "Embedding dimensions mismatch: expected {}, got {}. Set OPENAI_EMBEDDING_DIMENSIONS={} to match your model.",
144                self.dimensions, embedding.len(), embedding.len()
145            )));
146        }
147
148        Ok(embedding)
149    }
150
151    /// Async batch embedding (up to 2048 inputs per call)
152    pub async fn embed_batch_async(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
153        if texts.is_empty() {
154            return Ok(vec![]);
155        }
156
157        let url = format!("{}/embeddings", self.base_url);
158
159        // OpenAI allows up to 2048 inputs per batch
160        let mut all_embeddings = Vec::with_capacity(texts.len());
161
162        for chunk in texts.chunks(2048) {
163            let response = self
164                .client
165                .post(&url)
166                .header("Authorization", format!("Bearer {}", self.api_key))
167                // OpenRouter requires HTTP-Referer header
168                .header("HTTP-Referer", "https://github.com/engram")
169                .header("X-Title", "Engram Memory")
170                .json(&serde_json::json!({
171                    "input": chunk,
172                    "model": self.model,
173                }))
174                .send()
175                .await?;
176
177            if !response.status().is_success() {
178                let status = response.status();
179                let text = response.text().await.unwrap_or_default();
180                return Err(EngramError::Embedding(format!(
181                    "Embedding API error {}: {}",
182                    status, text
183                )));
184            }
185
186            let data: serde_json::Value = response.json().await?;
187            let embeddings: Vec<Vec<f32>> = data["data"]
188                .as_array()
189                .ok_or_else(|| EngramError::Embedding("Invalid response format".to_string()))?
190                .iter()
191                .map(|item| {
192                    item["embedding"]
193                        .as_array()
194                        .map(|arr| {
195                            arr.iter()
196                                .filter_map(|v| v.as_f64().map(|f| f as f32))
197                                .collect()
198                        })
199                        .unwrap_or_default()
200                })
201                .collect();
202
203            // Validate dimensions on first batch
204            if !embeddings.is_empty() && embeddings[0].len() != self.dimensions {
205                return Err(EngramError::Embedding(format!(
206                    "Embedding dimensions mismatch: expected {}, got {}. Set OPENAI_EMBEDDING_DIMENSIONS={} to match your model.",
207                    self.dimensions, embeddings[0].len(), embeddings[0].len()
208                )));
209            }
210
211            all_embeddings.extend(embeddings);
212        }
213
214        Ok(all_embeddings)
215    }
216}
217
218#[cfg(feature = "openai")]
219impl Embedder for OpenAIEmbedder {
220    fn embed(&self, text: &str) -> Result<Vec<f32>> {
221        // Blocking call for sync interface
222        tokio::task::block_in_place(|| {
223            tokio::runtime::Handle::current().block_on(self.embed_async(text))
224        })
225    }
226
227    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
228        tokio::task::block_in_place(|| {
229            tokio::runtime::Handle::current().block_on(self.embed_batch_async(texts))
230        })
231    }
232
233    fn dimensions(&self) -> usize {
234        self.dimensions
235    }
236
237    fn model_name(&self) -> &str {
238        &self.model
239    }
240}
241
242/// Create an embedder from configuration
243///
244/// Available models depend on enabled features:
245/// - `"tfidf"`: Always available, no external dependencies
246/// - `"openai"`: Requires `openai` feature and API key
247///
248/// For OpenAI-compatible APIs (OpenRouter, Azure, etc.), set:
249/// - `base_url`: API endpoint (e.g., "https://openrouter.ai/api/v1")
250/// - `embedding_model`: Model name (e.g., "openai/text-embedding-3-small")
251/// - `dimensions`: Expected output dimensions
252pub fn create_embedder(config: &EmbeddingConfig) -> Result<Arc<dyn Embedder>> {
253    match config.model.as_str() {
254        #[cfg(feature = "openai")]
255        "openai" => {
256            let api_key = config
257                .api_key
258                .clone()
259                .ok_or_else(|| EngramError::Config(
260                    "OPENAI_API_KEY required when ENGRAM_EMBEDDING_MODEL=openai".to_string()
261                ))?;
262            Ok(Arc::new(OpenAIEmbedder::with_config(
263                api_key,
264                config.base_url.clone(),
265                config.embedding_model.clone(),
266                Some(config.dimensions),
267            )))
268        }
269        #[cfg(not(feature = "openai"))]
270        "openai" => Err(EngramError::Config(
271            "OpenAI embeddings require the 'openai' feature to be enabled. Build with: cargo build --features openai".to_string(),
272        )),
273        "tfidf" => Ok(Arc::new(TfIdfEmbedder::new(config.dimensions))),
274        _ => Err(EngramError::Config(format!(
275            "Unknown embedding model: '{}'. Use 'openai' or 'tfidf'",
276            config.model
277        ))),
278    }
279}
280
281/// Cosine similarity between two vectors
282pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
283    if a.len() != b.len() || a.is_empty() {
284        return 0.0;
285    }
286
287    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
288    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
289    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
290
291    if norm_a == 0.0 || norm_b == 0.0 {
292        return 0.0;
293    }
294
295    dot / (norm_a * norm_b)
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_cosine_similarity() {
304        let a = vec![1.0, 0.0, 0.0];
305        let b = vec![1.0, 0.0, 0.0];
306        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
307
308        let c = vec![0.0, 1.0, 0.0];
309        assert!(cosine_similarity(&a, &c).abs() < 0.001);
310
311        let d = vec![-1.0, 0.0, 0.0];
312        assert!((cosine_similarity(&a, &d) + 1.0).abs() < 0.001);
313    }
314
315    #[test]
316    fn test_tfidf_embedder() {
317        let embedder = TfIdfEmbedder::new(384);
318        let embedding = embedder.embed("Hello world").unwrap();
319        assert_eq!(embedding.len(), 384);
320    }
321}