1use crate::backend::CacheBackend;
21use crate::loading_cache::{
22 cancellation_safe_future::CancellationSafeFuture,
23 loader::CacheLoader,
24 {CacheGetStatus, LoadingCache},
25};
26use async_trait::async_trait;
27use futures::future::{BoxFuture, Shared};
28use futures::{FutureExt, TryFutureExt};
29use log::debug;
30use parking_lot::Mutex;
31use std::collections::HashMap;
32use std::fmt::Debug;
33use std::future::Future;
34use std::hash::Hash;
35use std::sync::Arc;
36use tokio::{
37 sync::oneshot::{error::RecvError, Sender},
38 task::JoinHandle,
39};
40
41#[derive(Debug)]
43pub struct CacheDriver<K, V, L>
44where
45 K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
46 V: Clone + Debug + Send + 'static,
47 L: CacheLoader<K = K, V = V>,
48{
49 state: Arc<Mutex<CacheState<K, V>>>,
50 loader: Arc<L>,
51}
52
53impl<K, V, L> CacheDriver<K, V, L>
54where
55 K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
56 V: Clone + Debug + Send + 'static,
57 L: CacheLoader<K = K, V = V>,
58{
59 pub fn new(backend: CacheBackend<K, V>, loader: Arc<L>) -> Self {
61 Self {
62 state: Arc::new(Mutex::new(CacheState {
63 cached_entries: backend,
64 loaders: HashMap::new(),
65 next_loader_tag: 0,
66 })),
67 loader,
68 }
69 }
70}
71
72#[async_trait]
73impl<K, V, L> LoadingCache for CacheDriver<K, V, L>
74where
75 K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
76 V: Clone + Debug + Send + 'static,
77 L: CacheLoader<K = K, V = V>,
78{
79 type K = K;
80 type V = V;
81 type GetExtra = L::Extra;
82
83 fn get_if_present(&self, k: Self::K) -> Option<Self::V> {
84 self.state.lock().cached_entries.get(&k)
85 }
86
87 async fn get_with_status(
88 &self,
89 k: Self::K,
90 extra: Self::GetExtra,
91 ) -> (Self::V, CacheGetStatus) {
92 let (fut, receiver, status) = {
95 let mut state = self.state.lock();
96
97 if let Some(v) = state.cached_entries.get(&k) {
99 return (v, CacheGetStatus::Hit);
100 }
101
102 if let Some(loader) = state.loaders.get(&k) {
104 (
105 None,
106 loader.recv.clone(),
107 CacheGetStatus::MissAlreadyLoading,
108 )
109 } else {
110 let loader_tag = state.next_loader_tag();
112
113 let (fut, loader) = create_value_loader(
115 self.state.clone(),
116 self.loader.clone(),
117 loader_tag,
118 k.clone(),
119 extra,
120 );
121
122 let receiver = loader.recv.clone();
123 state.loaders.insert(k, loader);
124
125 (Some(fut), receiver, CacheGetStatus::Miss)
126 }
127 };
128
129 if let Some(fut) = fut {
132 fut.await;
133 }
134
135 let v = retrieve_from_shared(receiver).await;
136
137 (v, status)
138 }
139
140 async fn put(&self, k: Self::K, v: Self::V) {
141 let maybe_join_handle = {
142 let mut state = self.state.lock();
143
144 let maybe_recv = if let Some(loader) = state.loaders.remove(&k) {
145 loader.set.send(v.clone()).ok();
147
148 Some(loader.recv)
153 } else {
154 None
155 };
156
157 state.cached_entries.put(k, v);
158
159 maybe_recv
160 };
161
162 if let Some(recv) = maybe_join_handle {
164 recv.await.ok();
166 }
167 }
168
169 fn invalidate(&self, k: Self::K) {
170 let mut state = self.state.lock();
171
172 if state.loaders.remove(&k).is_some() {
173 debug!("Running loader for key {:?} is removed", k);
174 }
175
176 state.cached_entries.remove(&k);
177 }
178}
179
180impl<K, V, L> Drop for CacheDriver<K, V, L>
181where
182 K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
183 V: Clone + Debug + Send + 'static,
184 L: CacheLoader<K = K, V = V>,
185{
186 fn drop(&mut self) {
187 for (_k, loader) in self.state.lock().loaders.drain() {
188 let handle = loader.join_handle.lock();
193 if let Some(handle) = handle.as_ref() {
194 handle.abort();
195 }
196 }
197 }
198}
199
200fn create_value_loader<K, V, Extra>(
201 state: Arc<Mutex<CacheState<K, V>>>,
202 loader: Arc<dyn CacheLoader<K = K, V = V, Extra = Extra>>,
203 loader_tag: u64,
204 k: K,
205 extra: Extra,
206) -> (
207 CancellationSafeFuture<impl Future<Output = ()>>,
208 ValueLoader<V>,
209)
210where
211 K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
212 V: Clone + Debug + Send + 'static,
213 Extra: Debug + Send + 'static,
214{
215 let (tx_main, rx_main) = tokio::sync::oneshot::channel();
216 let receiver = rx_main
217 .map_ok(|v| Arc::new(Mutex::new(v)))
218 .map_err(Arc::new)
219 .boxed()
220 .shared();
221 let (tx_set, rx_set) = tokio::sync::oneshot::channel();
222
223 let join_handle_receiver = Arc::new(Mutex::new(None));
226 let fut = async move {
227 let loader_fut = async move {
228 let mut submitter = ResultSubmitter::new(state, k.clone(), loader_tag);
229
230 let v = loader.load(k, extra).await;
234
235 let was_running = submitter.submit(v.clone());
237
238 if !was_running {
239 loop {
243 tokio::task::yield_now().await;
244 }
245 }
246
247 v
248 };
249
250 let v = futures::select_biased! {
252 maybe_v = rx_set.fuse() => {
253 match maybe_v {
254 Ok(v) => {
255 v
259 }
260 Err(_) => {
261 debug!(
263 "Sender for side-loading data into running loader gone.",
264 );
265 return;
266 }
267 }
268 }
269 v = loader_fut.fuse() => v,
270 };
271
272 tx_main.send(v).ok();
275 };
276 let fut = CancellationSafeFuture::new(fut, Arc::clone(&join_handle_receiver));
277
278 (
279 fut,
280 ValueLoader {
281 recv: receiver,
282 set: tx_set,
283 join_handle: join_handle_receiver,
284 tag: loader_tag,
285 },
286 )
287}
288
289#[derive(Debug)]
293struct CacheState<K, V>
294where
295 K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
296 V: Clone + Debug + Send + 'static,
297{
298 cached_entries: CacheBackend<K, V>,
300
301 loaders: HashMap<K, ValueLoader<V>>,
303
304 next_loader_tag: u64,
307}
308
309impl<K, V> CacheState<K, V>
310where
311 K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
312 V: Clone + Debug + Send + 'static,
313{
314 #[inline]
317 fn next_loader_tag(&mut self) -> u64 {
318 let ret = self.next_loader_tag;
319 if self.next_loader_tag != u64::MAX {
320 self.next_loader_tag += 1;
321 } else {
322 self.next_loader_tag = 0;
323 }
324 ret
325 }
326}
327
328#[derive(Debug)]
330struct ValueLoader<V> {
331 recv: SharedReceiver<V>,
333
334 set: Sender<V>,
336
337 join_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
341
342 tag: u64,
345}
346
347type SharedReceiver<V> =
358 Shared<BoxFuture<'static, Result<Arc<Mutex<V>>, Arc<RecvError>>>>;
359
360async fn retrieve_from_shared<V>(receiver: SharedReceiver<V>) -> V
362where
363 V: Clone + Send,
364{
365 receiver
366 .await
367 .expect("cache loader panicked, see logs")
368 .lock()
369 .clone()
370}
371
372struct ResultSubmitter<K, V>
376where
377 K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
378 V: Clone + Debug + Send + 'static,
379{
380 state: Arc<Mutex<CacheState<K, V>>>,
381 tag: u64,
382 k: Option<K>,
383 v: Option<V>,
384}
385
386impl<K, V> ResultSubmitter<K, V>
387where
388 K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
389 V: Clone + Debug + Send + 'static,
390{
391 fn new(state: Arc<Mutex<CacheState<K, V>>>, k: K, tag: u64) -> Self {
392 Self {
393 state,
394 tag,
395 k: Some(k),
396 v: None,
397 }
398 }
399
400 fn submit(&mut self, v: V) -> bool {
404 assert!(self.v.is_none());
405 self.v = Some(v);
406 self.finalize()
407 }
408
409 fn finalize(&mut self) -> bool {
413 let k = self.k.take().expect("finalized twice");
414 let mut state = self.state.lock();
415
416 match state.loaders.get(&k) {
417 Some(loader) if loader.tag == self.tag => {
418 state.loaders.remove(&k);
419
420 if let Some(v) = self.v.take() {
421 state.cached_entries.put(k, v);
424 }
425
426 true
427 }
428 _ => {
429 false
434 }
435 }
436 }
437}
438
439impl<K, V> Drop for ResultSubmitter<K, V>
440where
441 K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
442 V: Clone + Debug + Send + 'static,
443{
444 fn drop(&mut self) {
445 if self.k.is_some() {
446 self.finalize();
448 }
449 }
450}
451
452#[cfg(test)]
453mod tests {
454
455 use crate::backend::policy::lru::lru_cache::LruCache;
456 use crate::listener::cache_policy::CachePolicyListener;
457 use crate::{CacheBackend, CacheDriver, CacheLoader, CachePolicyWithListener};
458
459 use crate::backend::policy::lru::DefaultResourceCounter;
460 use crate::loading_cache::LoadingCache;
461 use async_trait::async_trait;
462 use parking_lot::Mutex;
463 use std::sync::mpsc::{channel, Sender};
464 use std::sync::Arc;
465
466 #[tokio::test]
467 async fn test_removal_entries() {
468 let cache_policy =
469 LruCache::with_resource_counter(DefaultResourceCounter::new(3));
470 let loader = TestStringCacheLoader {
471 prefix: "file".to_string(),
472 };
473 let (sender, receiver) = channel::<(String, String)>();
474 let listener = Arc::new(EntryRemovalListener::new(sender));
475 let policy_with_listener =
476 CachePolicyWithListener::new(cache_policy, vec![listener.clone()]);
477 let cache_backend = CacheBackend::new(policy_with_listener);
478 let loading_cache = CacheDriver::new(cache_backend, Arc::new(loader));
479
480 assert_eq!(
481 "file1".to_string(),
482 loading_cache.get("1".to_string(), ()).await
483 );
484 assert_eq!(
485 "file2".to_string(),
486 loading_cache.get("2".to_string(), ()).await
487 );
488 assert_eq!(
489 "file3".to_string(),
490 loading_cache.get("3".to_string(), ()).await
491 );
492 assert_eq!(
493 "file4".to_string(),
494 loading_cache.get("4".to_string(), ()).await
495 );
496 assert_eq!(Ok(("1".to_string(), "file1".to_string())), receiver.recv());
497 assert!(loading_cache.get_if_present("1".to_string()).is_none());
498
499 loading_cache
500 .put("2".to_string(), "file2-bak".to_string())
501 .await;
502 assert_eq!(
503 "file5".to_string(),
504 loading_cache.get("5".to_string(), ()).await
505 );
506 assert_eq!(Ok(("3".to_string(), "file3".to_string())), receiver.recv());
507 assert!(loading_cache.get_if_present("3".to_string()).is_none());
508 assert!(loading_cache.get_if_present("2".to_string()).is_some());
509
510 loading_cache.invalidate("2".to_string());
511 assert_eq!(
512 Ok(("2".to_string(), "file2-bak".to_string())),
513 receiver.recv()
514 );
515 assert!(loading_cache.get_if_present("2".to_string()).is_none());
516 }
517
518 #[derive(Debug)]
519 struct EntryRemovalListener {
520 sender: Arc<Mutex<Sender<(String, String)>>>,
521 }
522
523 impl EntryRemovalListener {
524 pub fn new(sender: Sender<(String, String)>) -> Self {
525 Self {
526 sender: Arc::new(Mutex::new(sender)),
527 }
528 }
529 }
530
531 impl CachePolicyListener for EntryRemovalListener {
532 type K = String;
533 type V = String;
534
535 fn listen_on_get(&self, _k: Self::K, _v: Option<Self::V>) {
536 }
538
539 fn listen_on_peek(&self, _k: Self::K, _v: Option<Self::V>) {
540 }
542
543 fn listen_on_put(&self, _k: Self::K, _v: Self::V, _old_v: Option<Self::V>) {
544 }
546
547 fn listen_on_remove(&self, k: Self::K, v: Option<Self::V>) {
548 if let Some(v) = v {
549 self.sender.lock().send((k, v)).unwrap();
550 }
551 }
552
553 fn listen_on_pop(&self, entry: (Self::K, Self::V)) {
554 self.sender.lock().send(entry).unwrap();
555 }
556 }
557
558 #[derive(Debug)]
559 struct TestStringCacheLoader {
560 prefix: String,
561 }
562
563 #[async_trait]
564 impl CacheLoader for TestStringCacheLoader {
565 type K = String;
566 type V = String;
567 type Extra = ();
568
569 async fn load(&self, k: Self::K, _extra: Self::Extra) -> Self::V {
570 format!("{}{k}", self.prefix)
571 }
572 }
573}