use std::{cell::RefCell, hash::Hash, pin::Pin};
use rustc_data_structures::fx::FxHashMap as HashMap;
pub struct Cache<In, Out>(RefCell<HashMap<In, Option<Pin<Box<Out>>>>>);
impl<In, Out> Cache<In, Out>
where
In: Hash + Eq + Clone,
{
pub fn len(&self) -> usize {
self.0.borrow().len()
}
pub fn contains_key(&self, key: &In) -> bool {
self.0.borrow().contains_key(key)
}
pub fn get(&self, key: &In, compute: impl FnOnce(In) -> Out) -> &Out {
self
.get_maybe_recursive(key, compute)
.unwrap_or_else(recursion_panic)
}
pub fn get_maybe_recursive<'a>(
&'a self,
key: &In,
compute: impl FnOnce(In) -> Out,
) -> Option<&'a Out> {
if !self.0.borrow().contains_key(key) {
self.0.borrow_mut().insert(key.clone(), None);
let out = Box::pin(compute(key.clone()));
self.0.borrow_mut().insert(key.clone(), Some(out));
}
let cache = self.0.borrow();
let entry = cache.get(key).expect("invariant broken").as_ref()?;
Some(unsafe { std::mem::transmute::<&'_ Out, &'a Out>(&**entry) })
}
}
fn recursion_panic<A>() -> A {
panic!(
"Recursion detected! The computation of a value tried to retrieve the same from the cache. Using `get_maybe_recursive` to handle this case gracefully."
)
}
impl<In, Out> Default for Cache<In, Out> {
fn default() -> Self {
Cache(RefCell::new(HashMap::default()))
}
}
pub struct CopyCache<In, Out>(RefCell<HashMap<In, Option<Out>>>);
impl<In, Out> CopyCache<In, Out>
where
In: Hash + Eq + Clone,
Out: Copy,
{
pub fn len(&self) -> usize {
self.0.borrow().len()
}
pub fn get(&self, key: &In, compute: impl FnOnce(In) -> Out) -> Out {
self
.get_maybe_recursive(key, compute)
.unwrap_or_else(recursion_panic)
}
pub fn get_maybe_recursive(
&self,
key: &In,
compute: impl FnOnce(In) -> Out,
) -> Option<Out> {
if !self.0.borrow().contains_key(key) {
self.0.borrow_mut().insert(key.clone(), None);
let out = compute(key.clone());
self.0.borrow_mut().insert(key.clone(), Some(out));
}
*self.0.borrow_mut().get(key).expect("invariant broken")
}
}
impl<In, Out> Default for CopyCache<In, Out> {
fn default() -> Self {
CopyCache(RefCell::new(HashMap::default()))
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_cached() {
let cache: Cache<usize, usize> = Cache::default();
let x = cache.get(&0, |_| 0);
let y = cache.get(&1, |_| 1);
let z = cache.get(&0, |_| 2);
assert_eq!(*x, 0);
assert_eq!(*y, 1);
assert_eq!(*z, 0);
assert!(std::ptr::eq(x, z));
}
#[test]
fn test_recursion_breaking() {
struct RecursiveUse(Cache<i32, i32>);
impl RecursiveUse {
fn get_infinite_recursion(&self, i: i32) -> i32 {
self
.0
.get_maybe_recursive(&i, |_| i + self.get_infinite_recursion(i))
.copied()
.unwrap_or(-18)
}
fn get_safe_recursion(&self, i: i32) -> i32 {
*self.0.get(&i, |_| {
if i == 0 {
0
} else {
self.get_safe_recursion(i - 1) + i
}
})
}
}
let cache = RecursiveUse(Cache::default());
assert_eq!(cache.get_infinite_recursion(60), 42);
assert_eq!(cache.get_safe_recursion(5), 15);
}
}