Skip to main content

lance_core/cache/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Lance cache system.
5//!
6//! ## For cache users
7//!
8//! Use [`LanceCache`] (or [`WeakLanceCache`]) to store and retrieve typed
9//! values. Define a [`CacheKey`] (or [`UnsizedCacheKey`] for trait objects) to
10//! describe what you're caching and its type.
11//!
12//! To make a value type serializable (so persistent backends can store it),
13//! implement [`CacheCodecImpl`] on the type, then override [`CacheKey::codec`]:
14//!
15//! ```ignore
16//! impl CacheCodecImpl for MyData {
17//!     fn serialize(&self, w: &mut dyn Write) -> Result<()> { /* ... */ }
18//!     fn deserialize(data: &Bytes) -> Result<Self> { /* ... */ }
19//! }
20//!
21//! impl CacheKey for MyDataKey {
22//!     type ValueType = MyData;
23//!     fn key(&self) -> Cow<'_, str> { /* ... */ }
24//!     fn type_name() -> &'static str { "MyData" }
25//!     fn codec() -> Option<CacheCodec> {
26//!         Some(CacheCodec::from_impl::<MyData>())
27//!     }
28//! }
29//! ```
30//!
31//! ## For backend implementors
32//!
33//! Implement [`CacheBackend`] to provide a custom storage layer (disk, Redis,
34//! etc.). Backends receive [`InternalCacheKey`] keys and type-erased
35//! [`CacheEntry`] values — the typed wrapping is handled by [`LanceCache`].
36//! See the [`backend`] module for details.
37//!
38//! ## Serialization flow
39//!
40//! When a [`CacheKey`] provides a codec via [`CacheKey::codec`]:
41//!
42//! 1. [`LanceCache`] wraps the [`CacheCodec`] and passes it to the backend
43//!    alongside the entry on `insert` and `get` calls.
44//! 2. In-memory backends (like [`MokaCacheBackend`]) ignore the codec.
45//! 3. Persistent backends use `codec.serialize(entry, writer)` on insert and
46//!    `codec.deserialize(reader)` on get to persist entries across restarts.
47
48pub mod backend;
49pub mod codec;
50mod moka;
51
52pub use backend::{CacheBackend, CacheEntry, InternalCacheKey};
53pub use codec::{CacheCodec, CacheCodecImpl};
54pub use moka::MokaCacheBackend;
55
56use std::borrow::Cow;
57use std::sync::{
58    Arc,
59    atomic::{AtomicU64, Ordering},
60};
61
62use futures::{Future, FutureExt};
63
64use crate::Result;
65
66pub use deepsize::{Context, DeepSizeOf};
67
68// ---------------------------------------------------------------------------
69// CacheKey / UnsizedCacheKey — typed key traits for cache users
70// ---------------------------------------------------------------------------
71
72/// Typed cache key for sized value types.
73///
74/// Implement this trait to define a new type of cached entry. [`LanceCache`]
75/// uses the key string and type name to construct an [`InternalCacheKey`]
76/// for the backend.
77///
78/// # Example
79///
80/// ```ignore
81/// struct MyKey { id: u64 }
82///
83/// impl CacheKey for MyKey {
84///     type ValueType = MyData;
85///     fn key(&self) -> Cow<'_, str> { self.id.to_string().into() }
86///     fn type_name() -> &'static str { "MyData" }
87/// }
88/// ```
89pub trait CacheKey {
90    type ValueType: 'static;
91
92    fn key(&self) -> Cow<'_, str>;
93
94    /// Short, stable string identifying this value type.
95    ///
96    /// Two `CacheKey` impls that store different `ValueType`s **must** return
97    /// different type names; if they collide, gets will silently return `None`
98    /// due to failed downcasts.
99    ///
100    /// Use a short literal (e.g. `"Vec<IndexMetadata>"`), not
101    /// `std::any::type_name` — the latter is not guaranteed stable across
102    /// compiler versions or build configurations.
103    fn type_name() -> &'static str;
104
105    /// Optional codec for serializing/deserializing this key's value type.
106    ///
107    /// Returns `None` by default. Cache backends that support persistence
108    /// (e.g. disk-backed caches) use this to serialize entries on insert and
109    /// deserialize on get. Types without a codec will only be stored in-memory.
110    ///
111    /// [`CacheCodec`] is `Copy` (two plain function pointers), so returning it
112    /// by value is cheap — no allocation needed.
113    fn codec() -> Option<CacheCodec> {
114        None
115    }
116}
117
118/// Like [`CacheKey`] but for unsized value types (e.g. `dyn Trait`).
119///
120/// The cache wraps values in an extra `Arc` layer internally; callers pass
121/// and receive `Arc<T>` where `T: ?Sized`.
122///
123/// Unsized cache entries are always in-memory only (no serialization codec).
124/// For serializable entries, use a sized [`CacheKey`] instead.
125pub trait UnsizedCacheKey {
126    type ValueType: 'static + ?Sized;
127
128    fn key(&self) -> Cow<'_, str>;
129
130    /// Short, stable string identifying this value type.
131    /// See [`CacheKey::type_name`] for requirements.
132    fn type_name() -> &'static str;
133}
134
135// ---------------------------------------------------------------------------
136// Internal helpers
137// ---------------------------------------------------------------------------
138
139/// Size of a cached `Arc<T>`, accounting for the Arc overhead (two atomic counters).
140fn cache_entry_size<T: DeepSizeOf + ?Sized>(value: &T) -> usize {
141    value.deep_size_of() + std::mem::size_of::<std::sync::atomic::AtomicUsize>() * 2
142}
143
144/// Build an [`InternalCacheKey`] from a cache's prefix, a user key string,
145/// and a type name.
146fn build_key(prefix: &Arc<str>, key: &str, type_name: &'static str) -> InternalCacheKey {
147    InternalCacheKey::new(prefix.clone(), Arc::from(key), type_name)
148}
149
150// ---------------------------------------------------------------------------
151// LanceCache — typed wrapper around dyn CacheBackend
152// ---------------------------------------------------------------------------
153
154/// Typed cache wrapper that handles key construction and type safety.
155///
156/// Internally delegates to a [`CacheBackend`]. The default backend is
157/// [`MokaCacheBackend`]; pass a custom backend via [`LanceCache::with_backend`].
158#[derive(Clone)]
159pub struct LanceCache {
160    cache: Arc<dyn CacheBackend>,
161    prefix: Arc<str>,
162    hits: Arc<AtomicU64>,
163    misses: Arc<AtomicU64>,
164}
165
166impl std::fmt::Debug for LanceCache {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        f.debug_struct("LanceCache")
169            .field("cache", &self.cache)
170            .finish()
171    }
172}
173
174impl DeepSizeOf for LanceCache {
175    fn deep_size_of_children(&self, _: &mut Context) -> usize {
176        self.cache.approx_size_bytes()
177    }
178}
179
180impl LanceCache {
181    pub fn with_capacity(capacity: usize) -> Self {
182        Self {
183            cache: Arc::new(MokaCacheBackend::with_capacity(capacity)),
184            prefix: Arc::from(""),
185            hits: Arc::new(AtomicU64::new(0)),
186            misses: Arc::new(AtomicU64::new(0)),
187        }
188    }
189
190    /// Create a cache backed by a custom [`CacheBackend`].
191    pub fn with_backend(backend: Arc<dyn CacheBackend>) -> Self {
192        Self {
193            cache: backend,
194            prefix: Arc::from(""),
195            hits: Arc::new(AtomicU64::new(0)),
196            misses: Arc::new(AtomicU64::new(0)),
197        }
198    }
199
200    pub fn no_cache() -> Self {
201        Self {
202            cache: Arc::new(MokaCacheBackend::no_cache()),
203            prefix: Arc::from(""),
204            hits: Arc::new(AtomicU64::new(0)),
205            misses: Arc::new(AtomicU64::new(0)),
206        }
207    }
208
209    /// Create a cache with the given backend and an exact prefix string.
210    /// Unlike `with_key_prefix`, this sets the prefix verbatim (no trailing slash added).
211    pub fn with_backend_and_prefix(backend: Arc<dyn CacheBackend>, prefix: String) -> Self {
212        Self {
213            cache: backend,
214            prefix: Arc::from(prefix),
215            hits: Arc::new(AtomicU64::new(0)),
216            misses: Arc::new(AtomicU64::new(0)),
217        }
218    }
219
220    /// Appends a prefix to the cache key.
221    pub fn with_key_prefix(&self, prefix: &str) -> Self {
222        Self {
223            cache: self.cache.clone(),
224            prefix: Arc::from(format!("{}{}/", self.prefix, prefix)),
225            hits: self.hits.clone(),
226            misses: self.misses.clone(),
227        }
228    }
229
230    /// Invalidate all entries whose prefix starts with the given string.
231    pub async fn invalidate_prefix(&self, prefix: &str) {
232        let full_prefix = format!("{}{}", self.prefix, prefix);
233        self.cache.invalidate_prefix(&full_prefix).await;
234    }
235
236    pub async fn size(&self) -> usize {
237        self.cache.num_entries().await
238    }
239
240    pub fn approx_size(&self) -> usize {
241        self.cache.approx_num_entries()
242    }
243
244    pub async fn size_bytes(&self) -> usize {
245        self.cache.size_bytes().await
246    }
247
248    // -- Sized insert/get (internal, shared by sized and unsized paths) --------
249
250    async fn insert_with_id<T: DeepSizeOf + Send + Sync + 'static>(
251        &self,
252        key: &str,
253        type_name: &'static str,
254        codec: Option<CacheCodec>,
255        metadata: Arc<T>,
256    ) {
257        let size = cache_entry_size(&*metadata);
258        let cache_key = build_key(&self.prefix, key, type_name);
259        self.cache.insert(&cache_key, metadata, size, codec).await;
260    }
261
262    async fn get_with_id<T: Send + Sync + 'static>(
263        &self,
264        key: &str,
265        type_name: &'static str,
266        codec: Option<CacheCodec>,
267    ) -> Option<Arc<T>> {
268        let cache_key = build_key(&self.prefix, key, type_name);
269        if let Some(entry) = self.cache.get(&cache_key, codec).await {
270            match entry.downcast::<T>() {
271                Ok(val) => {
272                    self.hits.fetch_add(1, Ordering::Relaxed);
273                    Some(val)
274                }
275                Err(_) => {
276                    // Type mismatch: the backend returned a different concrete
277                    // type than expected (e.g. a disk cache may store
278                    // intermediate state). Treat as a miss.
279                    self.misses.fetch_add(1, Ordering::Relaxed);
280                    None
281                }
282            }
283        } else {
284            self.misses.fetch_add(1, Ordering::Relaxed);
285            None
286        }
287    }
288
289    // -- Stats / clear --------------------------------------------------------
290
291    pub async fn stats(&self) -> CacheStats {
292        CacheStats {
293            hits: self.hits.load(Ordering::Relaxed),
294            misses: self.misses.load(Ordering::Relaxed),
295            num_entries: self.cache.num_entries().await,
296            size_bytes: self.cache.size_bytes().await,
297        }
298    }
299
300    pub async fn clear(&self) {
301        self.cache.clear().await;
302        self.hits.store(0, Ordering::Relaxed);
303        self.misses.store(0, Ordering::Relaxed);
304    }
305
306    // -- CacheKey-based methods -----------------------------------------------
307
308    pub async fn insert_with_key<K>(&self, cache_key: &K, metadata: Arc<K::ValueType>)
309    where
310        K: CacheKey,
311        K::ValueType: DeepSizeOf + Send + Sync + 'static,
312    {
313        self.insert_with_id(&cache_key.key(), K::type_name(), K::codec(), metadata)
314            .boxed()
315            .await
316    }
317
318    pub async fn get_with_key<K>(&self, cache_key: &K) -> Option<Arc<K::ValueType>>
319    where
320        K: CacheKey,
321        K::ValueType: DeepSizeOf + Send + Sync + 'static,
322    {
323        self.get_with_id::<K::ValueType>(&cache_key.key(), K::type_name(), K::codec())
324            .boxed()
325            .await
326    }
327
328    pub async fn get_or_insert_with_key<K, F, Fut>(
329        &self,
330        cache_key: K,
331        loader: F,
332    ) -> Result<Arc<K::ValueType>>
333    where
334        K: CacheKey,
335        K::ValueType: DeepSizeOf + Send + Sync + 'static,
336        F: FnOnce() -> Fut + Send,
337        Fut: Future<Output = Result<K::ValueType>> + Send,
338    {
339        let key = build_key(&self.prefix, &cache_key.key(), K::type_name());
340
341        let typed_loader = Box::pin(async move {
342            let value = loader().await?;
343            let arc = Arc::new(value);
344            let size = cache_entry_size(&*arc);
345            Ok((arc as CacheEntry, size))
346        });
347
348        let (entry, was_cached) = self
349            .cache
350            .get_or_insert(&key, typed_loader, K::codec())
351            .await?;
352
353        if was_cached {
354            self.hits.fetch_add(1, Ordering::Relaxed);
355        } else {
356            self.misses.fetch_add(1, Ordering::Relaxed);
357        }
358
359        Ok(entry.downcast::<K::ValueType>().unwrap())
360    }
361
362    pub async fn insert_unsized_with_key<K>(&self, cache_key: &K, metadata: Arc<K::ValueType>)
363    where
364        K: UnsizedCacheKey,
365        K::ValueType: DeepSizeOf + Send + Sync + 'static,
366    {
367        self.insert_with_id(&cache_key.key(), K::type_name(), None, Arc::new(metadata))
368            .boxed()
369            .await
370    }
371
372    pub async fn get_unsized_with_key<K>(&self, cache_key: &K) -> Option<Arc<K::ValueType>>
373    where
374        K: UnsizedCacheKey,
375        K::ValueType: DeepSizeOf + Send + Sync + 'static,
376    {
377        let outer = self
378            .get_with_id::<Arc<K::ValueType>>(&cache_key.key(), K::type_name(), None)
379            .boxed()
380            .await?;
381        Some(outer.as_ref().clone())
382    }
383}
384
385// ---------------------------------------------------------------------------
386// WeakLanceCache
387// ---------------------------------------------------------------------------
388
389/// A weak reference to a LanceCache, used by indices to avoid circular references.
390/// When the original cache is dropped, operations on this will gracefully no-op.
391#[derive(Clone, Debug)]
392pub struct WeakLanceCache {
393    inner: std::sync::Weak<dyn CacheBackend>,
394    prefix: Arc<str>,
395    hits: Arc<AtomicU64>,
396    misses: Arc<AtomicU64>,
397}
398
399impl WeakLanceCache {
400    pub fn from(cache: &LanceCache) -> Self {
401        Self {
402            inner: Arc::downgrade(&cache.cache),
403            prefix: cache.prefix.clone(),
404            hits: cache.hits.clone(),
405            misses: cache.misses.clone(),
406        }
407    }
408
409    pub fn with_key_prefix(&self, prefix: &str) -> Self {
410        Self {
411            inner: self.inner.clone(),
412            prefix: Arc::from(format!("{}{}/", self.prefix, prefix)),
413            hits: self.hits.clone(),
414            misses: self.misses.clone(),
415        }
416    }
417
418    /// The key prefix used for all entries in this cache.
419    pub fn prefix(&self) -> &str {
420        &self.prefix
421    }
422
423    pub async fn get_with_key<K>(&self, cache_key: &K) -> Option<Arc<K::ValueType>>
424    where
425        K: CacheKey,
426        K::ValueType: DeepSizeOf + Send + Sync + 'static,
427    {
428        let cache = self.inner.upgrade()?;
429        let key = build_key(&self.prefix, &cache_key.key(), K::type_name());
430        if let Some(entry) = cache.get(&key, K::codec()).await {
431            self.hits.fetch_add(1, Ordering::Relaxed);
432            Some(entry.downcast::<K::ValueType>().unwrap())
433        } else {
434            self.misses.fetch_add(1, Ordering::Relaxed);
435            None
436        }
437    }
438
439    pub async fn insert_with_key<K>(&self, cache_key: &K, value: Arc<K::ValueType>) -> bool
440    where
441        K: CacheKey,
442        K::ValueType: DeepSizeOf + Send + Sync + 'static,
443    {
444        if let Some(cache) = self.inner.upgrade() {
445            let size = cache_entry_size(&*value);
446            let key = build_key(&self.prefix, &cache_key.key(), K::type_name());
447            cache.insert(&key, value, size, K::codec()).await;
448            true
449        } else {
450            log::warn!("WeakLanceCache: cache no longer available, unable to insert item");
451            false
452        }
453    }
454
455    /// Get or insert an item, computing it if necessary.
456    ///
457    /// Deduplication of concurrent loads is handled by the backend.
458    pub async fn get_or_insert_with_key<K, F, Fut>(
459        &self,
460        cache_key: K,
461        loader: F,
462    ) -> Result<Arc<K::ValueType>>
463    where
464        K: CacheKey,
465        K::ValueType: DeepSizeOf + Send + Sync + 'static,
466        F: FnOnce() -> Fut + Send,
467        Fut: Future<Output = Result<K::ValueType>> + Send,
468    {
469        if let Some(cache) = self.inner.upgrade() {
470            let key = build_key(&self.prefix, &cache_key.key(), K::type_name());
471            let typed_loader = Box::pin(async move {
472                let value = loader().await?;
473                let arc = Arc::new(value);
474                let size = cache_entry_size(&*arc);
475                Ok((arc as CacheEntry, size))
476            });
477            let (entry, was_cached) = cache.get_or_insert(&key, typed_loader, K::codec()).await?;
478            if was_cached {
479                self.hits.fetch_add(1, Ordering::Relaxed);
480            } else {
481                self.misses.fetch_add(1, Ordering::Relaxed);
482            }
483            Ok(entry.downcast::<K::ValueType>().unwrap())
484        } else {
485            log::warn!("WeakLanceCache: cache no longer available, computing without caching");
486            loader().await.map(Arc::new)
487        }
488    }
489
490    pub async fn get_unsized_with_key<K>(&self, cache_key: &K) -> Option<Arc<K::ValueType>>
491    where
492        K: UnsizedCacheKey,
493        K::ValueType: DeepSizeOf + Send + Sync + 'static,
494    {
495        let cache = self.inner.upgrade()?;
496        let key = build_key(&self.prefix, &cache_key.key(), K::type_name());
497        if let Some(entry) = cache.get(&key, None).await {
498            entry
499                .downcast::<Arc<K::ValueType>>()
500                .ok()
501                .map(|arc| arc.as_ref().clone())
502        } else {
503            None
504        }
505    }
506
507    pub async fn insert_unsized_with_key<K>(&self, cache_key: &K, value: Arc<K::ValueType>)
508    where
509        K: UnsizedCacheKey,
510        K::ValueType: DeepSizeOf + Send + Sync + 'static,
511    {
512        if let Some(cache) = self.inner.upgrade() {
513            let wrapper = Arc::new(value);
514            let size = cache_entry_size(&*wrapper);
515            let key = build_key(&self.prefix, &cache_key.key(), K::type_name());
516            cache.insert(&key, wrapper, size, None).await;
517        } else {
518            log::warn!("WeakLanceCache: cache no longer available, unable to insert unsized item");
519        }
520    }
521}
522
523// ---------------------------------------------------------------------------
524// CacheStats
525// ---------------------------------------------------------------------------
526
527#[derive(Debug, Clone)]
528pub struct CacheStats {
529    /// Number of times `get`, `get_unsized`, or `get_or_insert` found an item in the cache.
530    pub hits: u64,
531    /// Number of times `get`, `get_unsized`, or `get_or_insert` did not find an item in the cache.
532    pub misses: u64,
533    /// Number of entries currently in the cache.
534    pub num_entries: usize,
535    /// Total size in bytes of all entries in the cache.
536    pub size_bytes: usize,
537}
538
539impl CacheStats {
540    pub fn hit_ratio(&self) -> f32 {
541        if self.hits + self.misses == 0 {
542            0.0
543        } else {
544            self.hits as f32 / (self.hits + self.misses) as f32
545        }
546    }
547
548    pub fn miss_ratio(&self) -> f32 {
549        if self.hits + self.misses == 0 {
550            0.0
551        } else {
552            self.misses as f32 / (self.hits + self.misses) as f32
553        }
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560    use std::collections::HashMap;
561    use std::marker::PhantomData;
562
563    struct TestKey<T: 'static> {
564        key: String,
565        _phantom: PhantomData<T>,
566    }
567
568    impl<T: 'static> TestKey<T> {
569        fn new(key: &str) -> Self {
570            Self {
571                key: key.to_string(),
572                _phantom: PhantomData,
573            }
574        }
575    }
576
577    impl<T: 'static> CacheKey for TestKey<T> {
578        type ValueType = T;
579        fn key(&self) -> std::borrow::Cow<'_, str> {
580            std::borrow::Cow::Borrowed(&self.key)
581        }
582        fn type_name() -> &'static str {
583            std::any::type_name::<T>()
584        }
585    }
586
587    /// Test helper: an UnsizedCacheKey for trait object values.
588    struct TestUnsizedKey<T: 'static + ?Sized> {
589        key: String,
590        _phantom: PhantomData<T>,
591    }
592
593    impl<T: 'static + ?Sized> TestUnsizedKey<T> {
594        fn new(key: &str) -> Self {
595            Self {
596                key: key.to_string(),
597                _phantom: PhantomData,
598            }
599        }
600    }
601
602    impl<T: 'static + ?Sized> UnsizedCacheKey for TestUnsizedKey<T> {
603        type ValueType = T;
604        fn key(&self) -> std::borrow::Cow<'_, str> {
605            std::borrow::Cow::Borrowed(&self.key)
606        }
607        fn type_name() -> &'static str {
608            std::any::type_name::<T>()
609        }
610    }
611
612    #[tokio::test]
613    async fn test_cache_bytes() {
614        let item = Arc::new(vec![1, 2, 3]);
615        let item_size = item.deep_size_of();
616        let capacity = 10 * item_size;
617        let cache = LanceCache::with_capacity(capacity);
618
619        cache
620            .insert_with_key(&TestKey::<Vec<i32>>::new("key"), item.clone())
621            .await;
622        assert_eq!(cache.size().await, 1);
623
624        let retrieved = cache
625            .get_with_key(&TestKey::<Vec<i32>>::new("key"))
626            .await
627            .unwrap();
628        assert_eq!(*retrieved, *item);
629
630        for i in 0..20 {
631            cache
632                .insert_with_key(
633                    &TestKey::<Vec<i32>>::new(&format!("key_{}", i)),
634                    Arc::new(vec![i, i, i]),
635                )
636                .await;
637        }
638        assert!(cache.size_bytes().await <= capacity);
639    }
640
641    #[tokio::test]
642    async fn test_cache_trait_objects() {
643        #[derive(Debug, DeepSizeOf)]
644        struct MyType(i32);
645
646        trait MyTrait: DeepSizeOf + Send + Sync + std::any::Any {
647            fn as_any(&self) -> &dyn std::any::Any;
648        }
649
650        impl MyTrait for MyType {
651            fn as_any(&self) -> &dyn std::any::Any {
652                self
653            }
654        }
655
656        let item: Arc<dyn MyTrait> = Arc::new(MyType(42));
657        let cache = LanceCache::with_capacity(1000);
658        cache
659            .insert_unsized_with_key(&TestUnsizedKey::<dyn MyTrait>::new("test"), item)
660            .await;
661
662        let retrieved = cache
663            .get_unsized_with_key(&TestUnsizedKey::<dyn MyTrait>::new("test"))
664            .await
665            .unwrap();
666        assert_eq!(retrieved.as_any().downcast_ref::<MyType>().unwrap().0, 42);
667    }
668
669    #[tokio::test]
670    async fn test_cache_stats_basic() {
671        let cache = LanceCache::with_capacity(1000);
672        assert_eq!(cache.stats().await.hits, 0);
673
674        // Miss
675        assert!(
676            cache
677                .get_with_key(&TestKey::<Vec<i32>>::new("x"))
678                .await
679                .is_none()
680        );
681        assert_eq!(cache.stats().await.misses, 1);
682
683        // Insert then hit
684        cache
685            .insert_with_key(&TestKey::new("k"), Arc::new(vec![1, 2, 3]))
686            .await;
687        assert!(
688            cache
689                .get_with_key(&TestKey::<Vec<i32>>::new("k"))
690                .await
691                .is_some()
692        );
693        assert_eq!(cache.stats().await.hits, 1);
694    }
695
696    #[tokio::test]
697    async fn test_cache_stats_with_prefixes() {
698        let base = LanceCache::with_capacity(1000);
699        let prefixed = base.with_key_prefix("ns");
700
701        assert!(
702            prefixed
703                .get_with_key(&TestKey::<Vec<i32>>::new("k"))
704                .await
705                .is_none()
706        );
707        assert_eq!(base.stats().await.misses, 1);
708
709        prefixed
710            .insert_with_key(&TestKey::new("k"), Arc::new(vec![1]))
711            .await;
712        assert!(
713            prefixed
714                .get_with_key(&TestKey::<Vec<i32>>::new("k"))
715                .await
716                .is_some()
717        );
718        assert_eq!(base.stats().await.hits, 1);
719    }
720
721    #[tokio::test]
722    async fn test_cache_get_or_insert() {
723        let cache = LanceCache::with_capacity(1000);
724
725        let v: Arc<Vec<i32>> = cache
726            .get_or_insert_with_key(TestKey::<Vec<i32>>::new("k"), || async {
727                Ok(vec![1, 2, 3])
728            })
729            .await
730            .unwrap();
731        assert_eq!(*v, vec![1, 2, 3]);
732        assert_eq!(cache.stats().await.misses, 1);
733        assert_eq!(cache.stats().await.hits, 0);
734
735        // Second call should not invoke loader and should be a hit
736        let v: Arc<Vec<i32>> = cache
737            .get_or_insert_with_key(TestKey::<Vec<i32>>::new("k"), || async {
738                panic!("should not be called")
739            })
740            .await
741            .unwrap();
742        assert_eq!(*v, vec![1, 2, 3]);
743        assert_eq!(cache.stats().await.hits, 1);
744    }
745
746    #[tokio::test]
747    async fn test_custom_backend() {
748        use async_trait::async_trait;
749        use tokio::sync::Mutex;
750
751        #[derive(Debug)]
752        struct HashMapBackend {
753            map: Mutex<HashMap<InternalCacheKey, (CacheEntry, usize)>>,
754        }
755
756        impl HashMapBackend {
757            fn new() -> Self {
758                Self {
759                    map: Mutex::new(HashMap::new()),
760                }
761            }
762        }
763
764        #[async_trait]
765        impl CacheBackend for HashMapBackend {
766            async fn get(
767                &self,
768                key: &InternalCacheKey,
769                _codec: Option<CacheCodec>,
770            ) -> Option<CacheEntry> {
771                self.map.lock().await.get(key).map(|(e, _)| e.clone())
772            }
773            async fn insert(
774                &self,
775                key: &InternalCacheKey,
776                entry: CacheEntry,
777                size_bytes: usize,
778                _codec: Option<CacheCodec>,
779            ) {
780                self.map
781                    .lock()
782                    .await
783                    .insert(key.clone(), (entry, size_bytes));
784            }
785            async fn get_or_insert<'a>(
786                &self,
787                key: &InternalCacheKey,
788                loader: std::pin::Pin<
789                    Box<dyn futures::Future<Output = Result<(CacheEntry, usize)>> + Send + 'a>,
790                >,
791                _codec: Option<CacheCodec>,
792            ) -> Result<(CacheEntry, bool)> {
793                if let Some((entry, _)) = self.map.lock().await.get(key) {
794                    Ok((entry.clone(), true))
795                } else {
796                    let (entry, size) = loader.await?;
797                    self.map
798                        .lock()
799                        .await
800                        .insert(key.clone(), (entry.clone(), size));
801                    Ok((entry, false))
802                }
803            }
804            async fn invalidate_prefix(&self, prefix: &str) {
805                self.map.lock().await.retain(|k, _| !k.starts_with(prefix));
806            }
807            async fn clear(&self) {
808                self.map.lock().await.clear();
809            }
810            async fn num_entries(&self) -> usize {
811                self.map.lock().await.len()
812            }
813            async fn size_bytes(&self) -> usize {
814                self.map.lock().await.values().map(|(_, s)| *s).sum()
815            }
816        }
817
818        let cache = LanceCache::with_backend(Arc::new(HashMapBackend::new()));
819
820        cache
821            .insert_with_key(&TestKey::new("k"), Arc::new(vec![1, 2, 3]))
822            .await;
823        assert!(
824            cache
825                .get_with_key(&TestKey::<Vec<i32>>::new("k"))
826                .await
827                .is_some()
828        );
829        // Different type at same key = miss
830        assert!(
831            cache
832                .get_with_key(&TestKey::<Vec<u8>>::new("k"))
833                .await
834                .is_none()
835        );
836    }
837
838    #[tokio::test]
839    async fn test_get_or_insert_dedup() {
840        use std::sync::atomic::AtomicUsize;
841
842        let load_count = Arc::new(AtomicUsize::new(0));
843        let cache = LanceCache::with_capacity(10000);
844
845        let (barrier_tx, _) = tokio::sync::broadcast::channel::<()>(1);
846        let mut handles = Vec::new();
847        for _ in 0..5 {
848            let cache = cache.clone();
849            let load_count = load_count.clone();
850            let mut barrier_rx = barrier_tx.subscribe();
851            handles.push(tokio::spawn(async move {
852                barrier_rx.recv().await.ok();
853                cache
854                    .get_or_insert_with_key(TestKey::<Vec<i32>>::new("key"), || {
855                        let load_count = load_count.clone();
856                        async move {
857                            load_count.fetch_add(1, Ordering::SeqCst);
858                            tokio::task::yield_now().await;
859                            Ok(vec![1, 2, 3])
860                        }
861                    })
862                    .await
863            }));
864        }
865        barrier_tx.send(()).unwrap();
866        for h in handles {
867            let result: Arc<Vec<i32>> = h.await.unwrap().unwrap();
868            assert_eq!(*result, vec![1, 2, 3]);
869        }
870
871        assert_eq!(load_count.load(Ordering::SeqCst), 1);
872    }
873}