ceylon_next/memory/vector/
embedding.rs1use super::EmbeddingProvider;
4use async_trait::async_trait;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9pub struct CachedEmbeddings {
44 provider: Arc<dyn EmbeddingProvider>,
46 cache: Arc<RwLock<HashMap<String, Vec<f32>>>>,
48}
49
50impl CachedEmbeddings {
51 pub fn new(provider: Arc<dyn EmbeddingProvider>) -> Self {
57 Self {
58 provider,
59 cache: Arc::new(RwLock::new(HashMap::new())),
60 }
61 }
62
63 pub fn with_capacity(provider: Arc<dyn EmbeddingProvider>, capacity: usize) -> Self {
70 Self {
71 provider,
72 cache: Arc::new(RwLock::new(HashMap::with_capacity(capacity))),
73 }
74 }
75
76 pub async fn cache_size(&self) -> usize {
78 self.cache.read().await.len()
79 }
80
81 pub async fn clear_cache(&self) {
83 self.cache.write().await.clear();
84 }
85
86 pub async fn is_cached(&self, text: &str) -> bool {
88 self.cache.read().await.contains_key(text)
89 }
90
91 pub async fn preload(&self, texts: &[String]) -> Result<(), String> {
99 let embeddings = self.provider.embed_batch(texts).await?;
100
101 let mut cache = self.cache.write().await;
102 for (text, embedding) in texts.iter().zip(embeddings.iter()) {
103 cache.insert(text.clone(), embedding.clone());
104 }
105
106 Ok(())
107 }
108}
109
110#[async_trait]
111impl EmbeddingProvider for CachedEmbeddings {
112 async fn embed(&self, text: &str) -> Result<Vec<f32>, String> {
113 {
115 let cache = self.cache.read().await;
116 if let Some(embedding) = cache.get(text) {
117 return Ok(embedding.clone());
118 }
119 }
120
121 let embedding = self.provider.embed(text).await?;
123
124 {
126 let mut cache = self.cache.write().await;
127 cache.insert(text.to_string(), embedding.clone());
128 }
129
130 Ok(embedding)
131 }
132
133 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, String> {
134 let mut results = Vec::with_capacity(texts.len());
135 let mut uncached_indices = Vec::new();
136 let mut uncached_texts = Vec::new();
137
138 {
140 let cache = self.cache.read().await;
141 for (i, text) in texts.iter().enumerate() {
142 if let Some(embedding) = cache.get(text) {
143 results.push((i, embedding.clone()));
144 } else {
145 uncached_indices.push(i);
146 uncached_texts.push(text.clone());
147 }
148 }
149 }
150
151 if !uncached_texts.is_empty() {
153 let new_embeddings = self.provider.embed_batch(&uncached_texts).await?;
154
155 {
157 let mut cache = self.cache.write().await;
158 for (text, embedding) in uncached_texts.iter().zip(new_embeddings.iter()) {
159 cache.insert(text.clone(), embedding.clone());
160 }
161 }
162
163 for (i, embedding) in uncached_indices.iter().zip(new_embeddings.iter()) {
164 results.push((*i, embedding.clone()));
165 }
166 }
167
168 results.sort_by_key(|(i, _)| *i);
170 Ok(results.into_iter().map(|(_, emb)| emb).collect())
171 }
172
173 fn dimension(&self) -> usize {
174 self.provider.dimension()
175 }
176
177 fn model_name(&self) -> &str {
178 self.provider.model_name()
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 struct MockEmbedder {
188 dimension: usize,
189 call_count: Arc<RwLock<usize>>,
190 }
191
192 impl MockEmbedder {
193 fn new(dimension: usize) -> Self {
194 Self {
195 dimension,
196 call_count: Arc::new(RwLock::new(0)),
197 }
198 }
199
200 async fn calls(&self) -> usize {
201 *self.call_count.read().await
202 }
203 }
204
205 #[async_trait]
206 impl EmbeddingProvider for MockEmbedder {
207 async fn embed(&self, text: &str) -> Result<Vec<f32>, String> {
208 *self.call_count.write().await += 1;
209
210 let hash = text.len() as f32;
212 Ok(vec![hash; self.dimension])
213 }
214
215 fn dimension(&self) -> usize {
216 self.dimension
217 }
218
219 fn model_name(&self) -> &str {
220 "mock"
221 }
222 }
223
224 #[tokio::test]
225 async fn test_caching() {
226 let mock = Arc::new(MockEmbedder::new(3));
227 let cached = CachedEmbeddings::new(mock.clone());
228
229 let emb1 = cached.embed("test").await.unwrap();
231 assert_eq!(mock.calls().await, 1);
232 assert_eq!(emb1.len(), 3);
233
234 let emb2 = cached.embed("test").await.unwrap();
236 assert_eq!(mock.calls().await, 1); assert_eq!(emb1, emb2);
238
239 let _emb3 = cached.embed("other").await.unwrap();
241 assert_eq!(mock.calls().await, 2);
242 }
243
244 #[tokio::test]
245 async fn test_batch_caching() {
246 let mock = Arc::new(MockEmbedder::new(2));
247 let cached = CachedEmbeddings::new(mock.clone());
248
249 let texts1 = vec!["a".to_string(), "b".to_string(), "c".to_string()];
250 let embs1 = cached.embed_batch(&texts1).await.unwrap();
251 assert_eq!(embs1.len(), 3);
252
253 assert_eq!(cached.cache_size().await, 3);
255
256 let texts2 = vec!["a".to_string(), "b".to_string(), "d".to_string()];
258 let embs2 = cached.embed_batch(&texts2).await.unwrap();
259
260 assert_eq!(cached.cache_size().await, 4);
262
263 assert_eq!(embs1[0], embs2[0]); assert_eq!(embs1[1], embs2[1]); }
267
268 #[tokio::test]
269 async fn test_cache_operations() {
270 let mock = Arc::new(MockEmbedder::new(2));
271 let cached = CachedEmbeddings::new(mock);
272
273 assert_eq!(cached.cache_size().await, 0);
274 assert!(!cached.is_cached("test").await);
275
276 cached.embed("test").await.unwrap();
277
278 assert_eq!(cached.cache_size().await, 1);
279 assert!(cached.is_cached("test").await);
280
281 cached.clear_cache().await;
282
283 assert_eq!(cached.cache_size().await, 0);
284 assert!(!cached.is_cached("test").await);
285 }
286
287 #[tokio::test]
288 async fn test_preload() {
289 let mock = Arc::new(MockEmbedder::new(2));
290 let cached = CachedEmbeddings::new(mock.clone());
291
292 let texts = vec!["a".to_string(), "b".to_string(), "c".to_string()];
293 cached.preload(&texts).await.unwrap();
294
295 assert_eq!(cached.cache_size().await, 3);
296 assert!(cached.is_cached("a").await);
297 assert!(cached.is_cached("b").await);
298 assert!(cached.is_cached("c").await);
299
300 let initial_calls = mock.calls().await;
301
302 cached.embed("a").await.unwrap();
304 cached.embed("b").await.unwrap();
305 cached.embed("c").await.unwrap();
306
307 assert_eq!(mock.calls().await, initial_calls);
308 }
309}