1use std::any::{Any, TypeId};
7use std::borrow::Cow;
8use std::sync::{
9 Arc,
10 atomic::{AtomicU64, Ordering},
11};
12
13use futures::{Future, FutureExt};
14use moka::future::Cache;
15
16use crate::Result;
17
18pub use deepsize::{Context, DeepSizeOf};
19
20type ArcAny = Arc<dyn Any + Send + Sync>;
21
22#[derive(Clone)]
23pub struct SizedRecord {
24 record: ArcAny,
25 size_accessor: Arc<dyn Fn(&ArcAny) -> usize + Send + Sync>,
26}
27
28impl std::fmt::Debug for SizedRecord {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("SizedRecord")
31 .field("record", &self.record)
32 .finish()
33 }
34}
35
36impl DeepSizeOf for SizedRecord {
37 fn deep_size_of_children(&self, _: &mut Context) -> usize {
38 (self.size_accessor)(&self.record)
39 }
40}
41
42impl SizedRecord {
43 fn new<T: DeepSizeOf + Send + Sync + 'static>(record: Arc<T>) -> Self {
44 let size_accessor =
46 |record: &ArcAny| -> usize { record.downcast_ref::<T>().unwrap().deep_size_of() + 8 };
47 Self {
48 record,
49 size_accessor: Arc::new(size_accessor),
50 }
51 }
52}
53
54#[derive(Clone)]
55pub struct LanceCache {
56 cache: Arc<Cache<(String, TypeId), SizedRecord>>,
57 prefix: String,
58 hits: Arc<AtomicU64>,
59 misses: Arc<AtomicU64>,
60}
61
62impl std::fmt::Debug for LanceCache {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 f.debug_struct("LanceCache")
65 .field("cache", &self.cache)
66 .finish()
67 }
68}
69
70impl DeepSizeOf for LanceCache {
71 fn deep_size_of_children(&self, _: &mut Context) -> usize {
72 self.cache
73 .iter()
74 .map(|(_, v)| (v.size_accessor)(&v.record))
75 .sum()
76 }
77}
78
79impl LanceCache {
80 pub fn with_capacity(capacity: usize) -> Self {
81 let cache = Cache::builder()
82 .max_capacity(capacity as u64)
83 .weigher(|_, v: &SizedRecord| {
84 (v.size_accessor)(&v.record).try_into().unwrap_or(u32::MAX)
85 })
86 .support_invalidation_closures()
87 .build();
88 Self {
89 cache: Arc::new(cache),
90 prefix: String::new(),
91 hits: Arc::new(AtomicU64::new(0)),
92 misses: Arc::new(AtomicU64::new(0)),
93 }
94 }
95
96 pub fn no_cache() -> Self {
97 Self {
98 cache: Arc::new(Cache::new(0)),
99 prefix: String::new(),
100 hits: Arc::new(AtomicU64::new(0)),
101 misses: Arc::new(AtomicU64::new(0)),
102 }
103 }
104
105 pub fn with_key_prefix(&self, prefix: &str) -> Self {
113 Self {
114 cache: self.cache.clone(),
115 prefix: format!("{}{}/", self.prefix, prefix),
116 hits: self.hits.clone(),
117 misses: self.misses.clone(),
118 }
119 }
120
121 fn get_key(&self, key: &str) -> String {
122 if self.prefix.is_empty() {
123 key.to_string()
124 } else {
125 format!("{}/{}", self.prefix, key)
126 }
127 }
128
129 pub fn invalidate_prefix(&self, prefix: &str) {
134 let full_prefix = format!("{}{}", self.prefix, prefix);
135 self.cache
136 .invalidate_entries_if(move |(key, _typeid), _value| key.starts_with(&full_prefix))
137 .expect("Cache configured correctly");
138 }
139
140 pub async fn size(&self) -> usize {
141 self.cache.run_pending_tasks().await;
142 self.cache.entry_count() as usize
143 }
144
145 pub fn approx_size(&self) -> usize {
146 self.cache.entry_count() as usize
147 }
148
149 pub async fn size_bytes(&self) -> usize {
150 self.cache.run_pending_tasks().await;
151 self.approx_size_bytes()
152 }
153
154 pub fn approx_size_bytes(&self) -> usize {
155 self.cache.weighted_size() as usize
156 }
157
158 async fn insert<T: DeepSizeOf + Send + Sync + 'static>(&self, key: &str, metadata: Arc<T>) {
159 let key = self.get_key(key);
160 let record = SizedRecord::new(metadata);
161 tracing::trace!(
162 target: "lance_cache::insert",
163 key = key,
164 type_id = std::any::type_name::<T>(),
165 size = (record.size_accessor)(&record.record),
166 );
167 self.cache.insert((key, TypeId::of::<T>()), record).await;
168 }
169
170 pub async fn insert_unsized<T: DeepSizeOf + Send + Sync + 'static + ?Sized>(
171 &self,
172 key: &str,
173 metadata: Arc<T>,
174 ) {
175 self.insert(key, Arc::new(metadata)).await
177 }
178
179 async fn get<T: DeepSizeOf + Send + Sync + 'static>(&self, key: &str) -> Option<Arc<T>> {
180 let key = self.get_key(key);
181 if let Some(metadata) = self.cache.get(&(key, TypeId::of::<T>())).await {
182 self.hits.fetch_add(1, Ordering::Relaxed);
183 Some(metadata.record.clone().downcast::<T>().unwrap())
184 } else {
185 self.misses.fetch_add(1, Ordering::Relaxed);
186 None
187 }
188 }
189
190 pub async fn get_unsized<T: DeepSizeOf + Send + Sync + 'static + ?Sized>(
191 &self,
192 key: &str,
193 ) -> Option<Arc<T>> {
194 let outer = self.get::<Arc<T>>(key).await?;
195 Some(outer.as_ref().clone())
196 }
197
198 async fn get_or_insert<T: DeepSizeOf + Send + Sync + 'static, F, Fut>(
204 &self,
205 key: String,
206 loader: F,
207 ) -> Result<Arc<T>>
208 where
209 F: FnOnce(&str) -> Fut,
210 Fut: Future<Output = Result<T>> + Send,
211 {
212 let full_key = self.get_key(&key);
213 let cache_key = (full_key, TypeId::of::<T>());
214
215 let hits = self.hits.clone();
217 let misses = self.misses.clone();
218
219 let (error_tx, error_rx) = tokio::sync::oneshot::channel();
221 let (init_run_tx, mut init_run_rx) = tokio::sync::oneshot::channel();
222
223 let init = Box::pin(async move {
224 let _ = init_run_tx.send(());
225 misses.fetch_add(1, Ordering::Relaxed);
226 match loader(&key).await {
227 Ok(value) => Some(SizedRecord::new(Arc::new(value))),
228 Err(e) => {
229 let _ = error_tx.send(e);
230 None
231 }
232 }
233 });
234
235 match self.cache.optionally_get_with(cache_key, init).await {
236 Some(metadata) => {
237 match init_run_rx.try_recv() {
239 Ok(()) => {
240 }
242 Err(_) => {
243 hits.fetch_add(1, Ordering::Relaxed);
245 }
246 }
247 Ok(metadata.record.clone().downcast::<T>().unwrap())
248 }
249 None => {
250 match error_rx.await {
252 Ok(err) => Err(err),
253 Err(_) => Err(crate::Error::internal(
254 "Failed to retrieve error from cache loader",
255 )),
256 }
257 }
258 }
259 }
260
261 pub async fn stats(&self) -> CacheStats {
262 self.cache.run_pending_tasks().await;
263 CacheStats {
264 hits: self.hits.load(Ordering::Relaxed),
265 misses: self.misses.load(Ordering::Relaxed),
266 num_entries: self.cache.entry_count() as usize,
267 size_bytes: self.cache.weighted_size() as usize,
268 }
269 }
270
271 pub async fn clear(&self) {
272 self.cache.invalidate_all();
273 self.cache.run_pending_tasks().await;
274 self.hits.store(0, Ordering::Relaxed);
275 self.misses.store(0, Ordering::Relaxed);
276 }
277
278 pub async fn insert_with_key<K>(&self, cache_key: &K, metadata: Arc<K::ValueType>)
280 where
281 K: CacheKey,
282 K::ValueType: DeepSizeOf + Send + Sync + 'static,
283 {
284 self.insert(&cache_key.key(), metadata).boxed().await
285 }
286
287 pub async fn get_with_key<K>(&self, cache_key: &K) -> Option<Arc<K::ValueType>>
288 where
289 K: CacheKey,
290 K::ValueType: DeepSizeOf + Send + Sync + 'static,
291 {
292 self.get::<K::ValueType>(&cache_key.key()).boxed().await
293 }
294
295 pub async fn get_or_insert_with_key<K, F, Fut>(
296 &self,
297 cache_key: K,
298 loader: F,
299 ) -> Result<Arc<K::ValueType>>
300 where
301 K: CacheKey,
302 K::ValueType: DeepSizeOf + Send + Sync + 'static,
303 F: FnOnce() -> Fut,
304 Fut: Future<Output = Result<K::ValueType>> + Send,
305 {
306 let key_str = cache_key.key().into_owned();
307 Box::pin(self.get_or_insert(key_str, |_| loader())).await
308 }
309
310 pub async fn insert_unsized_with_key<K>(&self, cache_key: &K, metadata: Arc<K::ValueType>)
311 where
312 K: UnsizedCacheKey,
313 K::ValueType: DeepSizeOf + Send + Sync + 'static,
314 {
315 self.insert_unsized(&cache_key.key(), metadata)
316 .boxed()
317 .await
318 }
319
320 pub async fn get_unsized_with_key<K>(&self, cache_key: &K) -> Option<Arc<K::ValueType>>
321 where
322 K: UnsizedCacheKey,
323 K::ValueType: DeepSizeOf + Send + Sync + 'static,
324 {
325 self.get_unsized::<K::ValueType>(&cache_key.key())
326 .boxed()
327 .await
328 }
329}
330
331#[derive(Clone, Debug)]
334pub struct WeakLanceCache {
335 inner: std::sync::Weak<Cache<(String, TypeId), SizedRecord>>,
336 prefix: String,
337 hits: Arc<AtomicU64>,
338 misses: Arc<AtomicU64>,
339}
340
341impl WeakLanceCache {
342 pub fn from(cache: &LanceCache) -> Self {
344 Self {
345 inner: Arc::downgrade(&cache.cache),
346 prefix: cache.prefix.clone(),
347 hits: cache.hits.clone(),
348 misses: cache.misses.clone(),
349 }
350 }
351
352 pub fn with_key_prefix(&self, prefix: &str) -> Self {
354 Self {
355 inner: self.inner.clone(),
356 prefix: format!("{}{}/", self.prefix, prefix),
357 hits: self.hits.clone(),
358 misses: self.misses.clone(),
359 }
360 }
361
362 fn get_key(&self, key: &str) -> String {
363 if self.prefix.is_empty() {
364 key.to_string()
365 } else {
366 format!("{}/{}", self.prefix, key)
367 }
368 }
369
370 pub async fn get<T: DeepSizeOf + Send + Sync + 'static>(&self, key: &str) -> Option<Arc<T>> {
372 let cache = self.inner.upgrade()?;
373 let key = self.get_key(key);
374 if let Some(metadata) = cache.get(&(key, TypeId::of::<T>())).await {
375 self.hits.fetch_add(1, Ordering::Relaxed);
376 Some(metadata.record.clone().downcast::<T>().unwrap())
377 } else {
378 self.misses.fetch_add(1, Ordering::Relaxed);
379 None
380 }
381 }
382
383 pub async fn insert<T: DeepSizeOf + Send + Sync + 'static>(
386 &self,
387 key: &str,
388 value: Arc<T>,
389 ) -> bool {
390 if let Some(cache) = self.inner.upgrade() {
391 let key = self.get_key(key);
392 let record = SizedRecord::new(value);
393 cache.insert((key, TypeId::of::<T>()), record).await;
394 true
395 } else {
396 log::warn!("WeakLanceCache: cache no longer available, unable to insert item");
397 false
398 }
399 }
400
401 pub async fn get_or_insert<T, F, Fut>(&self, key: &str, f: F) -> Result<Arc<T>>
403 where
404 T: DeepSizeOf + Send + Sync + 'static,
405 F: FnOnce() -> Fut,
406 Fut: Future<Output = Result<T>> + Send,
407 {
408 if let Some(cache) = self.inner.upgrade() {
409 let full_key = self.get_key(key);
410 let cache_key = (full_key.clone(), TypeId::of::<T>());
411
412 let hits = self.hits.clone();
414 let misses = self.misses.clone();
415
416 let (init_run_tx, mut init_run_rx) = tokio::sync::oneshot::channel();
418 let (error_tx, error_rx) = tokio::sync::oneshot::channel();
419
420 let init = Box::pin(async move {
421 let _ = init_run_tx.send(());
422 misses.fetch_add(1, Ordering::Relaxed);
423 match f().await {
424 Ok(value) => Some(SizedRecord::new(Arc::new(value))),
425 Err(e) => {
426 let _ = error_tx.send(e);
427 None
428 }
429 }
430 });
431
432 match cache.optionally_get_with(cache_key, init).await {
433 Some(record) => {
434 match init_run_rx.try_recv() {
436 Ok(()) => {
437 }
439 Err(_) => {
440 hits.fetch_add(1, Ordering::Relaxed);
442 }
443 }
444 Ok(record.record.clone().downcast::<T>().unwrap())
445 }
446 None => {
447 match error_rx.await {
449 Ok(e) => Err(e),
450 Err(_) => Err(crate::Error::internal(
451 "Failed to receive error from cache init function".to_string(),
452 )),
453 }
454 }
455 }
456 } else {
457 log::warn!("WeakLanceCache: cache no longer available, computing without caching");
458 f().await.map(Arc::new)
459 }
460 }
461
462 pub async fn get_or_insert_with_key<K, F, Fut>(
464 &self,
465 cache_key: K,
466 loader: F,
467 ) -> Result<Arc<K::ValueType>>
468 where
469 K: CacheKey,
470 K::ValueType: DeepSizeOf + Send + Sync + 'static,
471 F: FnOnce() -> Fut,
472 Fut: Future<Output = Result<K::ValueType>> + Send,
473 {
474 let key_str = cache_key.key().into_owned();
475 self.get_or_insert(&key_str, loader).await
476 }
477
478 pub async fn insert_with_key<K>(&self, cache_key: &K, value: Arc<K::ValueType>) -> bool
481 where
482 K: CacheKey,
483 K::ValueType: DeepSizeOf + Send + Sync + 'static,
484 {
485 let key_str = cache_key.key().into_owned();
486 self.insert(&key_str, value).await
487 }
488
489 pub async fn get_with_key<K>(&self, cache_key: &K) -> Option<Arc<K::ValueType>>
491 where
492 K: CacheKey,
493 K::ValueType: DeepSizeOf + Send + Sync + 'static,
494 {
495 let key_str = cache_key.key().into_owned();
496 self.get(&key_str).await
497 }
498
499 pub async fn get_unsized<T: DeepSizeOf + Send + Sync + 'static + ?Sized>(
501 &self,
502 key: &str,
503 ) -> Option<Arc<T>> {
504 let cache = self.inner.upgrade()?;
506 let key = self.get_key(key);
507 if let Some(metadata) = cache.get(&(key, TypeId::of::<Arc<T>>())).await {
508 metadata
509 .record
510 .clone()
511 .downcast::<Arc<T>>()
512 .ok()
513 .map(|arc| arc.as_ref().clone())
514 } else {
515 None
516 }
517 }
518
519 pub async fn insert_unsized<T: DeepSizeOf + Send + Sync + 'static + ?Sized>(
521 &self,
522 key: &str,
523 value: Arc<T>,
524 ) {
525 if let Some(cache) = self.inner.upgrade() {
526 let key = self.get_key(key);
527 let record = SizedRecord::new(Arc::new(value));
528 cache.insert((key, TypeId::of::<Arc<T>>()), record).await;
529 } else {
530 log::warn!("WeakLanceCache: cache no longer available, unable to insert unsized item");
531 }
532 }
533
534 pub async fn get_unsized_with_key<K>(&self, cache_key: &K) -> Option<Arc<K::ValueType>>
536 where
537 K: UnsizedCacheKey,
538 K::ValueType: DeepSizeOf + Send + Sync + 'static,
539 {
540 let key_str = cache_key.key();
541 self.get_unsized(&key_str).await
542 }
543
544 pub async fn insert_unsized_with_key<K>(&self, cache_key: &K, value: Arc<K::ValueType>)
546 where
547 K: UnsizedCacheKey,
548 K::ValueType: DeepSizeOf + Send + Sync + 'static,
549 {
550 let key_str = cache_key.key();
551 self.insert_unsized(&key_str, value).await
552 }
553}
554
555pub trait CacheKey {
556 type ValueType;
557
558 fn key(&self) -> Cow<'_, str>;
559}
560
561pub trait UnsizedCacheKey {
562 type ValueType: ?Sized;
563
564 fn key(&self) -> Cow<'_, str>;
565}
566
567#[derive(Debug, Clone)]
568pub struct CacheStats {
569 pub hits: u64,
571 pub misses: u64,
573 pub num_entries: usize,
575 pub size_bytes: usize,
577}
578
579impl CacheStats {
580 pub fn hit_ratio(&self) -> f32 {
581 if self.hits + self.misses == 0 {
582 0.0
583 } else {
584 self.hits as f32 / (self.hits + self.misses) as f32
585 }
586 }
587
588 pub fn miss_ratio(&self) -> f32 {
589 if self.hits + self.misses == 0 {
590 0.0
591 } else {
592 self.misses as f32 / (self.hits + self.misses) as f32
593 }
594 }
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600
601 #[tokio::test]
602 async fn test_cache_bytes() {
603 let item = Arc::new(vec![1, 2, 3]);
604 let item_size = item.deep_size_of(); let capacity = 10 * item_size;
606
607 let cache = LanceCache::with_capacity(capacity);
608 assert_eq!(cache.size_bytes().await, 0);
609 assert_eq!(cache.approx_size_bytes(), 0);
610
611 let item = Arc::new(vec![1, 2, 3]);
612 cache.insert("key", item.clone()).await;
613 assert_eq!(cache.size().await, 1);
614 assert_eq!(cache.size_bytes().await, item_size);
615 assert_eq!(cache.approx_size_bytes(), item_size);
616
617 let retrieved = cache.get::<Vec<i32>>("key").await.unwrap();
618 assert_eq!(*retrieved, *item);
619
620 for i in 0..20 {
622 cache
623 .insert(&format!("key_{}", i), Arc::new(vec![i, i, i]))
624 .await;
625 }
626 assert_eq!(cache.size_bytes().await, capacity);
627 assert_eq!(cache.size().await, 10);
628 }
629
630 #[tokio::test]
631 async fn test_cache_trait_objects() {
632 #[derive(Debug, DeepSizeOf)]
633 struct MyType(i32);
634
635 trait MyTrait: DeepSizeOf + Send + Sync + Any {
636 fn as_any(&self) -> &dyn Any;
637 }
638
639 impl MyTrait for MyType {
640 fn as_any(&self) -> &dyn Any {
641 self
642 }
643 }
644
645 let item = Arc::new(MyType(42));
646 let item_dyn: Arc<dyn MyTrait> = item;
647
648 let cache = LanceCache::with_capacity(1000);
649 cache.insert_unsized("test", item_dyn).await;
650
651 let retrieved = cache.get_unsized::<dyn MyTrait>("test").await.unwrap();
652 let retrieved = retrieved.as_any().downcast_ref::<MyType>().unwrap();
653 assert_eq!(retrieved.0, 42);
654 }
655
656 #[tokio::test]
657 async fn test_cache_stats_basic() {
658 let cache = LanceCache::with_capacity(1000);
659
660 let stats = cache.stats().await;
662 assert_eq!(stats.hits, 0);
663 assert_eq!(stats.misses, 0);
664
665 let result = cache.get::<Vec<i32>>("nonexistent");
667 assert!(result.await.is_none());
668 let stats = cache.stats().await;
669 assert_eq!(stats.hits, 0);
670 assert_eq!(stats.misses, 1);
671
672 cache.insert("key1", Arc::new(vec![1, 2, 3])).await;
674 let result = cache.get::<Vec<i32>>("key1");
675 assert!(result.await.is_some());
676 let stats = cache.stats().await;
677 assert_eq!(stats.hits, 1);
678 assert_eq!(stats.misses, 1);
679
680 let result = cache.get::<Vec<i32>>("key1");
682 assert!(result.await.is_some());
683 let stats = cache.stats().await;
684 assert_eq!(stats.hits, 2);
685 assert_eq!(stats.misses, 1);
686
687 let result = cache.get::<Vec<i32>>("nonexistent2");
689 assert!(result.await.is_none());
690 let stats = cache.stats().await;
691 assert_eq!(stats.hits, 2);
692 assert_eq!(stats.misses, 2);
693 }
694
695 #[tokio::test]
696 async fn test_cache_stats_with_prefixes() {
697 let base_cache = LanceCache::with_capacity(1000);
698 let prefixed_cache = base_cache.with_key_prefix("test");
699
700 let stats = base_cache.stats().await;
702 assert_eq!(stats.hits, 0);
703 assert_eq!(stats.misses, 0);
704
705 let stats = prefixed_cache.stats().await;
706 assert_eq!(stats.hits, 0);
707 assert_eq!(stats.misses, 0);
708
709 let result = prefixed_cache.get::<Vec<i32>>("key1");
711 assert!(result.await.is_none());
712
713 let stats = base_cache.stats().await;
715 assert_eq!(stats.hits, 0);
716 assert_eq!(stats.misses, 1);
717
718 let stats = prefixed_cache.stats().await;
719 assert_eq!(stats.hits, 0);
720 assert_eq!(stats.misses, 1);
721
722 prefixed_cache.insert("key1", Arc::new(vec![1, 2, 3])).await;
724 let result = prefixed_cache.get::<Vec<i32>>("key1");
725 assert!(result.await.is_some());
726
727 let stats = base_cache.stats().await;
729 assert_eq!(stats.hits, 1);
730 assert_eq!(stats.misses, 1);
731
732 let stats = prefixed_cache.stats().await;
733 assert_eq!(stats.hits, 1);
734 assert_eq!(stats.misses, 1);
735 }
736
737 #[tokio::test]
738 async fn test_cache_stats_unsized() {
739 #[derive(Debug, DeepSizeOf)]
740 struct MyType(i32);
741
742 trait MyTrait: DeepSizeOf + Send + Sync + Any {}
743
744 impl MyTrait for MyType {}
745
746 let cache = LanceCache::with_capacity(1000);
747
748 let result = cache.get_unsized::<dyn MyTrait>("test");
750 assert!(result.await.is_none());
751 let stats = cache.stats().await;
752 assert_eq!(stats.hits, 0);
753 assert_eq!(stats.misses, 1);
754
755 let item = Arc::new(MyType(42));
757 let item_dyn: Arc<dyn MyTrait> = item;
758 cache.insert_unsized("test", item_dyn).await;
759
760 let result = cache.get_unsized::<dyn MyTrait>("test");
761 assert!(result.await.is_some());
762 let stats = cache.stats().await;
763 assert_eq!(stats.hits, 1);
764 assert_eq!(stats.misses, 1);
765 }
766
767 #[tokio::test]
768 async fn test_cache_stats_get_or_insert() {
769 let cache = LanceCache::with_capacity(1000);
770
771 let result: Arc<Vec<i32>> = cache
773 .get_or_insert("key1".to_string(), |_key| async { Ok(vec![1, 2, 3]) })
774 .await
775 .unwrap();
776 assert_eq!(*result, vec![1, 2, 3]);
777
778 let stats = cache.stats().await;
779 assert_eq!(stats.hits, 0);
780 assert_eq!(stats.misses, 1);
781
782 let result: Arc<Vec<i32>> = cache
784 .get_or_insert("key1".to_string(), |_key| async {
785 panic!("Should not be called")
786 })
787 .await
788 .unwrap();
789 assert_eq!(*result, vec![1, 2, 3]);
790
791 let stats = cache.stats().await;
792 assert_eq!(stats.hits, 1);
793 assert_eq!(stats.misses, 1);
794
795 let result: Arc<Vec<i32>> = cache
797 .get_or_insert("key2".to_string(), |_key| async { Ok(vec![4, 5, 6]) })
798 .await
799 .unwrap();
800 assert_eq!(*result, vec![4, 5, 6]);
801
802 let stats = cache.stats().await;
803 assert_eq!(stats.hits, 1);
804 assert_eq!(stats.misses, 2);
805 }
806}