Skip to main content

cranpose_core/
snapshot_state_observer.rs

1// Observer callbacks use Arc for shared ownership but may capture non-Send types.
2// This is safe because callbacks are always invoked on the UI thread where they were created.
3#![allow(clippy::arc_with_non_send_sync)]
4// Complex types are inherent to the observer pattern with nested callbacks and state tracking
5#![allow(clippy::type_complexity)]
6
7use crate::collections::map::HashSet;
8use crate::snapshot_v2::{register_apply_observer, ReadObserver, StateObjectId};
9use crate::state::StateObject;
10use std::any::Any;
11use std::cell::{Cell, RefCell};
12use std::rc::{Rc, Weak};
13use std::sync::Arc;
14
15/// Executes a callback once changes are delivered.
16type Executor = dyn Fn(Box<dyn FnOnce() + 'static>) + 'static;
17
18/// Observer that records state object reads performed inside a given scope and
19/// notifies the caller when any of the observed objects change.
20///
21/// This is a pragmatic Rust translation of Jetpack Compose's
22/// `SnapshotStateObserver`. The implementation focuses on the core behaviour
23/// needed by the Cranpose runtime:
24/// - Tracking state object reads per logical scope.
25/// - Reacting to snapshot apply notifications.
26/// - Scheduling invalidation callbacks via the supplied executor.
27///
28/// Advanced features from the Kotlin version (derived state tracking, change
29/// coalescing, queue minimisation) are deferred
30#[derive(Clone)]
31pub struct SnapshotStateObserver {
32    inner: Rc<SnapshotStateObserverInner>,
33}
34
35impl SnapshotStateObserver {
36    /// Create a new observer that schedules callbacks using `on_changed_executor`.
37    pub fn new(on_changed_executor: impl Fn(Box<dyn FnOnce() + 'static>) + 'static) -> Self {
38        let inner = Rc::new(SnapshotStateObserverInner::new(on_changed_executor));
39        inner.set_self(Rc::downgrade(&inner));
40        Self { inner }
41    }
42
43    /// Observe state object reads performed while executing `block`.
44    ///
45    /// Subsequent calls to `observe_reads` replace any previously recorded
46    /// observations for the provided `scope`. When one of the observed objects
47    /// mutates, `on_value_changed_for_scope` will be invoked on the executor.
48    pub fn observe_reads<T, R>(
49        &self,
50        scope: T,
51        on_value_changed_for_scope: impl Fn(&T) + 'static,
52        block: impl FnOnce() -> R,
53    ) -> R
54    where
55        T: Any + Clone + PartialEq + 'static,
56    {
57        self.inner
58            .observe_reads(scope, on_value_changed_for_scope, block)
59    }
60
61    /// Notify the observer that a new composition frame is starting.
62    pub fn begin_frame(&self) {
63        self.inner.begin_frame();
64    }
65
66    /// Temporarily pause read observation while executing `block`.
67    pub fn with_no_observations<R>(&self, block: impl FnOnce() -> R) -> R {
68        self.inner.with_no_observations(block)
69    }
70
71    /// Remove any recorded reads for `scope`.
72    pub fn clear<T>(&self, scope: &T)
73    where
74        T: Any + PartialEq + 'static,
75    {
76        self.inner.clear(scope);
77    }
78
79    /// Remove recorded reads for scopes that satisfy `predicate`.
80    pub fn clear_if(&self, predicate: impl Fn(&dyn Any) -> bool) {
81        self.inner.clear_if(predicate);
82    }
83
84    /// Remove all recorded observations.
85    pub fn clear_all(&self) {
86        self.inner.clear_all();
87    }
88
89    /// Begin listening for snapshot apply notifications.
90    pub fn start(&self) {
91        let weak = Rc::downgrade(&self.inner);
92        self.inner.start(weak);
93    }
94
95    /// Stop listening for snapshot apply notifications.
96    pub fn stop(&self) {
97        self.inner.stop();
98    }
99
100    /// Test-only helper to simulate snapshot changes.
101    #[cfg(test)]
102    pub fn notify_changes(&self, modified: &[Arc<dyn StateObject>]) {
103        self.inner.handle_apply(modified);
104    }
105}
106
107struct SnapshotStateObserverInner {
108    executor: Rc<Executor>,
109    scopes: RefCell<Vec<Rc<RefCell<ScopeEntry>>>>,
110    fast_scopes: RefCell<Vec<Option<Rc<RefCell<ScopeEntry>>>>>,
111    pause_count: Rc<Cell<usize>>,
112    apply_handle: RefCell<Option<crate::snapshot_v2::ObserverHandle>>,
113    weak_self: RefCell<Weak<SnapshotStateObserverInner>>,
114    frame_version: Cell<u64>,
115}
116
117impl SnapshotStateObserverInner {
118    fn new(on_changed_executor: impl Fn(Box<dyn FnOnce() + 'static>) + 'static) -> Self {
119        Self {
120            executor: Rc::new(on_changed_executor),
121            scopes: RefCell::new(Vec::new()),
122            fast_scopes: RefCell::new(Vec::new()),
123            pause_count: Rc::new(Cell::new(0)),
124            apply_handle: RefCell::new(None),
125            weak_self: RefCell::new(Weak::new()),
126            frame_version: Cell::new(0),
127        }
128    }
129
130    fn set_self(&self, weak: Weak<SnapshotStateObserverInner>) {
131        self.weak_self.replace(weak);
132    }
133
134    fn begin_frame(&self) {
135        let next = self.frame_version.get().wrapping_add(1);
136        self.frame_version.set(next);
137    }
138
139    fn observe_reads<T, R>(
140        &self,
141        scope: T,
142        on_value_changed_for_scope: impl Fn(&T) + 'static,
143        block: impl FnOnce() -> R,
144    ) -> R
145    where
146        T: Any + Clone + PartialEq + 'static,
147    {
148        let frame_version = self.frame_version.get();
149        let has_frame_version = frame_version != 0;
150
151        let on_changed: Rc<dyn Fn(&dyn Any)> = {
152            let callback = Rc::new(on_value_changed_for_scope);
153            Rc::new(move |scope_any: &dyn Any| {
154                if let Some(typed) = scope_any.downcast_ref::<T>() {
155                    callback(typed);
156                }
157            })
158        };
159
160        let entry = self.get_scope_entry(scope.clone(), on_changed.clone());
161
162        let pause_count = self.pause_count.clone();
163
164        let read_observer: ReadObserver = {
165            let mut entry_mut = entry.borrow_mut();
166            entry_mut.update(scope, on_changed);
167
168            let already_observed =
169                has_frame_version && entry_mut.last_seen_version == frame_version;
170            if already_observed || entry_mut.is_stateless {
171                drop(entry_mut);
172                return block();
173            }
174
175            entry_mut.observed.clear();
176            entry_mut.last_seen_version = if has_frame_version {
177                frame_version
178            } else {
179                u64::MAX
180            };
181            entry_mut.is_stateless = false;
182
183            if let Some(observer) = entry_mut.read_observer.clone() {
184                observer
185            } else {
186                let entry_for_observer = entry.clone();
187                let pause_count = pause_count.clone();
188
189                let observer: ReadObserver = Arc::new(move |state| {
190                    if pause_count.get() > 0 {
191                        return;
192                    }
193                    let mut entry_ref = entry_for_observer.borrow_mut();
194                    let id = state.object_id().as_usize();
195                    entry_ref.observed.insert(id);
196                    entry_ref.is_stateless = false;
197                });
198
199                entry_mut.read_observer = Some(observer.clone());
200                observer
201            }
202        };
203
204        let result = self.run_with_read_observer(read_observer, block);
205
206        {
207            let mut entry_mut = entry.borrow_mut();
208            if entry_mut.observed.is_empty() {
209                entry_mut.is_stateless = true;
210            }
211        }
212
213        result
214    }
215
216    fn with_no_observations<R>(&self, block: impl FnOnce() -> R) -> R {
217        self.pause_count.set(self.pause_count.get() + 1);
218        let result = block();
219        self.pause_count
220            .set(self.pause_count.get().saturating_sub(1));
221        result
222    }
223
224    fn clear<T>(&self, scope: &T)
225    where
226        T: Any + PartialEq + 'static,
227    {
228        // Clear from fast_scopes if it's a RecomposeScope
229        if let Some(rc_scope) = (scope as &dyn Any).downcast_ref::<RecomposeScope>() {
230            let id = rc_scope.id();
231            let mut fast = self.fast_scopes.borrow_mut();
232            if id < fast.len() {
233                fast[id] = None;
234            }
235        }
236
237        // Clear from scopes
238        self.scopes
239            .borrow_mut()
240            .retain(|entry| !entry.borrow().matches_scope(scope));
241    }
242
243    fn clear_if(&self, predicate: impl Fn(&dyn Any) -> bool) {
244        // Clear from fast_scopes for any RecomposeScope entries that match predicate
245        let mut fast = self.fast_scopes.borrow_mut();
246        for slot in fast.iter_mut() {
247            if let Some(entry) = slot {
248                let should_clear = {
249                    let entry_ref = entry.borrow();
250                    predicate(entry_ref.scope())
251                };
252                if should_clear {
253                    *slot = None;
254                }
255            }
256        }
257        drop(fast);
258
259        // Clear from scopes
260        self.scopes.borrow_mut().retain(|entry| {
261            let entry_ref = entry.borrow();
262            !predicate(entry_ref.scope())
263        });
264    }
265
266    fn clear_all(&self) {
267        self.fast_scopes.borrow_mut().clear();
268        self.scopes.borrow_mut().clear();
269    }
270
271    // Arc-wrapped closure captures Weak which may not be Send/Sync. This is safe because
272    // the observer callback is only invoked on the UI thread where it was registered.
273    #[allow(clippy::arc_with_non_send_sync)]
274    fn start(&self, weak_self: Weak<SnapshotStateObserverInner>) {
275        if self.apply_handle.borrow().is_some() {
276            return;
277        }
278
279        let handle = register_apply_observer(Arc::new(move |modified, _snapshot_id| {
280            if let Some(inner) = weak_self.upgrade() {
281                inner.handle_apply(modified);
282            }
283        }));
284        self.apply_handle.replace(Some(handle));
285    }
286
287    fn stop(&self) {
288        if let Some(handle) = self.apply_handle.borrow_mut().take() {
289            drop(handle);
290        }
291    }
292
293    fn get_scope_entry(
294        &self,
295        scope: impl Any + Clone + PartialEq + 'static,
296        on_changed: Rc<dyn Fn(&dyn Any)>,
297    ) -> Rc<RefCell<ScopeEntry>> {
298        // ---------- FAST PATH: real compose scope ----------
299        if let Some(rc_scope) = (&scope as &dyn Any).downcast_ref::<RecomposeScope>() {
300            let id: usize = rc_scope.id(); // or `.0` or similar
301
302            let mut fast = self.fast_scopes.borrow_mut();
303
304            if id >= fast.len() {
305                fast.resize_with(id + 1, || None);
306            }
307
308            if let Some(existing) = &fast[id] {
309                return existing.clone();
310            }
311
312            let entry = Rc::new(RefCell::new(ScopeEntry::new(scope, on_changed)));
313            fast[id] = Some(entry.clone());
314            // CRITICAL: Also add to scopes Vec so handle_apply and clear* methods work correctly
315            drop(fast);
316            self.scopes.borrow_mut().push(entry.clone());
317            return entry;
318        }
319
320        // ---------- SLOW / GENERIC PATH ----------
321        let mut scopes = self.scopes.borrow_mut();
322
323        if let Some(existing) = scopes
324            .iter()
325            .find(|entry| entry.borrow().matches_scope(&scope))
326        {
327            return existing.clone();
328        }
329
330        let entry = Rc::new(RefCell::new(ScopeEntry::new(scope, on_changed)));
331        scopes.push(entry.clone());
332        entry
333    }
334
335    fn run_with_read_observer<R>(
336        &self,
337        read_observer: ReadObserver,
338        block: impl FnOnce() -> R,
339    ) -> R {
340        // Kotlin uses Snapshot.observeInternal which creates a TransparentObserverMutableSnapshot,
341        // not a readonly snapshot. This allows writes to happen during observation (composition).
342        use crate::snapshot_v2::take_transparent_observer_mutable_snapshot;
343
344        // Create a transparent mutable snapshot (not readonly!) for observation
345        // This matches Kotlin's Snapshot.observeInternal behavior
346        let snapshot = take_transparent_observer_mutable_snapshot(Some(read_observer), None);
347        let result = snapshot.enter(block);
348        snapshot.dispose();
349        result
350    }
351
352    fn handle_apply(&self, modified: &[Arc<dyn StateObject>]) {
353        if modified.is_empty() {
354            return;
355        }
356
357        let mut modified_ids: SmallVec<[usize; MAX_OBSERVED_STATES]> = SmallVec::new();
358        for state in modified {
359            modified_ids.push(state.object_id().as_usize());
360        }
361
362        let scopes = self.scopes.borrow();
363        let mut to_notify: Vec<Rc<RefCell<ScopeEntry>>> = Vec::new();
364        let mut seen: HashSet<usize> = HashSet::default();
365
366        for entry in scopes.iter() {
367            let entry_ref = entry.borrow();
368            if entry_ref
369                .observed
370                .iter()
371                .any(|id| modified_ids.contains(id))
372            {
373                let ptr = Rc::as_ptr(entry) as usize;
374                if seen.insert(ptr) {
375                    to_notify.push(entry.clone());
376                }
377            }
378        }
379        drop(scopes);
380
381        {
382            let fast_scopes = self.fast_scopes.borrow();
383            for entry in fast_scopes.iter().flatten() {
384                let entry_ref = entry.borrow();
385                if entry_ref
386                    .observed
387                    .iter()
388                    .any(|id| modified_ids.contains(id))
389                {
390                    let ptr = Rc::as_ptr(entry) as usize;
391                    if seen.insert(ptr) {
392                        to_notify.push(entry.clone());
393                    }
394                }
395            }
396        }
397
398        if to_notify.is_empty() {
399            return;
400        }
401
402        for entry in to_notify {
403            let executor = self.executor.clone();
404            executor(Box::new(move || {
405                if let Ok(entry) = entry.try_borrow() {
406                    entry.notify();
407                }
408            }));
409        }
410    }
411}
412
413use cranpose_core::RecomposeScope;
414use smallvec::SmallVec;
415
416enum ObservedIds {
417    Small(SmallVec<[StateObjectId; MAX_OBSERVED_STATES]>),
418    Large(HashSet<StateObjectId>),
419}
420
421impl ObservedIds {
422    fn new() -> Self {
423        ObservedIds::Small(SmallVec::new())
424    }
425
426    fn insert(&mut self, id: StateObjectId) {
427        match self {
428            ObservedIds::Small(small) => {
429                if small.contains(&id) {
430                    return;
431                }
432                if small.len() < MAX_OBSERVED_STATES {
433                    small.push(id);
434                } else {
435                    let mut large =
436                        HashSet::with_capacity_and_hasher(small.len() + 1, Default::default());
437                    for existing in small.iter() {
438                        large.insert(*existing);
439                    }
440                    large.insert(id);
441                    *self = ObservedIds::Large(large);
442                }
443            }
444            ObservedIds::Large(large) => {
445                large.insert(id);
446            }
447        }
448    }
449
450    fn is_empty(&self) -> bool {
451        match self {
452            ObservedIds::Small(small) => small.is_empty(),
453            ObservedIds::Large(large) => large.is_empty(),
454        }
455    }
456
457    fn clear(&mut self) {
458        match self {
459            ObservedIds::Small(small) => small.clear(),
460            ObservedIds::Large(large) => large.clear(),
461        }
462    }
463
464    fn iter(&self) -> Box<dyn Iterator<Item = &StateObjectId> + '_> {
465        match self {
466            ObservedIds::Small(small) => Box::new(small.iter()),
467            ObservedIds::Large(large) => Box::new(large.iter()),
468        }
469    }
470}
471
472const MAX_OBSERVED_STATES: usize = 8;
473struct ScopeEntry {
474    scope: Box<dyn Any>,
475    on_changed: Rc<dyn Fn(&dyn Any)>,
476    observed: ObservedIds,
477    read_observer: Option<ReadObserver>,
478    is_stateless: bool,
479    last_seen_version: u64,
480}
481
482impl ScopeEntry {
483    fn new<T>(scope: T, on_changed: Rc<dyn Fn(&dyn Any)>) -> Self
484    where
485        T: Any + 'static,
486    {
487        Self {
488            scope: Box::new(scope),
489            on_changed,
490            observed: ObservedIds::new(),
491            read_observer: None,
492            is_stateless: false,
493            last_seen_version: u64::MAX,
494        }
495    }
496
497    fn update<T>(&mut self, new_scope: T, on_changed: Rc<dyn Fn(&dyn Any)>)
498    where
499        T: Any + 'static,
500    {
501        self.scope = Box::new(new_scope);
502        self.on_changed = on_changed;
503    }
504
505    fn matches_scope<T>(&self, scope: &T) -> bool
506    where
507        T: Any + PartialEq + 'static,
508    {
509        self.scope
510            .downcast_ref::<T>()
511            .map(|stored| stored == scope)
512            .unwrap_or(false)
513    }
514
515    fn scope(&self) -> &dyn Any {
516        &*self.scope
517    }
518
519    fn notify(&self) {
520        (self.on_changed)(self.scope());
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527    use crate::snapshot_v2::take_mutable_snapshot;
528    use crate::snapshot_v2::{reset_runtime_for_tests, TestRuntimeGuard};
529    use crate::state::{NeverEqual, SnapshotMutableState};
530    use std::cell::Cell;
531
532    fn reset_runtime() -> TestRuntimeGuard {
533        reset_runtime_for_tests()
534    }
535
536    #[derive(Clone, PartialEq)]
537    struct TestScope(&'static str);
538
539    #[test]
540    fn notifies_scope_when_state_changes() {
541        let _guard = reset_runtime();
542
543        let state = SnapshotMutableState::new_in_arc(0, Arc::new(NeverEqual));
544        let triggered = Rc::new(Cell::new(0));
545        let observer_trigger = triggered.clone();
546
547        let observer = SnapshotStateObserver::new(|callback| callback());
548        observer.start();
549
550        let scope = TestScope("scope");
551        observer.observe_reads(
552            scope.clone(),
553            move |_| {
554                observer_trigger.set(observer_trigger.get() + 1);
555            },
556            || {
557                let _ = state.get();
558            },
559        );
560
561        let snapshot = take_mutable_snapshot(None, None);
562        snapshot.enter(|| {
563            state.set(1);
564        });
565        snapshot.apply().check();
566
567        assert_eq!(triggered.get(), 1);
568        observer.stop();
569    }
570
571    #[test]
572    fn clear_removes_scope_observation() {
573        let _guard = reset_runtime();
574
575        let state = SnapshotMutableState::new_in_arc(0, Arc::new(NeverEqual));
576        let triggered = Rc::new(Cell::new(0));
577        let observer_trigger = triggered.clone();
578
579        let observer = SnapshotStateObserver::new(|callback| callback());
580        observer.start();
581
582        let scope = TestScope("scope");
583        observer.observe_reads(
584            scope.clone(),
585            move |_| {
586                observer_trigger.set(observer_trigger.get() + 1);
587            },
588            || {
589                let _ = state.get();
590            },
591        );
592
593        observer.clear(&scope);
594
595        let snapshot = take_mutable_snapshot(None, None);
596        snapshot.enter(|| {
597            state.set(1);
598        });
599        snapshot.apply().check();
600
601        assert_eq!(triggered.get(), 0);
602        observer.stop();
603    }
604
605    #[test]
606    fn with_no_observations_skips_reads() {
607        let _guard = reset_runtime();
608
609        let state = SnapshotMutableState::new_in_arc(0, Arc::new(NeverEqual));
610        let triggered = Rc::new(Cell::new(0));
611        let observer_trigger = triggered.clone();
612
613        let observer = SnapshotStateObserver::new(|callback| callback());
614        observer.start();
615
616        let scope = TestScope("scope");
617        observer.observe_reads(
618            scope.clone(),
619            move |_| {
620                observer_trigger.set(observer_trigger.get() + 1);
621            },
622            || {
623                observer.with_no_observations(|| {
624                    let _ = state.get();
625                });
626            },
627        );
628
629        let snapshot = take_mutable_snapshot(None, None);
630        snapshot.enter(|| {
631            state.set(1);
632        });
633        snapshot.apply().check();
634
635        assert_eq!(triggered.get(), 0);
636        observer.stop();
637    }
638}