cranpose_core/snapshot_v2/
mod.rs

1//! Snapshot system for managing isolated state changes.
2//!
3//! This module implements Jetpack Compose's snapshot isolation system, allowing
4//! state changes to be isolated, composed, and atomically applied.
5//!
6//! # Snapshot Types
7//!
8//! - **ReadonlySnapshot**: Immutable view of state at a point in time
9//! - **MutableSnapshot**: Allows isolated state mutations
10//! - **NestedReadonlySnapshot**: Readonly snapshot nested in a parent
11//! - **NestedMutableSnapshot**: Mutable snapshot nested in a parent
12//! - **GlobalSnapshot**: Special global mutable snapshot
13//! - **TransparentObserverMutableSnapshot**: Optimized for observer chaining
14//! - **TransparentObserverSnapshot**: Readonly version of transparent observer
15//!
16//! # Thread Local Storage
17//!
18//! The current snapshot is stored in thread-local storage and automatically
19//! managed by the snapshot system.
20
21// All snapshot types use Arc with Cell/RefCell for single-threaded shared ownership.
22// This is safe because snapshots are thread-local and never cross thread boundaries.
23#![allow(clippy::arc_with_non_send_sync)]
24
25use crate::collections::map::HashMap; // FUTURE(no_std): replace HashMap/HashSet with arena-backed maps.
26use crate::collections::map::HashSet;
27use crate::snapshot_id_set::{SnapshotId, SnapshotIdSet};
28use crate::snapshot_pinning::{self, PinHandle};
29use crate::state::{StateObject, StateRecord};
30use std::cell::{Cell, RefCell};
31use std::sync::atomic::{AtomicUsize, Ordering};
32use std::sync::{Arc, Weak};
33
34mod global;
35mod mutable;
36mod nested;
37mod readonly;
38mod runtime;
39mod transparent;
40
41#[cfg(test)]
42mod integration_tests;
43
44pub use global::{advance_global_snapshot, GlobalSnapshot};
45pub use mutable::MutableSnapshot;
46pub use nested::{NestedMutableSnapshot, NestedReadonlySnapshot};
47pub use readonly::ReadonlySnapshot;
48pub use transparent::{TransparentObserverMutableSnapshot, TransparentObserverSnapshot};
49
50pub(crate) use runtime::{allocate_snapshot, close_snapshot, with_runtime};
51#[cfg(test)]
52pub(crate) use runtime::{reset_runtime_for_tests, TestRuntimeGuard};
53
54/// Observer that is called when a state object is read.
55pub type ReadObserver = Arc<dyn Fn(&dyn StateObject) + 'static>;
56
57/// Observer that is called when a state object is written.
58pub type WriteObserver = Arc<dyn Fn(&dyn StateObject) + 'static>;
59
60/// Apply observer that is called when a snapshot is applied.
61pub type ApplyObserver = Arc<dyn Fn(&[Arc<dyn StateObject>], SnapshotId) + 'static>;
62
63/// Result of applying a mutable snapshot.
64#[derive(Debug, Clone, PartialEq, Eq)]
65pub enum SnapshotApplyResult {
66    /// The snapshot was applied successfully.
67    Success,
68    /// The snapshot could not be applied due to conflicts.
69    Failure,
70}
71
72impl SnapshotApplyResult {
73    /// Check if the result is successful.
74    pub fn is_success(&self) -> bool {
75        matches!(self, SnapshotApplyResult::Success)
76    }
77
78    /// Check if the result is a failure.
79    pub fn is_failure(&self) -> bool {
80        matches!(self, SnapshotApplyResult::Failure)
81    }
82
83    /// Panic if the result is a failure (for use in tests).
84    #[track_caller]
85    pub fn check(&self) {
86        if self.is_failure() {
87            panic!("Snapshot apply failed");
88        }
89    }
90}
91
92/// Unique identifier for a state object in the modified set.
93pub type StateObjectId = usize;
94
95/// Enum wrapper for all snapshot types.
96///
97/// This provides a type-safe way to work with different snapshot types
98/// without requiring trait objects, which avoids object-safety issues.
99#[derive(Clone)]
100pub enum AnySnapshot {
101    Readonly(Arc<ReadonlySnapshot>),
102    Mutable(Arc<MutableSnapshot>),
103    NestedReadonly(Arc<NestedReadonlySnapshot>),
104    NestedMutable(Arc<NestedMutableSnapshot>),
105    Global(Arc<GlobalSnapshot>),
106    TransparentMutable(Arc<TransparentObserverMutableSnapshot>),
107    TransparentReadonly(Arc<TransparentObserverSnapshot>),
108}
109
110impl AnySnapshot {
111    /// Get the snapshot ID.
112    pub fn snapshot_id(&self) -> SnapshotId {
113        match self {
114            AnySnapshot::Readonly(s) => s.snapshot_id(),
115            AnySnapshot::Mutable(s) => s.snapshot_id(),
116            AnySnapshot::NestedReadonly(s) => s.snapshot_id(),
117            AnySnapshot::NestedMutable(s) => s.snapshot_id(),
118            AnySnapshot::Global(s) => s.snapshot_id(),
119            AnySnapshot::TransparentMutable(s) => s.snapshot_id(),
120            AnySnapshot::TransparentReadonly(s) => s.snapshot_id(),
121        }
122    }
123
124    /// Get the set of invalid snapshot IDs.
125    pub fn invalid(&self) -> SnapshotIdSet {
126        match self {
127            AnySnapshot::Readonly(s) => s.invalid(),
128            AnySnapshot::Mutable(s) => s.invalid(),
129            AnySnapshot::NestedReadonly(s) => s.invalid(),
130            AnySnapshot::NestedMutable(s) => s.invalid(),
131            AnySnapshot::Global(s) => s.invalid(),
132            AnySnapshot::TransparentMutable(s) => s.invalid(),
133            AnySnapshot::TransparentReadonly(s) => s.invalid(),
134        }
135    }
136
137    /// Check if a snapshot ID is valid in this snapshot.
138    pub fn is_valid(&self, id: SnapshotId) -> bool {
139        let snapshot_id = self.snapshot_id();
140        id <= snapshot_id && !self.invalid().get(id)
141    }
142
143    /// Check if this is a read-only snapshot.
144    pub fn read_only(&self) -> bool {
145        match self {
146            AnySnapshot::Readonly(_) => true,
147            AnySnapshot::Mutable(_) => false,
148            AnySnapshot::NestedReadonly(_) => true,
149            AnySnapshot::NestedMutable(_) => false,
150            AnySnapshot::Global(_) => false,
151            AnySnapshot::TransparentMutable(_) => false,
152            AnySnapshot::TransparentReadonly(_) => true,
153        }
154    }
155
156    /// Get the root snapshot.
157    pub fn root(&self) -> AnySnapshot {
158        match self {
159            AnySnapshot::Readonly(s) => AnySnapshot::Readonly(s.root_readonly()),
160            AnySnapshot::Mutable(s) => AnySnapshot::Mutable(s.root_mutable()),
161            AnySnapshot::NestedReadonly(s) => AnySnapshot::NestedReadonly(s.root_nested_readonly()),
162            AnySnapshot::NestedMutable(s) => AnySnapshot::Mutable(s.root_mutable()),
163            AnySnapshot::Global(s) => AnySnapshot::Global(s.root_global()),
164            AnySnapshot::TransparentMutable(s) => {
165                AnySnapshot::TransparentMutable(s.root_transparent_mutable())
166            }
167            AnySnapshot::TransparentReadonly(s) => {
168                AnySnapshot::TransparentReadonly(s.root_transparent_readonly())
169            }
170        }
171    }
172
173    /// Check if this snapshot refers to the same transparent snapshot.
174    pub fn is_same_transparent(&self, other: &Arc<TransparentObserverMutableSnapshot>) -> bool {
175        matches!(self, AnySnapshot::TransparentMutable(snapshot) if Arc::ptr_eq(snapshot, other))
176    }
177
178    /// Check if this snapshot refers to the same transparent mutable snapshot.
179    pub fn is_same_transparent_mutable(
180        &self,
181        other: &Arc<TransparentObserverMutableSnapshot>,
182    ) -> bool {
183        self.is_same_transparent(other)
184    }
185
186    /// Check if this snapshot refers to the same transparent readonly snapshot.
187    pub fn is_same_transparent_readonly(&self, other: &Arc<TransparentObserverSnapshot>) -> bool {
188        matches!(self, AnySnapshot::TransparentReadonly(snapshot) if Arc::ptr_eq(snapshot, other))
189    }
190
191    /// Enter this snapshot, making it current for the duration of the closure.
192    pub fn enter<T>(&self, f: impl FnOnce() -> T) -> T {
193        match self {
194            AnySnapshot::Readonly(s) => s.enter(f),
195            AnySnapshot::Mutable(s) => s.enter(f),
196            AnySnapshot::NestedReadonly(s) => s.enter(f),
197            AnySnapshot::NestedMutable(s) => s.enter(f),
198            AnySnapshot::Global(s) => s.enter(f),
199            AnySnapshot::TransparentMutable(s) => s.enter(f),
200            AnySnapshot::TransparentReadonly(s) => s.enter(f),
201        }
202    }
203
204    /// Take a nested read-only snapshot.
205    pub fn take_nested_snapshot(&self, read_observer: Option<ReadObserver>) -> AnySnapshot {
206        match self {
207            AnySnapshot::Readonly(s) => {
208                AnySnapshot::Readonly(s.take_nested_snapshot(read_observer))
209            }
210            AnySnapshot::Mutable(s) => AnySnapshot::Readonly(s.take_nested_snapshot(read_observer)),
211            AnySnapshot::NestedReadonly(s) => {
212                AnySnapshot::NestedReadonly(s.take_nested_snapshot(read_observer))
213            }
214            AnySnapshot::NestedMutable(s) => {
215                AnySnapshot::Readonly(s.take_nested_snapshot(read_observer))
216            }
217            AnySnapshot::Global(s) => AnySnapshot::Readonly(s.take_nested_snapshot(read_observer)),
218            AnySnapshot::TransparentMutable(s) => {
219                AnySnapshot::Readonly(s.take_nested_snapshot(read_observer))
220            }
221            AnySnapshot::TransparentReadonly(s) => {
222                AnySnapshot::TransparentReadonly(s.take_nested_snapshot(read_observer))
223            }
224        }
225    }
226
227    /// Check if there are pending changes.
228    pub fn has_pending_changes(&self) -> bool {
229        match self {
230            AnySnapshot::Readonly(s) => s.has_pending_changes(),
231            AnySnapshot::Mutable(s) => s.has_pending_changes(),
232            AnySnapshot::NestedReadonly(s) => s.has_pending_changes(),
233            AnySnapshot::NestedMutable(s) => s.has_pending_changes(),
234            AnySnapshot::Global(s) => s.has_pending_changes(),
235            AnySnapshot::TransparentMutable(s) => s.has_pending_changes(),
236            AnySnapshot::TransparentReadonly(s) => s.has_pending_changes(),
237        }
238    }
239
240    /// Dispose of this snapshot.
241    pub fn dispose(&self) {
242        match self {
243            AnySnapshot::Readonly(s) => s.dispose(),
244            AnySnapshot::Mutable(s) => s.dispose(),
245            AnySnapshot::NestedReadonly(s) => s.dispose(),
246            AnySnapshot::NestedMutable(s) => s.dispose(),
247            AnySnapshot::Global(s) => s.dispose(),
248            AnySnapshot::TransparentMutable(s) => s.dispose(),
249            AnySnapshot::TransparentReadonly(s) => s.dispose(),
250        }
251    }
252
253    /// Check if disposed.
254    pub fn is_disposed(&self) -> bool {
255        match self {
256            AnySnapshot::Readonly(s) => s.is_disposed(),
257            AnySnapshot::Mutable(s) => s.is_disposed(),
258            AnySnapshot::NestedReadonly(s) => s.is_disposed(),
259            AnySnapshot::NestedMutable(s) => s.is_disposed(),
260            AnySnapshot::Global(s) => s.is_disposed(),
261            AnySnapshot::TransparentMutable(s) => s.is_disposed(),
262            AnySnapshot::TransparentReadonly(s) => s.is_disposed(),
263        }
264    }
265
266    /// Record a read.
267    pub fn record_read(&self, state: &dyn StateObject) {
268        match self {
269            AnySnapshot::Readonly(s) => s.record_read(state),
270            AnySnapshot::Mutable(s) => s.record_read(state),
271            AnySnapshot::NestedReadonly(s) => s.record_read(state),
272            AnySnapshot::NestedMutable(s) => s.record_read(state),
273            AnySnapshot::Global(s) => s.record_read(state),
274            AnySnapshot::TransparentMutable(s) => s.record_read(state),
275            AnySnapshot::TransparentReadonly(s) => s.record_read(state),
276        }
277    }
278
279    /// Record a write.
280    pub fn record_write(&self, state: Arc<dyn StateObject>) {
281        match self {
282            AnySnapshot::Readonly(s) => s.record_write(state),
283            AnySnapshot::Mutable(s) => s.record_write(state),
284            AnySnapshot::NestedReadonly(s) => s.record_write(state),
285            AnySnapshot::NestedMutable(s) => s.record_write(state),
286            AnySnapshot::Global(s) => s.record_write(state),
287            AnySnapshot::TransparentMutable(s) => s.record_write(state),
288            AnySnapshot::TransparentReadonly(s) => s.record_write(state),
289        }
290    }
291
292    /// Apply changes (only valid for mutable snapshots).
293    pub fn apply(&self) -> SnapshotApplyResult {
294        match self {
295            AnySnapshot::Mutable(s) => s.apply(),
296            AnySnapshot::NestedMutable(s) => s.apply(),
297            AnySnapshot::Global(s) => s.apply(),
298            AnySnapshot::TransparentMutable(s) => s.apply(),
299            _ => panic!("Cannot apply a read-only snapshot"),
300        }
301    }
302
303    /// Take a nested mutable snapshot (only valid for mutable snapshots).
304    pub fn take_nested_mutable_snapshot(
305        &self,
306        read_observer: Option<ReadObserver>,
307        write_observer: Option<WriteObserver>,
308    ) -> AnySnapshot {
309        match self {
310            AnySnapshot::Mutable(s) => AnySnapshot::NestedMutable(
311                s.take_nested_mutable_snapshot(read_observer, write_observer),
312            ),
313            AnySnapshot::NestedMutable(s) => AnySnapshot::NestedMutable(
314                s.take_nested_mutable_snapshot(read_observer, write_observer),
315            ),
316            AnySnapshot::Global(s) => {
317                AnySnapshot::Mutable(s.take_nested_mutable_snapshot(read_observer, write_observer))
318            }
319            AnySnapshot::TransparentMutable(s) => AnySnapshot::TransparentMutable(
320                s.take_nested_mutable_snapshot(read_observer, write_observer),
321            ),
322            _ => panic!("Cannot take nested mutable snapshot from read-only snapshot"),
323        }
324    }
325}
326
327thread_local! {
328    // Thread-local storage for the current snapshot.
329    static CURRENT_SNAPSHOT: RefCell<Option<AnySnapshot>> = const { RefCell::new(None) };
330}
331
332/// Get the current snapshot, or None if not in a snapshot context.
333pub fn current_snapshot() -> Option<AnySnapshot> {
334    CURRENT_SNAPSHOT
335        .try_with(|cell| cell.borrow().clone())
336        .unwrap_or(None)
337}
338
339/// Set the current snapshot (internal use only).
340pub(crate) fn set_current_snapshot(snapshot: Option<AnySnapshot>) {
341    let _ = CURRENT_SNAPSHOT.try_with(|cell| {
342        *cell.borrow_mut() = snapshot;
343    });
344}
345
346/// Convenience helper that mirrors the legacy `take_mutable_snapshot` API.
347///
348/// Returns a mutable snapshot rooted at the global snapshot with the provided
349/// read/write observers installed.
350pub fn take_mutable_snapshot(
351    read_observer: Option<ReadObserver>,
352    write_observer: Option<WriteObserver>,
353) -> Arc<MutableSnapshot> {
354    GlobalSnapshot::get_or_create().take_nested_mutable_snapshot(read_observer, write_observer)
355}
356
357/// Take a transparent observer mutable snapshot with optional observers.
358///
359/// This type of snapshot is used for read observation during composition,
360/// matching Kotlin's Snapshot.observeInternal behavior. It allows writes
361/// to happen during observation.
362///
363/// Transparent snapshots DO NOT allocate new IDs - they delegate to the
364/// current/global snapshot, making them "transparent" to the snapshot system.
365pub fn take_transparent_observer_mutable_snapshot(
366    read_observer: Option<ReadObserver>,
367    write_observer: Option<WriteObserver>,
368) -> Arc<TransparentObserverMutableSnapshot> {
369    let parent = current_snapshot();
370    match parent {
371        Some(AnySnapshot::TransparentMutable(transparent)) if transparent.can_reuse() => {
372            // Reuse the existing transparent snapshot
373            transparent
374        }
375        _ => {
376            // Create a new transparent snapshot using the current snapshot's ID
377            // Transparent snapshots do NOT allocate new IDs!
378            let current = current_snapshot()
379                .unwrap_or_else(|| AnySnapshot::Global(GlobalSnapshot::get_or_create()));
380            let id = current.snapshot_id();
381            let invalid = current.invalid();
382            TransparentObserverMutableSnapshot::new(
383                id,
384                invalid,
385                read_observer,
386                write_observer,
387                None,
388            )
389        }
390    }
391}
392
393/// Allocate a new record identifier that is distinct from any active snapshot id.
394pub fn allocate_record_id() -> SnapshotId {
395    runtime::allocate_record_id()
396}
397
398/// Get the next snapshot ID that will be allocated without incrementing the counter.
399///
400/// This is used for cleanup operations to determine the reuse limit.
401/// Mirrors Kotlin's `nextSnapshotId` field access.
402pub(crate) fn peek_next_snapshot_id() -> SnapshotId {
403    runtime::peek_next_snapshot_id()
404}
405
406/// Global counter for unique observer IDs.
407static NEXT_OBSERVER_ID: AtomicUsize = AtomicUsize::new(1);
408
409thread_local! {
410    // Global map of apply observers indexed by unique ID.
411    static APPLY_OBSERVERS: RefCell<HashMap<usize, ApplyObserver>> = RefCell::new(HashMap::default());
412}
413
414thread_local! {
415    // Thread-local last-writer registry used for conflict detection in v2.
416    //
417    // Maps a state object id to the snapshot id of the most recent successful apply
418    // that modified the object. This is a simplified conflict tracking mechanism
419    // for Phase 2.1 before full record-chain merging is implemented.
420    //
421    // Thread-local ensures test isolation - each test thread has its own registry.
422    static LAST_WRITES: RefCell<HashMap<StateObjectId, SnapshotId>> = RefCell::new(HashMap::default());
423}
424
425thread_local! {
426    // Thread-local weak set of state objects with multiple records for periodic garbage collection.
427    // Mirrors Kotlin's `extraStateObjects` WeakSet.
428    static EXTRA_STATE_OBJECTS: RefCell<crate::snapshot_weak_set::SnapshotWeakSet> = RefCell::new(crate::snapshot_weak_set::SnapshotWeakSet::new());
429}
430
431/// Register an apply observer.
432///
433/// Returns a handle that will automatically unregister the observer when dropped.
434pub fn register_apply_observer(observer: ApplyObserver) -> ObserverHandle {
435    let id = NEXT_OBSERVER_ID.fetch_add(1, Ordering::SeqCst);
436    APPLY_OBSERVERS.with(|cell| {
437        cell.borrow_mut().insert(id, observer);
438    });
439    ObserverHandle {
440        kind: ObserverKind::Apply,
441        id,
442    }
443}
444
445/// Handle for unregistering observers.
446///
447/// When dropped, automatically removes the associated observer.
448pub struct ObserverHandle {
449    kind: ObserverKind,
450    id: usize,
451}
452
453enum ObserverKind {
454    Apply,
455}
456
457impl Drop for ObserverHandle {
458    fn drop(&mut self) {
459        match self.kind {
460            ObserverKind::Apply => {
461                APPLY_OBSERVERS.with(|cell| {
462                    cell.borrow_mut().remove(&self.id);
463                });
464            }
465        }
466    }
467}
468
469/// Notify apply observers that a snapshot was applied.
470pub(crate) fn notify_apply_observers(modified: &[Arc<dyn StateObject>], snapshot_id: SnapshotId) {
471    // Copy observers so callbacks run outside the borrow
472    APPLY_OBSERVERS.with(|cell| {
473        let observers: Vec<ApplyObserver> = cell.borrow().values().cloned().collect();
474        for observer in observers.into_iter() {
475            observer(modified, snapshot_id);
476        }
477    });
478}
479
480/// Get the last successful writer snapshot id for a given object id.
481#[allow(dead_code)]
482pub(crate) fn get_last_write(id: StateObjectId) -> Option<SnapshotId> {
483    LAST_WRITES.with(|cell| cell.borrow().get(&id).copied())
484}
485
486/// Record the last successful writer snapshot id for a given object id.
487pub(crate) fn set_last_write(id: StateObjectId, snapshot_id: SnapshotId) {
488    LAST_WRITES.with(|cell| {
489        cell.borrow_mut().insert(id, snapshot_id);
490    });
491}
492
493/// Clear all last write records (for testing).
494#[cfg(test)]
495pub(crate) fn clear_last_writes() {
496    LAST_WRITES.with(|cell| {
497        cell.borrow_mut().clear();
498    });
499}
500
501/// Check and overwrite unused records for all tracked state objects.
502///
503/// Mirrors Kotlin's `checkAndOverwriteUnusedRecordsLocked()`. This method:
504/// 1. Iterates through all state objects in `EXTRA_STATE_OBJECTS`
505/// 2. Calls `overwrite_unused_records()` on each
506/// 3. Removes states that no longer need tracking (down to 1 or fewer records)
507/// 4. Automatically cleans up dead weak references
508pub(crate) fn check_and_overwrite_unused_records_locked() {
509    EXTRA_STATE_OBJECTS.with(|cell| {
510        cell.borrow_mut().remove_if(|state| {
511            // Returns true to keep, false to remove
512            state.overwrite_unused_records()
513        });
514    });
515}
516
517/// Process a state object for unused record cleanup, tracking it if needed.
518///
519/// Mirrors Kotlin's `processForUnusedRecordsLocked()`. After a state is modified:
520/// 1. Calls `overwrite_unused_records()` to clean up old records
521/// 2. If the state has multiple records, adds it to `EXTRA_STATE_OBJECTS` for future cleanup
522#[allow(dead_code)]
523pub(crate) fn process_for_unused_records_locked(state: &Arc<dyn crate::state::StateObject>) {
524    if state.overwrite_unused_records() {
525        // State has multiple records - track it for future cleanup
526        EXTRA_STATE_OBJECTS.with(|cell| {
527            cell.borrow_mut().add_trait_object(state);
528        });
529    }
530}
531
532pub(crate) fn optimistic_merges(
533    current_snapshot_id: SnapshotId,
534    base_parent_id: SnapshotId,
535    modified_objects: &[(StateObjectId, Arc<dyn StateObject>, SnapshotId)],
536    invalid_snapshots: &SnapshotIdSet,
537) -> Option<HashMap<usize, Arc<StateRecord>>> {
538    if modified_objects.is_empty() {
539        return None;
540    }
541
542    let mut result: Option<HashMap<usize, Arc<StateRecord>>> = None;
543
544    for (_, state, writer_id) in modified_objects.iter() {
545        let head = state.first_record();
546
547        let current = match crate::state::readable_record_for(
548            &head,
549            current_snapshot_id,
550            invalid_snapshots,
551        ) {
552            Some(record) => record,
553            None => continue,
554        };
555
556        let (previous_opt, found_base) = mutable::find_previous_record(&head, base_parent_id);
557        let previous = previous_opt?;
558
559        if !found_base || previous.snapshot_id() == crate::state::PREEXISTING_SNAPSHOT_ID {
560            continue;
561        }
562
563        if Arc::ptr_eq(&current, &previous) {
564            continue;
565        }
566
567        let applied = mutable::find_record_by_id(&head, *writer_id)?;
568
569        let merged = state.merge_records(
570            Arc::clone(&previous),
571            Arc::clone(&current),
572            Arc::clone(&applied),
573        )?;
574
575        result
576            .get_or_insert_with(HashMap::default)
577            .insert(Arc::as_ptr(&current) as usize, merged);
578    }
579
580    result
581}
582
583/// Merge two read observers into one.
584///
585/// # Thread Safety
586/// The resulting Arc-wrapped closure may capture non-Send closures. This is safe
587/// because observers are only invoked on the UI thread where they were created.
588#[allow(clippy::arc_with_non_send_sync)]
589pub fn merge_read_observers(
590    a: Option<ReadObserver>,
591    b: Option<ReadObserver>,
592) -> Option<ReadObserver> {
593    match (a, b) {
594        (None, None) => None,
595        (Some(a), None) => Some(a),
596        (None, Some(b)) => Some(b),
597        (Some(a), Some(b)) => Some(Arc::new(move |state: &dyn StateObject| {
598            a(state);
599            b(state);
600        })),
601    }
602}
603
604/// Merge two write observers into one.
605///
606/// # Thread Safety
607/// The resulting Arc-wrapped closure may capture non-Send closures. This is safe
608/// because observers are only invoked on the UI thread where they were created.
609#[allow(clippy::arc_with_non_send_sync)]
610pub fn merge_write_observers(
611    a: Option<WriteObserver>,
612    b: Option<WriteObserver>,
613) -> Option<WriteObserver> {
614    match (a, b) {
615        (None, None) => None,
616        (Some(a), None) => Some(a),
617        (None, Some(b)) => Some(b),
618        (Some(a), Some(b)) => Some(Arc::new(move |state: &dyn StateObject| {
619            a(state);
620            b(state);
621        })),
622    }
623}
624
625/// Shared state for all snapshots.
626pub(crate) struct SnapshotState {
627    /// The snapshot ID.
628    pub(crate) id: Cell<SnapshotId>,
629    /// Set of invalid snapshot IDs.
630    pub(crate) invalid: RefCell<SnapshotIdSet>,
631    /// Pin handle to keep this snapshot alive.
632    pub(crate) pin_handle: Cell<PinHandle>,
633    /// Whether this snapshot has been disposed.
634    pub(crate) disposed: Cell<bool>,
635    /// Read observer, if any.
636    pub(crate) read_observer: Option<ReadObserver>,
637    /// Write observer, if any.
638    pub(crate) write_observer: Option<WriteObserver>,
639    /// Modified state objects.
640    #[allow(clippy::type_complexity)]
641    // HashMap value is (Arc, SnapshotId) - reasonable for tracking state
642    pub(crate) modified: RefCell<HashMap<StateObjectId, (Arc<dyn StateObject>, SnapshotId)>>,
643    /// Optional callback invoked once when disposed.
644    on_dispose: RefCell<Option<Box<dyn FnOnce()>>>,
645    /// Whether this snapshot's lifecycle is tracked in the global runtime.
646    runtime_tracked: bool,
647    /// Set of child snapshot ids that are still pending.
648    pending_children: RefCell<HashSet<SnapshotId>>,
649}
650
651impl SnapshotState {
652    pub(crate) fn new(
653        id: SnapshotId,
654        invalid: SnapshotIdSet,
655        read_observer: Option<ReadObserver>,
656        write_observer: Option<WriteObserver>,
657        runtime_tracked: bool,
658    ) -> Self {
659        let pin_handle = snapshot_pinning::track_pinning(id, &invalid);
660        Self {
661            id: Cell::new(id),
662            invalid: RefCell::new(invalid),
663            pin_handle: Cell::new(pin_handle),
664            disposed: Cell::new(false),
665            read_observer,
666            write_observer,
667            modified: RefCell::new(HashMap::default()),
668            on_dispose: RefCell::new(None),
669            runtime_tracked,
670            pending_children: RefCell::new(HashSet::default()),
671        }
672    }
673
674    pub(crate) fn record_read(&self, state: &dyn StateObject) {
675        if let Some(ref observer) = self.read_observer {
676            observer(state);
677        }
678    }
679
680    pub(crate) fn record_write(&self, state: Arc<dyn StateObject>, writer_id: SnapshotId) {
681        // Get the unique ID for this state object
682        let state_id = state.object_id().as_usize();
683
684        let mut modified = self.modified.borrow_mut();
685
686        // Only call observer on first write
687        match modified.entry(state_id) {
688            std::collections::hash_map::Entry::Vacant(e) => {
689                if let Some(ref observer) = self.write_observer {
690                    observer(&*state);
691                }
692                // Store the Arc and writer id in the modified set
693                e.insert((state, writer_id));
694            }
695            std::collections::hash_map::Entry::Occupied(mut e) => {
696                // Update the writer id to reflect the most recent writer for this state.
697                e.insert((state, writer_id));
698            }
699        }
700    }
701
702    pub(crate) fn dispose(&self) {
703        if !self.disposed.replace(true) {
704            let pin_handle = self.pin_handle.get();
705            snapshot_pinning::release_pinning(pin_handle);
706            if let Some(cb) = self.on_dispose.borrow_mut().take() {
707                cb();
708            }
709            if self.runtime_tracked {
710                close_snapshot(self.id.get());
711            }
712        }
713    }
714
715    pub(crate) fn add_pending_child(&self, id: SnapshotId) {
716        self.pending_children.borrow_mut().insert(id);
717    }
718
719    pub(crate) fn remove_pending_child(&self, id: SnapshotId) {
720        self.pending_children.borrow_mut().remove(&id);
721    }
722
723    pub(crate) fn has_pending_children(&self) -> bool {
724        !self.pending_children.borrow().is_empty()
725    }
726
727    pub(crate) fn pending_children(&self) -> Vec<SnapshotId> {
728        self.pending_children.borrow().iter().copied().collect()
729    }
730
731    pub(crate) fn set_on_dispose<F>(&self, f: F)
732    where
733        F: FnOnce() + 'static,
734    {
735        *self.on_dispose.borrow_mut() = Some(Box::new(f));
736    }
737}
738
739#[cfg(test)]
740mod tests {
741    use super::*;
742
743    #[test]
744    fn test_apply_result_is_success() {
745        assert!(SnapshotApplyResult::Success.is_success());
746        assert!(!SnapshotApplyResult::Failure.is_success());
747    }
748
749    #[test]
750    fn test_apply_result_is_failure() {
751        assert!(!SnapshotApplyResult::Success.is_failure());
752        assert!(SnapshotApplyResult::Failure.is_failure());
753    }
754
755    #[test]
756    fn test_apply_result_check_success() {
757        SnapshotApplyResult::Success.check(); // Should not panic
758    }
759
760    #[test]
761    #[should_panic(expected = "Snapshot apply failed")]
762    fn test_apply_result_check_failure() {
763        SnapshotApplyResult::Failure.check(); // Should panic
764    }
765
766    #[test]
767    fn test_merge_read_observers_both_none() {
768        let result = merge_read_observers(None, None);
769        assert!(result.is_none());
770    }
771
772    #[test]
773    fn test_merge_read_observers_one_some() {
774        let observer = Arc::new(|_: &dyn StateObject| {});
775        let result = merge_read_observers(Some(observer.clone()), None);
776        assert!(result.is_some());
777
778        let result = merge_read_observers(None, Some(observer));
779        assert!(result.is_some());
780    }
781
782    #[test]
783    fn test_merge_write_observers_both_none() {
784        let result = merge_write_observers(None, None);
785        assert!(result.is_none());
786    }
787
788    #[test]
789    fn test_merge_write_observers_one_some() {
790        let observer = Arc::new(|_: &dyn StateObject| {});
791        let result = merge_write_observers(Some(observer.clone()), None);
792        assert!(result.is_some());
793
794        let result = merge_write_observers(None, Some(observer));
795        assert!(result.is_some());
796    }
797
798    #[test]
799    fn test_current_snapshot_none_initially() {
800        set_current_snapshot(None);
801        assert!(current_snapshot().is_none());
802    }
803
804    // Test helper: Simple state object for testing
805    struct TestStateObject {
806        id: usize,
807    }
808
809    impl TestStateObject {
810        fn new(id: usize) -> Arc<Self> {
811            Arc::new(Self { id })
812        }
813    }
814
815    impl StateObject for TestStateObject {
816        fn object_id(&self) -> crate::state::ObjectId {
817            crate::state::ObjectId(self.id)
818        }
819
820        fn first_record(&self) -> Arc<crate::state::StateRecord> {
821            unimplemented!("Not needed for observer tests")
822        }
823
824        fn readable_record(
825            &self,
826            _snapshot_id: SnapshotId,
827            _invalid: &SnapshotIdSet,
828        ) -> Arc<crate::state::StateRecord> {
829            unimplemented!("Not needed for observer tests")
830        }
831
832        fn prepend_state_record(&self, _record: Arc<crate::state::StateRecord>) {
833            unimplemented!("Not needed for observer tests")
834        }
835
836        fn promote_record(&self, _child_id: SnapshotId) -> Result<(), &'static str> {
837            unimplemented!("Not needed for observer tests")
838        }
839
840        fn as_any(&self) -> &dyn std::any::Any {
841            self
842        }
843    }
844
845    #[test]
846    fn test_apply_observer_receives_correct_modified_objects() {
847        use std::sync::Mutex;
848
849        // Setup: Track what the observer receives
850        let received_count = Arc::new(Mutex::new(0));
851        let received_snapshot_id = Arc::new(Mutex::new(0));
852
853        let received_count_clone = received_count.clone();
854        let received_snapshot_id_clone = received_snapshot_id.clone();
855
856        // Register observer
857        let _handle = register_apply_observer(Arc::new(move |modified, snapshot_id| {
858            *received_snapshot_id_clone.lock().unwrap() = snapshot_id;
859            *received_count_clone.lock().unwrap() = modified.len();
860        }));
861
862        // Create test objects
863        let obj1: Arc<dyn StateObject> = TestStateObject::new(42);
864        let obj2: Arc<dyn StateObject> = TestStateObject::new(99);
865        let modified = vec![obj1, obj2];
866
867        // Notify observers
868        notify_apply_observers(&modified, 123);
869
870        // Verify
871        assert_eq!(*received_snapshot_id.lock().unwrap(), 123);
872        assert_eq!(*received_count.lock().unwrap(), 2);
873    }
874
875    #[test]
876    fn test_apply_observer_receives_correct_snapshot_id() {
877        use std::sync::Mutex;
878
879        let received_id = Arc::new(Mutex::new(0));
880        let received_id_clone = received_id.clone();
881
882        let _handle = register_apply_observer(Arc::new(move |_, snapshot_id| {
883            *received_id_clone.lock().unwrap() = snapshot_id;
884        }));
885
886        // Notify with specific snapshot ID
887        notify_apply_observers(&[], 456);
888
889        assert_eq!(*received_id.lock().unwrap(), 456);
890    }
891
892    #[test]
893    fn test_multiple_apply_observers_all_called() {
894        use std::sync::Mutex;
895
896        let call_count1 = Arc::new(Mutex::new(0));
897        let call_count2 = Arc::new(Mutex::new(0));
898        let call_count3 = Arc::new(Mutex::new(0));
899
900        let call_count1_clone = call_count1.clone();
901        let call_count2_clone = call_count2.clone();
902        let call_count3_clone = call_count3.clone();
903
904        // Register three observers
905        let _handle1 = register_apply_observer(Arc::new(move |_, _| {
906            *call_count1_clone.lock().unwrap() += 1;
907        }));
908
909        let _handle2 = register_apply_observer(Arc::new(move |_, _| {
910            *call_count2_clone.lock().unwrap() += 1;
911        }));
912
913        let _handle3 = register_apply_observer(Arc::new(move |_, _| {
914            *call_count3_clone.lock().unwrap() += 1;
915        }));
916
917        // Notify observers
918        notify_apply_observers(&[], 1);
919
920        // All three should have been called
921        assert_eq!(*call_count1.lock().unwrap(), 1);
922        assert_eq!(*call_count2.lock().unwrap(), 1);
923        assert_eq!(*call_count3.lock().unwrap(), 1);
924
925        // Notify again
926        notify_apply_observers(&[], 2);
927
928        // All should have been called twice
929        assert_eq!(*call_count1.lock().unwrap(), 2);
930        assert_eq!(*call_count2.lock().unwrap(), 2);
931        assert_eq!(*call_count3.lock().unwrap(), 2);
932    }
933
934    #[test]
935    fn test_apply_observer_not_called_for_empty_modifications() {
936        use std::sync::Mutex;
937
938        let call_count = Arc::new(Mutex::new(0));
939        let call_count_clone = call_count.clone();
940
941        let _handle = register_apply_observer(Arc::new(move |modified, _| {
942            // Observer should still be called, but with empty array
943            *call_count_clone.lock().unwrap() += 1;
944            assert_eq!(modified.len(), 0);
945        }));
946
947        // Notify with no modifications
948        notify_apply_observers(&[], 1);
949
950        // Observer should have been called
951        assert_eq!(*call_count.lock().unwrap(), 1);
952    }
953
954    #[test]
955    fn test_observer_handle_drop_removes_correct_observer() {
956        use std::sync::Mutex;
957
958        // Register three observers that track their IDs
959        let calls = Arc::new(Mutex::new(Vec::new()));
960
961        let calls1 = calls.clone();
962        let handle1 = register_apply_observer(Arc::new(move |_, _| {
963            calls1.lock().unwrap().push(1);
964        }));
965
966        let calls2 = calls.clone();
967        let handle2 = register_apply_observer(Arc::new(move |_, _| {
968            calls2.lock().unwrap().push(2);
969        }));
970
971        let calls3 = calls.clone();
972        let handle3 = register_apply_observer(Arc::new(move |_, _| {
973            calls3.lock().unwrap().push(3);
974        }));
975
976        // All three should be called
977        notify_apply_observers(&[], 1);
978        let result = calls.lock().unwrap().clone();
979        assert_eq!(result.len(), 3);
980        assert!(result.contains(&1));
981        assert!(result.contains(&2));
982        assert!(result.contains(&3));
983        calls.lock().unwrap().clear();
984
985        // Drop handle2 (middle one)
986        drop(handle2);
987
988        // Only 1 and 3 should be called now
989        notify_apply_observers(&[], 2);
990        let result = calls.lock().unwrap().clone();
991        assert_eq!(result.len(), 2);
992        assert!(result.contains(&1));
993        assert!(result.contains(&3));
994        assert!(!result.contains(&2));
995        calls.lock().unwrap().clear();
996
997        // Drop handle1
998        drop(handle1);
999
1000        // Only 3 should be called
1001        notify_apply_observers(&[], 3);
1002        let result = calls.lock().unwrap().clone();
1003        assert_eq!(result.len(), 1);
1004        assert!(result.contains(&3));
1005        calls.lock().unwrap().clear();
1006
1007        // Drop handle3
1008        drop(handle3);
1009
1010        // None should be called
1011        notify_apply_observers(&[], 4);
1012        assert_eq!(calls.lock().unwrap().len(), 0);
1013    }
1014
1015    #[test]
1016    fn test_observer_handle_drop_in_different_orders() {
1017        use std::sync::Mutex;
1018
1019        // Test 1: Drop in reverse order (3, 2, 1)
1020        {
1021            let calls = Arc::new(Mutex::new(Vec::new()));
1022
1023            let calls1 = calls.clone();
1024            let h1 = register_apply_observer(Arc::new(move |_, _| {
1025                calls1.lock().unwrap().push(1);
1026            }));
1027
1028            let calls2 = calls.clone();
1029            let h2 = register_apply_observer(Arc::new(move |_, _| {
1030                calls2.lock().unwrap().push(2);
1031            }));
1032
1033            let calls3 = calls.clone();
1034            let h3 = register_apply_observer(Arc::new(move |_, _| {
1035                calls3.lock().unwrap().push(3);
1036            }));
1037
1038            drop(h3);
1039            notify_apply_observers(&[], 1);
1040            let result = calls.lock().unwrap().clone();
1041            assert!(result.contains(&1) && result.contains(&2) && !result.contains(&3));
1042            calls.lock().unwrap().clear();
1043
1044            drop(h2);
1045            notify_apply_observers(&[], 2);
1046            let result = calls.lock().unwrap().clone();
1047            assert_eq!(result.len(), 1);
1048            assert!(result.contains(&1));
1049            calls.lock().unwrap().clear();
1050
1051            drop(h1);
1052            notify_apply_observers(&[], 3);
1053            assert_eq!(calls.lock().unwrap().len(), 0);
1054        }
1055
1056        // Test 2: Drop in forward order (1, 2, 3)
1057        {
1058            let calls = Arc::new(Mutex::new(Vec::new()));
1059
1060            let calls1 = calls.clone();
1061            let h1 = register_apply_observer(Arc::new(move |_, _| {
1062                calls1.lock().unwrap().push(1);
1063            }));
1064
1065            let calls2 = calls.clone();
1066            let h2 = register_apply_observer(Arc::new(move |_, _| {
1067                calls2.lock().unwrap().push(2);
1068            }));
1069
1070            let calls3 = calls.clone();
1071            let h3 = register_apply_observer(Arc::new(move |_, _| {
1072                calls3.lock().unwrap().push(3);
1073            }));
1074
1075            drop(h1);
1076            notify_apply_observers(&[], 1);
1077            let result = calls.lock().unwrap().clone();
1078            assert!(!result.contains(&1) && result.contains(&2) && result.contains(&3));
1079            calls.lock().unwrap().clear();
1080
1081            drop(h2);
1082            notify_apply_observers(&[], 2);
1083            let result = calls.lock().unwrap().clone();
1084            assert_eq!(result.len(), 1);
1085            assert!(result.contains(&3));
1086            calls.lock().unwrap().clear();
1087
1088            drop(h3);
1089            notify_apply_observers(&[], 3);
1090            assert_eq!(calls.lock().unwrap().len(), 0);
1091        }
1092    }
1093
1094    #[test]
1095    fn test_remaining_observers_still_work_after_drop() {
1096        use std::sync::Mutex;
1097
1098        let calls = Arc::new(Mutex::new(Vec::new()));
1099
1100        let calls1 = calls.clone();
1101        let handle1 = register_apply_observer(Arc::new(move |_, snapshot_id| {
1102            calls1.lock().unwrap().push((1, snapshot_id));
1103        }));
1104
1105        let calls2 = calls.clone();
1106        let handle2 = register_apply_observer(Arc::new(move |_, snapshot_id| {
1107            calls2.lock().unwrap().push((2, snapshot_id));
1108        }));
1109
1110        // Both work
1111        notify_apply_observers(&[], 100);
1112        assert_eq!(calls.lock().unwrap().len(), 2);
1113        calls.lock().unwrap().clear();
1114
1115        // Drop handle1
1116        drop(handle1);
1117
1118        // handle2 still works with new snapshot ID
1119        notify_apply_observers(&[], 200);
1120        assert_eq!(*calls.lock().unwrap(), vec![(2, 200)]);
1121        calls.lock().unwrap().clear();
1122
1123        // Register new observer after dropping handle1
1124        let calls3 = calls.clone();
1125        let _handle3 = register_apply_observer(Arc::new(move |_, snapshot_id| {
1126            calls3.lock().unwrap().push((3, snapshot_id));
1127        }));
1128
1129        // Both handle2 and handle3 work
1130        notify_apply_observers(&[], 300);
1131        let result = calls.lock().unwrap().clone();
1132        assert_eq!(result.len(), 2);
1133        assert!(result.contains(&(2, 300)));
1134        assert!(result.contains(&(3, 300)));
1135
1136        drop(handle2);
1137    }
1138
1139    #[test]
1140    fn test_observer_ids_are_unique() {
1141        use std::sync::Mutex;
1142
1143        let ids = Arc::new(Mutex::new(std::collections::HashSet::new()));
1144
1145        let mut handles = Vec::new();
1146
1147        // Register 100 observers and track their IDs through side channel
1148        // Since we can't directly access the ID from the handle, we'll verify
1149        // uniqueness by ensuring all observers get called
1150        for i in 0..100 {
1151            let ids_clone = ids.clone();
1152            let handle = register_apply_observer(Arc::new(move |_, _| {
1153                ids_clone.lock().unwrap().insert(i);
1154            }));
1155            handles.push(handle);
1156        }
1157
1158        // Notify once - all 100 should be called
1159        notify_apply_observers(&[], 1);
1160        assert_eq!(ids.lock().unwrap().len(), 100);
1161
1162        // Drop every other handle
1163        for i in (0..100).step_by(2) {
1164            handles.remove(i / 2);
1165        }
1166
1167        // Clear and notify again - only 50 should be called
1168        ids.lock().unwrap().clear();
1169        notify_apply_observers(&[], 2);
1170        assert_eq!(ids.lock().unwrap().len(), 50);
1171    }
1172
1173    #[test]
1174    fn test_state_object_storage_in_modified_set() {
1175        use crate::state::StateObject;
1176        use std::cell::Cell;
1177
1178        // Mock StateObject for testing
1179        #[allow(dead_code)]
1180        struct TestState {
1181            value: Cell<i32>,
1182        }
1183
1184        impl StateObject for TestState {
1185            fn object_id(&self) -> crate::state::ObjectId {
1186                crate::state::ObjectId(12345)
1187            }
1188
1189            fn first_record(&self) -> Arc<crate::state::StateRecord> {
1190                unimplemented!("Not needed for this test")
1191            }
1192
1193            fn readable_record(
1194                &self,
1195                _snapshot_id: SnapshotId,
1196                _invalid: &SnapshotIdSet,
1197            ) -> Arc<crate::state::StateRecord> {
1198                unimplemented!("Not needed for this test")
1199            }
1200
1201            fn prepend_state_record(&self, _record: Arc<crate::state::StateRecord>) {
1202                unimplemented!("Not needed for this test")
1203            }
1204
1205            fn promote_record(&self, _child_id: SnapshotId) -> Result<(), &'static str> {
1206                unimplemented!("Not needed for this test")
1207            }
1208
1209            fn as_any(&self) -> &dyn std::any::Any {
1210                self
1211            }
1212        }
1213
1214        let state = SnapshotState::new(1, SnapshotIdSet::new(), None, None, false);
1215
1216        // Create Arc to state object
1217        let state_obj = Arc::new(TestState {
1218            value: Cell::new(42),
1219        }) as Arc<dyn StateObject>;
1220
1221        // Record write should store the Arc
1222        state.record_write(state_obj.clone(), 1);
1223
1224        // Verify it was stored in the modified set
1225        let modified = state.modified.borrow();
1226        assert_eq!(modified.len(), 1);
1227        assert!(modified.contains_key(&12345));
1228
1229        // Verify the Arc is the same object
1230        let (stored, writer_id) = modified.get(&12345).unwrap();
1231        assert_eq!(stored.object_id().as_usize(), 12345);
1232        assert_eq!(*writer_id, 1);
1233    }
1234
1235    #[test]
1236    fn test_multiple_writes_to_same_state_object() {
1237        use crate::state::StateObject;
1238        use std::cell::Cell;
1239
1240        #[allow(dead_code)]
1241        struct TestState {
1242            value: Cell<i32>,
1243        }
1244
1245        impl StateObject for TestState {
1246            fn object_id(&self) -> crate::state::ObjectId {
1247                crate::state::ObjectId(99999)
1248            }
1249
1250            fn first_record(&self) -> Arc<crate::state::StateRecord> {
1251                unimplemented!()
1252            }
1253
1254            fn readable_record(
1255                &self,
1256                _snapshot_id: SnapshotId,
1257                _invalid: &SnapshotIdSet,
1258            ) -> Arc<crate::state::StateRecord> {
1259                unimplemented!()
1260            }
1261
1262            fn prepend_state_record(&self, _record: Arc<crate::state::StateRecord>) {
1263                unimplemented!()
1264            }
1265
1266            fn promote_record(&self, _child_id: SnapshotId) -> Result<(), &'static str> {
1267                unimplemented!()
1268            }
1269
1270            fn as_any(&self) -> &dyn std::any::Any {
1271                self
1272            }
1273        }
1274
1275        let state = SnapshotState::new(1, SnapshotIdSet::new(), None, None, false);
1276        let state_obj = Arc::new(TestState {
1277            value: Cell::new(100),
1278        }) as Arc<dyn StateObject>;
1279
1280        // First write
1281        state.record_write(state_obj.clone(), 1);
1282        assert_eq!(state.modified.borrow().len(), 1);
1283
1284        // Second write to same object should not add a new entry but updates writer id
1285        state.record_write(state_obj.clone(), 2);
1286        let modified = state.modified.borrow();
1287        assert_eq!(modified.len(), 1);
1288        assert!(modified.contains_key(&99999));
1289        let (_, writer_id) = modified.get(&99999).unwrap();
1290        assert_eq!(*writer_id, 2);
1291    }
1292}