Skip to main content

lattice_embed/service/
cached.rs

1//! Caching wrapper for embedding services.
2
3use super::{DEFAULT_MAX_BATCH_SIZE, EmbeddingRole, EmbeddingService, MAX_TEXT_CHARS};
4use crate::error::Result;
5use crate::model::EmbeddingModel;
6use async_trait::async_trait;
7use std::sync::Arc;
8use tracing::debug;
9
10/// **Unstable**: caching strategy and constructor API may change; foundation-internal use only.
11///
12/// Caching wrapper around an embedding service.
13///
14/// Wraps any `EmbeddingService` implementation with LRU caching. Identical
15/// texts (with the same model) will return cached embeddings instead of
16/// recomputing.
17///
18/// # Example
19///
20/// ```rust,no_run
21/// use lattice_embed::{
22///     CachedEmbeddingService, NativeEmbeddingService, EmbeddingService,
23///     EmbeddingModel, EmbeddingCache,
24/// };
25/// use std::sync::Arc;
26///
27/// #[tokio::main]
28/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
29///     let inner = Arc::new(NativeEmbeddingService::new());
30///     let cached = CachedEmbeddingService::new(inner, 1000);
31///
32///     // First call - computes and caches
33///     let emb1 = cached.embed_one("Hello", EmbeddingModel::default()).await?;
34///
35///     // Second call - returns from cache
36///     let emb2 = cached.embed_one("Hello", EmbeddingModel::default()).await?;
37///
38///     assert_eq!(emb1, emb2);
39///     Ok(())
40/// }
41/// ```
42pub struct CachedEmbeddingService<S> {
43    inner: Arc<S>,
44    cache: crate::cache::EmbeddingCache,
45}
46
47impl<S: EmbeddingService> CachedEmbeddingService<S> {
48    /// **Unstable**: constructor signature may change when cache config becomes a struct.
49    ///
50    /// # Arguments
51    ///
52    /// * `inner` - The underlying embedding service
53    /// * `cache_capacity` - Maximum number of embeddings to cache
54    pub fn new(inner: Arc<S>, cache_capacity: usize) -> Self {
55        Self {
56            inner,
57            cache: crate::cache::EmbeddingCache::new(cache_capacity),
58        }
59    }
60
61    /// **Unstable**: constructor signature may change when cache config becomes a struct.
62    pub fn with_default_cache(inner: Arc<S>) -> Self {
63        Self {
64            inner,
65            cache: crate::cache::EmbeddingCache::with_default_capacity(),
66        }
67    }
68
69    /// **Unstable**: returns internal `CacheStats` type which is itself Unstable.
70    pub fn cache_stats(&self) -> crate::cache::CacheStats {
71        self.cache.stats()
72    }
73
74    /// **Unstable**: internal cache management; API subject to change.
75    pub fn clear_cache(&self) {
76        self.cache.clear();
77    }
78}
79
80#[async_trait]
81impl<S: EmbeddingService + 'static> EmbeddingService for CachedEmbeddingService<S> {
82    async fn embed(&self, texts: &[String], model: EmbeddingModel) -> Result<Vec<Vec<f32>>> {
83        // Generic role: cache key does NOT include a role tag, maintaining
84        // backwards compatibility with any on-disk cache entries written before
85        // role-aware keys were introduced.
86        self.embed_with_role(texts, model, EmbeddingRole::Generic)
87            .await
88    }
89
90    /// Override: apply query prompt then cache with `Query` role key.
91    async fn embed_query(&self, texts: &[String], model: EmbeddingModel) -> Result<Vec<Vec<f32>>> {
92        let prefix = model.query_instruction();
93        let prompted = super::apply_prefix(texts, prefix);
94        self.embed_with_role(&prompted, model, EmbeddingRole::Query)
95            .await
96    }
97
98    /// Override: apply passage prompt then cache with `Passage` role key.
99    async fn embed_passage(
100        &self,
101        texts: &[String],
102        model: EmbeddingModel,
103    ) -> Result<Vec<Vec<f32>>> {
104        let prefix = model.document_instruction();
105        let prompted = super::apply_prefix(texts, prefix);
106        self.embed_with_role(&prompted, model, EmbeddingRole::Passage)
107            .await
108    }
109
110    fn supports_model(&self, model: EmbeddingModel) -> bool {
111        self.inner.supports_model(model)
112    }
113
114    fn name(&self) -> &'static str {
115        "cached-embedding"
116    }
117}
118
119impl<S: EmbeddingService + 'static> CachedEmbeddingService<S> {
120    /// Core cache-and-embed implementation shared by `embed`, `embed_query`, and
121    /// `embed_passage`.  `texts` must already have the prompt prefix applied; `role`
122    /// is used only as part of the cache key so that different roles produce separate
123    /// cache entries for the same raw text.
124    async fn embed_with_role(
125        &self,
126        texts: &[String],
127        model: EmbeddingModel,
128        role: EmbeddingRole,
129    ) -> Result<Vec<Vec<f32>>> {
130        use crate::error::EmbedError;
131
132        // Validate inputs before any cache interaction so callers always get
133        // consistent errors regardless of whether the result is fully cached.
134        if texts.is_empty() {
135            return Err(EmbedError::InvalidInput("no texts provided".into()));
136        }
137        if texts.len() > DEFAULT_MAX_BATCH_SIZE {
138            return Err(EmbedError::InvalidInput(format!(
139                "batch size {} exceeds maximum {}",
140                texts.len(),
141                DEFAULT_MAX_BATCH_SIZE
142            )));
143        }
144        for text in texts {
145            if text.len() > MAX_TEXT_CHARS {
146                return Err(EmbedError::TextTooLong {
147                    length: text.len(),
148                    max: MAX_TEXT_CHARS,
149                });
150            }
151        }
152
153        // Fast path: bypass cache entirely when disabled (no key computation, no locking)
154        if !self.cache.is_enabled() {
155            return self.inner.embed(texts, model).await;
156        }
157
158        // Compute cache keys — include the active dimension (for MRL models) and role.
159        let model_config = self.inner.model_config(model);
160        let keys: Vec<_> = texts
161            .iter()
162            .map(|t| self.cache.compute_key(t, model_config, role))
163            .collect();
164
165        // Check cache for all texts — returns Arc<[f32]> refs (O(1) per hit)
166        let cached = self.cache.get_many(&keys);
167
168        // Identify which texts need embedding
169        let mut to_embed: Vec<(usize, &String)> = Vec::new();
170        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
171
172        for (i, (text, cached_emb)) in texts.iter().zip(cached.into_iter()).enumerate() {
173            if let Some(arc) = cached_emb {
174                results[i] = Some(arc.to_vec());
175            } else {
176                to_embed.push((i, text));
177            }
178        }
179
180        // If all cached, return immediately
181        if to_embed.is_empty() {
182            debug!("all {} texts found in cache", texts.len());
183            // SAFETY: All slots are Some because we only reach here when to_embed is empty,
184            // meaning every text was found in cache and had results[i] = Some(...) assigned.
185            return Ok(results.into_iter().flatten().collect());
186        }
187
188        debug!(
189            "{} texts cached, {} need embedding",
190            texts.len() - to_embed.len(),
191            to_embed.len()
192        );
193
194        // Embed missing texts (after prompt is already applied in texts)
195        let texts_to_embed: Vec<String> = to_embed.iter().map(|(_, t)| (*t).clone()).collect();
196        let new_embeddings = self.inner.embed(&texts_to_embed, model).await?;
197
198        // FP-035: validate count before zipping — a count mismatch would silently
199        // drop slots via zip() and return fewer embeddings than requested.
200        if new_embeddings.len() != to_embed.len() {
201            return Err(EmbedError::InferenceFailed(format!(
202                "embedding service returned {} vectors for {} inputs",
203                new_embeddings.len(),
204                to_embed.len()
205            )));
206        }
207
208        // Store in cache and populate results
209        let mut cache_entries = Vec::with_capacity(to_embed.len());
210        for ((i, _), embedding) in to_embed.into_iter().zip(new_embeddings.into_iter()) {
211            cache_entries.push((keys[i], embedding.clone()));
212            results[i] = Some(embedding);
213        }
214        self.cache.put_many(cache_entries);
215
216        // Return all results
217        // SAFETY: All slots are guaranteed to be Some at this point:
218        // - Cached items were assigned via results[i] = Some(arc.to_vec())
219        // - Non-cached items were assigned via results[i] = Some(embedding) in the loop above
220        Ok(results.into_iter().flatten().collect())
221    }
222}
223
224// Suppress dead code warnings for constants that are used by other modules
225const _: () = {
226    let _ = DEFAULT_MAX_BATCH_SIZE;
227    let _ = MAX_TEXT_CHARS;
228};