ballista_cache/loading_cache/
driver.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Main data structure, see [`CacheDriver`].
19
20use crate::backend::CacheBackend;
21use crate::loading_cache::{
22    cancellation_safe_future::CancellationSafeFuture,
23    loader::CacheLoader,
24    {CacheGetStatus, LoadingCache},
25};
26use async_trait::async_trait;
27use futures::future::{BoxFuture, Shared};
28use futures::{FutureExt, TryFutureExt};
29use log::debug;
30use parking_lot::Mutex;
31use std::collections::HashMap;
32use std::fmt::Debug;
33use std::future::Future;
34use std::hash::Hash;
35use std::sync::Arc;
36use tokio::{
37    sync::oneshot::{error::RecvError, Sender},
38    task::JoinHandle,
39};
40
41/// Combine a [`CacheBackend`] and a [`Loader`] into a single [`Cache`]
42#[derive(Debug)]
43pub struct CacheDriver<K, V, L>
44where
45    K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
46    V: Clone + Debug + Send + 'static,
47    L: CacheLoader<K = K, V = V>,
48{
49    state: Arc<Mutex<CacheState<K, V>>>,
50    loader: Arc<L>,
51}
52
53impl<K, V, L> CacheDriver<K, V, L>
54where
55    K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
56    V: Clone + Debug + Send + 'static,
57    L: CacheLoader<K = K, V = V>,
58{
59    /// Create new, empty cache with given loader function.
60    pub fn new(backend: CacheBackend<K, V>, loader: Arc<L>) -> Self {
61        Self {
62            state: Arc::new(Mutex::new(CacheState {
63                cached_entries: backend,
64                loaders: HashMap::new(),
65                next_loader_tag: 0,
66            })),
67            loader,
68        }
69    }
70}
71
72#[async_trait]
73impl<K, V, L> LoadingCache for CacheDriver<K, V, L>
74where
75    K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
76    V: Clone + Debug + Send + 'static,
77    L: CacheLoader<K = K, V = V>,
78{
79    type K = K;
80    type V = V;
81    type GetExtra = L::Extra;
82
83    fn get_if_present(&self, k: Self::K) -> Option<Self::V> {
84        self.state.lock().cached_entries.get(&k)
85    }
86
87    async fn get_with_status(
88        &self,
89        k: Self::K,
90        extra: Self::GetExtra,
91    ) -> (Self::V, CacheGetStatus) {
92        // place state locking into its own scope so it doesn't leak into the generator (async
93        // function)
94        let (fut, receiver, status) = {
95            let mut state = self.state.lock();
96
97            // check if the entry has already been cached
98            if let Some(v) = state.cached_entries.get(&k) {
99                return (v, CacheGetStatus::Hit);
100            }
101
102            // check if there is already a running loader for this key
103            if let Some(loader) = state.loaders.get(&k) {
104                (
105                    None,
106                    loader.recv.clone(),
107                    CacheGetStatus::MissAlreadyLoading,
108                )
109            } else {
110                // generate unique tag
111                let loader_tag = state.next_loader_tag();
112
113                // requires new loader
114                let (fut, loader) = create_value_loader(
115                    self.state.clone(),
116                    self.loader.clone(),
117                    loader_tag,
118                    k.clone(),
119                    extra,
120                );
121
122                let receiver = loader.recv.clone();
123                state.loaders.insert(k, loader);
124
125                (Some(fut), receiver, CacheGetStatus::Miss)
126            }
127        };
128
129        // try to run the loader future in this very task context to avoid spawning tokio tasks (which adds latency and
130        // overhead)
131        if let Some(fut) = fut {
132            fut.await;
133        }
134
135        let v = retrieve_from_shared(receiver).await;
136
137        (v, status)
138    }
139
140    async fn put(&self, k: Self::K, v: Self::V) {
141        let maybe_join_handle = {
142            let mut state = self.state.lock();
143
144            let maybe_recv = if let Some(loader) = state.loaders.remove(&k) {
145                // it's OK when the receiver side is gone (likely panicked)
146                loader.set.send(v.clone()).ok();
147
148                // When we side-load data into the running task, the task does NOT modify the
149                // backend, so we have to do that. The reason for not letting the task feed the
150                // side-loaded data back into `cached_entries` is that we would need to drop the
151                // state lock here before the task could acquire it, leading to a lock gap.
152                Some(loader.recv)
153            } else {
154                None
155            };
156
157            state.cached_entries.put(k, v);
158
159            maybe_recv
160        };
161
162        // drive running loader (if any) to completion
163        if let Some(recv) = maybe_join_handle {
164            // we do not care if the loader died (e.g. due to a panic)
165            recv.await.ok();
166        }
167    }
168
169    fn invalidate(&self, k: Self::K) {
170        let mut state = self.state.lock();
171
172        if state.loaders.remove(&k).is_some() {
173            debug!("Running loader for key {:?} is removed", k);
174        }
175
176        state.cached_entries.remove(&k);
177    }
178}
179
180impl<K, V, L> Drop for CacheDriver<K, V, L>
181where
182    K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
183    V: Clone + Debug + Send + 'static,
184    L: CacheLoader<K = K, V = V>,
185{
186    fn drop(&mut self) {
187        for (_k, loader) in self.state.lock().loaders.drain() {
188            // It's unlikely that anyone is still using the shared receiver at this point, because
189            // `Cache::get` borrows the `self`. If it is still in use, aborting the task will
190            // cancel the contained future which in turn will drop the sender of the oneshot
191            // channel. The receivers will be notified.
192            let handle = loader.join_handle.lock();
193            if let Some(handle) = handle.as_ref() {
194                handle.abort();
195            }
196        }
197    }
198}
199
200fn create_value_loader<K, V, Extra>(
201    state: Arc<Mutex<CacheState<K, V>>>,
202    loader: Arc<dyn CacheLoader<K = K, V = V, Extra = Extra>>,
203    loader_tag: u64,
204    k: K,
205    extra: Extra,
206) -> (
207    CancellationSafeFuture<impl Future<Output = ()>>,
208    ValueLoader<V>,
209)
210where
211    K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
212    V: Clone + Debug + Send + 'static,
213    Extra: Debug + Send + 'static,
214{
215    let (tx_main, rx_main) = tokio::sync::oneshot::channel();
216    let receiver = rx_main
217        .map_ok(|v| Arc::new(Mutex::new(v)))
218        .map_err(Arc::new)
219        .boxed()
220        .shared();
221    let (tx_set, rx_set) = tokio::sync::oneshot::channel();
222
223    // need to wrap the loader into a `CancellationSafeFuture` so that it doesn't get cancelled when
224    // this very request is cancelled
225    let join_handle_receiver = Arc::new(Mutex::new(None));
226    let fut = async move {
227        let loader_fut = async move {
228            let mut submitter = ResultSubmitter::new(state, k.clone(), loader_tag);
229
230            // execute the loader
231            // If we panic here then `tx` will be dropped and the receivers will be
232            // notified.
233            let v = loader.load(k, extra).await;
234
235            // remove "running" state and store result
236            let was_running = submitter.submit(v.clone());
237
238            if !was_running {
239                // value was side-loaded, so we cannot populate `v`. Instead block this
240                // execution branch and wait for `rx_set` to deliver the side-loaded
241                // result.
242                loop {
243                    tokio::task::yield_now().await;
244                }
245            }
246
247            v
248        };
249
250        // prefer the side-loader
251        let v = futures::select_biased! {
252            maybe_v = rx_set.fuse() => {
253                match maybe_v {
254                    Ok(v) => {
255                        // data get side-loaded via `Cache::set`. In this case, we do
256                        // NOT modify the state because there would be a lock-gap. The
257                        // `set` function will do that for us instead.
258                        v
259                    }
260                    Err(_) => {
261                        // sender side is gone, very likely the cache is shutting down
262                        debug!(
263                            "Sender for side-loading data into running loader gone.",
264                        );
265                        return;
266                    }
267                }
268            }
269            v = loader_fut.fuse() => v,
270        };
271
272        // broadcast result
273        // It's OK if the receiver side is gone. This might happen during shutdown
274        tx_main.send(v).ok();
275    };
276    let fut = CancellationSafeFuture::new(fut, Arc::clone(&join_handle_receiver));
277
278    (
279        fut,
280        ValueLoader {
281            recv: receiver,
282            set: tx_set,
283            join_handle: join_handle_receiver,
284            tag: loader_tag,
285        },
286    )
287}
288
289/// Inner cache state that is usually guarded by a lock.
290///
291/// The state parts must be updated in a consistent manner, i.e. while using the same lock guard.
292#[derive(Debug)]
293struct CacheState<K, V>
294where
295    K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
296    V: Clone + Debug + Send + 'static,
297{
298    /// Cached entries (i.e. queries completed).
299    cached_entries: CacheBackend<K, V>,
300
301    /// Currently value loaders indexed by cache key.
302    loaders: HashMap<K, ValueLoader<V>>,
303
304    /// Tag used for the next value loader to distinguish loaders for the same key
305    /// (e.g. when starting, side-loading, starting again)
306    next_loader_tag: u64,
307}
308
309impl<K, V> CacheState<K, V>
310where
311    K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
312    V: Clone + Debug + Send + 'static,
313{
314    /// To avoid overflow issue, it will begin from 0. It will rarely happen that
315    /// two value loaders share the same key and tag while for different purposes
316    #[inline]
317    fn next_loader_tag(&mut self) -> u64 {
318        let ret = self.next_loader_tag;
319        if self.next_loader_tag != u64::MAX {
320            self.next_loader_tag += 1;
321        } else {
322            self.next_loader_tag = 0;
323        }
324        ret
325    }
326}
327
328/// State for coordinating the execution of a single value loader.
329#[derive(Debug)]
330struct ValueLoader<V> {
331    /// A receiver that can await the result.
332    recv: SharedReceiver<V>,
333
334    /// A sender that enables setting entries while the query is running.
335    set: Sender<V>,
336
337    /// A handle for the task that is currently loading the value.
338    ///
339    /// The handle can be used to abort the loading, e.g. when dropping the cache.
340    join_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
341
342    /// Tag so that loaders for the same key (e.g. when starting, side-loading, starting again) can
343    /// be told apart.
344    tag: u64,
345}
346
347/// A [`tokio::sync::oneshot::Receiver`] that can be cloned.
348///
349/// The types are:
350///
351/// - `Arc<Mutex<V>>`: Ensures that we can clone `V` without requiring `V: Sync`. At the same time
352///   the reference to `V` (i.e. the `Arc`) must be cloneable for `Shared`
353/// - `Arc<RecvError>`: Is required because `RecvError` is not `Clone` but `Shared` requires that.
354/// - `BoxFuture`: The transformation from `Result<V, RecvError>` to `Result<Arc<Mutex<V>>,
355///   Arc<RecvError>>` results in a kinda messy type and we wanna erase that.
356/// - `Shared`: Allow the receiver to be cloned and be awaited from multiple places.
357type SharedReceiver<V> =
358    Shared<BoxFuture<'static, Result<Arc<Mutex<V>>, Arc<RecvError>>>>;
359
360/// Retrieve data from shared receiver.
361async fn retrieve_from_shared<V>(receiver: SharedReceiver<V>) -> V
362where
363    V: Clone + Send,
364{
365    receiver
366        .await
367        .expect("cache loader panicked, see logs")
368        .lock()
369        .clone()
370}
371
372/// Helper to submit results of running queries.
373///
374/// Ensures that running loader is removed when dropped (e.g. during panic).
375struct ResultSubmitter<K, V>
376where
377    K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
378    V: Clone + Debug + Send + 'static,
379{
380    state: Arc<Mutex<CacheState<K, V>>>,
381    tag: u64,
382    k: Option<K>,
383    v: Option<V>,
384}
385
386impl<K, V> ResultSubmitter<K, V>
387where
388    K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
389    V: Clone + Debug + Send + 'static,
390{
391    fn new(state: Arc<Mutex<CacheState<K, V>>>, k: K, tag: u64) -> Self {
392        Self {
393            state,
394            tag,
395            k: Some(k),
396            v: None,
397        }
398    }
399
400    /// Submit value.
401    ///
402    /// Returns `true` if this very loader was running.
403    fn submit(&mut self, v: V) -> bool {
404        assert!(self.v.is_none());
405        self.v = Some(v);
406        self.finalize()
407    }
408
409    /// Finalize request.
410    ///
411    /// Returns `true` if this very loader was running.
412    fn finalize(&mut self) -> bool {
413        let k = self.k.take().expect("finalized twice");
414        let mut state = self.state.lock();
415
416        match state.loaders.get(&k) {
417            Some(loader) if loader.tag == self.tag => {
418                state.loaders.remove(&k);
419
420                if let Some(v) = self.v.take() {
421                    // this very loader is in charge of the key, so store in in the
422                    // underlying cache
423                    state.cached_entries.put(k, v);
424                }
425
426                true
427            }
428            _ => {
429                // This loader is actually not really running any longer but got
430                // shut down, e.g. due to side loading. Do NOT store the
431                // generated value in the underlying cache.
432
433                false
434            }
435        }
436    }
437}
438
439impl<K, V> Drop for ResultSubmitter<K, V>
440where
441    K: Clone + Eq + Hash + Ord + Debug + Send + 'static,
442    V: Clone + Debug + Send + 'static,
443{
444    fn drop(&mut self) {
445        if self.k.is_some() {
446            // not finalized yet
447            self.finalize();
448        }
449    }
450}
451
452#[cfg(test)]
453mod tests {
454
455    use crate::backend::policy::lru::lru_cache::LruCache;
456    use crate::listener::cache_policy::CachePolicyListener;
457    use crate::{CacheBackend, CacheDriver, CacheLoader, CachePolicyWithListener};
458
459    use crate::backend::policy::lru::DefaultResourceCounter;
460    use crate::loading_cache::LoadingCache;
461    use async_trait::async_trait;
462    use parking_lot::Mutex;
463    use std::sync::mpsc::{channel, Sender};
464    use std::sync::Arc;
465
466    #[tokio::test]
467    async fn test_removal_entries() {
468        let cache_policy =
469            LruCache::with_resource_counter(DefaultResourceCounter::new(3));
470        let loader = TestStringCacheLoader {
471            prefix: "file".to_string(),
472        };
473        let (sender, receiver) = channel::<(String, String)>();
474        let listener = Arc::new(EntryRemovalListener::new(sender));
475        let policy_with_listener =
476            CachePolicyWithListener::new(cache_policy, vec![listener.clone()]);
477        let cache_backend = CacheBackend::new(policy_with_listener);
478        let loading_cache = CacheDriver::new(cache_backend, Arc::new(loader));
479
480        assert_eq!(
481            "file1".to_string(),
482            loading_cache.get("1".to_string(), ()).await
483        );
484        assert_eq!(
485            "file2".to_string(),
486            loading_cache.get("2".to_string(), ()).await
487        );
488        assert_eq!(
489            "file3".to_string(),
490            loading_cache.get("3".to_string(), ()).await
491        );
492        assert_eq!(
493            "file4".to_string(),
494            loading_cache.get("4".to_string(), ()).await
495        );
496        assert_eq!(Ok(("1".to_string(), "file1".to_string())), receiver.recv());
497        assert!(loading_cache.get_if_present("1".to_string()).is_none());
498
499        loading_cache
500            .put("2".to_string(), "file2-bak".to_string())
501            .await;
502        assert_eq!(
503            "file5".to_string(),
504            loading_cache.get("5".to_string(), ()).await
505        );
506        assert_eq!(Ok(("3".to_string(), "file3".to_string())), receiver.recv());
507        assert!(loading_cache.get_if_present("3".to_string()).is_none());
508        assert!(loading_cache.get_if_present("2".to_string()).is_some());
509
510        loading_cache.invalidate("2".to_string());
511        assert_eq!(
512            Ok(("2".to_string(), "file2-bak".to_string())),
513            receiver.recv()
514        );
515        assert!(loading_cache.get_if_present("2".to_string()).is_none());
516    }
517
518    #[derive(Debug)]
519    struct EntryRemovalListener {
520        sender: Arc<Mutex<Sender<(String, String)>>>,
521    }
522
523    impl EntryRemovalListener {
524        pub fn new(sender: Sender<(String, String)>) -> Self {
525            Self {
526                sender: Arc::new(Mutex::new(sender)),
527            }
528        }
529    }
530
531    impl CachePolicyListener for EntryRemovalListener {
532        type K = String;
533        type V = String;
534
535        fn listen_on_get(&self, _k: Self::K, _v: Option<Self::V>) {
536            // Do nothing
537        }
538
539        fn listen_on_peek(&self, _k: Self::K, _v: Option<Self::V>) {
540            // Do nothing
541        }
542
543        fn listen_on_put(&self, _k: Self::K, _v: Self::V, _old_v: Option<Self::V>) {
544            // Do nothing
545        }
546
547        fn listen_on_remove(&self, k: Self::K, v: Option<Self::V>) {
548            if let Some(v) = v {
549                self.sender.lock().send((k, v)).unwrap();
550            }
551        }
552
553        fn listen_on_pop(&self, entry: (Self::K, Self::V)) {
554            self.sender.lock().send(entry).unwrap();
555        }
556    }
557
558    #[derive(Debug)]
559    struct TestStringCacheLoader {
560        prefix: String,
561    }
562
563    #[async_trait]
564    impl CacheLoader for TestStringCacheLoader {
565        type K = String;
566        type V = String;
567        type Extra = ();
568
569        async fn load(&self, k: Self::K, _extra: Self::Extra) -> Self::V {
570            format!("{}{k}", self.prefix)
571        }
572    }
573}