1use std::any::{Any, TypeId};
7use std::borrow::Cow;
8use std::sync::{
9 atomic::{AtomicU64, Ordering},
10 Arc,
11};
12
13use futures::Future;
14use moka::sync::Cache;
15
16use crate::Result;
17
18pub use deepsize::{Context, DeepSizeOf};
19
20type ArcAny = Arc<dyn Any + Send + Sync>;
21
22#[derive(Clone)]
23struct SizedRecord {
24 record: ArcAny,
25 size_accessor: Arc<dyn Fn(&ArcAny) -> usize + Send + Sync>,
26}
27
28impl std::fmt::Debug for SizedRecord {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("SizedRecord")
31 .field("record", &self.record)
32 .finish()
33 }
34}
35
36impl SizedRecord {
37 fn new<T: DeepSizeOf + Send + Sync + 'static>(record: Arc<T>) -> Self {
38 let size_accessor =
40 |record: &ArcAny| -> usize { record.downcast_ref::<T>().unwrap().deep_size_of() + 8 };
41 Self {
42 record,
43 size_accessor: Arc::new(size_accessor),
44 }
45 }
46}
47
48#[derive(Clone)]
49pub struct LanceCache {
50 cache: Arc<Cache<(String, TypeId), SizedRecord>>,
51 prefix: String,
52 hits: Arc<AtomicU64>,
53 misses: Arc<AtomicU64>,
54}
55
56impl std::fmt::Debug for LanceCache {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("LanceCache")
59 .field("cache", &self.cache)
60 .finish()
61 }
62}
63
64impl DeepSizeOf for LanceCache {
65 fn deep_size_of_children(&self, _: &mut Context) -> usize {
66 self.cache
67 .iter()
68 .map(|(_, v)| (v.size_accessor)(&v.record))
69 .sum()
70 }
71}
72
73impl LanceCache {
74 pub fn with_capacity(capacity: usize) -> Self {
75 let cache = Cache::builder()
76 .max_capacity(capacity as u64)
77 .weigher(|_, v: &SizedRecord| {
78 (v.size_accessor)(&v.record).try_into().unwrap_or(u32::MAX)
79 })
80 .support_invalidation_closures()
81 .build();
82 Self {
83 cache: Arc::new(cache),
84 prefix: String::new(),
85 hits: Arc::new(AtomicU64::new(0)),
86 misses: Arc::new(AtomicU64::new(0)),
87 }
88 }
89
90 pub fn no_cache() -> Self {
91 Self {
92 cache: Arc::new(Cache::new(0)),
93 prefix: String::new(),
94 hits: Arc::new(AtomicU64::new(0)),
95 misses: Arc::new(AtomicU64::new(0)),
96 }
97 }
98
99 pub fn with_key_prefix(&self, prefix: &str) -> Self {
107 Self {
108 cache: self.cache.clone(),
109 prefix: format!("{}{}/", self.prefix, prefix),
110 hits: self.hits.clone(),
111 misses: self.misses.clone(),
112 }
113 }
114
115 fn get_key(&self, key: &str) -> String {
116 if self.prefix.is_empty() {
117 key.to_string()
118 } else {
119 format!("{}/{}", self.prefix, key)
120 }
121 }
122
123 pub fn invalidate_prefix(&self, prefix: &str) {
128 let full_prefix = format!("{}{}", self.prefix, prefix);
129 self.cache
130 .invalidate_entries_if(move |(key, _typeid), _value| key.starts_with(&full_prefix))
131 .expect("Cache configured correctly");
132 }
133
134 pub fn size(&self) -> usize {
135 self.cache.run_pending_tasks();
136 self.cache.entry_count() as usize
137 }
138
139 pub fn approx_size(&self) -> usize {
140 self.cache.entry_count() as usize
141 }
142
143 pub fn size_bytes(&self) -> usize {
144 self.cache.run_pending_tasks();
145 self.approx_size_bytes()
146 }
147
148 pub fn approx_size_bytes(&self) -> usize {
149 self.cache.weighted_size() as usize
150 }
151
152 pub fn insert<T: DeepSizeOf + Send + Sync + 'static>(&self, key: &str, metadata: Arc<T>) {
153 let key = self.get_key(key);
154 let record = SizedRecord::new(metadata);
155 tracing::trace!(
156 target: "lance_cache::insert",
157 key = key,
158 type_id = std::any::type_name::<T>(),
159 size = (record.size_accessor)(&record.record),
160 );
161 self.cache.insert((key, TypeId::of::<T>()), record);
162 }
163
164 pub fn insert_unsized<T: DeepSizeOf + Send + Sync + 'static + ?Sized>(
165 &self,
166 key: &str,
167 metadata: Arc<T>,
168 ) {
169 self.insert(key, Arc::new(metadata))
171 }
172
173 pub fn get<T: DeepSizeOf + Send + Sync + 'static>(&self, key: &str) -> Option<Arc<T>> {
174 let key = self.get_key(key);
175 if let Some(metadata) = self.cache.get(&(key, TypeId::of::<T>())) {
176 self.hits.fetch_add(1, Ordering::Relaxed);
177 Some(metadata.record.clone().downcast::<T>().unwrap())
178 } else {
179 self.misses.fetch_add(1, Ordering::Relaxed);
180 None
181 }
182 }
183
184 pub fn get_unsized<T: DeepSizeOf + Send + Sync + 'static + ?Sized>(
185 &self,
186 key: &str,
187 ) -> Option<Arc<T>> {
188 let outer = self.get::<Arc<T>>(key)?;
189 Some(outer.as_ref().clone())
190 }
191
192 pub async fn get_or_insert<T: DeepSizeOf + Send + Sync + 'static, F, Fut>(
198 &self,
199 key: String,
200 loader: F,
201 ) -> Result<Arc<T>>
202 where
203 F: FnOnce(&str) -> Fut,
204 Fut: Future<Output = Result<T>>,
205 {
206 let full_key = self.get_key(&key);
207 if let Some(metadata) = self.cache.get(&(full_key, TypeId::of::<T>())) {
208 self.hits.fetch_add(1, Ordering::Relaxed);
209 return Ok(metadata.record.clone().downcast::<T>().unwrap());
210 }
211
212 self.misses.fetch_add(1, Ordering::Relaxed);
213 let metadata = Arc::new(loader(&key).await?);
214 self.insert(&key, metadata.clone());
215 Ok(metadata)
216 }
217
218 pub fn stats(&self) -> CacheStats {
219 CacheStats {
220 hits: self.hits.load(Ordering::Relaxed),
221 misses: self.misses.load(Ordering::Relaxed),
222 }
223 }
224
225 pub fn insert_with_key<K>(&self, cache_key: &K, metadata: Arc<K::ValueType>)
227 where
228 K: CacheKey,
229 K::ValueType: DeepSizeOf + Send + Sync + 'static,
230 {
231 self.insert(&cache_key.key(), metadata)
232 }
233
234 pub fn get_with_key<K>(&self, cache_key: &K) -> Option<Arc<K::ValueType>>
235 where
236 K: CacheKey,
237 K::ValueType: DeepSizeOf + Send + Sync + 'static,
238 {
239 self.get::<K::ValueType>(&cache_key.key())
240 }
241
242 pub async fn get_or_insert_with_key<K, F, Fut>(
243 &self,
244 cache_key: K,
245 loader: F,
246 ) -> Result<Arc<K::ValueType>>
247 where
248 K: CacheKey,
249 K::ValueType: DeepSizeOf + Send + Sync + 'static,
250 F: FnOnce() -> Fut,
251 Fut: Future<Output = Result<K::ValueType>>,
252 {
253 let key_str = cache_key.key().into_owned();
254 self.get_or_insert(key_str, |_| loader()).await
255 }
256}
257
258pub trait CacheKey {
259 type ValueType;
260
261 fn key(&self) -> Cow<'_, str>;
262}
263
264#[derive(Debug, Clone)]
265pub struct CacheStats {
266 pub hits: u64,
268 pub misses: u64,
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_cache_bytes() {
278 let item = Arc::new(vec![1, 2, 3]);
279 let item_size = item.deep_size_of(); let capacity = 10 * item_size;
281
282 let cache = LanceCache::with_capacity(capacity);
283 assert_eq!(cache.size_bytes(), 0);
284 assert_eq!(cache.approx_size_bytes(), 0);
285
286 let item = Arc::new(vec![1, 2, 3]);
287 cache.insert("key", item.clone());
288 assert_eq!(cache.size(), 1);
289 assert_eq!(cache.size_bytes(), item_size);
290 assert_eq!(cache.approx_size_bytes(), item_size);
291
292 let retrieved = cache.get::<Vec<i32>>("key").unwrap();
293 assert_eq!(*retrieved, *item);
294
295 for i in 0..20 {
297 cache.insert(&format!("key_{}", i), Arc::new(vec![i, i, i]));
298 }
299 assert_eq!(cache.size_bytes(), capacity);
300 assert_eq!(cache.size(), 10);
301 }
302
303 #[test]
304 fn test_cache_trait_objects() {
305 #[derive(Debug, DeepSizeOf)]
306 struct MyType(i32);
307
308 trait MyTrait: DeepSizeOf + Send + Sync + Any {
309 fn as_any(&self) -> &dyn Any;
310 }
311
312 impl MyTrait for MyType {
313 fn as_any(&self) -> &dyn Any {
314 self
315 }
316 }
317
318 let item = Arc::new(MyType(42));
319 let item_dyn: Arc<dyn MyTrait> = item;
320
321 let cache = LanceCache::with_capacity(1000);
322 cache.insert_unsized("test", item_dyn);
323
324 let retrieved = cache.get_unsized::<dyn MyTrait>("test").unwrap();
325 let retrieved = retrieved.as_any().downcast_ref::<MyType>().unwrap();
326 assert_eq!(retrieved.0, 42);
327 }
328
329 #[test]
330 fn test_cache_stats_basic() {
331 let cache = LanceCache::with_capacity(1000);
332
333 let stats = cache.stats();
335 assert_eq!(stats.hits, 0);
336 assert_eq!(stats.misses, 0);
337
338 let result = cache.get::<Vec<i32>>("nonexistent");
340 assert!(result.is_none());
341 let stats = cache.stats();
342 assert_eq!(stats.hits, 0);
343 assert_eq!(stats.misses, 1);
344
345 cache.insert("key1", Arc::new(vec![1, 2, 3]));
347 let result = cache.get::<Vec<i32>>("key1");
348 assert!(result.is_some());
349 let stats = cache.stats();
350 assert_eq!(stats.hits, 1);
351 assert_eq!(stats.misses, 1);
352
353 let result = cache.get::<Vec<i32>>("key1");
355 assert!(result.is_some());
356 let stats = cache.stats();
357 assert_eq!(stats.hits, 2);
358 assert_eq!(stats.misses, 1);
359
360 let result = cache.get::<Vec<i32>>("nonexistent2");
362 assert!(result.is_none());
363 let stats = cache.stats();
364 assert_eq!(stats.hits, 2);
365 assert_eq!(stats.misses, 2);
366 }
367
368 #[test]
369 fn test_cache_stats_with_prefixes() {
370 let base_cache = LanceCache::with_capacity(1000);
371 let prefixed_cache = base_cache.with_key_prefix("test");
372
373 let stats = base_cache.stats();
375 assert_eq!(stats.hits, 0);
376 assert_eq!(stats.misses, 0);
377
378 let stats = prefixed_cache.stats();
379 assert_eq!(stats.hits, 0);
380 assert_eq!(stats.misses, 0);
381
382 let result = prefixed_cache.get::<Vec<i32>>("key1");
384 assert!(result.is_none());
385
386 let stats = base_cache.stats();
388 assert_eq!(stats.hits, 0);
389 assert_eq!(stats.misses, 1);
390
391 let stats = prefixed_cache.stats();
392 assert_eq!(stats.hits, 0);
393 assert_eq!(stats.misses, 1);
394
395 prefixed_cache.insert("key1", Arc::new(vec![1, 2, 3]));
397 let result = prefixed_cache.get::<Vec<i32>>("key1");
398 assert!(result.is_some());
399
400 let stats = base_cache.stats();
402 assert_eq!(stats.hits, 1);
403 assert_eq!(stats.misses, 1);
404
405 let stats = prefixed_cache.stats();
406 assert_eq!(stats.hits, 1);
407 assert_eq!(stats.misses, 1);
408 }
409
410 #[test]
411 fn test_cache_stats_unsized() {
412 #[derive(Debug, DeepSizeOf)]
413 struct MyType(i32);
414
415 trait MyTrait: DeepSizeOf + Send + Sync + Any {}
416
417 impl MyTrait for MyType {}
418
419 let cache = LanceCache::with_capacity(1000);
420
421 let result = cache.get_unsized::<dyn MyTrait>("test");
423 assert!(result.is_none());
424 let stats = cache.stats();
425 assert_eq!(stats.hits, 0);
426 assert_eq!(stats.misses, 1);
427
428 let item = Arc::new(MyType(42));
430 let item_dyn: Arc<dyn MyTrait> = item;
431 cache.insert_unsized("test", item_dyn);
432
433 let result = cache.get_unsized::<dyn MyTrait>("test");
434 assert!(result.is_some());
435 let stats = cache.stats();
436 assert_eq!(stats.hits, 1);
437 assert_eq!(stats.misses, 1);
438 }
439
440 #[tokio::test]
441 async fn test_cache_stats_get_or_insert() {
442 let cache = LanceCache::with_capacity(1000);
443
444 let result: Arc<Vec<i32>> = cache
446 .get_or_insert("key1".to_string(), |_key| async { Ok(vec![1, 2, 3]) })
447 .await
448 .unwrap();
449 assert_eq!(*result, vec![1, 2, 3]);
450
451 let stats = cache.stats();
452 assert_eq!(stats.hits, 0);
453 assert_eq!(stats.misses, 1);
454
455 let result: Arc<Vec<i32>> = cache
457 .get_or_insert("key1".to_string(), |_key| async {
458 panic!("Should not be called")
459 })
460 .await
461 .unwrap();
462 assert_eq!(*result, vec![1, 2, 3]);
463
464 let stats = cache.stats();
465 assert_eq!(stats.hits, 1);
466 assert_eq!(stats.misses, 1);
467
468 let result: Arc<Vec<i32>> = cache
470 .get_or_insert("key2".to_string(), |_key| async { Ok(vec![4, 5, 6]) })
471 .await
472 .unwrap();
473 assert_eq!(*result, vec![4, 5, 6]);
474
475 let stats = cache.stats();
476 assert_eq!(stats.hits, 1);
477 assert_eq!(stats.misses, 2);
478 }
479}