#![forbid(unsafe_code)]
use std::cell::{Cell, RefCell};
use std::rc::Rc;
use super::observable::{Observable, Subscription};
struct ComputedInner<T> {
compute: Box<dyn Fn() -> T>,
cached: Option<T>,
dirty: Cell<bool>,
version: u64,
_subscriptions: Vec<Subscription>,
}
pub struct Computed<T> {
inner: Rc<RefCell<ComputedInner<T>>>,
}
impl<T> Clone for Computed<T> {
fn clone(&self) -> Self {
Self {
inner: Rc::clone(&self.inner),
}
}
}
impl<T: std::fmt::Debug> std::fmt::Debug for Computed<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let inner = self.inner.borrow();
f.debug_struct("Computed")
.field("cached", &inner.cached)
.field("dirty", &inner.dirty.get())
.field("version", &inner.version)
.finish()
}
}
impl<T: Clone + 'static> Computed<T> {
pub fn from_observable<S: Clone + PartialEq + 'static>(
source: &Observable<S>,
map: impl Fn(&S) -> T + 'static,
) -> Self {
let source_clone = source.clone();
let compute = Box::new(move || source_clone.with(|v| map(v)));
let inner = Rc::new(RefCell::new(ComputedInner {
compute,
cached: None,
dirty: Cell::new(true), version: 0,
_subscriptions: Vec::new(),
}));
let weak_inner = Rc::downgrade(&inner);
let sub = source.subscribe(move |_| {
if let Some(strong) = weak_inner.upgrade() {
strong.borrow().dirty.set(true);
}
});
inner.borrow_mut()._subscriptions.push(sub);
Self { inner }
}
pub fn from2<S1, S2>(
s1: &Observable<S1>,
s2: &Observable<S2>,
map: impl Fn(&S1, &S2) -> T + 'static,
) -> Self
where
S1: Clone + PartialEq + 'static,
S2: Clone + PartialEq + 'static,
{
let s1_clone = s1.clone();
let s2_clone = s2.clone();
let compute = Box::new(move || s1_clone.with(|v1| s2_clone.with(|v2| map(v1, v2))));
let inner = Rc::new(RefCell::new(ComputedInner {
compute,
cached: None,
dirty: Cell::new(true),
version: 0,
_subscriptions: Vec::new(),
}));
let weak1 = Rc::downgrade(&inner);
let sub1 = s1.subscribe(move |_| {
if let Some(strong) = weak1.upgrade() {
strong.borrow().dirty.set(true);
}
});
let weak2 = Rc::downgrade(&inner);
let sub2 = s2.subscribe(move |_| {
if let Some(strong) = weak2.upgrade() {
strong.borrow().dirty.set(true);
}
});
{
let mut inner_mut = inner.borrow_mut();
inner_mut._subscriptions.push(sub1);
inner_mut._subscriptions.push(sub2);
}
Self { inner }
}
pub fn from3<S1, S2, S3>(
s1: &Observable<S1>,
s2: &Observable<S2>,
s3: &Observable<S3>,
map: impl Fn(&S1, &S2, &S3) -> T + 'static,
) -> Self
where
S1: Clone + PartialEq + 'static,
S2: Clone + PartialEq + 'static,
S3: Clone + PartialEq + 'static,
{
let s1_clone = s1.clone();
let s2_clone = s2.clone();
let s3_clone = s3.clone();
let compute = Box::new(move || {
s1_clone.with(|v1| s2_clone.with(|v2| s3_clone.with(|v3| map(v1, v2, v3))))
});
let inner = Rc::new(RefCell::new(ComputedInner {
compute,
cached: None,
dirty: Cell::new(true),
version: 0,
_subscriptions: Vec::new(),
}));
let weak1 = Rc::downgrade(&inner);
let sub1 = s1.subscribe(move |_| {
if let Some(strong) = weak1.upgrade() {
strong.borrow().dirty.set(true);
}
});
let weak2 = Rc::downgrade(&inner);
let sub2 = s2.subscribe(move |_| {
if let Some(strong) = weak2.upgrade() {
strong.borrow().dirty.set(true);
}
});
let weak3 = Rc::downgrade(&inner);
let sub3 = s3.subscribe(move |_| {
if let Some(strong) = weak3.upgrade() {
strong.borrow().dirty.set(true);
}
});
{
let mut inner_mut = inner.borrow_mut();
inner_mut._subscriptions.push(sub1);
inner_mut._subscriptions.push(sub2);
inner_mut._subscriptions.push(sub3);
}
Self { inner }
}
pub fn from_fn(compute: impl Fn() -> T + 'static, subscriptions: Vec<Subscription>) -> Self {
Self {
inner: Rc::new(RefCell::new(ComputedInner {
compute: Box::new(compute),
cached: None,
dirty: Cell::new(true),
version: 0,
_subscriptions: subscriptions,
})),
}
}
#[must_use]
pub fn get(&self) -> T {
let mut inner = self.inner.borrow_mut();
if inner.dirty.get() || inner.cached.is_none() {
let new_value = (inner.compute)();
inner.cached = Some(new_value);
inner.dirty.set(false);
inner.version += 1;
}
inner
.cached
.as_ref()
.expect("cached is always Some after get()")
.clone()
}
pub fn with<R>(&self, f: impl FnOnce(&T) -> R) -> R {
{
let mut inner = self.inner.borrow_mut();
if inner.dirty.get() || inner.cached.is_none() {
let new_value = (inner.compute)();
inner.cached = Some(new_value);
inner.dirty.set(false);
inner.version += 1;
}
}
let inner = self.inner.borrow();
f(inner
.cached
.as_ref()
.expect("cached is always Some after refresh"))
}
#[must_use]
pub fn is_dirty(&self) -> bool {
self.inner.borrow().dirty.get()
}
pub fn invalidate(&self) {
self.inner.borrow().dirty.set(true);
}
#[must_use]
pub fn version(&self) -> u64 {
self.inner.borrow().version
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::cell::Cell;
#[test]
fn single_dep_computed() {
let source = Observable::new(10);
let computed = Computed::from_observable(&source, |v| v * 2);
assert_eq!(computed.get(), 20);
assert_eq!(computed.version(), 1);
source.set(5);
assert!(computed.is_dirty());
assert_eq!(computed.get(), 10);
assert_eq!(computed.version(), 2);
}
#[test]
fn multi_dep_computed() {
let width = Observable::new(10);
let height = Observable::new(20);
let area = Computed::from2(&width, &height, |w, h| w * h);
assert_eq!(area.get(), 200);
width.set(5);
assert_eq!(area.get(), 100);
height.set(30);
assert_eq!(area.get(), 150);
}
#[test]
fn three_dep_computed() {
let a = Observable::new(1);
let b = Observable::new(2);
let c = Observable::new(3);
let sum = Computed::from3(&a, &b, &c, |x, y, z| x + y + z);
assert_eq!(sum.get(), 6);
a.set(10);
assert_eq!(sum.get(), 15);
c.set(100);
assert_eq!(sum.get(), 112);
}
#[test]
fn lazy_evaluation() {
let compute_count = Rc::new(Cell::new(0u32));
let count_clone = Rc::clone(&compute_count);
let source = Observable::new(42);
let source_clone = source.clone();
let computed = Computed::from_fn(
move || {
count_clone.set(count_clone.get() + 1);
source_clone.get() * 2
},
vec![],
);
assert_eq!(compute_count.get(), 0);
assert_eq!(computed.get(), 84);
assert_eq!(compute_count.get(), 1);
assert_eq!(computed.get(), 84);
assert_eq!(compute_count.get(), 1);
}
#[test]
fn memoization() {
let compute_count = Rc::new(Cell::new(0u32));
let count_clone = Rc::clone(&compute_count);
let source = Observable::new(10);
let computed = Computed::from_observable(&source, move |v| {
count_clone.set(count_clone.get() + 1);
v * 2
});
assert_eq!(computed.get(), 20);
assert_eq!(compute_count.get(), 1);
assert_eq!(computed.get(), 20);
assert_eq!(compute_count.get(), 1);
source.set(20);
assert_eq!(computed.get(), 40);
assert_eq!(compute_count.get(), 2);
assert_eq!(computed.get(), 40);
assert_eq!(compute_count.get(), 2);
}
#[test]
fn invalidate_forces_recompute() {
let compute_count = Rc::new(Cell::new(0u32));
let count_clone = Rc::clone(&compute_count);
let source = Observable::new(5);
let computed = Computed::from_observable(&source, move |v| {
count_clone.set(count_clone.get() + 1);
*v
});
assert_eq!(computed.get(), 5);
assert_eq!(compute_count.get(), 1);
computed.invalidate();
assert!(computed.is_dirty());
assert_eq!(computed.get(), 5);
assert_eq!(compute_count.get(), 2);
}
#[test]
fn with_access() {
let source = Observable::new(vec![1, 2, 3]);
let computed = Computed::from_observable(&source, |v| v.iter().sum::<i32>());
let result = computed.with(|sum| *sum);
assert_eq!(result, 6);
}
#[test]
fn version_increments_on_recompute() {
let source = Observable::new(0);
let computed = Computed::from_observable(&source, |v| *v);
assert_eq!(computed.version(), 0);
let _ = computed.get();
assert_eq!(computed.version(), 1);
let _ = computed.get();
assert_eq!(computed.version(), 1);
source.set(1);
let _ = computed.get();
assert_eq!(computed.version(), 2);
}
#[test]
fn clone_shares_state() {
let source = Observable::new(10);
let c1 = Computed::from_observable(&source, |v| v + 1);
let c2 = c1.clone();
assert_eq!(c1.get(), 11);
assert_eq!(c2.get(), 11);
source.set(20);
assert_eq!(c1.get(), 21);
assert_eq!(c2.get(), 21);
}
#[test]
fn diamond_dependency() {
let a = Observable::new(10);
let b = Computed::from_observable(&a, |v| v + 1);
let c = Computed::from_observable(&a, |v| v * 2);
let b_clone = b.clone();
let c_clone = c.clone();
let d = Computed::from_observable(&a, move |_| b_clone.get() + c_clone.get());
assert_eq!(b.get(), 11);
assert_eq!(c.get(), 20);
assert_eq!(d.get(), 31);
a.set(5);
assert_eq!(b.get(), 6);
assert_eq!(c.get(), 10);
assert_eq!(d.get(), 16);
}
#[test]
fn no_change_same_value() {
let source = Observable::new(42);
let compute_count = Rc::new(Cell::new(0u32));
let count_clone = Rc::clone(&compute_count);
let computed = Computed::from_observable(&source, move |v| {
count_clone.set(count_clone.get() + 1);
*v
});
let _ = computed.get();
assert_eq!(compute_count.get(), 1);
source.set(42);
assert!(!computed.is_dirty());
let _ = computed.get();
assert_eq!(compute_count.get(), 1);
}
#[test]
fn debug_format() {
let source = Observable::new(42);
let computed = Computed::from_observable(&source, |v| *v);
let _ = computed.get();
let dbg = format!("{:?}", computed);
assert!(dbg.contains("Computed"));
assert!(dbg.contains("42"));
}
#[test]
fn from_fn_with_manual_subscriptions() {
let source = Observable::new(10);
let computed = Computed::from_observable(&source, |v| v * 3);
assert_eq!(computed.get(), 30);
source.set(20);
assert_eq!(computed.get(), 60);
let source2 = Observable::new(5);
let s2_clone = source2.clone();
let inner_dirty = Rc::new(Cell::new(false));
let dirty_for_sub = Rc::clone(&inner_dirty);
let sub = source2.subscribe(move |_| {
dirty_for_sub.set(true);
});
let computed2 = Computed::from_fn(move || s2_clone.get() * 3, vec![sub]);
assert_eq!(computed2.get(), 15);
source2.set(10);
assert!(inner_dirty.get()); computed2.invalidate();
assert_eq!(computed2.get(), 30);
}
#[test]
fn string_computed() {
let first = Observable::new("John".to_string());
let last = Observable::new("Doe".to_string());
let full_name = Computed::from2(&first, &last, |f, l| format!("{} {}", f, l));
assert_eq!(full_name.get(), "John Doe");
first.set("Jane".to_string());
assert_eq!(full_name.get(), "Jane Doe");
last.set("Smith".to_string());
assert_eq!(full_name.get(), "Jane Smith");
}
#[test]
fn computed_survives_source_drop() {
let computed;
{
let source = Observable::new(42);
computed = Computed::from_observable(&source, |v| *v);
let _ = computed.get(); }
assert_eq!(computed.get(), 42);
assert!(!computed.is_dirty());
}
#[test]
fn is_dirty_initially_true() {
let source = Observable::new(1);
let computed = Computed::from_observable(&source, |v| *v);
assert!(computed.is_dirty());
}
#[test]
fn with_increments_version_on_dirty() {
let source = Observable::new(10);
let computed = Computed::from_observable(&source, |v| *v);
let val = computed.with(|v| *v);
assert_eq!(val, 10);
assert_eq!(computed.version(), 1);
source.set(20);
let val = computed.with(|v| *v);
assert_eq!(val, 20);
assert_eq!(computed.version(), 2);
}
#[test]
fn invalidate_without_source_change() {
let source = Observable::new(5);
let computed = Computed::from_observable(&source, |v| *v);
let _ = computed.get();
assert_eq!(computed.version(), 1);
assert!(!computed.is_dirty());
computed.invalidate();
assert!(computed.is_dirty());
let _ = computed.get();
assert_eq!(computed.version(), 2);
}
#[test]
fn many_updates_version_monotonic() {
let source = Observable::new(0);
let computed = Computed::from_observable(&source, |v| *v);
for i in 1..=50 {
source.set(i);
let _ = computed.get();
}
assert_eq!(computed.version(), 50);
}
}