Skip to main content

multilinear/
edit.rs

1use bit_set::BitSet;
2use event_simulation::EditableSimulationInfo;
3use std::ops::Deref;
4
5use crate::{
6    Aspect, Change, EventEdit, From, InvalidChangeError, MultilinearInfo, MultilinearState, To,
7};
8
9/// A wrapper type for `MultilinearInfo` which only supports safe edits that don't invalidate existing simulation states.
10pub struct EditMultilinearInfo<'a> {
11    info: &'a mut MultilinearInfo,
12}
13
14impl Deref for EditMultilinearInfo<'_> {
15    type Target = MultilinearInfo;
16    fn deref(&self) -> &MultilinearInfo {
17        self.info
18    }
19}
20
21impl<'a> EditMultilinearInfo<'a> {
22    pub(crate) fn new(info: &'a mut MultilinearInfo) -> Self {
23        Self { info }
24    }
25
26    /// Adds a new aspect to the multilinear system and returns its identifier.
27    pub fn add_aspect(&mut self) -> Aspect {
28        self.info.add_aspect()
29    }
30
31    /// Adds a new event to the multilinear system and returns an editable event info.
32    pub fn add_event(&mut self) -> EventEdit<'_> {
33        self.info.add_event()
34    }
35
36    /// Adds a new event with multiple change sets and returns the editable event info.
37    pub fn add_event_with_changes(
38        &mut self,
39        changes: &[&[Change]],
40    ) -> Result<EventEdit<'_>, InvalidChangeError> {
41        let mut event = self.info.add_event();
42        for &change_set in changes {
43            event.add_change(change_set)?;
44        }
45        Ok(event)
46    }
47}
48
49impl EditableSimulationInfo for MultilinearInfo {
50    type Edit<'a> = EditMultilinearInfo<'a>;
51
52    unsafe fn edit(&mut self) -> EditMultilinearInfo<'_> {
53        EditMultilinearInfo::new(self)
54    }
55
56    unsafe fn refresh_state(&self, state: &mut MultilinearState) {
57        state.values.resize(self.aspects.len(), 0);
58
59        state.callables = BitSet::new();
60        state.revertables = BitSet::new();
61
62        for (index, event) in self.events.iter().enumerate() {
63            if event.check_action::<From>(&state.values) {
64                state.callables.insert(index);
65            }
66            if event.check_action::<To>(&state.values) {
67                state.revertables.insert(index);
68            }
69        }
70    }
71}