Skip to main content

mnemos/embedding/
mod.rs

1//! Embedding generation for text vectorization.
2//!
3//! Supports multiple embedding backends:
4//! - Hash-based (default, for testing and offline use)
5//! - OpenAI API (production, high-quality embeddings)
6//!
7//! # Choosing a Backend
8//!
9//! | Backend | Quality | Speed | Cost | Offline |
10//! |---------|---------|-------|------|---------|
11//! | Hash    | Low     | Fast  | Free | Yes     |
12//! | OpenAI  | High    | Medium| Paid | No      |
13//!
14//! # Example
15//! ```ignore
16//! use mnemos::embedding::EmbeddingEngine;
17//!
18//! // Development/testing: use hash embeddings
19//! let engine = EmbeddingEngine::new(256);
20//!
21//! // Production: use OpenAI (reads OPENAI_API_KEY)
22//! let engine = EmbeddingEngine::from_env()?;
23//!
24//! // Production with explicit key
25//! let engine = EmbeddingEngine::with_openai("sk-...", None);
26//! ```
27
28mod openai;
29
30pub use openai::{OpenAIConfig, OpenAIEmbedding, OpenAIModel, UsageSnapshot, UsageStats};
31
32use crate::error::{Error, Result};
33use std::any::Any;
34use std::collections::hash_map::DefaultHasher;
35use std::hash::{Hash, Hasher};
36
37/// Embedding provider trait for different backends.
38#[async_trait::async_trait]
39pub trait EmbeddingProvider: Any + Send + Sync {
40    /// Generate embedding for text.
41    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
42
43    /// Generate embeddings for multiple texts in batch.
44    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
45        let mut embeddings = Vec::with_capacity(texts.len());
46        for text in texts {
47            embeddings.push(self.embed(text).await?);
48        }
49        Ok(embeddings)
50    }
51
52    /// Get embedding dimensions.
53    fn dimensions(&self) -> usize;
54
55    /// Downcast support for provider-specific fast paths.
56    fn as_any(&self) -> &dyn Any;
57}
58
59/// Embedding engine for converting text to vectors.
60///
61/// Wraps different embedding providers with a unified interface.
62pub struct EmbeddingEngine {
63    provider: Box<dyn EmbeddingProvider>,
64}
65
66impl EmbeddingEngine {
67    /// Create a new embedding engine with hash-based embeddings.
68    ///
69    /// Use this for development, testing, or offline scenarios.
70    /// Hash embeddings are fast and free but lower quality.
71    pub fn new(dimensions: usize) -> Self {
72        Self {
73            provider: Box::new(HashEmbedding::new(dimensions)),
74        }
75    }
76
77    /// Create from environment variables.
78    ///
79    /// Reads OPENAI_API_KEY. Falls back to hash embeddings if not set.
80    pub fn from_env() -> Self {
81        if let Ok(provider) = OpenAIEmbedding::from_env() {
82            Self {
83                provider: Box::new(provider),
84            }
85        } else {
86            tracing::warn!("OPENAI_API_KEY not set, falling back to hash embeddings");
87            Self::new(1536) // Match OpenAI default dimensions
88        }
89    }
90
91    /// Create from environment, returning error if not configured.
92    pub fn from_env_required() -> Result<Self> {
93        let provider = OpenAIEmbedding::from_env()?;
94        Ok(Self {
95            provider: Box::new(provider),
96        })
97    }
98
99    /// Create with OpenAI embeddings.
100    pub fn with_openai(api_key: impl Into<String>, model: Option<String>) -> Self {
101        Self {
102            provider: Box::new(OpenAIEmbedding::new(api_key, model)),
103        }
104    }
105
106    /// Create with OpenAI using custom configuration.
107    pub fn with_openai_config(api_key: impl Into<String>, config: OpenAIConfig) -> Self {
108        Self {
109            provider: Box::new(OpenAIEmbedding::with_config(api_key, config)),
110        }
111    }
112
113    /// Create with a custom provider.
114    pub fn with_provider(provider: Box<dyn EmbeddingProvider>) -> Self {
115        Self { provider }
116    }
117
118    /// Get the embedding dimensions.
119    pub fn dimensions(&self) -> usize {
120        self.provider.dimensions()
121    }
122
123    /// Generate an embedding for the given text.
124    pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
125        self.provider.embed(text).await
126    }
127
128    /// Generate embeddings for multiple texts.
129    pub async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
130        self.provider.embed_batch(texts).await
131    }
132
133    /// Compute similarity between two embeddings.
134    pub fn similarity(&self, a: &[f32], b: &[f32]) -> f32 {
135        if a.len() != b.len() {
136            return 0.0;
137        }
138
139        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
140        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
141        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
142
143        if norm_a == 0.0 || norm_b == 0.0 {
144            0.0
145        } else {
146            dot / (norm_a * norm_b)
147        }
148    }
149}
150
151/// Hash-based embedding for testing and development.
152pub struct HashEmbedding {
153    dimensions: usize,
154}
155
156impl HashEmbedding {
157    /// Create a new hash embedding engine.
158    pub fn new(dimensions: usize) -> Self {
159        Self { dimensions }
160    }
161
162    fn embed_sync(&self, text: &str) -> Result<Vec<f32>> {
163        if text.is_empty() {
164            return Err(Error::embedding("Cannot embed empty text"));
165        }
166
167        let mut embedding = vec![0.0f32; self.dimensions];
168        let normalized_text = text.to_lowercase();
169
170        // Hash individual words
171        for word in normalized_text.split_whitespace() {
172            self.add_word_embedding(&mut embedding, word, 1.0);
173        }
174
175        // Hash bigrams for context
176        let words: Vec<&str> = normalized_text.split_whitespace().collect();
177        for window in words.windows(2) {
178            let bigram = format!("{} {}", window[0], window[1]);
179            self.add_word_embedding(&mut embedding, &bigram, 0.5);
180        }
181
182        // Hash trigrams for more context
183        for window in words.windows(3) {
184            let trigram = format!("{} {} {}", window[0], window[1], window[2]);
185            self.add_word_embedding(&mut embedding, &trigram, 0.3);
186        }
187
188        // Character-level features for typo tolerance
189        for word in &words {
190            for char_ngram in word.as_bytes().windows(3) {
191                let hash = self.hash_bytes(char_ngram);
192                let idx = (hash as usize) % self.dimensions;
193                embedding[idx] += 0.1;
194            }
195        }
196
197        // Normalize to unit length
198        self.normalize(&mut embedding);
199
200        Ok(embedding)
201    }
202
203    fn add_word_embedding(&self, embedding: &mut [f32], text: &str, weight: f32) {
204        let hash = self.hash_text(text);
205        for i in 0..8 {
206            let idx = ((hash.wrapping_add(i * 0x9e37_79b9)) as usize) % self.dimensions;
207            let sign = if (hash >> i) & 1 == 0 { 1.0 } else { -1.0 };
208            embedding[idx] += sign * weight;
209        }
210    }
211
212    fn hash_text(&self, text: &str) -> u64 {
213        let mut hasher = DefaultHasher::new();
214        text.hash(&mut hasher);
215        hasher.finish()
216    }
217
218    fn hash_bytes(&self, bytes: &[u8]) -> u64 {
219        let mut hasher = DefaultHasher::new();
220        bytes.hash(&mut hasher);
221        hasher.finish()
222    }
223
224    fn normalize(&self, embedding: &mut [f32]) {
225        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
226        if norm > 0.0 {
227            for x in embedding.iter_mut() {
228                *x /= norm;
229            }
230        }
231    }
232}
233
234#[async_trait::async_trait]
235impl EmbeddingProvider for HashEmbedding {
236    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
237        self.embed_sync(text)
238    }
239
240    fn dimensions(&self) -> usize {
241        self.dimensions
242    }
243
244    fn as_any(&self) -> &dyn Any {
245        self
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[tokio::test]
254    async fn test_embedding_dimensions() {
255        let engine = EmbeddingEngine::new(128);
256        let embedding = engine.embed("test text").await.unwrap();
257        assert_eq!(embedding.len(), 128);
258    }
259
260    #[tokio::test]
261    async fn test_embedding_consistency() {
262        let engine = EmbeddingEngine::new(64);
263        let emb1 = engine.embed("hello world").await.unwrap();
264        let emb2 = engine.embed("hello world").await.unwrap();
265        assert_eq!(emb1, emb2);
266    }
267
268    #[tokio::test]
269    async fn test_embedding_similarity() {
270        let engine = EmbeddingEngine::new(128);
271
272        let emb1 = engine.embed("rust programming language").await.unwrap();
273        let emb2 = engine.embed("rust programming").await.unwrap();
274        let emb3 = engine.embed("cooking recipes").await.unwrap();
275
276        let sim_similar = engine.similarity(&emb1, &emb2);
277        let sim_different = engine.similarity(&emb1, &emb3);
278
279        assert!(sim_similar > sim_different);
280    }
281
282    #[tokio::test]
283    async fn test_normalized_embeddings() {
284        let engine = EmbeddingEngine::new(256);
285        let embedding = engine.embed("some text here").await.unwrap();
286
287        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
288        assert!((norm - 1.0).abs() < 1e-5);
289    }
290
291    #[tokio::test]
292    async fn test_empty_text_error() {
293        let engine = EmbeddingEngine::new(64);
294        assert!(engine.embed("").await.is_err());
295    }
296}