1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use async_trait::async_trait;
5use rucora_core::{embed::EmbeddingProvider, error::ProviderError};
6
7pub struct CachedEmbeddingProvider<P> {
12 inner: Arc<P>,
14 cache: Mutex<HashMap<String, Vec<f32>>>,
16}
17
18impl<P> CachedEmbeddingProvider<P> {
19 pub fn new(inner: P) -> Self {
21 Self {
22 inner: Arc::new(inner),
23 cache: Mutex::new(HashMap::new()),
24 }
25 }
26
27 pub fn new_arc(inner: Arc<P>) -> Self {
29 Self {
30 inner,
31 cache: Mutex::new(HashMap::new()),
32 }
33 }
34
35 pub fn inner(&self) -> Arc<P> {
37 self.inner.clone()
38 }
39
40 fn validate_dim(&self, v: &[f32]) -> Result<(), ProviderError>
42 where
43 P: EmbeddingProvider,
44 {
45 if let Some(dim) = self.inner.embedding_dim()
46 && v.len() != dim
47 {
48 return Err(ProviderError::Message(format!(
49 "embedding_dim 校验失败:expected={} got={}",
50 dim,
51 v.len()
52 )));
53 }
54 Ok(())
55 }
56}
57
58#[async_trait]
59impl<P> EmbeddingProvider for CachedEmbeddingProvider<P>
60where
61 P: EmbeddingProvider,
62{
63 async fn embed(&self, text: &str) -> Result<Vec<f32>, ProviderError> {
64 if let Ok(cache) = self.cache.lock()
66 && let Some(v) = cache.get(text)
67 {
68 return Ok(v.clone());
69 }
70
71 let v = self.inner.embed(text).await?;
73 self.validate_dim(&v)?;
75
76 if let Ok(mut cache) = self.cache.lock() {
78 cache.insert(text.to_string(), v.clone());
79 }
80
81 Ok(v)
82 }
83
84 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, ProviderError> {
85 if texts.is_empty() {
86 return Ok(Vec::new());
87 }
88
89 let mut out: Vec<Vec<f32>> = vec![Vec::new(); texts.len()];
91 let mut missing_texts: Vec<String> = Vec::new();
93 let mut missing_pos: Vec<usize> = Vec::new();
94
95 if let Ok(cache) = self.cache.lock() {
97 for (i, t) in texts.iter().enumerate() {
98 if let Some(v) = cache.get(t) {
99 out[i] = v.clone();
100 } else {
101 missing_texts.push(t.clone());
102 missing_pos.push(i);
103 }
104 }
105 } else {
106 missing_texts.extend_from_slice(texts);
108 missing_pos.extend(0..texts.len());
109 }
110
111 if !missing_texts.is_empty() {
112 let got = self.inner.embed_batch(&missing_texts).await?;
114 if got.len() != missing_texts.len() {
115 return Err(ProviderError::Message(
116 "embed_batch 返回的向量数量与输入不一致".to_string(),
117 ));
118 }
119
120 for (j, v) in got.into_iter().enumerate() {
121 self.validate_dim(&v)?;
122 let pos = missing_pos[j];
124 out[pos] = v.clone();
125 if let Ok(mut cache) = self.cache.lock() {
127 cache.insert(missing_texts[j].clone(), v);
128 }
129 }
130 }
131
132 Ok(out)
133 }
134
135 fn embedding_dim(&self) -> Option<usize> {
136 self.inner.embedding_dim()
138 }
139}