contrag_core/embedders/
mod.rs1pub mod openai;
2pub mod gemini;
3pub mod http_client;
4
5use crate::error::Result;
6use crate::types::ConnectionTestResult;
7
8#[async_trait::async_trait]
12pub trait Embedder: Send + Sync {
13 fn name(&self) -> &str;
15
16 async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>>;
18
19 fn dimensions(&self) -> usize;
21
22 async fn test_connection(&self) -> Result<ConnectionTestResult>;
24
25 async fn generate_with_prompt(
27 &self,
28 _text: String,
29 _system_prompt: String,
30 ) -> Result<String> {
31 Ok(String::new())
32 }
33}
34
35pub struct EmbeddingCache {
37 cache: std::collections::HashMap<String, Vec<f32>>,
38 max_size: usize,
39}
40
41impl EmbeddingCache {
42 pub fn new(max_size: usize) -> Self {
43 Self {
44 cache: std::collections::HashMap::new(),
45 max_size,
46 }
47 }
48
49 pub fn get(&self, text: &str) -> Option<Vec<f32>> {
50 self.cache.get(text).cloned()
51 }
52
53 pub fn insert(&mut self, text: String, embedding: Vec<f32>) {
54 if self.cache.len() >= self.max_size {
55 if let Some(first_key) = self.cache.keys().next().cloned() {
57 self.cache.remove(&first_key);
58 }
59 }
60 self.cache.insert(text, embedding);
61 }
62
63 pub fn clear(&mut self) {
64 self.cache.clear();
65 }
66}
67
68pub struct CachedEmbedder<E: Embedder> {
70 embedder: E,
71 cache: EmbeddingCache,
72}
73
74impl<E: Embedder> CachedEmbedder<E> {
75 pub fn new(embedder: E, cache_size: usize) -> Self {
76 Self {
77 embedder,
78 cache: EmbeddingCache::new(cache_size),
79 }
80 }
81
82 pub async fn embed_with_cache(&mut self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
83 let mut results = vec![];
84 let mut to_embed = vec![];
85 let mut indices = vec![];
86
87 for (idx, text) in texts.iter().enumerate() {
89 if let Some(cached) = self.cache.get(text) {
90 results.push((idx, cached));
91 } else {
92 to_embed.push(text.clone());
93 indices.push(idx);
94 }
95 }
96
97 if !to_embed.is_empty() {
99 let embeddings = self.embedder.embed(to_embed.clone()).await?;
100
101 for (text, embedding) in to_embed.iter().zip(embeddings.iter()) {
103 self.cache.insert(text.clone(), embedding.clone());
104 }
105
106 for (idx, embedding) in indices.iter().zip(embeddings) {
108 results.push((*idx, embedding));
109 }
110 }
111
112 results.sort_by_key(|(idx, _)| *idx);
114 Ok(results.into_iter().map(|(_, emb)| emb).collect())
115 }
116}