oxify_vector/
embeddings.rs

1//! Embedding Management
2//!
3//! Provides embedding generation from various providers (OpenAI, local models)
4//! with caching and batch processing support.
5//!
6//! ## Features
7//!
8//! - **Multiple Providers**: OpenAI, local models, custom implementations
9//! - **Caching**: TTL-based cache to reduce redundant API calls
10//! - **Batch Processing**: Efficient batch embedding generation
11//! - **Async Support**: Non-blocking embedding generation
12//!
13//! ## Example
14//!
15//! ```rust
16//! use oxify_vector::embeddings::{EmbeddingProvider, MockEmbeddingProvider};
17//!
18//! # fn example() -> anyhow::Result<()> {
19//! // Use mock provider for testing
20//! let provider = MockEmbeddingProvider::new(384);
21//!
22//! let text = "Hello, world!";
23//! let embedding = provider.embed(text)?;
24//!
25//! assert_eq!(embedding.len(), 384);
26//! # Ok(())
27//! # }
28//! ```
29
30use anyhow::{Context, Result};
31use serde::{Deserialize, Serialize};
32use std::collections::HashMap;
33use std::sync::{Arc, Mutex};
34use std::time::{Duration, SystemTime};
35
36/// Trait for embedding providers
37pub trait EmbeddingProvider: Send + Sync {
38    /// Generate embedding for a single text
39    fn embed(&self, text: &str) -> Result<Vec<f32>>;
40
41    /// Generate embeddings for multiple texts (batch processing)
42    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
43        texts.iter().map(|text| self.embed(text)).collect()
44    }
45
46    /// Get the dimension of embeddings produced by this provider
47    fn dimension(&self) -> usize;
48
49    /// Get the provider name
50    fn name(&self) -> &str;
51}
52
53/// Mock embedding provider for testing
54///
55/// Generates deterministic embeddings based on text hash
56#[derive(Debug, Clone)]
57pub struct MockEmbeddingProvider {
58    dimension: usize,
59}
60
61impl MockEmbeddingProvider {
62    pub fn new(dimension: usize) -> Self {
63        Self { dimension }
64    }
65
66    fn hash_text(&self, text: &str) -> u64 {
67        // Simple hash function for deterministic embeddings
68        let mut hash: u64 = 5381;
69        for byte in text.as_bytes() {
70            hash = hash.wrapping_mul(33).wrapping_add(*byte as u64);
71        }
72        hash
73    }
74}
75
76impl EmbeddingProvider for MockEmbeddingProvider {
77    fn embed(&self, text: &str) -> Result<Vec<f32>> {
78        let hash = self.hash_text(text);
79        let mut embedding = Vec::with_capacity(self.dimension);
80
81        for i in 0..self.dimension {
82            let val = ((hash.wrapping_add(i as u64) % 1000) as f32) / 1000.0;
83            embedding.push(val);
84        }
85
86        // Normalize to unit vector
87        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
88        if norm > 0.0 {
89            for val in &mut embedding {
90                *val /= norm;
91            }
92        }
93
94        Ok(embedding)
95    }
96
97    fn dimension(&self) -> usize {
98        self.dimension
99    }
100
101    fn name(&self) -> &str {
102        "mock"
103    }
104}
105
106/// OpenAI embedding provider configuration
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct OpenAIConfig {
109    /// API key for OpenAI
110    pub api_key: String,
111    /// Model name (e.g., "text-embedding-ada-002", "text-embedding-3-small")
112    pub model: String,
113    /// API endpoint (defaults to OpenAI's endpoint)
114    pub endpoint: Option<String>,
115}
116
117impl Default for OpenAIConfig {
118    fn default() -> Self {
119        Self {
120            api_key: String::new(),
121            model: "text-embedding-3-small".to_string(),
122            endpoint: None,
123        }
124    }
125}
126
127/// OpenAI embedding provider (stub for future implementation)
128///
129/// Note: Requires `reqwest` dependency for actual API calls
130/// This is a placeholder implementation
131#[derive(Debug, Clone)]
132pub struct OpenAIEmbeddingProvider {
133    #[allow(dead_code)]
134    config: OpenAIConfig,
135    dimension: usize,
136}
137
138impl OpenAIEmbeddingProvider {
139    pub fn new(config: OpenAIConfig) -> Result<Self> {
140        // Determine dimension based on model
141        let dimension = match config.model.as_str() {
142            "text-embedding-ada-002" => 1536,
143            "text-embedding-3-small" => 1536,
144            "text-embedding-3-large" => 3072,
145            _ => anyhow::bail!("Unknown OpenAI model: {}", config.model),
146        };
147
148        Ok(Self { config, dimension })
149    }
150}
151
152impl EmbeddingProvider for OpenAIEmbeddingProvider {
153    fn embed(&self, _text: &str) -> Result<Vec<f32>> {
154        // Placeholder: In production, this would make an HTTP request to OpenAI API
155        // For now, return a mock embedding
156        anyhow::bail!(
157            "OpenAI provider requires HTTP client implementation (add reqwest dependency)"
158        )
159    }
160
161    fn dimension(&self) -> usize {
162        self.dimension
163    }
164
165    fn name(&self) -> &str {
166        "openai"
167    }
168}
169
170/// Cache entry with TTL
171#[derive(Debug, Clone)]
172struct CacheEntry {
173    embedding: Vec<f32>,
174    created_at: SystemTime,
175}
176
177/// Embedding cache with TTL
178#[derive(Debug, Clone)]
179pub struct EmbeddingCache {
180    cache: Arc<Mutex<HashMap<String, CacheEntry>>>,
181    ttl: Duration,
182    max_entries: usize,
183}
184
185impl EmbeddingCache {
186    pub fn new(ttl: Duration, max_entries: usize) -> Self {
187        Self {
188            cache: Arc::new(Mutex::new(HashMap::new())),
189            ttl,
190            max_entries,
191        }
192    }
193
194    /// Get embedding from cache if available and not expired
195    pub fn get(&self, text: &str) -> Option<Vec<f32>> {
196        let cache = self.cache.lock().unwrap();
197        if let Some(entry) = cache.get(text) {
198            let elapsed = SystemTime::now()
199                .duration_since(entry.created_at)
200                .unwrap_or(Duration::MAX);
201
202            if elapsed < self.ttl {
203                return Some(entry.embedding.clone());
204            }
205        }
206        None
207    }
208
209    /// Store embedding in cache
210    pub fn put(&self, text: String, embedding: Vec<f32>) {
211        let mut cache = self.cache.lock().unwrap();
212
213        // Evict oldest entries if cache is full
214        if cache.len() >= self.max_entries {
215            self.evict_oldest(&mut cache);
216        }
217
218        cache.insert(
219            text,
220            CacheEntry {
221                embedding,
222                created_at: SystemTime::now(),
223            },
224        );
225    }
226
227    /// Clear all entries from cache
228    pub fn clear(&self) {
229        let mut cache = self.cache.lock().unwrap();
230        cache.clear();
231    }
232
233    /// Get cache size
234    pub fn size(&self) -> usize {
235        let cache = self.cache.lock().unwrap();
236        cache.len()
237    }
238
239    fn evict_oldest(&self, cache: &mut HashMap<String, CacheEntry>) {
240        if cache.is_empty() {
241            return;
242        }
243
244        // Find oldest entry
245        let oldest_key = cache
246            .iter()
247            .min_by_key(|(_, entry)| entry.created_at)
248            .map(|(key, _)| key.clone());
249
250        if let Some(key) = oldest_key {
251            cache.remove(&key);
252        }
253    }
254}
255
256impl Default for EmbeddingCache {
257    fn default() -> Self {
258        Self::new(Duration::from_secs(3600), 10000) // 1 hour TTL, 10k entries
259    }
260}
261
262/// Cached embedding provider
263///
264/// Wraps any embedding provider with caching
265pub struct CachedEmbeddingProvider<P: EmbeddingProvider> {
266    provider: P,
267    cache: EmbeddingCache,
268}
269
270impl<P: EmbeddingProvider> CachedEmbeddingProvider<P> {
271    pub fn new(provider: P, cache: EmbeddingCache) -> Self {
272        Self { provider, cache }
273    }
274
275    pub fn clear_cache(&self) {
276        self.cache.clear();
277    }
278
279    pub fn cache_size(&self) -> usize {
280        self.cache.size()
281    }
282}
283
284impl<P: EmbeddingProvider> EmbeddingProvider for CachedEmbeddingProvider<P> {
285    fn embed(&self, text: &str) -> Result<Vec<f32>> {
286        // Check cache first
287        if let Some(embedding) = self.cache.get(text) {
288            return Ok(embedding);
289        }
290
291        // Generate embedding
292        let embedding = self.provider.embed(text)?;
293
294        // Store in cache
295        self.cache.put(text.to_string(), embedding.clone());
296
297        Ok(embedding)
298    }
299
300    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
301        let mut results = Vec::with_capacity(texts.len());
302        let mut uncached_indices = Vec::new();
303        let mut uncached_texts = Vec::new();
304
305        // Check cache for each text
306        for (i, text) in texts.iter().enumerate() {
307            if let Some(embedding) = self.cache.get(text) {
308                results.push(Some(embedding));
309            } else {
310                results.push(None);
311                uncached_indices.push(i);
312                uncached_texts.push(*text);
313            }
314        }
315
316        // Generate embeddings for uncached texts
317        if !uncached_texts.is_empty() {
318            let new_embeddings = self.provider.embed_batch(&uncached_texts)?;
319
320            for (idx, embedding) in uncached_indices.iter().zip(new_embeddings.iter()) {
321                self.cache.put(texts[*idx].to_string(), embedding.clone());
322                results[*idx] = Some(embedding.clone());
323            }
324        }
325
326        // Unwrap all results
327        results
328            .into_iter()
329            .map(|opt| opt.context("Missing embedding"))
330            .collect()
331    }
332
333    fn dimension(&self) -> usize {
334        self.provider.dimension()
335    }
336
337    fn name(&self) -> &str {
338        self.provider.name()
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    fn test_mock_provider() {
348        let provider = MockEmbeddingProvider::new(384);
349
350        assert_eq!(provider.dimension(), 384);
351        assert_eq!(provider.name(), "mock");
352
353        let embedding = provider.embed("Hello, world!").unwrap();
354        assert_eq!(embedding.len(), 384);
355
356        // Embeddings should be normalized
357        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
358        assert!((norm - 1.0).abs() < 1e-5);
359    }
360
361    #[test]
362    fn test_mock_provider_deterministic() {
363        let provider = MockEmbeddingProvider::new(128);
364
365        let embedding1 = provider.embed("test").unwrap();
366        let embedding2 = provider.embed("test").unwrap();
367
368        // Same text should produce same embedding
369        assert_eq!(embedding1, embedding2);
370    }
371
372    #[test]
373    fn test_mock_provider_different_texts() {
374        let provider = MockEmbeddingProvider::new(128);
375
376        let embedding1 = provider.embed("hello").unwrap();
377        let embedding2 = provider.embed("world").unwrap();
378
379        // Different texts should produce different embeddings
380        assert_ne!(embedding1, embedding2);
381    }
382
383    #[test]
384    fn test_mock_provider_batch() {
385        let provider = MockEmbeddingProvider::new(256);
386
387        let texts = vec!["text1", "text2", "text3"];
388        let embeddings = provider.embed_batch(&texts).unwrap();
389
390        assert_eq!(embeddings.len(), 3);
391        assert_eq!(embeddings[0].len(), 256);
392        assert_eq!(embeddings[1].len(), 256);
393        assert_eq!(embeddings[2].len(), 256);
394    }
395
396    #[test]
397    fn test_openai_provider_creation() {
398        let config = OpenAIConfig {
399            api_key: "test-key".to_string(),
400            model: "text-embedding-ada-002".to_string(),
401            endpoint: None,
402        };
403
404        let provider = OpenAIEmbeddingProvider::new(config).unwrap();
405        assert_eq!(provider.dimension(), 1536);
406        assert_eq!(provider.name(), "openai");
407    }
408
409    #[test]
410    fn test_openai_provider_unknown_model() {
411        let config = OpenAIConfig {
412            api_key: "test-key".to_string(),
413            model: "unknown-model".to_string(),
414            endpoint: None,
415        };
416
417        let result = OpenAIEmbeddingProvider::new(config);
418        assert!(result.is_err());
419    }
420
421    #[test]
422    fn test_embedding_cache() {
423        let cache = EmbeddingCache::new(Duration::from_secs(10), 100);
424
425        // Initially empty
426        assert_eq!(cache.size(), 0);
427        assert!(cache.get("test").is_none());
428
429        // Store and retrieve
430        cache.put("test".to_string(), vec![1.0, 2.0, 3.0]);
431        assert_eq!(cache.size(), 1);
432
433        let embedding = cache.get("test").unwrap();
434        assert_eq!(embedding, vec![1.0, 2.0, 3.0]);
435
436        // Clear cache
437        cache.clear();
438        assert_eq!(cache.size(), 0);
439        assert!(cache.get("test").is_none());
440    }
441
442    #[test]
443    fn test_embedding_cache_max_entries() {
444        let cache = EmbeddingCache::new(Duration::from_secs(10), 3);
445
446        cache.put("key1".to_string(), vec![1.0]);
447        cache.put("key2".to_string(), vec![2.0]);
448        cache.put("key3".to_string(), vec![3.0]);
449
450        assert_eq!(cache.size(), 3);
451
452        // Adding 4th entry should evict oldest
453        cache.put("key4".to_string(), vec![4.0]);
454        assert_eq!(cache.size(), 3);
455
456        // key1 should be evicted
457        assert!(cache.get("key1").is_none());
458        assert!(cache.get("key4").is_some());
459    }
460
461    #[test]
462    fn test_cached_provider() {
463        let mock_provider = MockEmbeddingProvider::new(128);
464        let cache = EmbeddingCache::new(Duration::from_secs(10), 100);
465        let cached_provider = CachedEmbeddingProvider::new(mock_provider, cache);
466
467        assert_eq!(cached_provider.cache_size(), 0);
468
469        // First call - not cached
470        let embedding1 = cached_provider.embed("test").unwrap();
471        assert_eq!(cached_provider.cache_size(), 1);
472
473        // Second call - should be cached
474        let embedding2 = cached_provider.embed("test").unwrap();
475        assert_eq!(cached_provider.cache_size(), 1);
476
477        assert_eq!(embedding1, embedding2);
478    }
479
480    #[test]
481    fn test_cached_provider_batch() {
482        let mock_provider = MockEmbeddingProvider::new(64);
483        let cache = EmbeddingCache::new(Duration::from_secs(10), 100);
484        let cached_provider = CachedEmbeddingProvider::new(mock_provider, cache);
485
486        let texts = vec!["text1", "text2", "text3"];
487
488        // First batch call - nothing cached
489        let embeddings1 = cached_provider.embed_batch(&texts).unwrap();
490        assert_eq!(cached_provider.cache_size(), 3);
491
492        // Second batch call - all cached
493        let embeddings2 = cached_provider.embed_batch(&texts).unwrap();
494        assert_eq!(cached_provider.cache_size(), 3);
495
496        assert_eq!(embeddings1, embeddings2);
497    }
498
499    #[test]
500    fn test_cached_provider_partial_cache() {
501        let mock_provider = MockEmbeddingProvider::new(32);
502        let cache = EmbeddingCache::new(Duration::from_secs(10), 100);
503        let cached_provider = CachedEmbeddingProvider::new(mock_provider, cache);
504
505        // Cache some texts
506        cached_provider.embed("text1").unwrap();
507        cached_provider.embed("text2").unwrap();
508        assert_eq!(cached_provider.cache_size(), 2);
509
510        // Batch with mix of cached and uncached
511        let texts = vec!["text1", "text2", "text3", "text4"];
512        let embeddings = cached_provider.embed_batch(&texts).unwrap();
513
514        assert_eq!(embeddings.len(), 4);
515        assert_eq!(cached_provider.cache_size(), 4);
516    }
517
518    #[test]
519    fn test_cache_clear() {
520        let mock_provider = MockEmbeddingProvider::new(16);
521        let cache = EmbeddingCache::new(Duration::from_secs(10), 100);
522        let cached_provider = CachedEmbeddingProvider::new(mock_provider, cache);
523
524        cached_provider.embed("test1").unwrap();
525        cached_provider.embed("test2").unwrap();
526        assert_eq!(cached_provider.cache_size(), 2);
527
528        cached_provider.clear_cache();
529        assert_eq!(cached_provider.cache_size(), 0);
530    }
531}