Skip to main content

converge_knowledge/embedding/
mod.rs

1//! Embedding generation for text vectorization.
2//!
3//! Supports multiple embedding backends:
4//! - Hash-based (default, for testing and offline use)
5//! - OpenAI API (production, high-quality embeddings)
6//!
7//! # Choosing a Backend
8//!
9//! | Backend | Quality | Speed | Cost | Offline |
10//! |---------|---------|-------|------|---------|
11//! | Hash    | Low     | Fast  | Free | Yes     |
12//! | OpenAI  | High    | Medium| Paid | No      |
13//!
14//! # Example
15//! ```ignore
16//! use converge_knowledge::embedding::EmbeddingEngine;
17//!
18//! // Development/testing: use hash embeddings
19//! let engine = EmbeddingEngine::new(256);
20//!
21//! // Production: use OpenAI (reads OPENAI_API_KEY)
22//! let engine = EmbeddingEngine::from_env()?;
23//!
24//! // Production with explicit key
25//! let engine = EmbeddingEngine::with_openai("sk-...", None);
26//! ```
27
28mod openai;
29
30pub use openai::{OpenAIConfig, OpenAIEmbedding, OpenAIModel, UsageSnapshot, UsageStats};
31
32use crate::error::{Error, Result};
33use std::any::Any;
34use std::collections::hash_map::DefaultHasher;
35use std::hash::{Hash, Hasher};
36
37/// Embedding provider trait for different backends.
38#[async_trait::async_trait]
39pub trait EmbeddingProvider: Any + Send + Sync {
40    /// Generate embedding for text.
41    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
42
43    /// Generate embeddings for multiple texts in batch.
44    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
45        let mut embeddings = Vec::with_capacity(texts.len());
46        for text in texts {
47            embeddings.push(self.embed(text).await?);
48        }
49        Ok(embeddings)
50    }
51
52    /// Get embedding dimensions.
53    fn dimensions(&self) -> usize;
54
55    /// Downcast support for provider-specific fast paths.
56    fn as_any(&self) -> &dyn Any;
57}
58
59/// Embedding engine for converting text to vectors.
60///
61/// Wraps different embedding providers with a unified interface.
62pub struct EmbeddingEngine {
63    provider: Box<dyn EmbeddingProvider>,
64}
65
66impl EmbeddingEngine {
67    /// Create a new embedding engine with hash-based embeddings.
68    ///
69    /// Use this for development, testing, or offline scenarios.
70    /// Hash embeddings are fast and free but lower quality.
71    pub fn new(dimensions: usize) -> Self {
72        Self {
73            provider: Box::new(HashEmbedding::new(dimensions)),
74        }
75    }
76
77    /// Create from environment variables.
78    ///
79    /// Reads OPENAI_API_KEY. Falls back to hash embeddings if not set.
80    pub fn from_env() -> Self {
81        match OpenAIEmbedding::from_env() {
82            Ok(provider) => Self {
83                provider: Box::new(provider),
84            },
85            Err(_) => {
86                tracing::warn!("OPENAI_API_KEY not set, falling back to hash embeddings");
87                Self::new(1536) // Match OpenAI default dimensions
88            }
89        }
90    }
91
92    /// Create from environment, returning error if not configured.
93    pub fn from_env_required() -> Result<Self> {
94        let provider = OpenAIEmbedding::from_env()?;
95        Ok(Self {
96            provider: Box::new(provider),
97        })
98    }
99
100    /// Create with OpenAI embeddings.
101    pub fn with_openai(api_key: impl Into<String>, model: Option<String>) -> Self {
102        Self {
103            provider: Box::new(OpenAIEmbedding::new(api_key, model)),
104        }
105    }
106
107    /// Create with OpenAI using custom configuration.
108    pub fn with_openai_config(api_key: impl Into<String>, config: OpenAIConfig) -> Self {
109        Self {
110            provider: Box::new(OpenAIEmbedding::with_config(api_key, config)),
111        }
112    }
113
114    /// Create with a custom provider.
115    pub fn with_provider(provider: Box<dyn EmbeddingProvider>) -> Self {
116        Self { provider }
117    }
118
119    /// Get the embedding dimensions.
120    pub fn dimensions(&self) -> usize {
121        self.provider.dimensions()
122    }
123
124    /// Generate an embedding for the given text (sync wrapper).
125    pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
126        // For sync compatibility, use tokio's block_on if available
127        // Otherwise fall back to hash embedding
128        if let Some(hash_provider) = self.as_hash_provider() {
129            hash_provider.embed_sync(text)
130        } else {
131            // Create a new runtime for async providers
132
133            tokio::runtime::Handle::try_current()
134                .map(|h| h.block_on(self.provider.embed(text)))
135                .unwrap_or_else(|_| {
136                    // Fallback to hash if no runtime
137                    let hash = HashEmbedding::new(self.dimensions());
138                    hash.embed_sync(text)
139                })
140        }
141    }
142
143    /// Generate an embedding asynchronously.
144    pub async fn embed_async(&self, text: &str) -> Result<Vec<f32>> {
145        self.provider.embed(text).await
146    }
147
148    /// Generate embeddings for multiple texts.
149    pub async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
150        self.provider.embed_batch(texts).await
151    }
152
153    /// Try to get underlying hash provider (for sync operations).
154    fn as_hash_provider(&self) -> Option<&HashEmbedding> {
155        self.provider.as_any().downcast_ref::<HashEmbedding>()
156    }
157
158    /// Compute similarity between two embeddings.
159    pub fn similarity(&self, a: &[f32], b: &[f32]) -> f32 {
160        if a.len() != b.len() {
161            return 0.0;
162        }
163
164        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
165        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
166        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
167
168        if norm_a == 0.0 || norm_b == 0.0 {
169            0.0
170        } else {
171            dot / (norm_a * norm_b)
172        }
173    }
174}
175
176/// Hash-based embedding for testing and development.
177pub struct HashEmbedding {
178    dimensions: usize,
179}
180
181impl HashEmbedding {
182    /// Create a new hash embedding engine.
183    pub fn new(dimensions: usize) -> Self {
184        Self { dimensions }
185    }
186
187    /// Synchronous embedding for hash-based provider.
188    pub fn embed_sync(&self, text: &str) -> Result<Vec<f32>> {
189        if text.is_empty() {
190            return Err(Error::embedding("Cannot embed empty text"));
191        }
192
193        let mut embedding = vec![0.0f32; self.dimensions];
194        let normalized_text = text.to_lowercase();
195
196        // Hash individual words
197        for word in normalized_text.split_whitespace() {
198            self.add_word_embedding(&mut embedding, word, 1.0);
199        }
200
201        // Hash bigrams for context
202        let words: Vec<&str> = normalized_text.split_whitespace().collect();
203        for window in words.windows(2) {
204            let bigram = format!("{} {}", window[0], window[1]);
205            self.add_word_embedding(&mut embedding, &bigram, 0.5);
206        }
207
208        // Hash trigrams for more context
209        for window in words.windows(3) {
210            let trigram = format!("{} {} {}", window[0], window[1], window[2]);
211            self.add_word_embedding(&mut embedding, &trigram, 0.3);
212        }
213
214        // Character-level features for typo tolerance
215        for word in words.iter() {
216            for char_ngram in word.as_bytes().windows(3) {
217                let hash = self.hash_bytes(char_ngram);
218                let idx = (hash as usize) % self.dimensions;
219                embedding[idx] += 0.1;
220            }
221        }
222
223        // Normalize to unit length
224        self.normalize(&mut embedding);
225
226        Ok(embedding)
227    }
228
229    fn add_word_embedding(&self, embedding: &mut [f32], text: &str, weight: f32) {
230        let hash = self.hash_text(text);
231        for i in 0..8 {
232            let idx = ((hash.wrapping_add(i * 0x9e3779b9)) as usize) % self.dimensions;
233            let sign = if (hash >> i) & 1 == 0 { 1.0 } else { -1.0 };
234            embedding[idx] += sign * weight;
235        }
236    }
237
238    fn hash_text(&self, text: &str) -> u64 {
239        let mut hasher = DefaultHasher::new();
240        text.hash(&mut hasher);
241        hasher.finish()
242    }
243
244    fn hash_bytes(&self, bytes: &[u8]) -> u64 {
245        let mut hasher = DefaultHasher::new();
246        bytes.hash(&mut hasher);
247        hasher.finish()
248    }
249
250    fn normalize(&self, embedding: &mut [f32]) {
251        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
252        if norm > 0.0 {
253            for x in embedding.iter_mut() {
254                *x /= norm;
255            }
256        }
257    }
258}
259
260#[async_trait::async_trait]
261impl EmbeddingProvider for HashEmbedding {
262    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
263        self.embed_sync(text)
264    }
265
266    fn dimensions(&self) -> usize {
267        self.dimensions
268    }
269
270    fn as_any(&self) -> &dyn Any {
271        self
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn test_embedding_dimensions() {
281        let engine = EmbeddingEngine::new(128);
282        let embedding = engine.embed("test text").unwrap();
283        assert_eq!(embedding.len(), 128);
284    }
285
286    #[test]
287    fn test_embedding_consistency() {
288        let engine = EmbeddingEngine::new(64);
289        let emb1 = engine.embed("hello world").unwrap();
290        let emb2 = engine.embed("hello world").unwrap();
291        assert_eq!(emb1, emb2);
292    }
293
294    #[test]
295    fn test_embedding_similarity() {
296        let engine = EmbeddingEngine::new(128);
297
298        let emb1 = engine.embed("rust programming language").unwrap();
299        let emb2 = engine.embed("rust programming").unwrap();
300        let emb3 = engine.embed("cooking recipes").unwrap();
301
302        let sim_similar = engine.similarity(&emb1, &emb2);
303        let sim_different = engine.similarity(&emb1, &emb3);
304
305        assert!(sim_similar > sim_different);
306    }
307
308    #[test]
309    fn test_normalized_embeddings() {
310        let engine = EmbeddingEngine::new(256);
311        let embedding = engine.embed("some text here").unwrap();
312
313        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
314        assert!((norm - 1.0).abs() < 1e-5);
315    }
316
317    #[test]
318    fn test_empty_text_error() {
319        let engine = EmbeddingEngine::new(64);
320        assert!(engine.embed("").is_err());
321    }
322}