1use std::pin::Pin;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7
8use async_trait::async_trait;
9use futures::Future;
10
11use crate::Result;
12
13use super::CacheCodec;
14use super::backend::{CacheBackend, CacheEntry, InternalCacheKey};
15
16#[derive(Clone, Debug)]
18struct MokaCacheEntry {
19 entry: CacheEntry,
20 size_bytes: usize,
21}
22
23pub struct MokaCacheBackend {
28 cache: moka::future::Cache<InternalCacheKey, MokaCacheEntry>,
29}
30
31impl std::fmt::Debug for MokaCacheBackend {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 f.debug_struct("MokaCacheBackend")
34 .field("entry_count", &self.cache.entry_count())
35 .finish()
36 }
37}
38
39impl MokaCacheBackend {
40 pub fn with_capacity(capacity: usize) -> Self {
41 let cache = moka::future::Cache::builder()
42 .max_capacity(capacity as u64)
43 .weigher(|_, v: &MokaCacheEntry| v.size_bytes.try_into().unwrap_or(u32::MAX))
44 .support_invalidation_closures()
45 .build();
46 Self { cache }
47 }
48
49 pub fn no_cache() -> Self {
50 Self {
51 cache: moka::future::Cache::new(0),
52 }
53 }
54}
55
56#[async_trait]
57impl CacheBackend for MokaCacheBackend {
58 async fn get(&self, key: &InternalCacheKey, _codec: Option<CacheCodec>) -> Option<CacheEntry> {
59 self.cache.get(key).await.map(|r| r.entry)
60 }
61
62 async fn insert(
63 &self,
64 key: &InternalCacheKey,
65 entry: CacheEntry,
66 size_bytes: usize,
67 _codec: Option<CacheCodec>,
68 ) {
69 self.cache
70 .insert(key.clone(), MokaCacheEntry { entry, size_bytes })
71 .await;
72 }
73
74 async fn get_or_insert<'a>(
75 &self,
76 key: &InternalCacheKey,
77 loader: Pin<Box<dyn Future<Output = Result<(CacheEntry, usize)>> + Send + 'a>>,
78 _codec: Option<CacheCodec>,
79 ) -> Result<(CacheEntry, bool)> {
80 let (error_tx, error_rx) = tokio::sync::oneshot::channel();
83
84 let was_miss = Arc::new(AtomicBool::new(false));
86 let was_miss_clone = was_miss.clone();
87
88 let init = async move {
89 was_miss_clone.store(true, Ordering::Relaxed);
90 match loader.await {
91 Ok((entry, size_bytes)) => Some(MokaCacheEntry { entry, size_bytes }),
92 Err(e) => {
93 let _ = error_tx.send(e);
94 None
95 }
96 }
97 };
98
99 let owned_key = key.clone();
100 match self.cache.optionally_get_with(owned_key, init).await {
101 Some(record) => {
102 let was_cached = !was_miss.load(Ordering::Relaxed);
103 Ok((record.entry, was_cached))
104 }
105 None => match error_rx.await {
106 Ok(err) => Err(err),
107 Err(_) => Err(crate::Error::internal(
108 "Failed to retrieve error from cache loader",
109 )),
110 },
111 }
112 }
113
114 async fn invalidate_prefix(&self, prefix: &str) {
115 let prefix = prefix.to_owned();
116 self.cache
117 .invalidate_entries_if(move |key, _value| key.starts_with(&prefix))
118 .expect("Cache configured correctly");
119 }
120
121 async fn clear(&self) {
122 self.cache.invalidate_all();
123 self.cache.run_pending_tasks().await;
124 }
125
126 async fn num_entries(&self) -> usize {
127 self.cache.run_pending_tasks().await;
128 self.cache.entry_count() as usize
129 }
130
131 async fn size_bytes(&self) -> usize {
132 self.cache.run_pending_tasks().await;
133 self.cache.weighted_size() as usize
134 }
135
136 fn approx_num_entries(&self) -> usize {
137 self.cache.entry_count() as usize
138 }
139
140 fn approx_size_bytes(&self) -> usize {
141 self.cache.iter().map(|(_, v)| v.size_bytes).sum()
145 }
146}