lattice_embed/service/cached.rs
1//! Caching wrapper for embedding services.
2
3use super::{DEFAULT_MAX_BATCH_SIZE, EmbeddingRole, EmbeddingService, MAX_TEXT_CHARS};
4use crate::error::Result;
5use crate::model::EmbeddingModel;
6use async_trait::async_trait;
7use std::sync::Arc;
8use tracing::debug;
9
10/// **Unstable**: caching strategy and constructor API may change; foundation-internal use only.
11///
12/// Caching wrapper around an embedding service.
13///
14/// Wraps any `EmbeddingService` implementation with LRU caching. Identical
15/// texts (with the same model) will return cached embeddings instead of
16/// recomputing.
17///
18/// # Example
19///
20/// ```rust,no_run
21/// use lattice_embed::{
22/// CachedEmbeddingService, NativeEmbeddingService, EmbeddingService,
23/// EmbeddingModel, EmbeddingCache,
24/// };
25/// use std::sync::Arc;
26///
27/// #[tokio::main]
28/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
29/// let inner = Arc::new(NativeEmbeddingService::new());
30/// let cached = CachedEmbeddingService::new(inner, 1000);
31///
32/// // First call - computes and caches
33/// let emb1 = cached.embed_one("Hello", EmbeddingModel::default()).await?;
34///
35/// // Second call - returns from cache
36/// let emb2 = cached.embed_one("Hello", EmbeddingModel::default()).await?;
37///
38/// assert_eq!(emb1, emb2);
39/// Ok(())
40/// }
41/// ```
42pub struct CachedEmbeddingService<S> {
43 inner: Arc<S>,
44 cache: crate::cache::EmbeddingCache,
45}
46
47impl<S: EmbeddingService> CachedEmbeddingService<S> {
48 /// **Unstable**: constructor signature may change when cache config becomes a struct.
49 ///
50 /// # Arguments
51 ///
52 /// * `inner` - The underlying embedding service
53 /// * `cache_capacity` - Maximum number of embeddings to cache
54 pub fn new(inner: Arc<S>, cache_capacity: usize) -> Self {
55 Self {
56 inner,
57 cache: crate::cache::EmbeddingCache::new(cache_capacity),
58 }
59 }
60
61 /// **Unstable**: constructor signature may change when cache config becomes a struct.
62 pub fn with_default_cache(inner: Arc<S>) -> Self {
63 Self {
64 inner,
65 cache: crate::cache::EmbeddingCache::with_default_capacity(),
66 }
67 }
68
69 /// **Unstable**: returns internal `CacheStats` type which is itself Unstable.
70 pub fn cache_stats(&self) -> crate::cache::CacheStats {
71 self.cache.stats()
72 }
73
74 /// **Unstable**: internal cache management; API subject to change.
75 pub fn clear_cache(&self) {
76 self.cache.clear();
77 }
78}
79
80#[async_trait]
81impl<S: EmbeddingService + 'static> EmbeddingService for CachedEmbeddingService<S> {
82 async fn embed(&self, texts: &[String], model: EmbeddingModel) -> Result<Vec<Vec<f32>>> {
83 // Generic role: cache key does NOT include a role tag, maintaining
84 // backwards compatibility with any on-disk cache entries written before
85 // role-aware keys were introduced.
86 self.embed_with_role(texts, model, EmbeddingRole::Generic)
87 .await
88 }
89
90 /// Override: apply query prompt then cache with `Query` role key.
91 async fn embed_query(&self, texts: &[String], model: EmbeddingModel) -> Result<Vec<Vec<f32>>> {
92 let prefix = model.query_instruction();
93 let prompted = super::apply_prefix(texts, prefix);
94 self.embed_with_role(&prompted, model, EmbeddingRole::Query)
95 .await
96 }
97
98 /// Override: apply passage prompt then cache with `Passage` role key.
99 async fn embed_passage(
100 &self,
101 texts: &[String],
102 model: EmbeddingModel,
103 ) -> Result<Vec<Vec<f32>>> {
104 let prefix = model.document_instruction();
105 let prompted = super::apply_prefix(texts, prefix);
106 self.embed_with_role(&prompted, model, EmbeddingRole::Passage)
107 .await
108 }
109
110 fn supports_model(&self, model: EmbeddingModel) -> bool {
111 self.inner.supports_model(model)
112 }
113
114 fn name(&self) -> &'static str {
115 "cached-embedding"
116 }
117}
118
119impl<S: EmbeddingService + 'static> CachedEmbeddingService<S> {
120 /// Core cache-and-embed implementation shared by `embed`, `embed_query`, and
121 /// `embed_passage`. `texts` must already have the prompt prefix applied; `role`
122 /// is used only as part of the cache key so that different roles produce separate
123 /// cache entries for the same raw text.
124 async fn embed_with_role(
125 &self,
126 texts: &[String],
127 model: EmbeddingModel,
128 role: EmbeddingRole,
129 ) -> Result<Vec<Vec<f32>>> {
130 use crate::error::EmbedError;
131
132 // Validate inputs before any cache interaction so callers always get
133 // consistent errors regardless of whether the result is fully cached.
134 if texts.is_empty() {
135 return Err(EmbedError::InvalidInput("no texts provided".into()));
136 }
137 if texts.len() > DEFAULT_MAX_BATCH_SIZE {
138 return Err(EmbedError::InvalidInput(format!(
139 "batch size {} exceeds maximum {}",
140 texts.len(),
141 DEFAULT_MAX_BATCH_SIZE
142 )));
143 }
144 for text in texts {
145 if text.len() > MAX_TEXT_CHARS {
146 return Err(EmbedError::TextTooLong {
147 length: text.len(),
148 max: MAX_TEXT_CHARS,
149 });
150 }
151 }
152
153 // Fast path: bypass cache entirely when disabled (no key computation, no locking)
154 if !self.cache.is_enabled() {
155 return self.inner.embed(texts, model).await;
156 }
157
158 // Compute cache keys — include the active dimension (for MRL models) and role.
159 let model_config = self.inner.model_config(model);
160 let keys: Vec<_> = texts
161 .iter()
162 .map(|t| self.cache.compute_key(t, model_config, role))
163 .collect();
164
165 // Check cache for all texts — returns Arc<[f32]> refs (O(1) per hit)
166 let cached = self.cache.get_many(&keys);
167
168 // Identify which texts need embedding
169 let mut to_embed: Vec<(usize, &String)> = Vec::new();
170 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
171
172 for (i, (text, cached_emb)) in texts.iter().zip(cached.into_iter()).enumerate() {
173 if let Some(arc) = cached_emb {
174 results[i] = Some(arc.to_vec());
175 } else {
176 to_embed.push((i, text));
177 }
178 }
179
180 // If all cached, return immediately
181 if to_embed.is_empty() {
182 debug!("all {} texts found in cache", texts.len());
183 // SAFETY: All slots are Some because we only reach here when to_embed is empty,
184 // meaning every text was found in cache and had results[i] = Some(...) assigned.
185 return Ok(results.into_iter().flatten().collect());
186 }
187
188 debug!(
189 "{} texts cached, {} need embedding",
190 texts.len() - to_embed.len(),
191 to_embed.len()
192 );
193
194 // Embed missing texts (after prompt is already applied in texts)
195 let texts_to_embed: Vec<String> = to_embed.iter().map(|(_, t)| (*t).clone()).collect();
196 let new_embeddings = self.inner.embed(&texts_to_embed, model).await?;
197
198 // FP-035: validate count before zipping — a count mismatch would silently
199 // drop slots via zip() and return fewer embeddings than requested.
200 if new_embeddings.len() != to_embed.len() {
201 return Err(EmbedError::InferenceFailed(format!(
202 "embedding service returned {} vectors for {} inputs",
203 new_embeddings.len(),
204 to_embed.len()
205 )));
206 }
207
208 // Store in cache and populate results
209 let mut cache_entries = Vec::with_capacity(to_embed.len());
210 for ((i, _), embedding) in to_embed.into_iter().zip(new_embeddings.into_iter()) {
211 cache_entries.push((keys[i], embedding.clone()));
212 results[i] = Some(embedding);
213 }
214 self.cache.put_many(cache_entries);
215
216 // Return all results
217 // SAFETY: All slots are guaranteed to be Some at this point:
218 // - Cached items were assigned via results[i] = Some(arc.to_vec())
219 // - Non-cached items were assigned via results[i] = Some(embedding) in the loop above
220 Ok(results.into_iter().flatten().collect())
221 }
222}
223
224// Suppress dead code warnings for constants that are used by other modules
225const _: () = {
226 let _ = DEFAULT_MAX_BATCH_SIZE;
227 let _ = MAX_TEXT_CHARS;
228};