lance_core/
cache.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Cache implementation
5
6use std::any::{Any, TypeId};
7use std::borrow::Cow;
8use std::sync::{
9    atomic::{AtomicU64, Ordering},
10    Arc,
11};
12
13use futures::Future;
14use moka::sync::Cache;
15
16use crate::Result;
17
18pub use deepsize::{Context, DeepSizeOf};
19
20type ArcAny = Arc<dyn Any + Send + Sync>;
21
22#[derive(Clone)]
23struct SizedRecord {
24    record: ArcAny,
25    size_accessor: Arc<dyn Fn(&ArcAny) -> usize + Send + Sync>,
26}
27
28impl std::fmt::Debug for SizedRecord {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("SizedRecord")
31            .field("record", &self.record)
32            .finish()
33    }
34}
35
36impl SizedRecord {
37    fn new<T: DeepSizeOf + Send + Sync + 'static>(record: Arc<T>) -> Self {
38        // +8 for the size of the Arc pointer itself
39        let size_accessor =
40            |record: &ArcAny| -> usize { record.downcast_ref::<T>().unwrap().deep_size_of() + 8 };
41        Self {
42            record,
43            size_accessor: Arc::new(size_accessor),
44        }
45    }
46}
47
48#[derive(Clone)]
49pub struct LanceCache {
50    cache: Arc<Cache<(String, TypeId), SizedRecord>>,
51    prefix: String,
52    hits: Arc<AtomicU64>,
53    misses: Arc<AtomicU64>,
54}
55
56impl std::fmt::Debug for LanceCache {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.debug_struct("LanceCache")
59            .field("cache", &self.cache)
60            .finish()
61    }
62}
63
64impl DeepSizeOf for LanceCache {
65    fn deep_size_of_children(&self, _: &mut Context) -> usize {
66        self.cache
67            .iter()
68            .map(|(_, v)| (v.size_accessor)(&v.record))
69            .sum()
70    }
71}
72
73impl LanceCache {
74    pub fn with_capacity(capacity: usize) -> Self {
75        let cache = Cache::builder()
76            .max_capacity(capacity as u64)
77            .weigher(|_, v: &SizedRecord| {
78                (v.size_accessor)(&v.record).try_into().unwrap_or(u32::MAX)
79            })
80            .support_invalidation_closures()
81            .build();
82        Self {
83            cache: Arc::new(cache),
84            prefix: String::new(),
85            hits: Arc::new(AtomicU64::new(0)),
86            misses: Arc::new(AtomicU64::new(0)),
87        }
88    }
89
90    pub fn no_cache() -> Self {
91        Self {
92            cache: Arc::new(Cache::new(0)),
93            prefix: String::new(),
94            hits: Arc::new(AtomicU64::new(0)),
95            misses: Arc::new(AtomicU64::new(0)),
96        }
97    }
98
99    /// Appends a prefix to the cache key
100    ///
101    /// If this cache already has a prefix, the new prefix will be appended to
102    /// the existing one.
103    ///
104    /// Prefixes are used to create a namespace for the cache keys to avoid
105    /// collisions between different caches.
106    pub fn with_key_prefix(&self, prefix: &str) -> Self {
107        Self {
108            cache: self.cache.clone(),
109            prefix: format!("{}{}/", self.prefix, prefix),
110            hits: self.hits.clone(),
111            misses: self.misses.clone(),
112        }
113    }
114
115    fn get_key(&self, key: &str) -> String {
116        if self.prefix.is_empty() {
117            key.to_string()
118        } else {
119            format!("{}/{}", self.prefix, key)
120        }
121    }
122
123    /// Invalidate all entries in the cache that start with the given prefix
124    ///
125    /// The given prefix is appended to the existing prefix of the cache. If you
126    /// want to invalidate all at the current prefix, pass an empty string.
127    pub fn invalidate_prefix(&self, prefix: &str) {
128        let full_prefix = format!("{}{}", self.prefix, prefix);
129        self.cache
130            .invalidate_entries_if(move |(key, _typeid), _value| key.starts_with(&full_prefix))
131            .expect("Cache configured correctly");
132    }
133
134    pub fn size(&self) -> usize {
135        self.cache.run_pending_tasks();
136        self.cache.entry_count() as usize
137    }
138
139    pub fn approx_size(&self) -> usize {
140        self.cache.entry_count() as usize
141    }
142
143    pub fn size_bytes(&self) -> usize {
144        self.cache.run_pending_tasks();
145        self.approx_size_bytes()
146    }
147
148    pub fn approx_size_bytes(&self) -> usize {
149        self.cache.weighted_size() as usize
150    }
151
152    pub fn insert<T: DeepSizeOf + Send + Sync + 'static>(&self, key: &str, metadata: Arc<T>) {
153        let key = self.get_key(key);
154        let record = SizedRecord::new(metadata);
155        tracing::trace!(
156            target: "lance_cache::insert",
157            key = key,
158            type_id = std::any::type_name::<T>(),
159            size = (record.size_accessor)(&record.record),
160        );
161        self.cache.insert((key, TypeId::of::<T>()), record);
162    }
163
164    pub fn insert_unsized<T: DeepSizeOf + Send + Sync + 'static + ?Sized>(
165        &self,
166        key: &str,
167        metadata: Arc<T>,
168    ) {
169        // In order to make the data Sized, we wrap in another pointer.
170        self.insert(key, Arc::new(metadata))
171    }
172
173    pub fn get<T: DeepSizeOf + Send + Sync + 'static>(&self, key: &str) -> Option<Arc<T>> {
174        let key = self.get_key(key);
175        if let Some(metadata) = self.cache.get(&(key, TypeId::of::<T>())) {
176            self.hits.fetch_add(1, Ordering::Relaxed);
177            Some(metadata.record.clone().downcast::<T>().unwrap())
178        } else {
179            self.misses.fetch_add(1, Ordering::Relaxed);
180            None
181        }
182    }
183
184    pub fn get_unsized<T: DeepSizeOf + Send + Sync + 'static + ?Sized>(
185        &self,
186        key: &str,
187    ) -> Option<Arc<T>> {
188        let outer = self.get::<Arc<T>>(key)?;
189        Some(outer.as_ref().clone())
190    }
191
192    /// Get an item
193    ///
194    /// If it exists in the cache return that
195    ///
196    /// If it doesn't then run `loader` to load the item, insert into cache, and return
197    pub async fn get_or_insert<T: DeepSizeOf + Send + Sync + 'static, F, Fut>(
198        &self,
199        key: String,
200        loader: F,
201    ) -> Result<Arc<T>>
202    where
203        F: FnOnce(&str) -> Fut,
204        Fut: Future<Output = Result<T>>,
205    {
206        let full_key = self.get_key(&key);
207        if let Some(metadata) = self.cache.get(&(full_key, TypeId::of::<T>())) {
208            self.hits.fetch_add(1, Ordering::Relaxed);
209            return Ok(metadata.record.clone().downcast::<T>().unwrap());
210        }
211
212        self.misses.fetch_add(1, Ordering::Relaxed);
213        let metadata = Arc::new(loader(&key).await?);
214        self.insert(&key, metadata.clone());
215        Ok(metadata)
216    }
217
218    pub fn stats(&self) -> CacheStats {
219        CacheStats {
220            hits: self.hits.load(Ordering::Relaxed),
221            misses: self.misses.load(Ordering::Relaxed),
222        }
223    }
224
225    // CacheKey-based methods
226    pub fn insert_with_key<K>(&self, cache_key: &K, metadata: Arc<K::ValueType>)
227    where
228        K: CacheKey,
229        K::ValueType: DeepSizeOf + Send + Sync + 'static,
230    {
231        self.insert(&cache_key.key(), metadata)
232    }
233
234    pub fn get_with_key<K>(&self, cache_key: &K) -> Option<Arc<K::ValueType>>
235    where
236        K: CacheKey,
237        K::ValueType: DeepSizeOf + Send + Sync + 'static,
238    {
239        self.get::<K::ValueType>(&cache_key.key())
240    }
241
242    pub async fn get_or_insert_with_key<K, F, Fut>(
243        &self,
244        cache_key: K,
245        loader: F,
246    ) -> Result<Arc<K::ValueType>>
247    where
248        K: CacheKey,
249        K::ValueType: DeepSizeOf + Send + Sync + 'static,
250        F: FnOnce() -> Fut,
251        Fut: Future<Output = Result<K::ValueType>>,
252    {
253        let key_str = cache_key.key().into_owned();
254        self.get_or_insert(key_str, |_| loader()).await
255    }
256}
257
258pub trait CacheKey {
259    type ValueType;
260
261    fn key(&self) -> Cow<'_, str>;
262}
263
264#[derive(Debug, Clone)]
265pub struct CacheStats {
266    /// Number of times `get`, `get_unsized`, or `get_or_insert` found an item in the cache.
267    pub hits: u64,
268    /// Number of times `get`, `get_unsized`, or `get_or_insert` did not find an item in the cache.
269    pub misses: u64,
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_cache_bytes() {
278        let item = Arc::new(vec![1, 2, 3]);
279        let item_size = item.deep_size_of(); // Size of Arc<Vec<i32>>
280        let capacity = 10 * item_size;
281
282        let cache = LanceCache::with_capacity(capacity);
283        assert_eq!(cache.size_bytes(), 0);
284        assert_eq!(cache.approx_size_bytes(), 0);
285
286        let item = Arc::new(vec![1, 2, 3]);
287        cache.insert("key", item.clone());
288        assert_eq!(cache.size(), 1);
289        assert_eq!(cache.size_bytes(), item_size);
290        assert_eq!(cache.approx_size_bytes(), item_size);
291
292        let retrieved = cache.get::<Vec<i32>>("key").unwrap();
293        assert_eq!(*retrieved, *item);
294
295        // Test eviction based on size
296        for i in 0..20 {
297            cache.insert(&format!("key_{}", i), Arc::new(vec![i, i, i]));
298        }
299        assert_eq!(cache.size_bytes(), capacity);
300        assert_eq!(cache.size(), 10);
301    }
302
303    #[test]
304    fn test_cache_trait_objects() {
305        #[derive(Debug, DeepSizeOf)]
306        struct MyType(i32);
307
308        trait MyTrait: DeepSizeOf + Send + Sync + Any {
309            fn as_any(&self) -> &dyn Any;
310        }
311
312        impl MyTrait for MyType {
313            fn as_any(&self) -> &dyn Any {
314                self
315            }
316        }
317
318        let item = Arc::new(MyType(42));
319        let item_dyn: Arc<dyn MyTrait> = item;
320
321        let cache = LanceCache::with_capacity(1000);
322        cache.insert_unsized("test", item_dyn);
323
324        let retrieved = cache.get_unsized::<dyn MyTrait>("test").unwrap();
325        let retrieved = retrieved.as_any().downcast_ref::<MyType>().unwrap();
326        assert_eq!(retrieved.0, 42);
327    }
328
329    #[test]
330    fn test_cache_stats_basic() {
331        let cache = LanceCache::with_capacity(1000);
332
333        // Initially no hits or misses
334        let stats = cache.stats();
335        assert_eq!(stats.hits, 0);
336        assert_eq!(stats.misses, 0);
337
338        // Miss on first get
339        let result = cache.get::<Vec<i32>>("nonexistent");
340        assert!(result.is_none());
341        let stats = cache.stats();
342        assert_eq!(stats.hits, 0);
343        assert_eq!(stats.misses, 1);
344
345        // Insert and then hit
346        cache.insert("key1", Arc::new(vec![1, 2, 3]));
347        let result = cache.get::<Vec<i32>>("key1");
348        assert!(result.is_some());
349        let stats = cache.stats();
350        assert_eq!(stats.hits, 1);
351        assert_eq!(stats.misses, 1);
352
353        // Another hit
354        let result = cache.get::<Vec<i32>>("key1");
355        assert!(result.is_some());
356        let stats = cache.stats();
357        assert_eq!(stats.hits, 2);
358        assert_eq!(stats.misses, 1);
359
360        // Another miss
361        let result = cache.get::<Vec<i32>>("nonexistent2");
362        assert!(result.is_none());
363        let stats = cache.stats();
364        assert_eq!(stats.hits, 2);
365        assert_eq!(stats.misses, 2);
366    }
367
368    #[test]
369    fn test_cache_stats_with_prefixes() {
370        let base_cache = LanceCache::with_capacity(1000);
371        let prefixed_cache = base_cache.with_key_prefix("test");
372
373        // Stats should be shared between base and prefixed cache
374        let stats = base_cache.stats();
375        assert_eq!(stats.hits, 0);
376        assert_eq!(stats.misses, 0);
377
378        let stats = prefixed_cache.stats();
379        assert_eq!(stats.hits, 0);
380        assert_eq!(stats.misses, 0);
381
382        // Miss on prefixed cache
383        let result = prefixed_cache.get::<Vec<i32>>("key1");
384        assert!(result.is_none());
385
386        // Both should show the miss
387        let stats = base_cache.stats();
388        assert_eq!(stats.hits, 0);
389        assert_eq!(stats.misses, 1);
390
391        let stats = prefixed_cache.stats();
392        assert_eq!(stats.hits, 0);
393        assert_eq!(stats.misses, 1);
394
395        // Insert through prefixed cache and hit
396        prefixed_cache.insert("key1", Arc::new(vec![1, 2, 3]));
397        let result = prefixed_cache.get::<Vec<i32>>("key1");
398        assert!(result.is_some());
399
400        // Both should show the hit
401        let stats = base_cache.stats();
402        assert_eq!(stats.hits, 1);
403        assert_eq!(stats.misses, 1);
404
405        let stats = prefixed_cache.stats();
406        assert_eq!(stats.hits, 1);
407        assert_eq!(stats.misses, 1);
408    }
409
410    #[test]
411    fn test_cache_stats_unsized() {
412        #[derive(Debug, DeepSizeOf)]
413        struct MyType(i32);
414
415        trait MyTrait: DeepSizeOf + Send + Sync + Any {}
416
417        impl MyTrait for MyType {}
418
419        let cache = LanceCache::with_capacity(1000);
420
421        // Miss on unsized get
422        let result = cache.get_unsized::<dyn MyTrait>("test");
423        assert!(result.is_none());
424        let stats = cache.stats();
425        assert_eq!(stats.hits, 0);
426        assert_eq!(stats.misses, 1);
427
428        // Insert and hit on unsized
429        let item = Arc::new(MyType(42));
430        let item_dyn: Arc<dyn MyTrait> = item;
431        cache.insert_unsized("test", item_dyn);
432
433        let result = cache.get_unsized::<dyn MyTrait>("test");
434        assert!(result.is_some());
435        let stats = cache.stats();
436        assert_eq!(stats.hits, 1);
437        assert_eq!(stats.misses, 1);
438    }
439
440    #[tokio::test]
441    async fn test_cache_stats_get_or_insert() {
442        let cache = LanceCache::with_capacity(1000);
443
444        // First call should be a miss and load the value
445        let result: Arc<Vec<i32>> = cache
446            .get_or_insert("key1".to_string(), |_key| async { Ok(vec![1, 2, 3]) })
447            .await
448            .unwrap();
449        assert_eq!(*result, vec![1, 2, 3]);
450
451        let stats = cache.stats();
452        assert_eq!(stats.hits, 0);
453        assert_eq!(stats.misses, 1);
454
455        // Second call should be a hit
456        let result: Arc<Vec<i32>> = cache
457            .get_or_insert("key1".to_string(), |_key| async {
458                panic!("Should not be called")
459            })
460            .await
461            .unwrap();
462        assert_eq!(*result, vec![1, 2, 3]);
463
464        let stats = cache.stats();
465        assert_eq!(stats.hits, 1);
466        assert_eq!(stats.misses, 1);
467
468        // Different key should be another miss
469        let result: Arc<Vec<i32>> = cache
470            .get_or_insert("key2".to_string(), |_key| async { Ok(vec![4, 5, 6]) })
471            .await
472            .unwrap();
473        assert_eq!(*result, vec![4, 5, 6]);
474
475        let stats = cache.stats();
476        assert_eq!(stats.hits, 1);
477        assert_eq!(stats.misses, 2);
478    }
479}