1#![allow(clippy::arc_with_non_send_sync)]
4#![allow(clippy::type_complexity)]
6
7use crate::collections::map::{HashMap, HashSet};
8use crate::snapshot_v2::{register_apply_observer, ReadObserver, StateObjectId};
9use crate::state::StateObject;
10use crate::{RecomposeScope, RecomposeScopeInner, ScopeId};
11use std::any::Any;
12use std::cell::{Cell, RefCell};
13use std::rc::{Rc, Weak};
14use std::sync::Arc;
15
16type Executor = dyn Fn(Box<dyn FnOnce() + 'static>) + 'static;
18
19#[derive(Clone)]
32pub struct SnapshotStateObserver {
33 inner: Rc<SnapshotStateObserverInner>,
34}
35
36impl SnapshotStateObserver {
37 pub fn new(on_changed_executor: impl Fn(Box<dyn FnOnce() + 'static>) + 'static) -> Self {
39 let inner = Rc::new(SnapshotStateObserverInner::new(on_changed_executor));
40 inner.set_self(Rc::downgrade(&inner));
41 Self { inner }
42 }
43
44 pub fn observe_reads<T, R>(
50 &self,
51 scope: T,
52 on_value_changed_for_scope: impl Fn(&T) + 'static,
53 block: impl FnOnce() -> R,
54 ) -> R
55 where
56 T: Any + Clone + PartialEq + 'static,
57 {
58 self.inner
59 .observe_reads(scope, on_value_changed_for_scope, block)
60 }
61
62 pub fn begin_frame(&self) {
64 self.inner.begin_frame();
65 }
66
67 pub fn with_no_observations<R>(&self, block: impl FnOnce() -> R) -> R {
69 self.inner.with_no_observations(block)
70 }
71
72 pub fn clear<T>(&self, scope: &T)
74 where
75 T: Any + PartialEq + 'static,
76 {
77 self.inner.clear(scope);
78 }
79
80 pub fn clear_if(&self, predicate: impl Fn(&dyn Any) -> bool) {
82 self.inner.clear_if(predicate);
83 }
84
85 pub fn clear_all(&self) {
87 self.inner.clear_all();
88 }
89
90 pub fn start(&self) {
92 let weak = Rc::downgrade(&self.inner);
93 self.inner.start(weak);
94 }
95
96 pub fn stop(&self) {
98 self.inner.stop();
99 }
100
101 #[cfg(test)]
103 pub fn notify_changes(&self, modified: &[Arc<dyn StateObject>]) {
104 self.inner.handle_apply(modified);
105 }
106}
107
108struct SnapshotStateObserverInner {
109 executor: Rc<Executor>,
110 scopes: RefCell<Vec<Rc<RefCell<ScopeEntry>>>>,
111 fast_scopes: RefCell<HashMap<ScopeId, Rc<RefCell<ScopeEntry>>>>,
112 pause_count: Rc<Cell<usize>>,
113 apply_handle: RefCell<Option<crate::snapshot_v2::ObserverHandle>>,
114 weak_self: RefCell<Weak<SnapshotStateObserverInner>>,
115 frame_version: Cell<u64>,
116}
117
118impl SnapshotStateObserverInner {
119 fn new(on_changed_executor: impl Fn(Box<dyn FnOnce() + 'static>) + 'static) -> Self {
120 Self {
121 executor: Rc::new(on_changed_executor),
122 scopes: RefCell::new(Vec::new()),
123 fast_scopes: RefCell::new(HashMap::default()),
124 pause_count: Rc::new(Cell::new(0)),
125 apply_handle: RefCell::new(None),
126 weak_self: RefCell::new(Weak::new()),
127 frame_version: Cell::new(0),
128 }
129 }
130
131 fn set_self(&self, weak: Weak<SnapshotStateObserverInner>) {
132 self.weak_self.replace(weak);
133 }
134
135 fn begin_frame(&self) {
136 let next = self.frame_version.get().wrapping_add(1);
137 self.frame_version.set(next);
138 self.prune_dead_scopes();
139 }
140
141 fn observe_reads<T, R>(
142 &self,
143 scope: T,
144 on_value_changed_for_scope: impl Fn(&T) + 'static,
145 block: impl FnOnce() -> R,
146 ) -> R
147 where
148 T: Any + Clone + PartialEq + 'static,
149 {
150 let frame_version = self.frame_version.get();
151 let has_frame_version = frame_version != 0;
152
153 let on_changed: Rc<dyn Fn(&dyn Any)> = {
154 let callback = Rc::new(on_value_changed_for_scope);
155 Rc::new(move |scope_any: &dyn Any| {
156 if let Some(typed) = scope_any.downcast_ref::<T>() {
157 callback(typed);
158 }
159 })
160 };
161
162 let entry = self.get_scope_entry(scope.clone(), on_changed.clone());
163
164 let pause_count = self.pause_count.clone();
165
166 let read_observer: ReadObserver = {
167 let mut entry_mut = entry.borrow_mut();
168 entry_mut.update(scope, on_changed);
169
170 let already_observed =
171 has_frame_version && entry_mut.last_seen_version == frame_version;
172 if already_observed || entry_mut.is_stateless {
173 drop(entry_mut);
174 return block();
175 }
176
177 entry_mut.observed.clear();
178 entry_mut.last_seen_version = if has_frame_version {
179 frame_version
180 } else {
181 u64::MAX
182 };
183 entry_mut.is_stateless = false;
184
185 if let Some(observer) = entry_mut.read_observer.clone() {
186 observer
187 } else {
188 let entry_for_observer = entry.clone();
189 let pause_count = pause_count.clone();
190
191 let observer: ReadObserver = Arc::new(move |state| {
192 if pause_count.get() > 0 {
193 return;
194 }
195 let mut entry_ref = entry_for_observer.borrow_mut();
196 let id = state.object_id().as_usize();
197 entry_ref.observed.insert(id);
198 entry_ref.is_stateless = false;
199 });
200
201 entry_mut.read_observer = Some(observer.clone());
202 observer
203 }
204 };
205
206 let result = self.run_with_read_observer(read_observer, block);
207
208 {
209 let mut entry_mut = entry.borrow_mut();
210 if entry_mut.observed.is_empty() {
211 entry_mut.is_stateless = true;
212 }
213 }
214
215 result
216 }
217
218 fn with_no_observations<R>(&self, block: impl FnOnce() -> R) -> R {
219 self.pause_count.set(self.pause_count.get() + 1);
220 let result = block();
221 self.pause_count
222 .set(self.pause_count.get().saturating_sub(1));
223 result
224 }
225
226 fn clear<T>(&self, scope: &T)
227 where
228 T: Any + PartialEq + 'static,
229 {
230 if let Some(rc_scope) = (scope as &dyn Any).downcast_ref::<RecomposeScope>() {
231 self.fast_scopes.borrow_mut().remove(&rc_scope.id());
232 }
233
234 self.scopes
235 .borrow_mut()
236 .retain(|entry| !entry.borrow().matches_scope(scope));
237 }
238
239 fn clear_if(&self, predicate: impl Fn(&dyn Any) -> bool) {
240 self.fast_scopes
241 .borrow_mut()
242 .retain(|_, entry| !entry.borrow().matches_predicate(&predicate));
243 self.scopes
244 .borrow_mut()
245 .retain(|entry| !entry.borrow().matches_predicate(&predicate));
246 }
247
248 fn clear_all(&self) {
249 self.fast_scopes.borrow_mut().clear();
250 self.scopes.borrow_mut().clear();
251 }
252
253 #[allow(clippy::arc_with_non_send_sync)]
256 fn start(&self, weak_self: Weak<SnapshotStateObserverInner>) {
257 if self.apply_handle.borrow().is_some() {
258 return;
259 }
260
261 let handle = register_apply_observer(Arc::new(move |modified, _snapshot_id| {
262 if let Some(inner) = weak_self.upgrade() {
263 inner.handle_apply(modified);
264 }
265 }));
266 self.apply_handle.replace(Some(handle));
267 }
268
269 fn stop(&self) {
270 if let Some(handle) = self.apply_handle.borrow_mut().take() {
271 drop(handle);
272 }
273 }
274
275 fn get_scope_entry(
276 &self,
277 scope: impl Any + Clone + PartialEq + 'static,
278 on_changed: Rc<dyn Fn(&dyn Any)>,
279 ) -> Rc<RefCell<ScopeEntry>> {
280 let recompose_scope_id = (&scope as &dyn Any)
281 .downcast_ref::<RecomposeScope>()
282 .map(RecomposeScope::id);
283 if let Some(scope_id) = recompose_scope_id {
284 let mut fast = self.fast_scopes.borrow_mut();
285 if let Some(existing) = fast.get(&scope_id) {
286 return existing.clone();
287 }
288
289 let entry = Rc::new(RefCell::new(ScopeEntry::new(scope, on_changed)));
290 fast.insert(scope_id, entry.clone());
291 self.scopes.borrow_mut().push(entry.clone());
292 return entry;
293 }
294
295 let mut scopes = self.scopes.borrow_mut();
297
298 if let Some(existing) = scopes
299 .iter()
300 .find(|entry| entry.borrow().matches_scope(&scope))
301 {
302 return existing.clone();
303 }
304
305 let entry = Rc::new(RefCell::new(ScopeEntry::new(scope, on_changed)));
306 scopes.push(entry.clone());
307 entry
308 }
309
310 fn prune_dead_scopes(&self) {
311 self.fast_scopes
312 .borrow_mut()
313 .retain(|_, entry| entry.borrow().should_retain());
314 self.scopes
315 .borrow_mut()
316 .retain(|entry| entry.borrow().should_retain());
317 }
318
319 fn run_with_read_observer<R>(
320 &self,
321 read_observer: ReadObserver,
322 block: impl FnOnce() -> R,
323 ) -> R {
324 use crate::snapshot_v2::take_transparent_observer_mutable_snapshot;
327
328 let snapshot = take_transparent_observer_mutable_snapshot(Some(read_observer), None);
331 let result = snapshot.enter(block);
332 snapshot.dispose();
333 result
334 }
335
336 fn handle_apply(&self, modified: &[Arc<dyn StateObject>]) {
337 if modified.is_empty() {
338 return;
339 }
340
341 let mut modified_ids: SmallVec<[usize; MAX_OBSERVED_STATES]> = SmallVec::new();
342 for state in modified {
343 modified_ids.push(state.object_id().as_usize());
344 }
345
346 let scopes = self.scopes.borrow();
347 let mut to_notify: Vec<Rc<RefCell<ScopeEntry>>> = Vec::new();
348 let mut seen: HashSet<usize> = HashSet::default();
349
350 for entry in scopes.iter() {
351 let entry_ref = entry.borrow();
352 if entry_ref
353 .observed
354 .iter()
355 .any(|id| modified_ids.contains(id))
356 {
357 let ptr = Rc::as_ptr(entry) as usize;
358 if seen.insert(ptr) {
359 to_notify.push(entry.clone());
360 }
361 }
362 }
363 drop(scopes);
364
365 if to_notify.is_empty() {
366 return;
367 }
368
369 for entry in to_notify {
370 let executor = self.executor.clone();
371 executor(Box::new(move || {
372 if let Ok(entry) = entry.try_borrow() {
373 entry.notify();
374 }
375 }));
376 }
377 }
378}
379use smallvec::SmallVec;
380
381enum ObservedIds {
382 Small(SmallVec<[StateObjectId; MAX_OBSERVED_STATES]>),
383 Large(HashSet<StateObjectId>),
384}
385
386impl ObservedIds {
387 fn new() -> Self {
388 ObservedIds::Small(SmallVec::new())
389 }
390
391 fn insert(&mut self, id: StateObjectId) {
392 match self {
393 ObservedIds::Small(small) => {
394 if small.contains(&id) {
395 return;
396 }
397 if small.len() < MAX_OBSERVED_STATES {
398 small.push(id);
399 } else {
400 let mut large =
401 HashSet::with_capacity_and_hasher(small.len() + 1, Default::default());
402 for existing in small.iter() {
403 large.insert(*existing);
404 }
405 large.insert(id);
406 *self = ObservedIds::Large(large);
407 }
408 }
409 ObservedIds::Large(large) => {
410 large.insert(id);
411 }
412 }
413 }
414
415 fn is_empty(&self) -> bool {
416 match self {
417 ObservedIds::Small(small) => small.is_empty(),
418 ObservedIds::Large(large) => large.is_empty(),
419 }
420 }
421
422 fn clear(&mut self) {
423 match self {
424 ObservedIds::Small(small) => small.clear(),
425 ObservedIds::Large(large) => large.clear(),
426 }
427 }
428
429 fn iter(&self) -> Box<dyn Iterator<Item = &StateObjectId> + '_> {
430 match self {
431 ObservedIds::Small(small) => Box::new(small.iter()),
432 ObservedIds::Large(large) => Box::new(large.iter()),
433 }
434 }
435}
436
437const MAX_OBSERVED_STATES: usize = 8;
438
439enum ScopeStorage {
440 Owned(Box<dyn Any>),
441 RecomposeScope {
442 id: ScopeId,
443 weak: Weak<RecomposeScopeInner>,
444 },
445}
446
447struct ScopeEntry {
448 scope: ScopeStorage,
449 on_changed: Rc<dyn Fn(&dyn Any)>,
450 observed: ObservedIds,
451 read_observer: Option<ReadObserver>,
452 is_stateless: bool,
453 last_seen_version: u64,
454}
455
456impl ScopeEntry {
457 fn new<T>(scope: T, on_changed: Rc<dyn Fn(&dyn Any)>) -> Self
458 where
459 T: Any + 'static,
460 {
461 Self {
462 scope: ScopeStorage::from_value(scope),
463 on_changed,
464 observed: ObservedIds::new(),
465 read_observer: None,
466 is_stateless: false,
467 last_seen_version: u64::MAX,
468 }
469 }
470
471 fn update<T>(&mut self, new_scope: T, on_changed: Rc<dyn Fn(&dyn Any)>)
472 where
473 T: Any + 'static,
474 {
475 self.scope = ScopeStorage::from_value(new_scope);
476 self.on_changed = on_changed;
477 }
478
479 fn matches_scope<T>(&self, scope: &T) -> bool
480 where
481 T: Any + PartialEq + 'static,
482 {
483 if let Some(scope) = (scope as &dyn Any).downcast_ref::<RecomposeScope>() {
484 return matches!(
485 &self.scope,
486 ScopeStorage::RecomposeScope { id, .. } if *id == scope.id()
487 );
488 }
489
490 match &self.scope {
491 ScopeStorage::Owned(stored) => stored
492 .downcast_ref::<T>()
493 .map(|stored| stored == scope)
494 .unwrap_or(false),
495 ScopeStorage::RecomposeScope { .. } => false,
496 }
497 }
498
499 fn matches_predicate(&self, predicate: &impl Fn(&dyn Any) -> bool) -> bool {
500 match &self.scope {
501 ScopeStorage::Owned(scope) => predicate(scope.as_ref()),
502 ScopeStorage::RecomposeScope { weak, .. } => weak
503 .upgrade()
504 .map(|inner| predicate(&RecomposeScope { inner }))
505 .unwrap_or(true),
506 }
507 }
508
509 fn should_retain(&self) -> bool {
510 match &self.scope {
511 ScopeStorage::Owned(_) => true,
512 ScopeStorage::RecomposeScope { weak, .. } => weak.upgrade().is_some(),
513 }
514 }
515
516 fn notify(&self) {
517 match &self.scope {
518 ScopeStorage::Owned(scope) => (self.on_changed)(scope.as_ref()),
519 ScopeStorage::RecomposeScope { weak, .. } => {
520 if let Some(inner) = weak.upgrade() {
521 (self.on_changed)(&RecomposeScope { inner });
522 }
523 }
524 }
525 }
526}
527
528impl ScopeStorage {
529 fn from_value<T>(value: T) -> Self
530 where
531 T: Any + 'static,
532 {
533 let any = &value as &dyn Any;
534 if let Some(scope) = any.downcast_ref::<RecomposeScope>() {
535 Self::RecomposeScope {
536 id: scope.id(),
537 weak: scope.downgrade(),
538 }
539 } else {
540 Self::Owned(Box::new(value))
541 }
542 }
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548 use crate::snapshot_v2::take_mutable_snapshot;
549 use crate::snapshot_v2::{reset_runtime_for_tests, TestRuntimeGuard};
550 use crate::state::{NeverEqual, SnapshotMutableState};
551 use std::cell::Cell;
552
553 fn reset_runtime() -> TestRuntimeGuard {
554 reset_runtime_for_tests()
555 }
556
557 #[derive(Clone, PartialEq)]
558 struct TestScope(&'static str);
559
560 #[test]
561 fn notifies_scope_when_state_changes() {
562 let _guard = reset_runtime();
563
564 let state = SnapshotMutableState::new_in_arc(0, Arc::new(NeverEqual));
565 let triggered = Rc::new(Cell::new(0));
566 let observer_trigger = triggered.clone();
567
568 let observer = SnapshotStateObserver::new(|callback| callback());
569 observer.start();
570
571 let scope = TestScope("scope");
572 observer.observe_reads(
573 scope.clone(),
574 move |_| {
575 observer_trigger.set(observer_trigger.get() + 1);
576 },
577 || {
578 let _ = state.get();
579 },
580 );
581
582 let snapshot = take_mutable_snapshot(None, None);
583 snapshot.enter(|| {
584 state.set(1);
585 });
586 snapshot.apply().check();
587
588 assert_eq!(triggered.get(), 1);
589 observer.stop();
590 }
591
592 #[test]
593 fn clear_removes_scope_observation() {
594 let _guard = reset_runtime();
595
596 let state = SnapshotMutableState::new_in_arc(0, Arc::new(NeverEqual));
597 let triggered = Rc::new(Cell::new(0));
598 let observer_trigger = triggered.clone();
599
600 let observer = SnapshotStateObserver::new(|callback| callback());
601 observer.start();
602
603 let scope = TestScope("scope");
604 observer.observe_reads(
605 scope.clone(),
606 move |_| {
607 observer_trigger.set(observer_trigger.get() + 1);
608 },
609 || {
610 let _ = state.get();
611 },
612 );
613
614 observer.clear(&scope);
615
616 let snapshot = take_mutable_snapshot(None, None);
617 snapshot.enter(|| {
618 state.set(1);
619 });
620 snapshot.apply().check();
621
622 assert_eq!(triggered.get(), 0);
623 observer.stop();
624 }
625
626 #[test]
627 fn with_no_observations_skips_reads() {
628 let _guard = reset_runtime();
629
630 let state = SnapshotMutableState::new_in_arc(0, Arc::new(NeverEqual));
631 let triggered = Rc::new(Cell::new(0));
632 let observer_trigger = triggered.clone();
633
634 let observer = SnapshotStateObserver::new(|callback| callback());
635 observer.start();
636
637 let scope = TestScope("scope");
638 observer.observe_reads(
639 scope.clone(),
640 move |_| {
641 observer_trigger.set(observer_trigger.get() + 1);
642 },
643 || {
644 observer.with_no_observations(|| {
645 let _ = state.get();
646 });
647 },
648 );
649
650 let snapshot = take_mutable_snapshot(None, None);
651 snapshot.enter(|| {
652 state.set(1);
653 });
654 snapshot.apply().check();
655
656 assert_eq!(triggered.get(), 0);
657 observer.stop();
658 }
659
660 #[test]
661 fn begin_frame_prunes_dropped_recompose_scope_entries() {
662 let _guard = reset_runtime();
663
664 let state = SnapshotMutableState::new_in_arc(0, Arc::new(NeverEqual));
665 let observer = SnapshotStateObserver::new(|callback| callback());
666 let runtime = crate::TestRuntime::new();
667 let scope = RecomposeScope::new_for_test(runtime.handle());
668
669 observer.observe_reads(
670 scope.clone(),
671 |_| {},
672 || {
673 let _ = state.get();
674 },
675 );
676
677 assert_eq!(observer.inner.scopes.borrow().len(), 1);
678 assert_eq!(observer.inner.fast_scopes.borrow().len(), 1);
679
680 drop(scope);
681 observer.begin_frame();
682
683 assert_eq!(observer.inner.scopes.borrow().len(), 0);
684 assert_eq!(observer.inner.fast_scopes.borrow().len(), 0);
685 }
686}