1use std::any::Any;
2use std::cell::{Ref, RefCell};
3use std::collections::HashSet;
4use std::marker::PhantomData;
5use std::ops::{Deref, DerefMut};
6use std::rc::Rc;
7
8struct CacheNode {
9 modify: bool,
10 dependencies: Vec<usize>,
11 value: Box<dyn Any>,
12}
13
14#[derive(Default)]
15struct CacheInner {
16 nodes: Vec<CacheNode>,
17 progress: usize,
18}
19
20impl CacheInner {
21 fn any_modify(&self, dependencies: &[usize]) -> bool {
22 dependencies.iter().copied().any(|i| self.nodes[i].modify)
23 }
24}
25
26#[derive(Default)]
27struct Cache {
28 inner: RefCell<CacheInner>,
29 dependencies: RefCell<Option<HashSet<usize>>>,
30}
31
32#[derive(Default)]
33pub struct CacheFlow {
34 cache: Rc<Cache>,
35}
36
37impl CacheFlow {
38 pub fn new() -> Self {
39 Self::default()
40 }
41
42 pub fn reset(&mut self) {
43 *self = Self::default();
44 }
45
46 pub fn begin(&mut self) {
47 let mut inner = self.cache.inner.borrow_mut();
48 inner.progress = 0;
49 }
50
51 pub fn modify_cmp_by_fn<T, Cmp: FnOnce(&T, &T) -> bool>(
52 &self,
53 value: T,
54 cmp: Cmp,
55 ) -> CacheValue<T> {
56 let mut inner = self.cache.inner.borrow_mut();
57 let index = inner.progress;
58 if let Some(node) = inner.nodes.get_mut(index) {
59 let old_value = node.value.deref().downcast_ref().unwrap();
60 node.modify = !cmp(old_value, &value);
61 node.dependencies = vec![index];
62 node.value = Box::new(value);
63 } else {
64 inner.nodes.push(CacheNode {
65 modify: true,
66 dependencies: vec![index],
67 value: Box::new(value),
68 });
69 }
70 inner.progress = index + 1;
71 CacheValue {
72 cache: self.cache.clone(),
73 index,
74 marker: PhantomData,
75 }
76 }
77
78 pub fn modify<T>(&mut self, value: T) -> CacheValue<T> {
79 self.modify_cmp_by_fn(value, |_, _| true)
80 }
81
82 pub fn modify_cmp<T: PartialEq>(&mut self, value: T) -> CacheValue<T> {
83 self.modify_cmp_by_fn(value, T::eq)
84 }
85
86 pub fn compute_cmp_by_fn<F: FnOnce() -> T, T, Cmp: FnOnce(&T, &T) -> bool>(
87 &mut self,
88 f: F,
89 cmp: Cmp,
90 ) -> CacheValue<T> {
91 let mut inner = self.cache.inner.borrow_mut();
92 let index = inner.progress;
93 if let Some(node) = inner.nodes.get(index) {
94 node.value.deref().downcast_ref::<T>().unwrap();
95 if !inner.any_modify(&node.dependencies) {
96 inner.nodes[index].modify = false;
97 inner.progress = index + 1;
98 return CacheValue {
99 cache: self.cache.clone(),
100 index,
101 marker: PhantomData,
102 };
103 } else {
104 drop(inner);
105 }
106 } else {
107 drop(inner);
108 }
109
110 *self.cache.dependencies.borrow_mut() = Some(HashSet::new());
111 let value = f();
112 let dependencies = self
113 .cache
114 .dependencies
115 .borrow_mut()
116 .take()
117 .unwrap()
118 .into_iter()
119 .collect();
120
121 let mut inner = self.cache.inner.borrow_mut();
122 if let Some(node) = inner.nodes.get_mut(index) {
123 let old_value = node.value.deref().downcast_ref().unwrap();
124 node.modify = !cmp(old_value, &value);
125 node.dependencies = dependencies;
126 node.value = Box::new(value);
127 } else {
128 inner.nodes.push(CacheNode {
129 modify: true,
130 dependencies,
131 value: Box::new(value),
132 });
133 }
134 inner.progress = index + 1;
135
136 CacheValue {
137 cache: self.cache.clone(),
138 index,
139 marker: PhantomData,
140 }
141 }
142
143 pub fn compute<F: FnOnce() -> T, T>(&mut self, f: F) -> CacheValue<T> {
144 self.compute_cmp_by_fn(f, |_, _| true)
145 }
146
147 pub fn compute_cmp<F: FnOnce() -> T, T: PartialEq>(&mut self, f: F) -> CacheValue<T> {
148 self.compute_cmp_by_fn(f, T::eq)
149 }
150}
151
152pub struct CacheValue<T: Sized + 'static> {
153 cache: Rc<Cache>,
154 index: usize,
155 marker: PhantomData<T>,
156}
157
158impl<T: Sized + 'static> Clone for CacheValue<T> {
159 fn clone(&self) -> Self {
160 Self {
161 cache: self.cache.clone(),
162 index: self.index,
163 marker: PhantomData,
164 }
165 }
166}
167
168impl<T: Sized + 'static> CacheValue<T> {
169 pub fn get(&self) -> CacheRef<T> {
170 let inner = self.cache.inner.borrow();
171 if let Some(dependencies) = self.cache.dependencies.borrow_mut().deref_mut() {
172 dependencies.insert(self.index);
173 }
174 CacheRef {
175 inner,
176 index: self.index,
177 marker: PhantomData,
178 }
179 }
180}
181
182impl<T: Clone + Sized + 'static> CacheValue<T> {
183 pub fn inner_clone(&self) -> T {
184 T::clone(self.get().deref())
185 }
186}
187
188pub struct CacheRef<'b, T: Sized + 'static> {
189 inner: Ref<'b, CacheInner>,
190 index: usize,
191 marker: PhantomData<&'b T>,
192}
193
194impl<'b, T: Sized + 'static> Deref for CacheRef<'b, T> {
195 type Target = T;
196
197 fn deref(&self) -> &Self::Target {
198 self.inner.nodes[self.index]
199 .value
200 .deref()
201 .downcast_ref()
202 .unwrap()
203 }
204}
205
206#[test]
207fn test() {
208 let mut flow = CacheFlow::new();
209 for i in 0..2 {
210 let input = [0, 0, 1][i];
211 flow.begin();
212
213 let x = flow.modify_cmp(input);
214 assert_eq!(*x.get(), input);
215
216 let y = flow.compute_cmp(|| {
217 println!("compute y");
218 *x.get() + 1
219 });
220 assert_eq!(*y.get(), input + 1);
221
222 let z = flow.compute_cmp(|| {
223 println!("compute z");
224 *y.get() + 1
225 });
226 assert_eq!(*z.get(), input + 2);
227 }
228}