fn_memo/
unsync.rs

1use crate::FnMemo;
2use recur_fn::RecurFn;
3use std::cell::UnsafeCell;
4use std::collections::HashMap;
5use std::hash::Hash;
6
7/// The cache for single-thread memoization.
8pub trait Cache {
9    type Arg;
10    type Output;
11
12    /// Create an empty cache.
13    fn new() -> Self;
14    /// Gets the cached result of `arg`.
15    /// If it is not cached, returns `None`.
16    fn get(&self, arg: &Self::Arg) -> Option<&Self::Output>;
17    /// Caches the `arg` with `result`.
18    fn cache(&mut self, arg: Self::Arg, result: Self::Output);
19    /// Clears the cache.
20    fn clear(&mut self);
21}
22
23/// Use `HashMap` as `Cache`.
24impl<Arg: Clone + Eq + Hash, Output: Clone> Cache for HashMap<Arg, Output> {
25    type Arg = Arg;
26    type Output = Output;
27
28    fn new() -> Self {
29        HashMap::new()
30    }
31    fn get(&self, arg: &Arg) -> Option<&Output> {
32        HashMap::get(self, arg);
33        self.get(arg)
34    }
35
36    fn cache(&mut self, arg: Arg, result: Output) {
37        self.insert(arg, result);
38    }
39
40    fn clear(&mut self) {
41        self.clear();
42    }
43}
44
45/// Use `Vec` as `Cache` for sequences.
46impl<Output: Clone> Cache for Vec<Option<Output>> {
47    type Arg = usize;
48    type Output = Output;
49
50    fn new() -> Self {
51        Vec::new()
52    }
53
54    fn get(&self, arg: &usize) -> Option<&Output> {
55        self.as_slice().get(*arg)?.as_ref()
56    }
57
58    fn cache(&mut self, arg: usize, result: Output) {
59        if arg >= self.len() {
60            self.resize(arg + 1, None);
61        }
62        self[arg] = Some(result);
63    }
64
65    fn clear(&mut self) {
66        self.clear();
67    }
68}
69
70/// The implementation of `FnMemo` for single-thread.
71pub struct Memo<C, F> {
72    cache: UnsafeCell<C>,
73    f: F,
74}
75
76impl<C: Cache, F: RecurFn<C::Arg, C::Output>> Memo<C, F>
77where
78    C::Arg: Clone,
79    C::Output: Clone,
80{
81    /// Constructs a `Memo` using `C` as cache, caching function `f`.
82    pub fn new(f: F) -> Memo<C, F> {
83        Memo {
84            cache: UnsafeCell::new(C::new()),
85            f,
86        }
87    }
88}
89
90impl<C: Cache, F: RecurFn<C::Arg, C::Output>> FnMemo<C::Arg, C::Output> for Memo<C, F>
91where
92    C::Arg: Clone,
93    C::Output: Clone,
94{
95    fn call(&self, arg: C::Arg) -> C::Output {
96        if let Some(result) = unsafe { &*self.cache.get() }.get(&arg) {
97            return result.clone();
98        }
99
100        let result = self.f.body(|arg| self.call(arg), arg.clone());
101        unsafe { &mut *self.cache.get() }.cache(arg, result.clone());
102        result
103    }
104
105    fn clear_cache(&self) {
106        unsafe { &mut *self.cache.get() }.clear()
107    }
108}
109
110/// Creates a memoization of `f` using `HashMap` as cache.
111pub fn memoize<Arg, Output, F>(f: F) -> impl FnMemo<Arg, Output>
112where
113    Arg: Clone + Eq + Hash,
114    Output: Clone,
115    F: RecurFn<Arg, Output>,
116{
117    Memo::<std::collections::HashMap<_, _>, _>::new(f)
118}
119
120/// Creates a memoization of the sequence `f` using `Vec` as cache.
121pub fn memoize_seq<Output, F>(f: F) -> impl FnMemo<usize, Output>
122where
123    Output: Clone,
124    F: RecurFn<usize, Output>,
125{
126    Memo::<Vec<_>, _>::new(f)
127}