logic/
trace.rs

1use std::collections::{BTreeMap, VecDeque};
2use std::rc::Rc;
3
4use ordered_float::NotNan;
5
6type StateMap<T> = BTreeMap<NotNan<f64>, Rc<T>>;
7
8/// Data structure that represents a sequence of timed states
9pub struct Trace<T> {
10    timed_states: StateMap<T>,
11}
12
13/// Borrowing data structure that represents a system state at a given time
14pub struct TimedState<'a, T> {
15    pub time: f64,
16    pub state: &'a T,
17}
18
19/// An iterator over the times of a trace
20pub struct Times<'a> {
21    inner: VecDeque<&'a NotNan<f64>>,
22}
23
24impl<'a> Times<'a> {
25    pub fn first(&self) -> Option<f64> {
26        self.inner.front().map(|time| time.into_inner())
27    }
28
29    pub fn last(&self) -> Option<f64> {
30        self.inner.back().map(|time| time.into_inner())
31    }
32}
33
34/// A borrowing iterator over the times of a trace
35pub struct Iter<'a, T> {
36    inner: std::collections::btree_map::Iter<'a, NotNan<f64>, Rc<T>>,
37}
38
39/// An iterator that returns subtraces beginning from the end of the trace and extending to the front of a trace
40///
41/// Example: Given a trace { t1, .. tn-2, tn-1, tn }
42///     First subtrace: { tn }
43///     Second subtrace: { tn-1, tn }
44///     Third subtrace: { tn-2,  tn-1, tn }
45///     Nth subtrace { t1, .. tn-2, tn-1, tn }
46pub struct Subtraces<T> {
47    timed_states: StateMap<T>,
48    length: usize,
49}
50
51impl<T> Trace<T> {
52    pub fn len(&self) -> usize {
53        self.timed_states.len()
54    }
55
56    pub fn is_empty(&self) -> bool {
57        self.len() == 0
58    }
59
60    pub fn iter(&self) -> Iter<T> {
61        Iter {
62            inner: self.timed_states.iter(),
63        }
64    }
65
66    pub fn times(&self) -> Times {
67        Times {
68            inner: self.timed_states.keys().collect(),
69        }
70    }
71
72    pub fn into_subtraces(self) -> Subtraces<T> {
73        Subtraces {
74            timed_states: self.timed_states,
75            length: 1,
76        }
77    }
78
79    pub fn split_at(&mut self, time: f64) -> Option<Trace<T>> {
80        let split_key = self.timed_states.keys().find(|key| key.into_inner() >= time).cloned()?;
81        let later_states = self.timed_states.split_off(&split_key);
82
83        Some(Trace {
84            timed_states: later_states,
85        })
86    }
87
88    pub fn retain_between(&mut self, lower: f64, upper: f64) {
89        let lower_key = self
90            .timed_states
91            .keys()
92            .find(|key| key.into_inner() >= lower)
93            .cloned()
94            .unwrap();
95        let upper_key = self.timed_states.keys().find(|key| key.into_inner() > upper).cloned();
96        let mut new_timed_states = self.timed_states.split_off(&lower_key);
97
98        if let Some(key) = upper_key {
99            new_timed_states.split_off(&key);
100        }
101
102        self.timed_states = new_timed_states
103    }
104
105    pub fn first_state(&self) -> Option<TimedState<T>> {
106        let (first_time, first_state) = self.timed_states.iter().next()?;
107
108        Some(TimedState {
109            time: first_time.into_inner(),
110            state: first_state.as_ref(),
111        })
112    }
113
114    pub fn begin_at_zero(&mut self) {
115        let first_state = self.first_state();
116
117        if first_state.is_none() {
118            return;
119        }
120
121        let offset = first_state.unwrap().time;
122        let new_timed_states = self
123            .timed_states
124            .iter()
125            .map(|(time, state)| (time - offset, state.clone()))
126            .collect();
127
128        self.timed_states = new_timed_states;
129    }
130}
131
132impl<T> FromIterator<(f64, T)> for Trace<T> {
133    fn from_iter<I: IntoIterator<Item = (f64, T)>>(iter: I) -> Self {
134        let items = iter
135            .into_iter()
136            .map(|(time, state)| (NotNan::try_from(time).unwrap(), Rc::new(state)))
137            .collect::<Vec<_>>();
138        let mut timed_states = BTreeMap::new();
139
140        for (time, state) in items {
141            let previous_value = timed_states.insert(time, state);
142            if previous_value.is_some() {
143                panic!("Duplicate time {}", time.into_inner());
144            }
145        }
146
147        Self { timed_states }
148    }
149}
150
151impl<T> Clone for Trace<T> {
152    fn clone(&self) -> Self {
153        Self {
154            timed_states: self.timed_states.clone(),
155        }
156    }
157}
158
159impl<'a> Iterator for Times<'a> {
160    type Item = f64;
161
162    fn next(&mut self) -> Option<Self::Item> {
163        self.inner.pop_front().map(|time| time.into_inner())
164    }
165}
166
167impl<'a> DoubleEndedIterator for Times<'a> {
168    fn next_back(&mut self) -> Option<Self::Item> {
169        self.inner.pop_back().map(|time| time.into_inner())
170    }
171}
172
173impl<'a, T> Iterator for Iter<'a, T> {
174    type Item = TimedState<'a, T>;
175
176    fn next(&mut self) -> Option<Self::Item> {
177        self.inner.next().map(|(key, state)| TimedState {
178            time: key.into_inner(),
179            state: state.as_ref(),
180        })
181    }
182}
183
184impl<T> Iterator for Subtraces<T> {
185    type Item = Trace<T>;
186
187    fn next(&mut self) -> Option<Self::Item> {
188        let key_index = self.timed_states.len().checked_sub(self.length)?;
189        let split_key = self.timed_states.keys().nth(key_index)?;
190        let mut timed_states = self.timed_states.clone();
191        let later_states = timed_states.split_off(split_key);
192        let mut trace = Trace {
193            timed_states: later_states,
194        };
195        trace.begin_at_zero();
196        self.length += 1;
197
198        Some(trace)
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::Trace;
205
206    #[test]
207    fn retain_between_times() {
208        let mut trace = Trace::from_iter([(0.0, ()), (1.0, ()), (2.0, ()), (3.0, ()), (4.0, ()), (5.0, ())]);
209        trace.retain_between(1.0, 4.0);
210        let times = trace.times().collect::<Vec<_>>();
211
212        assert_eq!(times, vec![1.0, 2.0, 3.0, 4.0])
213    }
214}