fastsim_core/utils/
tracked_state.rs

1use super::*;
2
3/// Methods for manipulating states that need to be implemented up the hierarchy
4pub trait TrackedStateMethods {
5    /// Mark states as [State::Fresh]
6    /// # Arguments
7    /// - `loc`: closure that returns file and line number where called
8    fn mark_fresh<F: Fn() -> String>(&mut self, loc: F) -> anyhow::Result<()>;
9
10    /// Ensure [State::Fresh] and reset to [State::Stale]
11    /// # Arguments
12    /// - `loc`: closure that returns file and line number where called
13    fn check_and_reset<F: Fn() -> String>(&mut self, loc: F) -> anyhow::Result<()>;
14}
15
16impl<T> TrackedStateMethods for TrackedState<T>
17where
18    T: std::fmt::Debug + Clone + PartialEq + Default,
19{
20    fn check_and_reset<F: Fn() -> String>(&mut self, loc: F) -> anyhow::Result<()> {
21        self.ensure_fresh(loc)?;
22        self.mark_stale();
23        Ok(())
24    }
25
26    /// Verify that state is [State::Stale] and mark state as [State::Fresh]
27    /// without updating
28    /// # Arguments
29    /// - `loc`: closure that returns file and line number where called
30    fn mark_fresh<F: Fn() -> String>(&mut self, loc: F) -> anyhow::Result<()> {
31        self.ensure_stale(loc)?;
32        self.1 = StateStatus::Fresh;
33        Ok(())
34    }
35}
36
37/// Enum for tracking mutation
38#[derive(Clone, Default, Debug, PartialEq, IsVariant, derive_more::From, TryInto)]
39pub enum StateStatus {
40    /// Updated in this time step
41    #[default]
42    Fresh,
43    /// Not yet updated in this time step
44    Stale,
45}
46
47#[derive(Default, PartialEq, Clone, Debug)]
48/// Struct for storing state variable and ensuring one mutation per
49/// initialization or reset -- i.e. one mutation per time step
50pub struct TrackedState<T>(
51    /// Value
52    T,
53    /// Update status
54    StateStatus,
55);
56
57/// Provides methods to guarantee that states are updated once and only once per time step
58impl<T> TrackedState<T>
59where
60    T: std::fmt::Debug + Clone + PartialEq + Default,
61{
62    pub fn new(value: T) -> Self {
63        Self(value, Default::default())
64    }
65
66    fn is_fresh(&self) -> bool {
67        self.1.is_fresh()
68    }
69
70    fn is_stale(&self) -> bool {
71        self.1.is_stale()
72    }
73
74    /// # Arguments
75    /// - `loc`: closure that returns file and line number where called
76    fn ensure_fresh<F: Fn() -> String>(&self, loc: F) -> anyhow::Result<()> {
77        ensure!(
78            self.is_fresh(),
79            format!(
80                "{}\nState variable has not been updated. This is a bug in `fastsim-core`",
81                loc()
82            )
83        );
84        Ok(())
85    }
86
87    /// # Arguments
88    /// - `loc`: closure that returns file and line number where called
89    fn ensure_stale<F: Fn() -> String>(&self, loc: F) -> anyhow::Result<()> {
90        ensure!(
91            self.is_stale(),
92            format!(
93                "{}\nState variable has already been updated. This is a bug in `fastsim-core`",
94                loc()
95            )
96        );
97        Ok(())
98    }
99
100    /// Reset the tracked state to [State::Stale] for the next update after
101    /// verifying that is has been updated
102    pub fn mark_stale(&mut self) {
103        self.1 = StateStatus::Stale;
104    }
105
106    // Note that `anyhow::Error` is fine here because this should result only in
107    // logic errors and not runtime errors for end users
108    /// Update the value of the tracked state after verifying that it has not
109    /// already been updated
110    /// # Arguments
111    /// - `value`: new value
112    /// - `loc`: closure that returns file and line number where called
113    pub fn update<F: Fn() -> String>(&mut self, value: T, loc: F) -> anyhow::Result<()> {
114        self.ensure_stale(loc)?;
115        self.0 = value;
116        self.1 = StateStatus::Fresh;
117        Ok(())
118    }
119
120    // Note that `anyhow::Error` is fine here because this should result only in
121    // logic errors and not runtime errors for end users
122    // TODO: try to get rid of all uses of this method!
123    /// Update the value of the tracked state without verifying that it has not
124    /// already been updated -- to be used sparingly!
125    /// # Arguments
126    /// - `value`: new value
127    /// - `loc`: closure that returns file and line number where called
128    pub fn update_unchecked<F: Fn() -> String>(&mut self, value: T, _loc: F) -> anyhow::Result<()> {
129        self.0 = value;
130        self.1 = StateStatus::Fresh;
131        Ok(())
132    }
133
134    /// Check that value has been updated and then return as a result
135    /// # Arguments
136    /// - `loc`: call site location filename and line number
137    pub fn get_fresh<F: Fn() -> String>(&self, loc: F) -> anyhow::Result<&T> {
138        self.ensure_fresh(loc)?;
139        Ok(&self.0)
140    }
141
142    /// Check that value has **not** been updated and then return as a result
143    /// # Arguments
144    /// - `loc`: call site location filename and line number
145    pub fn get_stale<F: Fn() -> String>(&self, loc: F) -> anyhow::Result<&T> {
146        self.ensure_stale(loc)?;
147        Ok(&self.0)
148    }
149}
150
151/// State methods that allow for `+=`
152impl<T: std::fmt::Debug + Clone + PartialEq + Default + std::ops::AddAssign> TrackedState<T> {
153    // Note that `anyhow::Error` is fine here because this should result only in
154    // logic errors and not runtime errors for end users
155    /// Update the value of the tracked state
156    /// # Arguments
157    /// - `value`: new value
158    /// - `loc`: closure that returns file and line number where called
159    pub fn increment<F: Fn() -> String>(&mut self, value: T, loc: F) -> anyhow::Result<()> {
160        self.ensure_stale(loc)?;
161        self.0 += value;
162        self.1 = StateStatus::Fresh;
163        Ok(())
164    }
165}
166
167// Custom serialization
168impl<T> Serialize for TrackedState<T>
169where
170    T: std::fmt::Debug + Clone + PartialEq + for<'de> Deserialize<'de> + Serialize + Default,
171{
172    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
173    where
174        S: serde::Serializer,
175    {
176        self.0.serialize(serializer)
177    }
178}
179
180impl<'de, T> Deserialize<'de> for TrackedState<T>
181where
182    T: std::fmt::Debug + Clone + PartialEq + Deserialize<'de> + Serialize + Default,
183{
184    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
185    where
186        D: serde::Deserializer<'de>,
187    {
188        let value: T = T::deserialize(deserializer)?;
189
190        Ok(Self(value, Default::default()))
191    }
192}
193
194#[cfg(test)]
195mod test_tracked_state {
196    use super::*;
197
198    #[test]
199    #[should_panic]
200    fn test_update_fresh() {
201        let mut pwr = TrackedState::new(si::Power::ZERO);
202        pwr.update(uc::W * 10.0, || format_dbg!()).unwrap();
203    }
204
205    #[test]
206    fn test_update_stale() {
207        let mut pwr = TrackedState::new(si::Power::ZERO);
208        pwr.mark_stale();
209        pwr.update(uc::W * 10.0, || format_dbg!()).unwrap();
210    }
211
212    #[test]
213    fn test_get_ok() {
214        let mut pwr = TrackedState::new(si::Power::ZERO);
215        pwr.get_fresh(|| format_dbg!()).unwrap();
216        pwr.mark_stale();
217        pwr.get_stale(|| format_dbg!()).unwrap();
218    }
219
220    #[test]
221    #[should_panic]
222    fn test_get_stale_fail() {
223        let pwr = TrackedState::new(si::Power::ZERO);
224        pwr.get_stale(|| format_dbg!()).unwrap();
225    }
226
227    #[test]
228    #[should_panic]
229    fn test_get_fresh_fail() {
230        let mut pwr = TrackedState::new(si::Power::ZERO);
231        pwr.mark_stale();
232        pwr.get_fresh(|| format_dbg!()).unwrap();
233    }
234}