use std::borrow::Borrow;
use std::cell::RefCell;
use std::hash::Hash;
use hashbrown::HashMap;
pub trait Cache<I, O> {
fn find<Q: ?Sized>(&self, q: &Q) -> O
where
Q: Eq + Hash + ToOwned<Owned = I>,
I: Borrow<Q>;
}
#[derive(Clone, Copy)]
pub struct Recur<'a, I, O>(pub &'a dyn for<'b> Fn(&'b I) -> O);
impl<'a, I, O> Recur<'a, I, O> {
pub fn r(&self, i: &I) -> O { self.0(i) }
}
struct CacheImpl<I, O, F> {
store: RefCell<HashMap<I, O>>,
closure: F,
}
impl<I: Eq + Hash + Clone, O: Clone, F: for<'a> Fn(I, Recur<'a, I, O>) -> O>
CacheImpl<I, O, F>
{
pub fn new(closure: F) -> Self {
Self { store: RefCell::new(HashMap::new()), closure }
}
}
impl<I: Eq + Hash + Clone, O: Clone, F: for<'a> Fn(I, Recur<'a, I, O>) -> O>
Cache<I, O> for CacheImpl<I, O, F>
{
fn find<Q: ?Sized>(&self, q: &Q) -> O
where
Q: Eq + Hash + ToOwned<Owned = I>,
I: Borrow<Q>,
{
let closure = &self.closure;
if let Some(v) = self.store.borrow().get(q) {
return v.clone();
}
let result = closure(q.to_owned(), Recur(&|i: &I| self.find::<I>(i)));
let mut store = self.store.borrow_mut();
let (_, v) = store
.raw_entry_mut()
.from_key(q)
.or_insert_with(|| (q.to_owned(), result));
v.clone()
}
}
impl<I, O, F> IntoIterator for CacheImpl<I, O, F> {
type IntoIter = hashbrown::hash_map::IntoIter<I, O>;
type Item = (I, O);
fn into_iter(self) -> Self::IntoIter {
let CacheImpl { store, .. } = self;
let map = store.into_inner();
map.into_iter()
}
}
pub fn cached<I: Eq + Hash + Clone, O: Clone>(
func: impl for<'a> Fn(I, Recur<'a, I, O>) -> O,
) -> impl Cache<I, O> {
CacheImpl::new(func)
}
#[cfg(test)]
mod test {
use std::cell::RefCell;
use super::{Cache, CacheImpl};
#[test]
fn works() {
let runs = RefCell::new(0);
let cache = CacheImpl::new(|arg, _| {
*runs.borrow_mut() += 1;
arg * 2
});
assert_eq!(*runs.borrow(), 0, "callback ran without argument??");
cache.find(&1);
assert_eq!(*runs.borrow(), 1, "callback ran once and was recorded");
cache.find(&1);
assert_eq!(*runs.borrow(), 1, "cached value reused");
cache.find(&2);
assert_eq!(*runs.borrow(), 2, "new value passed to callback");
cache.find(&1);
assert_eq!(*runs.borrow(), 2, "cache used after unrelated call");
}
#[test]
fn recursive() {
let runs = RefCell::new(0);
let cache = CacheImpl::new(|val, r| {
*runs.borrow_mut() += 1;
match val {
1 | 2 => 1,
n => r.r(&(n - 1)) + r.r(&(n - 2)),
}
});
assert_eq!(cache.find(&15), 610, "correct fibonacci number");
assert_eq!(*runs.borrow(), 15, "evaluated for numbers 1-15 exactly once");
}
}