banquo_core/
trace.rs

1//! A set of values where each value is associated with a time.
2//!
3//! A [`Trace`] is an associative map that represents a sequence of values, where the key for
4//! each value is the time when the value was generated. In the context of this library, this data
5//! structure is used to represent the state of a system, or its satisfaction metric, over time.
6//!
7//! # Safety
8//!
9//! `f64` values are used to represent the time of each state, which do not inherently support
10//! `Ord` and `Hash` due to the presence of NaN values, thus making them unsuitable for use
11//! as map keys. To work around this issue we ensure that no time value is NaN, which enables the
12//! implementation of the missing traits. As a result, using a NaN value as a time value in any
13//! method will result in a panic.
14//!
15//! # Examples
16//!
17//! An empty `Trace` can be constructed by using the [`Trace::new`] method.
18//!
19//! ```rust
20//! use banquo::Trace;
21//! let trace: Trace<f64> = Trace::new();
22//! ```
23//!
24//! Elements can be inserted into the `Trace` by using the [`Trace::insert`] method.
25//!
26//! ```rust
27//! use banquo::Trace;
28//!
29//! let mut trace: Trace<f64> = Trace::new();
30//! trace.insert(0.0, 100.0);
31//! trace.insert(1.0, 105.3);
32//! trace.insert(2.0, 107.1);
33//! ```
34//!
35//! A `Trace` can also be constructed from an array of known size.
36//!
37//! ```rust
38//! use banquo::Trace;
39//!
40//! let trace = Trace::from([
41//!     (0.0, 100.0),
42//!     (1.0, 105.3),
43//!     (2.0, 107.1),
44//! ]);
45//! ```
46//!
47//! Individual states can be accessed either checked or unchecked using [`Trace::at_time`] or [`Index`].
48//!
49//! ```rust
50//! use banquo::Trace;
51//!
52//! let trace = Trace::from([
53//!     (0.0, 100.0),
54//!     (1.0, 105.3),
55//!     (2.0, 107.1),
56//! ]);
57//!
58//! trace.at_time(0.0);  // Some(100.0)
59//! trace.at_time(3.0);  // None
60//!
61//! trace[0.0];      // 100.0
62//! // trace[3.0];   // Panic
63//! ```
64//!
65//! A `Trace` can be iterated over using for loops. You can also use the [`IntoIterator`]
66//! implementation for `Trace<T>`, `&Trace<T>`, and `&mut Trace<T>` to manually create iterators.
67//! Finally, you can iterate over either the times or states in the trace using the [`Trace::times`]
68//! and [`Trace::states`] methods.
69//!
70//! ```rust
71//! use banquo::Trace;
72//!
73//! let mut trace = Trace::from([
74//!     (0.0, 100.0),
75//!     (1.0, 105.3),
76//!     (2.0, 107.1),
77//! ]);
78//!
79//! let times = trace.times();
80//! let states = trace.states();
81//!
82//! let iter = trace.iter();
83//! let iter = IntoIterator::into_iter(&trace);
84//!
85//! for (time, state) in &trace {  // (f64, &f64)
86//!     // ...
87//! }
88//!
89//! let iter = trace.iter_mut();
90//! let iter = IntoIterator::into_iter(&mut trace);
91//!
92//! for (time, state) in &mut trace {  // (f64, &mut f64)
93//!     // ... 
94//! }
95//!
96//! // let iter = trace.into_iter();
97//!
98//! for (time, state) in trace {  // (f64, f64)
99//!     // ... 
100//! }
101//! ```
102//!
103//! Traces can be collected from [`Iterator`]s that yield tuple values where the first element can
104//! be converted into a `f64`.
105//!
106//! ```rust
107//! use banquo::Trace;
108//!
109//! let elements = vec![
110//!     (0.0, 100.0),
111//!     (1.0, 105.3),
112//!     (2.0, 107.1),
113//! ];
114//!
115//! let trace: Trace<_> = elements
116//!     .into_iter()
117//!     .map(|(time, state): (f64, f64)| (time, state / 10.0))
118//!     .collect();
119//! ```
120//!
121//! Trace iterators support mapping over only the states of a trace while keeping the times the
122//! same by using the `map_states` method.
123//!
124//! ```rust
125//! use banquo::Trace;
126//!
127//! let trace = Trace::from([
128//!     (0.0, 100.0),
129//!     (1.0, 105.3),
130//!     (2.0, 107.1),
131//! ]);
132//!
133//! let mapped: Trace<_> = trace
134//!     .into_iter()
135//!     .map_states(|state: f64| state / 2.0)
136//!     .collect();
137//! ```
138//!
139//! Traces also support iterating over sub-intervals using the [`Trace::range`] and
140//! [`Trace::range_mut`] methods.
141//!
142//! ```rust
143//! use banquo::Trace;
144//!
145//! let mut trace = Trace::from([
146//!     (0.0, 100.0),
147//!     (1.0, 105.3),
148//!     (2.0, 107.1),
149//! ]);
150//!
151//! let r1 = trace.range(0.0..2.0);  // Contains elements from times 0.0 <= t < 2.0
152//! let r2 = trace.range_mut(0.0..=2.0); // Contains elements from times 0.0 <= t <= 2.0
153//! ```
154use std::collections::BTreeMap;
155use std::ops::{Bound, Index, RangeBounds};
156
157use ordered_float::NotNan;
158
159/// A set of values where each value is associated with a time.
160///
161/// See the [`trace`](trace) module for more information about the semantics of this type.
162#[derive(PartialEq, Eq, Debug, Clone)]
163pub struct Trace<T>(BTreeMap<NotNan<f64>, T>);
164
165impl<T> Default for Trace<T> {
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171impl<A, T> FromIterator<(A, T)> for Trace<T>
172where
173    A: Into<f64>,
174{
175    fn from_iter<I>(iter: I) -> Self
176    where
177        I: IntoIterator<Item = (A, T)>,
178    {
179        let elements = iter
180            .into_iter()
181            .map(|(time, state)| (NotNan::new(time.into()).unwrap(), state))
182            .collect();
183
184        Self(elements)
185    }
186}
187
188impl<A, T, const N: usize> From<[(A, T); N]> for Trace<T>
189where
190    A: Into<f64>,
191{
192    #[inline]
193    fn from(values: [(A, T); N]) -> Self {
194        Self::from_iter(values)
195    }
196}
197
198/// Create a `Trace` from an array of elements of known length
199impl<A, T> From<Vec<(A, T)>> for Trace<T>
200where
201    A: Into<f64>,
202{
203    #[inline]
204    fn from(values: Vec<(A, T)>) -> Self {
205        Self::from_iter(values)
206    }
207}
208
209impl<T> Index<f64> for Trace<T> {
210    type Output = T;
211
212    fn index(&self, index: f64) -> &Self::Output {
213        let index = NotNan::new(index).unwrap();
214        self.0.index(&index)
215    }
216}
217
218impl<T> Trace<T> {
219    /// Create a new empty trace. Equivalent to [`Trace::default()`]
220    pub fn new() -> Self {
221        Self(BTreeMap::new())
222    }
223
224    /// Number of elements in the trace.
225    pub fn len(&self) -> usize {
226        self.0.len()
227    }
228
229    /// Determine if the trace contains any elements.
230    pub fn is_empty(&self) -> bool {
231        self.0.is_empty()
232    }
233
234    /// Get the state for a given time. Returns None if the time is not present in the trace.
235    ///
236    /// # Safety
237    ///
238    /// This method will panic if the provided time is NaN
239    pub fn at_time(&self, time: f64) -> Option<&T> {
240        let key = NotNan::new(time).unwrap();
241        self.0.get(&key)
242    }
243
244    /// Insert a state for a given time into the trace. Returns the prior state if it exists.
245    ///
246    /// # Safety
247    ///
248    /// This method will panic if the provided time is NaN
249    pub fn insert(&mut self, time: f64, state: T) -> Option<T> {
250        let key = NotNan::new(time).unwrap();
251        self.0.insert(key, state)
252    }
253}
254
255/// Iterator over the times in a trace.
256///
257/// The times are yielded in chronological order (lower times -> higher times).
258///
259/// This iterator can be construced by calling the `times()` method on either a `Trace` or `Range`
260/// value.
261///
262/// ```rust
263/// use banquo::Trace;
264///
265/// let trace = Trace::from([
266///     (0.0, ()),
267///     (1.0, ()),
268///     (3.0, ()),
269///     (4.0, ()),
270///     (5.0, ()),
271/// ]);
272///
273/// let times = trace.times();
274///
275/// let range = trace.range(0.0..=3.0);
276/// let times = range.times();
277/// ```
278pub struct Times<I>(I);
279
280impl<I, T> Iterator for Times<I>
281where
282    I: Iterator<Item = (f64, T)>,
283{
284    type Item = f64;
285
286    fn next(&mut self) -> Option<Self::Item> {
287        self.0.next().map(|p| p.0)
288    }
289
290    fn size_hint(&self) -> (usize, Option<usize>) {
291        self.0.size_hint()
292    }
293}
294
295impl<I, T> ExactSizeIterator for Times<I>
296where
297    I: ExactSizeIterator<Item = (f64, T)>,
298{
299    fn len(&self) -> usize {
300        self.0.len()
301    }
302}
303
304impl<I, T> DoubleEndedIterator for Times<I>
305where
306    I: DoubleEndedIterator<Item = (f64, T)>,
307{
308    fn next_back(&mut self) -> Option<Self::Item> {
309        self.0.next_back().map(|p| p.0)
310    }
311}
312
313/// Iterator over the states in a trace.
314///
315/// The states are yielded in chronological order (lower times -> higher times).
316///
317/// This iterator can be construced by calling the `states()` method on either a `Trace` or `Range`
318/// value.
319///
320/// ```rust
321/// use banquo::Trace;
322///
323/// let trace = Trace::from([
324///     (0.0, ()),
325///     (1.0, ()),
326///     (3.0, ()),
327///     (4.0, ()),
328///     (5.0, ()),
329/// ]);
330///
331/// let states = trace.states();
332///
333/// let range = trace.range(0.0..=3.0);
334/// let states = range.states();
335/// ```
336pub struct States<I>(I);
337
338impl<I, T> Iterator for States<I>
339where
340    I: Iterator<Item = (f64, T)>,
341{
342    type Item = T;
343
344    fn next(&mut self) -> Option<Self::Item> {
345        self.0.next().map(|p| p.1)
346    }
347
348    fn size_hint(&self) -> (usize, Option<usize>) {
349        self.0.size_hint()
350    }
351}
352
353impl<I, T> ExactSizeIterator for States<I>
354where
355    I: ExactSizeIterator<Item = (f64, T)>,
356{
357    fn len(&self) -> usize {
358        self.0.len()
359    }
360}
361
362impl<I, T> DoubleEndedIterator for States<I>
363where
364    I: DoubleEndedIterator<Item = (f64, T)>,
365{
366    fn next_back(&mut self) -> Option<Self::Item> {
367        self.0.next_back().map(|p| p.1)
368    }
369}
370
371/// Iterator that calls a given function for every state in a trace, keeping the times the same.
372///
373/// This iterator can be construced by calling the `map_states()` method on an `IntoIter`, `Iter`,
374/// 'IterMut', or `Range` value.
375///
376/// ```rust
377/// use banquo::Trace;
378///
379/// let trace = Trace::from([
380///     (0.0, ()),
381///     (1.0, ()),
382///     (3.0, ()),
383///     (4.0, ()),
384///     (5.0, ()),
385/// ]);
386///
387/// let states = trace.states();
388///
389/// let range = trace.range(0.0..=3.0);
390/// let states = range.states();
391/// ```
392pub struct MapStates<I, F> {
393    iter: I,
394    f: F,
395}
396
397impl<I, F> MapStates<I, F> {
398    fn map_element<T, U>(&mut self, (time, state): (f64, T)) -> (f64, U)
399    where
400        F: FnMut(T) -> U,
401    {
402        (time, (self.f)(state))
403    }
404}
405
406impl<I, F, T, U> Iterator for MapStates<I, F>
407where
408    I: Iterator<Item = (f64, T)>,
409    F: FnMut(T) -> U,
410{
411    type Item = (f64, U);
412
413    fn next(&mut self) -> Option<Self::Item> {
414        self.iter.next().map(|e| self.map_element(e))
415    }
416
417    fn size_hint(&self) -> (usize, Option<usize>) {
418        self.iter.size_hint()
419    }
420}
421
422impl<I, F, T, U> DoubleEndedIterator for MapStates<I, F>
423where
424    I: DoubleEndedIterator<Item = (f64, T)>,
425    F: FnMut(T) -> U,
426{
427    fn next_back(&mut self) -> Option<Self::Item> {
428        self.iter.next_back().map(|e| self.map_element(e))
429    }
430}
431
432impl<I, F, T, U> ExactSizeIterator for MapStates<I, F>
433where
434    I: ExactSizeIterator<Item = (f64, T)>,
435    F: FnMut(T) -> U,
436{
437    fn len(&self) -> usize {
438        self.iter.len()
439    }
440}
441
442/// Borrowing iterator over the `(time, state)` pairs in a trace.
443///
444/// The values are yielded in chronological order (lower times -> higher times).
445///
446/// ```rust
447/// use banquo::Trace;
448///
449/// let trace = Trace::from([
450///     (0.0, ()),
451///     (1.0, ()),
452///     (2.0, ()),
453///     (3.0, ()),
454///     (4.0, ()),
455///     (5.0, ()),
456/// ]);
457///
458/// let iter = trace.iter();
459///
460/// for (time, state) in &trace {
461///     // ...
462/// }
463/// ```
464pub struct Iter<'a, T>(std::collections::btree_map::Iter<'a, NotNan<f64>, T>);
465
466impl<'a, T> Iter<'a, T> {
467    fn map_element((&time, state): (&'a NotNan<f64>, &'a T)) -> (f64, &'a T) {
468        (time.into_inner(), state)
469    }
470}
471
472impl<'a, T> Iterator for Iter<'a, T> {
473    type Item = (f64, &'a T);
474
475    fn next(&mut self) -> Option<Self::Item> {
476        self.0.next().map(Self::map_element)
477    }
478
479    fn size_hint(&self) -> (usize, Option<usize>) {
480        self.0.size_hint()
481    }
482}
483
484impl<'a, T> DoubleEndedIterator for Iter<'a, T> {
485    fn next_back(&mut self) -> Option<Self::Item> {
486        self.0.next_back().map(Self::map_element)
487    }
488}
489
490impl<'a, T> ExactSizeIterator for Iter<'a, T> {
491    fn len(&self) -> usize {
492        self.0.len()
493    }
494}
495
496impl<'a, T> Iter<'a, T> {
497    /// Create an iterator over the states of the trace, ignoring the times.
498    ///
499    /// # Examples
500    ///
501    /// ```rust
502    /// use banquo::Trace;
503    ///
504    /// let trace = Trace::from([
505    ///     (0.0, 'a'),
506    ///     (1.0, 'b'),
507    ///     (2.0, 'c'),
508    ///     (3.0, 'd'),
509    ///     (4.0, 'e'),
510    ///     (5.0, 'f'),
511    /// ]);
512    ///
513    /// let iter: Vec<&char> = trace.iter().states().collect();
514    /// ```
515    pub fn states(self) -> States<Self> {
516        States(self)
517    }
518
519    /// Create an iterator that applies the function `f` to each state of the trace, while keeping
520    /// the times the same.
521    ///
522    /// # Examples
523    ///
524    /// ```rust
525    /// use banquo::Trace;
526    ///
527    /// let trace = Trace::from([
528    ///     (0.0, "s1".to_string()),
529    ///     (1.0, "s2".to_string()),
530    ///     (2.0, "s3".to_string()),
531    ///     (3.0, "s4".to_string()),
532    ///     (4.0, "s5".to_string()),
533    ///     (5.0, "s6".to_string()),
534    /// ]);
535    ///
536    /// let iter: Trace<usize> = trace
537    ///     .iter()
538    ///     .map_states(|state: &String| state.len())
539    ///     .collect();
540    /// ```
541    pub fn map_states<F, U>(self, f: F) -> MapStates<Self, F>
542    where
543        F: FnMut(&'a T) -> U,
544    {
545        MapStates { f, iter: self }
546    }
547}
548
549/// Owning iterator over the `(time, state)` pairs in a trace.
550///
551/// The values are yielded in chronological order (lower times -> higher times).
552///
553/// ```rust
554/// use banquo::Trace;
555///
556/// let trace = Trace::from([
557///     (0.0, ()),
558///     (1.0, ()),
559///     (2.0, ()),
560///     (3.0, ()),
561///     (4.0, ()),
562///     (5.0, ()),
563/// ]);
564///
565/// // let iter = trace.into_iter();
566///
567/// for (time, state) in trace {
568///     // ...
569/// }
570/// ```
571pub struct IntoIter<T>(std::collections::btree_map::IntoIter<NotNan<f64>, T>);
572
573impl<T> IntoIter<T> {
574    fn map_element((time, state): (NotNan<f64>, T)) -> (f64, T) {
575        (time.into_inner(), state)
576    }
577}
578
579impl<T> Iterator for IntoIter<T> {
580    type Item = (f64, T);
581
582    fn next(&mut self) -> Option<Self::Item> {
583        self.0.next().map(Self::map_element)
584    }
585
586    fn size_hint(&self) -> (usize, Option<usize>) {
587        self.0.size_hint()
588    }
589}
590
591impl<T> DoubleEndedIterator for IntoIter<T> {
592    fn next_back(&mut self) -> Option<Self::Item> {
593        self.0.next_back().map(Self::map_element)
594    }
595}
596
597impl<T> ExactSizeIterator for IntoIter<T> {
598    fn len(&self) -> usize {
599        self.0.len()
600    }
601}
602
603impl<T> IntoIter<T> {
604    /// Create an iterator over the states of the trace, ignoring the times.
605    ///
606    /// # Example
607    ///
608    /// ```rust
609    /// use banquo::Trace;
610    ///
611    /// let trace = Trace::from([
612    ///     (0.0, 'a'),
613    ///     (1.0, 'b'),
614    ///     (2.0, 'c'),
615    ///     (3.0, 'd'),
616    ///     (4.0, 'e'),
617    ///     (5.0, 'f'),
618    /// ]);
619    ///
620    /// let iter: Vec<char> = trace
621    ///     .into_iter()
622    ///     .states()
623    ///     .collect();
624    /// ```
625    pub fn states(self) -> States<Self> {
626        States(self)
627    }
628
629    /// Create an iterator that applies a function `f` to each state of the trace, while keeping the
630    /// times the same.
631    ///
632    /// # Example
633    ///
634    /// ```rust
635    /// use banquo::Trace;
636    ///
637    /// let trace = Trace::from([
638    ///     (0.0, 'a'),
639    ///     (1.0, 'b'),
640    ///     (2.0, 'c'),
641    ///     (3.0, 'd'),
642    ///     (4.0, 'e'),
643    ///     (5.0, 'f'),
644    /// ]);
645    ///
646    /// let iter: Trace<u8> = trace
647    ///     .into_iter()
648    ///     .map_states(|state: char| state as u8)
649    ///     .collect();
650    /// ```
651    pub fn map_states<F, U>(self, f: F) -> MapStates<Self, F>
652    where
653        F: FnMut(T) -> U,
654    {
655        MapStates{ f, iter: self }
656    }
657}
658
659/// Mutably borrowing iterator over the `(time, state)` pairs in a trace.
660///
661/// The values are yielded in chronological order (lower times -> higher times).
662///
663/// ```rust
664/// use banquo::Trace;
665///
666/// let mut trace = Trace::from([
667///     (0.0, ()),
668///     (1.0, ()),
669///     (2.0, ()),
670///     (3.0, ()),
671///     (4.0, ()),
672///     (5.0, ()),
673/// ]);
674///
675/// let iter = trace.iter_mut();
676///
677/// for (time, state) in &mut trace {
678///     // ...
679/// }
680/// ```
681pub struct IterMut<'a, T>(std::collections::btree_map::IterMut<'a, NotNan<f64>, T>);
682
683impl<'a, T> IterMut<'a, T> {
684    fn map_element((&time, state): (&'a NotNan<f64>, &'a mut T)) -> (f64, &'a mut T) {
685        (time.into_inner(), state)
686    }
687}
688
689impl<'a, T> IterMut<'a, T> {
690    /// Create an iterator over the states of the trace, ignoring the times.
691    ///
692    /// # Example
693    ///
694    /// ```rust
695    /// use banquo::Trace;
696    ///
697    /// let mut trace = Trace::from([
698    ///     (0.0, 10),
699    ///     (1.0, 20),
700    ///     (2.0, 30),
701    ///     (3.0, 40),
702    ///     (4.0, 50),
703    ///     (5.0, 60),
704    /// ]);
705    ///
706    /// trace
707    ///     .iter_mut()
708    ///     .states()
709    ///     .for_each(|state: &mut i32| *state += 5);
710    /// ```
711    pub fn states(self) -> States<Self> {
712        States(self)
713    }
714}
715
716impl<'a, T> Iterator for IterMut<'a, T> {
717    type Item = (f64, &'a mut T);
718
719    fn next(&mut self) -> Option<Self::Item> {
720        self.0.next().map(Self::map_element)
721    }
722
723    fn size_hint(&self) -> (usize, Option<usize>) {
724        self.0.size_hint()
725    }
726}
727
728impl<'a, T> DoubleEndedIterator for IterMut<'a, T> {
729    fn next_back(&mut self) -> Option<Self::Item> {
730        self.0.next_back().map(Self::map_element)
731    }
732}
733
734impl<'a, T> ExactSizeIterator for IterMut<'a, T> {
735    fn len(&self) -> usize {
736        self.0.len()
737    }
738}
739
740/// Borrowing iterator over the `(time, state)` pairs of a sub-interval of trace.
741///
742/// The values are yielded in chronological order (lower times -> higher times).
743///
744/// This value can be constructed by calling the `range()` method on a `Trace` value;
745///
746/// # Example
747///
748/// ```rust
749/// use banquo::Trace;
750///
751/// let trace = Trace::from([
752///     (0.0, ()),
753///     (1.0, ()),
754///     (2.0, ()),
755///     (3.0, ()),
756///     (4.0, ()),
757///     (5.0, ()),
758/// ]);
759///
760/// let range = trace.range(0.0..4.0);
761/// let range = trace.range(1.0..=3.0);
762/// ```
763pub struct Range<'a, T>(std::collections::btree_map::Range<'a, NotNan<f64>, T>);
764
765impl<'a, T> Range<'a, T> {
766    fn map_element((&time, state): (&'a NotNan<f64>, &'a T)) -> (f64, &'a T) {
767        (time.into_inner(), state)
768    }
769}
770
771impl<'a, T> Iterator for Range<'a, T> {
772    type Item = (f64, &'a T);
773
774    fn next(&mut self) -> Option<Self::Item> {
775        self.0.next().map(Self::map_element)
776    }
777
778    fn size_hint(&self) -> (usize, Option<usize>) {
779        self.0.size_hint()
780    }
781}
782
783impl<'a, T> DoubleEndedIterator for Range<'a, T> {
784    fn next_back(&mut self) -> Option<Self::Item> {
785        self.0.next_back().map(Self::map_element)
786    }
787}
788
789impl<'a, T> Range<'a, T> {
790    /// Create an iterator over the times of the range, ignoring the states.
791    ///
792    /// # Example
793    ///
794    /// ```rust
795    /// use banquo::Trace;
796    ///
797    /// let mut trace = Trace::from([
798    ///     (0.0, 'a'),
799    ///     (1.0, 'b'),
800    ///     (2.0, 'c'),
801    ///     (3.0, 'd'),
802    ///     (4.0, 'e'),
803    ///     (5.0, 'f'),
804    /// ]);
805    ///
806    /// let iter: Vec<f64> = trace
807    ///     .range(0.5..=4.5)
808    ///     .times()
809    ///     .collect();
810    /// ```
811    pub fn times(self) -> Times<Self> {
812        Times(self)
813    }
814
815    /// Create an iterator over states of the range, ignoring the times.
816    ///
817    /// # Example
818    ///
819    /// ```rust
820    /// use banquo::Trace;
821    ///
822    /// let mut trace = Trace::from([
823    ///     (0.0, 'a'),
824    ///     (1.0, 'b'),
825    ///     (2.0, 'c'),
826    ///     (3.0, 'd'),
827    ///     (4.0, 'e'),
828    ///     (5.0, 'f'),
829    /// ]);
830    ///
831    /// let iter: Vec<&char> = trace
832    ///     .range(1.0..=4.0)
833    ///     .states()
834    ///     .collect();
835    /// ```
836    pub fn states(self) -> States<Self> {
837        States(self)
838    }
839    
840    /// Create an iterator that applies the function `f` to each state of the sub-interval, while keeping
841    /// the times the same.
842    ///
843    /// # Examples
844    ///
845    /// ```rust
846    /// use banquo::Trace;
847    ///
848    /// let trace = Trace::from([
849    ///     (0.0, "s1".to_string()),
850    ///     (1.0, "s2".to_string()),
851    ///     (2.0, "s3".to_string()),
852    ///     (3.0, "s4".to_string()),
853    ///     (4.0, "s5".to_string()),
854    ///     (5.0, "s6".to_string()),
855    /// ]);
856    ///
857    /// let iter: Trace<usize> = trace
858    ///     .iter()
859    ///     .map_states(|state: &String| state.len())
860    ///     .collect();
861    /// ```
862    pub fn map_states<F, U>(self, f: F) -> MapStates<Self, F>
863    where
864        F: FnMut(&'a T) -> U,
865    {
866        MapStates { f, iter: self }
867    }
868}
869
870/// Mutably borrowing iterator over the `(time, state)` pairs of a sub-interval of trace.
871///
872/// The values are yielded in chronological order (lower times -> higher times).
873///
874/// This value can be constructed by calling the `range_mut()` method on a `Trace` value;
875///
876/// ```rust
877/// use banquo::Trace;
878///
879/// let mut trace = Trace::from([
880///     (0.0, ()),
881///     (1.0, ()),
882///     (2.0, ()),
883///     (3.0, ()),
884///     (4.0, ()),
885///     (5.0, ()),
886/// ]);
887///
888/// let range = trace.range_mut(0.0..4.0);
889/// let range = trace.range_mut(1.0..=3.0);
890/// ```
891pub struct RangeMut<'a, T>(std::collections::btree_map::RangeMut<'a, NotNan<f64>, T>);
892
893impl<'a, T> RangeMut<'a, T> {
894    fn map_element((&time, state): (&'a NotNan<f64>, &'a mut T)) -> (f64, &'a mut T) {
895        (time.into_inner(), state)
896    }
897}
898
899impl<'a, T> Iterator for RangeMut<'a, T> {
900    type Item = (f64, &'a mut T);
901
902    fn next(&mut self) -> Option<Self::Item> {
903        self.0.next().map(Self::map_element)
904    }
905
906    fn size_hint(&self) -> (usize, Option<usize>) {
907        self.0.size_hint()
908    }
909}
910
911impl<'a, T> DoubleEndedIterator for RangeMut<'a, T> {
912    fn next_back(&mut self) -> Option<Self::Item> {
913        self.0.next_back().map(Self::map_element)
914    }
915}
916
917impl<'a, T> RangeMut<'a, T> {
918    /// Create an iterator over states of the range, ignoring the times.
919    ///
920    /// # Example
921    ///
922    /// ```rust
923    /// use banquo::Trace;
924    ///
925    /// let mut trace = Trace::from([
926    ///     (0.0, 'a'),
927    ///     (1.0, 'b'),
928    ///     (2.0, 'c'),
929    ///     (3.0, 'd'),
930    ///     (4.0, 'e'),
931    ///     (5.0, 'f'),
932    /// ]);
933    ///
934    /// let iter: Vec<&char> = trace
935    ///     .range(1.0..=4.0)
936    ///     .states()
937    ///     .collect();
938    /// ```
939    pub fn states(self) -> States<Self> {
940        States(self)
941    }
942}
943
944fn convert_bound(bound: Bound<&f64>) -> Bound<NotNan<f64>> {
945    match bound {
946        Bound::Unbounded => Bound::Unbounded,
947        Bound::Included(&val) => Bound::Included(NotNan::new(val).unwrap()),
948        Bound::Excluded(&val) => Bound::Excluded(NotNan::new(val).unwrap()),
949    }
950}
951
952impl<T> Trace<T> {
953    /// Create an iterator yielding `(time, &state)` values from the trace in chronological order.
954    pub fn iter(&self) -> Iter<T> {
955        self.into_iter()
956    }
957
958    /// Create an iterator yielding `(time, &mut state)` values from the trace in chronological order.
959    pub fn iter_mut(&mut self) -> IterMut<T> {
960        IterMut(self.0.iter_mut())
961    }
962
963    /// Create an iterator yielding time values from the trace in chronological order.
964    pub fn times(&self) -> Times<Iter<T>> {
965        Times(self.iter())
966    }
967
968    /// Create an iterator yielding `&state` values from the trace in chronological order.
969    pub fn states(&self) -> States<Iter<T>> {
970        States(self.iter())
971    }
972
973    /// Create an iterator yielding `&mut state values` from the trace in chronological order.
974    pub fn states_mut(&mut self) -> States<IterMut<T>> {
975        States(self.iter_mut())
976    }
977
978    /// Create an iterator over the `(time, &state)` values from a sub-interval of the trace in
979    /// chronological order
980    ///
981    /// # Safety
982    ///
983    /// This function panics if either range bound is NaN.
984    pub fn range<R>(&self, bounds: R) -> Range<T>
985    where
986        R: RangeBounds<f64>,
987    {
988        let start = convert_bound(bounds.start_bound());
989        let end = convert_bound(bounds.end_bound());
990
991        Range(self.0.range((start, end)))
992    }
993
994    /// Create an iterator over the `(time, &mut state)` values from a sub-interval of the trace in
995    /// chronological order.
996    ///
997    /// # Safety
998    ///
999    /// This function panics if either range bound is NaN.
1000    pub fn range_mut<R>(&mut self, bounds: R) -> RangeMut<T>
1001    where
1002        R: RangeBounds<f64>,
1003    {
1004        let start = convert_bound(bounds.start_bound());
1005        let end = convert_bound(bounds.end_bound());
1006
1007        RangeMut(self.0.range_mut((start, end)))
1008    }
1009}
1010
1011impl<T> IntoIterator for Trace<T> {
1012    type Item = (f64, T);
1013    type IntoIter = IntoIter<T>;
1014
1015    fn into_iter(self) -> Self::IntoIter {
1016        IntoIter(self.0.into_iter())
1017    }
1018}
1019
1020impl<'a, T> IntoIterator for &'a Trace<T> {
1021    type Item = (f64, &'a T);
1022    type IntoIter = Iter<'a, T>;
1023
1024    fn into_iter(self) -> Self::IntoIter {
1025        Iter(self.0.iter())
1026    }
1027}
1028
1029impl<'a, T> IntoIterator for &'a mut Trace<T> {
1030    type Item = (f64, &'a mut T);
1031    type IntoIter = IterMut<'a, T>;
1032
1033    fn into_iter(self) -> Self::IntoIter {
1034        IterMut(self.0.iter_mut())
1035    }
1036}
1037
1038#[cfg(test)]
1039mod tests {
1040    use super::Trace;
1041
1042    #[test]
1043    fn get_element() {
1044        let times = 0..10;
1045        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1046        let trace = Trace::from_iter(times.zip(values));
1047
1048        assert_eq!(trace.at_time(3.0), Some(&4.0))
1049    }
1050
1051    #[test]
1052    fn times() {
1053        let trace = Trace::from_iter([
1054            (1.0, ()),
1055            (2.0, ()),
1056            (3.0, ()),
1057            (4.0, ()),
1058        ]);
1059
1060        let mut times = trace.times();
1061
1062        assert_eq!(times.next(), Some(1.0));
1063        assert_eq!(times.next(), Some(2.0));
1064        assert_eq!(times.next(), Some(3.0));
1065        assert_eq!(times.next(), Some(4.0));
1066        assert_eq!(times.next(), None);
1067    }
1068
1069    #[test]
1070    fn states() {
1071        let trace = Trace::from_iter([
1072            (1.0, 1.0),
1073            (2.0, 2.0),
1074            (3.0, 3.0),
1075            (4.0, 4.0),
1076        ]);
1077
1078        let mut states = trace.states();
1079
1080        assert_eq!(states.next(), Some(&1.0));
1081        assert_eq!(states.next(), Some(&2.0));
1082        assert_eq!(states.next(), Some(&3.0));
1083        assert_eq!(states.next(), Some(&4.0));
1084        assert_eq!(states.next(), None);
1085    }
1086
1087    #[test]
1088    fn select_range() {
1089        let times = 0..10;
1090        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1091        let trace = Trace::from_iter(times.zip(values));
1092
1093        let subtrace_times: Vec<f64> = trace.range(0f64..4.0).times().collect::<Vec<_>>();
1094        let subtrace_states: Vec<f64> = trace.range(0f64..4.0).states().map(|state| *state).collect::<Vec<f64>>();
1095
1096        assert_eq!(subtrace_times, vec![0.0, 1.0, 2.0, 3.0]);
1097        assert_eq!(subtrace_states, vec![1.0, 2.0, 3.0, 4.0]);
1098    }
1099}