reflex/cache/l2/
cache.rs

1use std::sync::Arc;
2
3use futures_util::future::join_all;
4use tokio::sync::RwLock;
5use tracing::{debug, info, instrument, warn};
6
7use crate::embedding::sinter::SinterEmbedder;
8use crate::vectordb::VectorPoint;
9use crate::vectordb::rescoring::{CandidateEntry, RescorerConfig, VectorRescorer};
10
11use super::backend::BqSearchBackend;
12use super::config::L2Config;
13use super::error::{L2CacheError, L2CacheResult};
14use super::loader::StorageLoader;
15use super::types::L2LookupResult;
16
17/// L2 semantic cache: embed → BQ search → load entries → rescore.
18pub struct L2SemanticCache<B: BqSearchBackend, S: StorageLoader> {
19    embedder: SinterEmbedder,
20    bq_backend: B,
21    storage: S,
22    rescorer: VectorRescorer,
23    config: L2Config,
24}
25
26impl<B: BqSearchBackend, S: StorageLoader> std::fmt::Debug for L2SemanticCache<B, S> {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("L2SemanticCache")
29            .field("embedder", &self.embedder)
30            .field("rescorer", &self.rescorer)
31            .field("config", &self.config)
32            .finish_non_exhaustive()
33    }
34}
35
36impl<B: BqSearchBackend, S: StorageLoader> L2SemanticCache<B, S> {
37    /// Creates a new L2 cache.
38    pub fn new(
39        embedder: SinterEmbedder,
40        bq_backend: B,
41        storage: S,
42        config: L2Config,
43    ) -> L2CacheResult<Self> {
44        config.validate()?;
45
46        let rescorer_config = RescorerConfig {
47            top_k: config.top_k_final,
48            validate_dimensions: config.validate_dimensions,
49        };
50
51        Ok(Self {
52            embedder,
53            bq_backend,
54            storage,
55            rescorer: VectorRescorer::with_config(rescorer_config),
56            config,
57        })
58    }
59
60    /// Returns the active configuration.
61    pub fn config(&self) -> &L2Config {
62        &self.config
63    }
64
65    /// Returns the embedder.
66    pub fn embedder(&self) -> &SinterEmbedder {
67        &self.embedder
68    }
69
70    /// Returns `true` if the embedder is in stub mode.
71    pub fn is_embedder_stub(&self) -> bool {
72        self.embedder.is_stub()
73    }
74
75    /// Returns the storage loader.
76    pub fn storage(&self) -> &S {
77        &self.storage
78    }
79
80    /// Returns the BQ backend.
81    pub fn bq_backend(&self) -> &B {
82        &self.bq_backend
83    }
84
85    /// Searches for semantic matches for `prompt` within `tenant_id`.
86    #[instrument(skip(self, prompt), fields(tenant_id = tenant_id, prompt_len = prompt.len()))]
87    pub async fn search(&self, prompt: &str, tenant_id: u64) -> L2CacheResult<L2LookupResult> {
88        debug!("Generating embedding for prompt");
89        let embedding_f16 =
90            self.embedder
91                .embed(prompt)
92                .map_err(|e| L2CacheError::EmbeddingFailed {
93                    reason: e.to_string(),
94                })?;
95
96        let embedding_f32: Vec<f32> = embedding_f16.iter().map(|v| v.to_f32()).collect();
97
98        debug!(
99            embedding_dim = embedding_f32.len(),
100            "Embedding generated, starting BQ search"
101        );
102
103        let bq_results = self
104            .bq_backend
105            .search_bq(
106                &self.config.collection_name,
107                embedding_f32,
108                self.config.top_k_bq,
109                Some(tenant_id),
110            )
111            .await?;
112
113        let bq_candidates_count = bq_results.len();
114        debug!(
115            candidates = bq_candidates_count,
116            "BQ search complete, loading storage entries"
117        );
118
119        if bq_results.is_empty() {
120            return Err(L2CacheError::NoCandidates);
121        }
122
123        // Parallel storage loading: spawn all loads concurrently to reduce latency
124        // from O(n) sequential to O(1) bounded by the slowest single load
125        let load_futures: Vec<_> = bq_results
126            .iter()
127            .filter_map(|result| {
128                result.storage_key.as_ref().map(|key| {
129                    let key = key.clone();
130                    let id = result.id;
131                    let score = result.score;
132                    async move {
133                        let entry = self.storage.load(&key, tenant_id).await;
134                        (id, score, key, entry)
135                    }
136                })
137            })
138            .collect();
139
140        let load_results = join_all(load_futures).await;
141
142        let mut candidate_entries = Vec::with_capacity(load_results.len());
143        for (id, score, storage_key, entry) in load_results {
144            if let Some(entry) = entry {
145                candidate_entries.push(CandidateEntry::with_bq_score(id, entry, score));
146            } else {
147                warn!(
148                    storage_key = storage_key,
149                    "Storage entry not found or tenant mismatch, skipping candidate"
150                );
151            }
152        }
153
154        if candidate_entries.is_empty() {
155            return Err(L2CacheError::NoCandidates);
156        }
157
158        debug!(
159            loaded = candidate_entries.len(),
160            "Storage entries loaded, starting rescore"
161        );
162
163        let scored_candidates = self
164            .rescorer
165            .rescore(&embedding_f16, candidate_entries)
166            .map_err(|e| L2CacheError::RescoringFailed {
167                reason: e.to_string(),
168            })?;
169
170        info!(
171            tenant_id = tenant_id,
172            bq_candidates = bq_candidates_count,
173            rescored = scored_candidates.len(),
174            best_score = scored_candidates.first().map(|c| c.score),
175            "L2 search complete"
176        );
177
178        Ok(L2LookupResult::new(
179            embedding_f16,
180            scored_candidates,
181            tenant_id,
182            bq_candidates_count,
183        ))
184    }
185
186    /// Indexes `prompt` with metadata and a storage key.
187    #[instrument(skip(self, prompt, storage_key), fields(tenant_id = tenant_id, context_hash = context_hash))]
188    pub async fn index(
189        &self,
190        prompt: &str,
191        tenant_id: u64,
192        context_hash: u64,
193        storage_key: &str,
194        timestamp: i64,
195    ) -> L2CacheResult<u64> {
196        let embedding_f16 =
197            self.embedder
198                .embed(prompt)
199                .map_err(|e| L2CacheError::EmbeddingFailed {
200                    reason: e.to_string(),
201                })?;
202
203        let embedding_f32: Vec<f32> = embedding_f16.iter().map(|v| v.to_f32()).collect();
204
205        let point_id = crate::vectordb::generate_point_id(tenant_id, context_hash);
206
207        let point = VectorPoint {
208            id: point_id,
209            vector: embedding_f32,
210            tenant_id,
211            context_hash,
212            timestamp,
213            storage_key: Some(storage_key.to_string()),
214        };
215
216        self.bq_backend
217            .upsert_points(
218                &self.config.collection_name,
219                vec![point],
220                crate::vectordb::WriteConsistency::Eventual,
221            )
222            .await?;
223
224        debug!(point_id = point_id, "Entry indexed in L2 cache");
225
226        Ok(point_id)
227    }
228
229    /// Ensures the configured collection exists.
230    pub async fn ensure_collection(&self) -> L2CacheResult<()> {
231        self.bq_backend
232            .ensure_collection(&self.config.collection_name, self.config.vector_size)
233            .await?;
234        Ok(())
235    }
236
237    /// Returns `true` if the backend reports readiness.
238    pub async fn is_ready(&self) -> bool {
239        self.bq_backend.is_ready().await
240    }
241}
242
243#[derive(Clone)]
244/// Shared handle to an [`L2SemanticCache`].
245pub struct L2SemanticCacheHandle<B: BqSearchBackend, S: StorageLoader> {
246    inner: Arc<RwLock<L2SemanticCache<B, S>>>,
247}
248
249impl<B: BqSearchBackend, S: StorageLoader> L2SemanticCacheHandle<B, S> {
250    /// Wraps an L2 cache in an `Arc<RwLock<...>>` for shared async access.
251    pub fn new(cache: L2SemanticCache<B, S>) -> Self {
252        Self {
253            inner: Arc::new(RwLock::new(cache)),
254        }
255    }
256
257    /// Delegates to [`L2SemanticCache::search`].
258    pub async fn search(&self, prompt: &str, tenant_id: u64) -> L2CacheResult<L2LookupResult> {
259        self.inner.read().await.search(prompt, tenant_id).await
260    }
261
262    /// Delegates to [`L2SemanticCache::index`].
263    pub async fn index(
264        &self,
265        prompt: &str,
266        tenant_id: u64,
267        context_hash: u64,
268        storage_key: &str,
269        timestamp: i64,
270    ) -> L2CacheResult<u64> {
271        self.inner
272            .read()
273            .await
274            .index(prompt, tenant_id, context_hash, storage_key, timestamp)
275            .await
276    }
277
278    /// Returns the number of strong references to the underlying handle.
279    pub fn strong_count(&self) -> usize {
280        Arc::strong_count(&self.inner)
281    }
282}
283
284impl<B: BqSearchBackend, S: StorageLoader> std::fmt::Debug for L2SemanticCacheHandle<B, S> {
285    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286        f.debug_struct("L2SemanticCacheHandle")
287            .field("strong_count", &self.strong_count())
288            .finish()
289    }
290}