lance_core/
cache.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Cache implementation
5
6use std::any::{Any, TypeId};
7use std::borrow::Cow;
8use std::sync::{
9    atomic::{AtomicU64, Ordering},
10    Arc,
11};
12
13use futures::{Future, FutureExt};
14use moka::future::Cache;
15use snafu::location;
16
17use crate::Result;
18
19pub use deepsize::{Context, DeepSizeOf};
20
21type ArcAny = Arc<dyn Any + Send + Sync>;
22
23#[derive(Clone)]
24struct SizedRecord {
25    record: ArcAny,
26    size_accessor: Arc<dyn Fn(&ArcAny) -> usize + Send + Sync>,
27}
28
29impl std::fmt::Debug for SizedRecord {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("SizedRecord")
32            .field("record", &self.record)
33            .finish()
34    }
35}
36
37impl SizedRecord {
38    fn new<T: DeepSizeOf + Send + Sync + 'static>(record: Arc<T>) -> Self {
39        // +8 for the size of the Arc pointer itself
40        let size_accessor =
41            |record: &ArcAny| -> usize { record.downcast_ref::<T>().unwrap().deep_size_of() + 8 };
42        Self {
43            record,
44            size_accessor: Arc::new(size_accessor),
45        }
46    }
47}
48
49#[derive(Clone)]
50pub struct LanceCache {
51    cache: Arc<Cache<(String, TypeId), SizedRecord>>,
52    prefix: String,
53    hits: Arc<AtomicU64>,
54    misses: Arc<AtomicU64>,
55}
56
57impl std::fmt::Debug for LanceCache {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("LanceCache")
60            .field("cache", &self.cache)
61            .finish()
62    }
63}
64
65impl DeepSizeOf for LanceCache {
66    fn deep_size_of_children(&self, _: &mut Context) -> usize {
67        self.cache
68            .iter()
69            .map(|(_, v)| (v.size_accessor)(&v.record))
70            .sum()
71    }
72}
73
74impl LanceCache {
75    pub fn with_capacity(capacity: usize) -> Self {
76        let cache = Cache::builder()
77            .max_capacity(capacity as u64)
78            .weigher(|_, v: &SizedRecord| {
79                (v.size_accessor)(&v.record).try_into().unwrap_or(u32::MAX)
80            })
81            .support_invalidation_closures()
82            .build();
83        Self {
84            cache: Arc::new(cache),
85            prefix: String::new(),
86            hits: Arc::new(AtomicU64::new(0)),
87            misses: Arc::new(AtomicU64::new(0)),
88        }
89    }
90
91    pub fn no_cache() -> Self {
92        Self {
93            cache: Arc::new(Cache::new(0)),
94            prefix: String::new(),
95            hits: Arc::new(AtomicU64::new(0)),
96            misses: Arc::new(AtomicU64::new(0)),
97        }
98    }
99
100    /// Appends a prefix to the cache key
101    ///
102    /// If this cache already has a prefix, the new prefix will be appended to
103    /// the existing one.
104    ///
105    /// Prefixes are used to create a namespace for the cache keys to avoid
106    /// collisions between different caches.
107    pub fn with_key_prefix(&self, prefix: &str) -> Self {
108        Self {
109            cache: self.cache.clone(),
110            prefix: format!("{}{}/", self.prefix, prefix),
111            hits: self.hits.clone(),
112            misses: self.misses.clone(),
113        }
114    }
115
116    fn get_key(&self, key: &str) -> String {
117        if self.prefix.is_empty() {
118            key.to_string()
119        } else {
120            format!("{}/{}", self.prefix, key)
121        }
122    }
123
124    /// Invalidate all entries in the cache that start with the given prefix
125    ///
126    /// The given prefix is appended to the existing prefix of the cache. If you
127    /// want to invalidate all at the current prefix, pass an empty string.
128    pub fn invalidate_prefix(&self, prefix: &str) {
129        let full_prefix = format!("{}{}", self.prefix, prefix);
130        self.cache
131            .invalidate_entries_if(move |(key, _typeid), _value| key.starts_with(&full_prefix))
132            .expect("Cache configured correctly");
133    }
134
135    pub async fn size(&self) -> usize {
136        self.cache.run_pending_tasks().await;
137        self.cache.entry_count() as usize
138    }
139
140    pub fn approx_size(&self) -> usize {
141        self.cache.entry_count() as usize
142    }
143
144    pub async fn size_bytes(&self) -> usize {
145        self.cache.run_pending_tasks().await;
146        self.approx_size_bytes()
147    }
148
149    pub fn approx_size_bytes(&self) -> usize {
150        self.cache.weighted_size() as usize
151    }
152
153    async fn insert<T: DeepSizeOf + Send + Sync + 'static>(&self, key: &str, metadata: Arc<T>) {
154        let key = self.get_key(key);
155        let record = SizedRecord::new(metadata);
156        tracing::trace!(
157            target: "lance_cache::insert",
158            key = key,
159            type_id = std::any::type_name::<T>(),
160            size = (record.size_accessor)(&record.record),
161        );
162        self.cache.insert((key, TypeId::of::<T>()), record).await;
163    }
164
165    pub async fn insert_unsized<T: DeepSizeOf + Send + Sync + 'static + ?Sized>(
166        &self,
167        key: &str,
168        metadata: Arc<T>,
169    ) {
170        // In order to make the data Sized, we wrap in another pointer.
171        self.insert(key, Arc::new(metadata)).await
172    }
173
174    async fn get<T: DeepSizeOf + Send + Sync + 'static>(&self, key: &str) -> Option<Arc<T>> {
175        let key = self.get_key(key);
176        if let Some(metadata) = self.cache.get(&(key, TypeId::of::<T>())).await {
177            self.hits.fetch_add(1, Ordering::Relaxed);
178            Some(metadata.record.clone().downcast::<T>().unwrap())
179        } else {
180            self.misses.fetch_add(1, Ordering::Relaxed);
181            None
182        }
183    }
184
185    pub async fn get_unsized<T: DeepSizeOf + Send + Sync + 'static + ?Sized>(
186        &self,
187        key: &str,
188    ) -> Option<Arc<T>> {
189        let outer = self.get::<Arc<T>>(key).await?;
190        Some(outer.as_ref().clone())
191    }
192
193    /// Get an item
194    ///
195    /// If it exists in the cache return that
196    ///
197    /// If it doesn't then run `loader` to load the item, insert into cache, and return
198    async fn get_or_insert<T: DeepSizeOf + Send + Sync + 'static, F, Fut>(
199        &self,
200        key: String,
201        loader: F,
202    ) -> Result<Arc<T>>
203    where
204        F: FnOnce(&str) -> Fut,
205        Fut: Future<Output = Result<T>> + Send,
206    {
207        let full_key = self.get_key(&key);
208        let cache_key = (full_key, TypeId::of::<T>());
209
210        // Use optionally_get_with to handle concurrent requests
211        let hits = self.hits.clone();
212        let misses = self.misses.clone();
213
214        // Use oneshot channels to track both errors and whether init was run
215        let (error_tx, error_rx) = tokio::sync::oneshot::channel();
216        let (init_run_tx, mut init_run_rx) = tokio::sync::oneshot::channel();
217
218        let init = Box::pin(async move {
219            let _ = init_run_tx.send(());
220            misses.fetch_add(1, Ordering::Relaxed);
221            match loader(&key).await {
222                Ok(value) => Some(SizedRecord::new(Arc::new(value))),
223                Err(e) => {
224                    let _ = error_tx.send(e);
225                    None
226                }
227            }
228        });
229
230        match self.cache.optionally_get_with(cache_key, init).await {
231            Some(metadata) => {
232                // Check if init was run or if this was a cache hit
233                match init_run_rx.try_recv() {
234                    Ok(()) => {
235                        // Init was run, miss was already recorded
236                    }
237                    Err(_) => {
238                        // Init was not run, this is a cache hit
239                        hits.fetch_add(1, Ordering::Relaxed);
240                    }
241                }
242                Ok(metadata.record.clone().downcast::<T>().unwrap())
243            }
244            None => {
245                // The loader returned an error, retrieve it from the channel
246                match error_rx.await {
247                    Ok(err) => Err(err),
248                    Err(_) => Err(crate::Error::Internal {
249                        message: "Failed to retrieve error from cache loader".into(),
250                        location: location!(),
251                    }),
252                }
253            }
254        }
255    }
256
257    pub async fn stats(&self) -> CacheStats {
258        self.cache.run_pending_tasks().await;
259        CacheStats {
260            hits: self.hits.load(Ordering::Relaxed),
261            misses: self.misses.load(Ordering::Relaxed),
262            num_entries: self.cache.entry_count() as usize,
263            size_bytes: self.cache.weighted_size() as usize,
264        }
265    }
266
267    pub async fn clear(&self) {
268        self.cache.invalidate_all();
269        self.cache.run_pending_tasks().await;
270        self.hits.store(0, Ordering::Relaxed);
271        self.misses.store(0, Ordering::Relaxed);
272    }
273
274    // CacheKey-based methods
275    pub async fn insert_with_key<K>(&self, cache_key: &K, metadata: Arc<K::ValueType>)
276    where
277        K: CacheKey,
278        K::ValueType: DeepSizeOf + Send + Sync + 'static,
279    {
280        self.insert(&cache_key.key(), metadata).boxed().await
281    }
282
283    pub async fn get_with_key<K>(&self, cache_key: &K) -> Option<Arc<K::ValueType>>
284    where
285        K: CacheKey,
286        K::ValueType: DeepSizeOf + Send + Sync + 'static,
287    {
288        self.get::<K::ValueType>(&cache_key.key()).boxed().await
289    }
290
291    pub async fn get_or_insert_with_key<K, F, Fut>(
292        &self,
293        cache_key: K,
294        loader: F,
295    ) -> Result<Arc<K::ValueType>>
296    where
297        K: CacheKey,
298        K::ValueType: DeepSizeOf + Send + Sync + 'static,
299        F: FnOnce() -> Fut,
300        Fut: Future<Output = Result<K::ValueType>> + Send,
301    {
302        let key_str = cache_key.key().into_owned();
303        Box::pin(self.get_or_insert(key_str, |_| loader())).await
304    }
305
306    pub async fn insert_unsized_with_key<K>(&self, cache_key: &K, metadata: Arc<K::ValueType>)
307    where
308        K: UnsizedCacheKey,
309        K::ValueType: DeepSizeOf + Send + Sync + 'static,
310    {
311        self.insert_unsized(&cache_key.key(), metadata)
312            .boxed()
313            .await
314    }
315
316    pub async fn get_unsized_with_key<K>(&self, cache_key: &K) -> Option<Arc<K::ValueType>>
317    where
318        K: UnsizedCacheKey,
319        K::ValueType: DeepSizeOf + Send + Sync + 'static,
320    {
321        self.get_unsized::<K::ValueType>(&cache_key.key())
322            .boxed()
323            .await
324    }
325}
326
327pub trait CacheKey {
328    type ValueType;
329
330    fn key(&self) -> Cow<'_, str>;
331}
332
333pub trait UnsizedCacheKey {
334    type ValueType: ?Sized;
335
336    fn key(&self) -> Cow<'_, str>;
337}
338
339#[derive(Debug, Clone)]
340pub struct CacheStats {
341    /// Number of times `get`, `get_unsized`, or `get_or_insert` found an item in the cache.
342    pub hits: u64,
343    /// Number of times `get`, `get_unsized`, or `get_or_insert` did not find an item in the cache.
344    pub misses: u64,
345    /// Number of entries currently in the cache.
346    pub num_entries: usize,
347    /// Total size in bytes of all entries in the cache.
348    pub size_bytes: usize,
349}
350
351impl CacheStats {
352    pub fn hit_ratio(&self) -> f32 {
353        if self.hits + self.misses == 0 {
354            0.0
355        } else {
356            self.hits as f32 / (self.hits + self.misses) as f32
357        }
358    }
359
360    pub fn miss_ratio(&self) -> f32 {
361        if self.hits + self.misses == 0 {
362            0.0
363        } else {
364            self.misses as f32 / (self.hits + self.misses) as f32
365        }
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[tokio::test]
374    async fn test_cache_bytes() {
375        let item = Arc::new(vec![1, 2, 3]);
376        let item_size = item.deep_size_of(); // Size of Arc<Vec<i32>>
377        let capacity = 10 * item_size;
378
379        let cache = LanceCache::with_capacity(capacity);
380        assert_eq!(cache.size_bytes().await, 0);
381        assert_eq!(cache.approx_size_bytes(), 0);
382
383        let item = Arc::new(vec![1, 2, 3]);
384        cache.insert("key", item.clone()).await;
385        assert_eq!(cache.size().await, 1);
386        assert_eq!(cache.size_bytes().await, item_size);
387        assert_eq!(cache.approx_size_bytes(), item_size);
388
389        let retrieved = cache.get::<Vec<i32>>("key").await.unwrap();
390        assert_eq!(*retrieved, *item);
391
392        // Test eviction based on size
393        for i in 0..20 {
394            cache
395                .insert(&format!("key_{}", i), Arc::new(vec![i, i, i]))
396                .await;
397        }
398        assert_eq!(cache.size_bytes().await, capacity);
399        assert_eq!(cache.size().await, 10);
400    }
401
402    #[tokio::test]
403    async fn test_cache_trait_objects() {
404        #[derive(Debug, DeepSizeOf)]
405        struct MyType(i32);
406
407        trait MyTrait: DeepSizeOf + Send + Sync + Any {
408            fn as_any(&self) -> &dyn Any;
409        }
410
411        impl MyTrait for MyType {
412            fn as_any(&self) -> &dyn Any {
413                self
414            }
415        }
416
417        let item = Arc::new(MyType(42));
418        let item_dyn: Arc<dyn MyTrait> = item;
419
420        let cache = LanceCache::with_capacity(1000);
421        cache.insert_unsized("test", item_dyn).await;
422
423        let retrieved = cache.get_unsized::<dyn MyTrait>("test").await.unwrap();
424        let retrieved = retrieved.as_any().downcast_ref::<MyType>().unwrap();
425        assert_eq!(retrieved.0, 42);
426    }
427
428    #[tokio::test]
429    async fn test_cache_stats_basic() {
430        let cache = LanceCache::with_capacity(1000);
431
432        // Initially no hits or misses
433        let stats = cache.stats().await;
434        assert_eq!(stats.hits, 0);
435        assert_eq!(stats.misses, 0);
436
437        // Miss on first get
438        let result = cache.get::<Vec<i32>>("nonexistent");
439        assert!(result.await.is_none());
440        let stats = cache.stats().await;
441        assert_eq!(stats.hits, 0);
442        assert_eq!(stats.misses, 1);
443
444        // Insert and then hit
445        cache.insert("key1", Arc::new(vec![1, 2, 3])).await;
446        let result = cache.get::<Vec<i32>>("key1");
447        assert!(result.await.is_some());
448        let stats = cache.stats().await;
449        assert_eq!(stats.hits, 1);
450        assert_eq!(stats.misses, 1);
451
452        // Another hit
453        let result = cache.get::<Vec<i32>>("key1");
454        assert!(result.await.is_some());
455        let stats = cache.stats().await;
456        assert_eq!(stats.hits, 2);
457        assert_eq!(stats.misses, 1);
458
459        // Another miss
460        let result = cache.get::<Vec<i32>>("nonexistent2");
461        assert!(result.await.is_none());
462        let stats = cache.stats().await;
463        assert_eq!(stats.hits, 2);
464        assert_eq!(stats.misses, 2);
465    }
466
467    #[tokio::test]
468    async fn test_cache_stats_with_prefixes() {
469        let base_cache = LanceCache::with_capacity(1000);
470        let prefixed_cache = base_cache.with_key_prefix("test");
471
472        // Stats should be shared between base and prefixed cache
473        let stats = base_cache.stats().await;
474        assert_eq!(stats.hits, 0);
475        assert_eq!(stats.misses, 0);
476
477        let stats = prefixed_cache.stats().await;
478        assert_eq!(stats.hits, 0);
479        assert_eq!(stats.misses, 0);
480
481        // Miss on prefixed cache
482        let result = prefixed_cache.get::<Vec<i32>>("key1");
483        assert!(result.await.is_none());
484
485        // Both should show the miss
486        let stats = base_cache.stats().await;
487        assert_eq!(stats.hits, 0);
488        assert_eq!(stats.misses, 1);
489
490        let stats = prefixed_cache.stats().await;
491        assert_eq!(stats.hits, 0);
492        assert_eq!(stats.misses, 1);
493
494        // Insert through prefixed cache and hit
495        prefixed_cache.insert("key1", Arc::new(vec![1, 2, 3])).await;
496        let result = prefixed_cache.get::<Vec<i32>>("key1");
497        assert!(result.await.is_some());
498
499        // Both should show the hit
500        let stats = base_cache.stats().await;
501        assert_eq!(stats.hits, 1);
502        assert_eq!(stats.misses, 1);
503
504        let stats = prefixed_cache.stats().await;
505        assert_eq!(stats.hits, 1);
506        assert_eq!(stats.misses, 1);
507    }
508
509    #[tokio::test]
510    async fn test_cache_stats_unsized() {
511        #[derive(Debug, DeepSizeOf)]
512        struct MyType(i32);
513
514        trait MyTrait: DeepSizeOf + Send + Sync + Any {}
515
516        impl MyTrait for MyType {}
517
518        let cache = LanceCache::with_capacity(1000);
519
520        // Miss on unsized get
521        let result = cache.get_unsized::<dyn MyTrait>("test");
522        assert!(result.await.is_none());
523        let stats = cache.stats().await;
524        assert_eq!(stats.hits, 0);
525        assert_eq!(stats.misses, 1);
526
527        // Insert and hit on unsized
528        let item = Arc::new(MyType(42));
529        let item_dyn: Arc<dyn MyTrait> = item;
530        cache.insert_unsized("test", item_dyn).await;
531
532        let result = cache.get_unsized::<dyn MyTrait>("test");
533        assert!(result.await.is_some());
534        let stats = cache.stats().await;
535        assert_eq!(stats.hits, 1);
536        assert_eq!(stats.misses, 1);
537    }
538
539    #[tokio::test]
540    async fn test_cache_stats_get_or_insert() {
541        let cache = LanceCache::with_capacity(1000);
542
543        // First call should be a miss and load the value
544        let result: Arc<Vec<i32>> = cache
545            .get_or_insert("key1".to_string(), |_key| async { Ok(vec![1, 2, 3]) })
546            .await
547            .unwrap();
548        assert_eq!(*result, vec![1, 2, 3]);
549
550        let stats = cache.stats().await;
551        assert_eq!(stats.hits, 0);
552        assert_eq!(stats.misses, 1);
553
554        // Second call should be a hit
555        let result: Arc<Vec<i32>> = cache
556            .get_or_insert("key1".to_string(), |_key| async {
557                panic!("Should not be called")
558            })
559            .await
560            .unwrap();
561        assert_eq!(*result, vec![1, 2, 3]);
562
563        let stats = cache.stats().await;
564        assert_eq!(stats.hits, 1);
565        assert_eq!(stats.misses, 1);
566
567        // Different key should be another miss
568        let result: Arc<Vec<i32>> = cache
569            .get_or_insert("key2".to_string(), |_key| async { Ok(vec![4, 5, 6]) })
570            .await
571            .unwrap();
572        assert_eq!(*result, vec![4, 5, 6]);
573
574        let stats = cache.stats().await;
575        assert_eq!(stats.hits, 1);
576        assert_eq!(stats.misses, 2);
577    }
578}