lance_core/
cache.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Cache implementation
5
6use std::any::{Any, TypeId};
7use std::sync::{
8    atomic::{AtomicU64, Ordering},
9    Arc,
10};
11
12use futures::Future;
13use moka::sync::Cache;
14
15use crate::Result;
16
17pub use deepsize::{Context, DeepSizeOf};
18
19type ArcAny = Arc<dyn Any + Send + Sync>;
20
21#[derive(Clone)]
22struct SizedRecord {
23    record: ArcAny,
24    size_accessor: Arc<dyn Fn(&ArcAny) -> usize + Send + Sync>,
25}
26
27impl std::fmt::Debug for SizedRecord {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("SizedRecord")
30            .field("record", &self.record)
31            .finish()
32    }
33}
34
35impl SizedRecord {
36    fn new<T: DeepSizeOf + Send + Sync + 'static>(record: Arc<T>) -> Self {
37        // +8 for the size of the Arc pointer itself
38        let size_accessor =
39            |record: &ArcAny| -> usize { record.downcast_ref::<T>().unwrap().deep_size_of() + 8 };
40        Self {
41            record,
42            size_accessor: Arc::new(size_accessor),
43        }
44    }
45}
46
47#[derive(Clone)]
48pub struct LanceCache {
49    cache: Arc<Cache<(String, TypeId), SizedRecord>>,
50    prefix: String,
51    hits: Arc<AtomicU64>,
52    misses: Arc<AtomicU64>,
53}
54
55impl std::fmt::Debug for LanceCache {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("LanceCache")
58            .field("cache", &self.cache)
59            .finish()
60    }
61}
62
63impl DeepSizeOf for LanceCache {
64    fn deep_size_of_children(&self, _: &mut Context) -> usize {
65        self.cache
66            .iter()
67            .map(|(_, v)| (v.size_accessor)(&v.record))
68            .sum()
69    }
70}
71
72impl LanceCache {
73    pub fn with_capacity(capacity: usize) -> Self {
74        let cache = Cache::builder()
75            .max_capacity(capacity as u64)
76            .weigher(|_, v: &SizedRecord| {
77                (v.size_accessor)(&v.record).try_into().unwrap_or(u32::MAX)
78            })
79            .support_invalidation_closures()
80            .build();
81        Self {
82            cache: Arc::new(cache),
83            prefix: String::new(),
84            hits: Arc::new(AtomicU64::new(0)),
85            misses: Arc::new(AtomicU64::new(0)),
86        }
87    }
88
89    pub fn no_cache() -> Self {
90        Self {
91            cache: Arc::new(Cache::new(0)),
92            prefix: String::new(),
93            hits: Arc::new(AtomicU64::new(0)),
94            misses: Arc::new(AtomicU64::new(0)),
95        }
96    }
97
98    /// Appends a prefix to the cache key
99    ///
100    /// If this cache already has a prefix, the new prefix will be appended to
101    /// the existing one.
102    ///
103    /// Prefixes are used to create a namespace for the cache keys to avoid
104    /// collisions between different caches.
105    pub fn with_key_prefix(&self, prefix: &str) -> Self {
106        Self {
107            cache: self.cache.clone(),
108            prefix: format!("{}{}/", self.prefix, prefix),
109            hits: self.hits.clone(),
110            misses: self.misses.clone(),
111        }
112    }
113
114    fn get_key(&self, key: &str) -> String {
115        if self.prefix.is_empty() {
116            key.to_string()
117        } else {
118            format!("{}/{}", self.prefix, key)
119        }
120    }
121
122    /// Invalidate all entries in the cache that start with the given prefix
123    ///
124    /// The given prefix is appended to the existing prefix of the cache. If you
125    /// want to invalidate all at the current prefix, pass an empty string.
126    pub fn invalidate_prefix(&self, prefix: &str) {
127        let full_prefix = format!("{}{}", self.prefix, prefix);
128        self.cache
129            .invalidate_entries_if(move |(key, _typeid), _value| key.starts_with(&full_prefix))
130            .expect("Cache configured correctly");
131    }
132
133    pub fn size(&self) -> usize {
134        self.cache.run_pending_tasks();
135        self.cache.entry_count() as usize
136    }
137
138    pub fn approx_size(&self) -> usize {
139        self.cache.entry_count() as usize
140    }
141
142    pub fn size_bytes(&self) -> usize {
143        self.cache.run_pending_tasks();
144        self.approx_size_bytes()
145    }
146
147    pub fn approx_size_bytes(&self) -> usize {
148        self.cache.weighted_size() as usize
149    }
150
151    pub fn insert<T: DeepSizeOf + Send + Sync + 'static>(&self, key: &str, metadata: Arc<T>) {
152        let key = self.get_key(key);
153        self.cache
154            .insert((key, TypeId::of::<T>()), SizedRecord::new(metadata));
155    }
156
157    pub fn insert_unsized<T: DeepSizeOf + Send + Sync + 'static + ?Sized>(
158        &self,
159        key: &str,
160        metadata: Arc<T>,
161    ) {
162        // In order to make the data Sized, we wrap in another pointer.
163        self.insert(key, Arc::new(metadata))
164    }
165
166    pub fn get<T: DeepSizeOf + Send + Sync + 'static>(&self, key: &str) -> Option<Arc<T>> {
167        let key = self.get_key(key);
168        if let Some(metadata) = self.cache.get(&(key, TypeId::of::<T>())) {
169            self.hits.fetch_add(1, Ordering::Relaxed);
170            Some(metadata.record.clone().downcast::<T>().unwrap())
171        } else {
172            self.misses.fetch_add(1, Ordering::Relaxed);
173            None
174        }
175    }
176
177    pub fn get_unsized<T: DeepSizeOf + Send + Sync + 'static + ?Sized>(
178        &self,
179        key: &str,
180    ) -> Option<Arc<T>> {
181        let outer = self.get::<Arc<T>>(key)?;
182        Some(outer.as_ref().clone())
183    }
184
185    /// Get an item
186    ///
187    /// If it exists in the cache return that
188    ///
189    /// If it doesn't then run `loader` to load the item, insert into cache, and return
190    pub async fn get_or_insert<T: DeepSizeOf + Send + Sync + 'static, F, Fut>(
191        &self,
192        key: String,
193        loader: F,
194    ) -> Result<Arc<T>>
195    where
196        F: FnOnce(&str) -> Fut,
197        Fut: Future<Output = Result<T>>,
198    {
199        let full_key = self.get_key(&key);
200        if let Some(metadata) = self.cache.get(&(full_key, TypeId::of::<T>())) {
201            self.hits.fetch_add(1, Ordering::Relaxed);
202            return Ok(metadata.record.clone().downcast::<T>().unwrap());
203        }
204
205        self.misses.fetch_add(1, Ordering::Relaxed);
206        let metadata = Arc::new(loader(&key).await?);
207        self.insert(&key, metadata.clone());
208        Ok(metadata)
209    }
210
211    pub fn stats(&self) -> CacheStats {
212        CacheStats {
213            hits: self.hits.load(Ordering::Relaxed),
214            misses: self.misses.load(Ordering::Relaxed),
215        }
216    }
217}
218
219#[derive(Debug, Clone)]
220pub struct CacheStats {
221    /// Number of times `get`, `get_unsized`, or `get_or_insert` found an item in the cache.
222    pub hits: u64,
223    /// Number of times `get`, `get_unsized`, or `get_or_insert` did not find an item in the cache.
224    pub misses: u64,
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_cache_bytes() {
233        let item = Arc::new(vec![1, 2, 3]);
234        let item_size = item.deep_size_of(); // Size of Arc<Vec<i32>>
235        let capacity = 10 * item_size;
236
237        let cache = LanceCache::with_capacity(capacity);
238        assert_eq!(cache.size_bytes(), 0);
239        assert_eq!(cache.approx_size_bytes(), 0);
240
241        let item = Arc::new(vec![1, 2, 3]);
242        cache.insert("key", item.clone());
243        assert_eq!(cache.size(), 1);
244        assert_eq!(cache.size_bytes(), item_size);
245        assert_eq!(cache.approx_size_bytes(), item_size);
246
247        let retrieved = cache.get::<Vec<i32>>("key").unwrap();
248        assert_eq!(*retrieved, *item);
249
250        // Test eviction based on size
251        for i in 0..20 {
252            cache.insert(&format!("key_{}", i), Arc::new(vec![i, i, i]));
253        }
254        assert_eq!(cache.size_bytes(), capacity);
255        assert_eq!(cache.size(), 10);
256    }
257
258    #[test]
259    fn test_cache_trait_objects() {
260        #[derive(Debug, DeepSizeOf)]
261        struct MyType(i32);
262
263        trait MyTrait: DeepSizeOf + Send + Sync + Any {
264            fn as_any(&self) -> &dyn Any;
265        }
266
267        impl MyTrait for MyType {
268            fn as_any(&self) -> &dyn Any {
269                self
270            }
271        }
272
273        let item = Arc::new(MyType(42));
274        let item_dyn: Arc<dyn MyTrait> = item;
275
276        let cache = LanceCache::with_capacity(1000);
277        cache.insert_unsized("test", item_dyn);
278
279        let retrieved = cache.get_unsized::<dyn MyTrait>("test").unwrap();
280        let retrieved = retrieved.as_any().downcast_ref::<MyType>().unwrap();
281        assert_eq!(retrieved.0, 42);
282    }
283
284    #[test]
285    fn test_cache_stats_basic() {
286        let cache = LanceCache::with_capacity(1000);
287
288        // Initially no hits or misses
289        let stats = cache.stats();
290        assert_eq!(stats.hits, 0);
291        assert_eq!(stats.misses, 0);
292
293        // Miss on first get
294        let result = cache.get::<Vec<i32>>("nonexistent");
295        assert!(result.is_none());
296        let stats = cache.stats();
297        assert_eq!(stats.hits, 0);
298        assert_eq!(stats.misses, 1);
299
300        // Insert and then hit
301        cache.insert("key1", Arc::new(vec![1, 2, 3]));
302        let result = cache.get::<Vec<i32>>("key1");
303        assert!(result.is_some());
304        let stats = cache.stats();
305        assert_eq!(stats.hits, 1);
306        assert_eq!(stats.misses, 1);
307
308        // Another hit
309        let result = cache.get::<Vec<i32>>("key1");
310        assert!(result.is_some());
311        let stats = cache.stats();
312        assert_eq!(stats.hits, 2);
313        assert_eq!(stats.misses, 1);
314
315        // Another miss
316        let result = cache.get::<Vec<i32>>("nonexistent2");
317        assert!(result.is_none());
318        let stats = cache.stats();
319        assert_eq!(stats.hits, 2);
320        assert_eq!(stats.misses, 2);
321    }
322
323    #[test]
324    fn test_cache_stats_with_prefixes() {
325        let base_cache = LanceCache::with_capacity(1000);
326        let prefixed_cache = base_cache.with_key_prefix("test");
327
328        // Stats should be shared between base and prefixed cache
329        let stats = base_cache.stats();
330        assert_eq!(stats.hits, 0);
331        assert_eq!(stats.misses, 0);
332
333        let stats = prefixed_cache.stats();
334        assert_eq!(stats.hits, 0);
335        assert_eq!(stats.misses, 0);
336
337        // Miss on prefixed cache
338        let result = prefixed_cache.get::<Vec<i32>>("key1");
339        assert!(result.is_none());
340
341        // Both should show the miss
342        let stats = base_cache.stats();
343        assert_eq!(stats.hits, 0);
344        assert_eq!(stats.misses, 1);
345
346        let stats = prefixed_cache.stats();
347        assert_eq!(stats.hits, 0);
348        assert_eq!(stats.misses, 1);
349
350        // Insert through prefixed cache and hit
351        prefixed_cache.insert("key1", Arc::new(vec![1, 2, 3]));
352        let result = prefixed_cache.get::<Vec<i32>>("key1");
353        assert!(result.is_some());
354
355        // Both should show the hit
356        let stats = base_cache.stats();
357        assert_eq!(stats.hits, 1);
358        assert_eq!(stats.misses, 1);
359
360        let stats = prefixed_cache.stats();
361        assert_eq!(stats.hits, 1);
362        assert_eq!(stats.misses, 1);
363    }
364
365    #[test]
366    fn test_cache_stats_unsized() {
367        #[derive(Debug, DeepSizeOf)]
368        struct MyType(i32);
369
370        trait MyTrait: DeepSizeOf + Send + Sync + Any {}
371
372        impl MyTrait for MyType {}
373
374        let cache = LanceCache::with_capacity(1000);
375
376        // Miss on unsized get
377        let result = cache.get_unsized::<dyn MyTrait>("test");
378        assert!(result.is_none());
379        let stats = cache.stats();
380        assert_eq!(stats.hits, 0);
381        assert_eq!(stats.misses, 1);
382
383        // Insert and hit on unsized
384        let item = Arc::new(MyType(42));
385        let item_dyn: Arc<dyn MyTrait> = item;
386        cache.insert_unsized("test", item_dyn);
387
388        let result = cache.get_unsized::<dyn MyTrait>("test");
389        assert!(result.is_some());
390        let stats = cache.stats();
391        assert_eq!(stats.hits, 1);
392        assert_eq!(stats.misses, 1);
393    }
394
395    #[tokio::test]
396    async fn test_cache_stats_get_or_insert() {
397        let cache = LanceCache::with_capacity(1000);
398
399        // First call should be a miss and load the value
400        let result: Arc<Vec<i32>> = cache
401            .get_or_insert("key1".to_string(), |_key| async { Ok(vec![1, 2, 3]) })
402            .await
403            .unwrap();
404        assert_eq!(*result, vec![1, 2, 3]);
405
406        let stats = cache.stats();
407        assert_eq!(stats.hits, 0);
408        assert_eq!(stats.misses, 1);
409
410        // Second call should be a hit
411        let result: Arc<Vec<i32>> = cache
412            .get_or_insert("key1".to_string(), |_key| async {
413                panic!("Should not be called")
414            })
415            .await
416            .unwrap();
417        assert_eq!(*result, vec![1, 2, 3]);
418
419        let stats = cache.stats();
420        assert_eq!(stats.hits, 1);
421        assert_eq!(stats.misses, 1);
422
423        // Different key should be another miss
424        let result: Arc<Vec<i32>> = cache
425            .get_or_insert("key2".to_string(), |_key| async { Ok(vec![4, 5, 6]) })
426            .await
427            .unwrap();
428        assert_eq!(*result, vec![4, 5, 6]);
429
430        let stats = cache.stats();
431        assert_eq!(stats.hits, 1);
432        assert_eq!(stats.misses, 2);
433    }
434}