1#![forbid(unsafe_code)]
2
3use std::cell::{Cell, RefCell};
31use std::rc::Rc;
32
33use super::observable::{Observable, Subscription};
34
35struct ComputedInner<T> {
37 compute: Box<dyn Fn() -> T>,
39 cached: Option<T>,
41 dirty: Cell<bool>,
43 version: u64,
45 _subscriptions: Vec<Subscription>,
48}
49
50pub struct Computed<T> {
62 inner: Rc<RefCell<ComputedInner<T>>>,
63}
64
65impl<T> Clone for Computed<T> {
66 fn clone(&self) -> Self {
67 Self {
68 inner: Rc::clone(&self.inner),
69 }
70 }
71}
72
73impl<T: std::fmt::Debug> std::fmt::Debug for Computed<T> {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 let inner = self.inner.borrow();
76 f.debug_struct("Computed")
77 .field("cached", &inner.cached)
78 .field("dirty", &inner.dirty.get())
79 .field("version", &inner.version)
80 .finish()
81 }
82}
83
84impl<T: Clone + 'static> Computed<T> {
85 pub fn from_observable<S: Clone + PartialEq + 'static>(
90 source: &Observable<S>,
91 map: impl Fn(&S) -> T + 'static,
92 ) -> Self {
93 let source_clone = source.clone();
94 let compute = Box::new(move || source_clone.with(|v| map(v)));
95
96 let inner = Rc::new(RefCell::new(ComputedInner {
97 compute,
98 cached: None,
99 dirty: Cell::new(true), version: 0,
101 _subscriptions: Vec::new(),
102 }));
103
104 let weak_inner = Rc::downgrade(&inner);
106 let sub = source.subscribe(move |_| {
107 if let Some(strong) = weak_inner.upgrade() {
108 strong.borrow().dirty.set(true);
109 }
110 });
111
112 inner.borrow_mut()._subscriptions.push(sub);
113
114 Self { inner }
115 }
116
117 pub fn from2<S1, S2>(
119 s1: &Observable<S1>,
120 s2: &Observable<S2>,
121 map: impl Fn(&S1, &S2) -> T + 'static,
122 ) -> Self
123 where
124 S1: Clone + PartialEq + 'static,
125 S2: Clone + PartialEq + 'static,
126 {
127 let s1_clone = s1.clone();
128 let s2_clone = s2.clone();
129 let compute = Box::new(move || s1_clone.with(|v1| s2_clone.with(|v2| map(v1, v2))));
130
131 let inner = Rc::new(RefCell::new(ComputedInner {
132 compute,
133 cached: None,
134 dirty: Cell::new(true),
135 version: 0,
136 _subscriptions: Vec::new(),
137 }));
138
139 let weak1 = Rc::downgrade(&inner);
141 let sub1 = s1.subscribe(move |_| {
142 if let Some(strong) = weak1.upgrade() {
143 strong.borrow().dirty.set(true);
144 }
145 });
146
147 let weak2 = Rc::downgrade(&inner);
148 let sub2 = s2.subscribe(move |_| {
149 if let Some(strong) = weak2.upgrade() {
150 strong.borrow().dirty.set(true);
151 }
152 });
153
154 {
155 let mut inner_mut = inner.borrow_mut();
156 inner_mut._subscriptions.push(sub1);
157 inner_mut._subscriptions.push(sub2);
158 }
159
160 Self { inner }
161 }
162
163 pub fn from3<S1, S2, S3>(
165 s1: &Observable<S1>,
166 s2: &Observable<S2>,
167 s3: &Observable<S3>,
168 map: impl Fn(&S1, &S2, &S3) -> T + 'static,
169 ) -> Self
170 where
171 S1: Clone + PartialEq + 'static,
172 S2: Clone + PartialEq + 'static,
173 S3: Clone + PartialEq + 'static,
174 {
175 let s1_clone = s1.clone();
176 let s2_clone = s2.clone();
177 let s3_clone = s3.clone();
178 let compute = Box::new(move || {
179 s1_clone.with(|v1| s2_clone.with(|v2| s3_clone.with(|v3| map(v1, v2, v3))))
180 });
181
182 let inner = Rc::new(RefCell::new(ComputedInner {
183 compute,
184 cached: None,
185 dirty: Cell::new(true),
186 version: 0,
187 _subscriptions: Vec::new(),
188 }));
189
190 let weak1 = Rc::downgrade(&inner);
191 let sub1 = s1.subscribe(move |_| {
192 if let Some(strong) = weak1.upgrade() {
193 strong.borrow().dirty.set(true);
194 }
195 });
196
197 let weak2 = Rc::downgrade(&inner);
198 let sub2 = s2.subscribe(move |_| {
199 if let Some(strong) = weak2.upgrade() {
200 strong.borrow().dirty.set(true);
201 }
202 });
203
204 let weak3 = Rc::downgrade(&inner);
205 let sub3 = s3.subscribe(move |_| {
206 if let Some(strong) = weak3.upgrade() {
207 strong.borrow().dirty.set(true);
208 }
209 });
210
211 {
212 let mut inner_mut = inner.borrow_mut();
213 inner_mut._subscriptions.push(sub1);
214 inner_mut._subscriptions.push(sub2);
215 inner_mut._subscriptions.push(sub3);
216 }
217
218 Self { inner }
219 }
220
221 pub fn from_fn(compute: impl Fn() -> T + 'static, subscriptions: Vec<Subscription>) -> Self {
227 Self {
228 inner: Rc::new(RefCell::new(ComputedInner {
229 compute: Box::new(compute),
230 cached: None,
231 dirty: Cell::new(true),
232 version: 0,
233 _subscriptions: subscriptions,
234 })),
235 }
236 }
237
238 #[must_use]
243 pub fn get(&self) -> T {
244 let mut inner = self.inner.borrow_mut();
245 if inner.dirty.get() || inner.cached.is_none() {
246 let new_value = (inner.compute)();
247 inner.cached = Some(new_value);
248 inner.dirty.set(false);
249 inner.version += 1;
250 }
251 inner
252 .cached
253 .as_ref()
254 .expect("cached is always Some after get()")
255 .clone()
256 }
257
258 pub fn with<R>(&self, f: impl FnOnce(&T) -> R) -> R {
268 {
270 let mut inner = self.inner.borrow_mut();
271 if inner.dirty.get() || inner.cached.is_none() {
272 let new_value = (inner.compute)();
273 inner.cached = Some(new_value);
274 inner.dirty.set(false);
275 inner.version += 1;
276 }
277 }
278 let inner = self.inner.borrow();
279 f(inner
280 .cached
281 .as_ref()
282 .expect("cached is always Some after refresh"))
283 }
284
285 #[must_use]
287 pub fn is_dirty(&self) -> bool {
288 self.inner.borrow().dirty.get()
289 }
290
291 pub fn invalidate(&self) {
294 self.inner.borrow().dirty.set(true);
295 }
296
297 #[must_use]
299 pub fn version(&self) -> u64 {
300 self.inner.borrow().version
301 }
302}
303
304#[cfg(test)]
309mod tests {
310 use super::*;
311 use std::cell::Cell;
312
313 #[test]
314 fn single_dep_computed() {
315 let source = Observable::new(10);
316 let computed = Computed::from_observable(&source, |v| v * 2);
317
318 assert_eq!(computed.get(), 20);
319 assert_eq!(computed.version(), 1);
320
321 source.set(5);
322 assert!(computed.is_dirty());
323 assert_eq!(computed.get(), 10);
324 assert_eq!(computed.version(), 2);
325 }
326
327 #[test]
328 fn multi_dep_computed() {
329 let width = Observable::new(10);
330 let height = Observable::new(20);
331 let area = Computed::from2(&width, &height, |w, h| w * h);
332
333 assert_eq!(area.get(), 200);
334
335 width.set(5);
336 assert_eq!(area.get(), 100);
337
338 height.set(30);
339 assert_eq!(area.get(), 150);
340 }
341
342 #[test]
343 fn three_dep_computed() {
344 let a = Observable::new(1);
345 let b = Observable::new(2);
346 let c = Observable::new(3);
347 let sum = Computed::from3(&a, &b, &c, |x, y, z| x + y + z);
348
349 assert_eq!(sum.get(), 6);
350
351 a.set(10);
352 assert_eq!(sum.get(), 15);
353
354 c.set(100);
355 assert_eq!(sum.get(), 112);
356 }
357
358 #[test]
359 fn lazy_evaluation() {
360 let compute_count = Rc::new(Cell::new(0u32));
361 let count_clone = Rc::clone(&compute_count);
362
363 let source = Observable::new(42);
364 let source_clone = source.clone();
365 let computed = Computed::from_fn(
366 move || {
367 count_clone.set(count_clone.get() + 1);
368 source_clone.get() * 2
369 },
370 vec![],
371 );
372
373 assert_eq!(compute_count.get(), 0);
375
376 assert_eq!(computed.get(), 84);
378 assert_eq!(compute_count.get(), 1);
379
380 assert_eq!(computed.get(), 84);
382 assert_eq!(compute_count.get(), 1);
383 }
384
385 #[test]
386 fn memoization() {
387 let compute_count = Rc::new(Cell::new(0u32));
388 let count_clone = Rc::clone(&compute_count);
389
390 let source = Observable::new(10);
391 let computed = Computed::from_observable(&source, move |v| {
392 count_clone.set(count_clone.get() + 1);
393 v * 2
394 });
395
396 assert_eq!(computed.get(), 20);
398 assert_eq!(compute_count.get(), 1);
399
400 assert_eq!(computed.get(), 20);
402 assert_eq!(compute_count.get(), 1);
403
404 source.set(20);
406 assert_eq!(computed.get(), 40);
407 assert_eq!(compute_count.get(), 2);
408
409 assert_eq!(computed.get(), 40);
411 assert_eq!(compute_count.get(), 2);
412 }
413
414 #[test]
415 fn invalidate_forces_recompute() {
416 let compute_count = Rc::new(Cell::new(0u32));
417 let count_clone = Rc::clone(&compute_count);
418
419 let source = Observable::new(5);
420 let computed = Computed::from_observable(&source, move |v| {
421 count_clone.set(count_clone.get() + 1);
422 *v
423 });
424
425 assert_eq!(computed.get(), 5);
426 assert_eq!(compute_count.get(), 1);
427
428 computed.invalidate();
429 assert!(computed.is_dirty());
430
431 assert_eq!(computed.get(), 5);
432 assert_eq!(compute_count.get(), 2);
433 }
434
435 #[test]
436 fn with_access() {
437 let source = Observable::new(vec![1, 2, 3]);
438 let computed = Computed::from_observable(&source, |v| v.iter().sum::<i32>());
439
440 let result = computed.with(|sum| *sum);
441 assert_eq!(result, 6);
442 }
443
444 #[test]
445 fn version_increments_on_recompute() {
446 let source = Observable::new(0);
447 let computed = Computed::from_observable(&source, |v| *v);
448
449 assert_eq!(computed.version(), 0);
450
451 let _ = computed.get();
453 assert_eq!(computed.version(), 1);
454
455 let _ = computed.get();
457 assert_eq!(computed.version(), 1);
458
459 source.set(1);
461 let _ = computed.get();
462 assert_eq!(computed.version(), 2);
463 }
464
465 #[test]
466 fn clone_shares_state() {
467 let source = Observable::new(10);
468 let c1 = Computed::from_observable(&source, |v| v + 1);
469 let c2 = c1.clone();
470
471 assert_eq!(c1.get(), 11);
472 assert_eq!(c2.get(), 11);
473
474 source.set(20);
475 assert_eq!(c1.get(), 21);
476 assert_eq!(c2.get(), 21);
478 }
479
480 #[test]
481 fn diamond_dependency() {
482 let a = Observable::new(10);
484
485 let b = Computed::from_observable(&a, |v| v + 1);
486 let c = Computed::from_observable(&a, |v| v * 2);
487
488 let b_clone = b.clone();
490 let c_clone = c.clone();
491 let d = Computed::from_observable(&a, move |_| b_clone.get() + c_clone.get());
492
493 assert_eq!(b.get(), 11);
494 assert_eq!(c.get(), 20);
495 assert_eq!(d.get(), 31);
496
497 a.set(5);
498 assert_eq!(b.get(), 6);
499 assert_eq!(c.get(), 10);
500 assert_eq!(d.get(), 16);
501 }
502
503 #[test]
504 fn no_change_same_value() {
505 let source = Observable::new(42);
506 let compute_count = Rc::new(Cell::new(0u32));
507 let count_clone = Rc::clone(&compute_count);
508
509 let computed = Computed::from_observable(&source, move |v| {
510 count_clone.set(count_clone.get() + 1);
511 *v
512 });
513
514 let _ = computed.get();
515 assert_eq!(compute_count.get(), 1);
516
517 source.set(42);
519 assert!(!computed.is_dirty());
520 let _ = computed.get();
521 assert_eq!(compute_count.get(), 1);
522 }
523
524 #[test]
525 fn debug_format() {
526 let source = Observable::new(42);
527 let computed = Computed::from_observable(&source, |v| *v);
528 let _ = computed.get();
529 let dbg = format!("{:?}", computed);
530 assert!(dbg.contains("Computed"));
531 assert!(dbg.contains("42"));
532 }
533
534 #[test]
535 fn from_fn_with_manual_subscriptions() {
536 let source = Observable::new(10);
537
538 let computed = Computed::from_observable(&source, |v| v * 3);
540
541 assert_eq!(computed.get(), 30);
542
543 source.set(20);
544 assert_eq!(computed.get(), 60);
545
546 let source2 = Observable::new(5);
549 let s2_clone = source2.clone();
550
551 let inner_dirty = Rc::new(Cell::new(false));
552 let dirty_for_sub = Rc::clone(&inner_dirty);
553
554 let sub = source2.subscribe(move |_| {
557 dirty_for_sub.set(true);
558 });
559
560 let computed2 = Computed::from_fn(move || s2_clone.get() * 3, vec![sub]);
561
562 assert_eq!(computed2.get(), 15);
563
564 source2.set(10);
565 assert!(inner_dirty.get()); computed2.invalidate();
568 assert_eq!(computed2.get(), 30);
569 }
570
571 #[test]
572 fn string_computed() {
573 let first = Observable::new("John".to_string());
574 let last = Observable::new("Doe".to_string());
575 let full_name = Computed::from2(&first, &last, |f, l| format!("{} {}", f, l));
576
577 assert_eq!(full_name.get(), "John Doe");
578
579 first.set("Jane".to_string());
580 assert_eq!(full_name.get(), "Jane Doe");
581
582 last.set("Smith".to_string());
583 assert_eq!(full_name.get(), "Jane Smith");
584 }
585
586 #[test]
587 fn computed_survives_source_drop() {
588 let computed;
589 {
590 let source = Observable::new(42);
591 computed = Computed::from_observable(&source, |v| *v);
592 let _ = computed.get(); }
594 assert_eq!(computed.get(), 42);
596 assert!(!computed.is_dirty());
597 }
598
599 #[test]
600 fn many_updates_version_monotonic() {
601 let source = Observable::new(0);
602 let computed = Computed::from_observable(&source, |v| *v);
603
604 for i in 1..=50 {
605 source.set(i);
606 let _ = computed.get();
607 }
608 assert_eq!(computed.version(), 50);
610 }
611}