Skip to main content

codelens_engine/embedding/engine_impl/
query_cache.rs

1use anyhow::{Context, Result};
2use sha2::{Digest, Sha256};
3use std::collections::HashMap;
4
5use super::super::{EmbeddingEngine, QueryEmbeddingCacheStats};
6
7impl EmbeddingEngine {
8    pub(crate) fn configured_query_embed_cache_size() -> usize {
9        std::env::var("CODELENS_QUERY_EMBED_CACHE_SIZE")
10            .ok()
11            .and_then(|value| value.trim().parse::<usize>().ok())
12            .unwrap_or(4096)
13            .min(50_000)
14    }
15
16    pub(crate) fn normalize_query_for_cache(query: &str) -> String {
17        query.split_whitespace().collect::<Vec<_>>().join(" ")
18    }
19
20    pub(crate) fn query_cache_key(&self, query: &str) -> String {
21        let normalized = Self::normalize_query_for_cache(query);
22        let mut hasher = Sha256::new();
23        hasher.update(b"cache-v1\n");
24        hasher.update(self.model_name.as_bytes());
25        hasher.update(b"\n");
26        hasher.update(self.runtime_info.backend.as_bytes());
27        hasher.update(b"\n");
28        hasher.update(self.runtime_info.max_length.to_string().as_bytes());
29        hasher.update(b"\n");
30        hasher.update(normalized.as_bytes());
31        format!("{:x}", hasher.finalize())
32    }
33
34    pub(crate) fn embed_texts_cached(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
35        if texts.is_empty() {
36            return Ok(Vec::new());
37        }
38
39        let mut resolved: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
40        let mut missing_order: Vec<String> = Vec::new();
41        let mut missing_positions: HashMap<String, Vec<usize>> = HashMap::new();
42
43        {
44            let mut cache = self
45                .text_embed_cache
46                .lock()
47                .map_err(|_| anyhow::anyhow!("text embedding cache lock"))?;
48            for (index, text) in texts.iter().enumerate() {
49                if let Some(cached) = cache.get(text) {
50                    resolved[index] = Some(cached);
51                } else {
52                    let key = (*text).to_owned();
53                    if !missing_positions.contains_key(&key) {
54                        missing_order.push(key.clone());
55                    }
56                    missing_positions.entry(key).or_default().push(index);
57                }
58            }
59        }
60
61        if !missing_order.is_empty() {
62            let missing_refs: Vec<&str> = missing_order.iter().map(String::as_str).collect();
63            let embeddings = self
64                .model
65                .lock()
66                .map_err(|_| anyhow::anyhow!("model lock"))?
67                .embed(missing_refs, None)
68                .context("text embedding failed")?;
69
70            let mut cache = self
71                .text_embed_cache
72                .lock()
73                .map_err(|_| anyhow::anyhow!("text embedding cache lock"))?;
74            for (text, embedding) in missing_order.into_iter().zip(embeddings) {
75                cache.insert(text.clone(), embedding.clone());
76                if let Some(indices) = missing_positions.remove(&text) {
77                    for index in indices {
78                        resolved[index] = Some(embedding.clone());
79                    }
80                }
81            }
82        }
83
84        resolved
85            .into_iter()
86            .map(|item| item.ok_or_else(|| anyhow::anyhow!("missing embedding cache entry")))
87            .collect()
88    }
89
90    pub fn embed_query_cached(&self, query: &str) -> Result<Vec<f32>> {
91        let max_entries = Self::configured_query_embed_cache_size();
92        if max_entries == 0 {
93            return self
94                .embed_texts_cached(&[query])?
95                .into_iter()
96                .next()
97                .ok_or_else(|| anyhow::anyhow!("missing query embedding"));
98        }
99        let normalized = Self::normalize_query_for_cache(query);
100        let cache_key = self.query_cache_key(&normalized);
101        if let Some(embedding) = self.store.get_query_embedding(&cache_key)? {
102            return Ok(embedding);
103        }
104        let embedding = self
105            .embed_texts_cached(&[normalized.as_str()])?
106            .into_iter()
107            .next()
108            .ok_or_else(|| anyhow::anyhow!("missing query embedding"))?;
109        self.store
110            .put_query_embedding(&cache_key, &normalized, &embedding)?;
111        let _ = self.store.prune_query_embeddings(max_entries)?;
112        Ok(embedding)
113    }
114
115    pub fn prewarm_queries(&self, queries: &[String]) -> Result<usize> {
116        let max_entries = Self::configured_query_embed_cache_size();
117        if max_entries == 0 || queries.is_empty() {
118            return Ok(0);
119        }
120        let mut prewarmed = 0usize;
121        for query in queries {
122            if query.trim().is_empty() {
123                continue;
124            }
125            let _ = self.embed_query_cached(query)?;
126            prewarmed += 1;
127        }
128        Ok(prewarmed)
129    }
130
131    pub fn query_cache_stats(&self) -> Result<QueryEmbeddingCacheStats> {
132        let max_entries = Self::configured_query_embed_cache_size();
133        let entries = if max_entries == 0 {
134            0
135        } else {
136            self.store.query_cache_count()?
137        };
138        Ok(QueryEmbeddingCacheStats {
139            enabled: max_entries > 0,
140            entries,
141            max_entries,
142        })
143    }
144}