Skip to main content

lance_core/cache/
moka.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::pin::Pin;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7
8use async_trait::async_trait;
9use futures::Future;
10
11use crate::Result;
12
13use super::CacheCodec;
14use super::backend::{CacheBackend, CacheEntry, InternalCacheKey};
15
16/// Internal record stored in the moka cache.
17#[derive(Clone, Debug)]
18struct MokaCacheEntry {
19    entry: CacheEntry,
20    size_bytes: usize,
21}
22
23/// Default [`CacheBackend`] backed by a [moka](https://crates.io/crates/moka) cache.
24///
25/// Provides weighted-capacity eviction and concurrent-load deduplication
26/// via moka's built-in `optionally_get_with`.
27pub struct MokaCacheBackend {
28    cache: moka::future::Cache<InternalCacheKey, MokaCacheEntry>,
29}
30
31impl std::fmt::Debug for MokaCacheBackend {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("MokaCacheBackend")
34            .field("entry_count", &self.cache.entry_count())
35            .finish()
36    }
37}
38
39impl MokaCacheBackend {
40    pub fn with_capacity(capacity: usize) -> Self {
41        let cache = moka::future::Cache::builder()
42            .max_capacity(capacity as u64)
43            .weigher(|_, v: &MokaCacheEntry| v.size_bytes.try_into().unwrap_or(u32::MAX))
44            .support_invalidation_closures()
45            .build();
46        Self { cache }
47    }
48
49    pub fn no_cache() -> Self {
50        Self {
51            cache: moka::future::Cache::new(0),
52        }
53    }
54}
55
56#[async_trait]
57impl CacheBackend for MokaCacheBackend {
58    async fn get(&self, key: &InternalCacheKey, _codec: Option<CacheCodec>) -> Option<CacheEntry> {
59        self.cache.get(key).await.map(|r| r.entry)
60    }
61
62    async fn insert(
63        &self,
64        key: &InternalCacheKey,
65        entry: CacheEntry,
66        size_bytes: usize,
67        _codec: Option<CacheCodec>,
68    ) {
69        self.cache
70            .insert(key.clone(), MokaCacheEntry { entry, size_bytes })
71            .await;
72    }
73
74    async fn get_or_insert<'a>(
75        &self,
76        key: &InternalCacheKey,
77        loader: Pin<Box<dyn Future<Output = Result<(CacheEntry, usize)>> + Send + 'a>>,
78        _codec: Option<CacheCodec>,
79    ) -> Result<(CacheEntry, bool)> {
80        // Use moka's built-in dedup: optionally_get_with runs the init future
81        // at most once per key, even under concurrent access.
82        let (error_tx, error_rx) = tokio::sync::oneshot::channel();
83
84        // Track whether the loader actually ran (= cache miss).
85        let was_miss = Arc::new(AtomicBool::new(false));
86        let was_miss_clone = was_miss.clone();
87
88        let init = async move {
89            was_miss_clone.store(true, Ordering::Relaxed);
90            match loader.await {
91                Ok((entry, size_bytes)) => Some(MokaCacheEntry { entry, size_bytes }),
92                Err(e) => {
93                    let _ = error_tx.send(e);
94                    None
95                }
96            }
97        };
98
99        let owned_key = key.clone();
100        match self.cache.optionally_get_with(owned_key, init).await {
101            Some(record) => {
102                let was_cached = !was_miss.load(Ordering::Relaxed);
103                Ok((record.entry, was_cached))
104            }
105            None => match error_rx.await {
106                Ok(err) => Err(err),
107                Err(_) => Err(crate::Error::internal(
108                    "Failed to retrieve error from cache loader",
109                )),
110            },
111        }
112    }
113
114    async fn invalidate_prefix(&self, prefix: &str) {
115        let prefix = prefix.to_owned();
116        self.cache
117            .invalidate_entries_if(move |key, _value| key.starts_with(&prefix))
118            .expect("Cache configured correctly");
119    }
120
121    async fn clear(&self) {
122        self.cache.invalidate_all();
123        self.cache.run_pending_tasks().await;
124    }
125
126    async fn num_entries(&self) -> usize {
127        self.cache.run_pending_tasks().await;
128        self.cache.entry_count() as usize
129    }
130
131    async fn size_bytes(&self) -> usize {
132        self.cache.run_pending_tasks().await;
133        self.cache.weighted_size() as usize
134    }
135
136    fn approx_num_entries(&self) -> usize {
137        self.cache.entry_count() as usize
138    }
139
140    fn approx_size_bytes(&self) -> usize {
141        // Iterate rather than using `weighted_size()` because moka's
142        // weighted_size can be stale without `run_pending_tasks()`, which
143        // is async and can't be called from this synchronous context.
144        self.cache.iter().map(|(_, v)| v.size_bytes).sum()
145    }
146}