use super::tracker::{dispose_subscriber, start_tracking, stop_tracking, Subscriber, SubscriberId};
use crate::utils::lock::{lock_or_recover, read_or_recover, write_or_recover};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, RwLock};
pub struct Computed<T: Clone + Send + Sync + 'static> {
id: SubscriberId,
compute: Arc<dyn Fn() -> T + Send + Sync>,
cached: Arc<RwLock<Option<T>>>,
dirty: Arc<AtomicBool>,
recompute_lock: Arc<Mutex<()>>,
ref_count: Arc<AtomicUsize>,
}
impl<T: Clone + Send + Sync + 'static> Computed<T> {
pub fn new(f: impl Fn() -> T + Send + Sync + 'static) -> Self {
let id = SubscriberId::new();
let compute = Arc::new(f);
Self {
id,
compute,
cached: Arc::new(RwLock::new(None)),
dirty: Arc::new(AtomicBool::new(true)),
recompute_lock: Arc::new(Mutex::new(())),
ref_count: Arc::new(AtomicUsize::new(1)),
}
}
pub fn get(&self) -> T {
if !self.needs_recompute() {
if let Some(value) = self.get_cached() {
return value;
}
}
let _guard = lock_or_recover(&self.recompute_lock);
if self.needs_recompute() {
self.recompute_and_cache()
} else {
match self.get_cached() {
Some(value) => value,
None => self.recompute_and_cache(),
}
}
}
fn needs_recompute(&self) -> bool {
let is_dirty = self.dirty.load(Ordering::SeqCst);
let has_cache = read_or_recover(&self.cached).is_some();
is_dirty || !has_cache
}
fn recompute_and_cache(&self) -> T {
let dirty_flag = self.dirty.clone();
let subscriber = Subscriber {
id: self.id,
callback: Arc::new(move || {
dirty_flag.store(true, Ordering::SeqCst);
}),
};
start_tracking(subscriber);
let value = (self.compute)();
stop_tracking();
*write_or_recover(&self.cached) = Some(value.clone());
self.dirty.store(false, Ordering::SeqCst);
value
}
fn get_cached(&self) -> Option<T> {
read_or_recover(&self.cached).as_ref().cloned()
}
pub fn invalidate(&self) {
self.dirty.store(true, Ordering::SeqCst);
}
pub fn is_dirty(&self) -> bool {
self.dirty.load(Ordering::SeqCst)
}
}
impl<T: Clone + Send + Sync + 'static> Clone for Computed<T> {
fn clone(&self) -> Self {
self.ref_count.fetch_add(1, Ordering::SeqCst);
Self {
id: self.id,
compute: self.compute.clone(),
cached: self.cached.clone(),
dirty: self.dirty.clone(),
recompute_lock: self.recompute_lock.clone(),
ref_count: self.ref_count.clone(),
}
}
}
impl<T: Clone + Send + Sync + 'static> Drop for Computed<T> {
fn drop(&mut self) {
if self.ref_count.fetch_sub(1, Ordering::SeqCst) == 1 {
dispose_subscriber(self.id);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_computed_new() {
let c = Computed::new(|| 42);
assert_eq!(c.get(), 42);
}
#[test]
fn test_computed_new_string() {
let c = Computed::new(|| "hello".to_string());
assert_eq!(c.get(), "hello");
}
#[test]
fn test_computed_new_vec() {
let c = Computed::new(|| vec![1, 2, 3]);
assert_eq!(c.get(), vec![1, 2, 3]);
}
#[test]
fn test_computed_new_closure_with_captures() {
let x = 10;
let c = Computed::new(move || x * 2);
assert_eq!(c.get(), 20);
}
#[test]
fn test_computed_get_basic() {
let c = Computed::new(|| 42);
let result = c.get();
assert_eq!(result, 42);
}
#[test]
fn test_computed_get_multiple_calls() {
let c = Computed::new(|| 42);
assert_eq!(c.get(), 42);
assert_eq!(c.get(), 42);
assert_eq!(c.get(), 42);
}
#[test]
fn test_computed_get_returns_cloned_value() {
let c = Computed::new(|| vec![1, 2, 3]);
let v1 = c.get();
let v2 = c.get();
assert_eq!(v1, vec![1, 2, 3]);
assert_eq!(v2, vec![1, 2, 3]);
}
#[test]
fn test_computed_invalidate() {
let c = Computed::new(|| 42);
assert_eq!(c.get(), 42);
c.invalidate();
assert!(c.is_dirty());
assert_eq!(c.get(), 42);
}
#[test]
fn test_computed_invalidate_multiple() {
let c = Computed::new(|| 42);
c.invalidate();
c.invalidate();
c.invalidate();
assert!(c.is_dirty());
}
#[test]
fn test_computed_is_dirty_initially() {
let c = Computed::new(|| 42);
assert!(c.is_dirty(), "Should be dirty initially (cache empty)");
}
#[test]
fn test_computed_is_dirty_after_get() {
let c = Computed::new(|| 42);
c.get();
assert!(!c.is_dirty(), "Should be clean after first get");
}
#[test]
fn test_computed_is_dirty_after_invalidate() {
let c = Computed::new(|| 42);
c.get();
assert!(!c.is_dirty());
c.invalidate();
assert!(c.is_dirty());
}
#[test]
fn test_computed_clone_shares_cache() {
let c1 = Computed::new(|| 42);
assert_eq!(c1.get(), 42);
let c2 = c1.clone();
assert_eq!(c2.get(), 42);
}
#[test]
fn test_computed_clone_invalidate_affects_both() {
let c1 = Computed::new(|| 42);
c1.get();
let c2 = c1.clone();
c1.invalidate();
assert!(c2.is_dirty());
}
#[test]
fn test_computed_with_i32() {
let c = Computed::new(|| 123_i32);
assert_eq!(c.get(), 123);
}
#[test]
fn test_computed_with_u64() {
let c = Computed::new(|| 999_u64);
assert_eq!(c.get(), 999);
}
#[test]
fn test_computed_with_f64() {
let c = Computed::new(|| 3.14_f64);
assert!((c.get() - 3.14).abs() < 0.001);
}
#[test]
fn test_computed_with_bool() {
let c = Computed::new(|| true);
assert!(c.get());
}
#[test]
fn test_computed_with_option() {
let c = Computed::new(|| Some(42));
assert_eq!(c.get(), Some(42));
}
#[test]
fn test_computed_with_result() {
let c = Computed::new(|| Ok::<i32, &str>(42));
assert_eq!(c.get(), Ok(42));
}
#[test]
fn test_computed_with_complex_calculation() {
let c = Computed::new(|| {
let mut sum = 0;
for i in 1..=100 {
sum += i;
}
sum
});
assert_eq!(c.get(), 5050);
}
#[test]
fn test_computed_with_string_concat() {
let c = Computed::new(|| {
let mut s = String::new();
for i in 1..=5 {
s.push_str(&i.to_string());
}
s
});
assert_eq!(c.get(), "12345");
}
#[test]
fn test_computed_send_sync() {
fn is_send_sync<T: Send + Sync>() {}
is_send_sync::<Computed<i32>>();
}
}