Skip to main content

cranpose_core/
snapshot_state_observer.rs

1// Observer callbacks use Arc for shared ownership but may capture non-Send types.
2// This is safe because callbacks are always invoked on the UI thread where they were created.
3#![allow(clippy::arc_with_non_send_sync)]
4// Complex types are inherent to the observer pattern with nested callbacks and state tracking
5#![allow(clippy::type_complexity)]
6
7use crate::collections::map::{HashMap, HashSet};
8use crate::snapshot_v2::{register_apply_observer, ReadObserver, StateObjectId};
9use crate::state::StateObject;
10use crate::{RecomposeScope, RecomposeScopeInner, ScopeId};
11use std::any::Any;
12use std::cell::{Cell, RefCell};
13use std::rc::{Rc, Weak};
14use std::sync::Arc;
15
16/// Executes a callback once changes are delivered.
17type Executor = dyn Fn(Box<dyn FnOnce() + 'static>) + 'static;
18
19/// Observer that records state object reads performed inside a given scope and
20/// notifies the caller when any of the observed objects change.
21///
22/// This is a pragmatic Rust translation of Jetpack Compose's
23/// `SnapshotStateObserver`. The implementation focuses on the core behaviour
24/// needed by the Cranpose runtime:
25/// - Tracking state object reads per logical scope.
26/// - Reacting to snapshot apply notifications.
27/// - Scheduling invalidation callbacks via the supplied executor.
28///
29/// Advanced features from the Kotlin version (derived state tracking, change
30/// coalescing, queue minimisation) are deferred
31#[derive(Clone)]
32pub struct SnapshotStateObserver {
33    inner: Rc<SnapshotStateObserverInner>,
34}
35
36impl SnapshotStateObserver {
37    /// Create a new observer that schedules callbacks using `on_changed_executor`.
38    pub fn new(on_changed_executor: impl Fn(Box<dyn FnOnce() + 'static>) + 'static) -> Self {
39        let inner = Rc::new(SnapshotStateObserverInner::new(on_changed_executor));
40        inner.set_self(Rc::downgrade(&inner));
41        Self { inner }
42    }
43
44    /// Observe state object reads performed while executing `block`.
45    ///
46    /// Subsequent calls to `observe_reads` replace any previously recorded
47    /// observations for the provided `scope`. When one of the observed objects
48    /// mutates, `on_value_changed_for_scope` will be invoked on the executor.
49    pub fn observe_reads<T, R>(
50        &self,
51        scope: T,
52        on_value_changed_for_scope: impl Fn(&T) + 'static,
53        block: impl FnOnce() -> R,
54    ) -> R
55    where
56        T: Any + Clone + PartialEq + 'static,
57    {
58        self.inner
59            .observe_reads(scope, on_value_changed_for_scope, block)
60    }
61
62    /// Notify the observer that a new composition frame is starting.
63    pub fn begin_frame(&self) {
64        self.inner.begin_frame();
65    }
66
67    /// Temporarily pause read observation while executing `block`.
68    pub fn with_no_observations<R>(&self, block: impl FnOnce() -> R) -> R {
69        self.inner.with_no_observations(block)
70    }
71
72    /// Remove any recorded reads for `scope`.
73    pub fn clear<T>(&self, scope: &T)
74    where
75        T: Any + PartialEq + 'static,
76    {
77        self.inner.clear(scope);
78    }
79
80    /// Remove recorded reads for scopes that satisfy `predicate`.
81    pub fn clear_if(&self, predicate: impl Fn(&dyn Any) -> bool) {
82        self.inner.clear_if(predicate);
83    }
84
85    /// Remove all recorded observations.
86    pub fn clear_all(&self) {
87        self.inner.clear_all();
88    }
89
90    /// Begin listening for snapshot apply notifications.
91    pub fn start(&self) {
92        let weak = Rc::downgrade(&self.inner);
93        self.inner.start(weak);
94    }
95
96    /// Stop listening for snapshot apply notifications.
97    pub fn stop(&self) {
98        self.inner.stop();
99    }
100
101    /// Test-only helper to simulate snapshot changes.
102    #[cfg(test)]
103    pub fn notify_changes(&self, modified: &[Arc<dyn StateObject>]) {
104        self.inner.handle_apply(modified);
105    }
106}
107
108struct SnapshotStateObserverInner {
109    executor: Rc<Executor>,
110    scopes: RefCell<Vec<Rc<RefCell<ScopeEntry>>>>,
111    fast_scopes: RefCell<HashMap<ScopeId, Rc<RefCell<ScopeEntry>>>>,
112    pause_count: Rc<Cell<usize>>,
113    apply_handle: RefCell<Option<crate::snapshot_v2::ObserverHandle>>,
114    weak_self: RefCell<Weak<SnapshotStateObserverInner>>,
115    frame_version: Cell<u64>,
116}
117
118impl SnapshotStateObserverInner {
119    fn new(on_changed_executor: impl Fn(Box<dyn FnOnce() + 'static>) + 'static) -> Self {
120        Self {
121            executor: Rc::new(on_changed_executor),
122            scopes: RefCell::new(Vec::new()),
123            fast_scopes: RefCell::new(HashMap::default()),
124            pause_count: Rc::new(Cell::new(0)),
125            apply_handle: RefCell::new(None),
126            weak_self: RefCell::new(Weak::new()),
127            frame_version: Cell::new(0),
128        }
129    }
130
131    fn set_self(&self, weak: Weak<SnapshotStateObserverInner>) {
132        self.weak_self.replace(weak);
133    }
134
135    fn begin_frame(&self) {
136        let next = self.frame_version.get().wrapping_add(1);
137        self.frame_version.set(next);
138        self.prune_dead_scopes();
139    }
140
141    fn observe_reads<T, R>(
142        &self,
143        scope: T,
144        on_value_changed_for_scope: impl Fn(&T) + 'static,
145        block: impl FnOnce() -> R,
146    ) -> R
147    where
148        T: Any + Clone + PartialEq + 'static,
149    {
150        let frame_version = self.frame_version.get();
151        let has_frame_version = frame_version != 0;
152
153        let on_changed: Rc<dyn Fn(&dyn Any)> = {
154            let callback = Rc::new(on_value_changed_for_scope);
155            Rc::new(move |scope_any: &dyn Any| {
156                if let Some(typed) = scope_any.downcast_ref::<T>() {
157                    callback(typed);
158                }
159            })
160        };
161
162        let entry = self.get_scope_entry(scope.clone(), on_changed.clone());
163
164        let pause_count = self.pause_count.clone();
165
166        let read_observer: ReadObserver = {
167            let mut entry_mut = entry.borrow_mut();
168            entry_mut.update(scope, on_changed);
169
170            let already_observed =
171                has_frame_version && entry_mut.last_seen_version == frame_version;
172            if already_observed || entry_mut.is_stateless {
173                drop(entry_mut);
174                return block();
175            }
176
177            entry_mut.observed.clear();
178            entry_mut.last_seen_version = if has_frame_version {
179                frame_version
180            } else {
181                u64::MAX
182            };
183            entry_mut.is_stateless = false;
184
185            if let Some(observer) = entry_mut.read_observer.clone() {
186                observer
187            } else {
188                let entry_for_observer = entry.clone();
189                let pause_count = pause_count.clone();
190
191                let observer: ReadObserver = Arc::new(move |state| {
192                    if pause_count.get() > 0 {
193                        return;
194                    }
195                    let mut entry_ref = entry_for_observer.borrow_mut();
196                    let id = state.object_id().as_usize();
197                    entry_ref.observed.insert(id);
198                    entry_ref.is_stateless = false;
199                });
200
201                entry_mut.read_observer = Some(observer.clone());
202                observer
203            }
204        };
205
206        let result = self.run_with_read_observer(read_observer, block);
207
208        {
209            let mut entry_mut = entry.borrow_mut();
210            if entry_mut.observed.is_empty() {
211                entry_mut.is_stateless = true;
212            }
213        }
214
215        result
216    }
217
218    fn with_no_observations<R>(&self, block: impl FnOnce() -> R) -> R {
219        self.pause_count.set(self.pause_count.get() + 1);
220        let result = block();
221        self.pause_count
222            .set(self.pause_count.get().saturating_sub(1));
223        result
224    }
225
226    fn clear<T>(&self, scope: &T)
227    where
228        T: Any + PartialEq + 'static,
229    {
230        if let Some(rc_scope) = (scope as &dyn Any).downcast_ref::<RecomposeScope>() {
231            self.fast_scopes.borrow_mut().remove(&rc_scope.id());
232        }
233
234        self.scopes
235            .borrow_mut()
236            .retain(|entry| !entry.borrow().matches_scope(scope));
237    }
238
239    fn clear_if(&self, predicate: impl Fn(&dyn Any) -> bool) {
240        self.fast_scopes
241            .borrow_mut()
242            .retain(|_, entry| !entry.borrow().matches_predicate(&predicate));
243        self.scopes
244            .borrow_mut()
245            .retain(|entry| !entry.borrow().matches_predicate(&predicate));
246    }
247
248    fn clear_all(&self) {
249        self.fast_scopes.borrow_mut().clear();
250        self.scopes.borrow_mut().clear();
251    }
252
253    // Arc-wrapped closure captures Weak which may not be Send/Sync. This is safe because
254    // the observer callback is only invoked on the UI thread where it was registered.
255    #[allow(clippy::arc_with_non_send_sync)]
256    fn start(&self, weak_self: Weak<SnapshotStateObserverInner>) {
257        if self.apply_handle.borrow().is_some() {
258            return;
259        }
260
261        let handle = register_apply_observer(Arc::new(move |modified, _snapshot_id| {
262            if let Some(inner) = weak_self.upgrade() {
263                inner.handle_apply(modified);
264            }
265        }));
266        self.apply_handle.replace(Some(handle));
267    }
268
269    fn stop(&self) {
270        if let Some(handle) = self.apply_handle.borrow_mut().take() {
271            drop(handle);
272        }
273    }
274
275    fn get_scope_entry(
276        &self,
277        scope: impl Any + Clone + PartialEq + 'static,
278        on_changed: Rc<dyn Fn(&dyn Any)>,
279    ) -> Rc<RefCell<ScopeEntry>> {
280        let recompose_scope_id = (&scope as &dyn Any)
281            .downcast_ref::<RecomposeScope>()
282            .map(RecomposeScope::id);
283        if let Some(scope_id) = recompose_scope_id {
284            let mut fast = self.fast_scopes.borrow_mut();
285            if let Some(existing) = fast.get(&scope_id) {
286                return existing.clone();
287            }
288
289            let entry = Rc::new(RefCell::new(ScopeEntry::new(scope, on_changed)));
290            fast.insert(scope_id, entry.clone());
291            self.scopes.borrow_mut().push(entry.clone());
292            return entry;
293        }
294
295        // ---------- SLOW / GENERIC PATH ----------
296        let mut scopes = self.scopes.borrow_mut();
297
298        if let Some(existing) = scopes
299            .iter()
300            .find(|entry| entry.borrow().matches_scope(&scope))
301        {
302            return existing.clone();
303        }
304
305        let entry = Rc::new(RefCell::new(ScopeEntry::new(scope, on_changed)));
306        scopes.push(entry.clone());
307        entry
308    }
309
310    fn prune_dead_scopes(&self) {
311        self.fast_scopes
312            .borrow_mut()
313            .retain(|_, entry| entry.borrow().should_retain());
314        self.scopes
315            .borrow_mut()
316            .retain(|entry| entry.borrow().should_retain());
317    }
318
319    fn run_with_read_observer<R>(
320        &self,
321        read_observer: ReadObserver,
322        block: impl FnOnce() -> R,
323    ) -> R {
324        // Kotlin uses Snapshot.observeInternal which creates a TransparentObserverMutableSnapshot,
325        // not a readonly snapshot. This allows writes to happen during observation (composition).
326        use crate::snapshot_v2::take_transparent_observer_mutable_snapshot;
327
328        // Create a transparent mutable snapshot (not readonly!) for observation
329        // This matches Kotlin's Snapshot.observeInternal behavior
330        let snapshot = take_transparent_observer_mutable_snapshot(Some(read_observer), None);
331        let result = snapshot.enter(block);
332        snapshot.dispose();
333        result
334    }
335
336    fn handle_apply(&self, modified: &[Arc<dyn StateObject>]) {
337        if modified.is_empty() {
338            return;
339        }
340
341        let mut modified_ids: SmallVec<[usize; MAX_OBSERVED_STATES]> = SmallVec::new();
342        for state in modified {
343            modified_ids.push(state.object_id().as_usize());
344        }
345
346        let scopes = self.scopes.borrow();
347        let mut to_notify: Vec<Rc<RefCell<ScopeEntry>>> = Vec::new();
348        let mut seen: HashSet<usize> = HashSet::default();
349
350        for entry in scopes.iter() {
351            let entry_ref = entry.borrow();
352            if entry_ref
353                .observed
354                .iter()
355                .any(|id| modified_ids.contains(id))
356            {
357                let ptr = Rc::as_ptr(entry) as usize;
358                if seen.insert(ptr) {
359                    to_notify.push(entry.clone());
360                }
361            }
362        }
363        drop(scopes);
364
365        if to_notify.is_empty() {
366            return;
367        }
368
369        for entry in to_notify {
370            let executor = self.executor.clone();
371            executor(Box::new(move || {
372                if let Ok(entry) = entry.try_borrow() {
373                    entry.notify();
374                }
375            }));
376        }
377    }
378}
379use smallvec::SmallVec;
380
381enum ObservedIds {
382    Small(SmallVec<[StateObjectId; MAX_OBSERVED_STATES]>),
383    Large(HashSet<StateObjectId>),
384}
385
386impl ObservedIds {
387    fn new() -> Self {
388        ObservedIds::Small(SmallVec::new())
389    }
390
391    fn insert(&mut self, id: StateObjectId) {
392        match self {
393            ObservedIds::Small(small) => {
394                if small.contains(&id) {
395                    return;
396                }
397                if small.len() < MAX_OBSERVED_STATES {
398                    small.push(id);
399                } else {
400                    let mut large =
401                        HashSet::with_capacity_and_hasher(small.len() + 1, Default::default());
402                    for existing in small.iter() {
403                        large.insert(*existing);
404                    }
405                    large.insert(id);
406                    *self = ObservedIds::Large(large);
407                }
408            }
409            ObservedIds::Large(large) => {
410                large.insert(id);
411            }
412        }
413    }
414
415    fn is_empty(&self) -> bool {
416        match self {
417            ObservedIds::Small(small) => small.is_empty(),
418            ObservedIds::Large(large) => large.is_empty(),
419        }
420    }
421
422    fn clear(&mut self) {
423        match self {
424            ObservedIds::Small(small) => small.clear(),
425            ObservedIds::Large(large) => large.clear(),
426        }
427    }
428
429    fn iter(&self) -> Box<dyn Iterator<Item = &StateObjectId> + '_> {
430        match self {
431            ObservedIds::Small(small) => Box::new(small.iter()),
432            ObservedIds::Large(large) => Box::new(large.iter()),
433        }
434    }
435}
436
437const MAX_OBSERVED_STATES: usize = 8;
438
439enum ScopeStorage {
440    Owned(Box<dyn Any>),
441    RecomposeScope {
442        id: ScopeId,
443        weak: Weak<RecomposeScopeInner>,
444    },
445}
446
447struct ScopeEntry {
448    scope: ScopeStorage,
449    on_changed: Rc<dyn Fn(&dyn Any)>,
450    observed: ObservedIds,
451    read_observer: Option<ReadObserver>,
452    is_stateless: bool,
453    last_seen_version: u64,
454}
455
456impl ScopeEntry {
457    fn new<T>(scope: T, on_changed: Rc<dyn Fn(&dyn Any)>) -> Self
458    where
459        T: Any + 'static,
460    {
461        Self {
462            scope: ScopeStorage::from_value(scope),
463            on_changed,
464            observed: ObservedIds::new(),
465            read_observer: None,
466            is_stateless: false,
467            last_seen_version: u64::MAX,
468        }
469    }
470
471    fn update<T>(&mut self, new_scope: T, on_changed: Rc<dyn Fn(&dyn Any)>)
472    where
473        T: Any + 'static,
474    {
475        self.scope = ScopeStorage::from_value(new_scope);
476        self.on_changed = on_changed;
477    }
478
479    fn matches_scope<T>(&self, scope: &T) -> bool
480    where
481        T: Any + PartialEq + 'static,
482    {
483        if let Some(scope) = (scope as &dyn Any).downcast_ref::<RecomposeScope>() {
484            return matches!(
485                &self.scope,
486                ScopeStorage::RecomposeScope { id, .. } if *id == scope.id()
487            );
488        }
489
490        match &self.scope {
491            ScopeStorage::Owned(stored) => stored
492                .downcast_ref::<T>()
493                .map(|stored| stored == scope)
494                .unwrap_or(false),
495            ScopeStorage::RecomposeScope { .. } => false,
496        }
497    }
498
499    fn matches_predicate(&self, predicate: &impl Fn(&dyn Any) -> bool) -> bool {
500        match &self.scope {
501            ScopeStorage::Owned(scope) => predicate(scope.as_ref()),
502            ScopeStorage::RecomposeScope { weak, .. } => weak
503                .upgrade()
504                .map(|inner| predicate(&RecomposeScope { inner }))
505                .unwrap_or(true),
506        }
507    }
508
509    fn should_retain(&self) -> bool {
510        match &self.scope {
511            ScopeStorage::Owned(_) => true,
512            ScopeStorage::RecomposeScope { weak, .. } => weak.upgrade().is_some(),
513        }
514    }
515
516    fn notify(&self) {
517        match &self.scope {
518            ScopeStorage::Owned(scope) => (self.on_changed)(scope.as_ref()),
519            ScopeStorage::RecomposeScope { weak, .. } => {
520                if let Some(inner) = weak.upgrade() {
521                    (self.on_changed)(&RecomposeScope { inner });
522                }
523            }
524        }
525    }
526}
527
528impl ScopeStorage {
529    fn from_value<T>(value: T) -> Self
530    where
531        T: Any + 'static,
532    {
533        let any = &value as &dyn Any;
534        if let Some(scope) = any.downcast_ref::<RecomposeScope>() {
535            Self::RecomposeScope {
536                id: scope.id(),
537                weak: scope.downgrade(),
538            }
539        } else {
540            Self::Owned(Box::new(value))
541        }
542    }
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548    use crate::snapshot_v2::take_mutable_snapshot;
549    use crate::snapshot_v2::{reset_runtime_for_tests, TestRuntimeGuard};
550    use crate::state::{NeverEqual, SnapshotMutableState};
551    use std::cell::Cell;
552
553    fn reset_runtime() -> TestRuntimeGuard {
554        reset_runtime_for_tests()
555    }
556
557    #[derive(Clone, PartialEq)]
558    struct TestScope(&'static str);
559
560    #[test]
561    fn notifies_scope_when_state_changes() {
562        let _guard = reset_runtime();
563
564        let state = SnapshotMutableState::new_in_arc(0, Arc::new(NeverEqual));
565        let triggered = Rc::new(Cell::new(0));
566        let observer_trigger = triggered.clone();
567
568        let observer = SnapshotStateObserver::new(|callback| callback());
569        observer.start();
570
571        let scope = TestScope("scope");
572        observer.observe_reads(
573            scope.clone(),
574            move |_| {
575                observer_trigger.set(observer_trigger.get() + 1);
576            },
577            || {
578                let _ = state.get();
579            },
580        );
581
582        let snapshot = take_mutable_snapshot(None, None);
583        snapshot.enter(|| {
584            state.set(1);
585        });
586        snapshot.apply().check();
587
588        assert_eq!(triggered.get(), 1);
589        observer.stop();
590    }
591
592    #[test]
593    fn clear_removes_scope_observation() {
594        let _guard = reset_runtime();
595
596        let state = SnapshotMutableState::new_in_arc(0, Arc::new(NeverEqual));
597        let triggered = Rc::new(Cell::new(0));
598        let observer_trigger = triggered.clone();
599
600        let observer = SnapshotStateObserver::new(|callback| callback());
601        observer.start();
602
603        let scope = TestScope("scope");
604        observer.observe_reads(
605            scope.clone(),
606            move |_| {
607                observer_trigger.set(observer_trigger.get() + 1);
608            },
609            || {
610                let _ = state.get();
611            },
612        );
613
614        observer.clear(&scope);
615
616        let snapshot = take_mutable_snapshot(None, None);
617        snapshot.enter(|| {
618            state.set(1);
619        });
620        snapshot.apply().check();
621
622        assert_eq!(triggered.get(), 0);
623        observer.stop();
624    }
625
626    #[test]
627    fn with_no_observations_skips_reads() {
628        let _guard = reset_runtime();
629
630        let state = SnapshotMutableState::new_in_arc(0, Arc::new(NeverEqual));
631        let triggered = Rc::new(Cell::new(0));
632        let observer_trigger = triggered.clone();
633
634        let observer = SnapshotStateObserver::new(|callback| callback());
635        observer.start();
636
637        let scope = TestScope("scope");
638        observer.observe_reads(
639            scope.clone(),
640            move |_| {
641                observer_trigger.set(observer_trigger.get() + 1);
642            },
643            || {
644                observer.with_no_observations(|| {
645                    let _ = state.get();
646                });
647            },
648        );
649
650        let snapshot = take_mutable_snapshot(None, None);
651        snapshot.enter(|| {
652            state.set(1);
653        });
654        snapshot.apply().check();
655
656        assert_eq!(triggered.get(), 0);
657        observer.stop();
658    }
659
660    #[test]
661    fn begin_frame_prunes_dropped_recompose_scope_entries() {
662        let _guard = reset_runtime();
663
664        let state = SnapshotMutableState::new_in_arc(0, Arc::new(NeverEqual));
665        let observer = SnapshotStateObserver::new(|callback| callback());
666        let runtime = crate::TestRuntime::new();
667        let scope = RecomposeScope::new_for_test(runtime.handle());
668
669        observer.observe_reads(
670            scope.clone(),
671            |_| {},
672            || {
673                let _ = state.get();
674            },
675        );
676
677        assert_eq!(observer.inner.scopes.borrow().len(), 1);
678        assert_eq!(observer.inner.fast_scopes.borrow().len(), 1);
679
680        drop(scope);
681        observer.begin_frame();
682
683        assert_eq!(observer.inner.scopes.borrow().len(), 0);
684        assert_eq!(observer.inner.fast_scopes.borrow().len(), 0);
685    }
686}