Skip to main content

converge_knowledge/embedding/
mod.rs

1//! Embedding generation for text vectorization.
2//!
3//! Supports multiple embedding backends:
4//! - Hash-based (default, for testing and offline use)
5//! - OpenAI API (production, high-quality embeddings)
6//!
7//! # Choosing a Backend
8//!
9//! | Backend | Quality | Speed | Cost | Offline |
10//! |---------|---------|-------|------|---------|
11//! | Hash    | Low     | Fast  | Free | Yes     |
12//! | OpenAI  | High    | Medium| Paid | No      |
13//!
14//! # Example
15//! ```ignore
16//! use converge_knowledge::embedding::EmbeddingEngine;
17//!
18//! // Development/testing: use hash embeddings
19//! let engine = EmbeddingEngine::new(256);
20//!
21//! // Production: use OpenAI (reads OPENAI_API_KEY)
22//! let engine = EmbeddingEngine::from_env()?;
23//!
24//! // Production with explicit key
25//! let engine = EmbeddingEngine::with_openai("sk-...", None);
26//! ```
27
28mod openai;
29
30pub use openai::{OpenAIConfig, OpenAIEmbedding, OpenAIModel, UsageSnapshot, UsageStats};
31
32use crate::error::{Error, Result};
33use std::collections::hash_map::DefaultHasher;
34use std::hash::{Hash, Hasher};
35
36/// Embedding provider trait for different backends.
37#[async_trait::async_trait]
38pub trait EmbeddingProvider: Send + Sync {
39    /// Generate embedding for text.
40    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
41
42    /// Generate embeddings for multiple texts in batch.
43    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
44        let mut embeddings = Vec::with_capacity(texts.len());
45        for text in texts {
46            embeddings.push(self.embed(text).await?);
47        }
48        Ok(embeddings)
49    }
50
51    /// Get embedding dimensions.
52    fn dimensions(&self) -> usize;
53}
54
55/// Embedding engine for converting text to vectors.
56///
57/// Wraps different embedding providers with a unified interface.
58pub struct EmbeddingEngine {
59    provider: Box<dyn EmbeddingProvider>,
60}
61
62impl EmbeddingEngine {
63    /// Create a new embedding engine with hash-based embeddings.
64    ///
65    /// Use this for development, testing, or offline scenarios.
66    /// Hash embeddings are fast and free but lower quality.
67    pub fn new(dimensions: usize) -> Self {
68        Self {
69            provider: Box::new(HashEmbedding::new(dimensions)),
70        }
71    }
72
73    /// Create from environment variables.
74    ///
75    /// Reads OPENAI_API_KEY. Falls back to hash embeddings if not set.
76    pub fn from_env() -> Self {
77        match OpenAIEmbedding::from_env() {
78            Ok(provider) => Self {
79                provider: Box::new(provider),
80            },
81            Err(_) => {
82                tracing::warn!("OPENAI_API_KEY not set, falling back to hash embeddings");
83                Self::new(1536) // Match OpenAI default dimensions
84            }
85        }
86    }
87
88    /// Create from environment, returning error if not configured.
89    pub fn from_env_required() -> Result<Self> {
90        let provider = OpenAIEmbedding::from_env()?;
91        Ok(Self {
92            provider: Box::new(provider),
93        })
94    }
95
96    /// Create with OpenAI embeddings.
97    pub fn with_openai(api_key: impl Into<String>, model: Option<String>) -> Self {
98        Self {
99            provider: Box::new(OpenAIEmbedding::new(api_key, model)),
100        }
101    }
102
103    /// Create with OpenAI using custom configuration.
104    pub fn with_openai_config(api_key: impl Into<String>, config: OpenAIConfig) -> Self {
105        Self {
106            provider: Box::new(OpenAIEmbedding::with_config(api_key, config)),
107        }
108    }
109
110    /// Create with a custom provider.
111    pub fn with_provider(provider: Box<dyn EmbeddingProvider>) -> Self {
112        Self { provider }
113    }
114
115    /// Get the embedding dimensions.
116    pub fn dimensions(&self) -> usize {
117        self.provider.dimensions()
118    }
119
120    /// Generate an embedding for the given text (sync wrapper).
121    pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
122        // For sync compatibility, use tokio's block_on if available
123        // Otherwise fall back to hash embedding
124        if let Some(hash_provider) = self.as_hash_provider() {
125            hash_provider.embed_sync(text)
126        } else {
127            // Create a new runtime for async providers
128            let rt = tokio::runtime::Handle::try_current()
129                .map(|h| h.block_on(self.provider.embed(text)))
130                .unwrap_or_else(|_| {
131                    // Fallback to hash if no runtime
132                    let hash = HashEmbedding::new(self.dimensions());
133                    hash.embed_sync(text)
134                });
135            rt
136        }
137    }
138
139    /// Generate an embedding asynchronously.
140    pub async fn embed_async(&self, text: &str) -> Result<Vec<f32>> {
141        self.provider.embed(text).await
142    }
143
144    /// Generate embeddings for multiple texts.
145    pub async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
146        self.provider.embed_batch(texts).await
147    }
148
149    /// Try to get underlying hash provider (for sync operations).
150    fn as_hash_provider(&self) -> Option<&HashEmbedding> {
151        // Use Any trait for downcasting
152        None // Simplified - in practice use Any
153    }
154
155    /// Compute similarity between two embeddings.
156    pub fn similarity(&self, a: &[f32], b: &[f32]) -> f32 {
157        if a.len() != b.len() {
158            return 0.0;
159        }
160
161        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
162        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
163        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
164
165        if norm_a == 0.0 || norm_b == 0.0 {
166            0.0
167        } else {
168            dot / (norm_a * norm_b)
169        }
170    }
171}
172
173/// Hash-based embedding for testing and development.
174pub struct HashEmbedding {
175    dimensions: usize,
176}
177
178impl HashEmbedding {
179    /// Create a new hash embedding engine.
180    pub fn new(dimensions: usize) -> Self {
181        Self { dimensions }
182    }
183
184    /// Synchronous embedding for hash-based provider.
185    pub fn embed_sync(&self, text: &str) -> Result<Vec<f32>> {
186        if text.is_empty() {
187            return Err(Error::embedding("Cannot embed empty text"));
188        }
189
190        let mut embedding = vec![0.0f32; self.dimensions];
191        let normalized_text = text.to_lowercase();
192
193        // Hash individual words
194        for word in normalized_text.split_whitespace() {
195            self.add_word_embedding(&mut embedding, word, 1.0);
196        }
197
198        // Hash bigrams for context
199        let words: Vec<&str> = normalized_text.split_whitespace().collect();
200        for window in words.windows(2) {
201            let bigram = format!("{} {}", window[0], window[1]);
202            self.add_word_embedding(&mut embedding, &bigram, 0.5);
203        }
204
205        // Hash trigrams for more context
206        for window in words.windows(3) {
207            let trigram = format!("{} {} {}", window[0], window[1], window[2]);
208            self.add_word_embedding(&mut embedding, &trigram, 0.3);
209        }
210
211        // Character-level features for typo tolerance
212        for word in words.iter() {
213            for char_ngram in word.as_bytes().windows(3) {
214                let hash = self.hash_bytes(char_ngram);
215                let idx = (hash as usize) % self.dimensions;
216                embedding[idx] += 0.1;
217            }
218        }
219
220        // Normalize to unit length
221        self.normalize(&mut embedding);
222
223        Ok(embedding)
224    }
225
226    fn add_word_embedding(&self, embedding: &mut [f32], text: &str, weight: f32) {
227        let hash = self.hash_text(text);
228        for i in 0..8 {
229            let idx = ((hash.wrapping_add(i * 0x9e3779b9)) as usize) % self.dimensions;
230            let sign = if (hash >> i) & 1 == 0 { 1.0 } else { -1.0 };
231            embedding[idx] += sign * weight;
232        }
233    }
234
235    fn hash_text(&self, text: &str) -> u64 {
236        let mut hasher = DefaultHasher::new();
237        text.hash(&mut hasher);
238        hasher.finish()
239    }
240
241    fn hash_bytes(&self, bytes: &[u8]) -> u64 {
242        let mut hasher = DefaultHasher::new();
243        bytes.hash(&mut hasher);
244        hasher.finish()
245    }
246
247    fn normalize(&self, embedding: &mut [f32]) {
248        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
249        if norm > 0.0 {
250            for x in embedding.iter_mut() {
251                *x /= norm;
252            }
253        }
254    }
255}
256
257#[async_trait::async_trait]
258impl EmbeddingProvider for HashEmbedding {
259    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
260        self.embed_sync(text)
261    }
262
263    fn dimensions(&self) -> usize {
264        self.dimensions
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    #[test]
273    fn test_embedding_dimensions() {
274        let engine = EmbeddingEngine::new(128);
275        let embedding = engine.embed("test text").unwrap();
276        assert_eq!(embedding.len(), 128);
277    }
278
279    #[test]
280    fn test_embedding_consistency() {
281        let engine = EmbeddingEngine::new(64);
282        let emb1 = engine.embed("hello world").unwrap();
283        let emb2 = engine.embed("hello world").unwrap();
284        assert_eq!(emb1, emb2);
285    }
286
287    #[test]
288    fn test_embedding_similarity() {
289        let engine = EmbeddingEngine::new(128);
290
291        let emb1 = engine.embed("rust programming language").unwrap();
292        let emb2 = engine.embed("rust programming").unwrap();
293        let emb3 = engine.embed("cooking recipes").unwrap();
294
295        let sim_similar = engine.similarity(&emb1, &emb2);
296        let sim_different = engine.similarity(&emb1, &emb3);
297
298        assert!(sim_similar > sim_different);
299    }
300
301    #[test]
302    fn test_normalized_embeddings() {
303        let engine = EmbeddingEngine::new(256);
304        let embedding = engine.embed("some text here").unwrap();
305
306        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
307        assert!((norm - 1.0).abs() < 1e-5);
308    }
309
310    #[test]
311    fn test_empty_text_error() {
312        let engine = EmbeddingEngine::new(64);
313        assert!(engine.embed("").is_err());
314    }
315}