1#![allow(clippy::arc_with_non_send_sync)]
24
25use crate::collections::map::HashMap;
26use crate::collections::map::HashSet;
27use crate::snapshot_id_set::{SnapshotId, SnapshotIdSet};
28use crate::snapshot_pinning::{self, PinHandle};
29use crate::snapshot_weak_set::SnapshotWeakSetDebugStats;
30use crate::state::{StateObject, StateRecord};
31use std::cell::{Cell, RefCell};
32use std::hash::{Hash, Hasher};
33use std::rc::Rc;
34use std::sync::{Arc, Weak};
35
36mod global;
37mod mutable;
38mod nested;
39mod readonly;
40mod runtime;
41mod transparent;
42
43#[cfg(test)]
44mod integration_tests;
45
46pub use global::{advance_global_snapshot, GlobalSnapshot};
47pub use mutable::MutableSnapshot;
48pub use nested::{NestedMutableSnapshot, NestedReadonlySnapshot};
49pub use readonly::ReadonlySnapshot;
50pub use transparent::{TransparentObserverMutableSnapshot, TransparentObserverSnapshot};
51
52pub(crate) use runtime::{allocate_snapshot, close_snapshot, with_runtime};
53
54#[cfg(test)]
55pub(crate) use runtime::{reset_runtime_for_tests, TestRuntimeGuard};
56
57pub type ReadObserver = Arc<dyn Fn(&dyn StateObject) + 'static>;
59
60pub type WriteObserver = Arc<dyn Fn(&dyn StateObject) + 'static>;
62
63pub type ApplyObserver = Rc<dyn Fn(&[Arc<dyn StateObject>], SnapshotId) + 'static>;
65
66#[derive(Debug, Clone, PartialEq, Eq)]
68pub enum SnapshotApplyResult {
69 Success,
71 Failure,
73}
74
75impl SnapshotApplyResult {
76 pub fn is_success(&self) -> bool {
78 matches!(self, SnapshotApplyResult::Success)
79 }
80
81 pub fn is_failure(&self) -> bool {
83 matches!(self, SnapshotApplyResult::Failure)
84 }
85
86 #[track_caller]
88 pub fn check(&self) {
89 if self.is_failure() {
90 panic!("Snapshot apply failed");
91 }
92 }
93}
94
95pub type StateObjectId = usize;
97
98#[derive(Clone)]
103pub enum AnySnapshot {
104 Readonly(Arc<ReadonlySnapshot>),
105 Mutable(Arc<MutableSnapshot>),
106 NestedReadonly(Arc<NestedReadonlySnapshot>),
107 NestedMutable(Arc<NestedMutableSnapshot>),
108 Global(Arc<GlobalSnapshot>),
109 TransparentMutable(Arc<TransparentObserverMutableSnapshot>),
110 TransparentReadonly(Arc<TransparentObserverSnapshot>),
111}
112
113#[derive(Clone)]
120pub enum AnyMutableSnapshot {
121 Root(Arc<MutableSnapshot>),
122 Nested(Arc<NestedMutableSnapshot>),
123}
124
125impl AnyMutableSnapshot {
126 pub fn snapshot_id(&self) -> SnapshotId {
128 match self {
129 AnyMutableSnapshot::Root(s) => s.snapshot_id(),
130 AnyMutableSnapshot::Nested(s) => s.snapshot_id(),
131 }
132 }
133
134 pub fn invalid(&self) -> SnapshotIdSet {
136 match self {
137 AnyMutableSnapshot::Root(s) => s.invalid(),
138 AnyMutableSnapshot::Nested(s) => s.invalid(),
139 }
140 }
141
142 pub fn enter<T>(&self, f: impl FnOnce() -> T) -> T {
144 match self {
145 AnyMutableSnapshot::Root(s) => s.enter(f),
146 AnyMutableSnapshot::Nested(s) => s.enter(f),
147 }
148 }
149
150 pub fn apply(&self) -> SnapshotApplyResult {
152 match self {
153 AnyMutableSnapshot::Root(s) => s.apply(),
154 AnyMutableSnapshot::Nested(s) => s.apply(),
155 }
156 }
157
158 pub fn dispose(&self) {
160 match self {
161 AnyMutableSnapshot::Root(s) => s.dispose(),
162 AnyMutableSnapshot::Nested(s) => s.dispose(),
163 }
164 }
165}
166
167impl AnySnapshot {
168 pub fn snapshot_id(&self) -> SnapshotId {
170 match self {
171 AnySnapshot::Readonly(s) => s.snapshot_id(),
172 AnySnapshot::Mutable(s) => s.snapshot_id(),
173 AnySnapshot::NestedReadonly(s) => s.snapshot_id(),
174 AnySnapshot::NestedMutable(s) => s.snapshot_id(),
175 AnySnapshot::Global(s) => s.snapshot_id(),
176 AnySnapshot::TransparentMutable(s) => s.snapshot_id(),
177 AnySnapshot::TransparentReadonly(s) => s.snapshot_id(),
178 }
179 }
180
181 pub fn invalid(&self) -> SnapshotIdSet {
183 match self {
184 AnySnapshot::Readonly(s) => s.invalid(),
185 AnySnapshot::Mutable(s) => s.invalid(),
186 AnySnapshot::NestedReadonly(s) => s.invalid(),
187 AnySnapshot::NestedMutable(s) => s.invalid(),
188 AnySnapshot::Global(s) => s.invalid(),
189 AnySnapshot::TransparentMutable(s) => s.invalid(),
190 AnySnapshot::TransparentReadonly(s) => s.invalid(),
191 }
192 }
193
194 pub fn is_valid(&self, id: SnapshotId) -> bool {
196 let snapshot_id = self.snapshot_id();
197 id <= snapshot_id && !self.invalid().get(id)
198 }
199
200 pub fn read_only(&self) -> bool {
202 match self {
203 AnySnapshot::Readonly(_) => true,
204 AnySnapshot::Mutable(_) => false,
205 AnySnapshot::NestedReadonly(_) => true,
206 AnySnapshot::NestedMutable(_) => false,
207 AnySnapshot::Global(_) => false,
208 AnySnapshot::TransparentMutable(_) => false,
209 AnySnapshot::TransparentReadonly(_) => true,
210 }
211 }
212
213 pub fn root(&self) -> AnySnapshot {
215 match self {
216 AnySnapshot::Readonly(s) => AnySnapshot::Readonly(s.root_readonly()),
217 AnySnapshot::Mutable(s) => AnySnapshot::Mutable(s.root_mutable()),
218 AnySnapshot::NestedReadonly(s) => AnySnapshot::NestedReadonly(s.root_nested_readonly()),
219 AnySnapshot::NestedMutable(s) => AnySnapshot::Mutable(s.root_mutable()),
220 AnySnapshot::Global(s) => AnySnapshot::Global(s.root_global()),
221 AnySnapshot::TransparentMutable(s) => {
222 AnySnapshot::TransparentMutable(s.root_transparent_mutable())
223 }
224 AnySnapshot::TransparentReadonly(s) => {
225 AnySnapshot::TransparentReadonly(s.root_transparent_readonly())
226 }
227 }
228 }
229
230 pub fn is_same_transparent(&self, other: &Arc<TransparentObserverMutableSnapshot>) -> bool {
232 matches!(self, AnySnapshot::TransparentMutable(snapshot) if Arc::ptr_eq(snapshot, other))
233 }
234
235 pub fn is_same_transparent_mutable(
237 &self,
238 other: &Arc<TransparentObserverMutableSnapshot>,
239 ) -> bool {
240 self.is_same_transparent(other)
241 }
242
243 pub fn is_same_transparent_readonly(&self, other: &Arc<TransparentObserverSnapshot>) -> bool {
245 matches!(self, AnySnapshot::TransparentReadonly(snapshot) if Arc::ptr_eq(snapshot, other))
246 }
247
248 pub fn enter<T>(&self, f: impl FnOnce() -> T) -> T {
250 match self {
251 AnySnapshot::Readonly(s) => s.enter(f),
252 AnySnapshot::Mutable(s) => s.enter(f),
253 AnySnapshot::NestedReadonly(s) => s.enter(f),
254 AnySnapshot::NestedMutable(s) => s.enter(f),
255 AnySnapshot::Global(s) => s.enter(f),
256 AnySnapshot::TransparentMutable(s) => s.enter(f),
257 AnySnapshot::TransparentReadonly(s) => s.enter(f),
258 }
259 }
260
261 pub fn take_nested_snapshot(&self, read_observer: Option<ReadObserver>) -> AnySnapshot {
263 match self {
264 AnySnapshot::Readonly(s) => {
265 AnySnapshot::Readonly(s.take_nested_snapshot(read_observer))
266 }
267 AnySnapshot::Mutable(s) => AnySnapshot::Readonly(s.take_nested_snapshot(read_observer)),
268 AnySnapshot::NestedReadonly(s) => {
269 AnySnapshot::NestedReadonly(s.take_nested_snapshot(read_observer))
270 }
271 AnySnapshot::NestedMutable(s) => {
272 AnySnapshot::Readonly(s.take_nested_snapshot(read_observer))
273 }
274 AnySnapshot::Global(s) => AnySnapshot::Readonly(s.take_nested_snapshot(read_observer)),
275 AnySnapshot::TransparentMutable(s) => {
276 AnySnapshot::Readonly(s.take_nested_snapshot(read_observer))
277 }
278 AnySnapshot::TransparentReadonly(s) => {
279 AnySnapshot::TransparentReadonly(s.take_nested_snapshot(read_observer))
280 }
281 }
282 }
283
284 pub fn has_pending_changes(&self) -> bool {
286 match self {
287 AnySnapshot::Readonly(s) => s.has_pending_changes(),
288 AnySnapshot::Mutable(s) => s.has_pending_changes(),
289 AnySnapshot::NestedReadonly(s) => s.has_pending_changes(),
290 AnySnapshot::NestedMutable(s) => s.has_pending_changes(),
291 AnySnapshot::Global(s) => s.has_pending_changes(),
292 AnySnapshot::TransparentMutable(s) => s.has_pending_changes(),
293 AnySnapshot::TransparentReadonly(s) => s.has_pending_changes(),
294 }
295 }
296
297 pub fn dispose(&self) {
299 match self {
300 AnySnapshot::Readonly(s) => s.dispose(),
301 AnySnapshot::Mutable(s) => s.dispose(),
302 AnySnapshot::NestedReadonly(s) => s.dispose(),
303 AnySnapshot::NestedMutable(s) => s.dispose(),
304 AnySnapshot::Global(s) => s.dispose(),
305 AnySnapshot::TransparentMutable(s) => s.dispose(),
306 AnySnapshot::TransparentReadonly(s) => s.dispose(),
307 }
308 }
309
310 pub fn is_disposed(&self) -> bool {
312 match self {
313 AnySnapshot::Readonly(s) => s.is_disposed(),
314 AnySnapshot::Mutable(s) => s.is_disposed(),
315 AnySnapshot::NestedReadonly(s) => s.is_disposed(),
316 AnySnapshot::NestedMutable(s) => s.is_disposed(),
317 AnySnapshot::Global(s) => s.is_disposed(),
318 AnySnapshot::TransparentMutable(s) => s.is_disposed(),
319 AnySnapshot::TransparentReadonly(s) => s.is_disposed(),
320 }
321 }
322
323 pub fn record_read(&self, state: &dyn StateObject) {
325 match self {
326 AnySnapshot::Readonly(s) => s.record_read(state),
327 AnySnapshot::Mutable(s) => s.record_read(state),
328 AnySnapshot::NestedReadonly(s) => s.record_read(state),
329 AnySnapshot::NestedMutable(s) => s.record_read(state),
330 AnySnapshot::Global(s) => s.record_read(state),
331 AnySnapshot::TransparentMutable(s) => s.record_read(state),
332 AnySnapshot::TransparentReadonly(s) => s.record_read(state),
333 }
334 }
335
336 pub fn record_write(&self, state: Arc<dyn StateObject>) {
338 match self {
339 AnySnapshot::Readonly(s) => s.record_write(state),
340 AnySnapshot::Mutable(s) => s.record_write(state),
341 AnySnapshot::NestedReadonly(s) => s.record_write(state),
342 AnySnapshot::NestedMutable(s) => s.record_write(state),
343 AnySnapshot::Global(s) => s.record_write(state),
344 AnySnapshot::TransparentMutable(s) => s.record_write(state),
345 AnySnapshot::TransparentReadonly(s) => s.record_write(state),
346 }
347 }
348
349 pub fn apply(&self) -> SnapshotApplyResult {
351 match self {
352 AnySnapshot::Mutable(s) => s.apply(),
353 AnySnapshot::NestedMutable(s) => s.apply(),
354 AnySnapshot::Global(s) => s.apply(),
355 AnySnapshot::TransparentMutable(s) => s.apply(),
356 _ => panic!("Cannot apply a read-only snapshot"),
357 }
358 }
359
360 pub fn take_nested_mutable_snapshot(
362 &self,
363 read_observer: Option<ReadObserver>,
364 write_observer: Option<WriteObserver>,
365 ) -> AnySnapshot {
366 match self {
367 AnySnapshot::Mutable(s) => AnySnapshot::NestedMutable(
368 s.take_nested_mutable_snapshot(read_observer, write_observer),
369 ),
370 AnySnapshot::NestedMutable(s) => AnySnapshot::NestedMutable(
371 s.take_nested_mutable_snapshot(read_observer, write_observer),
372 ),
373 AnySnapshot::Global(s) => {
374 AnySnapshot::Mutable(s.take_nested_mutable_snapshot(read_observer, write_observer))
375 }
376 AnySnapshot::TransparentMutable(s) => AnySnapshot::TransparentMutable(
377 s.take_nested_mutable_snapshot(read_observer, write_observer),
378 ),
379 _ => panic!("Cannot take nested mutable snapshot from read-only snapshot"),
380 }
381 }
382}
383
384thread_local! {
385 static CURRENT_SNAPSHOT: RefCell<Option<AnySnapshot>> = const { RefCell::new(None) };
387}
388
389pub fn current_snapshot() -> Option<AnySnapshot> {
391 CURRENT_SNAPSHOT
392 .try_with(|cell| cell.borrow().clone())
393 .unwrap_or(None)
394}
395
396pub(crate) fn set_current_snapshot(snapshot: Option<AnySnapshot>) {
398 let _ = CURRENT_SNAPSHOT.try_with(|cell| {
399 *cell.borrow_mut() = snapshot;
400 });
401}
402
403struct CurrentSnapshotGuard {
404 previous: Option<AnySnapshot>,
405}
406
407impl CurrentSnapshotGuard {
408 fn enter(snapshot: AnySnapshot) -> Self {
409 let previous = current_snapshot();
410 set_current_snapshot(Some(snapshot));
411 Self { previous }
412 }
413}
414
415impl Drop for CurrentSnapshotGuard {
416 fn drop(&mut self) {
417 set_current_snapshot(self.previous.take());
418 }
419}
420
421pub(crate) fn enter_snapshot_scope<T>(snapshot: AnySnapshot, f: impl FnOnce() -> T) -> T {
422 let _guard = CurrentSnapshotGuard::enter(snapshot);
423 f()
424}
425
426pub fn take_mutable_snapshot(
435 read_observer: Option<ReadObserver>,
436 write_observer: Option<WriteObserver>,
437) -> AnyMutableSnapshot {
438 match current_snapshot() {
441 Some(AnySnapshot::Mutable(parent)) => AnyMutableSnapshot::Nested(
442 parent.take_nested_mutable_snapshot(read_observer, write_observer),
443 ),
444 Some(AnySnapshot::NestedMutable(parent)) => AnyMutableSnapshot::Nested(
445 parent.take_nested_mutable_snapshot(read_observer, write_observer),
446 ),
447 _ => AnyMutableSnapshot::Root(
449 GlobalSnapshot::get_or_create()
450 .take_nested_mutable_snapshot(read_observer, write_observer),
451 ),
452 }
453}
454
455pub fn take_transparent_observer_mutable_snapshot(
464 read_observer: Option<ReadObserver>,
465 write_observer: Option<WriteObserver>,
466) -> Arc<TransparentObserverMutableSnapshot> {
467 let parent = current_snapshot();
468 match parent {
469 Some(AnySnapshot::TransparentMutable(transparent)) if transparent.can_reuse() => {
470 transparent
472 }
473 _ => {
474 let current = current_snapshot()
477 .unwrap_or_else(|| AnySnapshot::Global(GlobalSnapshot::get_or_create()));
478 let id = current.snapshot_id();
479 let invalid = current.invalid();
480 TransparentObserverMutableSnapshot::new(
481 id,
482 invalid,
483 read_observer,
484 write_observer,
485 None,
486 )
487 }
488 }
489}
490
491pub fn allocate_record_id() -> SnapshotId {
493 runtime::allocate_record_id()
494}
495
496pub(crate) fn peek_next_snapshot_id() -> SnapshotId {
501 runtime::peek_next_snapshot_id()
502}
503
504#[derive(Clone)]
505struct ObserverId(Rc<()>);
506
507impl ObserverId {
508 fn new() -> Self {
509 Self(Rc::new(()))
510 }
511}
512
513impl PartialEq for ObserverId {
514 fn eq(&self, other: &Self) -> bool {
515 Rc::ptr_eq(&self.0, &other.0)
516 }
517}
518
519impl Eq for ObserverId {}
520
521impl Hash for ObserverId {
522 fn hash<H: Hasher>(&self, state: &mut H) {
523 Rc::as_ptr(&self.0).hash(state);
524 }
525}
526
527thread_local! {
528 static APPLY_OBSERVERS: RefCell<HashMap<ObserverId, ApplyObserver>> = RefCell::new(HashMap::default());
529}
530
531thread_local! {
532 static LAST_WRITES: RefCell<HashMap<StateObjectId, SnapshotId>> = RefCell::new(HashMap::default());
540}
541
542thread_local! {
543 static EXTRA_STATE_OBJECTS: RefCell<crate::snapshot_weak_set::SnapshotWeakSet> = RefCell::new(crate::snapshot_weak_set::SnapshotWeakSet::new());
546}
547
548const UNUSED_RECORD_CLEANUP_INTERVAL: SnapshotId = 2;
549const UNUSED_RECORD_CLEANUP_BUSY_INTERVAL: SnapshotId = 1;
550const UNUSED_RECORD_CLEANUP_MIN_SIZE: usize = 64;
551
552thread_local! {
553 static LAST_UNUSED_RECORD_CLEANUP: Cell<SnapshotId> = const { Cell::new(0) };
554}
555
556#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
557pub struct SnapshotV2DebugStats {
558 pub apply_observers_len: usize,
559 pub apply_observers_cap: usize,
560 pub last_writes_len: usize,
561 pub last_writes_cap: usize,
562 pub extra_state_objects_len: usize,
563 pub extra_state_objects_cap: usize,
564 pub last_unused_record_cleanup: SnapshotId,
565}
566
567pub fn debug_snapshot_v2_stats() -> SnapshotV2DebugStats {
568 let (apply_observers_len, apply_observers_cap) = APPLY_OBSERVERS.with(|cell| {
569 let observers = cell.borrow();
570 (observers.len(), observers.capacity())
571 });
572 let (last_writes_len, last_writes_cap) = LAST_WRITES.with(|cell| {
573 let writes = cell.borrow();
574 (writes.len(), writes.capacity())
575 });
576 let SnapshotWeakSetDebugStats {
577 len: extra_state_objects_len,
578 capacity: extra_state_objects_cap,
579 } = EXTRA_STATE_OBJECTS.with(|cell| cell.borrow().debug_stats());
580 let last_unused_record_cleanup = LAST_UNUSED_RECORD_CLEANUP.with(|cell| cell.get());
581
582 SnapshotV2DebugStats {
583 apply_observers_len,
584 apply_observers_cap,
585 last_writes_len,
586 last_writes_cap,
587 extra_state_objects_len,
588 extra_state_objects_cap,
589 last_unused_record_cleanup,
590 }
591}
592
593pub fn register_apply_observer(observer: ApplyObserver) -> ObserverHandle {
597 let id = ObserverId::new();
598 APPLY_OBSERVERS.with(|cell| {
599 cell.borrow_mut().insert(id.clone(), observer);
600 });
601 ObserverHandle {
602 kind: ObserverKind::Apply,
603 id,
604 }
605}
606
607pub struct ObserverHandle {
611 kind: ObserverKind,
612 id: ObserverId,
613}
614
615enum ObserverKind {
616 Apply,
617}
618
619impl Drop for ObserverHandle {
620 fn drop(&mut self) {
621 match self.kind {
622 ObserverKind::Apply => {
623 APPLY_OBSERVERS.with(|cell| {
624 cell.borrow_mut().remove(&self.id);
625 });
626 }
627 }
628 }
629}
630
631pub(crate) fn notify_apply_observers(modified: &[Arc<dyn StateObject>], snapshot_id: SnapshotId) {
633 APPLY_OBSERVERS.with(|cell| {
635 let observers: Vec<ApplyObserver> = cell.borrow().values().cloned().collect();
636 for observer in observers.into_iter() {
637 observer(modified, snapshot_id);
638 }
639 });
640}
641
642pub(crate) fn set_last_write(id: StateObjectId, snapshot_id: SnapshotId) {
644 LAST_WRITES.with(|cell| {
645 cell.borrow_mut().insert(id, snapshot_id);
646 });
647}
648
649#[cfg(test)]
651pub(crate) fn clear_last_writes() {
652 LAST_WRITES.with(|cell| {
653 cell.borrow_mut().clear();
654 });
655}
656
657pub(crate) fn check_and_overwrite_unused_records_locked() {
665 EXTRA_STATE_OBJECTS.with(|cell| {
666 cell.borrow_mut().remove_if(|state| {
667 state.overwrite_unused_records()
669 });
670 });
671}
672
673pub(crate) fn maybe_check_and_overwrite_unused_records_locked(current_snapshot_id: SnapshotId) {
674 let should_run = EXTRA_STATE_OBJECTS.with(|cell| {
675 let set = cell.borrow();
676 if set.is_empty() {
677 return false;
678 }
679 let last_cleanup = LAST_UNUSED_RECORD_CLEANUP.with(|last| last.get());
680 let interval = if set.len() >= UNUSED_RECORD_CLEANUP_MIN_SIZE {
681 UNUSED_RECORD_CLEANUP_BUSY_INTERVAL
682 } else {
683 UNUSED_RECORD_CLEANUP_INTERVAL
684 };
685 current_snapshot_id.saturating_sub(last_cleanup) >= interval
686 });
687
688 if should_run {
689 LAST_UNUSED_RECORD_CLEANUP.with(|cell| cell.set(current_snapshot_id));
690 check_and_overwrite_unused_records_locked();
691 }
692}
693
694#[cfg(test)]
695pub(crate) fn clear_unused_record_cleanup_for_tests() {
696 LAST_UNUSED_RECORD_CLEANUP.with(|cell| cell.set(0));
697}
698
699pub(crate) fn optimistic_merges(
700 current_snapshot_id: SnapshotId,
701 base_parent_id: SnapshotId,
702 modified_objects: &[(StateObjectId, Arc<dyn StateObject>, SnapshotId)],
703 invalid_snapshots: &SnapshotIdSet,
704 applying_invalid: &SnapshotIdSet,
705) -> Option<HashMap<usize, Rc<StateRecord>>> {
706 if modified_objects.is_empty() {
707 return None;
708 }
709
710 let mut result: Option<HashMap<usize, Rc<StateRecord>>> = None;
711
712 for (_, state, writer_id) in modified_objects.iter() {
713 let head = state.first_record();
714
715 let current = match crate::state::readable_record_for(
716 &head,
717 current_snapshot_id,
718 invalid_snapshots,
719 ) {
720 Some(record) => record,
721 None => continue,
722 };
723
724 let (previous_opt, found_base) =
726 mutable::find_previous_record(&head, base_parent_id, applying_invalid);
727 let previous = previous_opt?;
728
729 if !found_base || previous.snapshot_id() == crate::state::PREEXISTING_SNAPSHOT_ID {
730 continue;
731 }
732
733 if Rc::ptr_eq(¤t, &previous) {
734 continue;
735 }
736
737 let applied = mutable::find_record_by_id(&head, *writer_id)?;
738
739 let merged = state.merge_records(
740 Rc::clone(&previous),
741 Rc::clone(¤t),
742 Rc::clone(&applied),
743 )?;
744
745 result
746 .get_or_insert_with(HashMap::default)
747 .insert(Rc::as_ptr(¤t) as usize, merged);
748 }
749
750 result
751}
752
753#[allow(clippy::arc_with_non_send_sync)]
759pub fn merge_read_observers(
760 a: Option<ReadObserver>,
761 b: Option<ReadObserver>,
762) -> Option<ReadObserver> {
763 match (a, b) {
764 (None, None) => None,
765 (Some(a), None) => Some(a),
766 (None, Some(b)) => Some(b),
767 (Some(a), Some(b)) => Some(Arc::new(move |state: &dyn StateObject| {
768 a(state);
769 b(state);
770 })),
771 }
772}
773
774#[allow(clippy::arc_with_non_send_sync)]
780pub fn merge_write_observers(
781 a: Option<WriteObserver>,
782 b: Option<WriteObserver>,
783) -> Option<WriteObserver> {
784 match (a, b) {
785 (None, None) => None,
786 (Some(a), None) => Some(a),
787 (None, Some(b)) => Some(b),
788 (Some(a), Some(b)) => Some(Arc::new(move |state: &dyn StateObject| {
789 a(state);
790 b(state);
791 })),
792 }
793}
794
795pub(crate) struct SnapshotState {
797 pub(crate) id: Cell<SnapshotId>,
799 pub(crate) invalid: RefCell<SnapshotIdSet>,
801 pub(crate) pin_handle: Cell<PinHandle>,
803 pub(crate) disposed: Cell<bool>,
805 pub(crate) read_observer: RefCell<Option<ReadObserver>>,
807 pub(crate) write_observer: RefCell<Option<WriteObserver>>,
809 #[allow(clippy::type_complexity)]
811 pub(crate) modified: RefCell<HashMap<StateObjectId, (Arc<dyn StateObject>, SnapshotId)>>,
813 on_dispose: RefCell<Option<Box<dyn FnOnce()>>>,
815 runtime_tracked: bool,
817 pending_children: RefCell<HashSet<SnapshotId>>,
819}
820
821impl SnapshotState {
822 pub(crate) fn new(
823 id: SnapshotId,
824 invalid: SnapshotIdSet,
825 read_observer: Option<ReadObserver>,
826 write_observer: Option<WriteObserver>,
827 runtime_tracked: bool,
828 ) -> Self {
829 Self::new_with_pinning(
830 id,
831 invalid,
832 read_observer,
833 write_observer,
834 runtime_tracked,
835 true,
836 )
837 }
838
839 pub(crate) fn new_with_pinning(
844 id: SnapshotId,
845 invalid: SnapshotIdSet,
846 read_observer: Option<ReadObserver>,
847 write_observer: Option<WriteObserver>,
848 runtime_tracked: bool,
849 should_pin: bool,
850 ) -> Self {
851 let pin_handle = if should_pin {
852 snapshot_pinning::track_pinning(id, &invalid)
853 } else {
854 snapshot_pinning::PinHandle::INVALID
855 };
856 Self {
857 id: Cell::new(id),
858 invalid: RefCell::new(invalid),
859 pin_handle: Cell::new(pin_handle),
860 disposed: Cell::new(false),
861 read_observer: RefCell::new(read_observer),
862 write_observer: RefCell::new(write_observer),
863 modified: RefCell::new(HashMap::default()),
864 on_dispose: RefCell::new(None),
865 runtime_tracked,
866 pending_children: RefCell::new(HashSet::default()),
867 }
868 }
869
870 pub(crate) fn record_read(&self, state: &dyn StateObject) {
871 if let Some(observer) = self.read_observer.borrow().as_ref() {
872 observer(state);
873 }
874 }
875
876 pub(crate) fn record_write(&self, state: Arc<dyn StateObject>, writer_id: SnapshotId) {
877 let state_id = state.object_id().as_usize();
879
880 let mut modified = self.modified.borrow_mut();
881
882 match modified.entry(state_id) {
884 std::collections::hash_map::Entry::Vacant(e) => {
885 if let Some(observer) = self.write_observer.borrow().as_ref() {
886 observer(&*state);
887 }
888 e.insert((state, writer_id));
890 }
891 std::collections::hash_map::Entry::Occupied(mut e) => {
892 e.insert((state, writer_id));
894 }
895 }
896 }
897
898 pub(crate) fn dispose(&self) {
899 if !self.disposed.replace(true) {
900 let pin_handle = self.pin_handle.get();
901 snapshot_pinning::release_pinning(pin_handle);
902 if let Some(cb) = self.on_dispose.borrow_mut().take() {
903 cb();
904 }
905 if self.runtime_tracked {
906 close_snapshot(self.id.get());
907 }
908 }
909 }
910
911 pub(crate) fn add_pending_child(&self, id: SnapshotId) {
912 self.pending_children.borrow_mut().insert(id);
913 }
914
915 pub(crate) fn remove_pending_child(&self, id: SnapshotId) {
916 self.pending_children.borrow_mut().remove(&id);
917 }
918
919 pub(crate) fn has_pending_children(&self) -> bool {
920 !self.pending_children.borrow().is_empty()
921 }
922
923 pub(crate) fn pending_children(&self) -> Vec<SnapshotId> {
924 self.pending_children.borrow().iter().copied().collect()
925 }
926
927 pub(crate) fn set_on_dispose<F>(&self, f: F)
928 where
929 F: FnOnce() + 'static,
930 {
931 *self.on_dispose.borrow_mut() = Some(Box::new(f));
932 }
933}
934
935#[cfg(test)]
936mod tests {
937 use super::*;
938
939 #[test]
940 fn apply_observer_ids_do_not_use_process_global_counter() {
941 let source = include_str!("mod.rs");
942 assert!(!source.contains(concat!("NEXT_", "OBSERVER_ID")));
943 assert!(!source.contains(concat!("Atomic", "Usize")));
944 }
945
946 #[test]
947 fn test_apply_result_is_success() {
948 assert!(SnapshotApplyResult::Success.is_success());
949 assert!(!SnapshotApplyResult::Failure.is_success());
950 }
951
952 #[test]
953 fn test_apply_result_is_failure() {
954 assert!(!SnapshotApplyResult::Success.is_failure());
955 assert!(SnapshotApplyResult::Failure.is_failure());
956 }
957
958 #[test]
959 fn test_apply_result_check_success() {
960 SnapshotApplyResult::Success.check(); }
962
963 #[test]
964 #[should_panic(expected = "Snapshot apply failed")]
965 fn test_apply_result_check_failure() {
966 SnapshotApplyResult::Failure.check(); }
968
969 #[test]
970 fn snapshot_enter_restores_current_snapshot_after_panic() {
971 let _guard = reset_runtime_for_tests();
972 let snapshot = take_mutable_snapshot(None, None);
973
974 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
975 snapshot.enter(|| panic!("snapshot body panic"));
976 }));
977
978 assert!(result.is_err());
979 assert!(
980 current_snapshot().is_none(),
981 "snapshot enter must restore the previous current snapshot while unwinding"
982 );
983 }
984
985 #[test]
986 fn test_merge_read_observers_both_none() {
987 let result = merge_read_observers(None, None);
988 assert!(result.is_none());
989 }
990
991 #[test]
992 fn test_merge_read_observers_one_some() {
993 let observer = Arc::new(|_: &dyn StateObject| {});
994 let result = merge_read_observers(Some(observer.clone()), None);
995 assert!(result.is_some());
996
997 let result = merge_read_observers(None, Some(observer));
998 assert!(result.is_some());
999 }
1000
1001 #[test]
1002 fn test_merge_write_observers_both_none() {
1003 let result = merge_write_observers(None, None);
1004 assert!(result.is_none());
1005 }
1006
1007 #[test]
1008 fn test_merge_write_observers_one_some() {
1009 let observer = Arc::new(|_: &dyn StateObject| {});
1010 let result = merge_write_observers(Some(observer.clone()), None);
1011 assert!(result.is_some());
1012
1013 let result = merge_write_observers(None, Some(observer));
1014 assert!(result.is_some());
1015 }
1016
1017 #[test]
1018 fn test_current_snapshot_none_initially() {
1019 set_current_snapshot(None);
1020 assert!(current_snapshot().is_none());
1021 }
1022
1023 struct TestStateObject {
1025 id: usize,
1026 }
1027
1028 impl TestStateObject {
1029 fn new(id: usize) -> Arc<Self> {
1030 Arc::new(Self { id })
1031 }
1032 }
1033
1034 impl StateObject for TestStateObject {
1035 fn object_id(&self) -> crate::state::ObjectId {
1036 crate::state::ObjectId(self.id)
1037 }
1038
1039 fn first_record(&self) -> Rc<crate::state::StateRecord> {
1040 unimplemented!("Not needed for observer tests")
1041 }
1042
1043 fn try_readable_record(
1044 &self,
1045 _snapshot_id: SnapshotId,
1046 _invalid: &SnapshotIdSet,
1047 ) -> Option<Rc<crate::state::StateRecord>> {
1048 None
1049 }
1050
1051 fn readable_record(
1052 &self,
1053 _snapshot_id: SnapshotId,
1054 _invalid: &SnapshotIdSet,
1055 ) -> Rc<crate::state::StateRecord> {
1056 unimplemented!("Not needed for observer tests")
1057 }
1058
1059 fn prepend_state_record(&self, _record: Rc<crate::state::StateRecord>) {
1060 unimplemented!("Not needed for observer tests")
1061 }
1062
1063 fn promote_record(&self, _child_id: SnapshotId) -> Result<(), &'static str> {
1064 unimplemented!("Not needed for observer tests")
1065 }
1066
1067 fn as_any(&self) -> &dyn std::any::Any {
1068 self
1069 }
1070 }
1071
1072 #[test]
1073 fn test_apply_observer_receives_correct_modified_objects() {
1074 use std::sync::Mutex;
1075
1076 let received_count = Arc::new(Mutex::new(0));
1078 let received_snapshot_id = Arc::new(Mutex::new(0));
1079
1080 let received_count_clone = received_count.clone();
1081 let received_snapshot_id_clone = received_snapshot_id.clone();
1082
1083 let _handle = register_apply_observer(Rc::new(move |modified, snapshot_id| {
1085 *received_snapshot_id_clone.lock().unwrap() = snapshot_id;
1086 *received_count_clone.lock().unwrap() = modified.len();
1087 }));
1088
1089 let obj1: Arc<dyn StateObject> = TestStateObject::new(42);
1091 let obj2: Arc<dyn StateObject> = TestStateObject::new(99);
1092 let modified = vec![obj1, obj2];
1093
1094 notify_apply_observers(&modified, 123);
1096
1097 assert_eq!(*received_snapshot_id.lock().unwrap(), 123);
1099 assert_eq!(*received_count.lock().unwrap(), 2);
1100 }
1101
1102 #[test]
1103 fn test_apply_observer_receives_correct_snapshot_id() {
1104 use std::sync::Mutex;
1105
1106 let received_id = Arc::new(Mutex::new(0));
1107 let received_id_clone = received_id.clone();
1108
1109 let _handle = register_apply_observer(Rc::new(move |_, snapshot_id| {
1110 *received_id_clone.lock().unwrap() = snapshot_id;
1111 }));
1112
1113 notify_apply_observers(&[], 456);
1115
1116 assert_eq!(*received_id.lock().unwrap(), 456);
1117 }
1118
1119 #[test]
1120 fn test_multiple_apply_observers_all_called() {
1121 use std::sync::Mutex;
1122
1123 let call_count1 = Arc::new(Mutex::new(0));
1124 let call_count2 = Arc::new(Mutex::new(0));
1125 let call_count3 = Arc::new(Mutex::new(0));
1126
1127 let call_count1_clone = call_count1.clone();
1128 let call_count2_clone = call_count2.clone();
1129 let call_count3_clone = call_count3.clone();
1130
1131 let _handle1 = register_apply_observer(Rc::new(move |_, _| {
1133 *call_count1_clone.lock().unwrap() += 1;
1134 }));
1135
1136 let _handle2 = register_apply_observer(Rc::new(move |_, _| {
1137 *call_count2_clone.lock().unwrap() += 1;
1138 }));
1139
1140 let _handle3 = register_apply_observer(Rc::new(move |_, _| {
1141 *call_count3_clone.lock().unwrap() += 1;
1142 }));
1143
1144 notify_apply_observers(&[], 1);
1146
1147 assert_eq!(*call_count1.lock().unwrap(), 1);
1149 assert_eq!(*call_count2.lock().unwrap(), 1);
1150 assert_eq!(*call_count3.lock().unwrap(), 1);
1151
1152 notify_apply_observers(&[], 2);
1154
1155 assert_eq!(*call_count1.lock().unwrap(), 2);
1157 assert_eq!(*call_count2.lock().unwrap(), 2);
1158 assert_eq!(*call_count3.lock().unwrap(), 2);
1159 }
1160
1161 #[test]
1162 fn test_apply_observer_not_called_for_empty_modifications() {
1163 use std::sync::Mutex;
1164
1165 let call_count = Arc::new(Mutex::new(0));
1166 let call_count_clone = call_count.clone();
1167
1168 let _handle = register_apply_observer(Rc::new(move |modified, _| {
1169 *call_count_clone.lock().unwrap() += 1;
1171 assert_eq!(modified.len(), 0);
1172 }));
1173
1174 notify_apply_observers(&[], 1);
1176
1177 assert_eq!(*call_count.lock().unwrap(), 1);
1179 }
1180
1181 #[test]
1182 fn test_observer_handle_drop_removes_correct_observer() {
1183 use std::sync::Mutex;
1184
1185 let calls = Arc::new(Mutex::new(Vec::new()));
1187
1188 let calls1 = calls.clone();
1189 let handle1 = register_apply_observer(Rc::new(move |_, _| {
1190 calls1.lock().unwrap().push(1);
1191 }));
1192
1193 let calls2 = calls.clone();
1194 let handle2 = register_apply_observer(Rc::new(move |_, _| {
1195 calls2.lock().unwrap().push(2);
1196 }));
1197
1198 let calls3 = calls.clone();
1199 let handle3 = register_apply_observer(Rc::new(move |_, _| {
1200 calls3.lock().unwrap().push(3);
1201 }));
1202
1203 notify_apply_observers(&[], 1);
1205 let result = calls.lock().unwrap().clone();
1206 assert_eq!(result.len(), 3);
1207 assert!(result.contains(&1));
1208 assert!(result.contains(&2));
1209 assert!(result.contains(&3));
1210 calls.lock().unwrap().clear();
1211
1212 drop(handle2);
1214
1215 notify_apply_observers(&[], 2);
1217 let result = calls.lock().unwrap().clone();
1218 assert_eq!(result.len(), 2);
1219 assert!(result.contains(&1));
1220 assert!(result.contains(&3));
1221 assert!(!result.contains(&2));
1222 calls.lock().unwrap().clear();
1223
1224 drop(handle1);
1226
1227 notify_apply_observers(&[], 3);
1229 let result = calls.lock().unwrap().clone();
1230 assert_eq!(result.len(), 1);
1231 assert!(result.contains(&3));
1232 calls.lock().unwrap().clear();
1233
1234 drop(handle3);
1236
1237 notify_apply_observers(&[], 4);
1239 assert_eq!(calls.lock().unwrap().len(), 0);
1240 }
1241
1242 #[test]
1243 fn test_observer_handle_drop_in_different_orders() {
1244 use std::sync::Mutex;
1245
1246 {
1248 let calls = Arc::new(Mutex::new(Vec::new()));
1249
1250 let calls1 = calls.clone();
1251 let h1 = register_apply_observer(Rc::new(move |_, _| {
1252 calls1.lock().unwrap().push(1);
1253 }));
1254
1255 let calls2 = calls.clone();
1256 let h2 = register_apply_observer(Rc::new(move |_, _| {
1257 calls2.lock().unwrap().push(2);
1258 }));
1259
1260 let calls3 = calls.clone();
1261 let h3 = register_apply_observer(Rc::new(move |_, _| {
1262 calls3.lock().unwrap().push(3);
1263 }));
1264
1265 drop(h3);
1266 notify_apply_observers(&[], 1);
1267 let result = calls.lock().unwrap().clone();
1268 assert!(result.contains(&1) && result.contains(&2) && !result.contains(&3));
1269 calls.lock().unwrap().clear();
1270
1271 drop(h2);
1272 notify_apply_observers(&[], 2);
1273 let result = calls.lock().unwrap().clone();
1274 assert_eq!(result.len(), 1);
1275 assert!(result.contains(&1));
1276 calls.lock().unwrap().clear();
1277
1278 drop(h1);
1279 notify_apply_observers(&[], 3);
1280 assert_eq!(calls.lock().unwrap().len(), 0);
1281 }
1282
1283 {
1285 let calls = Arc::new(Mutex::new(Vec::new()));
1286
1287 let calls1 = calls.clone();
1288 let h1 = register_apply_observer(Rc::new(move |_, _| {
1289 calls1.lock().unwrap().push(1);
1290 }));
1291
1292 let calls2 = calls.clone();
1293 let h2 = register_apply_observer(Rc::new(move |_, _| {
1294 calls2.lock().unwrap().push(2);
1295 }));
1296
1297 let calls3 = calls.clone();
1298 let h3 = register_apply_observer(Rc::new(move |_, _| {
1299 calls3.lock().unwrap().push(3);
1300 }));
1301
1302 drop(h1);
1303 notify_apply_observers(&[], 1);
1304 let result = calls.lock().unwrap().clone();
1305 assert!(!result.contains(&1) && result.contains(&2) && result.contains(&3));
1306 calls.lock().unwrap().clear();
1307
1308 drop(h2);
1309 notify_apply_observers(&[], 2);
1310 let result = calls.lock().unwrap().clone();
1311 assert_eq!(result.len(), 1);
1312 assert!(result.contains(&3));
1313 calls.lock().unwrap().clear();
1314
1315 drop(h3);
1316 notify_apply_observers(&[], 3);
1317 assert_eq!(calls.lock().unwrap().len(), 0);
1318 }
1319 }
1320
1321 #[test]
1322 fn test_remaining_observers_still_work_after_drop() {
1323 use std::sync::Mutex;
1324
1325 let calls = Arc::new(Mutex::new(Vec::new()));
1326
1327 let calls1 = calls.clone();
1328 let handle1 = register_apply_observer(Rc::new(move |_, snapshot_id| {
1329 calls1.lock().unwrap().push((1, snapshot_id));
1330 }));
1331
1332 let calls2 = calls.clone();
1333 let handle2 = register_apply_observer(Rc::new(move |_, snapshot_id| {
1334 calls2.lock().unwrap().push((2, snapshot_id));
1335 }));
1336
1337 notify_apply_observers(&[], 100);
1339 assert_eq!(calls.lock().unwrap().len(), 2);
1340 calls.lock().unwrap().clear();
1341
1342 drop(handle1);
1344
1345 notify_apply_observers(&[], 200);
1347 assert_eq!(*calls.lock().unwrap(), vec![(2, 200)]);
1348 calls.lock().unwrap().clear();
1349
1350 let calls3 = calls.clone();
1352 let _handle3 = register_apply_observer(Rc::new(move |_, snapshot_id| {
1353 calls3.lock().unwrap().push((3, snapshot_id));
1354 }));
1355
1356 notify_apply_observers(&[], 300);
1358 let result = calls.lock().unwrap().clone();
1359 assert_eq!(result.len(), 2);
1360 assert!(result.contains(&(2, 300)));
1361 assert!(result.contains(&(3, 300)));
1362
1363 drop(handle2);
1364 }
1365
1366 #[test]
1367 fn test_observer_ids_are_unique() {
1368 use std::sync::Mutex;
1369
1370 let ids = Arc::new(Mutex::new(std::collections::HashSet::new()));
1371
1372 let mut handles = Vec::new();
1373
1374 for i in 0..100 {
1378 let ids_clone = ids.clone();
1379 let handle = register_apply_observer(Rc::new(move |_, _| {
1380 ids_clone.lock().unwrap().insert(i);
1381 }));
1382 handles.push(handle);
1383 }
1384
1385 notify_apply_observers(&[], 1);
1387 assert_eq!(ids.lock().unwrap().len(), 100);
1388
1389 for i in (0..100).step_by(2) {
1391 handles.remove(i / 2);
1392 }
1393
1394 ids.lock().unwrap().clear();
1396 notify_apply_observers(&[], 2);
1397 assert_eq!(ids.lock().unwrap().len(), 50);
1398 }
1399
1400 #[test]
1401 fn test_state_object_storage_in_modified_set() {
1402 use crate::state::StateObject;
1403
1404 struct TestState;
1406
1407 impl StateObject for TestState {
1408 fn object_id(&self) -> crate::state::ObjectId {
1409 crate::state::ObjectId(12345)
1410 }
1411
1412 fn first_record(&self) -> Rc<crate::state::StateRecord> {
1413 unimplemented!("Not needed for this test")
1414 }
1415
1416 fn try_readable_record(
1417 &self,
1418 _snapshot_id: SnapshotId,
1419 _invalid: &SnapshotIdSet,
1420 ) -> Option<Rc<crate::state::StateRecord>> {
1421 None
1422 }
1423
1424 fn readable_record(
1425 &self,
1426 _snapshot_id: SnapshotId,
1427 _invalid: &SnapshotIdSet,
1428 ) -> Rc<crate::state::StateRecord> {
1429 unimplemented!("Not needed for this test")
1430 }
1431
1432 fn prepend_state_record(&self, _record: Rc<crate::state::StateRecord>) {
1433 unimplemented!("Not needed for this test")
1434 }
1435
1436 fn promote_record(&self, _child_id: SnapshotId) -> Result<(), &'static str> {
1437 unimplemented!("Not needed for this test")
1438 }
1439
1440 fn as_any(&self) -> &dyn std::any::Any {
1441 self
1442 }
1443 }
1444
1445 let state = SnapshotState::new(1, SnapshotIdSet::new(), None, None, false);
1446
1447 let state_obj = Arc::new(TestState) as Arc<dyn StateObject>;
1449
1450 state.record_write(state_obj.clone(), 1);
1452
1453 let modified = state.modified.borrow();
1455 assert_eq!(modified.len(), 1);
1456 assert!(modified.contains_key(&12345));
1457
1458 let (stored, writer_id) = modified.get(&12345).unwrap();
1460 assert_eq!(stored.object_id().as_usize(), 12345);
1461 assert_eq!(*writer_id, 1);
1462 }
1463
1464 #[test]
1465 fn test_multiple_writes_to_same_state_object() {
1466 use crate::state::StateObject;
1467
1468 struct TestState;
1469
1470 impl StateObject for TestState {
1471 fn object_id(&self) -> crate::state::ObjectId {
1472 crate::state::ObjectId(99999)
1473 }
1474
1475 fn first_record(&self) -> Rc<crate::state::StateRecord> {
1476 unimplemented!()
1477 }
1478
1479 fn try_readable_record(
1480 &self,
1481 _snapshot_id: SnapshotId,
1482 _invalid: &SnapshotIdSet,
1483 ) -> Option<Rc<crate::state::StateRecord>> {
1484 None
1485 }
1486
1487 fn readable_record(
1488 &self,
1489 _snapshot_id: SnapshotId,
1490 _invalid: &SnapshotIdSet,
1491 ) -> Rc<crate::state::StateRecord> {
1492 unimplemented!()
1493 }
1494
1495 fn prepend_state_record(&self, _record: Rc<crate::state::StateRecord>) {
1496 unimplemented!()
1497 }
1498
1499 fn promote_record(&self, _child_id: SnapshotId) -> Result<(), &'static str> {
1500 unimplemented!()
1501 }
1502
1503 fn as_any(&self) -> &dyn std::any::Any {
1504 self
1505 }
1506 }
1507
1508 let state = SnapshotState::new(1, SnapshotIdSet::new(), None, None, false);
1509 let state_obj = Arc::new(TestState) as Arc<dyn StateObject>;
1510
1511 state.record_write(state_obj.clone(), 1);
1513 assert_eq!(state.modified.borrow().len(), 1);
1514
1515 state.record_write(state_obj.clone(), 2);
1517 let modified = state.modified.borrow();
1518 assert_eq!(modified.len(), 1);
1519 assert!(modified.contains_key(&99999));
1520 let (_, writer_id) = modified.get(&99999).unwrap();
1521 assert_eq!(*writer_id, 2);
1522 }
1523}