cache_flow/
lib.rs

1use std::any::Any;
2use std::cell::{Ref, RefCell};
3use std::collections::HashSet;
4use std::marker::PhantomData;
5use std::ops::{Deref, DerefMut};
6use std::rc::Rc;
7
8struct CacheNode {
9    modify: bool,
10    dependencies: Vec<usize>,
11    value: Box<dyn Any>,
12}
13
14#[derive(Default)]
15struct CacheInner {
16    nodes: Vec<CacheNode>,
17    progress: usize,
18}
19
20impl CacheInner {
21    fn any_modify(&self, dependencies: &[usize]) -> bool {
22        dependencies.iter().copied().any(|i| self.nodes[i].modify)
23    }
24}
25
26#[derive(Default)]
27struct Cache {
28    inner: RefCell<CacheInner>,
29    dependencies: RefCell<Option<HashSet<usize>>>,
30}
31
32#[derive(Default)]
33pub struct CacheFlow {
34    cache: Rc<Cache>,
35}
36
37impl CacheFlow {
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    pub fn reset(&mut self) {
43        *self = Self::default();
44    }
45
46    pub fn begin(&mut self) {
47        let mut inner = self.cache.inner.borrow_mut();
48        inner.progress = 0;
49    }
50
51    pub fn modify_cmp_by_fn<T, Cmp: FnOnce(&T, &T) -> bool>(
52        &self,
53        value: T,
54        cmp: Cmp,
55    ) -> CacheValue<T> {
56        let mut inner = self.cache.inner.borrow_mut();
57        let index = inner.progress;
58        if let Some(node) = inner.nodes.get_mut(index) {
59            let old_value = node.value.deref().downcast_ref().unwrap();
60            node.modify = !cmp(old_value, &value);
61            node.dependencies = vec![index];
62            node.value = Box::new(value);
63        } else {
64            inner.nodes.push(CacheNode {
65                modify: true,
66                dependencies: vec![index],
67                value: Box::new(value),
68            });
69        }
70        inner.progress = index + 1;
71        CacheValue {
72            cache: self.cache.clone(),
73            index,
74            marker: PhantomData,
75        }
76    }
77
78    pub fn modify<T>(&mut self, value: T) -> CacheValue<T> {
79        self.modify_cmp_by_fn(value, |_, _| true)
80    }
81
82    pub fn modify_cmp<T: PartialEq>(&mut self, value: T) -> CacheValue<T> {
83        self.modify_cmp_by_fn(value, T::eq)
84    }
85
86    pub fn compute_cmp_by_fn<F: FnOnce() -> T, T, Cmp: FnOnce(&T, &T) -> bool>(
87        &mut self,
88        f: F,
89        cmp: Cmp,
90    ) -> CacheValue<T> {
91        let mut inner = self.cache.inner.borrow_mut();
92        let index = inner.progress;
93        if let Some(node) = inner.nodes.get(index) {
94            node.value.deref().downcast_ref::<T>().unwrap();
95            if !inner.any_modify(&node.dependencies) {
96                inner.nodes[index].modify = false;
97                inner.progress = index + 1;
98                return CacheValue {
99                    cache: self.cache.clone(),
100                    index,
101                    marker: PhantomData,
102                };
103            } else {
104                drop(inner);
105            }
106        } else {
107            drop(inner);
108        }
109
110        *self.cache.dependencies.borrow_mut() = Some(HashSet::new());
111        let value = f();
112        let dependencies = self
113            .cache
114            .dependencies
115            .borrow_mut()
116            .take()
117            .unwrap()
118            .into_iter()
119            .collect();
120
121        let mut inner = self.cache.inner.borrow_mut();
122        if let Some(node) = inner.nodes.get_mut(index) {
123            let old_value = node.value.deref().downcast_ref().unwrap();
124            node.modify = !cmp(old_value, &value);
125            node.dependencies = dependencies;
126            node.value = Box::new(value);
127        } else {
128            inner.nodes.push(CacheNode {
129                modify: true,
130                dependencies,
131                value: Box::new(value),
132            });
133        }
134        inner.progress = index + 1;
135
136        CacheValue {
137            cache: self.cache.clone(),
138            index,
139            marker: PhantomData,
140        }
141    }
142
143    pub fn compute<F: FnOnce() -> T, T>(&mut self, f: F) -> CacheValue<T> {
144        self.compute_cmp_by_fn(f, |_, _| true)
145    }
146
147    pub fn compute_cmp<F: FnOnce() -> T, T: PartialEq>(&mut self, f: F) -> CacheValue<T> {
148        self.compute_cmp_by_fn(f, T::eq)
149    }
150}
151
152pub struct CacheValue<T: Sized + 'static> {
153    cache: Rc<Cache>,
154    index: usize,
155    marker: PhantomData<T>,
156}
157
158impl<T: Sized + 'static> Clone for CacheValue<T> {
159    fn clone(&self) -> Self {
160        Self {
161            cache: self.cache.clone(),
162            index: self.index,
163            marker: PhantomData,
164        }
165    }
166}
167
168impl<T: Sized + 'static> CacheValue<T> {
169    pub fn get(&self) -> CacheRef<T> {
170        let inner = self.cache.inner.borrow();
171        if let Some(dependencies) = self.cache.dependencies.borrow_mut().deref_mut() {
172            dependencies.insert(self.index);
173        }
174        CacheRef {
175            inner,
176            index: self.index,
177            marker: PhantomData,
178        }
179    }
180}
181
182impl<T: Clone + Sized + 'static> CacheValue<T> {
183    pub fn inner_clone(&self) -> T {
184        T::clone(self.get().deref())
185    }
186}
187
188pub struct CacheRef<'b, T: Sized + 'static> {
189    inner: Ref<'b, CacheInner>,
190    index: usize,
191    marker: PhantomData<&'b T>,
192}
193
194impl<'b, T: Sized + 'static> Deref for CacheRef<'b, T> {
195    type Target = T;
196
197    fn deref(&self) -> &Self::Target {
198        self.inner.nodes[self.index]
199            .value
200            .deref()
201            .downcast_ref()
202            .unwrap()
203    }
204}
205
206#[test]
207fn test() {
208    let mut flow = CacheFlow::new();
209    for i in 0..2 {
210        let input = [0, 0, 1][i];
211        flow.begin();
212
213        let x = flow.modify_cmp(input);
214        assert_eq!(*x.get(), input);
215
216        let y = flow.compute_cmp(|| {
217            println!("compute y");
218            *x.get() + 1
219        });
220        assert_eq!(*y.get(), input + 1);
221
222        let z = flow.compute_cmp(|| {
223            println!("compute z");
224            *y.get() + 1
225        });
226        assert_eq!(*z.get(), input + 2);
227    }
228}