1use std::any::{Any, TypeId};
7use std::borrow::Cow;
8use std::sync::{
9 atomic::{AtomicU64, Ordering},
10 Arc,
11};
12
13use futures::{Future, FutureExt};
14use moka::future::Cache;
15use snafu::location;
16
17use crate::Result;
18
19pub use deepsize::{Context, DeepSizeOf};
20
21type ArcAny = Arc<dyn Any + Send + Sync>;
22
23#[derive(Clone)]
24struct SizedRecord {
25 record: ArcAny,
26 size_accessor: Arc<dyn Fn(&ArcAny) -> usize + Send + Sync>,
27}
28
29impl std::fmt::Debug for SizedRecord {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("SizedRecord")
32 .field("record", &self.record)
33 .finish()
34 }
35}
36
37impl SizedRecord {
38 fn new<T: DeepSizeOf + Send + Sync + 'static>(record: Arc<T>) -> Self {
39 let size_accessor =
41 |record: &ArcAny| -> usize { record.downcast_ref::<T>().unwrap().deep_size_of() + 8 };
42 Self {
43 record,
44 size_accessor: Arc::new(size_accessor),
45 }
46 }
47}
48
49#[derive(Clone)]
50pub struct LanceCache {
51 cache: Arc<Cache<(String, TypeId), SizedRecord>>,
52 prefix: String,
53 hits: Arc<AtomicU64>,
54 misses: Arc<AtomicU64>,
55}
56
57impl std::fmt::Debug for LanceCache {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("LanceCache")
60 .field("cache", &self.cache)
61 .finish()
62 }
63}
64
65impl DeepSizeOf for LanceCache {
66 fn deep_size_of_children(&self, _: &mut Context) -> usize {
67 self.cache
68 .iter()
69 .map(|(_, v)| (v.size_accessor)(&v.record))
70 .sum()
71 }
72}
73
74impl LanceCache {
75 pub fn with_capacity(capacity: usize) -> Self {
76 let cache = Cache::builder()
77 .max_capacity(capacity as u64)
78 .weigher(|_, v: &SizedRecord| {
79 (v.size_accessor)(&v.record).try_into().unwrap_or(u32::MAX)
80 })
81 .support_invalidation_closures()
82 .build();
83 Self {
84 cache: Arc::new(cache),
85 prefix: String::new(),
86 hits: Arc::new(AtomicU64::new(0)),
87 misses: Arc::new(AtomicU64::new(0)),
88 }
89 }
90
91 pub fn no_cache() -> Self {
92 Self {
93 cache: Arc::new(Cache::new(0)),
94 prefix: String::new(),
95 hits: Arc::new(AtomicU64::new(0)),
96 misses: Arc::new(AtomicU64::new(0)),
97 }
98 }
99
100 pub fn with_key_prefix(&self, prefix: &str) -> Self {
108 Self {
109 cache: self.cache.clone(),
110 prefix: format!("{}{}/", self.prefix, prefix),
111 hits: self.hits.clone(),
112 misses: self.misses.clone(),
113 }
114 }
115
116 fn get_key(&self, key: &str) -> String {
117 if self.prefix.is_empty() {
118 key.to_string()
119 } else {
120 format!("{}/{}", self.prefix, key)
121 }
122 }
123
124 pub fn invalidate_prefix(&self, prefix: &str) {
129 let full_prefix = format!("{}{}", self.prefix, prefix);
130 self.cache
131 .invalidate_entries_if(move |(key, _typeid), _value| key.starts_with(&full_prefix))
132 .expect("Cache configured correctly");
133 }
134
135 pub async fn size(&self) -> usize {
136 self.cache.run_pending_tasks().await;
137 self.cache.entry_count() as usize
138 }
139
140 pub fn approx_size(&self) -> usize {
141 self.cache.entry_count() as usize
142 }
143
144 pub async fn size_bytes(&self) -> usize {
145 self.cache.run_pending_tasks().await;
146 self.approx_size_bytes()
147 }
148
149 pub fn approx_size_bytes(&self) -> usize {
150 self.cache.weighted_size() as usize
151 }
152
153 async fn insert<T: DeepSizeOf + Send + Sync + 'static>(&self, key: &str, metadata: Arc<T>) {
154 let key = self.get_key(key);
155 let record = SizedRecord::new(metadata);
156 tracing::trace!(
157 target: "lance_cache::insert",
158 key = key,
159 type_id = std::any::type_name::<T>(),
160 size = (record.size_accessor)(&record.record),
161 );
162 self.cache.insert((key, TypeId::of::<T>()), record).await;
163 }
164
165 pub async fn insert_unsized<T: DeepSizeOf + Send + Sync + 'static + ?Sized>(
166 &self,
167 key: &str,
168 metadata: Arc<T>,
169 ) {
170 self.insert(key, Arc::new(metadata)).await
172 }
173
174 async fn get<T: DeepSizeOf + Send + Sync + 'static>(&self, key: &str) -> Option<Arc<T>> {
175 let key = self.get_key(key);
176 if let Some(metadata) = self.cache.get(&(key, TypeId::of::<T>())).await {
177 self.hits.fetch_add(1, Ordering::Relaxed);
178 Some(metadata.record.clone().downcast::<T>().unwrap())
179 } else {
180 self.misses.fetch_add(1, Ordering::Relaxed);
181 None
182 }
183 }
184
185 pub async fn get_unsized<T: DeepSizeOf + Send + Sync + 'static + ?Sized>(
186 &self,
187 key: &str,
188 ) -> Option<Arc<T>> {
189 let outer = self.get::<Arc<T>>(key).await?;
190 Some(outer.as_ref().clone())
191 }
192
193 async fn get_or_insert<T: DeepSizeOf + Send + Sync + 'static, F, Fut>(
199 &self,
200 key: String,
201 loader: F,
202 ) -> Result<Arc<T>>
203 where
204 F: FnOnce(&str) -> Fut,
205 Fut: Future<Output = Result<T>> + Send,
206 {
207 let full_key = self.get_key(&key);
208 let cache_key = (full_key, TypeId::of::<T>());
209
210 let hits = self.hits.clone();
212 let misses = self.misses.clone();
213
214 let (error_tx, error_rx) = tokio::sync::oneshot::channel();
216 let (init_run_tx, mut init_run_rx) = tokio::sync::oneshot::channel();
217
218 let init = Box::pin(async move {
219 let _ = init_run_tx.send(());
220 misses.fetch_add(1, Ordering::Relaxed);
221 match loader(&key).await {
222 Ok(value) => Some(SizedRecord::new(Arc::new(value))),
223 Err(e) => {
224 let _ = error_tx.send(e);
225 None
226 }
227 }
228 });
229
230 match self.cache.optionally_get_with(cache_key, init).await {
231 Some(metadata) => {
232 match init_run_rx.try_recv() {
234 Ok(()) => {
235 }
237 Err(_) => {
238 hits.fetch_add(1, Ordering::Relaxed);
240 }
241 }
242 Ok(metadata.record.clone().downcast::<T>().unwrap())
243 }
244 None => {
245 match error_rx.await {
247 Ok(err) => Err(err),
248 Err(_) => Err(crate::Error::Internal {
249 message: "Failed to retrieve error from cache loader".into(),
250 location: location!(),
251 }),
252 }
253 }
254 }
255 }
256
257 pub async fn stats(&self) -> CacheStats {
258 self.cache.run_pending_tasks().await;
259 CacheStats {
260 hits: self.hits.load(Ordering::Relaxed),
261 misses: self.misses.load(Ordering::Relaxed),
262 num_entries: self.cache.entry_count() as usize,
263 size_bytes: self.cache.weighted_size() as usize,
264 }
265 }
266
267 pub async fn clear(&self) {
268 self.cache.invalidate_all();
269 self.cache.run_pending_tasks().await;
270 self.hits.store(0, Ordering::Relaxed);
271 self.misses.store(0, Ordering::Relaxed);
272 }
273
274 pub async fn insert_with_key<K>(&self, cache_key: &K, metadata: Arc<K::ValueType>)
276 where
277 K: CacheKey,
278 K::ValueType: DeepSizeOf + Send + Sync + 'static,
279 {
280 self.insert(&cache_key.key(), metadata).boxed().await
281 }
282
283 pub async fn get_with_key<K>(&self, cache_key: &K) -> Option<Arc<K::ValueType>>
284 where
285 K: CacheKey,
286 K::ValueType: DeepSizeOf + Send + Sync + 'static,
287 {
288 self.get::<K::ValueType>(&cache_key.key()).boxed().await
289 }
290
291 pub async fn get_or_insert_with_key<K, F, Fut>(
292 &self,
293 cache_key: K,
294 loader: F,
295 ) -> Result<Arc<K::ValueType>>
296 where
297 K: CacheKey,
298 K::ValueType: DeepSizeOf + Send + Sync + 'static,
299 F: FnOnce() -> Fut,
300 Fut: Future<Output = Result<K::ValueType>> + Send,
301 {
302 let key_str = cache_key.key().into_owned();
303 Box::pin(self.get_or_insert(key_str, |_| loader())).await
304 }
305
306 pub async fn insert_unsized_with_key<K>(&self, cache_key: &K, metadata: Arc<K::ValueType>)
307 where
308 K: UnsizedCacheKey,
309 K::ValueType: DeepSizeOf + Send + Sync + 'static,
310 {
311 self.insert_unsized(&cache_key.key(), metadata)
312 .boxed()
313 .await
314 }
315
316 pub async fn get_unsized_with_key<K>(&self, cache_key: &K) -> Option<Arc<K::ValueType>>
317 where
318 K: UnsizedCacheKey,
319 K::ValueType: DeepSizeOf + Send + Sync + 'static,
320 {
321 self.get_unsized::<K::ValueType>(&cache_key.key())
322 .boxed()
323 .await
324 }
325}
326
327pub trait CacheKey {
328 type ValueType;
329
330 fn key(&self) -> Cow<'_, str>;
331}
332
333pub trait UnsizedCacheKey {
334 type ValueType: ?Sized;
335
336 fn key(&self) -> Cow<'_, str>;
337}
338
339#[derive(Debug, Clone)]
340pub struct CacheStats {
341 pub hits: u64,
343 pub misses: u64,
345 pub num_entries: usize,
347 pub size_bytes: usize,
349}
350
351impl CacheStats {
352 pub fn hit_ratio(&self) -> f32 {
353 if self.hits + self.misses == 0 {
354 0.0
355 } else {
356 self.hits as f32 / (self.hits + self.misses) as f32
357 }
358 }
359
360 pub fn miss_ratio(&self) -> f32 {
361 if self.hits + self.misses == 0 {
362 0.0
363 } else {
364 self.misses as f32 / (self.hits + self.misses) as f32
365 }
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[tokio::test]
374 async fn test_cache_bytes() {
375 let item = Arc::new(vec![1, 2, 3]);
376 let item_size = item.deep_size_of(); let capacity = 10 * item_size;
378
379 let cache = LanceCache::with_capacity(capacity);
380 assert_eq!(cache.size_bytes().await, 0);
381 assert_eq!(cache.approx_size_bytes(), 0);
382
383 let item = Arc::new(vec![1, 2, 3]);
384 cache.insert("key", item.clone()).await;
385 assert_eq!(cache.size().await, 1);
386 assert_eq!(cache.size_bytes().await, item_size);
387 assert_eq!(cache.approx_size_bytes(), item_size);
388
389 let retrieved = cache.get::<Vec<i32>>("key").await.unwrap();
390 assert_eq!(*retrieved, *item);
391
392 for i in 0..20 {
394 cache
395 .insert(&format!("key_{}", i), Arc::new(vec![i, i, i]))
396 .await;
397 }
398 assert_eq!(cache.size_bytes().await, capacity);
399 assert_eq!(cache.size().await, 10);
400 }
401
402 #[tokio::test]
403 async fn test_cache_trait_objects() {
404 #[derive(Debug, DeepSizeOf)]
405 struct MyType(i32);
406
407 trait MyTrait: DeepSizeOf + Send + Sync + Any {
408 fn as_any(&self) -> &dyn Any;
409 }
410
411 impl MyTrait for MyType {
412 fn as_any(&self) -> &dyn Any {
413 self
414 }
415 }
416
417 let item = Arc::new(MyType(42));
418 let item_dyn: Arc<dyn MyTrait> = item;
419
420 let cache = LanceCache::with_capacity(1000);
421 cache.insert_unsized("test", item_dyn).await;
422
423 let retrieved = cache.get_unsized::<dyn MyTrait>("test").await.unwrap();
424 let retrieved = retrieved.as_any().downcast_ref::<MyType>().unwrap();
425 assert_eq!(retrieved.0, 42);
426 }
427
428 #[tokio::test]
429 async fn test_cache_stats_basic() {
430 let cache = LanceCache::with_capacity(1000);
431
432 let stats = cache.stats().await;
434 assert_eq!(stats.hits, 0);
435 assert_eq!(stats.misses, 0);
436
437 let result = cache.get::<Vec<i32>>("nonexistent");
439 assert!(result.await.is_none());
440 let stats = cache.stats().await;
441 assert_eq!(stats.hits, 0);
442 assert_eq!(stats.misses, 1);
443
444 cache.insert("key1", Arc::new(vec![1, 2, 3])).await;
446 let result = cache.get::<Vec<i32>>("key1");
447 assert!(result.await.is_some());
448 let stats = cache.stats().await;
449 assert_eq!(stats.hits, 1);
450 assert_eq!(stats.misses, 1);
451
452 let result = cache.get::<Vec<i32>>("key1");
454 assert!(result.await.is_some());
455 let stats = cache.stats().await;
456 assert_eq!(stats.hits, 2);
457 assert_eq!(stats.misses, 1);
458
459 let result = cache.get::<Vec<i32>>("nonexistent2");
461 assert!(result.await.is_none());
462 let stats = cache.stats().await;
463 assert_eq!(stats.hits, 2);
464 assert_eq!(stats.misses, 2);
465 }
466
467 #[tokio::test]
468 async fn test_cache_stats_with_prefixes() {
469 let base_cache = LanceCache::with_capacity(1000);
470 let prefixed_cache = base_cache.with_key_prefix("test");
471
472 let stats = base_cache.stats().await;
474 assert_eq!(stats.hits, 0);
475 assert_eq!(stats.misses, 0);
476
477 let stats = prefixed_cache.stats().await;
478 assert_eq!(stats.hits, 0);
479 assert_eq!(stats.misses, 0);
480
481 let result = prefixed_cache.get::<Vec<i32>>("key1");
483 assert!(result.await.is_none());
484
485 let stats = base_cache.stats().await;
487 assert_eq!(stats.hits, 0);
488 assert_eq!(stats.misses, 1);
489
490 let stats = prefixed_cache.stats().await;
491 assert_eq!(stats.hits, 0);
492 assert_eq!(stats.misses, 1);
493
494 prefixed_cache.insert("key1", Arc::new(vec![1, 2, 3])).await;
496 let result = prefixed_cache.get::<Vec<i32>>("key1");
497 assert!(result.await.is_some());
498
499 let stats = base_cache.stats().await;
501 assert_eq!(stats.hits, 1);
502 assert_eq!(stats.misses, 1);
503
504 let stats = prefixed_cache.stats().await;
505 assert_eq!(stats.hits, 1);
506 assert_eq!(stats.misses, 1);
507 }
508
509 #[tokio::test]
510 async fn test_cache_stats_unsized() {
511 #[derive(Debug, DeepSizeOf)]
512 struct MyType(i32);
513
514 trait MyTrait: DeepSizeOf + Send + Sync + Any {}
515
516 impl MyTrait for MyType {}
517
518 let cache = LanceCache::with_capacity(1000);
519
520 let result = cache.get_unsized::<dyn MyTrait>("test");
522 assert!(result.await.is_none());
523 let stats = cache.stats().await;
524 assert_eq!(stats.hits, 0);
525 assert_eq!(stats.misses, 1);
526
527 let item = Arc::new(MyType(42));
529 let item_dyn: Arc<dyn MyTrait> = item;
530 cache.insert_unsized("test", item_dyn).await;
531
532 let result = cache.get_unsized::<dyn MyTrait>("test");
533 assert!(result.await.is_some());
534 let stats = cache.stats().await;
535 assert_eq!(stats.hits, 1);
536 assert_eq!(stats.misses, 1);
537 }
538
539 #[tokio::test]
540 async fn test_cache_stats_get_or_insert() {
541 let cache = LanceCache::with_capacity(1000);
542
543 let result: Arc<Vec<i32>> = cache
545 .get_or_insert("key1".to_string(), |_key| async { Ok(vec![1, 2, 3]) })
546 .await
547 .unwrap();
548 assert_eq!(*result, vec![1, 2, 3]);
549
550 let stats = cache.stats().await;
551 assert_eq!(stats.hits, 0);
552 assert_eq!(stats.misses, 1);
553
554 let result: Arc<Vec<i32>> = cache
556 .get_or_insert("key1".to_string(), |_key| async {
557 panic!("Should not be called")
558 })
559 .await
560 .unwrap();
561 assert_eq!(*result, vec![1, 2, 3]);
562
563 let stats = cache.stats().await;
564 assert_eq!(stats.hits, 1);
565 assert_eq!(stats.misses, 1);
566
567 let result: Arc<Vec<i32>> = cache
569 .get_or_insert("key2".to_string(), |_key| async { Ok(vec![4, 5, 6]) })
570 .await
571 .unwrap();
572 assert_eq!(*result, vec![4, 5, 6]);
573
574 let stats = cache.stats().await;
575 assert_eq!(stats.hits, 1);
576 assert_eq!(stats.misses, 2);
577 }
578}