1#![allow(clippy::arc_with_non_send_sync)]
24
25use crate::collections::map::HashMap; use crate::collections::map::HashSet;
27use crate::snapshot_id_set::{SnapshotId, SnapshotIdSet};
28use crate::snapshot_pinning::{self, PinHandle};
29use crate::state::{StateObject, StateRecord};
30use std::cell::{Cell, RefCell};
31use std::sync::atomic::{AtomicUsize, Ordering};
32use std::sync::{Arc, Weak};
33
34mod global;
35mod mutable;
36mod nested;
37mod readonly;
38mod runtime;
39mod transparent;
40
41#[cfg(test)]
42mod integration_tests;
43
44pub use global::{advance_global_snapshot, GlobalSnapshot};
45pub use mutable::MutableSnapshot;
46pub use nested::{NestedMutableSnapshot, NestedReadonlySnapshot};
47pub use readonly::ReadonlySnapshot;
48pub use transparent::{TransparentObserverMutableSnapshot, TransparentObserverSnapshot};
49
50pub(crate) use runtime::{allocate_snapshot, close_snapshot, with_runtime};
51#[cfg(test)]
52pub(crate) use runtime::{reset_runtime_for_tests, TestRuntimeGuard};
53
54pub type ReadObserver = Arc<dyn Fn(&dyn StateObject) + 'static>;
56
57pub type WriteObserver = Arc<dyn Fn(&dyn StateObject) + 'static>;
59
60pub type ApplyObserver = Arc<dyn Fn(&[Arc<dyn StateObject>], SnapshotId) + 'static>;
62
63#[derive(Debug, Clone, PartialEq, Eq)]
65pub enum SnapshotApplyResult {
66 Success,
68 Failure,
70}
71
72impl SnapshotApplyResult {
73 pub fn is_success(&self) -> bool {
75 matches!(self, SnapshotApplyResult::Success)
76 }
77
78 pub fn is_failure(&self) -> bool {
80 matches!(self, SnapshotApplyResult::Failure)
81 }
82
83 #[track_caller]
85 pub fn check(&self) {
86 if self.is_failure() {
87 panic!("Snapshot apply failed");
88 }
89 }
90}
91
92pub type StateObjectId = usize;
94
95#[derive(Clone)]
100pub enum AnySnapshot {
101 Readonly(Arc<ReadonlySnapshot>),
102 Mutable(Arc<MutableSnapshot>),
103 NestedReadonly(Arc<NestedReadonlySnapshot>),
104 NestedMutable(Arc<NestedMutableSnapshot>),
105 Global(Arc<GlobalSnapshot>),
106 TransparentMutable(Arc<TransparentObserverMutableSnapshot>),
107 TransparentReadonly(Arc<TransparentObserverSnapshot>),
108}
109
110impl AnySnapshot {
111 pub fn snapshot_id(&self) -> SnapshotId {
113 match self {
114 AnySnapshot::Readonly(s) => s.snapshot_id(),
115 AnySnapshot::Mutable(s) => s.snapshot_id(),
116 AnySnapshot::NestedReadonly(s) => s.snapshot_id(),
117 AnySnapshot::NestedMutable(s) => s.snapshot_id(),
118 AnySnapshot::Global(s) => s.snapshot_id(),
119 AnySnapshot::TransparentMutable(s) => s.snapshot_id(),
120 AnySnapshot::TransparentReadonly(s) => s.snapshot_id(),
121 }
122 }
123
124 pub fn invalid(&self) -> SnapshotIdSet {
126 match self {
127 AnySnapshot::Readonly(s) => s.invalid(),
128 AnySnapshot::Mutable(s) => s.invalid(),
129 AnySnapshot::NestedReadonly(s) => s.invalid(),
130 AnySnapshot::NestedMutable(s) => s.invalid(),
131 AnySnapshot::Global(s) => s.invalid(),
132 AnySnapshot::TransparentMutable(s) => s.invalid(),
133 AnySnapshot::TransparentReadonly(s) => s.invalid(),
134 }
135 }
136
137 pub fn is_valid(&self, id: SnapshotId) -> bool {
139 let snapshot_id = self.snapshot_id();
140 id <= snapshot_id && !self.invalid().get(id)
141 }
142
143 pub fn read_only(&self) -> bool {
145 match self {
146 AnySnapshot::Readonly(_) => true,
147 AnySnapshot::Mutable(_) => false,
148 AnySnapshot::NestedReadonly(_) => true,
149 AnySnapshot::NestedMutable(_) => false,
150 AnySnapshot::Global(_) => false,
151 AnySnapshot::TransparentMutable(_) => false,
152 AnySnapshot::TransparentReadonly(_) => true,
153 }
154 }
155
156 pub fn root(&self) -> AnySnapshot {
158 match self {
159 AnySnapshot::Readonly(s) => AnySnapshot::Readonly(s.root_readonly()),
160 AnySnapshot::Mutable(s) => AnySnapshot::Mutable(s.root_mutable()),
161 AnySnapshot::NestedReadonly(s) => AnySnapshot::NestedReadonly(s.root_nested_readonly()),
162 AnySnapshot::NestedMutable(s) => AnySnapshot::Mutable(s.root_mutable()),
163 AnySnapshot::Global(s) => AnySnapshot::Global(s.root_global()),
164 AnySnapshot::TransparentMutable(s) => {
165 AnySnapshot::TransparentMutable(s.root_transparent_mutable())
166 }
167 AnySnapshot::TransparentReadonly(s) => {
168 AnySnapshot::TransparentReadonly(s.root_transparent_readonly())
169 }
170 }
171 }
172
173 pub fn is_same_transparent(&self, other: &Arc<TransparentObserverMutableSnapshot>) -> bool {
175 matches!(self, AnySnapshot::TransparentMutable(snapshot) if Arc::ptr_eq(snapshot, other))
176 }
177
178 pub fn is_same_transparent_mutable(
180 &self,
181 other: &Arc<TransparentObserverMutableSnapshot>,
182 ) -> bool {
183 self.is_same_transparent(other)
184 }
185
186 pub fn is_same_transparent_readonly(&self, other: &Arc<TransparentObserverSnapshot>) -> bool {
188 matches!(self, AnySnapshot::TransparentReadonly(snapshot) if Arc::ptr_eq(snapshot, other))
189 }
190
191 pub fn enter<T>(&self, f: impl FnOnce() -> T) -> T {
193 match self {
194 AnySnapshot::Readonly(s) => s.enter(f),
195 AnySnapshot::Mutable(s) => s.enter(f),
196 AnySnapshot::NestedReadonly(s) => s.enter(f),
197 AnySnapshot::NestedMutable(s) => s.enter(f),
198 AnySnapshot::Global(s) => s.enter(f),
199 AnySnapshot::TransparentMutable(s) => s.enter(f),
200 AnySnapshot::TransparentReadonly(s) => s.enter(f),
201 }
202 }
203
204 pub fn take_nested_snapshot(&self, read_observer: Option<ReadObserver>) -> AnySnapshot {
206 match self {
207 AnySnapshot::Readonly(s) => {
208 AnySnapshot::Readonly(s.take_nested_snapshot(read_observer))
209 }
210 AnySnapshot::Mutable(s) => AnySnapshot::Readonly(s.take_nested_snapshot(read_observer)),
211 AnySnapshot::NestedReadonly(s) => {
212 AnySnapshot::NestedReadonly(s.take_nested_snapshot(read_observer))
213 }
214 AnySnapshot::NestedMutable(s) => {
215 AnySnapshot::Readonly(s.take_nested_snapshot(read_observer))
216 }
217 AnySnapshot::Global(s) => AnySnapshot::Readonly(s.take_nested_snapshot(read_observer)),
218 AnySnapshot::TransparentMutable(s) => {
219 AnySnapshot::Readonly(s.take_nested_snapshot(read_observer))
220 }
221 AnySnapshot::TransparentReadonly(s) => {
222 AnySnapshot::TransparentReadonly(s.take_nested_snapshot(read_observer))
223 }
224 }
225 }
226
227 pub fn has_pending_changes(&self) -> bool {
229 match self {
230 AnySnapshot::Readonly(s) => s.has_pending_changes(),
231 AnySnapshot::Mutable(s) => s.has_pending_changes(),
232 AnySnapshot::NestedReadonly(s) => s.has_pending_changes(),
233 AnySnapshot::NestedMutable(s) => s.has_pending_changes(),
234 AnySnapshot::Global(s) => s.has_pending_changes(),
235 AnySnapshot::TransparentMutable(s) => s.has_pending_changes(),
236 AnySnapshot::TransparentReadonly(s) => s.has_pending_changes(),
237 }
238 }
239
240 pub fn dispose(&self) {
242 match self {
243 AnySnapshot::Readonly(s) => s.dispose(),
244 AnySnapshot::Mutable(s) => s.dispose(),
245 AnySnapshot::NestedReadonly(s) => s.dispose(),
246 AnySnapshot::NestedMutable(s) => s.dispose(),
247 AnySnapshot::Global(s) => s.dispose(),
248 AnySnapshot::TransparentMutable(s) => s.dispose(),
249 AnySnapshot::TransparentReadonly(s) => s.dispose(),
250 }
251 }
252
253 pub fn is_disposed(&self) -> bool {
255 match self {
256 AnySnapshot::Readonly(s) => s.is_disposed(),
257 AnySnapshot::Mutable(s) => s.is_disposed(),
258 AnySnapshot::NestedReadonly(s) => s.is_disposed(),
259 AnySnapshot::NestedMutable(s) => s.is_disposed(),
260 AnySnapshot::Global(s) => s.is_disposed(),
261 AnySnapshot::TransparentMutable(s) => s.is_disposed(),
262 AnySnapshot::TransparentReadonly(s) => s.is_disposed(),
263 }
264 }
265
266 pub fn record_read(&self, state: &dyn StateObject) {
268 match self {
269 AnySnapshot::Readonly(s) => s.record_read(state),
270 AnySnapshot::Mutable(s) => s.record_read(state),
271 AnySnapshot::NestedReadonly(s) => s.record_read(state),
272 AnySnapshot::NestedMutable(s) => s.record_read(state),
273 AnySnapshot::Global(s) => s.record_read(state),
274 AnySnapshot::TransparentMutable(s) => s.record_read(state),
275 AnySnapshot::TransparentReadonly(s) => s.record_read(state),
276 }
277 }
278
279 pub fn record_write(&self, state: Arc<dyn StateObject>) {
281 match self {
282 AnySnapshot::Readonly(s) => s.record_write(state),
283 AnySnapshot::Mutable(s) => s.record_write(state),
284 AnySnapshot::NestedReadonly(s) => s.record_write(state),
285 AnySnapshot::NestedMutable(s) => s.record_write(state),
286 AnySnapshot::Global(s) => s.record_write(state),
287 AnySnapshot::TransparentMutable(s) => s.record_write(state),
288 AnySnapshot::TransparentReadonly(s) => s.record_write(state),
289 }
290 }
291
292 pub fn apply(&self) -> SnapshotApplyResult {
294 match self {
295 AnySnapshot::Mutable(s) => s.apply(),
296 AnySnapshot::NestedMutable(s) => s.apply(),
297 AnySnapshot::Global(s) => s.apply(),
298 AnySnapshot::TransparentMutable(s) => s.apply(),
299 _ => panic!("Cannot apply a read-only snapshot"),
300 }
301 }
302
303 pub fn take_nested_mutable_snapshot(
305 &self,
306 read_observer: Option<ReadObserver>,
307 write_observer: Option<WriteObserver>,
308 ) -> AnySnapshot {
309 match self {
310 AnySnapshot::Mutable(s) => AnySnapshot::NestedMutable(
311 s.take_nested_mutable_snapshot(read_observer, write_observer),
312 ),
313 AnySnapshot::NestedMutable(s) => AnySnapshot::NestedMutable(
314 s.take_nested_mutable_snapshot(read_observer, write_observer),
315 ),
316 AnySnapshot::Global(s) => {
317 AnySnapshot::Mutable(s.take_nested_mutable_snapshot(read_observer, write_observer))
318 }
319 AnySnapshot::TransparentMutable(s) => AnySnapshot::TransparentMutable(
320 s.take_nested_mutable_snapshot(read_observer, write_observer),
321 ),
322 _ => panic!("Cannot take nested mutable snapshot from read-only snapshot"),
323 }
324 }
325}
326
327thread_local! {
328 static CURRENT_SNAPSHOT: RefCell<Option<AnySnapshot>> = const { RefCell::new(None) };
330}
331
332pub fn current_snapshot() -> Option<AnySnapshot> {
334 CURRENT_SNAPSHOT
335 .try_with(|cell| cell.borrow().clone())
336 .unwrap_or(None)
337}
338
339pub(crate) fn set_current_snapshot(snapshot: Option<AnySnapshot>) {
341 let _ = CURRENT_SNAPSHOT.try_with(|cell| {
342 *cell.borrow_mut() = snapshot;
343 });
344}
345
346pub fn take_mutable_snapshot(
351 read_observer: Option<ReadObserver>,
352 write_observer: Option<WriteObserver>,
353) -> Arc<MutableSnapshot> {
354 GlobalSnapshot::get_or_create().take_nested_mutable_snapshot(read_observer, write_observer)
355}
356
357pub fn take_transparent_observer_mutable_snapshot(
366 read_observer: Option<ReadObserver>,
367 write_observer: Option<WriteObserver>,
368) -> Arc<TransparentObserverMutableSnapshot> {
369 let parent = current_snapshot();
370 match parent {
371 Some(AnySnapshot::TransparentMutable(transparent)) if transparent.can_reuse() => {
372 transparent
374 }
375 _ => {
376 let current = current_snapshot()
379 .unwrap_or_else(|| AnySnapshot::Global(GlobalSnapshot::get_or_create()));
380 let id = current.snapshot_id();
381 let invalid = current.invalid();
382 TransparentObserverMutableSnapshot::new(
383 id,
384 invalid,
385 read_observer,
386 write_observer,
387 None,
388 )
389 }
390 }
391}
392
393pub fn allocate_record_id() -> SnapshotId {
395 runtime::allocate_record_id()
396}
397
398pub(crate) fn peek_next_snapshot_id() -> SnapshotId {
403 runtime::peek_next_snapshot_id()
404}
405
406static NEXT_OBSERVER_ID: AtomicUsize = AtomicUsize::new(1);
408
409thread_local! {
410 static APPLY_OBSERVERS: RefCell<HashMap<usize, ApplyObserver>> = RefCell::new(HashMap::default());
412}
413
414thread_local! {
415 static LAST_WRITES: RefCell<HashMap<StateObjectId, SnapshotId>> = RefCell::new(HashMap::default());
423}
424
425thread_local! {
426 static EXTRA_STATE_OBJECTS: RefCell<crate::snapshot_weak_set::SnapshotWeakSet> = RefCell::new(crate::snapshot_weak_set::SnapshotWeakSet::new());
429}
430
431pub fn register_apply_observer(observer: ApplyObserver) -> ObserverHandle {
435 let id = NEXT_OBSERVER_ID.fetch_add(1, Ordering::SeqCst);
436 APPLY_OBSERVERS.with(|cell| {
437 cell.borrow_mut().insert(id, observer);
438 });
439 ObserverHandle {
440 kind: ObserverKind::Apply,
441 id,
442 }
443}
444
445pub struct ObserverHandle {
449 kind: ObserverKind,
450 id: usize,
451}
452
453enum ObserverKind {
454 Apply,
455}
456
457impl Drop for ObserverHandle {
458 fn drop(&mut self) {
459 match self.kind {
460 ObserverKind::Apply => {
461 APPLY_OBSERVERS.with(|cell| {
462 cell.borrow_mut().remove(&self.id);
463 });
464 }
465 }
466 }
467}
468
469pub(crate) fn notify_apply_observers(modified: &[Arc<dyn StateObject>], snapshot_id: SnapshotId) {
471 APPLY_OBSERVERS.with(|cell| {
473 let observers: Vec<ApplyObserver> = cell.borrow().values().cloned().collect();
474 for observer in observers.into_iter() {
475 observer(modified, snapshot_id);
476 }
477 });
478}
479
480#[allow(dead_code)]
482pub(crate) fn get_last_write(id: StateObjectId) -> Option<SnapshotId> {
483 LAST_WRITES.with(|cell| cell.borrow().get(&id).copied())
484}
485
486pub(crate) fn set_last_write(id: StateObjectId, snapshot_id: SnapshotId) {
488 LAST_WRITES.with(|cell| {
489 cell.borrow_mut().insert(id, snapshot_id);
490 });
491}
492
493#[cfg(test)]
495pub(crate) fn clear_last_writes() {
496 LAST_WRITES.with(|cell| {
497 cell.borrow_mut().clear();
498 });
499}
500
501pub(crate) fn check_and_overwrite_unused_records_locked() {
509 EXTRA_STATE_OBJECTS.with(|cell| {
510 cell.borrow_mut().remove_if(|state| {
511 state.overwrite_unused_records()
513 });
514 });
515}
516
517#[allow(dead_code)]
523pub(crate) fn process_for_unused_records_locked(state: &Arc<dyn crate::state::StateObject>) {
524 if state.overwrite_unused_records() {
525 EXTRA_STATE_OBJECTS.with(|cell| {
527 cell.borrow_mut().add_trait_object(state);
528 });
529 }
530}
531
532pub(crate) fn optimistic_merges(
533 current_snapshot_id: SnapshotId,
534 base_parent_id: SnapshotId,
535 modified_objects: &[(StateObjectId, Arc<dyn StateObject>, SnapshotId)],
536 invalid_snapshots: &SnapshotIdSet,
537) -> Option<HashMap<usize, Arc<StateRecord>>> {
538 if modified_objects.is_empty() {
539 return None;
540 }
541
542 let mut result: Option<HashMap<usize, Arc<StateRecord>>> = None;
543
544 for (_, state, writer_id) in modified_objects.iter() {
545 let head = state.first_record();
546
547 let current = match crate::state::readable_record_for(
548 &head,
549 current_snapshot_id,
550 invalid_snapshots,
551 ) {
552 Some(record) => record,
553 None => continue,
554 };
555
556 let (previous_opt, found_base) = mutable::find_previous_record(&head, base_parent_id);
557 let previous = previous_opt?;
558
559 if !found_base || previous.snapshot_id() == crate::state::PREEXISTING_SNAPSHOT_ID {
560 continue;
561 }
562
563 if Arc::ptr_eq(¤t, &previous) {
564 continue;
565 }
566
567 let applied = mutable::find_record_by_id(&head, *writer_id)?;
568
569 let merged = state.merge_records(
570 Arc::clone(&previous),
571 Arc::clone(¤t),
572 Arc::clone(&applied),
573 )?;
574
575 result
576 .get_or_insert_with(HashMap::default)
577 .insert(Arc::as_ptr(¤t) as usize, merged);
578 }
579
580 result
581}
582
583#[allow(clippy::arc_with_non_send_sync)]
589pub fn merge_read_observers(
590 a: Option<ReadObserver>,
591 b: Option<ReadObserver>,
592) -> Option<ReadObserver> {
593 match (a, b) {
594 (None, None) => None,
595 (Some(a), None) => Some(a),
596 (None, Some(b)) => Some(b),
597 (Some(a), Some(b)) => Some(Arc::new(move |state: &dyn StateObject| {
598 a(state);
599 b(state);
600 })),
601 }
602}
603
604#[allow(clippy::arc_with_non_send_sync)]
610pub fn merge_write_observers(
611 a: Option<WriteObserver>,
612 b: Option<WriteObserver>,
613) -> Option<WriteObserver> {
614 match (a, b) {
615 (None, None) => None,
616 (Some(a), None) => Some(a),
617 (None, Some(b)) => Some(b),
618 (Some(a), Some(b)) => Some(Arc::new(move |state: &dyn StateObject| {
619 a(state);
620 b(state);
621 })),
622 }
623}
624
625pub(crate) struct SnapshotState {
627 pub(crate) id: Cell<SnapshotId>,
629 pub(crate) invalid: RefCell<SnapshotIdSet>,
631 pub(crate) pin_handle: Cell<PinHandle>,
633 pub(crate) disposed: Cell<bool>,
635 pub(crate) read_observer: Option<ReadObserver>,
637 pub(crate) write_observer: Option<WriteObserver>,
639 #[allow(clippy::type_complexity)]
641 pub(crate) modified: RefCell<HashMap<StateObjectId, (Arc<dyn StateObject>, SnapshotId)>>,
643 on_dispose: RefCell<Option<Box<dyn FnOnce()>>>,
645 runtime_tracked: bool,
647 pending_children: RefCell<HashSet<SnapshotId>>,
649}
650
651impl SnapshotState {
652 pub(crate) fn new(
653 id: SnapshotId,
654 invalid: SnapshotIdSet,
655 read_observer: Option<ReadObserver>,
656 write_observer: Option<WriteObserver>,
657 runtime_tracked: bool,
658 ) -> Self {
659 let pin_handle = snapshot_pinning::track_pinning(id, &invalid);
660 Self {
661 id: Cell::new(id),
662 invalid: RefCell::new(invalid),
663 pin_handle: Cell::new(pin_handle),
664 disposed: Cell::new(false),
665 read_observer,
666 write_observer,
667 modified: RefCell::new(HashMap::default()),
668 on_dispose: RefCell::new(None),
669 runtime_tracked,
670 pending_children: RefCell::new(HashSet::default()),
671 }
672 }
673
674 pub(crate) fn record_read(&self, state: &dyn StateObject) {
675 if let Some(ref observer) = self.read_observer {
676 observer(state);
677 }
678 }
679
680 pub(crate) fn record_write(&self, state: Arc<dyn StateObject>, writer_id: SnapshotId) {
681 let state_id = state.object_id().as_usize();
683
684 let mut modified = self.modified.borrow_mut();
685
686 match modified.entry(state_id) {
688 std::collections::hash_map::Entry::Vacant(e) => {
689 if let Some(ref observer) = self.write_observer {
690 observer(&*state);
691 }
692 e.insert((state, writer_id));
694 }
695 std::collections::hash_map::Entry::Occupied(mut e) => {
696 e.insert((state, writer_id));
698 }
699 }
700 }
701
702 pub(crate) fn dispose(&self) {
703 if !self.disposed.replace(true) {
704 let pin_handle = self.pin_handle.get();
705 snapshot_pinning::release_pinning(pin_handle);
706 if let Some(cb) = self.on_dispose.borrow_mut().take() {
707 cb();
708 }
709 if self.runtime_tracked {
710 close_snapshot(self.id.get());
711 }
712 }
713 }
714
715 pub(crate) fn add_pending_child(&self, id: SnapshotId) {
716 self.pending_children.borrow_mut().insert(id);
717 }
718
719 pub(crate) fn remove_pending_child(&self, id: SnapshotId) {
720 self.pending_children.borrow_mut().remove(&id);
721 }
722
723 pub(crate) fn has_pending_children(&self) -> bool {
724 !self.pending_children.borrow().is_empty()
725 }
726
727 pub(crate) fn pending_children(&self) -> Vec<SnapshotId> {
728 self.pending_children.borrow().iter().copied().collect()
729 }
730
731 pub(crate) fn set_on_dispose<F>(&self, f: F)
732 where
733 F: FnOnce() + 'static,
734 {
735 *self.on_dispose.borrow_mut() = Some(Box::new(f));
736 }
737}
738
739#[cfg(test)]
740mod tests {
741 use super::*;
742
743 #[test]
744 fn test_apply_result_is_success() {
745 assert!(SnapshotApplyResult::Success.is_success());
746 assert!(!SnapshotApplyResult::Failure.is_success());
747 }
748
749 #[test]
750 fn test_apply_result_is_failure() {
751 assert!(!SnapshotApplyResult::Success.is_failure());
752 assert!(SnapshotApplyResult::Failure.is_failure());
753 }
754
755 #[test]
756 fn test_apply_result_check_success() {
757 SnapshotApplyResult::Success.check(); }
759
760 #[test]
761 #[should_panic(expected = "Snapshot apply failed")]
762 fn test_apply_result_check_failure() {
763 SnapshotApplyResult::Failure.check(); }
765
766 #[test]
767 fn test_merge_read_observers_both_none() {
768 let result = merge_read_observers(None, None);
769 assert!(result.is_none());
770 }
771
772 #[test]
773 fn test_merge_read_observers_one_some() {
774 let observer = Arc::new(|_: &dyn StateObject| {});
775 let result = merge_read_observers(Some(observer.clone()), None);
776 assert!(result.is_some());
777
778 let result = merge_read_observers(None, Some(observer));
779 assert!(result.is_some());
780 }
781
782 #[test]
783 fn test_merge_write_observers_both_none() {
784 let result = merge_write_observers(None, None);
785 assert!(result.is_none());
786 }
787
788 #[test]
789 fn test_merge_write_observers_one_some() {
790 let observer = Arc::new(|_: &dyn StateObject| {});
791 let result = merge_write_observers(Some(observer.clone()), None);
792 assert!(result.is_some());
793
794 let result = merge_write_observers(None, Some(observer));
795 assert!(result.is_some());
796 }
797
798 #[test]
799 fn test_current_snapshot_none_initially() {
800 set_current_snapshot(None);
801 assert!(current_snapshot().is_none());
802 }
803
804 struct TestStateObject {
806 id: usize,
807 }
808
809 impl TestStateObject {
810 fn new(id: usize) -> Arc<Self> {
811 Arc::new(Self { id })
812 }
813 }
814
815 impl StateObject for TestStateObject {
816 fn object_id(&self) -> crate::state::ObjectId {
817 crate::state::ObjectId(self.id)
818 }
819
820 fn first_record(&self) -> Arc<crate::state::StateRecord> {
821 unimplemented!("Not needed for observer tests")
822 }
823
824 fn readable_record(
825 &self,
826 _snapshot_id: SnapshotId,
827 _invalid: &SnapshotIdSet,
828 ) -> Arc<crate::state::StateRecord> {
829 unimplemented!("Not needed for observer tests")
830 }
831
832 fn prepend_state_record(&self, _record: Arc<crate::state::StateRecord>) {
833 unimplemented!("Not needed for observer tests")
834 }
835
836 fn promote_record(&self, _child_id: SnapshotId) -> Result<(), &'static str> {
837 unimplemented!("Not needed for observer tests")
838 }
839
840 fn as_any(&self) -> &dyn std::any::Any {
841 self
842 }
843 }
844
845 #[test]
846 fn test_apply_observer_receives_correct_modified_objects() {
847 use std::sync::Mutex;
848
849 let received_count = Arc::new(Mutex::new(0));
851 let received_snapshot_id = Arc::new(Mutex::new(0));
852
853 let received_count_clone = received_count.clone();
854 let received_snapshot_id_clone = received_snapshot_id.clone();
855
856 let _handle = register_apply_observer(Arc::new(move |modified, snapshot_id| {
858 *received_snapshot_id_clone.lock().unwrap() = snapshot_id;
859 *received_count_clone.lock().unwrap() = modified.len();
860 }));
861
862 let obj1: Arc<dyn StateObject> = TestStateObject::new(42);
864 let obj2: Arc<dyn StateObject> = TestStateObject::new(99);
865 let modified = vec![obj1, obj2];
866
867 notify_apply_observers(&modified, 123);
869
870 assert_eq!(*received_snapshot_id.lock().unwrap(), 123);
872 assert_eq!(*received_count.lock().unwrap(), 2);
873 }
874
875 #[test]
876 fn test_apply_observer_receives_correct_snapshot_id() {
877 use std::sync::Mutex;
878
879 let received_id = Arc::new(Mutex::new(0));
880 let received_id_clone = received_id.clone();
881
882 let _handle = register_apply_observer(Arc::new(move |_, snapshot_id| {
883 *received_id_clone.lock().unwrap() = snapshot_id;
884 }));
885
886 notify_apply_observers(&[], 456);
888
889 assert_eq!(*received_id.lock().unwrap(), 456);
890 }
891
892 #[test]
893 fn test_multiple_apply_observers_all_called() {
894 use std::sync::Mutex;
895
896 let call_count1 = Arc::new(Mutex::new(0));
897 let call_count2 = Arc::new(Mutex::new(0));
898 let call_count3 = Arc::new(Mutex::new(0));
899
900 let call_count1_clone = call_count1.clone();
901 let call_count2_clone = call_count2.clone();
902 let call_count3_clone = call_count3.clone();
903
904 let _handle1 = register_apply_observer(Arc::new(move |_, _| {
906 *call_count1_clone.lock().unwrap() += 1;
907 }));
908
909 let _handle2 = register_apply_observer(Arc::new(move |_, _| {
910 *call_count2_clone.lock().unwrap() += 1;
911 }));
912
913 let _handle3 = register_apply_observer(Arc::new(move |_, _| {
914 *call_count3_clone.lock().unwrap() += 1;
915 }));
916
917 notify_apply_observers(&[], 1);
919
920 assert_eq!(*call_count1.lock().unwrap(), 1);
922 assert_eq!(*call_count2.lock().unwrap(), 1);
923 assert_eq!(*call_count3.lock().unwrap(), 1);
924
925 notify_apply_observers(&[], 2);
927
928 assert_eq!(*call_count1.lock().unwrap(), 2);
930 assert_eq!(*call_count2.lock().unwrap(), 2);
931 assert_eq!(*call_count3.lock().unwrap(), 2);
932 }
933
934 #[test]
935 fn test_apply_observer_not_called_for_empty_modifications() {
936 use std::sync::Mutex;
937
938 let call_count = Arc::new(Mutex::new(0));
939 let call_count_clone = call_count.clone();
940
941 let _handle = register_apply_observer(Arc::new(move |modified, _| {
942 *call_count_clone.lock().unwrap() += 1;
944 assert_eq!(modified.len(), 0);
945 }));
946
947 notify_apply_observers(&[], 1);
949
950 assert_eq!(*call_count.lock().unwrap(), 1);
952 }
953
954 #[test]
955 fn test_observer_handle_drop_removes_correct_observer() {
956 use std::sync::Mutex;
957
958 let calls = Arc::new(Mutex::new(Vec::new()));
960
961 let calls1 = calls.clone();
962 let handle1 = register_apply_observer(Arc::new(move |_, _| {
963 calls1.lock().unwrap().push(1);
964 }));
965
966 let calls2 = calls.clone();
967 let handle2 = register_apply_observer(Arc::new(move |_, _| {
968 calls2.lock().unwrap().push(2);
969 }));
970
971 let calls3 = calls.clone();
972 let handle3 = register_apply_observer(Arc::new(move |_, _| {
973 calls3.lock().unwrap().push(3);
974 }));
975
976 notify_apply_observers(&[], 1);
978 let result = calls.lock().unwrap().clone();
979 assert_eq!(result.len(), 3);
980 assert!(result.contains(&1));
981 assert!(result.contains(&2));
982 assert!(result.contains(&3));
983 calls.lock().unwrap().clear();
984
985 drop(handle2);
987
988 notify_apply_observers(&[], 2);
990 let result = calls.lock().unwrap().clone();
991 assert_eq!(result.len(), 2);
992 assert!(result.contains(&1));
993 assert!(result.contains(&3));
994 assert!(!result.contains(&2));
995 calls.lock().unwrap().clear();
996
997 drop(handle1);
999
1000 notify_apply_observers(&[], 3);
1002 let result = calls.lock().unwrap().clone();
1003 assert_eq!(result.len(), 1);
1004 assert!(result.contains(&3));
1005 calls.lock().unwrap().clear();
1006
1007 drop(handle3);
1009
1010 notify_apply_observers(&[], 4);
1012 assert_eq!(calls.lock().unwrap().len(), 0);
1013 }
1014
1015 #[test]
1016 fn test_observer_handle_drop_in_different_orders() {
1017 use std::sync::Mutex;
1018
1019 {
1021 let calls = Arc::new(Mutex::new(Vec::new()));
1022
1023 let calls1 = calls.clone();
1024 let h1 = register_apply_observer(Arc::new(move |_, _| {
1025 calls1.lock().unwrap().push(1);
1026 }));
1027
1028 let calls2 = calls.clone();
1029 let h2 = register_apply_observer(Arc::new(move |_, _| {
1030 calls2.lock().unwrap().push(2);
1031 }));
1032
1033 let calls3 = calls.clone();
1034 let h3 = register_apply_observer(Arc::new(move |_, _| {
1035 calls3.lock().unwrap().push(3);
1036 }));
1037
1038 drop(h3);
1039 notify_apply_observers(&[], 1);
1040 let result = calls.lock().unwrap().clone();
1041 assert!(result.contains(&1) && result.contains(&2) && !result.contains(&3));
1042 calls.lock().unwrap().clear();
1043
1044 drop(h2);
1045 notify_apply_observers(&[], 2);
1046 let result = calls.lock().unwrap().clone();
1047 assert_eq!(result.len(), 1);
1048 assert!(result.contains(&1));
1049 calls.lock().unwrap().clear();
1050
1051 drop(h1);
1052 notify_apply_observers(&[], 3);
1053 assert_eq!(calls.lock().unwrap().len(), 0);
1054 }
1055
1056 {
1058 let calls = Arc::new(Mutex::new(Vec::new()));
1059
1060 let calls1 = calls.clone();
1061 let h1 = register_apply_observer(Arc::new(move |_, _| {
1062 calls1.lock().unwrap().push(1);
1063 }));
1064
1065 let calls2 = calls.clone();
1066 let h2 = register_apply_observer(Arc::new(move |_, _| {
1067 calls2.lock().unwrap().push(2);
1068 }));
1069
1070 let calls3 = calls.clone();
1071 let h3 = register_apply_observer(Arc::new(move |_, _| {
1072 calls3.lock().unwrap().push(3);
1073 }));
1074
1075 drop(h1);
1076 notify_apply_observers(&[], 1);
1077 let result = calls.lock().unwrap().clone();
1078 assert!(!result.contains(&1) && result.contains(&2) && result.contains(&3));
1079 calls.lock().unwrap().clear();
1080
1081 drop(h2);
1082 notify_apply_observers(&[], 2);
1083 let result = calls.lock().unwrap().clone();
1084 assert_eq!(result.len(), 1);
1085 assert!(result.contains(&3));
1086 calls.lock().unwrap().clear();
1087
1088 drop(h3);
1089 notify_apply_observers(&[], 3);
1090 assert_eq!(calls.lock().unwrap().len(), 0);
1091 }
1092 }
1093
1094 #[test]
1095 fn test_remaining_observers_still_work_after_drop() {
1096 use std::sync::Mutex;
1097
1098 let calls = Arc::new(Mutex::new(Vec::new()));
1099
1100 let calls1 = calls.clone();
1101 let handle1 = register_apply_observer(Arc::new(move |_, snapshot_id| {
1102 calls1.lock().unwrap().push((1, snapshot_id));
1103 }));
1104
1105 let calls2 = calls.clone();
1106 let handle2 = register_apply_observer(Arc::new(move |_, snapshot_id| {
1107 calls2.lock().unwrap().push((2, snapshot_id));
1108 }));
1109
1110 notify_apply_observers(&[], 100);
1112 assert_eq!(calls.lock().unwrap().len(), 2);
1113 calls.lock().unwrap().clear();
1114
1115 drop(handle1);
1117
1118 notify_apply_observers(&[], 200);
1120 assert_eq!(*calls.lock().unwrap(), vec![(2, 200)]);
1121 calls.lock().unwrap().clear();
1122
1123 let calls3 = calls.clone();
1125 let _handle3 = register_apply_observer(Arc::new(move |_, snapshot_id| {
1126 calls3.lock().unwrap().push((3, snapshot_id));
1127 }));
1128
1129 notify_apply_observers(&[], 300);
1131 let result = calls.lock().unwrap().clone();
1132 assert_eq!(result.len(), 2);
1133 assert!(result.contains(&(2, 300)));
1134 assert!(result.contains(&(3, 300)));
1135
1136 drop(handle2);
1137 }
1138
1139 #[test]
1140 fn test_observer_ids_are_unique() {
1141 use std::sync::Mutex;
1142
1143 let ids = Arc::new(Mutex::new(std::collections::HashSet::new()));
1144
1145 let mut handles = Vec::new();
1146
1147 for i in 0..100 {
1151 let ids_clone = ids.clone();
1152 let handle = register_apply_observer(Arc::new(move |_, _| {
1153 ids_clone.lock().unwrap().insert(i);
1154 }));
1155 handles.push(handle);
1156 }
1157
1158 notify_apply_observers(&[], 1);
1160 assert_eq!(ids.lock().unwrap().len(), 100);
1161
1162 for i in (0..100).step_by(2) {
1164 handles.remove(i / 2);
1165 }
1166
1167 ids.lock().unwrap().clear();
1169 notify_apply_observers(&[], 2);
1170 assert_eq!(ids.lock().unwrap().len(), 50);
1171 }
1172
1173 #[test]
1174 fn test_state_object_storage_in_modified_set() {
1175 use crate::state::StateObject;
1176 use std::cell::Cell;
1177
1178 #[allow(dead_code)]
1180 struct TestState {
1181 value: Cell<i32>,
1182 }
1183
1184 impl StateObject for TestState {
1185 fn object_id(&self) -> crate::state::ObjectId {
1186 crate::state::ObjectId(12345)
1187 }
1188
1189 fn first_record(&self) -> Arc<crate::state::StateRecord> {
1190 unimplemented!("Not needed for this test")
1191 }
1192
1193 fn readable_record(
1194 &self,
1195 _snapshot_id: SnapshotId,
1196 _invalid: &SnapshotIdSet,
1197 ) -> Arc<crate::state::StateRecord> {
1198 unimplemented!("Not needed for this test")
1199 }
1200
1201 fn prepend_state_record(&self, _record: Arc<crate::state::StateRecord>) {
1202 unimplemented!("Not needed for this test")
1203 }
1204
1205 fn promote_record(&self, _child_id: SnapshotId) -> Result<(), &'static str> {
1206 unimplemented!("Not needed for this test")
1207 }
1208
1209 fn as_any(&self) -> &dyn std::any::Any {
1210 self
1211 }
1212 }
1213
1214 let state = SnapshotState::new(1, SnapshotIdSet::new(), None, None, false);
1215
1216 let state_obj = Arc::new(TestState {
1218 value: Cell::new(42),
1219 }) as Arc<dyn StateObject>;
1220
1221 state.record_write(state_obj.clone(), 1);
1223
1224 let modified = state.modified.borrow();
1226 assert_eq!(modified.len(), 1);
1227 assert!(modified.contains_key(&12345));
1228
1229 let (stored, writer_id) = modified.get(&12345).unwrap();
1231 assert_eq!(stored.object_id().as_usize(), 12345);
1232 assert_eq!(*writer_id, 1);
1233 }
1234
1235 #[test]
1236 fn test_multiple_writes_to_same_state_object() {
1237 use crate::state::StateObject;
1238 use std::cell::Cell;
1239
1240 #[allow(dead_code)]
1241 struct TestState {
1242 value: Cell<i32>,
1243 }
1244
1245 impl StateObject for TestState {
1246 fn object_id(&self) -> crate::state::ObjectId {
1247 crate::state::ObjectId(99999)
1248 }
1249
1250 fn first_record(&self) -> Arc<crate::state::StateRecord> {
1251 unimplemented!()
1252 }
1253
1254 fn readable_record(
1255 &self,
1256 _snapshot_id: SnapshotId,
1257 _invalid: &SnapshotIdSet,
1258 ) -> Arc<crate::state::StateRecord> {
1259 unimplemented!()
1260 }
1261
1262 fn prepend_state_record(&self, _record: Arc<crate::state::StateRecord>) {
1263 unimplemented!()
1264 }
1265
1266 fn promote_record(&self, _child_id: SnapshotId) -> Result<(), &'static str> {
1267 unimplemented!()
1268 }
1269
1270 fn as_any(&self) -> &dyn std::any::Any {
1271 self
1272 }
1273 }
1274
1275 let state = SnapshotState::new(1, SnapshotIdSet::new(), None, None, false);
1276 let state_obj = Arc::new(TestState {
1277 value: Cell::new(100),
1278 }) as Arc<dyn StateObject>;
1279
1280 state.record_write(state_obj.clone(), 1);
1282 assert_eq!(state.modified.borrow().len(), 1);
1283
1284 state.record_write(state_obj.clone(), 2);
1286 let modified = state.modified.borrow();
1287 assert_eq!(modified.len(), 1);
1288 assert!(modified.contains_key(&99999));
1289 let (_, writer_id) = modified.get(&99999).unwrap();
1290 assert_eq!(*writer_id, 2);
1291 }
1292}