cache_loader_async/
internal_cache.rs

1use std::hash::Hash;
2use futures::Future;
3use tokio::task::JoinHandle;
4use crate::cache_api::{CacheResult, CacheLoadingError, CacheEntry, CacheCommunicationError, DataWithMeta};
5use crate::backing::CacheBacking;
6use std::fmt::Debug;
7
8macro_rules! unwrap_backing {
9    ($expr:expr) => {
10        match $expr {
11            Ok(data) => data,
12            Err(err) => return CacheResult::Error(err),
13        }
14    }
15}
16
17pub(crate) enum CacheAction<
18    K: Clone + Eq + Hash + Send,
19    V: Clone + Sized + Send,
20    E: Debug + Clone + Send,
21    B: CacheBacking<K, CacheEntry<V, E>>
22> {
23    GetIfPresent(K),
24    Get(K),
25    Set(K, V, Option<B::Meta>),
26    Update(K, Option<B::Meta>, Box<dyn FnOnce(V) -> V + Send + 'static>, bool),
27    UpdateMut(K, Box<dyn FnMut(&mut V) -> () + Send + 'static>, bool),
28    Remove(K),
29    RemoveIf(Box<dyn Fn((&K, Option<&V>)) -> bool + Send + Sync + 'static>),
30    Clear(),
31    // Internal use
32    SetAndUnblock(K, V, Option<B::Meta>),
33    Unblock(K),
34}
35
36pub(crate) struct CacheMessage<
37    K: Eq + Hash + Clone + Send,
38    V: Clone + Sized + Send,
39    E: Debug + Clone + Send,
40    B: CacheBacking<K, CacheEntry<V, E>>
41> {
42    pub(crate) action: CacheAction<K, V, E, B>,
43    pub(crate) response: tokio::sync::oneshot::Sender<CacheResult<V, E>>,
44}
45
46pub(crate) struct InternalCacheStore<
47    K: Clone + Eq + Hash + Send,
48    V: Clone + Sized + Send,
49    T,
50    E: Debug + Clone + Send,
51    B: CacheBacking<K, CacheEntry<V, E>>
52> {
53    tx: tokio::sync::mpsc::Sender<CacheMessage<K, V, E, B>>,
54    data: B,
55    loader: T,
56}
57
58impl<
59    K: Eq + Hash + Clone + Send + 'static,
60    V: Clone + Sized + Send + 'static,
61    E: Clone + Sized + Send + Debug + 'static,
62    F: Future<Output=Result<DataWithMeta<K, V, E, B>, E>> + Sized + Send + 'static,
63    T: Fn(K) -> F + Send + 'static,
64    B: CacheBacking<K, CacheEntry<V, E>> + Send + 'static
65> InternalCacheStore<K, V, T, E, B>
66{
67    pub fn new(
68        backing: B,
69        tx: tokio::sync::mpsc::Sender<CacheMessage<K, V, E, B>>,
70        loader: T,
71    ) -> Self {
72        Self {
73            tx,
74            data: backing,
75            loader,
76        }
77    }
78
79    pub(crate) fn run(mut self, mut rx: tokio::sync::mpsc::Receiver<CacheMessage<K, V, E, B>>) -> JoinHandle<()> {
80        tokio::spawn(async move {
81            loop {
82                if let Some(message) = rx.recv().await {
83                    let result = match message.action {
84                        CacheAction::GetIfPresent(key) => self.get_if_present(key),
85                        CacheAction::Get(key) => self.get(key),
86                        CacheAction::Set(key, value, meta) => self.set(key, value, false, meta),
87                        CacheAction::Update(key, meta, update_fn, load) => self.update(key, update_fn, load, meta),
88                        CacheAction::UpdateMut(key, update_mut_fn, load) => self.update_mut(key, update_mut_fn, load),
89                        CacheAction::Remove(key) => self.remove(key),
90                        CacheAction::RemoveIf(predicate) => self.remove_if(predicate),
91                        CacheAction::Clear() => self.clear(),
92                        CacheAction::SetAndUnblock(key, value, meta) => self.set(key, value, true, meta),
93                        CacheAction::Unblock(key) => self.unblock(key),
94                    };
95                    message.response.send(result).ok();
96                }
97            }
98        })
99    }
100
101    fn unblock(&mut self, key: K) -> CacheResult<V, E>{
102        if let Some(entry) = unwrap_backing!(self.data.get(&key)) {
103            if let CacheEntry::Loading(_) = entry {
104                if let Some(entry) = unwrap_backing!(self.data.remove(&key)) {
105                    if let CacheEntry::Loading(waiter) = entry {
106                        std::mem::drop(waiter) // dropping the sender closes the channel
107                    }
108                }
109            }
110        }
111        CacheResult::None
112    }
113
114    fn remove(&mut self, key: K) -> CacheResult<V, E> {
115        if let Some(entry) = unwrap_backing!(self.data.remove(&key)) {
116            match entry {
117                CacheEntry::Loaded(data) => CacheResult::Found(data),
118                CacheEntry::Loading(_) => CacheResult::None
119            }
120        } else {
121            CacheResult::None
122        }
123    }
124
125    fn remove_if(&mut self, predicate: Box<dyn Fn((&K, Option<&V>)) -> bool + Send + Sync + 'static>) -> CacheResult<V, E> {
126        unwrap_backing!(self.data.remove_if(self.to_predicate(predicate)));
127        CacheResult::None
128    }
129
130    fn to_predicate(&self, predicate: Box<dyn Fn((&K, Option<&V>)) -> bool + Send + Sync + 'static>)
131                    -> Box<dyn Fn((&K, &CacheEntry<V, E>)) -> bool + Send + Sync + 'static> {
132        Box::new(move |(key, value)| {
133            match value {
134                CacheEntry::Loaded(value) => {
135                    predicate((key, Some(value)))
136                }
137                CacheEntry::Loading(_) => {
138                    predicate((key, None))
139                }
140            }
141        })
142    }
143
144    fn clear(&mut self) -> CacheResult<V, E> {
145        unwrap_backing!(self.data.clear());
146        CacheResult::None
147    }
148
149    fn update_mut(&mut self, key: K, mut update_mut_fn: Box<dyn FnMut(&mut V) -> () + Send + 'static>, load: bool) -> CacheResult<V, E> {
150        match unwrap_backing!(self.data.get_mut(&key)) {
151            Some(entry) => {
152                match entry {
153                    CacheEntry::Loaded(data) => {
154                        update_mut_fn(data);
155                        CacheResult::Found(data.clone())
156                    }
157                    CacheEntry::Loading(waiter) => {
158                        let mut rx = waiter.subscribe();
159                        let cache_tx = self.tx.clone();
160                        CacheResult::Loading(tokio::spawn(async move {
161                            rx.recv().await.ok(); // result confirmed
162                            let (response_tx, response_rx) = tokio::sync::oneshot::channel();
163                            cache_tx.send(CacheMessage {
164                                action: CacheAction::UpdateMut(key, update_mut_fn, load),
165                                response: response_tx,
166                            }).await.ok();
167                            match response_rx.await.unwrap() {
168                                CacheResult::Found(data) => Ok(data),
169                                _ => Err(CacheLoadingError::NoData())
170                            }
171                        }))
172                    }
173                }
174            }
175            None => {
176                if load {
177                    let result = self.get(key.clone());
178                    match result {
179                        CacheResult::Loading(waiter) => {
180                            let cache_tx = self.tx.clone();
181                            CacheResult::Loading(tokio::spawn(async move {
182                                waiter.await.ok(); // result confirmed
183                                let (response_tx, response_rx) = tokio::sync::oneshot::channel();
184                                cache_tx.send(CacheMessage {
185                                    action: CacheAction::UpdateMut(key, update_mut_fn, load),
186                                    response: response_tx,
187                                }).await.ok();
188                                match response_rx.await.unwrap() {
189                                    CacheResult::Found(data) => Ok(data),
190                                    _ => Err(CacheLoadingError::NoData())
191                                }
192                            }))
193                        }
194                        _ => CacheResult::None,
195                    }
196                } else {
197                    CacheResult::None
198                }
199            }
200        }
201    }
202
203    fn update(&mut self, key: K, update_fn: Box<dyn FnOnce(V) -> V + Send + 'static>, load: bool, meta: Option<B::Meta>) -> CacheResult<V, E> {
204        let data = if load {
205            self.get(key.clone())
206        } else {
207            self.get_if_present(key.clone())
208        };
209
210        match data {
211            CacheResult::Found(data) => {
212                let updated_data = update_fn(data);
213                unwrap_backing!(self.data.set(key, CacheEntry::Loaded(updated_data.clone()), meta));
214                CacheResult::Found(updated_data)
215            }
216            CacheResult::Loading(handle) => {
217                let tx = self.tx.clone();
218                CacheResult::Loading(tokio::spawn(async move {
219                    handle.await.ok(); // set stupidly await the load to be done
220                    // we let the set logic take place which is called from within the future
221                    // and we're invoking a second update on the (now cached) data
222                    let (response_tx, rx) = tokio::sync::oneshot::channel();
223                    tx.send(CacheMessage {
224                        action: CacheAction::Update(key, meta, update_fn, load),
225                        response: response_tx,
226                    }).await.ok();
227                    match rx.await {
228                        Ok(result) => {
229                            match result {
230                                CacheResult::Found(data) => Ok(data),
231                                CacheResult::Loading(_) => Err(CacheLoadingError::CommunicationError(CacheCommunicationError::LookupLoop())),
232                                CacheResult::None => Err(CacheLoadingError::CommunicationError(CacheCommunicationError::LookupLoop())),
233                                CacheResult::Error(err) => Err(CacheLoadingError::BackingError(err)),
234                            }
235                        }
236                        Err(err) => Err(CacheLoadingError::CommunicationError(CacheCommunicationError::TokioOneshotRecvError(err))),
237                    }
238                }))
239            }
240            res => res
241        }
242    }
243
244    fn set(&mut self, key: K, value: V, loading_result: bool, meta: Option<B::Meta>) -> CacheResult<V, E> {
245        let opt_entry = unwrap_backing!(self.data.get(&key));
246        if loading_result {
247            if opt_entry.is_none() {
248                return CacheResult::None; // abort mission, key was deleted via remove
249            }
250            let entry = opt_entry.unwrap(); // it's some, because we return if its none
251            if matches!(entry, CacheEntry::Loaded(_)) {
252                return CacheResult::None; // abort mission, we already have an updated entry!
253            }
254        }
255        unwrap_backing!(self.data.set(key, CacheEntry::Loaded(value), meta))
256            .and_then(|entry| {
257                match entry {
258                    CacheEntry::Loaded(data) => Some(data),
259                    CacheEntry::Loading(_) => None
260                }
261            })
262            .map(|value| CacheResult::Found(value))
263            .unwrap_or(CacheResult::None)
264    }
265
266    fn get_if_present(&mut self, key: K) -> CacheResult<V, E> {
267        if let Some(entry) = unwrap_backing!(self.data.get(&key)) {
268            match entry {
269                CacheEntry::Loaded(data) => CacheResult::Found(data.clone()),
270                CacheEntry::Loading(_) => CacheResult::None,
271            }
272        } else {
273            CacheResult::None
274        }
275    }
276
277    fn get(&mut self, key: K) -> CacheResult<V, E> {
278        if let Some(entry) = unwrap_backing!(self.data.get(&key)) {
279            match entry {
280                CacheEntry::Loaded(value) => {
281                    CacheResult::Found(value.clone())
282                }
283                CacheEntry::Loading(waiter) => {
284                    let waiter = waiter.clone();
285                    CacheResult::Loading(tokio::spawn(async move {
286                        match waiter.subscribe().recv().await {
287                            Ok(result) => {
288                                match result {
289                                    Ok(data) => {
290                                        Ok(data)
291                                    }
292                                    Err(loading_error) => {
293                                        Err(CacheLoadingError::LoadingError(loading_error))
294                                    }
295                                }
296                            }
297                            Err(err) => Err(CacheLoadingError::CommunicationError(CacheCommunicationError::TokioBroadcastRecvError(err)))
298                        }
299                    }))
300                }
301            }
302        } else {
303            let (tx, _) = tokio::sync::broadcast::channel(1);
304            let inner_tx = tx.clone();
305            let cache_tx = self.tx.clone();
306            let loader = (self.loader)(key.clone());
307            let inner_key = key.clone();
308            let join_handle = tokio::spawn(async move {
309                match loader.await {
310                    Ok(value) => {
311                        let meta = value.meta;
312                        let value = value.data;
313                        inner_tx.send(Ok(value.clone())).ok();
314                        let (tx, rx) = tokio::sync::oneshot::channel();
315                        let send_value = value.clone();
316                        cache_tx.send(CacheMessage {
317                            action: CacheAction::SetAndUnblock(inner_key, send_value, meta),
318                            response: tx,
319                        }).await.ok();
320                        rx.await.ok(); // await cache confirmation
321                        Ok(value)
322                    }
323                    Err(loading_error) => {
324                        inner_tx.send(Err(loading_error.clone())).ok();
325                        let (tx, rx) = tokio::sync::oneshot::channel();
326                        cache_tx.send(CacheMessage {
327                            action: CacheAction::Unblock(inner_key),
328                            response: tx,
329                        }).await.ok();
330                        rx.await.ok(); // await cache confirmation
331                        Err(CacheLoadingError::LoadingError(loading_error))
332                    }
333                }
334            });
335            // Loading state is set without any meta
336            unwrap_backing!(self.data.set(key, CacheEntry::Loading(tx), None));
337            CacheResult::Loading(join_handle)
338        }
339    }
340}