1use anyhow::{Context, Result};
31use serde::{Deserialize, Serialize};
32use std::collections::HashMap;
33use std::sync::{Arc, Mutex};
34use std::time::{Duration, SystemTime};
35
36pub trait EmbeddingProvider: Send + Sync {
38 fn embed(&self, text: &str) -> Result<Vec<f32>>;
40
41 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
43 texts.iter().map(|text| self.embed(text)).collect()
44 }
45
46 fn dimension(&self) -> usize;
48
49 fn name(&self) -> &str;
51}
52
53#[derive(Debug, Clone)]
57pub struct MockEmbeddingProvider {
58 dimension: usize,
59}
60
61impl MockEmbeddingProvider {
62 pub fn new(dimension: usize) -> Self {
63 Self { dimension }
64 }
65
66 fn hash_text(&self, text: &str) -> u64 {
67 let mut hash: u64 = 5381;
69 for byte in text.as_bytes() {
70 hash = hash.wrapping_mul(33).wrapping_add(*byte as u64);
71 }
72 hash
73 }
74}
75
76impl EmbeddingProvider for MockEmbeddingProvider {
77 fn embed(&self, text: &str) -> Result<Vec<f32>> {
78 let hash = self.hash_text(text);
79 let mut embedding = Vec::with_capacity(self.dimension);
80
81 for i in 0..self.dimension {
82 let val = ((hash.wrapping_add(i as u64) % 1000) as f32) / 1000.0;
83 embedding.push(val);
84 }
85
86 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
88 if norm > 0.0 {
89 for val in &mut embedding {
90 *val /= norm;
91 }
92 }
93
94 Ok(embedding)
95 }
96
97 fn dimension(&self) -> usize {
98 self.dimension
99 }
100
101 fn name(&self) -> &str {
102 "mock"
103 }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct OpenAIConfig {
109 pub api_key: String,
111 pub model: String,
113 pub endpoint: Option<String>,
115}
116
117impl Default for OpenAIConfig {
118 fn default() -> Self {
119 Self {
120 api_key: String::new(),
121 model: "text-embedding-3-small".to_string(),
122 endpoint: None,
123 }
124 }
125}
126
127#[derive(Debug, Clone)]
132pub struct OpenAIEmbeddingProvider {
133 #[allow(dead_code)]
134 config: OpenAIConfig,
135 dimension: usize,
136}
137
138impl OpenAIEmbeddingProvider {
139 pub fn new(config: OpenAIConfig) -> Result<Self> {
140 let dimension = match config.model.as_str() {
142 "text-embedding-ada-002" => 1536,
143 "text-embedding-3-small" => 1536,
144 "text-embedding-3-large" => 3072,
145 _ => anyhow::bail!("Unknown OpenAI model: {}", config.model),
146 };
147
148 Ok(Self { config, dimension })
149 }
150}
151
152impl EmbeddingProvider for OpenAIEmbeddingProvider {
153 fn embed(&self, _text: &str) -> Result<Vec<f32>> {
154 anyhow::bail!(
157 "OpenAI provider requires HTTP client implementation (add reqwest dependency)"
158 )
159 }
160
161 fn dimension(&self) -> usize {
162 self.dimension
163 }
164
165 fn name(&self) -> &str {
166 "openai"
167 }
168}
169
170#[derive(Debug, Clone)]
172struct CacheEntry {
173 embedding: Vec<f32>,
174 created_at: SystemTime,
175}
176
177#[derive(Debug, Clone)]
179pub struct EmbeddingCache {
180 cache: Arc<Mutex<HashMap<String, CacheEntry>>>,
181 ttl: Duration,
182 max_entries: usize,
183}
184
185impl EmbeddingCache {
186 pub fn new(ttl: Duration, max_entries: usize) -> Self {
187 Self {
188 cache: Arc::new(Mutex::new(HashMap::new())),
189 ttl,
190 max_entries,
191 }
192 }
193
194 pub fn get(&self, text: &str) -> Option<Vec<f32>> {
196 let cache = self.cache.lock().unwrap();
197 if let Some(entry) = cache.get(text) {
198 let elapsed = SystemTime::now()
199 .duration_since(entry.created_at)
200 .unwrap_or(Duration::MAX);
201
202 if elapsed < self.ttl {
203 return Some(entry.embedding.clone());
204 }
205 }
206 None
207 }
208
209 pub fn put(&self, text: String, embedding: Vec<f32>) {
211 let mut cache = self.cache.lock().unwrap();
212
213 if cache.len() >= self.max_entries {
215 self.evict_oldest(&mut cache);
216 }
217
218 cache.insert(
219 text,
220 CacheEntry {
221 embedding,
222 created_at: SystemTime::now(),
223 },
224 );
225 }
226
227 pub fn clear(&self) {
229 let mut cache = self.cache.lock().unwrap();
230 cache.clear();
231 }
232
233 pub fn size(&self) -> usize {
235 let cache = self.cache.lock().unwrap();
236 cache.len()
237 }
238
239 fn evict_oldest(&self, cache: &mut HashMap<String, CacheEntry>) {
240 if cache.is_empty() {
241 return;
242 }
243
244 let oldest_key = cache
246 .iter()
247 .min_by_key(|(_, entry)| entry.created_at)
248 .map(|(key, _)| key.clone());
249
250 if let Some(key) = oldest_key {
251 cache.remove(&key);
252 }
253 }
254}
255
256impl Default for EmbeddingCache {
257 fn default() -> Self {
258 Self::new(Duration::from_secs(3600), 10000) }
260}
261
262pub struct CachedEmbeddingProvider<P: EmbeddingProvider> {
266 provider: P,
267 cache: EmbeddingCache,
268}
269
270impl<P: EmbeddingProvider> CachedEmbeddingProvider<P> {
271 pub fn new(provider: P, cache: EmbeddingCache) -> Self {
272 Self { provider, cache }
273 }
274
275 pub fn clear_cache(&self) {
276 self.cache.clear();
277 }
278
279 pub fn cache_size(&self) -> usize {
280 self.cache.size()
281 }
282}
283
284impl<P: EmbeddingProvider> EmbeddingProvider for CachedEmbeddingProvider<P> {
285 fn embed(&self, text: &str) -> Result<Vec<f32>> {
286 if let Some(embedding) = self.cache.get(text) {
288 return Ok(embedding);
289 }
290
291 let embedding = self.provider.embed(text)?;
293
294 self.cache.put(text.to_string(), embedding.clone());
296
297 Ok(embedding)
298 }
299
300 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
301 let mut results = Vec::with_capacity(texts.len());
302 let mut uncached_indices = Vec::new();
303 let mut uncached_texts = Vec::new();
304
305 for (i, text) in texts.iter().enumerate() {
307 if let Some(embedding) = self.cache.get(text) {
308 results.push(Some(embedding));
309 } else {
310 results.push(None);
311 uncached_indices.push(i);
312 uncached_texts.push(*text);
313 }
314 }
315
316 if !uncached_texts.is_empty() {
318 let new_embeddings = self.provider.embed_batch(&uncached_texts)?;
319
320 for (idx, embedding) in uncached_indices.iter().zip(new_embeddings.iter()) {
321 self.cache.put(texts[*idx].to_string(), embedding.clone());
322 results[*idx] = Some(embedding.clone());
323 }
324 }
325
326 results
328 .into_iter()
329 .map(|opt| opt.context("Missing embedding"))
330 .collect()
331 }
332
333 fn dimension(&self) -> usize {
334 self.provider.dimension()
335 }
336
337 fn name(&self) -> &str {
338 self.provider.name()
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn test_mock_provider() {
348 let provider = MockEmbeddingProvider::new(384);
349
350 assert_eq!(provider.dimension(), 384);
351 assert_eq!(provider.name(), "mock");
352
353 let embedding = provider.embed("Hello, world!").unwrap();
354 assert_eq!(embedding.len(), 384);
355
356 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
358 assert!((norm - 1.0).abs() < 1e-5);
359 }
360
361 #[test]
362 fn test_mock_provider_deterministic() {
363 let provider = MockEmbeddingProvider::new(128);
364
365 let embedding1 = provider.embed("test").unwrap();
366 let embedding2 = provider.embed("test").unwrap();
367
368 assert_eq!(embedding1, embedding2);
370 }
371
372 #[test]
373 fn test_mock_provider_different_texts() {
374 let provider = MockEmbeddingProvider::new(128);
375
376 let embedding1 = provider.embed("hello").unwrap();
377 let embedding2 = provider.embed("world").unwrap();
378
379 assert_ne!(embedding1, embedding2);
381 }
382
383 #[test]
384 fn test_mock_provider_batch() {
385 let provider = MockEmbeddingProvider::new(256);
386
387 let texts = vec!["text1", "text2", "text3"];
388 let embeddings = provider.embed_batch(&texts).unwrap();
389
390 assert_eq!(embeddings.len(), 3);
391 assert_eq!(embeddings[0].len(), 256);
392 assert_eq!(embeddings[1].len(), 256);
393 assert_eq!(embeddings[2].len(), 256);
394 }
395
396 #[test]
397 fn test_openai_provider_creation() {
398 let config = OpenAIConfig {
399 api_key: "test-key".to_string(),
400 model: "text-embedding-ada-002".to_string(),
401 endpoint: None,
402 };
403
404 let provider = OpenAIEmbeddingProvider::new(config).unwrap();
405 assert_eq!(provider.dimension(), 1536);
406 assert_eq!(provider.name(), "openai");
407 }
408
409 #[test]
410 fn test_openai_provider_unknown_model() {
411 let config = OpenAIConfig {
412 api_key: "test-key".to_string(),
413 model: "unknown-model".to_string(),
414 endpoint: None,
415 };
416
417 let result = OpenAIEmbeddingProvider::new(config);
418 assert!(result.is_err());
419 }
420
421 #[test]
422 fn test_embedding_cache() {
423 let cache = EmbeddingCache::new(Duration::from_secs(10), 100);
424
425 assert_eq!(cache.size(), 0);
427 assert!(cache.get("test").is_none());
428
429 cache.put("test".to_string(), vec![1.0, 2.0, 3.0]);
431 assert_eq!(cache.size(), 1);
432
433 let embedding = cache.get("test").unwrap();
434 assert_eq!(embedding, vec![1.0, 2.0, 3.0]);
435
436 cache.clear();
438 assert_eq!(cache.size(), 0);
439 assert!(cache.get("test").is_none());
440 }
441
442 #[test]
443 fn test_embedding_cache_max_entries() {
444 let cache = EmbeddingCache::new(Duration::from_secs(10), 3);
445
446 cache.put("key1".to_string(), vec![1.0]);
447 cache.put("key2".to_string(), vec![2.0]);
448 cache.put("key3".to_string(), vec![3.0]);
449
450 assert_eq!(cache.size(), 3);
451
452 cache.put("key4".to_string(), vec![4.0]);
454 assert_eq!(cache.size(), 3);
455
456 assert!(cache.get("key1").is_none());
458 assert!(cache.get("key4").is_some());
459 }
460
461 #[test]
462 fn test_cached_provider() {
463 let mock_provider = MockEmbeddingProvider::new(128);
464 let cache = EmbeddingCache::new(Duration::from_secs(10), 100);
465 let cached_provider = CachedEmbeddingProvider::new(mock_provider, cache);
466
467 assert_eq!(cached_provider.cache_size(), 0);
468
469 let embedding1 = cached_provider.embed("test").unwrap();
471 assert_eq!(cached_provider.cache_size(), 1);
472
473 let embedding2 = cached_provider.embed("test").unwrap();
475 assert_eq!(cached_provider.cache_size(), 1);
476
477 assert_eq!(embedding1, embedding2);
478 }
479
480 #[test]
481 fn test_cached_provider_batch() {
482 let mock_provider = MockEmbeddingProvider::new(64);
483 let cache = EmbeddingCache::new(Duration::from_secs(10), 100);
484 let cached_provider = CachedEmbeddingProvider::new(mock_provider, cache);
485
486 let texts = vec!["text1", "text2", "text3"];
487
488 let embeddings1 = cached_provider.embed_batch(&texts).unwrap();
490 assert_eq!(cached_provider.cache_size(), 3);
491
492 let embeddings2 = cached_provider.embed_batch(&texts).unwrap();
494 assert_eq!(cached_provider.cache_size(), 3);
495
496 assert_eq!(embeddings1, embeddings2);
497 }
498
499 #[test]
500 fn test_cached_provider_partial_cache() {
501 let mock_provider = MockEmbeddingProvider::new(32);
502 let cache = EmbeddingCache::new(Duration::from_secs(10), 100);
503 let cached_provider = CachedEmbeddingProvider::new(mock_provider, cache);
504
505 cached_provider.embed("text1").unwrap();
507 cached_provider.embed("text2").unwrap();
508 assert_eq!(cached_provider.cache_size(), 2);
509
510 let texts = vec!["text1", "text2", "text3", "text4"];
512 let embeddings = cached_provider.embed_batch(&texts).unwrap();
513
514 assert_eq!(embeddings.len(), 4);
515 assert_eq!(cached_provider.cache_size(), 4);
516 }
517
518 #[test]
519 fn test_cache_clear() {
520 let mock_provider = MockEmbeddingProvider::new(16);
521 let cache = EmbeddingCache::new(Duration::from_secs(10), 100);
522 let cached_provider = CachedEmbeddingProvider::new(mock_provider, cache);
523
524 cached_provider.embed("test1").unwrap();
525 cached_provider.embed("test2").unwrap();
526 assert_eq!(cached_provider.cache_size(), 2);
527
528 cached_provider.clear_cache();
529 assert_eq!(cached_provider.cache_size(), 0);
530 }
531}