1use std::any::{Any, TypeId};
7use std::sync::{
8 atomic::{AtomicU64, Ordering},
9 Arc,
10};
11
12use futures::Future;
13use moka::sync::Cache;
14
15use crate::Result;
16
17pub use deepsize::{Context, DeepSizeOf};
18
19type ArcAny = Arc<dyn Any + Send + Sync>;
20
21#[derive(Clone)]
22struct SizedRecord {
23 record: ArcAny,
24 size_accessor: Arc<dyn Fn(&ArcAny) -> usize + Send + Sync>,
25}
26
27impl std::fmt::Debug for SizedRecord {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("SizedRecord")
30 .field("record", &self.record)
31 .finish()
32 }
33}
34
35impl SizedRecord {
36 fn new<T: DeepSizeOf + Send + Sync + 'static>(record: Arc<T>) -> Self {
37 let size_accessor =
39 |record: &ArcAny| -> usize { record.downcast_ref::<T>().unwrap().deep_size_of() + 8 };
40 Self {
41 record,
42 size_accessor: Arc::new(size_accessor),
43 }
44 }
45}
46
47#[derive(Clone)]
48pub struct LanceCache {
49 cache: Arc<Cache<(String, TypeId), SizedRecord>>,
50 prefix: String,
51 hits: Arc<AtomicU64>,
52 misses: Arc<AtomicU64>,
53}
54
55impl std::fmt::Debug for LanceCache {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.debug_struct("LanceCache")
58 .field("cache", &self.cache)
59 .finish()
60 }
61}
62
63impl DeepSizeOf for LanceCache {
64 fn deep_size_of_children(&self, _: &mut Context) -> usize {
65 self.cache
66 .iter()
67 .map(|(_, v)| (v.size_accessor)(&v.record))
68 .sum()
69 }
70}
71
72impl LanceCache {
73 pub fn with_capacity(capacity: usize) -> Self {
74 let cache = Cache::builder()
75 .max_capacity(capacity as u64)
76 .weigher(|_, v: &SizedRecord| {
77 (v.size_accessor)(&v.record).try_into().unwrap_or(u32::MAX)
78 })
79 .support_invalidation_closures()
80 .build();
81 Self {
82 cache: Arc::new(cache),
83 prefix: String::new(),
84 hits: Arc::new(AtomicU64::new(0)),
85 misses: Arc::new(AtomicU64::new(0)),
86 }
87 }
88
89 pub fn no_cache() -> Self {
90 Self {
91 cache: Arc::new(Cache::new(0)),
92 prefix: String::new(),
93 hits: Arc::new(AtomicU64::new(0)),
94 misses: Arc::new(AtomicU64::new(0)),
95 }
96 }
97
98 pub fn with_key_prefix(&self, prefix: &str) -> Self {
106 Self {
107 cache: self.cache.clone(),
108 prefix: format!("{}{}/", self.prefix, prefix),
109 hits: self.hits.clone(),
110 misses: self.misses.clone(),
111 }
112 }
113
114 fn get_key(&self, key: &str) -> String {
115 if self.prefix.is_empty() {
116 key.to_string()
117 } else {
118 format!("{}/{}", self.prefix, key)
119 }
120 }
121
122 pub fn invalidate_prefix(&self, prefix: &str) {
127 let full_prefix = format!("{}{}", self.prefix, prefix);
128 self.cache
129 .invalidate_entries_if(move |(key, _typeid), _value| key.starts_with(&full_prefix))
130 .expect("Cache configured correctly");
131 }
132
133 pub fn size(&self) -> usize {
134 self.cache.run_pending_tasks();
135 self.cache.entry_count() as usize
136 }
137
138 pub fn approx_size(&self) -> usize {
139 self.cache.entry_count() as usize
140 }
141
142 pub fn size_bytes(&self) -> usize {
143 self.cache.run_pending_tasks();
144 self.approx_size_bytes()
145 }
146
147 pub fn approx_size_bytes(&self) -> usize {
148 self.cache.weighted_size() as usize
149 }
150
151 pub fn insert<T: DeepSizeOf + Send + Sync + 'static>(&self, key: &str, metadata: Arc<T>) {
152 let key = self.get_key(key);
153 self.cache
154 .insert((key, TypeId::of::<T>()), SizedRecord::new(metadata));
155 }
156
157 pub fn insert_unsized<T: DeepSizeOf + Send + Sync + 'static + ?Sized>(
158 &self,
159 key: &str,
160 metadata: Arc<T>,
161 ) {
162 self.insert(key, Arc::new(metadata))
164 }
165
166 pub fn get<T: DeepSizeOf + Send + Sync + 'static>(&self, key: &str) -> Option<Arc<T>> {
167 let key = self.get_key(key);
168 if let Some(metadata) = self.cache.get(&(key, TypeId::of::<T>())) {
169 self.hits.fetch_add(1, Ordering::Relaxed);
170 Some(metadata.record.clone().downcast::<T>().unwrap())
171 } else {
172 self.misses.fetch_add(1, Ordering::Relaxed);
173 None
174 }
175 }
176
177 pub fn get_unsized<T: DeepSizeOf + Send + Sync + 'static + ?Sized>(
178 &self,
179 key: &str,
180 ) -> Option<Arc<T>> {
181 let outer = self.get::<Arc<T>>(key)?;
182 Some(outer.as_ref().clone())
183 }
184
185 pub async fn get_or_insert<T: DeepSizeOf + Send + Sync + 'static, F, Fut>(
191 &self,
192 key: String,
193 loader: F,
194 ) -> Result<Arc<T>>
195 where
196 F: FnOnce(&str) -> Fut,
197 Fut: Future<Output = Result<T>>,
198 {
199 let full_key = self.get_key(&key);
200 if let Some(metadata) = self.cache.get(&(full_key, TypeId::of::<T>())) {
201 self.hits.fetch_add(1, Ordering::Relaxed);
202 return Ok(metadata.record.clone().downcast::<T>().unwrap());
203 }
204
205 self.misses.fetch_add(1, Ordering::Relaxed);
206 let metadata = Arc::new(loader(&key).await?);
207 self.insert(&key, metadata.clone());
208 Ok(metadata)
209 }
210
211 pub fn stats(&self) -> CacheStats {
212 CacheStats {
213 hits: self.hits.load(Ordering::Relaxed),
214 misses: self.misses.load(Ordering::Relaxed),
215 }
216 }
217}
218
219#[derive(Debug, Clone)]
220pub struct CacheStats {
221 pub hits: u64,
223 pub misses: u64,
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_cache_bytes() {
233 let item = Arc::new(vec![1, 2, 3]);
234 let item_size = item.deep_size_of(); let capacity = 10 * item_size;
236
237 let cache = LanceCache::with_capacity(capacity);
238 assert_eq!(cache.size_bytes(), 0);
239 assert_eq!(cache.approx_size_bytes(), 0);
240
241 let item = Arc::new(vec![1, 2, 3]);
242 cache.insert("key", item.clone());
243 assert_eq!(cache.size(), 1);
244 assert_eq!(cache.size_bytes(), item_size);
245 assert_eq!(cache.approx_size_bytes(), item_size);
246
247 let retrieved = cache.get::<Vec<i32>>("key").unwrap();
248 assert_eq!(*retrieved, *item);
249
250 for i in 0..20 {
252 cache.insert(&format!("key_{}", i), Arc::new(vec![i, i, i]));
253 }
254 assert_eq!(cache.size_bytes(), capacity);
255 assert_eq!(cache.size(), 10);
256 }
257
258 #[test]
259 fn test_cache_trait_objects() {
260 #[derive(Debug, DeepSizeOf)]
261 struct MyType(i32);
262
263 trait MyTrait: DeepSizeOf + Send + Sync + Any {
264 fn as_any(&self) -> &dyn Any;
265 }
266
267 impl MyTrait for MyType {
268 fn as_any(&self) -> &dyn Any {
269 self
270 }
271 }
272
273 let item = Arc::new(MyType(42));
274 let item_dyn: Arc<dyn MyTrait> = item;
275
276 let cache = LanceCache::with_capacity(1000);
277 cache.insert_unsized("test", item_dyn);
278
279 let retrieved = cache.get_unsized::<dyn MyTrait>("test").unwrap();
280 let retrieved = retrieved.as_any().downcast_ref::<MyType>().unwrap();
281 assert_eq!(retrieved.0, 42);
282 }
283
284 #[test]
285 fn test_cache_stats_basic() {
286 let cache = LanceCache::with_capacity(1000);
287
288 let stats = cache.stats();
290 assert_eq!(stats.hits, 0);
291 assert_eq!(stats.misses, 0);
292
293 let result = cache.get::<Vec<i32>>("nonexistent");
295 assert!(result.is_none());
296 let stats = cache.stats();
297 assert_eq!(stats.hits, 0);
298 assert_eq!(stats.misses, 1);
299
300 cache.insert("key1", Arc::new(vec![1, 2, 3]));
302 let result = cache.get::<Vec<i32>>("key1");
303 assert!(result.is_some());
304 let stats = cache.stats();
305 assert_eq!(stats.hits, 1);
306 assert_eq!(stats.misses, 1);
307
308 let result = cache.get::<Vec<i32>>("key1");
310 assert!(result.is_some());
311 let stats = cache.stats();
312 assert_eq!(stats.hits, 2);
313 assert_eq!(stats.misses, 1);
314
315 let result = cache.get::<Vec<i32>>("nonexistent2");
317 assert!(result.is_none());
318 let stats = cache.stats();
319 assert_eq!(stats.hits, 2);
320 assert_eq!(stats.misses, 2);
321 }
322
323 #[test]
324 fn test_cache_stats_with_prefixes() {
325 let base_cache = LanceCache::with_capacity(1000);
326 let prefixed_cache = base_cache.with_key_prefix("test");
327
328 let stats = base_cache.stats();
330 assert_eq!(stats.hits, 0);
331 assert_eq!(stats.misses, 0);
332
333 let stats = prefixed_cache.stats();
334 assert_eq!(stats.hits, 0);
335 assert_eq!(stats.misses, 0);
336
337 let result = prefixed_cache.get::<Vec<i32>>("key1");
339 assert!(result.is_none());
340
341 let stats = base_cache.stats();
343 assert_eq!(stats.hits, 0);
344 assert_eq!(stats.misses, 1);
345
346 let stats = prefixed_cache.stats();
347 assert_eq!(stats.hits, 0);
348 assert_eq!(stats.misses, 1);
349
350 prefixed_cache.insert("key1", Arc::new(vec![1, 2, 3]));
352 let result = prefixed_cache.get::<Vec<i32>>("key1");
353 assert!(result.is_some());
354
355 let stats = base_cache.stats();
357 assert_eq!(stats.hits, 1);
358 assert_eq!(stats.misses, 1);
359
360 let stats = prefixed_cache.stats();
361 assert_eq!(stats.hits, 1);
362 assert_eq!(stats.misses, 1);
363 }
364
365 #[test]
366 fn test_cache_stats_unsized() {
367 #[derive(Debug, DeepSizeOf)]
368 struct MyType(i32);
369
370 trait MyTrait: DeepSizeOf + Send + Sync + Any {}
371
372 impl MyTrait for MyType {}
373
374 let cache = LanceCache::with_capacity(1000);
375
376 let result = cache.get_unsized::<dyn MyTrait>("test");
378 assert!(result.is_none());
379 let stats = cache.stats();
380 assert_eq!(stats.hits, 0);
381 assert_eq!(stats.misses, 1);
382
383 let item = Arc::new(MyType(42));
385 let item_dyn: Arc<dyn MyTrait> = item;
386 cache.insert_unsized("test", item_dyn);
387
388 let result = cache.get_unsized::<dyn MyTrait>("test");
389 assert!(result.is_some());
390 let stats = cache.stats();
391 assert_eq!(stats.hits, 1);
392 assert_eq!(stats.misses, 1);
393 }
394
395 #[tokio::test]
396 async fn test_cache_stats_get_or_insert() {
397 let cache = LanceCache::with_capacity(1000);
398
399 let result: Arc<Vec<i32>> = cache
401 .get_or_insert("key1".to_string(), |_key| async { Ok(vec![1, 2, 3]) })
402 .await
403 .unwrap();
404 assert_eq!(*result, vec![1, 2, 3]);
405
406 let stats = cache.stats();
407 assert_eq!(stats.hits, 0);
408 assert_eq!(stats.misses, 1);
409
410 let result: Arc<Vec<i32>> = cache
412 .get_or_insert("key1".to_string(), |_key| async {
413 panic!("Should not be called")
414 })
415 .await
416 .unwrap();
417 assert_eq!(*result, vec![1, 2, 3]);
418
419 let stats = cache.stats();
420 assert_eq!(stats.hits, 1);
421 assert_eq!(stats.misses, 1);
422
423 let result: Arc<Vec<i32>> = cache
425 .get_or_insert("key2".to_string(), |_key| async { Ok(vec![4, 5, 6]) })
426 .await
427 .unwrap();
428 assert_eq!(*result, vec![4, 5, 6]);
429
430 let stats = cache.stats();
431 assert_eq!(stats.hits, 1);
432 assert_eq!(stats.misses, 2);
433 }
434}