use super::*;
use crate::collections::map::HashMap;
use crate::state::{StateRecord, PREEXISTING_SNAPSHOT_ID};
use std::rc::Rc;
use std::sync::Arc;
pub(super) fn find_record_by_id(
head: &Rc<StateRecord>,
target: SnapshotId,
) -> Option<Rc<StateRecord>> {
let mut cursor = Some(Rc::clone(head));
while let Some(record) = cursor {
if !record.is_tombstone() && record.snapshot_id() == target {
return Some(record);
}
cursor = record.next();
}
None
}
pub(super) fn find_previous_record(
head: &Rc<StateRecord>,
base_snapshot_id: SnapshotId,
invalid: &SnapshotIdSet,
) -> (Option<Rc<StateRecord>>, bool) {
let mut cursor = Some(Rc::clone(head));
let mut best: Option<Rc<StateRecord>> = None;
let mut fallback: Option<Rc<StateRecord>> = None;
let mut found_base = false;
while let Some(record) = cursor {
if !record.is_tombstone() {
let id = record.snapshot_id();
let is_valid = id <= base_snapshot_id && !invalid.get(id);
if is_valid {
found_base = true;
let replace = best
.as_ref()
.map(|current| current.snapshot_id() < id)
.unwrap_or(true);
if replace {
best = Some(record.clone());
}
}
if fallback.is_none() {
fallback = Some(record.clone());
}
}
cursor = record.next();
}
(best.or(fallback), found_base)
}
enum ApplyOperation {
PromoteChild {
object_id: StateObjectId,
state: Arc<dyn StateObject>,
writer_id: SnapshotId,
},
PromoteExisting {
object_id: StateObjectId,
state: Arc<dyn StateObject>,
source_id: SnapshotId,
applied: Rc<StateRecord>,
},
CommitMerged {
object_id: StateObjectId,
state: Arc<dyn StateObject>,
merged: Rc<StateRecord>,
applied: Rc<StateRecord>,
},
}
#[allow(clippy::arc_with_non_send_sync)]
pub struct MutableSnapshot {
state: SnapshotState,
base_parent_id: SnapshotId,
nested_count: Cell<usize>,
applied: Cell<bool>,
}
impl MutableSnapshot {
pub(crate) fn from_parts(
id: SnapshotId,
invalid: SnapshotIdSet,
read_observer: Option<ReadObserver>,
write_observer: Option<WriteObserver>,
base_parent_id: SnapshotId,
runtime_tracked: bool,
) -> Arc<Self> {
Arc::new(Self {
state: SnapshotState::new(id, invalid, read_observer, write_observer, runtime_tracked),
base_parent_id,
nested_count: Cell::new(0),
applied: Cell::new(false),
})
}
pub fn new_root(
read_observer: Option<ReadObserver>,
write_observer: Option<WriteObserver>,
) -> Arc<Self> {
GlobalSnapshot::get_or_create().take_nested_mutable_snapshot(read_observer, write_observer)
}
pub fn new(
id: SnapshotId,
invalid: SnapshotIdSet,
read_observer: Option<ReadObserver>,
write_observer: Option<WriteObserver>,
base_parent_id: SnapshotId,
) -> Arc<Self> {
Self::from_parts(
id,
invalid,
read_observer,
write_observer,
base_parent_id,
false,
)
}
fn validate_not_applied(&self) {
if self.applied.get() {
panic!("Snapshot has already been applied");
}
}
fn validate_not_disposed(&self) {
if self.state.disposed.get() {
panic!("Snapshot has been disposed");
}
}
pub fn snapshot_id(&self) -> SnapshotId {
self.state.id.get()
}
pub fn invalid(&self) -> SnapshotIdSet {
self.state.invalid.borrow().clone()
}
pub fn read_only(&self) -> bool {
false
}
pub(crate) fn set_on_dispose<F>(&self, f: F)
where
F: FnOnce() + 'static,
{
self.state.set_on_dispose(f);
}
pub fn root_mutable(self: &Arc<Self>) -> Arc<Self> {
self.clone()
}
pub fn enter<T>(self: &Arc<Self>, f: impl FnOnce() -> T) -> T {
let previous = current_snapshot();
set_current_snapshot(Some(AnySnapshot::Mutable(self.clone())));
let result = f();
set_current_snapshot(previous);
result
}
pub fn take_nested_snapshot(
self: &Arc<Self>,
read_observer: Option<ReadObserver>,
) -> Arc<ReadonlySnapshot> {
self.validate_not_disposed();
self.validate_not_applied();
let merged_observer = merge_read_observers(read_observer, self.state.read_observer.clone());
let nested = ReadonlySnapshot::new(
self.state.id.get(),
self.state.invalid.borrow().clone(),
merged_observer,
);
self.nested_count.set(self.nested_count.get() + 1);
let parent_weak = Arc::downgrade(self);
nested.set_on_dispose(move || {
if let Some(parent) = parent_weak.upgrade() {
let cur = parent.nested_count.get();
if cur > 0 {
parent.nested_count.set(cur - 1);
}
}
});
nested
}
pub fn has_pending_changes(&self) -> bool {
!self.state.modified.borrow().is_empty()
}
pub fn pending_children(&self) -> Vec<SnapshotId> {
self.state.pending_children()
}
pub fn has_pending_children(&self) -> bool {
self.state.has_pending_children()
}
pub fn dispose(&self) {
if !self.state.disposed.get() && self.nested_count.get() == 0 {
self.state.dispose();
}
}
pub fn record_read(&self, state: &dyn StateObject) {
self.state.record_read(state);
}
pub fn record_write(&self, state: Arc<dyn StateObject>) {
self.validate_not_applied();
self.validate_not_disposed();
self.state.record_write(state, self.state.id.get());
}
pub fn notify_objects_initialized(&self) {
if !self.applied.get() && !self.state.disposed.get() {
}
}
pub fn close(&self) {
self.state.disposed.set(true);
}
pub fn is_disposed(&self) -> bool {
self.state.disposed.get()
}
pub fn apply(&self) -> SnapshotApplyResult {
if self.state.disposed.get() {
return SnapshotApplyResult::Failure;
}
if self.applied.get() {
return SnapshotApplyResult::Failure;
}
let modified = self.state.modified.borrow();
if modified.is_empty() {
self.applied.set(true);
self.state.dispose();
return SnapshotApplyResult::Success;
}
let this_id = self.state.id.get();
let mut modified_objects: Vec<(StateObjectId, Arc<dyn StateObject>, SnapshotId)> =
Vec::with_capacity(modified.len());
for (&obj_id, (obj, writer_id)) in modified.iter() {
modified_objects.push((obj_id, obj.clone(), *writer_id));
}
drop(modified);
let parent_snapshot = GlobalSnapshot::get_or_create();
let parent_snapshot_id = parent_snapshot.snapshot_id();
let parent_invalid = parent_snapshot.invalid();
drop(parent_snapshot);
let next_invalid = super::runtime::open_snapshots().clear(parent_snapshot_id);
let this_invalid_for_optimistic = self.state.invalid.borrow().clone();
let optimistic = super::optimistic_merges(
parent_snapshot_id,
self.base_parent_id,
&modified_objects,
&next_invalid,
&this_invalid_for_optimistic,
);
let mut operations: Vec<ApplyOperation> = Vec::with_capacity(modified_objects.len());
for (obj_id, state, writer_id) in &modified_objects {
let head = state.first_record();
let applied = match find_record_by_id(&head, *writer_id) {
Some(record) => record,
None => return SnapshotApplyResult::Failure,
};
let current =
crate::state::readable_record_for(&head, parent_snapshot_id, &next_invalid)
.unwrap_or_else(|| state.readable_record(parent_snapshot_id, &parent_invalid));
let this_invalid = self.state.invalid.borrow();
let (previous_opt, found_base) =
find_previous_record(&head, self.base_parent_id, &this_invalid);
drop(this_invalid);
let Some(previous) = previous_opt else {
return SnapshotApplyResult::Failure;
};
if !found_base || previous.snapshot_id() == PREEXISTING_SNAPSHOT_ID {
operations.push(ApplyOperation::PromoteChild {
object_id: *obj_id,
state: state.clone(),
writer_id: *writer_id,
});
continue;
}
if Rc::ptr_eq(¤t, &previous) {
operations.push(ApplyOperation::PromoteChild {
object_id: *obj_id,
state: state.clone(),
writer_id: *writer_id,
});
continue;
}
let merged = if let Some(candidate) = optimistic
.as_ref()
.and_then(|map| map.get(&(Rc::as_ptr(¤t) as usize)))
.cloned()
{
candidate
} else {
match state.merge_records(
Rc::clone(&previous),
Rc::clone(¤t),
Rc::clone(&applied),
) {
Some(record) => record,
None => return SnapshotApplyResult::Failure,
}
};
if Rc::ptr_eq(&merged, &applied) {
operations.push(ApplyOperation::PromoteChild {
object_id: *obj_id,
state: state.clone(),
writer_id: *writer_id,
});
} else if Rc::ptr_eq(&merged, ¤t) {
operations.push(ApplyOperation::PromoteExisting {
object_id: *obj_id,
state: state.clone(),
source_id: current.snapshot_id(),
applied: applied.clone(),
});
} else {
operations.push(ApplyOperation::CommitMerged {
object_id: *obj_id,
state: state.clone(),
merged: merged.clone(),
applied: applied.clone(),
});
}
}
let mut applied_info: Vec<(StateObjectId, Arc<dyn StateObject>, SnapshotId)> =
Vec::with_capacity(operations.len());
for operation in operations {
match operation {
ApplyOperation::PromoteChild {
object_id,
state,
writer_id,
} => {
if state.promote_record(writer_id).is_err() {
return SnapshotApplyResult::Failure;
}
let new_head_id = state.first_record().snapshot_id();
applied_info.push((object_id, state, new_head_id));
}
ApplyOperation::PromoteExisting {
object_id,
state,
source_id,
applied,
} => {
if state.promote_record(source_id).is_err() {
return SnapshotApplyResult::Failure;
}
applied.set_tombstone(true);
applied.clear_value();
let new_head_id = state.first_record().snapshot_id();
applied_info.push((object_id, state, new_head_id));
}
ApplyOperation::CommitMerged {
object_id,
state,
merged,
applied,
} => {
let Ok(new_head_id) = state.commit_merged_record(merged) else {
return SnapshotApplyResult::Failure;
};
applied.set_tombstone(true);
applied.clear_value();
applied_info.push((object_id, state, new_head_id));
}
}
}
for (obj_id, _, head_id) in &applied_info {
super::set_last_write(*obj_id, *head_id);
}
self.applied.set(true);
self.state.dispose();
for (_, state, _) in &applied_info {
super::EXTRA_STATE_OBJECTS.with(|cell| {
cell.borrow_mut().add_trait_object(state);
});
}
let observer_states: Vec<Arc<dyn StateObject>> = applied_info
.iter()
.map(|(_, state, _)| state.clone())
.collect();
super::notify_apply_observers(&observer_states, this_id);
SnapshotApplyResult::Success
}
pub fn take_nested_mutable_snapshot(
self: &Arc<Self>,
read_observer: Option<ReadObserver>,
write_observer: Option<WriteObserver>,
) -> Arc<NestedMutableSnapshot> {
self.validate_not_disposed();
self.validate_not_applied();
let merged_read = merge_read_observers(read_observer, self.state.read_observer.clone());
let merged_write = merge_write_observers(write_observer, self.state.write_observer.clone());
let parent_id = self.state.id.get();
let current_invalid = self.state.invalid.borrow().clone();
let (new_id, _runtime_invalid) = allocate_snapshot();
let parent_invalid_with_child = current_invalid.set(new_id);
self.state.invalid.replace(parent_invalid_with_child);
let invalid = current_invalid.add_range(parent_id + 1, new_id);
let self_weak = Arc::downgrade(self);
let nested = NestedMutableSnapshot::new(
new_id,
invalid,
merged_read,
merged_write,
self_weak,
self.state.id.get(), );
self.nested_count.set(self.nested_count.get() + 1);
self.state.add_pending_child(new_id);
let parent_weak = Arc::downgrade(self);
nested.set_on_dispose({
let child_id = new_id;
move || {
if let Some(parent) = parent_weak.upgrade() {
if parent.nested_count.get() > 0 {
parent
.nested_count
.set(parent.nested_count.get().saturating_sub(1));
}
let mut invalid = parent.state.invalid.borrow_mut();
let new_set = invalid.clone().clear(child_id);
*invalid = new_set;
parent.state.remove_pending_child(child_id);
}
}
});
nested
}
pub(crate) fn merge_child_modifications(
&self,
child_modified: &HashMap<StateObjectId, (Arc<dyn StateObject>, SnapshotId)>,
) -> Result<(), ()> {
{
let parent_mod = self.state.modified.borrow();
for key in child_modified.keys() {
if parent_mod.contains_key(key) {
return Err(());
}
}
}
let mut parent_mod = self.state.modified.borrow_mut();
for (key, value) in child_modified.iter() {
parent_mod.entry(*key).or_insert_with(|| value.clone());
}
Ok(())
}
}
#[cfg(test)]
impl MutableSnapshot {
pub(crate) fn debug_modified_objects(
&self,
) -> Vec<(StateObjectId, Arc<dyn StateObject>, SnapshotId)> {
let modified = self.state.modified.borrow();
modified
.iter()
.map(|(&obj_id, (state, writer_id))| (obj_id, state.clone(), *writer_id))
.collect()
}
pub(crate) fn debug_base_parent_id(&self) -> SnapshotId {
self.base_parent_id
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::snapshot_v2::runtime::TestRuntimeGuard;
use crate::state::{NeverEqual, SnapshotMutableState, StateObject};
use std::sync::Arc;
fn reset_runtime() -> TestRuntimeGuard {
reset_runtime_for_tests()
}
fn new_state(initial: i32) -> Arc<SnapshotMutableState<i32>> {
SnapshotMutableState::new_in_arc(initial, Arc::new(NeverEqual))
}
struct MockStateObject;
fn mock_state_record() -> Rc<crate::state::StateRecord> {
crate::state::StateRecord::new(crate::state::PREEXISTING_SNAPSHOT_ID, (), None)
}
impl StateObject for MockStateObject {
fn object_id(&self) -> crate::state::ObjectId {
crate::state::ObjectId(0)
}
fn first_record(&self) -> Rc<crate::state::StateRecord> {
mock_state_record()
}
fn readable_record(
&self,
_snapshot_id: crate::snapshot_id_set::SnapshotId,
_invalid: &SnapshotIdSet,
) -> Rc<crate::state::StateRecord> {
mock_state_record()
}
fn prepend_state_record(&self, _record: Rc<crate::state::StateRecord>) {}
fn promote_record(
&self,
_child_id: crate::snapshot_id_set::SnapshotId,
) -> Result<(), &'static str> {
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[test]
fn test_mutable_snapshot_creation() {
let _guard = reset_runtime();
let snapshot = MutableSnapshot::new(1, SnapshotIdSet::new(), None, None, 0);
assert_eq!(snapshot.snapshot_id(), 1);
assert!(!snapshot.read_only());
assert!(!snapshot.is_disposed());
assert!(!snapshot.applied.get());
}
#[test]
fn test_mutable_snapshot_no_pending_changes_initially() {
let _guard = reset_runtime();
let snapshot = MutableSnapshot::new(1, SnapshotIdSet::new(), None, None, 0);
assert!(!snapshot.has_pending_changes());
}
#[test]
fn test_mutable_snapshot_enter() {
let _guard = reset_runtime();
let snapshot = MutableSnapshot::new(1, SnapshotIdSet::new(), None, None, 0);
set_current_snapshot(None);
assert!(current_snapshot().is_none());
snapshot.enter(|| {
let current = current_snapshot();
assert!(current.is_some());
assert_eq!(current.unwrap().snapshot_id(), 1);
});
assert!(current_snapshot().is_none());
}
#[test]
fn test_mutable_snapshot_read_observer() {
let _guard = reset_runtime();
use std::sync::{Arc as StdArc, Mutex};
let read_count = StdArc::new(Mutex::new(0));
let read_count_clone = read_count.clone();
let observer = Arc::new(move |_: &dyn StateObject| {
*read_count_clone.lock().unwrap() += 1;
});
let snapshot = MutableSnapshot::new(1, SnapshotIdSet::new(), Some(observer), None, 0);
let mock_state = MockStateObject;
snapshot.record_read(&mock_state);
snapshot.record_read(&mock_state);
assert_eq!(*read_count.lock().unwrap(), 2);
}
#[test]
fn test_mutable_snapshot_write_observer() {
let _guard = reset_runtime();
use std::sync::{Arc as StdArc, Mutex};
let write_count = StdArc::new(Mutex::new(0));
let write_count_clone = write_count.clone();
let observer = Arc::new(move |_: &dyn StateObject| {
*write_count_clone.lock().unwrap() += 1;
});
let snapshot = MutableSnapshot::new(1, SnapshotIdSet::new(), None, Some(observer), 0);
let mock_state = Arc::new(MockStateObject);
snapshot.record_write(mock_state.clone());
snapshot.record_write(mock_state.clone());
assert!(*write_count.lock().unwrap() >= 1);
}
#[test]
fn test_mutable_snapshot_apply_empty() {
let _guard = reset_runtime();
let snapshot = MutableSnapshot::new(1, SnapshotIdSet::new(), None, None, 0);
let result = snapshot.apply();
assert!(result.is_success());
assert!(snapshot.applied.get());
}
#[test]
fn test_mutable_snapshot_apply_twice_fails() {
let _guard = reset_runtime();
let snapshot = MutableSnapshot::new(1, SnapshotIdSet::new(), None, None, 0);
snapshot.apply().check();
let result = snapshot.apply();
assert!(result.is_failure());
}
#[test]
fn test_mutable_snapshot_nested_readonly() {
let _guard = reset_runtime();
let parent = MutableSnapshot::new(1, SnapshotIdSet::new(), None, None, 0);
let nested = parent.take_nested_snapshot(None);
assert_eq!(nested.snapshot_id(), 1);
assert!(nested.read_only());
assert_eq!(parent.nested_count.get(), 1);
}
#[test]
fn test_mutable_snapshot_nested_mutable() {
let _guard = reset_runtime();
let parent = MutableSnapshot::new(1, SnapshotIdSet::new(), None, None, 0);
let nested = parent.take_nested_mutable_snapshot(None, None);
assert!(nested.snapshot_id() > parent.snapshot_id());
assert!(!nested.read_only());
assert_eq!(parent.nested_count.get(), 1);
}
#[test]
fn test_mutable_snapshot_nested_mutable_dispose_clears_invalid() {
let _guard = reset_runtime();
let parent = MutableSnapshot::new(1, SnapshotIdSet::new(), None, None, 0);
let nested = parent.take_nested_mutable_snapshot(None, None);
let child_id = nested.snapshot_id();
assert!(parent.state.invalid.borrow().get(child_id));
nested.dispose();
assert_eq!(parent.nested_count.get(), 0);
assert!(!parent.state.invalid.borrow().get(child_id));
}
#[test]
fn test_mutable_snapshot_nested_dispose() {
let _guard = reset_runtime();
let parent = MutableSnapshot::new(1, SnapshotIdSet::new(), None, None, 0);
let nested = parent.take_nested_snapshot(None);
assert_eq!(parent.nested_count.get(), 1);
nested.dispose();
assert_eq!(parent.nested_count.get(), 0);
}
#[test]
#[should_panic(expected = "Snapshot has already been applied")]
fn test_mutable_snapshot_write_after_apply_panics() {
let _guard = reset_runtime();
let snapshot = MutableSnapshot::new(1, SnapshotIdSet::new(), None, None, 0);
snapshot.apply().check();
let mock_state = Arc::new(MockStateObject);
snapshot.record_write(mock_state);
}
#[test]
#[should_panic(expected = "Snapshot has been disposed")]
fn test_mutable_snapshot_write_after_dispose_panics() {
let _guard = reset_runtime();
let snapshot = MutableSnapshot::new(1, SnapshotIdSet::new(), None, None, 0);
snapshot.dispose();
let mock_state = Arc::new(MockStateObject);
snapshot.record_write(mock_state);
}
#[test]
fn test_mutable_snapshot_dispose() {
let _guard = reset_runtime();
let snapshot = MutableSnapshot::new(1, SnapshotIdSet::new(), None, None, 0);
assert!(!snapshot.is_disposed());
snapshot.dispose();
assert!(snapshot.is_disposed());
}
#[test]
fn test_mutable_snapshot_apply_observer() {
let _guard = reset_runtime();
use std::sync::{Arc as StdArc, Mutex};
let applied_count = StdArc::new(Mutex::new(0));
let applied_count_clone = applied_count.clone();
let observer = Rc::new(
move |_modified: &[Arc<dyn StateObject>], _snapshot_id: SnapshotId| {
*applied_count_clone.lock().unwrap() += 1;
},
);
let _handle = register_apply_observer(observer);
let snapshot = MutableSnapshot::new(1, SnapshotIdSet::new(), None, None, 0);
let state = new_state(0);
snapshot.enter(|| state.set(10));
snapshot.apply().check();
assert_eq!(*applied_count.lock().unwrap(), 1);
}
#[test]
fn test_mutable_conflict_detection_same_object() {
let _guard = reset_runtime();
let global = GlobalSnapshot::get_or_create();
let state = new_state(0);
let s1 = global.take_nested_mutable_snapshot(None, None);
s1.enter(|| state.set(1));
let s2 = global.take_nested_mutable_snapshot(None, None);
s2.enter(|| state.set(2));
assert!(s1.apply().is_success());
assert!(s2.apply().is_failure());
}
#[test]
fn test_mutable_no_conflict_different_objects() {
let _guard = reset_runtime();
let global = GlobalSnapshot::get_or_create();
let state1 = new_state(0);
let state2 = new_state(0);
let s1 = global.take_nested_mutable_snapshot(None, None);
s1.enter(|| state1.set(10));
let s2 = global.take_nested_mutable_snapshot(None, None);
s2.enter(|| state2.set(20));
assert!(s1.apply().is_success());
assert!(s2.apply().is_success());
}
}