1#![allow(clippy::arc_with_non_send_sync)]
24
25use crate::collections::map::HashMap; use crate::collections::map::HashSet;
27use crate::snapshot_id_set::{SnapshotId, SnapshotIdSet};
28use crate::snapshot_pinning::{self, PinHandle};
29use crate::state::{StateObject, StateRecord};
30use std::cell::{Cell, RefCell};
31use std::rc::Rc;
32use std::sync::atomic::{AtomicUsize, Ordering};
33use std::sync::{Arc, Weak};
34
35mod global;
36mod mutable;
37mod nested;
38mod readonly;
39mod runtime;
40mod transparent;
41
42#[cfg(test)]
43mod integration_tests;
44
45pub use global::{advance_global_snapshot, GlobalSnapshot};
46pub use mutable::MutableSnapshot;
47pub use nested::{NestedMutableSnapshot, NestedReadonlySnapshot};
48pub use readonly::ReadonlySnapshot;
49pub use transparent::{TransparentObserverMutableSnapshot, TransparentObserverSnapshot};
50
51pub(crate) use runtime::{allocate_snapshot, close_snapshot, with_runtime};
52
53#[cfg(test)]
54pub(crate) use runtime::{reset_runtime_for_tests, TestRuntimeGuard};
55
56pub type ReadObserver = Arc<dyn Fn(&dyn StateObject) + 'static>;
58
59pub type WriteObserver = Arc<dyn Fn(&dyn StateObject) + 'static>;
61
62pub type ApplyObserver = Arc<dyn Fn(&[Arc<dyn StateObject>], SnapshotId) + 'static>;
64
65#[derive(Debug, Clone, PartialEq, Eq)]
67pub enum SnapshotApplyResult {
68 Success,
70 Failure,
72}
73
74impl SnapshotApplyResult {
75 pub fn is_success(&self) -> bool {
77 matches!(self, SnapshotApplyResult::Success)
78 }
79
80 pub fn is_failure(&self) -> bool {
82 matches!(self, SnapshotApplyResult::Failure)
83 }
84
85 #[track_caller]
87 pub fn check(&self) {
88 if self.is_failure() {
89 panic!("Snapshot apply failed");
90 }
91 }
92}
93
94pub type StateObjectId = usize;
96
97#[derive(Clone)]
102pub enum AnySnapshot {
103 Readonly(Arc<ReadonlySnapshot>),
104 Mutable(Arc<MutableSnapshot>),
105 NestedReadonly(Arc<NestedReadonlySnapshot>),
106 NestedMutable(Arc<NestedMutableSnapshot>),
107 Global(Arc<GlobalSnapshot>),
108 TransparentMutable(Arc<TransparentObserverMutableSnapshot>),
109 TransparentReadonly(Arc<TransparentObserverSnapshot>),
110}
111
112#[derive(Clone)]
119pub enum AnyMutableSnapshot {
120 Root(Arc<MutableSnapshot>),
121 Nested(Arc<NestedMutableSnapshot>),
122}
123
124impl AnyMutableSnapshot {
125 pub fn snapshot_id(&self) -> SnapshotId {
127 match self {
128 AnyMutableSnapshot::Root(s) => s.snapshot_id(),
129 AnyMutableSnapshot::Nested(s) => s.snapshot_id(),
130 }
131 }
132
133 pub fn invalid(&self) -> SnapshotIdSet {
135 match self {
136 AnyMutableSnapshot::Root(s) => s.invalid(),
137 AnyMutableSnapshot::Nested(s) => s.invalid(),
138 }
139 }
140
141 pub fn enter<T>(&self, f: impl FnOnce() -> T) -> T {
143 match self {
144 AnyMutableSnapshot::Root(s) => s.enter(f),
145 AnyMutableSnapshot::Nested(s) => s.enter(f),
146 }
147 }
148
149 pub fn apply(&self) -> SnapshotApplyResult {
151 match self {
152 AnyMutableSnapshot::Root(s) => s.apply(),
153 AnyMutableSnapshot::Nested(s) => s.apply(),
154 }
155 }
156
157 pub fn dispose(&self) {
159 match self {
160 AnyMutableSnapshot::Root(s) => s.dispose(),
161 AnyMutableSnapshot::Nested(s) => s.dispose(),
162 }
163 }
164}
165
166impl AnySnapshot {
167 pub fn snapshot_id(&self) -> SnapshotId {
169 match self {
170 AnySnapshot::Readonly(s) => s.snapshot_id(),
171 AnySnapshot::Mutable(s) => s.snapshot_id(),
172 AnySnapshot::NestedReadonly(s) => s.snapshot_id(),
173 AnySnapshot::NestedMutable(s) => s.snapshot_id(),
174 AnySnapshot::Global(s) => s.snapshot_id(),
175 AnySnapshot::TransparentMutable(s) => s.snapshot_id(),
176 AnySnapshot::TransparentReadonly(s) => s.snapshot_id(),
177 }
178 }
179
180 pub fn invalid(&self) -> SnapshotIdSet {
182 match self {
183 AnySnapshot::Readonly(s) => s.invalid(),
184 AnySnapshot::Mutable(s) => s.invalid(),
185 AnySnapshot::NestedReadonly(s) => s.invalid(),
186 AnySnapshot::NestedMutable(s) => s.invalid(),
187 AnySnapshot::Global(s) => s.invalid(),
188 AnySnapshot::TransparentMutable(s) => s.invalid(),
189 AnySnapshot::TransparentReadonly(s) => s.invalid(),
190 }
191 }
192
193 pub fn is_valid(&self, id: SnapshotId) -> bool {
195 let snapshot_id = self.snapshot_id();
196 id <= snapshot_id && !self.invalid().get(id)
197 }
198
199 pub fn read_only(&self) -> bool {
201 match self {
202 AnySnapshot::Readonly(_) => true,
203 AnySnapshot::Mutable(_) => false,
204 AnySnapshot::NestedReadonly(_) => true,
205 AnySnapshot::NestedMutable(_) => false,
206 AnySnapshot::Global(_) => false,
207 AnySnapshot::TransparentMutable(_) => false,
208 AnySnapshot::TransparentReadonly(_) => true,
209 }
210 }
211
212 pub fn root(&self) -> AnySnapshot {
214 match self {
215 AnySnapshot::Readonly(s) => AnySnapshot::Readonly(s.root_readonly()),
216 AnySnapshot::Mutable(s) => AnySnapshot::Mutable(s.root_mutable()),
217 AnySnapshot::NestedReadonly(s) => AnySnapshot::NestedReadonly(s.root_nested_readonly()),
218 AnySnapshot::NestedMutable(s) => AnySnapshot::Mutable(s.root_mutable()),
219 AnySnapshot::Global(s) => AnySnapshot::Global(s.root_global()),
220 AnySnapshot::TransparentMutable(s) => {
221 AnySnapshot::TransparentMutable(s.root_transparent_mutable())
222 }
223 AnySnapshot::TransparentReadonly(s) => {
224 AnySnapshot::TransparentReadonly(s.root_transparent_readonly())
225 }
226 }
227 }
228
229 pub fn is_same_transparent(&self, other: &Arc<TransparentObserverMutableSnapshot>) -> bool {
231 matches!(self, AnySnapshot::TransparentMutable(snapshot) if Arc::ptr_eq(snapshot, other))
232 }
233
234 pub fn is_same_transparent_mutable(
236 &self,
237 other: &Arc<TransparentObserverMutableSnapshot>,
238 ) -> bool {
239 self.is_same_transparent(other)
240 }
241
242 pub fn is_same_transparent_readonly(&self, other: &Arc<TransparentObserverSnapshot>) -> bool {
244 matches!(self, AnySnapshot::TransparentReadonly(snapshot) if Arc::ptr_eq(snapshot, other))
245 }
246
247 pub fn enter<T>(&self, f: impl FnOnce() -> T) -> T {
249 match self {
250 AnySnapshot::Readonly(s) => s.enter(f),
251 AnySnapshot::Mutable(s) => s.enter(f),
252 AnySnapshot::NestedReadonly(s) => s.enter(f),
253 AnySnapshot::NestedMutable(s) => s.enter(f),
254 AnySnapshot::Global(s) => s.enter(f),
255 AnySnapshot::TransparentMutable(s) => s.enter(f),
256 AnySnapshot::TransparentReadonly(s) => s.enter(f),
257 }
258 }
259
260 pub fn take_nested_snapshot(&self, read_observer: Option<ReadObserver>) -> AnySnapshot {
262 match self {
263 AnySnapshot::Readonly(s) => {
264 AnySnapshot::Readonly(s.take_nested_snapshot(read_observer))
265 }
266 AnySnapshot::Mutable(s) => AnySnapshot::Readonly(s.take_nested_snapshot(read_observer)),
267 AnySnapshot::NestedReadonly(s) => {
268 AnySnapshot::NestedReadonly(s.take_nested_snapshot(read_observer))
269 }
270 AnySnapshot::NestedMutable(s) => {
271 AnySnapshot::Readonly(s.take_nested_snapshot(read_observer))
272 }
273 AnySnapshot::Global(s) => AnySnapshot::Readonly(s.take_nested_snapshot(read_observer)),
274 AnySnapshot::TransparentMutable(s) => {
275 AnySnapshot::Readonly(s.take_nested_snapshot(read_observer))
276 }
277 AnySnapshot::TransparentReadonly(s) => {
278 AnySnapshot::TransparentReadonly(s.take_nested_snapshot(read_observer))
279 }
280 }
281 }
282
283 pub fn has_pending_changes(&self) -> bool {
285 match self {
286 AnySnapshot::Readonly(s) => s.has_pending_changes(),
287 AnySnapshot::Mutable(s) => s.has_pending_changes(),
288 AnySnapshot::NestedReadonly(s) => s.has_pending_changes(),
289 AnySnapshot::NestedMutable(s) => s.has_pending_changes(),
290 AnySnapshot::Global(s) => s.has_pending_changes(),
291 AnySnapshot::TransparentMutable(s) => s.has_pending_changes(),
292 AnySnapshot::TransparentReadonly(s) => s.has_pending_changes(),
293 }
294 }
295
296 pub fn dispose(&self) {
298 match self {
299 AnySnapshot::Readonly(s) => s.dispose(),
300 AnySnapshot::Mutable(s) => s.dispose(),
301 AnySnapshot::NestedReadonly(s) => s.dispose(),
302 AnySnapshot::NestedMutable(s) => s.dispose(),
303 AnySnapshot::Global(s) => s.dispose(),
304 AnySnapshot::TransparentMutable(s) => s.dispose(),
305 AnySnapshot::TransparentReadonly(s) => s.dispose(),
306 }
307 }
308
309 pub fn is_disposed(&self) -> bool {
311 match self {
312 AnySnapshot::Readonly(s) => s.is_disposed(),
313 AnySnapshot::Mutable(s) => s.is_disposed(),
314 AnySnapshot::NestedReadonly(s) => s.is_disposed(),
315 AnySnapshot::NestedMutable(s) => s.is_disposed(),
316 AnySnapshot::Global(s) => s.is_disposed(),
317 AnySnapshot::TransparentMutable(s) => s.is_disposed(),
318 AnySnapshot::TransparentReadonly(s) => s.is_disposed(),
319 }
320 }
321
322 pub fn record_read(&self, state: &dyn StateObject) {
324 match self {
325 AnySnapshot::Readonly(s) => s.record_read(state),
326 AnySnapshot::Mutable(s) => s.record_read(state),
327 AnySnapshot::NestedReadonly(s) => s.record_read(state),
328 AnySnapshot::NestedMutable(s) => s.record_read(state),
329 AnySnapshot::Global(s) => s.record_read(state),
330 AnySnapshot::TransparentMutable(s) => s.record_read(state),
331 AnySnapshot::TransparentReadonly(s) => s.record_read(state),
332 }
333 }
334
335 pub fn record_write(&self, state: Arc<dyn StateObject>) {
337 match self {
338 AnySnapshot::Readonly(s) => s.record_write(state),
339 AnySnapshot::Mutable(s) => s.record_write(state),
340 AnySnapshot::NestedReadonly(s) => s.record_write(state),
341 AnySnapshot::NestedMutable(s) => s.record_write(state),
342 AnySnapshot::Global(s) => s.record_write(state),
343 AnySnapshot::TransparentMutable(s) => s.record_write(state),
344 AnySnapshot::TransparentReadonly(s) => s.record_write(state),
345 }
346 }
347
348 pub fn apply(&self) -> SnapshotApplyResult {
350 match self {
351 AnySnapshot::Mutable(s) => s.apply(),
352 AnySnapshot::NestedMutable(s) => s.apply(),
353 AnySnapshot::Global(s) => s.apply(),
354 AnySnapshot::TransparentMutable(s) => s.apply(),
355 _ => panic!("Cannot apply a read-only snapshot"),
356 }
357 }
358
359 pub fn take_nested_mutable_snapshot(
361 &self,
362 read_observer: Option<ReadObserver>,
363 write_observer: Option<WriteObserver>,
364 ) -> AnySnapshot {
365 match self {
366 AnySnapshot::Mutable(s) => AnySnapshot::NestedMutable(
367 s.take_nested_mutable_snapshot(read_observer, write_observer),
368 ),
369 AnySnapshot::NestedMutable(s) => AnySnapshot::NestedMutable(
370 s.take_nested_mutable_snapshot(read_observer, write_observer),
371 ),
372 AnySnapshot::Global(s) => {
373 AnySnapshot::Mutable(s.take_nested_mutable_snapshot(read_observer, write_observer))
374 }
375 AnySnapshot::TransparentMutable(s) => AnySnapshot::TransparentMutable(
376 s.take_nested_mutable_snapshot(read_observer, write_observer),
377 ),
378 _ => panic!("Cannot take nested mutable snapshot from read-only snapshot"),
379 }
380 }
381}
382
383thread_local! {
384 static CURRENT_SNAPSHOT: RefCell<Option<AnySnapshot>> = const { RefCell::new(None) };
386}
387
388pub fn current_snapshot() -> Option<AnySnapshot> {
390 CURRENT_SNAPSHOT
391 .try_with(|cell| cell.borrow().clone())
392 .unwrap_or(None)
393}
394
395pub(crate) fn set_current_snapshot(snapshot: Option<AnySnapshot>) {
397 let _ = CURRENT_SNAPSHOT.try_with(|cell| {
398 *cell.borrow_mut() = snapshot;
399 });
400}
401
402pub fn take_mutable_snapshot(
411 read_observer: Option<ReadObserver>,
412 write_observer: Option<WriteObserver>,
413) -> AnyMutableSnapshot {
414 match current_snapshot() {
417 Some(AnySnapshot::Mutable(parent)) => AnyMutableSnapshot::Nested(
418 parent.take_nested_mutable_snapshot(read_observer, write_observer),
419 ),
420 Some(AnySnapshot::NestedMutable(parent)) => AnyMutableSnapshot::Nested(
421 parent.take_nested_mutable_snapshot(read_observer, write_observer),
422 ),
423 _ => AnyMutableSnapshot::Root(
425 GlobalSnapshot::get_or_create()
426 .take_nested_mutable_snapshot(read_observer, write_observer),
427 ),
428 }
429}
430
431pub fn take_transparent_observer_mutable_snapshot(
440 read_observer: Option<ReadObserver>,
441 write_observer: Option<WriteObserver>,
442) -> Arc<TransparentObserverMutableSnapshot> {
443 let parent = current_snapshot();
444 match parent {
445 Some(AnySnapshot::TransparentMutable(transparent)) if transparent.can_reuse() => {
446 transparent
448 }
449 _ => {
450 let current = current_snapshot()
453 .unwrap_or_else(|| AnySnapshot::Global(GlobalSnapshot::get_or_create()));
454 let id = current.snapshot_id();
455 let invalid = current.invalid();
456 TransparentObserverMutableSnapshot::new(
457 id,
458 invalid,
459 read_observer,
460 write_observer,
461 None,
462 )
463 }
464 }
465}
466
467pub fn allocate_record_id() -> SnapshotId {
469 runtime::allocate_record_id()
470}
471
472pub(crate) fn peek_next_snapshot_id() -> SnapshotId {
477 runtime::peek_next_snapshot_id()
478}
479
480static NEXT_OBSERVER_ID: AtomicUsize = AtomicUsize::new(1);
482
483thread_local! {
484 static APPLY_OBSERVERS: RefCell<HashMap<usize, ApplyObserver>> = RefCell::new(HashMap::default());
486}
487
488thread_local! {
489 static LAST_WRITES: RefCell<HashMap<StateObjectId, SnapshotId>> = RefCell::new(HashMap::default());
497}
498
499thread_local! {
500 static EXTRA_STATE_OBJECTS: RefCell<crate::snapshot_weak_set::SnapshotWeakSet> = RefCell::new(crate::snapshot_weak_set::SnapshotWeakSet::new());
503}
504
505const UNUSED_RECORD_CLEANUP_INTERVAL: SnapshotId = 2;
506const UNUSED_RECORD_CLEANUP_BUSY_INTERVAL: SnapshotId = 1;
507const UNUSED_RECORD_CLEANUP_MIN_SIZE: usize = 64;
508
509thread_local! {
510 static LAST_UNUSED_RECORD_CLEANUP: Cell<SnapshotId> = const { Cell::new(0) };
511}
512
513pub fn register_apply_observer(observer: ApplyObserver) -> ObserverHandle {
517 let id = NEXT_OBSERVER_ID.fetch_add(1, Ordering::SeqCst);
518 APPLY_OBSERVERS.with(|cell| {
519 cell.borrow_mut().insert(id, observer);
520 });
521 ObserverHandle {
522 kind: ObserverKind::Apply,
523 id,
524 }
525}
526
527pub struct ObserverHandle {
531 kind: ObserverKind,
532 id: usize,
533}
534
535enum ObserverKind {
536 Apply,
537}
538
539impl Drop for ObserverHandle {
540 fn drop(&mut self) {
541 match self.kind {
542 ObserverKind::Apply => {
543 APPLY_OBSERVERS.with(|cell| {
544 cell.borrow_mut().remove(&self.id);
545 });
546 }
547 }
548 }
549}
550
551pub(crate) fn notify_apply_observers(modified: &[Arc<dyn StateObject>], snapshot_id: SnapshotId) {
553 APPLY_OBSERVERS.with(|cell| {
555 let observers: Vec<ApplyObserver> = cell.borrow().values().cloned().collect();
556 for observer in observers.into_iter() {
557 observer(modified, snapshot_id);
558 }
559 });
560}
561
562#[allow(dead_code)]
564pub(crate) fn get_last_write(id: StateObjectId) -> Option<SnapshotId> {
565 LAST_WRITES.with(|cell| cell.borrow().get(&id).copied())
566}
567
568pub(crate) fn set_last_write(id: StateObjectId, snapshot_id: SnapshotId) {
570 LAST_WRITES.with(|cell| {
571 cell.borrow_mut().insert(id, snapshot_id);
572 });
573}
574
575#[cfg(test)]
577pub(crate) fn clear_last_writes() {
578 LAST_WRITES.with(|cell| {
579 cell.borrow_mut().clear();
580 });
581}
582
583pub(crate) fn check_and_overwrite_unused_records_locked() {
591 EXTRA_STATE_OBJECTS.with(|cell| {
592 cell.borrow_mut().remove_if(|state| {
593 state.overwrite_unused_records()
595 });
596 });
597}
598
599pub(crate) fn maybe_check_and_overwrite_unused_records_locked(current_snapshot_id: SnapshotId) {
600 let should_run = EXTRA_STATE_OBJECTS.with(|cell| {
601 let set = cell.borrow();
602 if set.is_empty() {
603 return false;
604 }
605 let last_cleanup = LAST_UNUSED_RECORD_CLEANUP.with(|last| last.get());
606 let interval = if set.len() >= UNUSED_RECORD_CLEANUP_MIN_SIZE {
607 UNUSED_RECORD_CLEANUP_BUSY_INTERVAL
608 } else {
609 UNUSED_RECORD_CLEANUP_INTERVAL
610 };
611 current_snapshot_id.saturating_sub(last_cleanup) >= interval
612 });
613
614 if should_run {
615 LAST_UNUSED_RECORD_CLEANUP.with(|cell| cell.set(current_snapshot_id));
616 check_and_overwrite_unused_records_locked();
617 }
618}
619
620#[allow(dead_code)]
626pub(crate) fn process_for_unused_records_locked(state: &Arc<dyn crate::state::StateObject>) {
627 if state.overwrite_unused_records() {
628 EXTRA_STATE_OBJECTS.with(|cell| {
630 cell.borrow_mut().add_trait_object(state);
631 });
632 }
633}
634
635#[cfg(test)]
636pub(crate) fn clear_unused_record_cleanup_for_tests() {
637 LAST_UNUSED_RECORD_CLEANUP.with(|cell| cell.set(0));
638}
639
640pub(crate) fn optimistic_merges(
641 current_snapshot_id: SnapshotId,
642 base_parent_id: SnapshotId,
643 modified_objects: &[(StateObjectId, Arc<dyn StateObject>, SnapshotId)],
644 invalid_snapshots: &SnapshotIdSet,
645 applying_invalid: &SnapshotIdSet,
646) -> Option<HashMap<usize, Rc<StateRecord>>> {
647 if modified_objects.is_empty() {
648 return None;
649 }
650
651 let mut result: Option<HashMap<usize, Rc<StateRecord>>> = None;
652
653 for (_, state, writer_id) in modified_objects.iter() {
654 let head = state.first_record();
655
656 let current = match crate::state::readable_record_for(
657 &head,
658 current_snapshot_id,
659 invalid_snapshots,
660 ) {
661 Some(record) => record,
662 None => continue,
663 };
664
665 let (previous_opt, found_base) =
667 mutable::find_previous_record(&head, base_parent_id, applying_invalid);
668 let previous = previous_opt?;
669
670 if !found_base || previous.snapshot_id() == crate::state::PREEXISTING_SNAPSHOT_ID {
671 continue;
672 }
673
674 if Rc::ptr_eq(¤t, &previous) {
675 continue;
676 }
677
678 let applied = mutable::find_record_by_id(&head, *writer_id)?;
679
680 let merged = state.merge_records(
681 Rc::clone(&previous),
682 Rc::clone(¤t),
683 Rc::clone(&applied),
684 )?;
685
686 result
687 .get_or_insert_with(HashMap::default)
688 .insert(Rc::as_ptr(¤t) as usize, merged);
689 }
690
691 result
692}
693
694#[allow(clippy::arc_with_non_send_sync)]
700pub fn merge_read_observers(
701 a: Option<ReadObserver>,
702 b: Option<ReadObserver>,
703) -> Option<ReadObserver> {
704 match (a, b) {
705 (None, None) => None,
706 (Some(a), None) => Some(a),
707 (None, Some(b)) => Some(b),
708 (Some(a), Some(b)) => Some(Arc::new(move |state: &dyn StateObject| {
709 a(state);
710 b(state);
711 })),
712 }
713}
714
715#[allow(clippy::arc_with_non_send_sync)]
721pub fn merge_write_observers(
722 a: Option<WriteObserver>,
723 b: Option<WriteObserver>,
724) -> Option<WriteObserver> {
725 match (a, b) {
726 (None, None) => None,
727 (Some(a), None) => Some(a),
728 (None, Some(b)) => Some(b),
729 (Some(a), Some(b)) => Some(Arc::new(move |state: &dyn StateObject| {
730 a(state);
731 b(state);
732 })),
733 }
734}
735
736pub(crate) struct SnapshotState {
738 pub(crate) id: Cell<SnapshotId>,
740 pub(crate) invalid: RefCell<SnapshotIdSet>,
742 pub(crate) pin_handle: Cell<PinHandle>,
744 pub(crate) disposed: Cell<bool>,
746 pub(crate) read_observer: Option<ReadObserver>,
748 pub(crate) write_observer: Option<WriteObserver>,
750 #[allow(clippy::type_complexity)]
752 pub(crate) modified: RefCell<HashMap<StateObjectId, (Arc<dyn StateObject>, SnapshotId)>>,
754 on_dispose: RefCell<Option<Box<dyn FnOnce()>>>,
756 runtime_tracked: bool,
758 pending_children: RefCell<HashSet<SnapshotId>>,
760}
761
762impl SnapshotState {
763 pub(crate) fn new(
764 id: SnapshotId,
765 invalid: SnapshotIdSet,
766 read_observer: Option<ReadObserver>,
767 write_observer: Option<WriteObserver>,
768 runtime_tracked: bool,
769 ) -> Self {
770 Self::new_with_pinning(
771 id,
772 invalid,
773 read_observer,
774 write_observer,
775 runtime_tracked,
776 true,
777 )
778 }
779
780 pub(crate) fn new_with_pinning(
785 id: SnapshotId,
786 invalid: SnapshotIdSet,
787 read_observer: Option<ReadObserver>,
788 write_observer: Option<WriteObserver>,
789 runtime_tracked: bool,
790 should_pin: bool,
791 ) -> Self {
792 let pin_handle = if should_pin {
793 snapshot_pinning::track_pinning(id, &invalid)
794 } else {
795 snapshot_pinning::PinHandle::INVALID
796 };
797 Self {
798 id: Cell::new(id),
799 invalid: RefCell::new(invalid),
800 pin_handle: Cell::new(pin_handle),
801 disposed: Cell::new(false),
802 read_observer,
803 write_observer,
804 modified: RefCell::new(HashMap::default()),
805 on_dispose: RefCell::new(None),
806 runtime_tracked,
807 pending_children: RefCell::new(HashSet::default()),
808 }
809 }
810
811 pub(crate) fn record_read(&self, state: &dyn StateObject) {
812 if let Some(ref observer) = self.read_observer {
813 observer(state);
814 }
815 }
816
817 pub(crate) fn record_write(&self, state: Arc<dyn StateObject>, writer_id: SnapshotId) {
818 let state_id = state.object_id().as_usize();
820
821 let mut modified = self.modified.borrow_mut();
822
823 match modified.entry(state_id) {
825 std::collections::hash_map::Entry::Vacant(e) => {
826 if let Some(ref observer) = self.write_observer {
827 observer(&*state);
828 }
829 e.insert((state, writer_id));
831 }
832 std::collections::hash_map::Entry::Occupied(mut e) => {
833 e.insert((state, writer_id));
835 }
836 }
837 }
838
839 pub(crate) fn dispose(&self) {
840 if !self.disposed.replace(true) {
841 let pin_handle = self.pin_handle.get();
842 snapshot_pinning::release_pinning(pin_handle);
843 if let Some(cb) = self.on_dispose.borrow_mut().take() {
844 cb();
845 }
846 if self.runtime_tracked {
847 close_snapshot(self.id.get());
848 }
849 }
850 }
851
852 pub(crate) fn add_pending_child(&self, id: SnapshotId) {
853 self.pending_children.borrow_mut().insert(id);
854 }
855
856 pub(crate) fn remove_pending_child(&self, id: SnapshotId) {
857 self.pending_children.borrow_mut().remove(&id);
858 }
859
860 pub(crate) fn has_pending_children(&self) -> bool {
861 !self.pending_children.borrow().is_empty()
862 }
863
864 pub(crate) fn pending_children(&self) -> Vec<SnapshotId> {
865 self.pending_children.borrow().iter().copied().collect()
866 }
867
868 pub(crate) fn set_on_dispose<F>(&self, f: F)
869 where
870 F: FnOnce() + 'static,
871 {
872 *self.on_dispose.borrow_mut() = Some(Box::new(f));
873 }
874}
875
876#[cfg(test)]
877mod tests {
878 use super::*;
879
880 #[test]
881 fn test_apply_result_is_success() {
882 assert!(SnapshotApplyResult::Success.is_success());
883 assert!(!SnapshotApplyResult::Failure.is_success());
884 }
885
886 #[test]
887 fn test_apply_result_is_failure() {
888 assert!(!SnapshotApplyResult::Success.is_failure());
889 assert!(SnapshotApplyResult::Failure.is_failure());
890 }
891
892 #[test]
893 fn test_apply_result_check_success() {
894 SnapshotApplyResult::Success.check(); }
896
897 #[test]
898 #[should_panic(expected = "Snapshot apply failed")]
899 fn test_apply_result_check_failure() {
900 SnapshotApplyResult::Failure.check(); }
902
903 #[test]
904 fn test_merge_read_observers_both_none() {
905 let result = merge_read_observers(None, None);
906 assert!(result.is_none());
907 }
908
909 #[test]
910 fn test_merge_read_observers_one_some() {
911 let observer = Arc::new(|_: &dyn StateObject| {});
912 let result = merge_read_observers(Some(observer.clone()), None);
913 assert!(result.is_some());
914
915 let result = merge_read_observers(None, Some(observer));
916 assert!(result.is_some());
917 }
918
919 #[test]
920 fn test_merge_write_observers_both_none() {
921 let result = merge_write_observers(None, None);
922 assert!(result.is_none());
923 }
924
925 #[test]
926 fn test_merge_write_observers_one_some() {
927 let observer = Arc::new(|_: &dyn StateObject| {});
928 let result = merge_write_observers(Some(observer.clone()), None);
929 assert!(result.is_some());
930
931 let result = merge_write_observers(None, Some(observer));
932 assert!(result.is_some());
933 }
934
935 #[test]
936 fn test_current_snapshot_none_initially() {
937 set_current_snapshot(None);
938 assert!(current_snapshot().is_none());
939 }
940
941 struct TestStateObject {
943 id: usize,
944 }
945
946 impl TestStateObject {
947 fn new(id: usize) -> Arc<Self> {
948 Arc::new(Self { id })
949 }
950 }
951
952 impl StateObject for TestStateObject {
953 fn object_id(&self) -> crate::state::ObjectId {
954 crate::state::ObjectId(self.id)
955 }
956
957 fn first_record(&self) -> Rc<crate::state::StateRecord> {
958 unimplemented!("Not needed for observer tests")
959 }
960
961 fn readable_record(
962 &self,
963 _snapshot_id: SnapshotId,
964 _invalid: &SnapshotIdSet,
965 ) -> Rc<crate::state::StateRecord> {
966 unimplemented!("Not needed for observer tests")
967 }
968
969 fn prepend_state_record(&self, _record: Rc<crate::state::StateRecord>) {
970 unimplemented!("Not needed for observer tests")
971 }
972
973 fn promote_record(&self, _child_id: SnapshotId) -> Result<(), &'static str> {
974 unimplemented!("Not needed for observer tests")
975 }
976
977 fn as_any(&self) -> &dyn std::any::Any {
978 self
979 }
980 }
981
982 #[test]
983 fn test_apply_observer_receives_correct_modified_objects() {
984 use std::sync::Mutex;
985
986 let received_count = Arc::new(Mutex::new(0));
988 let received_snapshot_id = Arc::new(Mutex::new(0));
989
990 let received_count_clone = received_count.clone();
991 let received_snapshot_id_clone = received_snapshot_id.clone();
992
993 let _handle = register_apply_observer(Arc::new(move |modified, snapshot_id| {
995 *received_snapshot_id_clone.lock().unwrap() = snapshot_id;
996 *received_count_clone.lock().unwrap() = modified.len();
997 }));
998
999 let obj1: Arc<dyn StateObject> = TestStateObject::new(42);
1001 let obj2: Arc<dyn StateObject> = TestStateObject::new(99);
1002 let modified = vec![obj1, obj2];
1003
1004 notify_apply_observers(&modified, 123);
1006
1007 assert_eq!(*received_snapshot_id.lock().unwrap(), 123);
1009 assert_eq!(*received_count.lock().unwrap(), 2);
1010 }
1011
1012 #[test]
1013 fn test_apply_observer_receives_correct_snapshot_id() {
1014 use std::sync::Mutex;
1015
1016 let received_id = Arc::new(Mutex::new(0));
1017 let received_id_clone = received_id.clone();
1018
1019 let _handle = register_apply_observer(Arc::new(move |_, snapshot_id| {
1020 *received_id_clone.lock().unwrap() = snapshot_id;
1021 }));
1022
1023 notify_apply_observers(&[], 456);
1025
1026 assert_eq!(*received_id.lock().unwrap(), 456);
1027 }
1028
1029 #[test]
1030 fn test_multiple_apply_observers_all_called() {
1031 use std::sync::Mutex;
1032
1033 let call_count1 = Arc::new(Mutex::new(0));
1034 let call_count2 = Arc::new(Mutex::new(0));
1035 let call_count3 = Arc::new(Mutex::new(0));
1036
1037 let call_count1_clone = call_count1.clone();
1038 let call_count2_clone = call_count2.clone();
1039 let call_count3_clone = call_count3.clone();
1040
1041 let _handle1 = register_apply_observer(Arc::new(move |_, _| {
1043 *call_count1_clone.lock().unwrap() += 1;
1044 }));
1045
1046 let _handle2 = register_apply_observer(Arc::new(move |_, _| {
1047 *call_count2_clone.lock().unwrap() += 1;
1048 }));
1049
1050 let _handle3 = register_apply_observer(Arc::new(move |_, _| {
1051 *call_count3_clone.lock().unwrap() += 1;
1052 }));
1053
1054 notify_apply_observers(&[], 1);
1056
1057 assert_eq!(*call_count1.lock().unwrap(), 1);
1059 assert_eq!(*call_count2.lock().unwrap(), 1);
1060 assert_eq!(*call_count3.lock().unwrap(), 1);
1061
1062 notify_apply_observers(&[], 2);
1064
1065 assert_eq!(*call_count1.lock().unwrap(), 2);
1067 assert_eq!(*call_count2.lock().unwrap(), 2);
1068 assert_eq!(*call_count3.lock().unwrap(), 2);
1069 }
1070
1071 #[test]
1072 fn test_apply_observer_not_called_for_empty_modifications() {
1073 use std::sync::Mutex;
1074
1075 let call_count = Arc::new(Mutex::new(0));
1076 let call_count_clone = call_count.clone();
1077
1078 let _handle = register_apply_observer(Arc::new(move |modified, _| {
1079 *call_count_clone.lock().unwrap() += 1;
1081 assert_eq!(modified.len(), 0);
1082 }));
1083
1084 notify_apply_observers(&[], 1);
1086
1087 assert_eq!(*call_count.lock().unwrap(), 1);
1089 }
1090
1091 #[test]
1092 fn test_observer_handle_drop_removes_correct_observer() {
1093 use std::sync::Mutex;
1094
1095 let calls = Arc::new(Mutex::new(Vec::new()));
1097
1098 let calls1 = calls.clone();
1099 let handle1 = register_apply_observer(Arc::new(move |_, _| {
1100 calls1.lock().unwrap().push(1);
1101 }));
1102
1103 let calls2 = calls.clone();
1104 let handle2 = register_apply_observer(Arc::new(move |_, _| {
1105 calls2.lock().unwrap().push(2);
1106 }));
1107
1108 let calls3 = calls.clone();
1109 let handle3 = register_apply_observer(Arc::new(move |_, _| {
1110 calls3.lock().unwrap().push(3);
1111 }));
1112
1113 notify_apply_observers(&[], 1);
1115 let result = calls.lock().unwrap().clone();
1116 assert_eq!(result.len(), 3);
1117 assert!(result.contains(&1));
1118 assert!(result.contains(&2));
1119 assert!(result.contains(&3));
1120 calls.lock().unwrap().clear();
1121
1122 drop(handle2);
1124
1125 notify_apply_observers(&[], 2);
1127 let result = calls.lock().unwrap().clone();
1128 assert_eq!(result.len(), 2);
1129 assert!(result.contains(&1));
1130 assert!(result.contains(&3));
1131 assert!(!result.contains(&2));
1132 calls.lock().unwrap().clear();
1133
1134 drop(handle1);
1136
1137 notify_apply_observers(&[], 3);
1139 let result = calls.lock().unwrap().clone();
1140 assert_eq!(result.len(), 1);
1141 assert!(result.contains(&3));
1142 calls.lock().unwrap().clear();
1143
1144 drop(handle3);
1146
1147 notify_apply_observers(&[], 4);
1149 assert_eq!(calls.lock().unwrap().len(), 0);
1150 }
1151
1152 #[test]
1153 fn test_observer_handle_drop_in_different_orders() {
1154 use std::sync::Mutex;
1155
1156 {
1158 let calls = Arc::new(Mutex::new(Vec::new()));
1159
1160 let calls1 = calls.clone();
1161 let h1 = register_apply_observer(Arc::new(move |_, _| {
1162 calls1.lock().unwrap().push(1);
1163 }));
1164
1165 let calls2 = calls.clone();
1166 let h2 = register_apply_observer(Arc::new(move |_, _| {
1167 calls2.lock().unwrap().push(2);
1168 }));
1169
1170 let calls3 = calls.clone();
1171 let h3 = register_apply_observer(Arc::new(move |_, _| {
1172 calls3.lock().unwrap().push(3);
1173 }));
1174
1175 drop(h3);
1176 notify_apply_observers(&[], 1);
1177 let result = calls.lock().unwrap().clone();
1178 assert!(result.contains(&1) && result.contains(&2) && !result.contains(&3));
1179 calls.lock().unwrap().clear();
1180
1181 drop(h2);
1182 notify_apply_observers(&[], 2);
1183 let result = calls.lock().unwrap().clone();
1184 assert_eq!(result.len(), 1);
1185 assert!(result.contains(&1));
1186 calls.lock().unwrap().clear();
1187
1188 drop(h1);
1189 notify_apply_observers(&[], 3);
1190 assert_eq!(calls.lock().unwrap().len(), 0);
1191 }
1192
1193 {
1195 let calls = Arc::new(Mutex::new(Vec::new()));
1196
1197 let calls1 = calls.clone();
1198 let h1 = register_apply_observer(Arc::new(move |_, _| {
1199 calls1.lock().unwrap().push(1);
1200 }));
1201
1202 let calls2 = calls.clone();
1203 let h2 = register_apply_observer(Arc::new(move |_, _| {
1204 calls2.lock().unwrap().push(2);
1205 }));
1206
1207 let calls3 = calls.clone();
1208 let h3 = register_apply_observer(Arc::new(move |_, _| {
1209 calls3.lock().unwrap().push(3);
1210 }));
1211
1212 drop(h1);
1213 notify_apply_observers(&[], 1);
1214 let result = calls.lock().unwrap().clone();
1215 assert!(!result.contains(&1) && result.contains(&2) && result.contains(&3));
1216 calls.lock().unwrap().clear();
1217
1218 drop(h2);
1219 notify_apply_observers(&[], 2);
1220 let result = calls.lock().unwrap().clone();
1221 assert_eq!(result.len(), 1);
1222 assert!(result.contains(&3));
1223 calls.lock().unwrap().clear();
1224
1225 drop(h3);
1226 notify_apply_observers(&[], 3);
1227 assert_eq!(calls.lock().unwrap().len(), 0);
1228 }
1229 }
1230
1231 #[test]
1232 fn test_remaining_observers_still_work_after_drop() {
1233 use std::sync::Mutex;
1234
1235 let calls = Arc::new(Mutex::new(Vec::new()));
1236
1237 let calls1 = calls.clone();
1238 let handle1 = register_apply_observer(Arc::new(move |_, snapshot_id| {
1239 calls1.lock().unwrap().push((1, snapshot_id));
1240 }));
1241
1242 let calls2 = calls.clone();
1243 let handle2 = register_apply_observer(Arc::new(move |_, snapshot_id| {
1244 calls2.lock().unwrap().push((2, snapshot_id));
1245 }));
1246
1247 notify_apply_observers(&[], 100);
1249 assert_eq!(calls.lock().unwrap().len(), 2);
1250 calls.lock().unwrap().clear();
1251
1252 drop(handle1);
1254
1255 notify_apply_observers(&[], 200);
1257 assert_eq!(*calls.lock().unwrap(), vec![(2, 200)]);
1258 calls.lock().unwrap().clear();
1259
1260 let calls3 = calls.clone();
1262 let _handle3 = register_apply_observer(Arc::new(move |_, snapshot_id| {
1263 calls3.lock().unwrap().push((3, snapshot_id));
1264 }));
1265
1266 notify_apply_observers(&[], 300);
1268 let result = calls.lock().unwrap().clone();
1269 assert_eq!(result.len(), 2);
1270 assert!(result.contains(&(2, 300)));
1271 assert!(result.contains(&(3, 300)));
1272
1273 drop(handle2);
1274 }
1275
1276 #[test]
1277 fn test_observer_ids_are_unique() {
1278 use std::sync::Mutex;
1279
1280 let ids = Arc::new(Mutex::new(std::collections::HashSet::new()));
1281
1282 let mut handles = Vec::new();
1283
1284 for i in 0..100 {
1288 let ids_clone = ids.clone();
1289 let handle = register_apply_observer(Arc::new(move |_, _| {
1290 ids_clone.lock().unwrap().insert(i);
1291 }));
1292 handles.push(handle);
1293 }
1294
1295 notify_apply_observers(&[], 1);
1297 assert_eq!(ids.lock().unwrap().len(), 100);
1298
1299 for i in (0..100).step_by(2) {
1301 handles.remove(i / 2);
1302 }
1303
1304 ids.lock().unwrap().clear();
1306 notify_apply_observers(&[], 2);
1307 assert_eq!(ids.lock().unwrap().len(), 50);
1308 }
1309
1310 #[test]
1311 fn test_state_object_storage_in_modified_set() {
1312 use crate::state::StateObject;
1313 use std::cell::Cell;
1314
1315 #[allow(dead_code)]
1317 struct TestState {
1318 value: Cell<i32>,
1319 }
1320
1321 impl StateObject for TestState {
1322 fn object_id(&self) -> crate::state::ObjectId {
1323 crate::state::ObjectId(12345)
1324 }
1325
1326 fn first_record(&self) -> Rc<crate::state::StateRecord> {
1327 unimplemented!("Not needed for this test")
1328 }
1329
1330 fn readable_record(
1331 &self,
1332 _snapshot_id: SnapshotId,
1333 _invalid: &SnapshotIdSet,
1334 ) -> Rc<crate::state::StateRecord> {
1335 unimplemented!("Not needed for this test")
1336 }
1337
1338 fn prepend_state_record(&self, _record: Rc<crate::state::StateRecord>) {
1339 unimplemented!("Not needed for this test")
1340 }
1341
1342 fn promote_record(&self, _child_id: SnapshotId) -> Result<(), &'static str> {
1343 unimplemented!("Not needed for this test")
1344 }
1345
1346 fn as_any(&self) -> &dyn std::any::Any {
1347 self
1348 }
1349 }
1350
1351 let state = SnapshotState::new(1, SnapshotIdSet::new(), None, None, false);
1352
1353 let state_obj = Arc::new(TestState {
1355 value: Cell::new(42),
1356 }) as Arc<dyn StateObject>;
1357
1358 state.record_write(state_obj.clone(), 1);
1360
1361 let modified = state.modified.borrow();
1363 assert_eq!(modified.len(), 1);
1364 assert!(modified.contains_key(&12345));
1365
1366 let (stored, writer_id) = modified.get(&12345).unwrap();
1368 assert_eq!(stored.object_id().as_usize(), 12345);
1369 assert_eq!(*writer_id, 1);
1370 }
1371
1372 #[test]
1373 fn test_multiple_writes_to_same_state_object() {
1374 use crate::state::StateObject;
1375 use std::cell::Cell;
1376
1377 #[allow(dead_code)]
1378 struct TestState {
1379 value: Cell<i32>,
1380 }
1381
1382 impl StateObject for TestState {
1383 fn object_id(&self) -> crate::state::ObjectId {
1384 crate::state::ObjectId(99999)
1385 }
1386
1387 fn first_record(&self) -> Rc<crate::state::StateRecord> {
1388 unimplemented!()
1389 }
1390
1391 fn readable_record(
1392 &self,
1393 _snapshot_id: SnapshotId,
1394 _invalid: &SnapshotIdSet,
1395 ) -> Rc<crate::state::StateRecord> {
1396 unimplemented!()
1397 }
1398
1399 fn prepend_state_record(&self, _record: Rc<crate::state::StateRecord>) {
1400 unimplemented!()
1401 }
1402
1403 fn promote_record(&self, _child_id: SnapshotId) -> Result<(), &'static str> {
1404 unimplemented!()
1405 }
1406
1407 fn as_any(&self) -> &dyn std::any::Any {
1408 self
1409 }
1410 }
1411
1412 let state = SnapshotState::new(1, SnapshotIdSet::new(), None, None, false);
1413 let state_obj = Arc::new(TestState {
1414 value: Cell::new(100),
1415 }) as Arc<dyn StateObject>;
1416
1417 state.record_write(state_obj.clone(), 1);
1419 assert_eq!(state.modified.borrow().len(), 1);
1420
1421 state.record_write(state_obj.clone(), 2);
1423 let modified = state.modified.borrow();
1424 assert_eq!(modified.len(), 1);
1425 assert!(modified.contains_key(&99999));
1426 let (_, writer_id) = modified.get(&99999).unwrap();
1427 assert_eq!(*writer_id, 2);
1428 }
1429}