contrag_core/embedders/
mod.rs

1pub mod openai;
2pub mod gemini;
3pub mod http_client;
4
5use crate::error::Result;
6use crate::types::ConnectionTestResult;
7
8/// Trait for embedding providers
9/// 
10/// Implement this trait to add support for additional embedding APIs.
11#[async_trait::async_trait]
12pub trait Embedder: Send + Sync {
13    /// Get the name of this embedder
14    fn name(&self) -> &str;
15
16    /// Generate embeddings for a batch of texts
17    async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>>;
18
19    /// Get the dimensions of the embeddings
20    fn dimensions(&self) -> usize;
21
22    /// Test the connection to the embedding service
23    async fn test_connection(&self) -> Result<ConnectionTestResult>;
24
25    /// Optional: Generate text with prompt (for LLM features)
26    async fn generate_with_prompt(
27        &self,
28        _text: String,
29        _system_prompt: String,
30    ) -> Result<String> {
31        Ok(String::new())
32    }
33}
34
35/// Cache for embeddings to reduce API calls
36pub struct EmbeddingCache {
37    cache: std::collections::HashMap<String, Vec<f32>>,
38    max_size: usize,
39}
40
41impl EmbeddingCache {
42    pub fn new(max_size: usize) -> Self {
43        Self {
44            cache: std::collections::HashMap::new(),
45            max_size,
46        }
47    }
48
49    pub fn get(&self, text: &str) -> Option<Vec<f32>> {
50        self.cache.get(text).cloned()
51    }
52
53    pub fn insert(&mut self, text: String, embedding: Vec<f32>) {
54        if self.cache.len() >= self.max_size {
55            // Simple LRU: remove first entry
56            if let Some(first_key) = self.cache.keys().next().cloned() {
57                self.cache.remove(&first_key);
58            }
59        }
60        self.cache.insert(text, embedding);
61    }
62
63    pub fn clear(&mut self) {
64        self.cache.clear();
65    }
66}
67
68/// Embedder wrapper with caching support
69pub struct CachedEmbedder<E: Embedder> {
70    embedder: E,
71    cache: EmbeddingCache,
72}
73
74impl<E: Embedder> CachedEmbedder<E> {
75    pub fn new(embedder: E, cache_size: usize) -> Self {
76        Self {
77            embedder,
78            cache: EmbeddingCache::new(cache_size),
79        }
80    }
81
82    pub async fn embed_with_cache(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
83        let mut results = vec![];
84        let mut to_embed = vec![];
85        let mut indices = vec![];
86
87        // Check cache
88        for (idx, text) in texts.iter().enumerate() {
89            if let Some(cached) = self.cache.get(text) {
90                results.push((idx, cached));
91            } else {
92                to_embed.push(text.clone());
93                indices.push(idx);
94            }
95        }
96
97        // Embed uncached texts
98        if !to_embed.is_empty() {
99            let embeddings = self.embedder.embed(to_embed.clone()).await?;
100            
101            // Cache results
102            for (text, embedding) in to_embed.iter().zip(embeddings.iter()) {
103                self.cache.insert(text.clone(), embedding.clone());
104            }
105
106            // Add to results
107            for (idx, embedding) in indices.iter().zip(embeddings) {
108                results.push((*idx, embedding));
109            }
110        }
111
112        // Sort by original index and return
113        results.sort_by_key(|(idx, _)| *idx);
114        Ok(results.into_iter().map(|(_, emb)| emb).collect())
115    }
116}