use std::sync::{
Arc, RwLock,
atomic::{AtomicBool, Ordering},
};
use crate::signal::Signal;
pub trait DirtySink: Send + Sync {
fn dirty_flag(&self) -> Arc<AtomicBool>;
}
pub struct Computed<T: Clone + Send + Sync + 'static> {
func: Arc<dyn Fn() -> T + Send + Sync>,
cached: Arc<RwLock<Option<T>>>,
dirty: Arc<AtomicBool>,
}
impl<T: Clone + Send + Sync + 'static> Computed<T> {
pub fn new(f: impl Fn() -> T + Send + Sync + 'static) -> Self {
Self {
func: Arc::new(f),
cached: Arc::new(RwLock::new(None)),
dirty: Arc::new(AtomicBool::new(true)),
}
}
pub fn get(&self) -> T {
if !self.dirty.load(Ordering::Acquire) {
if let Some(cached) = self.cached.read().expect("Computed cache poisoned").as_ref() {
return cached.clone();
}
}
let new_val = (self.func)();
let mut guard = self.cached.write().expect("Computed cache poisoned");
*guard = Some(new_val.clone());
self.dirty.store(false, Ordering::Release);
new_val
}
pub fn dirty_flag(&self) -> Arc<AtomicBool> {
Arc::clone(&self.dirty)
}
pub fn track(&self, signal: &Signal<T>) {
signal.subscribe_with_flag(Arc::clone(&self.dirty));
}
pub fn track_any<S: SignalSubscribable>(&self, signal: &S) {
signal.subscribe_with_flag(Arc::clone(&self.dirty));
}
}
impl<T: Clone + Send + Sync + 'static> DirtySink for Computed<T> {
fn dirty_flag(&self) -> Arc<AtomicBool> {
Arc::clone(&self.dirty)
}
}
impl<T: Clone + Send + Sync + 'static> Clone for Computed<T> {
fn clone(&self) -> Self {
Self {
func: Arc::clone(&self.func),
cached: Arc::clone(&self.cached),
dirty: Arc::clone(&self.dirty),
}
}
}
impl<T: Clone + Send + Sync + std::fmt::Debug + 'static> std::fmt::Debug for Computed<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let dirty = self.dirty.load(Ordering::Relaxed);
match self.cached.try_read() {
Ok(guard) => write!(f, "Computed(dirty={}, cached={:?})", dirty, *guard),
Err(_) => write!(f, "Computed(dirty={}, cached=<locked>)", dirty),
}
}
}
pub trait SignalSubscribable: Send + Sync {
fn subscribe_with_flag(&self, flag: Arc<AtomicBool>);
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicU32;
#[test]
fn computed_evaluates_on_first_get() {
let c = Computed::new(|| 42i32);
assert_eq!(c.get(), 42);
}
#[test]
fn computed_caches_after_first_get() {
let call_count = Arc::new(AtomicU32::new(0));
let cc = Arc::clone(&call_count);
let c = Computed::new(move || {
cc.fetch_add(1, Ordering::SeqCst);
99i32
});
assert_eq!(c.get(), 99);
assert_eq!(c.get(), 99); assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[test]
fn dirty_flag_triggers_recompute() {
let call_count = Arc::new(AtomicU32::new(0));
let cc = Arc::clone(&call_count);
let c = Computed::new(move || {
cc.fetch_add(1, Ordering::SeqCst);
7i32
});
let _ = c.get(); c.dirty_flag().store(true, Ordering::Release); let _ = c.get(); assert_eq!(call_count.load(Ordering::SeqCst), 2);
}
#[test]
fn clone_shares_cache_and_dirty() {
let c1 = Computed::new(|| 1i32);
let c2 = c1.clone();
let _ = c1.get(); c2.dirty_flag().store(true, Ordering::Release);
assert!(c1.dirty_flag().load(Ordering::Acquire));
}
#[test]
fn track_wires_signal_to_computed() {
let signal = Signal::new(10i32);
let s = signal.clone();
let c = Computed::new(move || s.get() * 2);
c.track(&signal);
assert_eq!(c.get(), 20); assert_eq!(c.get(), 20);
signal.set(5);
assert_eq!(c.get(), 10);
assert_eq!(c.get(), 10);
}
#[test]
fn track_any_cross_type() {
let signal: Signal<i32> = Signal::new(3);
let s = signal.clone();
let c: Computed<String> = Computed::new(move || format!("val={}", s.get()));
c.track_any(&signal);
assert_eq!(c.get(), "val=3");
signal.set(7);
assert_eq!(c.get(), "val=7");
}
}