use std::any::Any;
use std::cell::{Ref, RefCell};
use std::collections::HashSet;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::rc::Rc;
struct CacheNode {
modify: bool,
dependencies: Vec<usize>,
value: Box<dyn Any>,
}
#[derive(Default)]
struct CacheInner {
nodes: Vec<CacheNode>,
progress: usize,
}
impl CacheInner {
fn any_modify(&self, dependencies: &[usize]) -> bool {
dependencies.iter().copied().any(|i| self.nodes[i].modify)
}
}
#[derive(Default)]
struct Cache {
inner: RefCell<CacheInner>,
dependencies: RefCell<Option<HashSet<usize>>>,
}
#[derive(Default)]
pub struct CacheFlow {
cache: Rc<Cache>,
}
impl CacheFlow {
pub fn new() -> Self {
Self::default()
}
pub fn reset(&mut self) {
*self = Self::default();
}
pub fn begin(&mut self) {
let mut inner = self.cache.inner.borrow_mut();
inner.progress = 0;
}
pub fn modify_cmp_by_fn<T, Cmp: FnOnce(&T, &T) -> bool>(
&self,
value: T,
cmp: Cmp,
) -> CacheValue<T> {
let mut inner = self.cache.inner.borrow_mut();
let index = inner.progress;
if let Some(node) = inner.nodes.get_mut(index) {
let old_value = node.value.deref().downcast_ref().unwrap();
node.modify = !cmp(old_value, &value);
node.dependencies = vec![index];
node.value = Box::new(value);
} else {
inner.nodes.push(CacheNode {
modify: true,
dependencies: vec![index],
value: Box::new(value),
});
}
inner.progress = index + 1;
CacheValue {
cache: self.cache.clone(),
index,
marker: PhantomData,
}
}
pub fn modify<T>(&mut self, value: T) -> CacheValue<T> {
self.modify_cmp_by_fn(value, |_, _| true)
}
pub fn modify_cmp<T: PartialEq>(&mut self, value: T) -> CacheValue<T> {
self.modify_cmp_by_fn(value, T::eq)
}
pub fn compute_cmp_by_fn<F: FnOnce() -> T, T, Cmp: FnOnce(&T, &T) -> bool>(
&mut self,
f: F,
cmp: Cmp,
) -> CacheValue<T> {
let mut inner = self.cache.inner.borrow_mut();
let index = inner.progress;
if let Some(node) = inner.nodes.get(index) {
node.value.deref().downcast_ref::<T>().unwrap();
if !inner.any_modify(&node.dependencies) {
inner.nodes[index].modify = false;
inner.progress = index + 1;
return CacheValue {
cache: self.cache.clone(),
index,
marker: PhantomData,
};
} else {
drop(inner);
}
} else {
drop(inner);
}
*self.cache.dependencies.borrow_mut() = Some(HashSet::new());
let value = f();
let dependencies = self
.cache
.dependencies
.borrow_mut()
.take()
.unwrap()
.into_iter()
.collect();
let mut inner = self.cache.inner.borrow_mut();
if let Some(node) = inner.nodes.get_mut(index) {
let old_value = node.value.deref().downcast_ref().unwrap();
node.modify = !cmp(old_value, &value);
node.dependencies = dependencies;
node.value = Box::new(value);
} else {
inner.nodes.push(CacheNode {
modify: true,
dependencies,
value: Box::new(value),
});
}
inner.progress = index + 1;
CacheValue {
cache: self.cache.clone(),
index,
marker: PhantomData,
}
}
pub fn compute<F: FnOnce() -> T, T>(&mut self, f: F) -> CacheValue<T> {
self.compute_cmp_by_fn(f, |_, _| true)
}
pub fn compute_cmp<F: FnOnce() -> T, T: PartialEq>(&mut self, f: F) -> CacheValue<T> {
self.compute_cmp_by_fn(f, T::eq)
}
}
pub struct CacheValue<T: Sized + 'static> {
cache: Rc<Cache>,
index: usize,
marker: PhantomData<T>,
}
impl<T: Sized + 'static> Clone for CacheValue<T> {
fn clone(&self) -> Self {
Self {
cache: self.cache.clone(),
index: self.index,
marker: PhantomData,
}
}
}
impl<T: Sized + 'static> CacheValue<T> {
pub fn get(&self) -> CacheRef<T> {
let inner = self.cache.inner.borrow();
if let Some(dependencies) = self.cache.dependencies.borrow_mut().deref_mut() {
dependencies.insert(self.index);
}
CacheRef {
inner,
index: self.index,
marker: PhantomData,
}
}
}
impl<T: Clone + Sized + 'static> CacheValue<T> {
pub fn inner_clone(&self) -> T {
T::clone(self.get().deref())
}
}
pub struct CacheRef<'b, T: Sized + 'static> {
inner: Ref<'b, CacheInner>,
index: usize,
marker: PhantomData<&'b T>,
}
impl<'b, T: Sized + 'static> Deref for CacheRef<'b, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner.nodes[self.index]
.value
.deref()
.downcast_ref()
.unwrap()
}
}
#[test]
fn test() {
let mut flow = CacheFlow::new();
for i in 0..2 {
let input = [0, 0, 1][i];
flow.begin();
let x = flow.modify_cmp(input);
assert_eq!(*x.get(), input);
let y = flow.compute_cmp(|| {
println!("compute y");
*x.get() + 1
});
assert_eq!(*y.get(), input + 1);
let z = flow.compute_cmp(|| {
println!("compute z");
*y.get() + 1
});
assert_eq!(*z.get(), input + 2);
}
}