use core::hash::Hash;
use std::collections::HashMap;
pub trait Memoized<A, R>
where
A: Hash + Eq + Copy
{
fn call(&mut self, arg: A) -> &R;
#[inline]
fn call_cloned(&mut self, arg: A) -> R
where
R: Clone,
{
self.call(arg).clone()
}
}
#[inline]
pub fn memoize<F, A, R>(f: F) -> impl Memoized<A, R>
where
F: FnMut(A) -> R,
A: Hash + Eq + Copy,
{
Closure(f, Default::default())
}
#[inline(always)]
pub fn memoize_rec<A, R, F>(f: F) -> impl Memoized<A, R>
where
F: Fn(&mut dyn FnMut(A) -> R, A) -> R,
R: Clone,
A: Hash + Eq + Copy,
{
RecursiveClosure(f, Default::default())
}
pub struct Closure<A, R, F>(F, HashMap<A, R>);
impl<A, R, F> Memoized<A, R> for Closure<A, R, F>
where
A: Hash + Eq + Copy,
F: FnMut(A) -> R,
{
fn call(&mut self, arg: A) -> &R {
use std::collections::hash_map::Entry;
let Self(f, mem) = self;
if let Entry::Vacant(v) = mem.entry(arg) {
v.insert(f(arg));
}
mem.get(&arg).unwrap()
}
}
pub struct RecursiveClosure<A, R, F>(F, HashMap<A, R>);
fn call_recursive<'a, A, R, F>(f: &'a F, mem: &'a mut HashMap<A, R>, arg: A) -> &'a R
where
A: Hash + Eq + Copy,
R: Clone,
F: Fn(&mut dyn FnMut(A) -> R, A) -> R,
{
if mem.contains_key(&arg) {
mem.get(&arg).unwrap()
} else {
let val = f(&mut |a| call_recursive(f, mem, a).clone(), arg);
mem.entry(arg).or_insert(val);
mem.get(&arg).unwrap()
}
}
impl<A, R, F> Memoized<A, R> for RecursiveClosure<A, R, F>
where
A: Hash + Eq + Copy,
R: Clone,
F: Fn(&mut dyn FnMut(A) -> R, A) -> R,
{
fn call(&mut self, arg: A) -> &R {
let Self(f, mem) = self;
call_recursive(f, mem, arg)
}
}