1use crate::FnMemo;
2use recur_fn::RecurFn;
3use std::cell::UnsafeCell;
4use std::collections::HashMap;
5use std::hash::Hash;
6
7pub trait Cache {
9 type Arg;
10 type Output;
11
12 fn new() -> Self;
14 fn get(&self, arg: &Self::Arg) -> Option<&Self::Output>;
17 fn cache(&mut self, arg: Self::Arg, result: Self::Output);
19 fn clear(&mut self);
21}
22
23impl<Arg: Clone + Eq + Hash, Output: Clone> Cache for HashMap<Arg, Output> {
25 type Arg = Arg;
26 type Output = Output;
27
28 fn new() -> Self {
29 HashMap::new()
30 }
31 fn get(&self, arg: &Arg) -> Option<&Output> {
32 HashMap::get(self, arg);
33 self.get(arg)
34 }
35
36 fn cache(&mut self, arg: Arg, result: Output) {
37 self.insert(arg, result);
38 }
39
40 fn clear(&mut self) {
41 self.clear();
42 }
43}
44
45impl<Output: Clone> Cache for Vec<Option<Output>> {
47 type Arg = usize;
48 type Output = Output;
49
50 fn new() -> Self {
51 Vec::new()
52 }
53
54 fn get(&self, arg: &usize) -> Option<&Output> {
55 self.as_slice().get(*arg)?.as_ref()
56 }
57
58 fn cache(&mut self, arg: usize, result: Output) {
59 if arg >= self.len() {
60 self.resize(arg + 1, None);
61 }
62 self[arg] = Some(result);
63 }
64
65 fn clear(&mut self) {
66 self.clear();
67 }
68}
69
70pub struct Memo<C, F> {
72 cache: UnsafeCell<C>,
73 f: F,
74}
75
76impl<C: Cache, F: RecurFn<C::Arg, C::Output>> Memo<C, F>
77where
78 C::Arg: Clone,
79 C::Output: Clone,
80{
81 pub fn new(f: F) -> Memo<C, F> {
83 Memo {
84 cache: UnsafeCell::new(C::new()),
85 f,
86 }
87 }
88}
89
90impl<C: Cache, F: RecurFn<C::Arg, C::Output>> FnMemo<C::Arg, C::Output> for Memo<C, F>
91where
92 C::Arg: Clone,
93 C::Output: Clone,
94{
95 fn call(&self, arg: C::Arg) -> C::Output {
96 if let Some(result) = unsafe { &*self.cache.get() }.get(&arg) {
97 return result.clone();
98 }
99
100 let result = self.f.body(|arg| self.call(arg), arg.clone());
101 unsafe { &mut *self.cache.get() }.cache(arg, result.clone());
102 result
103 }
104
105 fn clear_cache(&self) {
106 unsafe { &mut *self.cache.get() }.clear()
107 }
108}
109
110pub fn memoize<Arg, Output, F>(f: F) -> impl FnMemo<Arg, Output>
112where
113 Arg: Clone + Eq + Hash,
114 Output: Clone,
115 F: RecurFn<Arg, Output>,
116{
117 Memo::<std::collections::HashMap<_, _>, _>::new(f)
118}
119
120pub fn memoize_seq<Output, F>(f: F) -> impl FnMemo<usize, Output>
122where
123 Output: Clone,
124 F: RecurFn<usize, Output>,
125{
126 Memo::<Vec<_>, _>::new(f)
127}