1#![allow(clippy::arc_with_non_send_sync)]
24
25use crate::collections::map::HashMap; use crate::collections::map::HashSet;
27use crate::snapshot_id_set::{SnapshotId, SnapshotIdSet};
28use crate::snapshot_pinning::{self, PinHandle};
29use crate::state::{StateObject, StateRecord};
30use std::cell::{Cell, RefCell};
31use std::rc::Rc;
32use std::sync::atomic::{AtomicUsize, Ordering};
33use std::sync::{Arc, Weak};
34
35mod global;
36mod mutable;
37mod nested;
38mod readonly;
39mod runtime;
40mod transparent;
41
42#[cfg(test)]
43mod integration_tests;
44
45pub use global::{advance_global_snapshot, GlobalSnapshot};
46pub use mutable::MutableSnapshot;
47pub use nested::{NestedMutableSnapshot, NestedReadonlySnapshot};
48pub use readonly::ReadonlySnapshot;
49pub use transparent::{TransparentObserverMutableSnapshot, TransparentObserverSnapshot};
50
51pub(crate) use runtime::{allocate_snapshot, close_snapshot, with_runtime};
52#[cfg(test)]
53pub(crate) use runtime::{reset_runtime_for_tests, TestRuntimeGuard};
54
55pub type ReadObserver = Arc<dyn Fn(&dyn StateObject) + 'static>;
57
58pub type WriteObserver = Arc<dyn Fn(&dyn StateObject) + 'static>;
60
61pub type ApplyObserver = Arc<dyn Fn(&[Arc<dyn StateObject>], SnapshotId) + 'static>;
63
64#[derive(Debug, Clone, PartialEq, Eq)]
66pub enum SnapshotApplyResult {
67 Success,
69 Failure,
71}
72
73impl SnapshotApplyResult {
74 pub fn is_success(&self) -> bool {
76 matches!(self, SnapshotApplyResult::Success)
77 }
78
79 pub fn is_failure(&self) -> bool {
81 matches!(self, SnapshotApplyResult::Failure)
82 }
83
84 #[track_caller]
86 pub fn check(&self) {
87 if self.is_failure() {
88 panic!("Snapshot apply failed");
89 }
90 }
91}
92
93pub type StateObjectId = usize;
95
96#[derive(Clone)]
101pub enum AnySnapshot {
102 Readonly(Arc<ReadonlySnapshot>),
103 Mutable(Arc<MutableSnapshot>),
104 NestedReadonly(Arc<NestedReadonlySnapshot>),
105 NestedMutable(Arc<NestedMutableSnapshot>),
106 Global(Arc<GlobalSnapshot>),
107 TransparentMutable(Arc<TransparentObserverMutableSnapshot>),
108 TransparentReadonly(Arc<TransparentObserverSnapshot>),
109}
110
111impl AnySnapshot {
112 pub fn snapshot_id(&self) -> SnapshotId {
114 match self {
115 AnySnapshot::Readonly(s) => s.snapshot_id(),
116 AnySnapshot::Mutable(s) => s.snapshot_id(),
117 AnySnapshot::NestedReadonly(s) => s.snapshot_id(),
118 AnySnapshot::NestedMutable(s) => s.snapshot_id(),
119 AnySnapshot::Global(s) => s.snapshot_id(),
120 AnySnapshot::TransparentMutable(s) => s.snapshot_id(),
121 AnySnapshot::TransparentReadonly(s) => s.snapshot_id(),
122 }
123 }
124
125 pub fn invalid(&self) -> SnapshotIdSet {
127 match self {
128 AnySnapshot::Readonly(s) => s.invalid(),
129 AnySnapshot::Mutable(s) => s.invalid(),
130 AnySnapshot::NestedReadonly(s) => s.invalid(),
131 AnySnapshot::NestedMutable(s) => s.invalid(),
132 AnySnapshot::Global(s) => s.invalid(),
133 AnySnapshot::TransparentMutable(s) => s.invalid(),
134 AnySnapshot::TransparentReadonly(s) => s.invalid(),
135 }
136 }
137
138 pub fn is_valid(&self, id: SnapshotId) -> bool {
140 let snapshot_id = self.snapshot_id();
141 id <= snapshot_id && !self.invalid().get(id)
142 }
143
144 pub fn read_only(&self) -> bool {
146 match self {
147 AnySnapshot::Readonly(_) => true,
148 AnySnapshot::Mutable(_) => false,
149 AnySnapshot::NestedReadonly(_) => true,
150 AnySnapshot::NestedMutable(_) => false,
151 AnySnapshot::Global(_) => false,
152 AnySnapshot::TransparentMutable(_) => false,
153 AnySnapshot::TransparentReadonly(_) => true,
154 }
155 }
156
157 pub fn root(&self) -> AnySnapshot {
159 match self {
160 AnySnapshot::Readonly(s) => AnySnapshot::Readonly(s.root_readonly()),
161 AnySnapshot::Mutable(s) => AnySnapshot::Mutable(s.root_mutable()),
162 AnySnapshot::NestedReadonly(s) => AnySnapshot::NestedReadonly(s.root_nested_readonly()),
163 AnySnapshot::NestedMutable(s) => AnySnapshot::Mutable(s.root_mutable()),
164 AnySnapshot::Global(s) => AnySnapshot::Global(s.root_global()),
165 AnySnapshot::TransparentMutable(s) => {
166 AnySnapshot::TransparentMutable(s.root_transparent_mutable())
167 }
168 AnySnapshot::TransparentReadonly(s) => {
169 AnySnapshot::TransparentReadonly(s.root_transparent_readonly())
170 }
171 }
172 }
173
174 pub fn is_same_transparent(&self, other: &Arc<TransparentObserverMutableSnapshot>) -> bool {
176 matches!(self, AnySnapshot::TransparentMutable(snapshot) if Arc::ptr_eq(snapshot, other))
177 }
178
179 pub fn is_same_transparent_mutable(
181 &self,
182 other: &Arc<TransparentObserverMutableSnapshot>,
183 ) -> bool {
184 self.is_same_transparent(other)
185 }
186
187 pub fn is_same_transparent_readonly(&self, other: &Arc<TransparentObserverSnapshot>) -> bool {
189 matches!(self, AnySnapshot::TransparentReadonly(snapshot) if Arc::ptr_eq(snapshot, other))
190 }
191
192 pub fn enter<T>(&self, f: impl FnOnce() -> T) -> T {
194 match self {
195 AnySnapshot::Readonly(s) => s.enter(f),
196 AnySnapshot::Mutable(s) => s.enter(f),
197 AnySnapshot::NestedReadonly(s) => s.enter(f),
198 AnySnapshot::NestedMutable(s) => s.enter(f),
199 AnySnapshot::Global(s) => s.enter(f),
200 AnySnapshot::TransparentMutable(s) => s.enter(f),
201 AnySnapshot::TransparentReadonly(s) => s.enter(f),
202 }
203 }
204
205 pub fn take_nested_snapshot(&self, read_observer: Option<ReadObserver>) -> AnySnapshot {
207 match self {
208 AnySnapshot::Readonly(s) => {
209 AnySnapshot::Readonly(s.take_nested_snapshot(read_observer))
210 }
211 AnySnapshot::Mutable(s) => AnySnapshot::Readonly(s.take_nested_snapshot(read_observer)),
212 AnySnapshot::NestedReadonly(s) => {
213 AnySnapshot::NestedReadonly(s.take_nested_snapshot(read_observer))
214 }
215 AnySnapshot::NestedMutable(s) => {
216 AnySnapshot::Readonly(s.take_nested_snapshot(read_observer))
217 }
218 AnySnapshot::Global(s) => AnySnapshot::Readonly(s.take_nested_snapshot(read_observer)),
219 AnySnapshot::TransparentMutable(s) => {
220 AnySnapshot::Readonly(s.take_nested_snapshot(read_observer))
221 }
222 AnySnapshot::TransparentReadonly(s) => {
223 AnySnapshot::TransparentReadonly(s.take_nested_snapshot(read_observer))
224 }
225 }
226 }
227
228 pub fn has_pending_changes(&self) -> bool {
230 match self {
231 AnySnapshot::Readonly(s) => s.has_pending_changes(),
232 AnySnapshot::Mutable(s) => s.has_pending_changes(),
233 AnySnapshot::NestedReadonly(s) => s.has_pending_changes(),
234 AnySnapshot::NestedMutable(s) => s.has_pending_changes(),
235 AnySnapshot::Global(s) => s.has_pending_changes(),
236 AnySnapshot::TransparentMutable(s) => s.has_pending_changes(),
237 AnySnapshot::TransparentReadonly(s) => s.has_pending_changes(),
238 }
239 }
240
241 pub fn dispose(&self) {
243 match self {
244 AnySnapshot::Readonly(s) => s.dispose(),
245 AnySnapshot::Mutable(s) => s.dispose(),
246 AnySnapshot::NestedReadonly(s) => s.dispose(),
247 AnySnapshot::NestedMutable(s) => s.dispose(),
248 AnySnapshot::Global(s) => s.dispose(),
249 AnySnapshot::TransparentMutable(s) => s.dispose(),
250 AnySnapshot::TransparentReadonly(s) => s.dispose(),
251 }
252 }
253
254 pub fn is_disposed(&self) -> bool {
256 match self {
257 AnySnapshot::Readonly(s) => s.is_disposed(),
258 AnySnapshot::Mutable(s) => s.is_disposed(),
259 AnySnapshot::NestedReadonly(s) => s.is_disposed(),
260 AnySnapshot::NestedMutable(s) => s.is_disposed(),
261 AnySnapshot::Global(s) => s.is_disposed(),
262 AnySnapshot::TransparentMutable(s) => s.is_disposed(),
263 AnySnapshot::TransparentReadonly(s) => s.is_disposed(),
264 }
265 }
266
267 pub fn record_read(&self, state: &dyn StateObject) {
269 match self {
270 AnySnapshot::Readonly(s) => s.record_read(state),
271 AnySnapshot::Mutable(s) => s.record_read(state),
272 AnySnapshot::NestedReadonly(s) => s.record_read(state),
273 AnySnapshot::NestedMutable(s) => s.record_read(state),
274 AnySnapshot::Global(s) => s.record_read(state),
275 AnySnapshot::TransparentMutable(s) => s.record_read(state),
276 AnySnapshot::TransparentReadonly(s) => s.record_read(state),
277 }
278 }
279
280 pub fn record_write(&self, state: Arc<dyn StateObject>) {
282 match self {
283 AnySnapshot::Readonly(s) => s.record_write(state),
284 AnySnapshot::Mutable(s) => s.record_write(state),
285 AnySnapshot::NestedReadonly(s) => s.record_write(state),
286 AnySnapshot::NestedMutable(s) => s.record_write(state),
287 AnySnapshot::Global(s) => s.record_write(state),
288 AnySnapshot::TransparentMutable(s) => s.record_write(state),
289 AnySnapshot::TransparentReadonly(s) => s.record_write(state),
290 }
291 }
292
293 pub fn apply(&self) -> SnapshotApplyResult {
295 match self {
296 AnySnapshot::Mutable(s) => s.apply(),
297 AnySnapshot::NestedMutable(s) => s.apply(),
298 AnySnapshot::Global(s) => s.apply(),
299 AnySnapshot::TransparentMutable(s) => s.apply(),
300 _ => panic!("Cannot apply a read-only snapshot"),
301 }
302 }
303
304 pub fn take_nested_mutable_snapshot(
306 &self,
307 read_observer: Option<ReadObserver>,
308 write_observer: Option<WriteObserver>,
309 ) -> AnySnapshot {
310 match self {
311 AnySnapshot::Mutable(s) => AnySnapshot::NestedMutable(
312 s.take_nested_mutable_snapshot(read_observer, write_observer),
313 ),
314 AnySnapshot::NestedMutable(s) => AnySnapshot::NestedMutable(
315 s.take_nested_mutable_snapshot(read_observer, write_observer),
316 ),
317 AnySnapshot::Global(s) => {
318 AnySnapshot::Mutable(s.take_nested_mutable_snapshot(read_observer, write_observer))
319 }
320 AnySnapshot::TransparentMutable(s) => AnySnapshot::TransparentMutable(
321 s.take_nested_mutable_snapshot(read_observer, write_observer),
322 ),
323 _ => panic!("Cannot take nested mutable snapshot from read-only snapshot"),
324 }
325 }
326}
327
328thread_local! {
329 static CURRENT_SNAPSHOT: RefCell<Option<AnySnapshot>> = const { RefCell::new(None) };
331}
332
333pub fn current_snapshot() -> Option<AnySnapshot> {
335 CURRENT_SNAPSHOT
336 .try_with(|cell| cell.borrow().clone())
337 .unwrap_or(None)
338}
339
340pub(crate) fn set_current_snapshot(snapshot: Option<AnySnapshot>) {
342 let _ = CURRENT_SNAPSHOT.try_with(|cell| {
343 *cell.borrow_mut() = snapshot;
344 });
345}
346
347pub fn take_mutable_snapshot(
352 read_observer: Option<ReadObserver>,
353 write_observer: Option<WriteObserver>,
354) -> Arc<MutableSnapshot> {
355 GlobalSnapshot::get_or_create().take_nested_mutable_snapshot(read_observer, write_observer)
356}
357
358pub fn take_transparent_observer_mutable_snapshot(
367 read_observer: Option<ReadObserver>,
368 write_observer: Option<WriteObserver>,
369) -> Arc<TransparentObserverMutableSnapshot> {
370 let parent = current_snapshot();
371 match parent {
372 Some(AnySnapshot::TransparentMutable(transparent)) if transparent.can_reuse() => {
373 transparent
375 }
376 _ => {
377 let current = current_snapshot()
380 .unwrap_or_else(|| AnySnapshot::Global(GlobalSnapshot::get_or_create()));
381 let id = current.snapshot_id();
382 let invalid = current.invalid();
383 TransparentObserverMutableSnapshot::new(
384 id,
385 invalid,
386 read_observer,
387 write_observer,
388 None,
389 )
390 }
391 }
392}
393
394pub fn allocate_record_id() -> SnapshotId {
396 runtime::allocate_record_id()
397}
398
399pub(crate) fn peek_next_snapshot_id() -> SnapshotId {
404 runtime::peek_next_snapshot_id()
405}
406
407static NEXT_OBSERVER_ID: AtomicUsize = AtomicUsize::new(1);
409
410thread_local! {
411 static APPLY_OBSERVERS: RefCell<HashMap<usize, ApplyObserver>> = RefCell::new(HashMap::default());
413}
414
415thread_local! {
416 static LAST_WRITES: RefCell<HashMap<StateObjectId, SnapshotId>> = RefCell::new(HashMap::default());
424}
425
426thread_local! {
427 static EXTRA_STATE_OBJECTS: RefCell<crate::snapshot_weak_set::SnapshotWeakSet> = RefCell::new(crate::snapshot_weak_set::SnapshotWeakSet::new());
430}
431
432const UNUSED_RECORD_CLEANUP_INTERVAL: SnapshotId = 2;
433const UNUSED_RECORD_CLEANUP_BUSY_INTERVAL: SnapshotId = 1;
434const UNUSED_RECORD_CLEANUP_MIN_SIZE: usize = 64;
435
436thread_local! {
437 static LAST_UNUSED_RECORD_CLEANUP: Cell<SnapshotId> = const { Cell::new(0) };
438}
439
440pub fn register_apply_observer(observer: ApplyObserver) -> ObserverHandle {
444 let id = NEXT_OBSERVER_ID.fetch_add(1, Ordering::SeqCst);
445 APPLY_OBSERVERS.with(|cell| {
446 cell.borrow_mut().insert(id, observer);
447 });
448 ObserverHandle {
449 kind: ObserverKind::Apply,
450 id,
451 }
452}
453
454pub struct ObserverHandle {
458 kind: ObserverKind,
459 id: usize,
460}
461
462enum ObserverKind {
463 Apply,
464}
465
466impl Drop for ObserverHandle {
467 fn drop(&mut self) {
468 match self.kind {
469 ObserverKind::Apply => {
470 APPLY_OBSERVERS.with(|cell| {
471 cell.borrow_mut().remove(&self.id);
472 });
473 }
474 }
475 }
476}
477
478pub(crate) fn notify_apply_observers(modified: &[Arc<dyn StateObject>], snapshot_id: SnapshotId) {
480 APPLY_OBSERVERS.with(|cell| {
482 let observers: Vec<ApplyObserver> = cell.borrow().values().cloned().collect();
483 for observer in observers.into_iter() {
484 observer(modified, snapshot_id);
485 }
486 });
487}
488
489#[allow(dead_code)]
491pub(crate) fn get_last_write(id: StateObjectId) -> Option<SnapshotId> {
492 LAST_WRITES.with(|cell| cell.borrow().get(&id).copied())
493}
494
495pub(crate) fn set_last_write(id: StateObjectId, snapshot_id: SnapshotId) {
497 LAST_WRITES.with(|cell| {
498 cell.borrow_mut().insert(id, snapshot_id);
499 });
500}
501
502#[cfg(test)]
504pub(crate) fn clear_last_writes() {
505 LAST_WRITES.with(|cell| {
506 cell.borrow_mut().clear();
507 });
508}
509
510pub(crate) fn check_and_overwrite_unused_records_locked() {
518 EXTRA_STATE_OBJECTS.with(|cell| {
519 cell.borrow_mut().remove_if(|state| {
520 state.overwrite_unused_records()
522 });
523 });
524}
525
526pub(crate) fn maybe_check_and_overwrite_unused_records_locked(current_snapshot_id: SnapshotId) {
527 let should_run = EXTRA_STATE_OBJECTS.with(|cell| {
528 let set = cell.borrow();
529 if set.is_empty() {
530 return false;
531 }
532 let last_cleanup = LAST_UNUSED_RECORD_CLEANUP.with(|last| last.get());
533 let interval = if set.len() >= UNUSED_RECORD_CLEANUP_MIN_SIZE {
534 UNUSED_RECORD_CLEANUP_BUSY_INTERVAL
535 } else {
536 UNUSED_RECORD_CLEANUP_INTERVAL
537 };
538 current_snapshot_id.saturating_sub(last_cleanup) >= interval
539 });
540
541 if should_run {
542 LAST_UNUSED_RECORD_CLEANUP.with(|cell| cell.set(current_snapshot_id));
543 check_and_overwrite_unused_records_locked();
544 }
545}
546
547#[allow(dead_code)]
553pub(crate) fn process_for_unused_records_locked(state: &Arc<dyn crate::state::StateObject>) {
554 if state.overwrite_unused_records() {
555 EXTRA_STATE_OBJECTS.with(|cell| {
557 cell.borrow_mut().add_trait_object(state);
558 });
559 }
560}
561
562#[cfg(test)]
563pub(crate) fn clear_unused_record_cleanup_for_tests() {
564 LAST_UNUSED_RECORD_CLEANUP.with(|cell| cell.set(0));
565}
566
567pub(crate) fn optimistic_merges(
568 current_snapshot_id: SnapshotId,
569 base_parent_id: SnapshotId,
570 modified_objects: &[(StateObjectId, Arc<dyn StateObject>, SnapshotId)],
571 invalid_snapshots: &SnapshotIdSet,
572) -> Option<HashMap<usize, Rc<StateRecord>>> {
573 if modified_objects.is_empty() {
574 return None;
575 }
576
577 let mut result: Option<HashMap<usize, Rc<StateRecord>>> = None;
578
579 for (_, state, writer_id) in modified_objects.iter() {
580 let head = state.first_record();
581
582 let current = match crate::state::readable_record_for(
583 &head,
584 current_snapshot_id,
585 invalid_snapshots,
586 ) {
587 Some(record) => record,
588 None => continue,
589 };
590
591 let (previous_opt, found_base) = mutable::find_previous_record(&head, base_parent_id);
592 let previous = previous_opt?;
593
594 if !found_base || previous.snapshot_id() == crate::state::PREEXISTING_SNAPSHOT_ID {
595 continue;
596 }
597
598 if Rc::ptr_eq(¤t, &previous) {
599 continue;
600 }
601
602 let applied = mutable::find_record_by_id(&head, *writer_id)?;
603
604 let merged = state.merge_records(
605 Rc::clone(&previous),
606 Rc::clone(¤t),
607 Rc::clone(&applied),
608 )?;
609
610 result
611 .get_or_insert_with(HashMap::default)
612 .insert(Rc::as_ptr(¤t) as usize, merged);
613 }
614
615 result
616}
617
618#[allow(clippy::arc_with_non_send_sync)]
624pub fn merge_read_observers(
625 a: Option<ReadObserver>,
626 b: Option<ReadObserver>,
627) -> Option<ReadObserver> {
628 match (a, b) {
629 (None, None) => None,
630 (Some(a), None) => Some(a),
631 (None, Some(b)) => Some(b),
632 (Some(a), Some(b)) => Some(Arc::new(move |state: &dyn StateObject| {
633 a(state);
634 b(state);
635 })),
636 }
637}
638
639#[allow(clippy::arc_with_non_send_sync)]
645pub fn merge_write_observers(
646 a: Option<WriteObserver>,
647 b: Option<WriteObserver>,
648) -> Option<WriteObserver> {
649 match (a, b) {
650 (None, None) => None,
651 (Some(a), None) => Some(a),
652 (None, Some(b)) => Some(b),
653 (Some(a), Some(b)) => Some(Arc::new(move |state: &dyn StateObject| {
654 a(state);
655 b(state);
656 })),
657 }
658}
659
660pub(crate) struct SnapshotState {
662 pub(crate) id: Cell<SnapshotId>,
664 pub(crate) invalid: RefCell<SnapshotIdSet>,
666 pub(crate) pin_handle: Cell<PinHandle>,
668 pub(crate) disposed: Cell<bool>,
670 pub(crate) read_observer: Option<ReadObserver>,
672 pub(crate) write_observer: Option<WriteObserver>,
674 #[allow(clippy::type_complexity)]
676 pub(crate) modified: RefCell<HashMap<StateObjectId, (Arc<dyn StateObject>, SnapshotId)>>,
678 on_dispose: RefCell<Option<Box<dyn FnOnce()>>>,
680 runtime_tracked: bool,
682 pending_children: RefCell<HashSet<SnapshotId>>,
684}
685
686impl SnapshotState {
687 pub(crate) fn new(
688 id: SnapshotId,
689 invalid: SnapshotIdSet,
690 read_observer: Option<ReadObserver>,
691 write_observer: Option<WriteObserver>,
692 runtime_tracked: bool,
693 ) -> Self {
694 Self::new_with_pinning(
695 id,
696 invalid,
697 read_observer,
698 write_observer,
699 runtime_tracked,
700 true,
701 )
702 }
703
704 pub(crate) fn new_with_pinning(
709 id: SnapshotId,
710 invalid: SnapshotIdSet,
711 read_observer: Option<ReadObserver>,
712 write_observer: Option<WriteObserver>,
713 runtime_tracked: bool,
714 should_pin: bool,
715 ) -> Self {
716 let pin_handle = if should_pin {
717 snapshot_pinning::track_pinning(id, &invalid)
718 } else {
719 snapshot_pinning::PinHandle::INVALID
720 };
721 Self {
722 id: Cell::new(id),
723 invalid: RefCell::new(invalid),
724 pin_handle: Cell::new(pin_handle),
725 disposed: Cell::new(false),
726 read_observer,
727 write_observer,
728 modified: RefCell::new(HashMap::default()),
729 on_dispose: RefCell::new(None),
730 runtime_tracked,
731 pending_children: RefCell::new(HashSet::default()),
732 }
733 }
734
735 pub(crate) fn record_read(&self, state: &dyn StateObject) {
736 if let Some(ref observer) = self.read_observer {
737 observer(state);
738 }
739 }
740
741 pub(crate) fn record_write(&self, state: Arc<dyn StateObject>, writer_id: SnapshotId) {
742 let state_id = state.object_id().as_usize();
744
745 let mut modified = self.modified.borrow_mut();
746
747 match modified.entry(state_id) {
749 std::collections::hash_map::Entry::Vacant(e) => {
750 if let Some(ref observer) = self.write_observer {
751 observer(&*state);
752 }
753 e.insert((state, writer_id));
755 }
756 std::collections::hash_map::Entry::Occupied(mut e) => {
757 e.insert((state, writer_id));
759 }
760 }
761 }
762
763 pub(crate) fn dispose(&self) {
764 if !self.disposed.replace(true) {
765 let pin_handle = self.pin_handle.get();
766 snapshot_pinning::release_pinning(pin_handle);
767 if let Some(cb) = self.on_dispose.borrow_mut().take() {
768 cb();
769 }
770 if self.runtime_tracked {
771 close_snapshot(self.id.get());
772 }
773 }
774 }
775
776 pub(crate) fn add_pending_child(&self, id: SnapshotId) {
777 self.pending_children.borrow_mut().insert(id);
778 }
779
780 pub(crate) fn remove_pending_child(&self, id: SnapshotId) {
781 self.pending_children.borrow_mut().remove(&id);
782 }
783
784 pub(crate) fn has_pending_children(&self) -> bool {
785 !self.pending_children.borrow().is_empty()
786 }
787
788 pub(crate) fn pending_children(&self) -> Vec<SnapshotId> {
789 self.pending_children.borrow().iter().copied().collect()
790 }
791
792 pub(crate) fn set_on_dispose<F>(&self, f: F)
793 where
794 F: FnOnce() + 'static,
795 {
796 *self.on_dispose.borrow_mut() = Some(Box::new(f));
797 }
798}
799
800#[cfg(test)]
801mod tests {
802 use super::*;
803
804 #[test]
805 fn test_apply_result_is_success() {
806 assert!(SnapshotApplyResult::Success.is_success());
807 assert!(!SnapshotApplyResult::Failure.is_success());
808 }
809
810 #[test]
811 fn test_apply_result_is_failure() {
812 assert!(!SnapshotApplyResult::Success.is_failure());
813 assert!(SnapshotApplyResult::Failure.is_failure());
814 }
815
816 #[test]
817 fn test_apply_result_check_success() {
818 SnapshotApplyResult::Success.check(); }
820
821 #[test]
822 #[should_panic(expected = "Snapshot apply failed")]
823 fn test_apply_result_check_failure() {
824 SnapshotApplyResult::Failure.check(); }
826
827 #[test]
828 fn test_merge_read_observers_both_none() {
829 let result = merge_read_observers(None, None);
830 assert!(result.is_none());
831 }
832
833 #[test]
834 fn test_merge_read_observers_one_some() {
835 let observer = Arc::new(|_: &dyn StateObject| {});
836 let result = merge_read_observers(Some(observer.clone()), None);
837 assert!(result.is_some());
838
839 let result = merge_read_observers(None, Some(observer));
840 assert!(result.is_some());
841 }
842
843 #[test]
844 fn test_merge_write_observers_both_none() {
845 let result = merge_write_observers(None, None);
846 assert!(result.is_none());
847 }
848
849 #[test]
850 fn test_merge_write_observers_one_some() {
851 let observer = Arc::new(|_: &dyn StateObject| {});
852 let result = merge_write_observers(Some(observer.clone()), None);
853 assert!(result.is_some());
854
855 let result = merge_write_observers(None, Some(observer));
856 assert!(result.is_some());
857 }
858
859 #[test]
860 fn test_current_snapshot_none_initially() {
861 set_current_snapshot(None);
862 assert!(current_snapshot().is_none());
863 }
864
865 struct TestStateObject {
867 id: usize,
868 }
869
870 impl TestStateObject {
871 fn new(id: usize) -> Arc<Self> {
872 Arc::new(Self { id })
873 }
874 }
875
876 impl StateObject for TestStateObject {
877 fn object_id(&self) -> crate::state::ObjectId {
878 crate::state::ObjectId(self.id)
879 }
880
881 fn first_record(&self) -> Rc<crate::state::StateRecord> {
882 unimplemented!("Not needed for observer tests")
883 }
884
885 fn readable_record(
886 &self,
887 _snapshot_id: SnapshotId,
888 _invalid: &SnapshotIdSet,
889 ) -> Rc<crate::state::StateRecord> {
890 unimplemented!("Not needed for observer tests")
891 }
892
893 fn prepend_state_record(&self, _record: Rc<crate::state::StateRecord>) {
894 unimplemented!("Not needed for observer tests")
895 }
896
897 fn promote_record(&self, _child_id: SnapshotId) -> Result<(), &'static str> {
898 unimplemented!("Not needed for observer tests")
899 }
900
901 fn as_any(&self) -> &dyn std::any::Any {
902 self
903 }
904 }
905
906 #[test]
907 fn test_apply_observer_receives_correct_modified_objects() {
908 use std::sync::Mutex;
909
910 let received_count = Arc::new(Mutex::new(0));
912 let received_snapshot_id = Arc::new(Mutex::new(0));
913
914 let received_count_clone = received_count.clone();
915 let received_snapshot_id_clone = received_snapshot_id.clone();
916
917 let _handle = register_apply_observer(Arc::new(move |modified, snapshot_id| {
919 *received_snapshot_id_clone.lock().unwrap() = snapshot_id;
920 *received_count_clone.lock().unwrap() = modified.len();
921 }));
922
923 let obj1: Arc<dyn StateObject> = TestStateObject::new(42);
925 let obj2: Arc<dyn StateObject> = TestStateObject::new(99);
926 let modified = vec![obj1, obj2];
927
928 notify_apply_observers(&modified, 123);
930
931 assert_eq!(*received_snapshot_id.lock().unwrap(), 123);
933 assert_eq!(*received_count.lock().unwrap(), 2);
934 }
935
936 #[test]
937 fn test_apply_observer_receives_correct_snapshot_id() {
938 use std::sync::Mutex;
939
940 let received_id = Arc::new(Mutex::new(0));
941 let received_id_clone = received_id.clone();
942
943 let _handle = register_apply_observer(Arc::new(move |_, snapshot_id| {
944 *received_id_clone.lock().unwrap() = snapshot_id;
945 }));
946
947 notify_apply_observers(&[], 456);
949
950 assert_eq!(*received_id.lock().unwrap(), 456);
951 }
952
953 #[test]
954 fn test_multiple_apply_observers_all_called() {
955 use std::sync::Mutex;
956
957 let call_count1 = Arc::new(Mutex::new(0));
958 let call_count2 = Arc::new(Mutex::new(0));
959 let call_count3 = Arc::new(Mutex::new(0));
960
961 let call_count1_clone = call_count1.clone();
962 let call_count2_clone = call_count2.clone();
963 let call_count3_clone = call_count3.clone();
964
965 let _handle1 = register_apply_observer(Arc::new(move |_, _| {
967 *call_count1_clone.lock().unwrap() += 1;
968 }));
969
970 let _handle2 = register_apply_observer(Arc::new(move |_, _| {
971 *call_count2_clone.lock().unwrap() += 1;
972 }));
973
974 let _handle3 = register_apply_observer(Arc::new(move |_, _| {
975 *call_count3_clone.lock().unwrap() += 1;
976 }));
977
978 notify_apply_observers(&[], 1);
980
981 assert_eq!(*call_count1.lock().unwrap(), 1);
983 assert_eq!(*call_count2.lock().unwrap(), 1);
984 assert_eq!(*call_count3.lock().unwrap(), 1);
985
986 notify_apply_observers(&[], 2);
988
989 assert_eq!(*call_count1.lock().unwrap(), 2);
991 assert_eq!(*call_count2.lock().unwrap(), 2);
992 assert_eq!(*call_count3.lock().unwrap(), 2);
993 }
994
995 #[test]
996 fn test_apply_observer_not_called_for_empty_modifications() {
997 use std::sync::Mutex;
998
999 let call_count = Arc::new(Mutex::new(0));
1000 let call_count_clone = call_count.clone();
1001
1002 let _handle = register_apply_observer(Arc::new(move |modified, _| {
1003 *call_count_clone.lock().unwrap() += 1;
1005 assert_eq!(modified.len(), 0);
1006 }));
1007
1008 notify_apply_observers(&[], 1);
1010
1011 assert_eq!(*call_count.lock().unwrap(), 1);
1013 }
1014
1015 #[test]
1016 fn test_observer_handle_drop_removes_correct_observer() {
1017 use std::sync::Mutex;
1018
1019 let calls = Arc::new(Mutex::new(Vec::new()));
1021
1022 let calls1 = calls.clone();
1023 let handle1 = register_apply_observer(Arc::new(move |_, _| {
1024 calls1.lock().unwrap().push(1);
1025 }));
1026
1027 let calls2 = calls.clone();
1028 let handle2 = register_apply_observer(Arc::new(move |_, _| {
1029 calls2.lock().unwrap().push(2);
1030 }));
1031
1032 let calls3 = calls.clone();
1033 let handle3 = register_apply_observer(Arc::new(move |_, _| {
1034 calls3.lock().unwrap().push(3);
1035 }));
1036
1037 notify_apply_observers(&[], 1);
1039 let result = calls.lock().unwrap().clone();
1040 assert_eq!(result.len(), 3);
1041 assert!(result.contains(&1));
1042 assert!(result.contains(&2));
1043 assert!(result.contains(&3));
1044 calls.lock().unwrap().clear();
1045
1046 drop(handle2);
1048
1049 notify_apply_observers(&[], 2);
1051 let result = calls.lock().unwrap().clone();
1052 assert_eq!(result.len(), 2);
1053 assert!(result.contains(&1));
1054 assert!(result.contains(&3));
1055 assert!(!result.contains(&2));
1056 calls.lock().unwrap().clear();
1057
1058 drop(handle1);
1060
1061 notify_apply_observers(&[], 3);
1063 let result = calls.lock().unwrap().clone();
1064 assert_eq!(result.len(), 1);
1065 assert!(result.contains(&3));
1066 calls.lock().unwrap().clear();
1067
1068 drop(handle3);
1070
1071 notify_apply_observers(&[], 4);
1073 assert_eq!(calls.lock().unwrap().len(), 0);
1074 }
1075
1076 #[test]
1077 fn test_observer_handle_drop_in_different_orders() {
1078 use std::sync::Mutex;
1079
1080 {
1082 let calls = Arc::new(Mutex::new(Vec::new()));
1083
1084 let calls1 = calls.clone();
1085 let h1 = register_apply_observer(Arc::new(move |_, _| {
1086 calls1.lock().unwrap().push(1);
1087 }));
1088
1089 let calls2 = calls.clone();
1090 let h2 = register_apply_observer(Arc::new(move |_, _| {
1091 calls2.lock().unwrap().push(2);
1092 }));
1093
1094 let calls3 = calls.clone();
1095 let h3 = register_apply_observer(Arc::new(move |_, _| {
1096 calls3.lock().unwrap().push(3);
1097 }));
1098
1099 drop(h3);
1100 notify_apply_observers(&[], 1);
1101 let result = calls.lock().unwrap().clone();
1102 assert!(result.contains(&1) && result.contains(&2) && !result.contains(&3));
1103 calls.lock().unwrap().clear();
1104
1105 drop(h2);
1106 notify_apply_observers(&[], 2);
1107 let result = calls.lock().unwrap().clone();
1108 assert_eq!(result.len(), 1);
1109 assert!(result.contains(&1));
1110 calls.lock().unwrap().clear();
1111
1112 drop(h1);
1113 notify_apply_observers(&[], 3);
1114 assert_eq!(calls.lock().unwrap().len(), 0);
1115 }
1116
1117 {
1119 let calls = Arc::new(Mutex::new(Vec::new()));
1120
1121 let calls1 = calls.clone();
1122 let h1 = register_apply_observer(Arc::new(move |_, _| {
1123 calls1.lock().unwrap().push(1);
1124 }));
1125
1126 let calls2 = calls.clone();
1127 let h2 = register_apply_observer(Arc::new(move |_, _| {
1128 calls2.lock().unwrap().push(2);
1129 }));
1130
1131 let calls3 = calls.clone();
1132 let h3 = register_apply_observer(Arc::new(move |_, _| {
1133 calls3.lock().unwrap().push(3);
1134 }));
1135
1136 drop(h1);
1137 notify_apply_observers(&[], 1);
1138 let result = calls.lock().unwrap().clone();
1139 assert!(!result.contains(&1) && result.contains(&2) && result.contains(&3));
1140 calls.lock().unwrap().clear();
1141
1142 drop(h2);
1143 notify_apply_observers(&[], 2);
1144 let result = calls.lock().unwrap().clone();
1145 assert_eq!(result.len(), 1);
1146 assert!(result.contains(&3));
1147 calls.lock().unwrap().clear();
1148
1149 drop(h3);
1150 notify_apply_observers(&[], 3);
1151 assert_eq!(calls.lock().unwrap().len(), 0);
1152 }
1153 }
1154
1155 #[test]
1156 fn test_remaining_observers_still_work_after_drop() {
1157 use std::sync::Mutex;
1158
1159 let calls = Arc::new(Mutex::new(Vec::new()));
1160
1161 let calls1 = calls.clone();
1162 let handle1 = register_apply_observer(Arc::new(move |_, snapshot_id| {
1163 calls1.lock().unwrap().push((1, snapshot_id));
1164 }));
1165
1166 let calls2 = calls.clone();
1167 let handle2 = register_apply_observer(Arc::new(move |_, snapshot_id| {
1168 calls2.lock().unwrap().push((2, snapshot_id));
1169 }));
1170
1171 notify_apply_observers(&[], 100);
1173 assert_eq!(calls.lock().unwrap().len(), 2);
1174 calls.lock().unwrap().clear();
1175
1176 drop(handle1);
1178
1179 notify_apply_observers(&[], 200);
1181 assert_eq!(*calls.lock().unwrap(), vec![(2, 200)]);
1182 calls.lock().unwrap().clear();
1183
1184 let calls3 = calls.clone();
1186 let _handle3 = register_apply_observer(Arc::new(move |_, snapshot_id| {
1187 calls3.lock().unwrap().push((3, snapshot_id));
1188 }));
1189
1190 notify_apply_observers(&[], 300);
1192 let result = calls.lock().unwrap().clone();
1193 assert_eq!(result.len(), 2);
1194 assert!(result.contains(&(2, 300)));
1195 assert!(result.contains(&(3, 300)));
1196
1197 drop(handle2);
1198 }
1199
1200 #[test]
1201 fn test_observer_ids_are_unique() {
1202 use std::sync::Mutex;
1203
1204 let ids = Arc::new(Mutex::new(std::collections::HashSet::new()));
1205
1206 let mut handles = Vec::new();
1207
1208 for i in 0..100 {
1212 let ids_clone = ids.clone();
1213 let handle = register_apply_observer(Arc::new(move |_, _| {
1214 ids_clone.lock().unwrap().insert(i);
1215 }));
1216 handles.push(handle);
1217 }
1218
1219 notify_apply_observers(&[], 1);
1221 assert_eq!(ids.lock().unwrap().len(), 100);
1222
1223 for i in (0..100).step_by(2) {
1225 handles.remove(i / 2);
1226 }
1227
1228 ids.lock().unwrap().clear();
1230 notify_apply_observers(&[], 2);
1231 assert_eq!(ids.lock().unwrap().len(), 50);
1232 }
1233
1234 #[test]
1235 fn test_state_object_storage_in_modified_set() {
1236 use crate::state::StateObject;
1237 use std::cell::Cell;
1238
1239 #[allow(dead_code)]
1241 struct TestState {
1242 value: Cell<i32>,
1243 }
1244
1245 impl StateObject for TestState {
1246 fn object_id(&self) -> crate::state::ObjectId {
1247 crate::state::ObjectId(12345)
1248 }
1249
1250 fn first_record(&self) -> Rc<crate::state::StateRecord> {
1251 unimplemented!("Not needed for this test")
1252 }
1253
1254 fn readable_record(
1255 &self,
1256 _snapshot_id: SnapshotId,
1257 _invalid: &SnapshotIdSet,
1258 ) -> Rc<crate::state::StateRecord> {
1259 unimplemented!("Not needed for this test")
1260 }
1261
1262 fn prepend_state_record(&self, _record: Rc<crate::state::StateRecord>) {
1263 unimplemented!("Not needed for this test")
1264 }
1265
1266 fn promote_record(&self, _child_id: SnapshotId) -> Result<(), &'static str> {
1267 unimplemented!("Not needed for this test")
1268 }
1269
1270 fn as_any(&self) -> &dyn std::any::Any {
1271 self
1272 }
1273 }
1274
1275 let state = SnapshotState::new(1, SnapshotIdSet::new(), None, None, false);
1276
1277 let state_obj = Arc::new(TestState {
1279 value: Cell::new(42),
1280 }) as Arc<dyn StateObject>;
1281
1282 state.record_write(state_obj.clone(), 1);
1284
1285 let modified = state.modified.borrow();
1287 assert_eq!(modified.len(), 1);
1288 assert!(modified.contains_key(&12345));
1289
1290 let (stored, writer_id) = modified.get(&12345).unwrap();
1292 assert_eq!(stored.object_id().as_usize(), 12345);
1293 assert_eq!(*writer_id, 1);
1294 }
1295
1296 #[test]
1297 fn test_multiple_writes_to_same_state_object() {
1298 use crate::state::StateObject;
1299 use std::cell::Cell;
1300
1301 #[allow(dead_code)]
1302 struct TestState {
1303 value: Cell<i32>,
1304 }
1305
1306 impl StateObject for TestState {
1307 fn object_id(&self) -> crate::state::ObjectId {
1308 crate::state::ObjectId(99999)
1309 }
1310
1311 fn first_record(&self) -> Rc<crate::state::StateRecord> {
1312 unimplemented!()
1313 }
1314
1315 fn readable_record(
1316 &self,
1317 _snapshot_id: SnapshotId,
1318 _invalid: &SnapshotIdSet,
1319 ) -> Rc<crate::state::StateRecord> {
1320 unimplemented!()
1321 }
1322
1323 fn prepend_state_record(&self, _record: Rc<crate::state::StateRecord>) {
1324 unimplemented!()
1325 }
1326
1327 fn promote_record(&self, _child_id: SnapshotId) -> Result<(), &'static str> {
1328 unimplemented!()
1329 }
1330
1331 fn as_any(&self) -> &dyn std::any::Any {
1332 self
1333 }
1334 }
1335
1336 let state = SnapshotState::new(1, SnapshotIdSet::new(), None, None, false);
1337 let state_obj = Arc::new(TestState {
1338 value: Cell::new(100),
1339 }) as Arc<dyn StateObject>;
1340
1341 state.record_write(state_obj.clone(), 1);
1343 assert_eq!(state.modified.borrow().len(), 1);
1344
1345 state.record_write(state_obj.clone(), 2);
1347 let modified = state.modified.borrow();
1348 assert_eq!(modified.len(), 1);
1349 assert!(modified.contains_key(&99999));
1350 let (_, writer_id) = modified.get(&99999).unwrap();
1351 assert_eq!(*writer_id, 2);
1352 }
1353}