fastsim_core/utils/
tracked_state.rs

1use super::*;
2
3pub trait CheckAndResetState {
4    /// Ensure [State::Fresh] and reset to [State::Stale]
5    /// # Arguments
6    /// - `loc`: closure that returns file and line number where called
7    fn check_and_reset<F: Fn() -> String>(&mut self, loc: F) -> anyhow::Result<()>;
8}
9
10impl<T> CheckAndResetState for TrackedState<T>
11where
12    T: std::fmt::Debug + Clone + PartialEq + Default,
13{
14    fn check_and_reset<F: Fn() -> String>(&mut self, loc: F) -> anyhow::Result<()> {
15        self.ensure_fresh(loc)?;
16        self.mark_stale();
17        Ok(())
18    }
19}
20
21/// Enum for tracking mutation
22#[derive(Clone, Default, Debug, PartialEq, IsVariant, derive_more::From, TryInto)]
23pub enum StateStatus {
24    /// Updated in this time step
25    #[default]
26    Fresh,
27    /// Not yet updated in this time step
28    Stale,
29}
30
31#[derive(Default, PartialEq, Clone, Debug)]
32/// Struct for storing state variable and ensuring one mutation per
33/// initialization or reset -- i.e. one mutation per time step
34pub struct TrackedState<T>(
35    /// Value
36    T,
37    /// Update status
38    StateStatus,
39);
40
41/// Provides methods to guarantee that states are updated once and only once per time step
42impl<T> TrackedState<T>
43where
44    T: std::fmt::Debug + Clone + PartialEq + Default,
45{
46    pub fn new(value: T) -> Self {
47        Self(value, Default::default())
48    }
49
50    fn is_fresh(&self) -> bool {
51        self.1.is_fresh()
52    }
53
54    fn is_stale(&self) -> bool {
55        self.1.is_stale()
56    }
57
58    /// # Arguments
59    /// - `loc`: closure that returns file and line number where called
60    fn ensure_fresh<F: Fn() -> String>(&self, loc: F) -> anyhow::Result<()> {
61        ensure!(
62            self.is_fresh(),
63            format!(
64                "{}\nState variable has not been updated. This is a bug in `fastsim-core`",
65                loc()
66            )
67        );
68        Ok(())
69    }
70
71    /// # Arguments
72    /// - `loc`: closure that returns file and line number where called
73    fn ensure_stale<F: Fn() -> String>(&self, loc: F) -> anyhow::Result<()> {
74        ensure!(
75            self.is_stale(),
76            format!(
77                "{}\nState variable has already been updated. This is a bug in `fastsim-core`",
78                loc()
79            )
80        );
81        Ok(())
82    }
83
84    /// Reset the tracked state to [State::Stale] for the next update after
85    /// verifying that is has been updated
86    pub fn mark_stale(&mut self) {
87        self.1 = StateStatus::Stale;
88    }
89
90    // Note that `anyhow::Error` is fine here because this should result only in
91    // logic errors and not runtime errors for end users
92    /// Update the value of the tracked state after verifying that it has not
93    /// already been updated
94    /// # Arguments
95    /// - `value`: new value
96    /// - `loc`: closure that returns file and line number where called
97    pub fn update<F: Fn() -> String>(&mut self, value: T, loc: F) -> anyhow::Result<()> {
98        self.ensure_stale(loc)?;
99        self.0 = value;
100        self.1 = StateStatus::Fresh;
101        Ok(())
102    }
103
104    /// Verify that state is [State::Stale] and mark state as [State::Fresh]
105    /// without updating
106    /// # Arguments
107    /// - `loc`: closure that returns file and line number where called
108    pub fn mark_fresh<F: Fn() -> String>(&mut self, loc: F) -> anyhow::Result<()> {
109        self.ensure_stale(loc)?;
110        self.1 = StateStatus::Fresh;
111        Ok(())
112    }
113
114    /// Check that value has been updated and then return as a result
115    /// # Arguments
116    /// - `loc`: call site location filename and line number
117    pub fn get_fresh<F: Fn() -> String>(&self, loc: F) -> anyhow::Result<&T> {
118        self.ensure_fresh(loc)?;
119        Ok(&self.0)
120    }
121
122    /// Check that value has **not** been updated and then return as a result
123    /// # Arguments
124    /// - `loc`: call site location filename and line number
125    pub fn get_stale<F: Fn() -> String>(&self, loc: F) -> anyhow::Result<&T> {
126        self.ensure_stale(loc)?;
127        Ok(&self.0)
128    }
129}
130
131/// State methods that allow for `+=`
132impl<T: std::fmt::Debug + Clone + PartialEq + Default + std::ops::AddAssign> TrackedState<T> {
133    // Note that `anyhow::Error` is fine here because this should result only in
134    // logic errors and not runtime errors for end users
135    /// Update the value of the tracked state
136    /// # Arguments
137    /// - `value`: new value
138    /// - `loc`: closure that returns file and line number where called
139    pub fn increment<F: Fn() -> String>(&mut self, value: T, loc: F) -> anyhow::Result<()> {
140        self.ensure_stale(loc)?;
141        self.0 += value;
142        self.1 = StateStatus::Fresh;
143        Ok(())
144    }
145}
146
147// Custom serialization
148impl<T> Serialize for TrackedState<T>
149where
150    T: std::fmt::Debug + Clone + PartialEq + for<'de> Deserialize<'de> + Serialize + Default,
151{
152    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
153    where
154        S: serde::Serializer,
155    {
156        self.0.serialize(serializer)
157    }
158}
159
160impl<'de, T> Deserialize<'de> for TrackedState<T>
161where
162    T: std::fmt::Debug + Clone + PartialEq + Deserialize<'de> + Serialize + Default,
163{
164    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
165    where
166        D: serde::Deserializer<'de>,
167    {
168        let value: T = T::deserialize(deserializer)?;
169
170        Ok(Self(value, Default::default()))
171    }
172}
173
174#[cfg(test)]
175mod test_tracked_state {
176    use super::*;
177
178    #[test]
179    #[should_panic]
180    fn test_update_fresh() {
181        let mut pwr = TrackedState::new(si::Power::ZERO);
182        pwr.update(uc::W * 10.0, || format_dbg!()).unwrap();
183    }
184
185    #[test]
186    fn test_update_stale() {
187        let mut pwr = TrackedState::new(si::Power::ZERO);
188        pwr.mark_stale();
189        pwr.update(uc::W * 10.0, || format_dbg!()).unwrap();
190    }
191
192    #[test]
193    fn test_get_ok() {
194        let mut pwr = TrackedState::new(si::Power::ZERO);
195        pwr.get_fresh(|| format_dbg!()).unwrap();
196        pwr.mark_stale();
197        pwr.get_stale(|| format_dbg!()).unwrap();
198    }
199
200    #[test]
201    #[should_panic]
202    fn test_get_stale_fail() {
203        let pwr = TrackedState::new(si::Power::ZERO);
204        pwr.get_stale(|| format_dbg!()).unwrap();
205    }
206
207    #[test]
208    #[should_panic]
209    fn test_get_fresh_fail() {
210        let mut pwr = TrackedState::new(si::Power::ZERO);
211        pwr.mark_stale();
212        pwr.get_fresh(|| format_dbg!()).unwrap();
213    }
214}