1#![allow(clippy::arc_with_non_send_sync)]
4#![allow(clippy::type_complexity)]
6
7use crate::collections::map::HashSet;
8use crate::snapshot_v2::{register_apply_observer, ReadObserver, StateObjectId};
9use crate::state::StateObject;
10use std::any::Any;
11use std::cell::{Cell, RefCell};
12use std::rc::{Rc, Weak};
13use std::sync::Arc;
14
15type Executor = dyn Fn(Box<dyn FnOnce() + 'static>) + 'static;
17
18#[derive(Clone)]
31pub struct SnapshotStateObserver {
32 inner: Rc<SnapshotStateObserverInner>,
33}
34
35impl SnapshotStateObserver {
36 pub fn new(on_changed_executor: impl Fn(Box<dyn FnOnce() + 'static>) + 'static) -> Self {
38 let inner = Rc::new(SnapshotStateObserverInner::new(on_changed_executor));
39 inner.set_self(Rc::downgrade(&inner));
40 Self { inner }
41 }
42
43 pub fn observe_reads<T, R>(
49 &self,
50 scope: T,
51 on_value_changed_for_scope: impl Fn(&T) + 'static,
52 block: impl FnOnce() -> R,
53 ) -> R
54 where
55 T: Any + Clone + PartialEq + 'static,
56 {
57 self.inner
58 .observe_reads(scope, on_value_changed_for_scope, block)
59 }
60
61 pub fn begin_frame(&self) {
63 self.inner.begin_frame();
64 }
65
66 pub fn with_no_observations<R>(&self, block: impl FnOnce() -> R) -> R {
68 self.inner.with_no_observations(block)
69 }
70
71 pub fn clear<T>(&self, scope: &T)
73 where
74 T: Any + PartialEq + 'static,
75 {
76 self.inner.clear(scope);
77 }
78
79 pub fn clear_if(&self, predicate: impl Fn(&dyn Any) -> bool) {
81 self.inner.clear_if(predicate);
82 }
83
84 pub fn clear_all(&self) {
86 self.inner.clear_all();
87 }
88
89 pub fn start(&self) {
91 let weak = Rc::downgrade(&self.inner);
92 self.inner.start(weak);
93 }
94
95 pub fn stop(&self) {
97 self.inner.stop();
98 }
99
100 #[cfg(test)]
102 pub fn notify_changes(&self, modified: &[Arc<dyn StateObject>]) {
103 self.inner.handle_apply(modified);
104 }
105}
106
107struct SnapshotStateObserverInner {
108 executor: Rc<Executor>,
109 scopes: RefCell<Vec<Rc<RefCell<ScopeEntry>>>>,
110 fast_scopes: RefCell<Vec<Option<Rc<RefCell<ScopeEntry>>>>>,
111 pause_count: Rc<Cell<usize>>,
112 apply_handle: RefCell<Option<crate::snapshot_v2::ObserverHandle>>,
113 weak_self: RefCell<Weak<SnapshotStateObserverInner>>,
114 frame_version: Cell<u64>,
115}
116
117impl SnapshotStateObserverInner {
118 fn new(on_changed_executor: impl Fn(Box<dyn FnOnce() + 'static>) + 'static) -> Self {
119 Self {
120 executor: Rc::new(on_changed_executor),
121 scopes: RefCell::new(Vec::new()),
122 fast_scopes: RefCell::new(Vec::new()),
123 pause_count: Rc::new(Cell::new(0)),
124 apply_handle: RefCell::new(None),
125 weak_self: RefCell::new(Weak::new()),
126 frame_version: Cell::new(0),
127 }
128 }
129
130 fn set_self(&self, weak: Weak<SnapshotStateObserverInner>) {
131 self.weak_self.replace(weak);
132 }
133
134 fn begin_frame(&self) {
135 let next = self.frame_version.get().wrapping_add(1);
136 self.frame_version.set(next);
137 }
138
139 fn observe_reads<T, R>(
140 &self,
141 scope: T,
142 on_value_changed_for_scope: impl Fn(&T) + 'static,
143 block: impl FnOnce() -> R,
144 ) -> R
145 where
146 T: Any + Clone + PartialEq + 'static,
147 {
148 let frame_version = self.frame_version.get();
149 let has_frame_version = frame_version != 0;
150
151 let on_changed: Rc<dyn Fn(&dyn Any)> = {
152 let callback = Rc::new(on_value_changed_for_scope);
153 Rc::new(move |scope_any: &dyn Any| {
154 if let Some(typed) = scope_any.downcast_ref::<T>() {
155 callback(typed);
156 }
157 })
158 };
159
160 let entry = self.get_scope_entry(scope.clone(), on_changed.clone());
161
162 let pause_count = self.pause_count.clone();
163
164 let read_observer: ReadObserver = {
165 let mut entry_mut = entry.borrow_mut();
166 entry_mut.update(scope, on_changed);
167
168 let already_observed =
169 has_frame_version && entry_mut.last_seen_version == frame_version;
170 if already_observed || entry_mut.is_stateless {
171 drop(entry_mut);
172 return block();
173 }
174
175 entry_mut.observed.clear();
176 entry_mut.last_seen_version = if has_frame_version {
177 frame_version
178 } else {
179 u64::MAX
180 };
181 entry_mut.is_stateless = false;
182
183 if let Some(observer) = entry_mut.read_observer.clone() {
184 observer
185 } else {
186 let entry_for_observer = entry.clone();
187 let pause_count = pause_count.clone();
188
189 let observer: ReadObserver = Arc::new(move |state| {
190 if pause_count.get() > 0 {
191 return;
192 }
193 let mut entry_ref = entry_for_observer.borrow_mut();
194 let id = state.object_id().as_usize();
195 entry_ref.observed.insert(id);
196 entry_ref.is_stateless = false;
197 });
198
199 entry_mut.read_observer = Some(observer.clone());
200 observer
201 }
202 };
203
204 let result = self.run_with_read_observer(read_observer, block);
205
206 {
207 let mut entry_mut = entry.borrow_mut();
208 if entry_mut.observed.is_empty() {
209 entry_mut.is_stateless = true;
210 }
211 }
212
213 result
214 }
215
216 fn with_no_observations<R>(&self, block: impl FnOnce() -> R) -> R {
217 self.pause_count.set(self.pause_count.get() + 1);
218 let result = block();
219 self.pause_count
220 .set(self.pause_count.get().saturating_sub(1));
221 result
222 }
223
224 fn clear<T>(&self, scope: &T)
225 where
226 T: Any + PartialEq + 'static,
227 {
228 if let Some(rc_scope) = (scope as &dyn Any).downcast_ref::<RecomposeScope>() {
230 let id = rc_scope.id();
231 let mut fast = self.fast_scopes.borrow_mut();
232 if id < fast.len() {
233 fast[id] = None;
234 }
235 }
236
237 self.scopes
239 .borrow_mut()
240 .retain(|entry| !entry.borrow().matches_scope(scope));
241 }
242
243 fn clear_if(&self, predicate: impl Fn(&dyn Any) -> bool) {
244 let mut fast = self.fast_scopes.borrow_mut();
246 for slot in fast.iter_mut() {
247 if let Some(entry) = slot {
248 let should_clear = {
249 let entry_ref = entry.borrow();
250 predicate(entry_ref.scope())
251 };
252 if should_clear {
253 *slot = None;
254 }
255 }
256 }
257 drop(fast);
258
259 self.scopes.borrow_mut().retain(|entry| {
261 let entry_ref = entry.borrow();
262 !predicate(entry_ref.scope())
263 });
264 }
265
266 fn clear_all(&self) {
267 self.fast_scopes.borrow_mut().clear();
268 self.scopes.borrow_mut().clear();
269 }
270
271 #[allow(clippy::arc_with_non_send_sync)]
274 fn start(&self, weak_self: Weak<SnapshotStateObserverInner>) {
275 if self.apply_handle.borrow().is_some() {
276 return;
277 }
278
279 let handle = register_apply_observer(Arc::new(move |modified, _snapshot_id| {
280 if let Some(inner) = weak_self.upgrade() {
281 inner.handle_apply(modified);
282 }
283 }));
284 self.apply_handle.replace(Some(handle));
285 }
286
287 fn stop(&self) {
288 if let Some(handle) = self.apply_handle.borrow_mut().take() {
289 drop(handle);
290 }
291 }
292
293 fn get_scope_entry(
294 &self,
295 scope: impl Any + Clone + PartialEq + 'static,
296 on_changed: Rc<dyn Fn(&dyn Any)>,
297 ) -> Rc<RefCell<ScopeEntry>> {
298 if let Some(rc_scope) = (&scope as &dyn Any).downcast_ref::<RecomposeScope>() {
300 let id: usize = rc_scope.id(); let mut fast = self.fast_scopes.borrow_mut();
303
304 if id >= fast.len() {
305 fast.resize_with(id + 1, || None);
306 }
307
308 if let Some(existing) = &fast[id] {
309 return existing.clone();
310 }
311
312 let entry = Rc::new(RefCell::new(ScopeEntry::new(scope, on_changed)));
313 fast[id] = Some(entry.clone());
314 drop(fast);
316 self.scopes.borrow_mut().push(entry.clone());
317 return entry;
318 }
319
320 let mut scopes = self.scopes.borrow_mut();
322
323 if let Some(existing) = scopes
324 .iter()
325 .find(|entry| entry.borrow().matches_scope(&scope))
326 {
327 return existing.clone();
328 }
329
330 let entry = Rc::new(RefCell::new(ScopeEntry::new(scope, on_changed)));
331 scopes.push(entry.clone());
332 entry
333 }
334
335 fn run_with_read_observer<R>(
336 &self,
337 read_observer: ReadObserver,
338 block: impl FnOnce() -> R,
339 ) -> R {
340 use crate::snapshot_v2::take_transparent_observer_mutable_snapshot;
343
344 let snapshot = take_transparent_observer_mutable_snapshot(Some(read_observer), None);
347 let result = snapshot.enter(block);
348 snapshot.dispose();
349 result
350 }
351
352 fn handle_apply(&self, modified: &[Arc<dyn StateObject>]) {
353 if modified.is_empty() {
354 return;
355 }
356
357 let mut modified_ids: SmallVec<[usize; MAX_OBSERVED_STATES]> = SmallVec::new();
358 for state in modified {
359 modified_ids.push(state.object_id().as_usize());
360 }
361
362 let scopes = self.scopes.borrow();
363 let mut to_notify: Vec<Rc<RefCell<ScopeEntry>>> = Vec::new();
364 let mut seen: HashSet<usize> = HashSet::default();
365
366 for entry in scopes.iter() {
367 let entry_ref = entry.borrow();
368 if entry_ref
369 .observed
370 .iter()
371 .any(|id| modified_ids.contains(id))
372 {
373 let ptr = Rc::as_ptr(entry) as usize;
374 if seen.insert(ptr) {
375 to_notify.push(entry.clone());
376 }
377 }
378 }
379 drop(scopes);
380
381 {
382 let fast_scopes = self.fast_scopes.borrow();
383 for entry in fast_scopes.iter().flatten() {
384 let entry_ref = entry.borrow();
385 if entry_ref
386 .observed
387 .iter()
388 .any(|id| modified_ids.contains(id))
389 {
390 let ptr = Rc::as_ptr(entry) as usize;
391 if seen.insert(ptr) {
392 to_notify.push(entry.clone());
393 }
394 }
395 }
396 }
397
398 if to_notify.is_empty() {
399 return;
400 }
401
402 for entry in to_notify {
403 let executor = self.executor.clone();
404 executor(Box::new(move || {
405 if let Ok(entry) = entry.try_borrow() {
406 entry.notify();
407 }
408 }));
409 }
410 }
411}
412
413use cranpose_core::RecomposeScope;
414use smallvec::SmallVec;
415
416enum ObservedIds {
417 Small(SmallVec<[StateObjectId; MAX_OBSERVED_STATES]>),
418 Large(HashSet<StateObjectId>),
419}
420
421impl ObservedIds {
422 fn new() -> Self {
423 ObservedIds::Small(SmallVec::new())
424 }
425
426 fn insert(&mut self, id: StateObjectId) {
427 match self {
428 ObservedIds::Small(small) => {
429 if small.contains(&id) {
430 return;
431 }
432 if small.len() < MAX_OBSERVED_STATES {
433 small.push(id);
434 } else {
435 let mut large =
436 HashSet::with_capacity_and_hasher(small.len() + 1, Default::default());
437 for existing in small.iter() {
438 large.insert(*existing);
439 }
440 large.insert(id);
441 *self = ObservedIds::Large(large);
442 }
443 }
444 ObservedIds::Large(large) => {
445 large.insert(id);
446 }
447 }
448 }
449
450 fn is_empty(&self) -> bool {
451 match self {
452 ObservedIds::Small(small) => small.is_empty(),
453 ObservedIds::Large(large) => large.is_empty(),
454 }
455 }
456
457 fn clear(&mut self) {
458 match self {
459 ObservedIds::Small(small) => small.clear(),
460 ObservedIds::Large(large) => large.clear(),
461 }
462 }
463
464 fn iter(&self) -> Box<dyn Iterator<Item = &StateObjectId> + '_> {
465 match self {
466 ObservedIds::Small(small) => Box::new(small.iter()),
467 ObservedIds::Large(large) => Box::new(large.iter()),
468 }
469 }
470}
471
472const MAX_OBSERVED_STATES: usize = 8;
473struct ScopeEntry {
474 scope: Box<dyn Any>,
475 on_changed: Rc<dyn Fn(&dyn Any)>,
476 observed: ObservedIds,
477 read_observer: Option<ReadObserver>,
478 is_stateless: bool,
479 last_seen_version: u64,
480}
481
482impl ScopeEntry {
483 fn new<T>(scope: T, on_changed: Rc<dyn Fn(&dyn Any)>) -> Self
484 where
485 T: Any + 'static,
486 {
487 Self {
488 scope: Box::new(scope),
489 on_changed,
490 observed: ObservedIds::new(),
491 read_observer: None,
492 is_stateless: false,
493 last_seen_version: u64::MAX,
494 }
495 }
496
497 fn update<T>(&mut self, new_scope: T, on_changed: Rc<dyn Fn(&dyn Any)>)
498 where
499 T: Any + 'static,
500 {
501 self.scope = Box::new(new_scope);
502 self.on_changed = on_changed;
503 }
504
505 fn matches_scope<T>(&self, scope: &T) -> bool
506 where
507 T: Any + PartialEq + 'static,
508 {
509 self.scope
510 .downcast_ref::<T>()
511 .map(|stored| stored == scope)
512 .unwrap_or(false)
513 }
514
515 fn scope(&self) -> &dyn Any {
516 &*self.scope
517 }
518
519 fn notify(&self) {
520 (self.on_changed)(self.scope());
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527 use crate::snapshot_v2::take_mutable_snapshot;
528 use crate::snapshot_v2::{reset_runtime_for_tests, TestRuntimeGuard};
529 use crate::state::{NeverEqual, SnapshotMutableState};
530 use std::cell::Cell;
531
532 fn reset_runtime() -> TestRuntimeGuard {
533 reset_runtime_for_tests()
534 }
535
536 #[derive(Clone, PartialEq)]
537 struct TestScope(&'static str);
538
539 #[test]
540 fn notifies_scope_when_state_changes() {
541 let _guard = reset_runtime();
542
543 let state = SnapshotMutableState::new_in_arc(0, Arc::new(NeverEqual));
544 let triggered = Rc::new(Cell::new(0));
545 let observer_trigger = triggered.clone();
546
547 let observer = SnapshotStateObserver::new(|callback| callback());
548 observer.start();
549
550 let scope = TestScope("scope");
551 observer.observe_reads(
552 scope.clone(),
553 move |_| {
554 observer_trigger.set(observer_trigger.get() + 1);
555 },
556 || {
557 let _ = state.get();
558 },
559 );
560
561 let snapshot = take_mutable_snapshot(None, None);
562 snapshot.enter(|| {
563 state.set(1);
564 });
565 snapshot.apply().check();
566
567 assert_eq!(triggered.get(), 1);
568 observer.stop();
569 }
570
571 #[test]
572 fn clear_removes_scope_observation() {
573 let _guard = reset_runtime();
574
575 let state = SnapshotMutableState::new_in_arc(0, Arc::new(NeverEqual));
576 let triggered = Rc::new(Cell::new(0));
577 let observer_trigger = triggered.clone();
578
579 let observer = SnapshotStateObserver::new(|callback| callback());
580 observer.start();
581
582 let scope = TestScope("scope");
583 observer.observe_reads(
584 scope.clone(),
585 move |_| {
586 observer_trigger.set(observer_trigger.get() + 1);
587 },
588 || {
589 let _ = state.get();
590 },
591 );
592
593 observer.clear(&scope);
594
595 let snapshot = take_mutable_snapshot(None, None);
596 snapshot.enter(|| {
597 state.set(1);
598 });
599 snapshot.apply().check();
600
601 assert_eq!(triggered.get(), 0);
602 observer.stop();
603 }
604
605 #[test]
606 fn with_no_observations_skips_reads() {
607 let _guard = reset_runtime();
608
609 let state = SnapshotMutableState::new_in_arc(0, Arc::new(NeverEqual));
610 let triggered = Rc::new(Cell::new(0));
611 let observer_trigger = triggered.clone();
612
613 let observer = SnapshotStateObserver::new(|callback| callback());
614 observer.start();
615
616 let scope = TestScope("scope");
617 observer.observe_reads(
618 scope.clone(),
619 move |_| {
620 observer_trigger.set(observer_trigger.get() + 1);
621 },
622 || {
623 observer.with_no_observations(|| {
624 let _ = state.get();
625 });
626 },
627 );
628
629 let snapshot = take_mutable_snapshot(None, None);
630 snapshot.enter(|| {
631 state.set(1);
632 });
633 snapshot.apply().check();
634
635 assert_eq!(triggered.get(), 0);
636 observer.stop();
637 }
638}