1#![allow(clippy::arc_with_non_send_sync)]
24
25use crate::collections::map::HashMap;
26use crate::collections::map::HashSet;
27use crate::snapshot_id_set::{SnapshotId, SnapshotIdSet};
28use crate::snapshot_pinning::{self, PinHandle};
29use crate::snapshot_weak_set::SnapshotWeakSetDebugStats;
30use crate::state::{StateObject, StateRecord};
31use std::cell::{Cell, RefCell};
32use std::rc::Rc;
33use std::sync::atomic::{AtomicUsize, Ordering};
34use std::sync::{Arc, Weak};
35
36mod global;
37mod mutable;
38mod nested;
39mod readonly;
40mod runtime;
41mod transparent;
42
43#[cfg(test)]
44mod integration_tests;
45
46pub use global::{advance_global_snapshot, GlobalSnapshot};
47pub use mutable::MutableSnapshot;
48pub use nested::{NestedMutableSnapshot, NestedReadonlySnapshot};
49pub use readonly::ReadonlySnapshot;
50pub use transparent::{TransparentObserverMutableSnapshot, TransparentObserverSnapshot};
51
52pub(crate) use runtime::{allocate_snapshot, close_snapshot, with_runtime};
53
54#[cfg(test)]
55pub(crate) use runtime::{reset_runtime_for_tests, TestRuntimeGuard};
56
57pub type ReadObserver = Arc<dyn Fn(&dyn StateObject) + 'static>;
59
60pub type WriteObserver = Arc<dyn Fn(&dyn StateObject) + 'static>;
62
63pub type ApplyObserver = Rc<dyn Fn(&[Arc<dyn StateObject>], SnapshotId) + 'static>;
65
66#[derive(Debug, Clone, PartialEq, Eq)]
68pub enum SnapshotApplyResult {
69 Success,
71 Failure,
73}
74
75impl SnapshotApplyResult {
76 pub fn is_success(&self) -> bool {
78 matches!(self, SnapshotApplyResult::Success)
79 }
80
81 pub fn is_failure(&self) -> bool {
83 matches!(self, SnapshotApplyResult::Failure)
84 }
85
86 #[track_caller]
88 pub fn check(&self) {
89 if self.is_failure() {
90 panic!("Snapshot apply failed");
91 }
92 }
93}
94
95pub type StateObjectId = usize;
97
98#[derive(Clone)]
103pub enum AnySnapshot {
104 Readonly(Arc<ReadonlySnapshot>),
105 Mutable(Arc<MutableSnapshot>),
106 NestedReadonly(Arc<NestedReadonlySnapshot>),
107 NestedMutable(Arc<NestedMutableSnapshot>),
108 Global(Arc<GlobalSnapshot>),
109 TransparentMutable(Arc<TransparentObserverMutableSnapshot>),
110 TransparentReadonly(Arc<TransparentObserverSnapshot>),
111}
112
113#[derive(Clone)]
120pub enum AnyMutableSnapshot {
121 Root(Arc<MutableSnapshot>),
122 Nested(Arc<NestedMutableSnapshot>),
123}
124
125impl AnyMutableSnapshot {
126 pub fn snapshot_id(&self) -> SnapshotId {
128 match self {
129 AnyMutableSnapshot::Root(s) => s.snapshot_id(),
130 AnyMutableSnapshot::Nested(s) => s.snapshot_id(),
131 }
132 }
133
134 pub fn invalid(&self) -> SnapshotIdSet {
136 match self {
137 AnyMutableSnapshot::Root(s) => s.invalid(),
138 AnyMutableSnapshot::Nested(s) => s.invalid(),
139 }
140 }
141
142 pub fn enter<T>(&self, f: impl FnOnce() -> T) -> T {
144 match self {
145 AnyMutableSnapshot::Root(s) => s.enter(f),
146 AnyMutableSnapshot::Nested(s) => s.enter(f),
147 }
148 }
149
150 pub fn apply(&self) -> SnapshotApplyResult {
152 match self {
153 AnyMutableSnapshot::Root(s) => s.apply(),
154 AnyMutableSnapshot::Nested(s) => s.apply(),
155 }
156 }
157
158 pub fn dispose(&self) {
160 match self {
161 AnyMutableSnapshot::Root(s) => s.dispose(),
162 AnyMutableSnapshot::Nested(s) => s.dispose(),
163 }
164 }
165}
166
167impl AnySnapshot {
168 pub fn snapshot_id(&self) -> SnapshotId {
170 match self {
171 AnySnapshot::Readonly(s) => s.snapshot_id(),
172 AnySnapshot::Mutable(s) => s.snapshot_id(),
173 AnySnapshot::NestedReadonly(s) => s.snapshot_id(),
174 AnySnapshot::NestedMutable(s) => s.snapshot_id(),
175 AnySnapshot::Global(s) => s.snapshot_id(),
176 AnySnapshot::TransparentMutable(s) => s.snapshot_id(),
177 AnySnapshot::TransparentReadonly(s) => s.snapshot_id(),
178 }
179 }
180
181 pub fn invalid(&self) -> SnapshotIdSet {
183 match self {
184 AnySnapshot::Readonly(s) => s.invalid(),
185 AnySnapshot::Mutable(s) => s.invalid(),
186 AnySnapshot::NestedReadonly(s) => s.invalid(),
187 AnySnapshot::NestedMutable(s) => s.invalid(),
188 AnySnapshot::Global(s) => s.invalid(),
189 AnySnapshot::TransparentMutable(s) => s.invalid(),
190 AnySnapshot::TransparentReadonly(s) => s.invalid(),
191 }
192 }
193
194 pub fn is_valid(&self, id: SnapshotId) -> bool {
196 let snapshot_id = self.snapshot_id();
197 id <= snapshot_id && !self.invalid().get(id)
198 }
199
200 pub fn read_only(&self) -> bool {
202 match self {
203 AnySnapshot::Readonly(_) => true,
204 AnySnapshot::Mutable(_) => false,
205 AnySnapshot::NestedReadonly(_) => true,
206 AnySnapshot::NestedMutable(_) => false,
207 AnySnapshot::Global(_) => false,
208 AnySnapshot::TransparentMutable(_) => false,
209 AnySnapshot::TransparentReadonly(_) => true,
210 }
211 }
212
213 pub fn root(&self) -> AnySnapshot {
215 match self {
216 AnySnapshot::Readonly(s) => AnySnapshot::Readonly(s.root_readonly()),
217 AnySnapshot::Mutable(s) => AnySnapshot::Mutable(s.root_mutable()),
218 AnySnapshot::NestedReadonly(s) => AnySnapshot::NestedReadonly(s.root_nested_readonly()),
219 AnySnapshot::NestedMutable(s) => AnySnapshot::Mutable(s.root_mutable()),
220 AnySnapshot::Global(s) => AnySnapshot::Global(s.root_global()),
221 AnySnapshot::TransparentMutable(s) => {
222 AnySnapshot::TransparentMutable(s.root_transparent_mutable())
223 }
224 AnySnapshot::TransparentReadonly(s) => {
225 AnySnapshot::TransparentReadonly(s.root_transparent_readonly())
226 }
227 }
228 }
229
230 pub fn is_same_transparent(&self, other: &Arc<TransparentObserverMutableSnapshot>) -> bool {
232 matches!(self, AnySnapshot::TransparentMutable(snapshot) if Arc::ptr_eq(snapshot, other))
233 }
234
235 pub fn is_same_transparent_mutable(
237 &self,
238 other: &Arc<TransparentObserverMutableSnapshot>,
239 ) -> bool {
240 self.is_same_transparent(other)
241 }
242
243 pub fn is_same_transparent_readonly(&self, other: &Arc<TransparentObserverSnapshot>) -> bool {
245 matches!(self, AnySnapshot::TransparentReadonly(snapshot) if Arc::ptr_eq(snapshot, other))
246 }
247
248 pub fn enter<T>(&self, f: impl FnOnce() -> T) -> T {
250 match self {
251 AnySnapshot::Readonly(s) => s.enter(f),
252 AnySnapshot::Mutable(s) => s.enter(f),
253 AnySnapshot::NestedReadonly(s) => s.enter(f),
254 AnySnapshot::NestedMutable(s) => s.enter(f),
255 AnySnapshot::Global(s) => s.enter(f),
256 AnySnapshot::TransparentMutable(s) => s.enter(f),
257 AnySnapshot::TransparentReadonly(s) => s.enter(f),
258 }
259 }
260
261 pub fn take_nested_snapshot(&self, read_observer: Option<ReadObserver>) -> AnySnapshot {
263 match self {
264 AnySnapshot::Readonly(s) => {
265 AnySnapshot::Readonly(s.take_nested_snapshot(read_observer))
266 }
267 AnySnapshot::Mutable(s) => AnySnapshot::Readonly(s.take_nested_snapshot(read_observer)),
268 AnySnapshot::NestedReadonly(s) => {
269 AnySnapshot::NestedReadonly(s.take_nested_snapshot(read_observer))
270 }
271 AnySnapshot::NestedMutable(s) => {
272 AnySnapshot::Readonly(s.take_nested_snapshot(read_observer))
273 }
274 AnySnapshot::Global(s) => AnySnapshot::Readonly(s.take_nested_snapshot(read_observer)),
275 AnySnapshot::TransparentMutable(s) => {
276 AnySnapshot::Readonly(s.take_nested_snapshot(read_observer))
277 }
278 AnySnapshot::TransparentReadonly(s) => {
279 AnySnapshot::TransparentReadonly(s.take_nested_snapshot(read_observer))
280 }
281 }
282 }
283
284 pub fn has_pending_changes(&self) -> bool {
286 match self {
287 AnySnapshot::Readonly(s) => s.has_pending_changes(),
288 AnySnapshot::Mutable(s) => s.has_pending_changes(),
289 AnySnapshot::NestedReadonly(s) => s.has_pending_changes(),
290 AnySnapshot::NestedMutable(s) => s.has_pending_changes(),
291 AnySnapshot::Global(s) => s.has_pending_changes(),
292 AnySnapshot::TransparentMutable(s) => s.has_pending_changes(),
293 AnySnapshot::TransparentReadonly(s) => s.has_pending_changes(),
294 }
295 }
296
297 pub fn dispose(&self) {
299 match self {
300 AnySnapshot::Readonly(s) => s.dispose(),
301 AnySnapshot::Mutable(s) => s.dispose(),
302 AnySnapshot::NestedReadonly(s) => s.dispose(),
303 AnySnapshot::NestedMutable(s) => s.dispose(),
304 AnySnapshot::Global(s) => s.dispose(),
305 AnySnapshot::TransparentMutable(s) => s.dispose(),
306 AnySnapshot::TransparentReadonly(s) => s.dispose(),
307 }
308 }
309
310 pub fn is_disposed(&self) -> bool {
312 match self {
313 AnySnapshot::Readonly(s) => s.is_disposed(),
314 AnySnapshot::Mutable(s) => s.is_disposed(),
315 AnySnapshot::NestedReadonly(s) => s.is_disposed(),
316 AnySnapshot::NestedMutable(s) => s.is_disposed(),
317 AnySnapshot::Global(s) => s.is_disposed(),
318 AnySnapshot::TransparentMutable(s) => s.is_disposed(),
319 AnySnapshot::TransparentReadonly(s) => s.is_disposed(),
320 }
321 }
322
323 pub fn record_read(&self, state: &dyn StateObject) {
325 match self {
326 AnySnapshot::Readonly(s) => s.record_read(state),
327 AnySnapshot::Mutable(s) => s.record_read(state),
328 AnySnapshot::NestedReadonly(s) => s.record_read(state),
329 AnySnapshot::NestedMutable(s) => s.record_read(state),
330 AnySnapshot::Global(s) => s.record_read(state),
331 AnySnapshot::TransparentMutable(s) => s.record_read(state),
332 AnySnapshot::TransparentReadonly(s) => s.record_read(state),
333 }
334 }
335
336 pub fn record_write(&self, state: Arc<dyn StateObject>) {
338 match self {
339 AnySnapshot::Readonly(s) => s.record_write(state),
340 AnySnapshot::Mutable(s) => s.record_write(state),
341 AnySnapshot::NestedReadonly(s) => s.record_write(state),
342 AnySnapshot::NestedMutable(s) => s.record_write(state),
343 AnySnapshot::Global(s) => s.record_write(state),
344 AnySnapshot::TransparentMutable(s) => s.record_write(state),
345 AnySnapshot::TransparentReadonly(s) => s.record_write(state),
346 }
347 }
348
349 pub fn apply(&self) -> SnapshotApplyResult {
351 match self {
352 AnySnapshot::Mutable(s) => s.apply(),
353 AnySnapshot::NestedMutable(s) => s.apply(),
354 AnySnapshot::Global(s) => s.apply(),
355 AnySnapshot::TransparentMutable(s) => s.apply(),
356 _ => panic!("Cannot apply a read-only snapshot"),
357 }
358 }
359
360 pub fn take_nested_mutable_snapshot(
362 &self,
363 read_observer: Option<ReadObserver>,
364 write_observer: Option<WriteObserver>,
365 ) -> AnySnapshot {
366 match self {
367 AnySnapshot::Mutable(s) => AnySnapshot::NestedMutable(
368 s.take_nested_mutable_snapshot(read_observer, write_observer),
369 ),
370 AnySnapshot::NestedMutable(s) => AnySnapshot::NestedMutable(
371 s.take_nested_mutable_snapshot(read_observer, write_observer),
372 ),
373 AnySnapshot::Global(s) => {
374 AnySnapshot::Mutable(s.take_nested_mutable_snapshot(read_observer, write_observer))
375 }
376 AnySnapshot::TransparentMutable(s) => AnySnapshot::TransparentMutable(
377 s.take_nested_mutable_snapshot(read_observer, write_observer),
378 ),
379 _ => panic!("Cannot take nested mutable snapshot from read-only snapshot"),
380 }
381 }
382}
383
384thread_local! {
385 static CURRENT_SNAPSHOT: RefCell<Option<AnySnapshot>> = const { RefCell::new(None) };
387}
388
389pub fn current_snapshot() -> Option<AnySnapshot> {
391 CURRENT_SNAPSHOT
392 .try_with(|cell| cell.borrow().clone())
393 .unwrap_or(None)
394}
395
396pub(crate) fn set_current_snapshot(snapshot: Option<AnySnapshot>) {
398 let _ = CURRENT_SNAPSHOT.try_with(|cell| {
399 *cell.borrow_mut() = snapshot;
400 });
401}
402
403pub fn take_mutable_snapshot(
412 read_observer: Option<ReadObserver>,
413 write_observer: Option<WriteObserver>,
414) -> AnyMutableSnapshot {
415 match current_snapshot() {
418 Some(AnySnapshot::Mutable(parent)) => AnyMutableSnapshot::Nested(
419 parent.take_nested_mutable_snapshot(read_observer, write_observer),
420 ),
421 Some(AnySnapshot::NestedMutable(parent)) => AnyMutableSnapshot::Nested(
422 parent.take_nested_mutable_snapshot(read_observer, write_observer),
423 ),
424 _ => AnyMutableSnapshot::Root(
426 GlobalSnapshot::get_or_create()
427 .take_nested_mutable_snapshot(read_observer, write_observer),
428 ),
429 }
430}
431
432pub fn take_transparent_observer_mutable_snapshot(
441 read_observer: Option<ReadObserver>,
442 write_observer: Option<WriteObserver>,
443) -> Arc<TransparentObserverMutableSnapshot> {
444 let parent = current_snapshot();
445 match parent {
446 Some(AnySnapshot::TransparentMutable(transparent)) if transparent.can_reuse() => {
447 transparent
449 }
450 _ => {
451 let current = current_snapshot()
454 .unwrap_or_else(|| AnySnapshot::Global(GlobalSnapshot::get_or_create()));
455 let id = current.snapshot_id();
456 let invalid = current.invalid();
457 TransparentObserverMutableSnapshot::new(
458 id,
459 invalid,
460 read_observer,
461 write_observer,
462 None,
463 )
464 }
465 }
466}
467
468pub fn allocate_record_id() -> SnapshotId {
470 runtime::allocate_record_id()
471}
472
473pub(crate) fn peek_next_snapshot_id() -> SnapshotId {
478 runtime::peek_next_snapshot_id()
479}
480
481static NEXT_OBSERVER_ID: AtomicUsize = AtomicUsize::new(1);
483
484thread_local! {
485 static APPLY_OBSERVERS: RefCell<HashMap<usize, ApplyObserver>> = RefCell::new(HashMap::default());
487}
488
489thread_local! {
490 static LAST_WRITES: RefCell<HashMap<StateObjectId, SnapshotId>> = RefCell::new(HashMap::default());
498}
499
500thread_local! {
501 static EXTRA_STATE_OBJECTS: RefCell<crate::snapshot_weak_set::SnapshotWeakSet> = RefCell::new(crate::snapshot_weak_set::SnapshotWeakSet::new());
504}
505
506const UNUSED_RECORD_CLEANUP_INTERVAL: SnapshotId = 2;
507const UNUSED_RECORD_CLEANUP_BUSY_INTERVAL: SnapshotId = 1;
508const UNUSED_RECORD_CLEANUP_MIN_SIZE: usize = 64;
509
510thread_local! {
511 static LAST_UNUSED_RECORD_CLEANUP: Cell<SnapshotId> = const { Cell::new(0) };
512}
513
514#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
515pub struct SnapshotV2DebugStats {
516 pub apply_observers_len: usize,
517 pub apply_observers_cap: usize,
518 pub last_writes_len: usize,
519 pub last_writes_cap: usize,
520 pub extra_state_objects_len: usize,
521 pub extra_state_objects_cap: usize,
522 pub last_unused_record_cleanup: SnapshotId,
523}
524
525pub fn debug_snapshot_v2_stats() -> SnapshotV2DebugStats {
526 let (apply_observers_len, apply_observers_cap) = APPLY_OBSERVERS.with(|cell| {
527 let observers = cell.borrow();
528 (observers.len(), observers.capacity())
529 });
530 let (last_writes_len, last_writes_cap) = LAST_WRITES.with(|cell| {
531 let writes = cell.borrow();
532 (writes.len(), writes.capacity())
533 });
534 let SnapshotWeakSetDebugStats {
535 len: extra_state_objects_len,
536 capacity: extra_state_objects_cap,
537 } = EXTRA_STATE_OBJECTS.with(|cell| cell.borrow().debug_stats());
538 let last_unused_record_cleanup = LAST_UNUSED_RECORD_CLEANUP.with(|cell| cell.get());
539
540 SnapshotV2DebugStats {
541 apply_observers_len,
542 apply_observers_cap,
543 last_writes_len,
544 last_writes_cap,
545 extra_state_objects_len,
546 extra_state_objects_cap,
547 last_unused_record_cleanup,
548 }
549}
550
551pub fn register_apply_observer(observer: ApplyObserver) -> ObserverHandle {
555 let id = NEXT_OBSERVER_ID.fetch_add(1, Ordering::SeqCst);
556 APPLY_OBSERVERS.with(|cell| {
557 cell.borrow_mut().insert(id, observer);
558 });
559 ObserverHandle {
560 kind: ObserverKind::Apply,
561 id,
562 }
563}
564
565pub struct ObserverHandle {
569 kind: ObserverKind,
570 id: usize,
571}
572
573enum ObserverKind {
574 Apply,
575}
576
577impl Drop for ObserverHandle {
578 fn drop(&mut self) {
579 match self.kind {
580 ObserverKind::Apply => {
581 APPLY_OBSERVERS.with(|cell| {
582 cell.borrow_mut().remove(&self.id);
583 });
584 }
585 }
586 }
587}
588
589pub(crate) fn notify_apply_observers(modified: &[Arc<dyn StateObject>], snapshot_id: SnapshotId) {
591 APPLY_OBSERVERS.with(|cell| {
593 let observers: Vec<ApplyObserver> = cell.borrow().values().cloned().collect();
594 for observer in observers.into_iter() {
595 observer(modified, snapshot_id);
596 }
597 });
598}
599
600pub(crate) fn set_last_write(id: StateObjectId, snapshot_id: SnapshotId) {
602 LAST_WRITES.with(|cell| {
603 cell.borrow_mut().insert(id, snapshot_id);
604 });
605}
606
607#[cfg(test)]
609pub(crate) fn clear_last_writes() {
610 LAST_WRITES.with(|cell| {
611 cell.borrow_mut().clear();
612 });
613}
614
615pub(crate) fn check_and_overwrite_unused_records_locked() {
623 EXTRA_STATE_OBJECTS.with(|cell| {
624 cell.borrow_mut().remove_if(|state| {
625 state.overwrite_unused_records()
627 });
628 });
629}
630
631pub(crate) fn maybe_check_and_overwrite_unused_records_locked(current_snapshot_id: SnapshotId) {
632 let should_run = EXTRA_STATE_OBJECTS.with(|cell| {
633 let set = cell.borrow();
634 if set.is_empty() {
635 return false;
636 }
637 let last_cleanup = LAST_UNUSED_RECORD_CLEANUP.with(|last| last.get());
638 let interval = if set.len() >= UNUSED_RECORD_CLEANUP_MIN_SIZE {
639 UNUSED_RECORD_CLEANUP_BUSY_INTERVAL
640 } else {
641 UNUSED_RECORD_CLEANUP_INTERVAL
642 };
643 current_snapshot_id.saturating_sub(last_cleanup) >= interval
644 });
645
646 if should_run {
647 LAST_UNUSED_RECORD_CLEANUP.with(|cell| cell.set(current_snapshot_id));
648 check_and_overwrite_unused_records_locked();
649 }
650}
651
652#[cfg(test)]
653pub(crate) fn clear_unused_record_cleanup_for_tests() {
654 LAST_UNUSED_RECORD_CLEANUP.with(|cell| cell.set(0));
655}
656
657pub(crate) fn optimistic_merges(
658 current_snapshot_id: SnapshotId,
659 base_parent_id: SnapshotId,
660 modified_objects: &[(StateObjectId, Arc<dyn StateObject>, SnapshotId)],
661 invalid_snapshots: &SnapshotIdSet,
662 applying_invalid: &SnapshotIdSet,
663) -> Option<HashMap<usize, Rc<StateRecord>>> {
664 if modified_objects.is_empty() {
665 return None;
666 }
667
668 let mut result: Option<HashMap<usize, Rc<StateRecord>>> = None;
669
670 for (_, state, writer_id) in modified_objects.iter() {
671 let head = state.first_record();
672
673 let current = match crate::state::readable_record_for(
674 &head,
675 current_snapshot_id,
676 invalid_snapshots,
677 ) {
678 Some(record) => record,
679 None => continue,
680 };
681
682 let (previous_opt, found_base) =
684 mutable::find_previous_record(&head, base_parent_id, applying_invalid);
685 let previous = previous_opt?;
686
687 if !found_base || previous.snapshot_id() == crate::state::PREEXISTING_SNAPSHOT_ID {
688 continue;
689 }
690
691 if Rc::ptr_eq(¤t, &previous) {
692 continue;
693 }
694
695 let applied = mutable::find_record_by_id(&head, *writer_id)?;
696
697 let merged = state.merge_records(
698 Rc::clone(&previous),
699 Rc::clone(¤t),
700 Rc::clone(&applied),
701 )?;
702
703 result
704 .get_or_insert_with(HashMap::default)
705 .insert(Rc::as_ptr(¤t) as usize, merged);
706 }
707
708 result
709}
710
711#[allow(clippy::arc_with_non_send_sync)]
717pub fn merge_read_observers(
718 a: Option<ReadObserver>,
719 b: Option<ReadObserver>,
720) -> Option<ReadObserver> {
721 match (a, b) {
722 (None, None) => None,
723 (Some(a), None) => Some(a),
724 (None, Some(b)) => Some(b),
725 (Some(a), Some(b)) => Some(Arc::new(move |state: &dyn StateObject| {
726 a(state);
727 b(state);
728 })),
729 }
730}
731
732#[allow(clippy::arc_with_non_send_sync)]
738pub fn merge_write_observers(
739 a: Option<WriteObserver>,
740 b: Option<WriteObserver>,
741) -> Option<WriteObserver> {
742 match (a, b) {
743 (None, None) => None,
744 (Some(a), None) => Some(a),
745 (None, Some(b)) => Some(b),
746 (Some(a), Some(b)) => Some(Arc::new(move |state: &dyn StateObject| {
747 a(state);
748 b(state);
749 })),
750 }
751}
752
753pub(crate) struct SnapshotState {
755 pub(crate) id: Cell<SnapshotId>,
757 pub(crate) invalid: RefCell<SnapshotIdSet>,
759 pub(crate) pin_handle: Cell<PinHandle>,
761 pub(crate) disposed: Cell<bool>,
763 pub(crate) read_observer: Option<ReadObserver>,
765 pub(crate) write_observer: Option<WriteObserver>,
767 #[allow(clippy::type_complexity)]
769 pub(crate) modified: RefCell<HashMap<StateObjectId, (Arc<dyn StateObject>, SnapshotId)>>,
771 on_dispose: RefCell<Option<Box<dyn FnOnce()>>>,
773 runtime_tracked: bool,
775 pending_children: RefCell<HashSet<SnapshotId>>,
777}
778
779impl SnapshotState {
780 pub(crate) fn new(
781 id: SnapshotId,
782 invalid: SnapshotIdSet,
783 read_observer: Option<ReadObserver>,
784 write_observer: Option<WriteObserver>,
785 runtime_tracked: bool,
786 ) -> Self {
787 Self::new_with_pinning(
788 id,
789 invalid,
790 read_observer,
791 write_observer,
792 runtime_tracked,
793 true,
794 )
795 }
796
797 pub(crate) fn new_with_pinning(
802 id: SnapshotId,
803 invalid: SnapshotIdSet,
804 read_observer: Option<ReadObserver>,
805 write_observer: Option<WriteObserver>,
806 runtime_tracked: bool,
807 should_pin: bool,
808 ) -> Self {
809 let pin_handle = if should_pin {
810 snapshot_pinning::track_pinning(id, &invalid)
811 } else {
812 snapshot_pinning::PinHandle::INVALID
813 };
814 Self {
815 id: Cell::new(id),
816 invalid: RefCell::new(invalid),
817 pin_handle: Cell::new(pin_handle),
818 disposed: Cell::new(false),
819 read_observer,
820 write_observer,
821 modified: RefCell::new(HashMap::default()),
822 on_dispose: RefCell::new(None),
823 runtime_tracked,
824 pending_children: RefCell::new(HashSet::default()),
825 }
826 }
827
828 pub(crate) fn record_read(&self, state: &dyn StateObject) {
829 if let Some(ref observer) = self.read_observer {
830 observer(state);
831 }
832 }
833
834 pub(crate) fn record_write(&self, state: Arc<dyn StateObject>, writer_id: SnapshotId) {
835 let state_id = state.object_id().as_usize();
837
838 let mut modified = self.modified.borrow_mut();
839
840 match modified.entry(state_id) {
842 std::collections::hash_map::Entry::Vacant(e) => {
843 if let Some(ref observer) = self.write_observer {
844 observer(&*state);
845 }
846 e.insert((state, writer_id));
848 }
849 std::collections::hash_map::Entry::Occupied(mut e) => {
850 e.insert((state, writer_id));
852 }
853 }
854 }
855
856 pub(crate) fn dispose(&self) {
857 if !self.disposed.replace(true) {
858 let pin_handle = self.pin_handle.get();
859 snapshot_pinning::release_pinning(pin_handle);
860 if let Some(cb) = self.on_dispose.borrow_mut().take() {
861 cb();
862 }
863 if self.runtime_tracked {
864 close_snapshot(self.id.get());
865 }
866 }
867 }
868
869 pub(crate) fn add_pending_child(&self, id: SnapshotId) {
870 self.pending_children.borrow_mut().insert(id);
871 }
872
873 pub(crate) fn remove_pending_child(&self, id: SnapshotId) {
874 self.pending_children.borrow_mut().remove(&id);
875 }
876
877 pub(crate) fn has_pending_children(&self) -> bool {
878 !self.pending_children.borrow().is_empty()
879 }
880
881 pub(crate) fn pending_children(&self) -> Vec<SnapshotId> {
882 self.pending_children.borrow().iter().copied().collect()
883 }
884
885 pub(crate) fn set_on_dispose<F>(&self, f: F)
886 where
887 F: FnOnce() + 'static,
888 {
889 *self.on_dispose.borrow_mut() = Some(Box::new(f));
890 }
891}
892
893#[cfg(test)]
894mod tests {
895 use super::*;
896
897 #[test]
898 fn test_apply_result_is_success() {
899 assert!(SnapshotApplyResult::Success.is_success());
900 assert!(!SnapshotApplyResult::Failure.is_success());
901 }
902
903 #[test]
904 fn test_apply_result_is_failure() {
905 assert!(!SnapshotApplyResult::Success.is_failure());
906 assert!(SnapshotApplyResult::Failure.is_failure());
907 }
908
909 #[test]
910 fn test_apply_result_check_success() {
911 SnapshotApplyResult::Success.check(); }
913
914 #[test]
915 #[should_panic(expected = "Snapshot apply failed")]
916 fn test_apply_result_check_failure() {
917 SnapshotApplyResult::Failure.check(); }
919
920 #[test]
921 fn test_merge_read_observers_both_none() {
922 let result = merge_read_observers(None, None);
923 assert!(result.is_none());
924 }
925
926 #[test]
927 fn test_merge_read_observers_one_some() {
928 let observer = Arc::new(|_: &dyn StateObject| {});
929 let result = merge_read_observers(Some(observer.clone()), None);
930 assert!(result.is_some());
931
932 let result = merge_read_observers(None, Some(observer));
933 assert!(result.is_some());
934 }
935
936 #[test]
937 fn test_merge_write_observers_both_none() {
938 let result = merge_write_observers(None, None);
939 assert!(result.is_none());
940 }
941
942 #[test]
943 fn test_merge_write_observers_one_some() {
944 let observer = Arc::new(|_: &dyn StateObject| {});
945 let result = merge_write_observers(Some(observer.clone()), None);
946 assert!(result.is_some());
947
948 let result = merge_write_observers(None, Some(observer));
949 assert!(result.is_some());
950 }
951
952 #[test]
953 fn test_current_snapshot_none_initially() {
954 set_current_snapshot(None);
955 assert!(current_snapshot().is_none());
956 }
957
958 struct TestStateObject {
960 id: usize,
961 }
962
963 impl TestStateObject {
964 fn new(id: usize) -> Arc<Self> {
965 Arc::new(Self { id })
966 }
967 }
968
969 impl StateObject for TestStateObject {
970 fn object_id(&self) -> crate::state::ObjectId {
971 crate::state::ObjectId(self.id)
972 }
973
974 fn first_record(&self) -> Rc<crate::state::StateRecord> {
975 unimplemented!("Not needed for observer tests")
976 }
977
978 fn readable_record(
979 &self,
980 _snapshot_id: SnapshotId,
981 _invalid: &SnapshotIdSet,
982 ) -> Rc<crate::state::StateRecord> {
983 unimplemented!("Not needed for observer tests")
984 }
985
986 fn prepend_state_record(&self, _record: Rc<crate::state::StateRecord>) {
987 unimplemented!("Not needed for observer tests")
988 }
989
990 fn promote_record(&self, _child_id: SnapshotId) -> Result<(), &'static str> {
991 unimplemented!("Not needed for observer tests")
992 }
993
994 fn as_any(&self) -> &dyn std::any::Any {
995 self
996 }
997 }
998
999 #[test]
1000 fn test_apply_observer_receives_correct_modified_objects() {
1001 use std::sync::Mutex;
1002
1003 let received_count = Arc::new(Mutex::new(0));
1005 let received_snapshot_id = Arc::new(Mutex::new(0));
1006
1007 let received_count_clone = received_count.clone();
1008 let received_snapshot_id_clone = received_snapshot_id.clone();
1009
1010 let _handle = register_apply_observer(Rc::new(move |modified, snapshot_id| {
1012 *received_snapshot_id_clone.lock().unwrap() = snapshot_id;
1013 *received_count_clone.lock().unwrap() = modified.len();
1014 }));
1015
1016 let obj1: Arc<dyn StateObject> = TestStateObject::new(42);
1018 let obj2: Arc<dyn StateObject> = TestStateObject::new(99);
1019 let modified = vec![obj1, obj2];
1020
1021 notify_apply_observers(&modified, 123);
1023
1024 assert_eq!(*received_snapshot_id.lock().unwrap(), 123);
1026 assert_eq!(*received_count.lock().unwrap(), 2);
1027 }
1028
1029 #[test]
1030 fn test_apply_observer_receives_correct_snapshot_id() {
1031 use std::sync::Mutex;
1032
1033 let received_id = Arc::new(Mutex::new(0));
1034 let received_id_clone = received_id.clone();
1035
1036 let _handle = register_apply_observer(Rc::new(move |_, snapshot_id| {
1037 *received_id_clone.lock().unwrap() = snapshot_id;
1038 }));
1039
1040 notify_apply_observers(&[], 456);
1042
1043 assert_eq!(*received_id.lock().unwrap(), 456);
1044 }
1045
1046 #[test]
1047 fn test_multiple_apply_observers_all_called() {
1048 use std::sync::Mutex;
1049
1050 let call_count1 = Arc::new(Mutex::new(0));
1051 let call_count2 = Arc::new(Mutex::new(0));
1052 let call_count3 = Arc::new(Mutex::new(0));
1053
1054 let call_count1_clone = call_count1.clone();
1055 let call_count2_clone = call_count2.clone();
1056 let call_count3_clone = call_count3.clone();
1057
1058 let _handle1 = register_apply_observer(Rc::new(move |_, _| {
1060 *call_count1_clone.lock().unwrap() += 1;
1061 }));
1062
1063 let _handle2 = register_apply_observer(Rc::new(move |_, _| {
1064 *call_count2_clone.lock().unwrap() += 1;
1065 }));
1066
1067 let _handle3 = register_apply_observer(Rc::new(move |_, _| {
1068 *call_count3_clone.lock().unwrap() += 1;
1069 }));
1070
1071 notify_apply_observers(&[], 1);
1073
1074 assert_eq!(*call_count1.lock().unwrap(), 1);
1076 assert_eq!(*call_count2.lock().unwrap(), 1);
1077 assert_eq!(*call_count3.lock().unwrap(), 1);
1078
1079 notify_apply_observers(&[], 2);
1081
1082 assert_eq!(*call_count1.lock().unwrap(), 2);
1084 assert_eq!(*call_count2.lock().unwrap(), 2);
1085 assert_eq!(*call_count3.lock().unwrap(), 2);
1086 }
1087
1088 #[test]
1089 fn test_apply_observer_not_called_for_empty_modifications() {
1090 use std::sync::Mutex;
1091
1092 let call_count = Arc::new(Mutex::new(0));
1093 let call_count_clone = call_count.clone();
1094
1095 let _handle = register_apply_observer(Rc::new(move |modified, _| {
1096 *call_count_clone.lock().unwrap() += 1;
1098 assert_eq!(modified.len(), 0);
1099 }));
1100
1101 notify_apply_observers(&[], 1);
1103
1104 assert_eq!(*call_count.lock().unwrap(), 1);
1106 }
1107
1108 #[test]
1109 fn test_observer_handle_drop_removes_correct_observer() {
1110 use std::sync::Mutex;
1111
1112 let calls = Arc::new(Mutex::new(Vec::new()));
1114
1115 let calls1 = calls.clone();
1116 let handle1 = register_apply_observer(Rc::new(move |_, _| {
1117 calls1.lock().unwrap().push(1);
1118 }));
1119
1120 let calls2 = calls.clone();
1121 let handle2 = register_apply_observer(Rc::new(move |_, _| {
1122 calls2.lock().unwrap().push(2);
1123 }));
1124
1125 let calls3 = calls.clone();
1126 let handle3 = register_apply_observer(Rc::new(move |_, _| {
1127 calls3.lock().unwrap().push(3);
1128 }));
1129
1130 notify_apply_observers(&[], 1);
1132 let result = calls.lock().unwrap().clone();
1133 assert_eq!(result.len(), 3);
1134 assert!(result.contains(&1));
1135 assert!(result.contains(&2));
1136 assert!(result.contains(&3));
1137 calls.lock().unwrap().clear();
1138
1139 drop(handle2);
1141
1142 notify_apply_observers(&[], 2);
1144 let result = calls.lock().unwrap().clone();
1145 assert_eq!(result.len(), 2);
1146 assert!(result.contains(&1));
1147 assert!(result.contains(&3));
1148 assert!(!result.contains(&2));
1149 calls.lock().unwrap().clear();
1150
1151 drop(handle1);
1153
1154 notify_apply_observers(&[], 3);
1156 let result = calls.lock().unwrap().clone();
1157 assert_eq!(result.len(), 1);
1158 assert!(result.contains(&3));
1159 calls.lock().unwrap().clear();
1160
1161 drop(handle3);
1163
1164 notify_apply_observers(&[], 4);
1166 assert_eq!(calls.lock().unwrap().len(), 0);
1167 }
1168
1169 #[test]
1170 fn test_observer_handle_drop_in_different_orders() {
1171 use std::sync::Mutex;
1172
1173 {
1175 let calls = Arc::new(Mutex::new(Vec::new()));
1176
1177 let calls1 = calls.clone();
1178 let h1 = register_apply_observer(Rc::new(move |_, _| {
1179 calls1.lock().unwrap().push(1);
1180 }));
1181
1182 let calls2 = calls.clone();
1183 let h2 = register_apply_observer(Rc::new(move |_, _| {
1184 calls2.lock().unwrap().push(2);
1185 }));
1186
1187 let calls3 = calls.clone();
1188 let h3 = register_apply_observer(Rc::new(move |_, _| {
1189 calls3.lock().unwrap().push(3);
1190 }));
1191
1192 drop(h3);
1193 notify_apply_observers(&[], 1);
1194 let result = calls.lock().unwrap().clone();
1195 assert!(result.contains(&1) && result.contains(&2) && !result.contains(&3));
1196 calls.lock().unwrap().clear();
1197
1198 drop(h2);
1199 notify_apply_observers(&[], 2);
1200 let result = calls.lock().unwrap().clone();
1201 assert_eq!(result.len(), 1);
1202 assert!(result.contains(&1));
1203 calls.lock().unwrap().clear();
1204
1205 drop(h1);
1206 notify_apply_observers(&[], 3);
1207 assert_eq!(calls.lock().unwrap().len(), 0);
1208 }
1209
1210 {
1212 let calls = Arc::new(Mutex::new(Vec::new()));
1213
1214 let calls1 = calls.clone();
1215 let h1 = register_apply_observer(Rc::new(move |_, _| {
1216 calls1.lock().unwrap().push(1);
1217 }));
1218
1219 let calls2 = calls.clone();
1220 let h2 = register_apply_observer(Rc::new(move |_, _| {
1221 calls2.lock().unwrap().push(2);
1222 }));
1223
1224 let calls3 = calls.clone();
1225 let h3 = register_apply_observer(Rc::new(move |_, _| {
1226 calls3.lock().unwrap().push(3);
1227 }));
1228
1229 drop(h1);
1230 notify_apply_observers(&[], 1);
1231 let result = calls.lock().unwrap().clone();
1232 assert!(!result.contains(&1) && result.contains(&2) && result.contains(&3));
1233 calls.lock().unwrap().clear();
1234
1235 drop(h2);
1236 notify_apply_observers(&[], 2);
1237 let result = calls.lock().unwrap().clone();
1238 assert_eq!(result.len(), 1);
1239 assert!(result.contains(&3));
1240 calls.lock().unwrap().clear();
1241
1242 drop(h3);
1243 notify_apply_observers(&[], 3);
1244 assert_eq!(calls.lock().unwrap().len(), 0);
1245 }
1246 }
1247
1248 #[test]
1249 fn test_remaining_observers_still_work_after_drop() {
1250 use std::sync::Mutex;
1251
1252 let calls = Arc::new(Mutex::new(Vec::new()));
1253
1254 let calls1 = calls.clone();
1255 let handle1 = register_apply_observer(Rc::new(move |_, snapshot_id| {
1256 calls1.lock().unwrap().push((1, snapshot_id));
1257 }));
1258
1259 let calls2 = calls.clone();
1260 let handle2 = register_apply_observer(Rc::new(move |_, snapshot_id| {
1261 calls2.lock().unwrap().push((2, snapshot_id));
1262 }));
1263
1264 notify_apply_observers(&[], 100);
1266 assert_eq!(calls.lock().unwrap().len(), 2);
1267 calls.lock().unwrap().clear();
1268
1269 drop(handle1);
1271
1272 notify_apply_observers(&[], 200);
1274 assert_eq!(*calls.lock().unwrap(), vec![(2, 200)]);
1275 calls.lock().unwrap().clear();
1276
1277 let calls3 = calls.clone();
1279 let _handle3 = register_apply_observer(Rc::new(move |_, snapshot_id| {
1280 calls3.lock().unwrap().push((3, snapshot_id));
1281 }));
1282
1283 notify_apply_observers(&[], 300);
1285 let result = calls.lock().unwrap().clone();
1286 assert_eq!(result.len(), 2);
1287 assert!(result.contains(&(2, 300)));
1288 assert!(result.contains(&(3, 300)));
1289
1290 drop(handle2);
1291 }
1292
1293 #[test]
1294 fn test_observer_ids_are_unique() {
1295 use std::sync::Mutex;
1296
1297 let ids = Arc::new(Mutex::new(std::collections::HashSet::new()));
1298
1299 let mut handles = Vec::new();
1300
1301 for i in 0..100 {
1305 let ids_clone = ids.clone();
1306 let handle = register_apply_observer(Rc::new(move |_, _| {
1307 ids_clone.lock().unwrap().insert(i);
1308 }));
1309 handles.push(handle);
1310 }
1311
1312 notify_apply_observers(&[], 1);
1314 assert_eq!(ids.lock().unwrap().len(), 100);
1315
1316 for i in (0..100).step_by(2) {
1318 handles.remove(i / 2);
1319 }
1320
1321 ids.lock().unwrap().clear();
1323 notify_apply_observers(&[], 2);
1324 assert_eq!(ids.lock().unwrap().len(), 50);
1325 }
1326
1327 #[test]
1328 fn test_state_object_storage_in_modified_set() {
1329 use crate::state::StateObject;
1330
1331 struct TestState;
1333
1334 impl StateObject for TestState {
1335 fn object_id(&self) -> crate::state::ObjectId {
1336 crate::state::ObjectId(12345)
1337 }
1338
1339 fn first_record(&self) -> Rc<crate::state::StateRecord> {
1340 unimplemented!("Not needed for this test")
1341 }
1342
1343 fn readable_record(
1344 &self,
1345 _snapshot_id: SnapshotId,
1346 _invalid: &SnapshotIdSet,
1347 ) -> Rc<crate::state::StateRecord> {
1348 unimplemented!("Not needed for this test")
1349 }
1350
1351 fn prepend_state_record(&self, _record: Rc<crate::state::StateRecord>) {
1352 unimplemented!("Not needed for this test")
1353 }
1354
1355 fn promote_record(&self, _child_id: SnapshotId) -> Result<(), &'static str> {
1356 unimplemented!("Not needed for this test")
1357 }
1358
1359 fn as_any(&self) -> &dyn std::any::Any {
1360 self
1361 }
1362 }
1363
1364 let state = SnapshotState::new(1, SnapshotIdSet::new(), None, None, false);
1365
1366 let state_obj = Arc::new(TestState) as Arc<dyn StateObject>;
1368
1369 state.record_write(state_obj.clone(), 1);
1371
1372 let modified = state.modified.borrow();
1374 assert_eq!(modified.len(), 1);
1375 assert!(modified.contains_key(&12345));
1376
1377 let (stored, writer_id) = modified.get(&12345).unwrap();
1379 assert_eq!(stored.object_id().as_usize(), 12345);
1380 assert_eq!(*writer_id, 1);
1381 }
1382
1383 #[test]
1384 fn test_multiple_writes_to_same_state_object() {
1385 use crate::state::StateObject;
1386
1387 struct TestState;
1388
1389 impl StateObject for TestState {
1390 fn object_id(&self) -> crate::state::ObjectId {
1391 crate::state::ObjectId(99999)
1392 }
1393
1394 fn first_record(&self) -> Rc<crate::state::StateRecord> {
1395 unimplemented!()
1396 }
1397
1398 fn readable_record(
1399 &self,
1400 _snapshot_id: SnapshotId,
1401 _invalid: &SnapshotIdSet,
1402 ) -> Rc<crate::state::StateRecord> {
1403 unimplemented!()
1404 }
1405
1406 fn prepend_state_record(&self, _record: Rc<crate::state::StateRecord>) {
1407 unimplemented!()
1408 }
1409
1410 fn promote_record(&self, _child_id: SnapshotId) -> Result<(), &'static str> {
1411 unimplemented!()
1412 }
1413
1414 fn as_any(&self) -> &dyn std::any::Any {
1415 self
1416 }
1417 }
1418
1419 let state = SnapshotState::new(1, SnapshotIdSet::new(), None, None, false);
1420 let state_obj = Arc::new(TestState) as Arc<dyn StateObject>;
1421
1422 state.record_write(state_obj.clone(), 1);
1424 assert_eq!(state.modified.borrow().len(), 1);
1425
1426 state.record_write(state_obj.clone(), 2);
1428 let modified = state.modified.borrow();
1429 assert_eq!(modified.len(), 1);
1430 assert!(modified.contains_key(&99999));
1431 let (_, writer_id) = modified.get(&99999).unwrap();
1432 assert_eq!(*writer_id, 2);
1433 }
1434}