Skip to main content

rucora_embed/
cache.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use async_trait::async_trait;
5use rucora_core::{embed::EmbeddingProvider, error::ProviderError};
6
7// EmbeddingProvider 的简单内存缓存包装器:
8// - 对单条 `embed`:以 `text` 作为 key 缓存向量结果
9// - 对批量 `embed_batch`:优先从缓存命中,未命中的部分再交给 inner 计算并回填缓存
10// - 使用 `Mutex<HashMap<..>>` 做线程安全保护;锁获取失败时会退化为不读/不写缓存
11pub struct CachedEmbeddingProvider<P> {
12    // 实际执行 embedding 的底层 provider
13    inner: Arc<P>,
14    // 进程内缓存(key 为原始文本,value 为 embedding 向量)
15    cache: Mutex<HashMap<String, Vec<f32>>>,
16}
17
18impl<P> CachedEmbeddingProvider<P> {
19    // 通过传入一个 provider 创建缓存包装器(内部会包一层 Arc 便于 clone)
20    pub fn new(inner: P) -> Self {
21        Self {
22            inner: Arc::new(inner),
23            cache: Mutex::new(HashMap::new()),
24        }
25    }
26
27    // 直接用已有的 Arc provider 创建缓存包装器
28    pub fn new_arc(inner: Arc<P>) -> Self {
29        Self {
30            inner,
31            cache: Mutex::new(HashMap::new()),
32        }
33    }
34
35    // 获取底层 provider 的 Arc(用于外部复用/共享同一个 provider)
36    pub fn inner(&self) -> Arc<P> {
37        self.inner.clone()
38    }
39
40    // 若底层 provider 声明了固定维度,则对返回向量做维度校验,避免混入错误数据到缓存。
41    fn validate_dim(&self, v: &[f32]) -> Result<(), ProviderError>
42    where
43        P: EmbeddingProvider,
44    {
45        if let Some(dim) = self.inner.embedding_dim()
46            && v.len() != dim
47        {
48            return Err(ProviderError::Message(format!(
49                "embedding_dim 校验失败:expected={} got={}",
50                dim,
51                v.len()
52            )));
53        }
54        Ok(())
55    }
56}
57
58#[async_trait]
59impl<P> EmbeddingProvider for CachedEmbeddingProvider<P>
60where
61    P: EmbeddingProvider,
62{
63    async fn embed(&self, text: &str) -> Result<Vec<f32>, ProviderError> {
64        // 先尝试从缓存读取;如果锁获取失败,则直接跳过缓存读取(保证功能可用)。
65        if let Ok(cache) = self.cache.lock()
66            && let Some(v) = cache.get(text)
67        {
68            return Ok(v.clone());
69        }
70
71        // 未命中则调用底层 provider 计算
72        let v = self.inner.embed(text).await?;
73        // 写入缓存前先校验维度,避免缓存不合法向量
74        self.validate_dim(&v)?;
75
76        // 尝试写回缓存;锁获取失败时静默跳过(不影响返回结果)
77        if let Ok(mut cache) = self.cache.lock() {
78            cache.insert(text.to_string(), v.clone());
79        }
80
81        Ok(v)
82    }
83
84    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, ProviderError> {
85        if texts.is_empty() {
86            return Ok(Vec::new());
87        }
88
89        // 输出向量与输入 texts 一一对应:out[i] 对应 texts[i]
90        let mut out: Vec<Vec<f32>> = vec![Vec::new(); texts.len()];
91        // missing_* 用于记录“未命中缓存”的文本及其在 out 中的位置,稍后批量补齐
92        let mut missing_texts: Vec<String> = Vec::new();
93        let mut missing_pos: Vec<usize> = Vec::new();
94
95        // 尽量在一次加锁期间完成所有读取,降低锁竞争
96        if let Ok(cache) = self.cache.lock() {
97            for (i, t) in texts.iter().enumerate() {
98                if let Some(v) = cache.get(t) {
99                    out[i] = v.clone();
100                } else {
101                    missing_texts.push(t.clone());
102                    missing_pos.push(i);
103                }
104            }
105        } else {
106            // 锁获取失败则视为全部未命中:直接走底层批量计算
107            missing_texts.extend_from_slice(texts);
108            missing_pos.extend(0..texts.len());
109        }
110
111        if !missing_texts.is_empty() {
112            // 仅对缺失部分调用底层 provider,减少重复计算
113            let got = self.inner.embed_batch(&missing_texts).await?;
114            if got.len() != missing_texts.len() {
115                return Err(ProviderError::Message(
116                    "embed_batch 返回的向量数量与输入不一致".to_string(),
117                ));
118            }
119
120            for (j, v) in got.into_iter().enumerate() {
121                self.validate_dim(&v)?;
122                // 将缺失项回填到对应位置,保证输出顺序与输入一致
123                let pos = missing_pos[j];
124                out[pos] = v.clone();
125                // 同时回写缓存(失败则跳过,不影响返回)
126                if let Ok(mut cache) = self.cache.lock() {
127                    cache.insert(missing_texts[j].clone(), v);
128                }
129            }
130        }
131
132        Ok(out)
133    }
134
135    fn embedding_dim(&self) -> Option<usize> {
136        // 缓存层不改变维度信息,直接透传
137        self.inner.embedding_dim()
138    }
139}