Skip to main content

lattice_embed/service/
cached.rs

1//! Caching wrapper for embedding services.
2
3use super::{DEFAULT_MAX_BATCH_SIZE, EmbeddingService, MAX_TEXT_CHARS};
4use crate::error::Result;
5use crate::model::EmbeddingModel;
6use async_trait::async_trait;
7use std::sync::Arc;
8use tracing::debug;
9
10/// **Unstable**: caching strategy and constructor API may change; foundation-internal use only.
11///
12/// Caching wrapper around an embedding service.
13///
14/// Wraps any `EmbeddingService` implementation with LRU caching. Identical
15/// texts (with the same model) will return cached embeddings instead of
16/// recomputing.
17///
18/// # Example
19///
20/// ```rust,no_run
21/// use lattice_embed::{
22///     CachedEmbeddingService, NativeEmbeddingService, EmbeddingService,
23///     EmbeddingModel, EmbeddingCache,
24/// };
25/// use std::sync::Arc;
26///
27/// #[tokio::main]
28/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
29///     let inner = Arc::new(NativeEmbeddingService::new());
30///     let cached = CachedEmbeddingService::new(inner, 1000);
31///
32///     // First call - computes and caches
33///     let emb1 = cached.embed_one("Hello", EmbeddingModel::default()).await?;
34///
35///     // Second call - returns from cache
36///     let emb2 = cached.embed_one("Hello", EmbeddingModel::default()).await?;
37///
38///     assert_eq!(emb1, emb2);
39///     Ok(())
40/// }
41/// ```
42pub struct CachedEmbeddingService<S> {
43    inner: Arc<S>,
44    cache: crate::cache::EmbeddingCache,
45}
46
47impl<S: EmbeddingService> CachedEmbeddingService<S> {
48    /// **Unstable**: constructor signature may change when cache config becomes a struct.
49    ///
50    /// # Arguments
51    ///
52    /// * `inner` - The underlying embedding service
53    /// * `cache_capacity` - Maximum number of embeddings to cache
54    pub fn new(inner: Arc<S>, cache_capacity: usize) -> Self {
55        Self {
56            inner,
57            cache: crate::cache::EmbeddingCache::new(cache_capacity),
58        }
59    }
60
61    /// **Unstable**: constructor signature may change when cache config becomes a struct.
62    pub fn with_default_cache(inner: Arc<S>) -> Self {
63        Self {
64            inner,
65            cache: crate::cache::EmbeddingCache::with_default_capacity(),
66        }
67    }
68
69    /// **Unstable**: returns internal `CacheStats` type which is itself Unstable.
70    pub fn cache_stats(&self) -> crate::cache::CacheStats {
71        self.cache.stats()
72    }
73
74    /// **Unstable**: internal cache management; API subject to change.
75    pub fn clear_cache(&self) {
76        self.cache.clear();
77    }
78}
79
80#[async_trait]
81impl<S: EmbeddingService + 'static> EmbeddingService for CachedEmbeddingService<S> {
82    async fn embed(&self, texts: &[String], model: EmbeddingModel) -> Result<Vec<Vec<f32>>> {
83        use crate::error::EmbedError;
84
85        // Validate inputs before any cache interaction so callers always get
86        // consistent errors regardless of whether the result is fully cached.
87        if texts.is_empty() {
88            return Err(EmbedError::InvalidInput("no texts provided".into()));
89        }
90        if texts.len() > DEFAULT_MAX_BATCH_SIZE {
91            return Err(EmbedError::InvalidInput(format!(
92                "batch size {} exceeds maximum {}",
93                texts.len(),
94                DEFAULT_MAX_BATCH_SIZE
95            )));
96        }
97        for text in texts {
98            if text.len() > MAX_TEXT_CHARS {
99                return Err(EmbedError::TextTooLong {
100                    length: text.len(),
101                    max: MAX_TEXT_CHARS,
102                });
103            }
104        }
105
106        // Fast path: bypass cache entirely when disabled (no key computation, no locking)
107        if !self.cache.is_enabled() {
108            return self.inner.embed(texts, model).await;
109        }
110
111        // Compute cache keys — include the active dimension (for MRL models).
112        let model_config = self.inner.model_config(model);
113        let keys: Vec<_> = texts
114            .iter()
115            .map(|t| self.cache.compute_key(t, model_config))
116            .collect();
117
118        // Check cache for all texts — returns Arc<[f32]> refs (O(1) per hit)
119        let cached = self.cache.get_many(&keys);
120
121        // Identify which texts need embedding
122        let mut to_embed: Vec<(usize, &String)> = Vec::new();
123        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
124
125        for (i, (text, cached_emb)) in texts.iter().zip(cached.into_iter()).enumerate() {
126            if let Some(arc) = cached_emb {
127                results[i] = Some(arc.to_vec());
128            } else {
129                to_embed.push((i, text));
130            }
131        }
132
133        // If all cached, return immediately
134        if to_embed.is_empty() {
135            debug!("all {} texts found in cache", texts.len());
136            // SAFETY: All slots are Some because we only reach here when to_embed is empty,
137            // meaning every text was found in cache and had results[i] = Some(...) assigned.
138            return Ok(results.into_iter().flatten().collect());
139        }
140
141        debug!(
142            "{} texts cached, {} need embedding",
143            texts.len() - to_embed.len(),
144            to_embed.len()
145        );
146
147        // Embed missing texts
148        let texts_to_embed: Vec<String> = to_embed.iter().map(|(_, t)| (*t).clone()).collect();
149        let new_embeddings = self.inner.embed(&texts_to_embed, model).await?;
150
151        // FP-035: validate count before zipping — a count mismatch would silently
152        // drop slots via zip() and return fewer embeddings than requested.
153        if new_embeddings.len() != to_embed.len() {
154            return Err(EmbedError::InferenceFailed(format!(
155                "embedding service returned {} vectors for {} inputs",
156                new_embeddings.len(),
157                to_embed.len()
158            )));
159        }
160
161        // Store in cache and populate results
162        let mut cache_entries = Vec::with_capacity(to_embed.len());
163        for ((i, _), embedding) in to_embed.into_iter().zip(new_embeddings.into_iter()) {
164            cache_entries.push((keys[i], embedding.clone()));
165            results[i] = Some(embedding);
166        }
167        self.cache.put_many(cache_entries);
168
169        // Return all results
170        // SAFETY: All slots are guaranteed to be Some at this point:
171        // - Cached items were assigned via results[i] = Some(arc.to_vec())
172        // - Non-cached items were assigned via results[i] = Some(embedding) in the loop above
173        Ok(results.into_iter().flatten().collect())
174    }
175
176    fn supports_model(&self, model: EmbeddingModel) -> bool {
177        self.inner.supports_model(model)
178    }
179
180    fn name(&self) -> &'static str {
181        "cached-embedding"
182    }
183}
184
185// Suppress dead code warnings for constants that are used by other modules
186const _: () = {
187    let _ = DEFAULT_MAX_BATCH_SIZE;
188    let _ = MAX_TEXT_CHARS;
189};