ceylon_next/memory/vector/
embedding.rs

1//! Embedding provider implementations and utilities.
2
3use super::EmbeddingProvider;
4use async_trait::async_trait;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9/// A caching wrapper for embedding providers.
10///
11/// This wrapper adds a caching layer on top of any embedding provider to avoid
12/// regenerating embeddings for the same text. This is especially useful when:
13/// - The same queries or texts are embedded multiple times
14/// - Embeddings are expensive to compute (API calls, local models)
15/// - You want to reduce API costs and latency
16///
17/// # Cache Strategy
18///
19/// - Uses a simple in-memory HashMap with RwLock for thread-safety
20/// - Cache key is the exact text string
21/// - No cache eviction (unbounded cache - use with caution for large datasets)
22/// - Thread-safe for concurrent access
23///
24/// # Example
25///
26/// ```rust,no_run
27/// use ceylon_next::memory::vector::{EmbeddingProvider, CachedEmbeddings};
28/// use std::sync::Arc;
29///
30/// #[tokio::main]
31/// async fn main() {
32///     // Wrap any embedding provider with caching
33///     // let base_provider = OpenAIEmbeddings::new("api-key");
34///     // let cached = CachedEmbeddings::new(Arc::new(base_provider));
35///
36///     // First call - computes embedding
37///     // let embedding1 = cached.embed("hello world").await.unwrap();
38///
39///     // Second call - returns cached result
40///     // let embedding2 = cached.embed("hello world").await.unwrap();
41/// }
42/// ```
43pub struct CachedEmbeddings {
44    /// The underlying embedding provider
45    provider: Arc<dyn EmbeddingProvider>,
46    /// Cache mapping text to embeddings
47    cache: Arc<RwLock<HashMap<String, Vec<f32>>>>,
48}
49
50impl CachedEmbeddings {
51    /// Creates a new cached embedding provider.
52    ///
53    /// # Arguments
54    ///
55    /// * `provider` - The underlying embedding provider to wrap
56    pub fn new(provider: Arc<dyn EmbeddingProvider>) -> Self {
57        Self {
58            provider,
59            cache: Arc::new(RwLock::new(HashMap::new())),
60        }
61    }
62
63    /// Creates a new cached embedding provider with pre-allocated cache capacity.
64    ///
65    /// # Arguments
66    ///
67    /// * `provider` - The underlying embedding provider to wrap
68    /// * `capacity` - Initial capacity for the cache
69    pub fn with_capacity(provider: Arc<dyn EmbeddingProvider>, capacity: usize) -> Self {
70        Self {
71            provider,
72            cache: Arc::new(RwLock::new(HashMap::with_capacity(capacity))),
73        }
74    }
75
76    /// Returns the number of cached embeddings.
77    pub async fn cache_size(&self) -> usize {
78        self.cache.read().await.len()
79    }
80
81    /// Clears all cached embeddings.
82    pub async fn clear_cache(&self) {
83        self.cache.write().await.clear();
84    }
85
86    /// Checks if a text is in the cache.
87    pub async fn is_cached(&self, text: &str) -> bool {
88        self.cache.read().await.contains_key(text)
89    }
90
91    /// Preloads embeddings into the cache.
92    ///
93    /// This can be useful for warming up the cache with known queries.
94    ///
95    /// # Arguments
96    ///
97    /// * `texts` - Texts to preload into cache
98    pub async fn preload(&self, texts: &[String]) -> Result<(), String> {
99        let embeddings = self.provider.embed_batch(texts).await?;
100
101        let mut cache = self.cache.write().await;
102        for (text, embedding) in texts.iter().zip(embeddings.iter()) {
103            cache.insert(text.clone(), embedding.clone());
104        }
105
106        Ok(())
107    }
108}
109
110#[async_trait]
111impl EmbeddingProvider for CachedEmbeddings {
112    async fn embed(&self, text: &str) -> Result<Vec<f32>, String> {
113        // Check cache first
114        {
115            let cache = self.cache.read().await;
116            if let Some(embedding) = cache.get(text) {
117                return Ok(embedding.clone());
118            }
119        }
120
121        // Not in cache, compute embedding
122        let embedding = self.provider.embed(text).await?;
123
124        // Store in cache
125        {
126            let mut cache = self.cache.write().await;
127            cache.insert(text.to_string(), embedding.clone());
128        }
129
130        Ok(embedding)
131    }
132
133    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, String> {
134        let mut results = Vec::with_capacity(texts.len());
135        let mut uncached_indices = Vec::new();
136        let mut uncached_texts = Vec::new();
137
138        // Check which texts are cached
139        {
140            let cache = self.cache.read().await;
141            for (i, text) in texts.iter().enumerate() {
142                if let Some(embedding) = cache.get(text) {
143                    results.push((i, embedding.clone()));
144                } else {
145                    uncached_indices.push(i);
146                    uncached_texts.push(text.clone());
147                }
148            }
149        }
150
151        // Compute embeddings for uncached texts
152        if !uncached_texts.is_empty() {
153            let new_embeddings = self.provider.embed_batch(&uncached_texts).await?;
154
155            // Store new embeddings in cache and results
156            {
157                let mut cache = self.cache.write().await;
158                for (text, embedding) in uncached_texts.iter().zip(new_embeddings.iter()) {
159                    cache.insert(text.clone(), embedding.clone());
160                }
161            }
162
163            for (i, embedding) in uncached_indices.iter().zip(new_embeddings.iter()) {
164                results.push((*i, embedding.clone()));
165            }
166        }
167
168        // Sort by original index and extract embeddings
169        results.sort_by_key(|(i, _)| *i);
170        Ok(results.into_iter().map(|(_, emb)| emb).collect())
171    }
172
173    fn dimension(&self) -> usize {
174        self.provider.dimension()
175    }
176
177    fn model_name(&self) -> &str {
178        self.provider.model_name()
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    // Mock embedding provider for testing
187    struct MockEmbedder {
188        dimension: usize,
189        call_count: Arc<RwLock<usize>>,
190    }
191
192    impl MockEmbedder {
193        fn new(dimension: usize) -> Self {
194            Self {
195                dimension,
196                call_count: Arc::new(RwLock::new(0)),
197            }
198        }
199
200        async fn calls(&self) -> usize {
201            *self.call_count.read().await
202        }
203    }
204
205    #[async_trait]
206    impl EmbeddingProvider for MockEmbedder {
207        async fn embed(&self, text: &str) -> Result<Vec<f32>, String> {
208            *self.call_count.write().await += 1;
209
210            // Simple hash-based embedding for testing
211            let hash = text.len() as f32;
212            Ok(vec![hash; self.dimension])
213        }
214
215        fn dimension(&self) -> usize {
216            self.dimension
217        }
218
219        fn model_name(&self) -> &str {
220            "mock"
221        }
222    }
223
224    #[tokio::test]
225    async fn test_caching() {
226        let mock = Arc::new(MockEmbedder::new(3));
227        let cached = CachedEmbeddings::new(mock.clone());
228
229        // First call should compute
230        let emb1 = cached.embed("test").await.unwrap();
231        assert_eq!(mock.calls().await, 1);
232        assert_eq!(emb1.len(), 3);
233
234        // Second call should use cache
235        let emb2 = cached.embed("test").await.unwrap();
236        assert_eq!(mock.calls().await, 1); // No new call
237        assert_eq!(emb1, emb2);
238
239        // Different text should compute again
240        let _emb3 = cached.embed("other").await.unwrap();
241        assert_eq!(mock.calls().await, 2);
242    }
243
244    #[tokio::test]
245    async fn test_batch_caching() {
246        let mock = Arc::new(MockEmbedder::new(2));
247        let cached = CachedEmbeddings::new(mock.clone());
248
249        let texts1 = vec!["a".to_string(), "b".to_string(), "c".to_string()];
250        let embs1 = cached.embed_batch(&texts1).await.unwrap();
251        assert_eq!(embs1.len(), 3);
252
253        // All texts should be cached
254        assert_eq!(cached.cache_size().await, 3);
255
256        // Reusing some texts
257        let texts2 = vec!["a".to_string(), "b".to_string(), "d".to_string()];
258        let embs2 = cached.embed_batch(&texts2).await.unwrap();
259
260        // Should only compute for "d"
261        assert_eq!(cached.cache_size().await, 4);
262
263        // Cached texts should return same embeddings
264        assert_eq!(embs1[0], embs2[0]); // "a"
265        assert_eq!(embs1[1], embs2[1]); // "b"
266    }
267
268    #[tokio::test]
269    async fn test_cache_operations() {
270        let mock = Arc::new(MockEmbedder::new(2));
271        let cached = CachedEmbeddings::new(mock);
272
273        assert_eq!(cached.cache_size().await, 0);
274        assert!(!cached.is_cached("test").await);
275
276        cached.embed("test").await.unwrap();
277
278        assert_eq!(cached.cache_size().await, 1);
279        assert!(cached.is_cached("test").await);
280
281        cached.clear_cache().await;
282
283        assert_eq!(cached.cache_size().await, 0);
284        assert!(!cached.is_cached("test").await);
285    }
286
287    #[tokio::test]
288    async fn test_preload() {
289        let mock = Arc::new(MockEmbedder::new(2));
290        let cached = CachedEmbeddings::new(mock.clone());
291
292        let texts = vec!["a".to_string(), "b".to_string(), "c".to_string()];
293        cached.preload(&texts).await.unwrap();
294
295        assert_eq!(cached.cache_size().await, 3);
296        assert!(cached.is_cached("a").await);
297        assert!(cached.is_cached("b").await);
298        assert!(cached.is_cached("c").await);
299
300        let initial_calls = mock.calls().await;
301
302        // These should all hit cache
303        cached.embed("a").await.unwrap();
304        cached.embed("b").await.unwrap();
305        cached.embed("c").await.unwrap();
306
307        assert_eq!(mock.calls().await, initial_calls);
308    }
309}