lattice_embed/service/cached.rs
1//! Caching wrapper for embedding services.
2
3use super::{DEFAULT_MAX_BATCH_SIZE, EmbeddingService, MAX_TEXT_CHARS};
4use crate::error::Result;
5use crate::model::EmbeddingModel;
6use async_trait::async_trait;
7use std::sync::Arc;
8use tracing::debug;
9
10/// **Unstable**: caching strategy and constructor API may change; foundation-internal use only.
11///
12/// Caching wrapper around an embedding service.
13///
14/// Wraps any `EmbeddingService` implementation with LRU caching. Identical
15/// texts (with the same model) will return cached embeddings instead of
16/// recomputing.
17///
18/// # Example
19///
20/// ```rust,no_run
21/// use lattice_embed::{
22/// CachedEmbeddingService, NativeEmbeddingService, EmbeddingService,
23/// EmbeddingModel, EmbeddingCache,
24/// };
25/// use std::sync::Arc;
26///
27/// #[tokio::main]
28/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
29/// let inner = Arc::new(NativeEmbeddingService::new());
30/// let cached = CachedEmbeddingService::new(inner, 1000);
31///
32/// // First call - computes and caches
33/// let emb1 = cached.embed_one("Hello", EmbeddingModel::default()).await?;
34///
35/// // Second call - returns from cache
36/// let emb2 = cached.embed_one("Hello", EmbeddingModel::default()).await?;
37///
38/// assert_eq!(emb1, emb2);
39/// Ok(())
40/// }
41/// ```
42pub struct CachedEmbeddingService<S> {
43 inner: Arc<S>,
44 cache: crate::cache::EmbeddingCache,
45}
46
47impl<S: EmbeddingService> CachedEmbeddingService<S> {
48 /// **Unstable**: constructor signature may change when cache config becomes a struct.
49 ///
50 /// # Arguments
51 ///
52 /// * `inner` - The underlying embedding service
53 /// * `cache_capacity` - Maximum number of embeddings to cache
54 pub fn new(inner: Arc<S>, cache_capacity: usize) -> Self {
55 Self {
56 inner,
57 cache: crate::cache::EmbeddingCache::new(cache_capacity),
58 }
59 }
60
61 /// **Unstable**: constructor signature may change when cache config becomes a struct.
62 pub fn with_default_cache(inner: Arc<S>) -> Self {
63 Self {
64 inner,
65 cache: crate::cache::EmbeddingCache::with_default_capacity(),
66 }
67 }
68
69 /// **Unstable**: returns internal `CacheStats` type which is itself Unstable.
70 pub fn cache_stats(&self) -> crate::cache::CacheStats {
71 self.cache.stats()
72 }
73
74 /// **Unstable**: internal cache management; API subject to change.
75 pub fn clear_cache(&self) {
76 self.cache.clear();
77 }
78}
79
80#[async_trait]
81impl<S: EmbeddingService + 'static> EmbeddingService for CachedEmbeddingService<S> {
82 async fn embed(&self, texts: &[String], model: EmbeddingModel) -> Result<Vec<Vec<f32>>> {
83 use crate::error::EmbedError;
84
85 // Validate inputs before any cache interaction so callers always get
86 // consistent errors regardless of whether the result is fully cached.
87 if texts.is_empty() {
88 return Err(EmbedError::InvalidInput("no texts provided".into()));
89 }
90 if texts.len() > DEFAULT_MAX_BATCH_SIZE {
91 return Err(EmbedError::InvalidInput(format!(
92 "batch size {} exceeds maximum {}",
93 texts.len(),
94 DEFAULT_MAX_BATCH_SIZE
95 )));
96 }
97 for text in texts {
98 if text.len() > MAX_TEXT_CHARS {
99 return Err(EmbedError::TextTooLong {
100 length: text.len(),
101 max: MAX_TEXT_CHARS,
102 });
103 }
104 }
105
106 // Fast path: bypass cache entirely when disabled (no key computation, no locking)
107 if !self.cache.is_enabled() {
108 return self.inner.embed(texts, model).await;
109 }
110
111 // Compute cache keys — include the active dimension (for MRL models).
112 let model_config = self.inner.model_config(model);
113 let keys: Vec<_> = texts
114 .iter()
115 .map(|t| self.cache.compute_key(t, model_config))
116 .collect();
117
118 // Check cache for all texts — returns Arc<[f32]> refs (O(1) per hit)
119 let cached = self.cache.get_many(&keys);
120
121 // Identify which texts need embedding
122 let mut to_embed: Vec<(usize, &String)> = Vec::new();
123 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
124
125 for (i, (text, cached_emb)) in texts.iter().zip(cached.into_iter()).enumerate() {
126 if let Some(arc) = cached_emb {
127 results[i] = Some(arc.to_vec());
128 } else {
129 to_embed.push((i, text));
130 }
131 }
132
133 // If all cached, return immediately
134 if to_embed.is_empty() {
135 debug!("all {} texts found in cache", texts.len());
136 // SAFETY: All slots are Some because we only reach here when to_embed is empty,
137 // meaning every text was found in cache and had results[i] = Some(...) assigned.
138 return Ok(results.into_iter().flatten().collect());
139 }
140
141 debug!(
142 "{} texts cached, {} need embedding",
143 texts.len() - to_embed.len(),
144 to_embed.len()
145 );
146
147 // Embed missing texts
148 let texts_to_embed: Vec<String> = to_embed.iter().map(|(_, t)| (*t).clone()).collect();
149 let new_embeddings = self.inner.embed(&texts_to_embed, model).await?;
150
151 // FP-035: validate count before zipping — a count mismatch would silently
152 // drop slots via zip() and return fewer embeddings than requested.
153 if new_embeddings.len() != to_embed.len() {
154 return Err(EmbedError::InferenceFailed(format!(
155 "embedding service returned {} vectors for {} inputs",
156 new_embeddings.len(),
157 to_embed.len()
158 )));
159 }
160
161 // Store in cache and populate results
162 let mut cache_entries = Vec::with_capacity(to_embed.len());
163 for ((i, _), embedding) in to_embed.into_iter().zip(new_embeddings.into_iter()) {
164 cache_entries.push((keys[i], embedding.clone()));
165 results[i] = Some(embedding);
166 }
167 self.cache.put_many(cache_entries);
168
169 // Return all results
170 // SAFETY: All slots are guaranteed to be Some at this point:
171 // - Cached items were assigned via results[i] = Some(arc.to_vec())
172 // - Non-cached items were assigned via results[i] = Some(embedding) in the loop above
173 Ok(results.into_iter().flatten().collect())
174 }
175
176 fn supports_model(&self, model: EmbeddingModel) -> bool {
177 self.inner.supports_model(model)
178 }
179
180 fn name(&self) -> &'static str {
181 "cached-embedding"
182 }
183}
184
185// Suppress dead code warnings for constants that are used by other modules
186const _: () = {
187 let _ = DEFAULT_MAX_BATCH_SIZE;
188 let _ = MAX_TEXT_CHARS;
189};