1use std::sync::Arc;
2
3use futures_util::future::join_all;
4use tokio::sync::RwLock;
5use tracing::{debug, info, instrument, warn};
6
7use crate::embedding::sinter::SinterEmbedder;
8use crate::vectordb::VectorPoint;
9use crate::vectordb::rescoring::{CandidateEntry, RescorerConfig, VectorRescorer};
10
11use super::backend::BqSearchBackend;
12use super::config::L2Config;
13use super::error::{L2CacheError, L2CacheResult};
14use super::loader::StorageLoader;
15use super::types::L2LookupResult;
16
17pub struct L2SemanticCache<B: BqSearchBackend, S: StorageLoader> {
19 embedder: SinterEmbedder,
20 bq_backend: B,
21 storage: S,
22 rescorer: VectorRescorer,
23 config: L2Config,
24}
25
26impl<B: BqSearchBackend, S: StorageLoader> std::fmt::Debug for L2SemanticCache<B, S> {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct("L2SemanticCache")
29 .field("embedder", &self.embedder)
30 .field("rescorer", &self.rescorer)
31 .field("config", &self.config)
32 .finish_non_exhaustive()
33 }
34}
35
36impl<B: BqSearchBackend, S: StorageLoader> L2SemanticCache<B, S> {
37 pub fn new(
39 embedder: SinterEmbedder,
40 bq_backend: B,
41 storage: S,
42 config: L2Config,
43 ) -> L2CacheResult<Self> {
44 config.validate()?;
45
46 let rescorer_config = RescorerConfig {
47 top_k: config.top_k_final,
48 validate_dimensions: config.validate_dimensions,
49 };
50
51 Ok(Self {
52 embedder,
53 bq_backend,
54 storage,
55 rescorer: VectorRescorer::with_config(rescorer_config),
56 config,
57 })
58 }
59
60 pub fn config(&self) -> &L2Config {
62 &self.config
63 }
64
65 pub fn embedder(&self) -> &SinterEmbedder {
67 &self.embedder
68 }
69
70 pub fn is_embedder_stub(&self) -> bool {
72 self.embedder.is_stub()
73 }
74
75 pub fn storage(&self) -> &S {
77 &self.storage
78 }
79
80 pub fn bq_backend(&self) -> &B {
82 &self.bq_backend
83 }
84
85 #[instrument(skip(self, prompt), fields(tenant_id = tenant_id, prompt_len = prompt.len()))]
87 pub async fn search(&self, prompt: &str, tenant_id: u64) -> L2CacheResult<L2LookupResult> {
88 debug!("Generating embedding for prompt");
89 let embedding_f16 =
90 self.embedder
91 .embed(prompt)
92 .map_err(|e| L2CacheError::EmbeddingFailed {
93 reason: e.to_string(),
94 })?;
95
96 let embedding_f32: Vec<f32> = embedding_f16.iter().map(|v| v.to_f32()).collect();
97
98 debug!(
99 embedding_dim = embedding_f32.len(),
100 "Embedding generated, starting BQ search"
101 );
102
103 let bq_results = self
104 .bq_backend
105 .search_bq(
106 &self.config.collection_name,
107 embedding_f32,
108 self.config.top_k_bq,
109 Some(tenant_id),
110 )
111 .await?;
112
113 let bq_candidates_count = bq_results.len();
114 debug!(
115 candidates = bq_candidates_count,
116 "BQ search complete, loading storage entries"
117 );
118
119 if bq_results.is_empty() {
120 return Err(L2CacheError::NoCandidates);
121 }
122
123 let load_futures: Vec<_> = bq_results
126 .iter()
127 .filter_map(|result| {
128 result.storage_key.as_ref().map(|key| {
129 let key = key.clone();
130 let id = result.id;
131 let score = result.score;
132 async move {
133 let entry = self.storage.load(&key, tenant_id).await;
134 (id, score, key, entry)
135 }
136 })
137 })
138 .collect();
139
140 let load_results = join_all(load_futures).await;
141
142 let mut candidate_entries = Vec::with_capacity(load_results.len());
143 for (id, score, storage_key, entry) in load_results {
144 if let Some(entry) = entry {
145 candidate_entries.push(CandidateEntry::with_bq_score(id, entry, score));
146 } else {
147 warn!(
148 storage_key = storage_key,
149 "Storage entry not found or tenant mismatch, skipping candidate"
150 );
151 }
152 }
153
154 if candidate_entries.is_empty() {
155 return Err(L2CacheError::NoCandidates);
156 }
157
158 debug!(
159 loaded = candidate_entries.len(),
160 "Storage entries loaded, starting rescore"
161 );
162
163 let scored_candidates = self
164 .rescorer
165 .rescore(&embedding_f16, candidate_entries)
166 .map_err(|e| L2CacheError::RescoringFailed {
167 reason: e.to_string(),
168 })?;
169
170 info!(
171 tenant_id = tenant_id,
172 bq_candidates = bq_candidates_count,
173 rescored = scored_candidates.len(),
174 best_score = scored_candidates.first().map(|c| c.score),
175 "L2 search complete"
176 );
177
178 Ok(L2LookupResult::new(
179 embedding_f16,
180 scored_candidates,
181 tenant_id,
182 bq_candidates_count,
183 ))
184 }
185
186 #[instrument(skip(self, prompt, storage_key), fields(tenant_id = tenant_id, context_hash = context_hash))]
188 pub async fn index(
189 &self,
190 prompt: &str,
191 tenant_id: u64,
192 context_hash: u64,
193 storage_key: &str,
194 timestamp: i64,
195 ) -> L2CacheResult<u64> {
196 let embedding_f16 =
197 self.embedder
198 .embed(prompt)
199 .map_err(|e| L2CacheError::EmbeddingFailed {
200 reason: e.to_string(),
201 })?;
202
203 let embedding_f32: Vec<f32> = embedding_f16.iter().map(|v| v.to_f32()).collect();
204
205 let point_id = crate::vectordb::generate_point_id(tenant_id, context_hash);
206
207 let point = VectorPoint {
208 id: point_id,
209 vector: embedding_f32,
210 tenant_id,
211 context_hash,
212 timestamp,
213 storage_key: Some(storage_key.to_string()),
214 };
215
216 self.bq_backend
217 .upsert_points(
218 &self.config.collection_name,
219 vec![point],
220 crate::vectordb::WriteConsistency::Eventual,
221 )
222 .await?;
223
224 debug!(point_id = point_id, "Entry indexed in L2 cache");
225
226 Ok(point_id)
227 }
228
229 pub async fn ensure_collection(&self) -> L2CacheResult<()> {
231 self.bq_backend
232 .ensure_collection(&self.config.collection_name, self.config.vector_size)
233 .await?;
234 Ok(())
235 }
236
237 pub async fn is_ready(&self) -> bool {
239 self.bq_backend.is_ready().await
240 }
241}
242
243#[derive(Clone)]
244pub struct L2SemanticCacheHandle<B: BqSearchBackend, S: StorageLoader> {
246 inner: Arc<RwLock<L2SemanticCache<B, S>>>,
247}
248
249impl<B: BqSearchBackend, S: StorageLoader> L2SemanticCacheHandle<B, S> {
250 pub fn new(cache: L2SemanticCache<B, S>) -> Self {
252 Self {
253 inner: Arc::new(RwLock::new(cache)),
254 }
255 }
256
257 pub async fn search(&self, prompt: &str, tenant_id: u64) -> L2CacheResult<L2LookupResult> {
259 self.inner.read().await.search(prompt, tenant_id).await
260 }
261
262 pub async fn index(
264 &self,
265 prompt: &str,
266 tenant_id: u64,
267 context_hash: u64,
268 storage_key: &str,
269 timestamp: i64,
270 ) -> L2CacheResult<u64> {
271 self.inner
272 .read()
273 .await
274 .index(prompt, tenant_id, context_hash, storage_key, timestamp)
275 .await
276 }
277
278 pub fn strong_count(&self) -> usize {
280 Arc::strong_count(&self.inner)
281 }
282}
283
284impl<B: BqSearchBackend, S: StorageLoader> std::fmt::Debug for L2SemanticCacheHandle<B, S> {
285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 f.debug_struct("L2SemanticCacheHandle")
287 .field("strong_count", &self.strong_count())
288 .finish()
289 }
290}