banquo_core/trace.rs
1//! A set of values where each value is associated with a time.
2//!
3//! A [`Trace`] is an associative map that represents a sequence of values, where the key for
4//! each value is the time when the value was generated. In the context of this library, this data
5//! structure is used to represent the state of a system, or its satisfaction metric, over time.
6//!
7//! # Safety
8//!
9//! `f64` values are used to represent the time of each state, which do not inherently support
10//! `Ord` and `Hash` due to the presence of NaN values, thus making them unsuitable for use
11//! as map keys. To work around this issue we ensure that no time value is NaN, which enables the
12//! implementation of the missing traits. As a result, using a NaN value as a time value in any
13//! method will result in a panic.
14//!
15//! # Examples
16//!
17//! An empty `Trace` can be constructed by using the [`Trace::new`] method.
18//!
19//! ```rust
20//! use banquo::Trace;
21//! let trace: Trace<f64> = Trace::new();
22//! ```
23//!
24//! Elements can be inserted into the `Trace` by using the [`Trace::insert`] method.
25//!
26//! ```rust
27//! use banquo::Trace;
28//!
29//! let mut trace: Trace<f64> = Trace::new();
30//! trace.insert(0.0, 100.0);
31//! trace.insert(1.0, 105.3);
32//! trace.insert(2.0, 107.1);
33//! ```
34//!
35//! A `Trace` can also be constructed from an array of known size.
36//!
37//! ```rust
38//! use banquo::Trace;
39//!
40//! let trace = Trace::from([
41//! (0.0, 100.0),
42//! (1.0, 105.3),
43//! (2.0, 107.1),
44//! ]);
45//! ```
46//!
47//! Individual states can be accessed either checked or unchecked using [`Trace::at_time`] or [`Index`].
48//!
49//! ```rust
50//! use banquo::Trace;
51//!
52//! let trace = Trace::from([
53//! (0.0, 100.0),
54//! (1.0, 105.3),
55//! (2.0, 107.1),
56//! ]);
57//!
58//! trace.at_time(0.0); // Some(100.0)
59//! trace.at_time(3.0); // None
60//!
61//! trace[0.0]; // 100.0
62//! // trace[3.0]; // Panic
63//! ```
64//!
65//! A `Trace` can be iterated over using for loops. You can also use the [`IntoIterator`]
66//! implementation for `Trace<T>`, `&Trace<T>`, and `&mut Trace<T>` to manually create iterators.
67//! Finally, you can iterate over either the times or states in the trace using the [`Trace::times`]
68//! and [`Trace::states`] methods.
69//!
70//! ```rust
71//! use banquo::Trace;
72//!
73//! let mut trace = Trace::from([
74//! (0.0, 100.0),
75//! (1.0, 105.3),
76//! (2.0, 107.1),
77//! ]);
78//!
79//! let times = trace.times();
80//! let states = trace.states();
81//!
82//! let iter = trace.iter();
83//! let iter = IntoIterator::into_iter(&trace);
84//!
85//! for (time, state) in &trace { // (f64, &f64)
86//! // ...
87//! }
88//!
89//! let iter = trace.iter_mut();
90//! let iter = IntoIterator::into_iter(&mut trace);
91//!
92//! for (time, state) in &mut trace { // (f64, &mut f64)
93//! // ...
94//! }
95//!
96//! // let iter = trace.into_iter();
97//!
98//! for (time, state) in trace { // (f64, f64)
99//! // ...
100//! }
101//! ```
102//!
103//! Traces can be collected from [`Iterator`]s that yield tuple values where the first element can
104//! be converted into a `f64`.
105//!
106//! ```rust
107//! use banquo::Trace;
108//!
109//! let elements = vec![
110//! (0.0, 100.0),
111//! (1.0, 105.3),
112//! (2.0, 107.1),
113//! ];
114//!
115//! let trace: Trace<_> = elements
116//! .into_iter()
117//! .map(|(time, state): (f64, f64)| (time, state / 10.0))
118//! .collect();
119//! ```
120//!
121//! Trace iterators support mapping over only the states of a trace while keeping the times the
122//! same by using the `map_states` method.
123//!
124//! ```rust
125//! use banquo::Trace;
126//!
127//! let trace = Trace::from([
128//! (0.0, 100.0),
129//! (1.0, 105.3),
130//! (2.0, 107.1),
131//! ]);
132//!
133//! let mapped: Trace<_> = trace
134//! .into_iter()
135//! .map_states(|state: f64| state / 2.0)
136//! .collect();
137//! ```
138//!
139//! Traces also support iterating over sub-intervals using the [`Trace::range`] and
140//! [`Trace::range_mut`] methods.
141//!
142//! ```rust
143//! use banquo::Trace;
144//!
145//! let mut trace = Trace::from([
146//! (0.0, 100.0),
147//! (1.0, 105.3),
148//! (2.0, 107.1),
149//! ]);
150//!
151//! let r1 = trace.range(0.0..2.0); // Contains elements from times 0.0 <= t < 2.0
152//! let r2 = trace.range_mut(0.0..=2.0); // Contains elements from times 0.0 <= t <= 2.0
153//! ```
154use std::collections::BTreeMap;
155use std::ops::{Bound, Index, RangeBounds};
156
157use ordered_float::NotNan;
158
159/// A set of values where each value is associated with a time.
160///
161/// See the [`trace`](trace) module for more information about the semantics of this type.
162#[derive(PartialEq, Eq, Debug, Clone)]
163pub struct Trace<T>(BTreeMap<NotNan<f64>, T>);
164
165impl<T> Default for Trace<T> {
166 fn default() -> Self {
167 Self::new()
168 }
169}
170
171impl<A, T> FromIterator<(A, T)> for Trace<T>
172where
173 A: Into<f64>,
174{
175 fn from_iter<I>(iter: I) -> Self
176 where
177 I: IntoIterator<Item = (A, T)>,
178 {
179 let elements = iter
180 .into_iter()
181 .map(|(time, state)| (NotNan::new(time.into()).unwrap(), state))
182 .collect();
183
184 Self(elements)
185 }
186}
187
188impl<A, T, const N: usize> From<[(A, T); N]> for Trace<T>
189where
190 A: Into<f64>,
191{
192 #[inline]
193 fn from(values: [(A, T); N]) -> Self {
194 Self::from_iter(values)
195 }
196}
197
198/// Create a `Trace` from an array of elements of known length
199impl<A, T> From<Vec<(A, T)>> for Trace<T>
200where
201 A: Into<f64>,
202{
203 #[inline]
204 fn from(values: Vec<(A, T)>) -> Self {
205 Self::from_iter(values)
206 }
207}
208
209impl<T> Index<f64> for Trace<T> {
210 type Output = T;
211
212 fn index(&self, index: f64) -> &Self::Output {
213 let index = NotNan::new(index).unwrap();
214 self.0.index(&index)
215 }
216}
217
218impl<T> Trace<T> {
219 /// Create a new empty trace. Equivalent to [`Trace::default()`]
220 pub fn new() -> Self {
221 Self(BTreeMap::new())
222 }
223
224 /// Number of elements in the trace.
225 pub fn len(&self) -> usize {
226 self.0.len()
227 }
228
229 /// Determine if the trace contains any elements.
230 pub fn is_empty(&self) -> bool {
231 self.0.is_empty()
232 }
233
234 /// Get the state for a given time. Returns None if the time is not present in the trace.
235 ///
236 /// # Safety
237 ///
238 /// This method will panic if the provided time is NaN
239 pub fn at_time(&self, time: f64) -> Option<&T> {
240 let key = NotNan::new(time).unwrap();
241 self.0.get(&key)
242 }
243
244 /// Insert a state for a given time into the trace. Returns the prior state if it exists.
245 ///
246 /// # Safety
247 ///
248 /// This method will panic if the provided time is NaN
249 pub fn insert(&mut self, time: f64, state: T) -> Option<T> {
250 let key = NotNan::new(time).unwrap();
251 self.0.insert(key, state)
252 }
253}
254
255/// Iterator over the times in a trace.
256///
257/// The times are yielded in chronological order (lower times -> higher times).
258///
259/// This iterator can be construced by calling the `times()` method on either a `Trace` or `Range`
260/// value.
261///
262/// ```rust
263/// use banquo::Trace;
264///
265/// let trace = Trace::from([
266/// (0.0, ()),
267/// (1.0, ()),
268/// (3.0, ()),
269/// (4.0, ()),
270/// (5.0, ()),
271/// ]);
272///
273/// let times = trace.times();
274///
275/// let range = trace.range(0.0..=3.0);
276/// let times = range.times();
277/// ```
278pub struct Times<I>(I);
279
280impl<I, T> Iterator for Times<I>
281where
282 I: Iterator<Item = (f64, T)>,
283{
284 type Item = f64;
285
286 fn next(&mut self) -> Option<Self::Item> {
287 self.0.next().map(|p| p.0)
288 }
289
290 fn size_hint(&self) -> (usize, Option<usize>) {
291 self.0.size_hint()
292 }
293}
294
295impl<I, T> ExactSizeIterator for Times<I>
296where
297 I: ExactSizeIterator<Item = (f64, T)>,
298{
299 fn len(&self) -> usize {
300 self.0.len()
301 }
302}
303
304impl<I, T> DoubleEndedIterator for Times<I>
305where
306 I: DoubleEndedIterator<Item = (f64, T)>,
307{
308 fn next_back(&mut self) -> Option<Self::Item> {
309 self.0.next_back().map(|p| p.0)
310 }
311}
312
313/// Iterator over the states in a trace.
314///
315/// The states are yielded in chronological order (lower times -> higher times).
316///
317/// This iterator can be construced by calling the `states()` method on either a `Trace` or `Range`
318/// value.
319///
320/// ```rust
321/// use banquo::Trace;
322///
323/// let trace = Trace::from([
324/// (0.0, ()),
325/// (1.0, ()),
326/// (3.0, ()),
327/// (4.0, ()),
328/// (5.0, ()),
329/// ]);
330///
331/// let states = trace.states();
332///
333/// let range = trace.range(0.0..=3.0);
334/// let states = range.states();
335/// ```
336pub struct States<I>(I);
337
338impl<I, T> Iterator for States<I>
339where
340 I: Iterator<Item = (f64, T)>,
341{
342 type Item = T;
343
344 fn next(&mut self) -> Option<Self::Item> {
345 self.0.next().map(|p| p.1)
346 }
347
348 fn size_hint(&self) -> (usize, Option<usize>) {
349 self.0.size_hint()
350 }
351}
352
353impl<I, T> ExactSizeIterator for States<I>
354where
355 I: ExactSizeIterator<Item = (f64, T)>,
356{
357 fn len(&self) -> usize {
358 self.0.len()
359 }
360}
361
362impl<I, T> DoubleEndedIterator for States<I>
363where
364 I: DoubleEndedIterator<Item = (f64, T)>,
365{
366 fn next_back(&mut self) -> Option<Self::Item> {
367 self.0.next_back().map(|p| p.1)
368 }
369}
370
371/// Iterator that calls a given function for every state in a trace, keeping the times the same.
372///
373/// This iterator can be construced by calling the `map_states()` method on an `IntoIter`, `Iter`,
374/// 'IterMut', or `Range` value.
375///
376/// ```rust
377/// use banquo::Trace;
378///
379/// let trace = Trace::from([
380/// (0.0, ()),
381/// (1.0, ()),
382/// (3.0, ()),
383/// (4.0, ()),
384/// (5.0, ()),
385/// ]);
386///
387/// let states = trace.states();
388///
389/// let range = trace.range(0.0..=3.0);
390/// let states = range.states();
391/// ```
392pub struct MapStates<I, F> {
393 iter: I,
394 f: F,
395}
396
397impl<I, F> MapStates<I, F> {
398 fn map_element<T, U>(&mut self, (time, state): (f64, T)) -> (f64, U)
399 where
400 F: FnMut(T) -> U,
401 {
402 (time, (self.f)(state))
403 }
404}
405
406impl<I, F, T, U> Iterator for MapStates<I, F>
407where
408 I: Iterator<Item = (f64, T)>,
409 F: FnMut(T) -> U,
410{
411 type Item = (f64, U);
412
413 fn next(&mut self) -> Option<Self::Item> {
414 self.iter.next().map(|e| self.map_element(e))
415 }
416
417 fn size_hint(&self) -> (usize, Option<usize>) {
418 self.iter.size_hint()
419 }
420}
421
422impl<I, F, T, U> DoubleEndedIterator for MapStates<I, F>
423where
424 I: DoubleEndedIterator<Item = (f64, T)>,
425 F: FnMut(T) -> U,
426{
427 fn next_back(&mut self) -> Option<Self::Item> {
428 self.iter.next_back().map(|e| self.map_element(e))
429 }
430}
431
432impl<I, F, T, U> ExactSizeIterator for MapStates<I, F>
433where
434 I: ExactSizeIterator<Item = (f64, T)>,
435 F: FnMut(T) -> U,
436{
437 fn len(&self) -> usize {
438 self.iter.len()
439 }
440}
441
442/// Borrowing iterator over the `(time, state)` pairs in a trace.
443///
444/// The values are yielded in chronological order (lower times -> higher times).
445///
446/// ```rust
447/// use banquo::Trace;
448///
449/// let trace = Trace::from([
450/// (0.0, ()),
451/// (1.0, ()),
452/// (2.0, ()),
453/// (3.0, ()),
454/// (4.0, ()),
455/// (5.0, ()),
456/// ]);
457///
458/// let iter = trace.iter();
459///
460/// for (time, state) in &trace {
461/// // ...
462/// }
463/// ```
464pub struct Iter<'a, T>(std::collections::btree_map::Iter<'a, NotNan<f64>, T>);
465
466impl<'a, T> Iter<'a, T> {
467 fn map_element((&time, state): (&'a NotNan<f64>, &'a T)) -> (f64, &'a T) {
468 (time.into_inner(), state)
469 }
470}
471
472impl<'a, T> Iterator for Iter<'a, T> {
473 type Item = (f64, &'a T);
474
475 fn next(&mut self) -> Option<Self::Item> {
476 self.0.next().map(Self::map_element)
477 }
478
479 fn size_hint(&self) -> (usize, Option<usize>) {
480 self.0.size_hint()
481 }
482}
483
484impl<'a, T> DoubleEndedIterator for Iter<'a, T> {
485 fn next_back(&mut self) -> Option<Self::Item> {
486 self.0.next_back().map(Self::map_element)
487 }
488}
489
490impl<'a, T> ExactSizeIterator for Iter<'a, T> {
491 fn len(&self) -> usize {
492 self.0.len()
493 }
494}
495
496impl<'a, T> Iter<'a, T> {
497 /// Create an iterator over the states of the trace, ignoring the times.
498 ///
499 /// # Examples
500 ///
501 /// ```rust
502 /// use banquo::Trace;
503 ///
504 /// let trace = Trace::from([
505 /// (0.0, 'a'),
506 /// (1.0, 'b'),
507 /// (2.0, 'c'),
508 /// (3.0, 'd'),
509 /// (4.0, 'e'),
510 /// (5.0, 'f'),
511 /// ]);
512 ///
513 /// let iter: Vec<&char> = trace.iter().states().collect();
514 /// ```
515 pub fn states(self) -> States<Self> {
516 States(self)
517 }
518
519 /// Create an iterator that applies the function `f` to each state of the trace, while keeping
520 /// the times the same.
521 ///
522 /// # Examples
523 ///
524 /// ```rust
525 /// use banquo::Trace;
526 ///
527 /// let trace = Trace::from([
528 /// (0.0, "s1".to_string()),
529 /// (1.0, "s2".to_string()),
530 /// (2.0, "s3".to_string()),
531 /// (3.0, "s4".to_string()),
532 /// (4.0, "s5".to_string()),
533 /// (5.0, "s6".to_string()),
534 /// ]);
535 ///
536 /// let iter: Trace<usize> = trace
537 /// .iter()
538 /// .map_states(|state: &String| state.len())
539 /// .collect();
540 /// ```
541 pub fn map_states<F, U>(self, f: F) -> MapStates<Self, F>
542 where
543 F: FnMut(&'a T) -> U,
544 {
545 MapStates { f, iter: self }
546 }
547}
548
549/// Owning iterator over the `(time, state)` pairs in a trace.
550///
551/// The values are yielded in chronological order (lower times -> higher times).
552///
553/// ```rust
554/// use banquo::Trace;
555///
556/// let trace = Trace::from([
557/// (0.0, ()),
558/// (1.0, ()),
559/// (2.0, ()),
560/// (3.0, ()),
561/// (4.0, ()),
562/// (5.0, ()),
563/// ]);
564///
565/// // let iter = trace.into_iter();
566///
567/// for (time, state) in trace {
568/// // ...
569/// }
570/// ```
571pub struct IntoIter<T>(std::collections::btree_map::IntoIter<NotNan<f64>, T>);
572
573impl<T> IntoIter<T> {
574 fn map_element((time, state): (NotNan<f64>, T)) -> (f64, T) {
575 (time.into_inner(), state)
576 }
577}
578
579impl<T> Iterator for IntoIter<T> {
580 type Item = (f64, T);
581
582 fn next(&mut self) -> Option<Self::Item> {
583 self.0.next().map(Self::map_element)
584 }
585
586 fn size_hint(&self) -> (usize, Option<usize>) {
587 self.0.size_hint()
588 }
589}
590
591impl<T> DoubleEndedIterator for IntoIter<T> {
592 fn next_back(&mut self) -> Option<Self::Item> {
593 self.0.next_back().map(Self::map_element)
594 }
595}
596
597impl<T> ExactSizeIterator for IntoIter<T> {
598 fn len(&self) -> usize {
599 self.0.len()
600 }
601}
602
603impl<T> IntoIter<T> {
604 /// Create an iterator over the states of the trace, ignoring the times.
605 ///
606 /// # Example
607 ///
608 /// ```rust
609 /// use banquo::Trace;
610 ///
611 /// let trace = Trace::from([
612 /// (0.0, 'a'),
613 /// (1.0, 'b'),
614 /// (2.0, 'c'),
615 /// (3.0, 'd'),
616 /// (4.0, 'e'),
617 /// (5.0, 'f'),
618 /// ]);
619 ///
620 /// let iter: Vec<char> = trace
621 /// .into_iter()
622 /// .states()
623 /// .collect();
624 /// ```
625 pub fn states(self) -> States<Self> {
626 States(self)
627 }
628
629 /// Create an iterator that applies a function `f` to each state of the trace, while keeping the
630 /// times the same.
631 ///
632 /// # Example
633 ///
634 /// ```rust
635 /// use banquo::Trace;
636 ///
637 /// let trace = Trace::from([
638 /// (0.0, 'a'),
639 /// (1.0, 'b'),
640 /// (2.0, 'c'),
641 /// (3.0, 'd'),
642 /// (4.0, 'e'),
643 /// (5.0, 'f'),
644 /// ]);
645 ///
646 /// let iter: Trace<u8> = trace
647 /// .into_iter()
648 /// .map_states(|state: char| state as u8)
649 /// .collect();
650 /// ```
651 pub fn map_states<F, U>(self, f: F) -> MapStates<Self, F>
652 where
653 F: FnMut(T) -> U,
654 {
655 MapStates{ f, iter: self }
656 }
657}
658
659/// Mutably borrowing iterator over the `(time, state)` pairs in a trace.
660///
661/// The values are yielded in chronological order (lower times -> higher times).
662///
663/// ```rust
664/// use banquo::Trace;
665///
666/// let mut trace = Trace::from([
667/// (0.0, ()),
668/// (1.0, ()),
669/// (2.0, ()),
670/// (3.0, ()),
671/// (4.0, ()),
672/// (5.0, ()),
673/// ]);
674///
675/// let iter = trace.iter_mut();
676///
677/// for (time, state) in &mut trace {
678/// // ...
679/// }
680/// ```
681pub struct IterMut<'a, T>(std::collections::btree_map::IterMut<'a, NotNan<f64>, T>);
682
683impl<'a, T> IterMut<'a, T> {
684 fn map_element((&time, state): (&'a NotNan<f64>, &'a mut T)) -> (f64, &'a mut T) {
685 (time.into_inner(), state)
686 }
687}
688
689impl<'a, T> IterMut<'a, T> {
690 /// Create an iterator over the states of the trace, ignoring the times.
691 ///
692 /// # Example
693 ///
694 /// ```rust
695 /// use banquo::Trace;
696 ///
697 /// let mut trace = Trace::from([
698 /// (0.0, 10),
699 /// (1.0, 20),
700 /// (2.0, 30),
701 /// (3.0, 40),
702 /// (4.0, 50),
703 /// (5.0, 60),
704 /// ]);
705 ///
706 /// trace
707 /// .iter_mut()
708 /// .states()
709 /// .for_each(|state: &mut i32| *state += 5);
710 /// ```
711 pub fn states(self) -> States<Self> {
712 States(self)
713 }
714}
715
716impl<'a, T> Iterator for IterMut<'a, T> {
717 type Item = (f64, &'a mut T);
718
719 fn next(&mut self) -> Option<Self::Item> {
720 self.0.next().map(Self::map_element)
721 }
722
723 fn size_hint(&self) -> (usize, Option<usize>) {
724 self.0.size_hint()
725 }
726}
727
728impl<'a, T> DoubleEndedIterator for IterMut<'a, T> {
729 fn next_back(&mut self) -> Option<Self::Item> {
730 self.0.next_back().map(Self::map_element)
731 }
732}
733
734impl<'a, T> ExactSizeIterator for IterMut<'a, T> {
735 fn len(&self) -> usize {
736 self.0.len()
737 }
738}
739
740/// Borrowing iterator over the `(time, state)` pairs of a sub-interval of trace.
741///
742/// The values are yielded in chronological order (lower times -> higher times).
743///
744/// This value can be constructed by calling the `range()` method on a `Trace` value;
745///
746/// # Example
747///
748/// ```rust
749/// use banquo::Trace;
750///
751/// let trace = Trace::from([
752/// (0.0, ()),
753/// (1.0, ()),
754/// (2.0, ()),
755/// (3.0, ()),
756/// (4.0, ()),
757/// (5.0, ()),
758/// ]);
759///
760/// let range = trace.range(0.0..4.0);
761/// let range = trace.range(1.0..=3.0);
762/// ```
763pub struct Range<'a, T>(std::collections::btree_map::Range<'a, NotNan<f64>, T>);
764
765impl<'a, T> Range<'a, T> {
766 fn map_element((&time, state): (&'a NotNan<f64>, &'a T)) -> (f64, &'a T) {
767 (time.into_inner(), state)
768 }
769}
770
771impl<'a, T> Iterator for Range<'a, T> {
772 type Item = (f64, &'a T);
773
774 fn next(&mut self) -> Option<Self::Item> {
775 self.0.next().map(Self::map_element)
776 }
777
778 fn size_hint(&self) -> (usize, Option<usize>) {
779 self.0.size_hint()
780 }
781}
782
783impl<'a, T> DoubleEndedIterator for Range<'a, T> {
784 fn next_back(&mut self) -> Option<Self::Item> {
785 self.0.next_back().map(Self::map_element)
786 }
787}
788
789impl<'a, T> Range<'a, T> {
790 /// Create an iterator over the times of the range, ignoring the states.
791 ///
792 /// # Example
793 ///
794 /// ```rust
795 /// use banquo::Trace;
796 ///
797 /// let mut trace = Trace::from([
798 /// (0.0, 'a'),
799 /// (1.0, 'b'),
800 /// (2.0, 'c'),
801 /// (3.0, 'd'),
802 /// (4.0, 'e'),
803 /// (5.0, 'f'),
804 /// ]);
805 ///
806 /// let iter: Vec<f64> = trace
807 /// .range(0.5..=4.5)
808 /// .times()
809 /// .collect();
810 /// ```
811 pub fn times(self) -> Times<Self> {
812 Times(self)
813 }
814
815 /// Create an iterator over states of the range, ignoring the times.
816 ///
817 /// # Example
818 ///
819 /// ```rust
820 /// use banquo::Trace;
821 ///
822 /// let mut trace = Trace::from([
823 /// (0.0, 'a'),
824 /// (1.0, 'b'),
825 /// (2.0, 'c'),
826 /// (3.0, 'd'),
827 /// (4.0, 'e'),
828 /// (5.0, 'f'),
829 /// ]);
830 ///
831 /// let iter: Vec<&char> = trace
832 /// .range(1.0..=4.0)
833 /// .states()
834 /// .collect();
835 /// ```
836 pub fn states(self) -> States<Self> {
837 States(self)
838 }
839
840 /// Create an iterator that applies the function `f` to each state of the sub-interval, while keeping
841 /// the times the same.
842 ///
843 /// # Examples
844 ///
845 /// ```rust
846 /// use banquo::Trace;
847 ///
848 /// let trace = Trace::from([
849 /// (0.0, "s1".to_string()),
850 /// (1.0, "s2".to_string()),
851 /// (2.0, "s3".to_string()),
852 /// (3.0, "s4".to_string()),
853 /// (4.0, "s5".to_string()),
854 /// (5.0, "s6".to_string()),
855 /// ]);
856 ///
857 /// let iter: Trace<usize> = trace
858 /// .iter()
859 /// .map_states(|state: &String| state.len())
860 /// .collect();
861 /// ```
862 pub fn map_states<F, U>(self, f: F) -> MapStates<Self, F>
863 where
864 F: FnMut(&'a T) -> U,
865 {
866 MapStates { f, iter: self }
867 }
868}
869
870/// Mutably borrowing iterator over the `(time, state)` pairs of a sub-interval of trace.
871///
872/// The values are yielded in chronological order (lower times -> higher times).
873///
874/// This value can be constructed by calling the `range_mut()` method on a `Trace` value;
875///
876/// ```rust
877/// use banquo::Trace;
878///
879/// let mut trace = Trace::from([
880/// (0.0, ()),
881/// (1.0, ()),
882/// (2.0, ()),
883/// (3.0, ()),
884/// (4.0, ()),
885/// (5.0, ()),
886/// ]);
887///
888/// let range = trace.range_mut(0.0..4.0);
889/// let range = trace.range_mut(1.0..=3.0);
890/// ```
891pub struct RangeMut<'a, T>(std::collections::btree_map::RangeMut<'a, NotNan<f64>, T>);
892
893impl<'a, T> RangeMut<'a, T> {
894 fn map_element((&time, state): (&'a NotNan<f64>, &'a mut T)) -> (f64, &'a mut T) {
895 (time.into_inner(), state)
896 }
897}
898
899impl<'a, T> Iterator for RangeMut<'a, T> {
900 type Item = (f64, &'a mut T);
901
902 fn next(&mut self) -> Option<Self::Item> {
903 self.0.next().map(Self::map_element)
904 }
905
906 fn size_hint(&self) -> (usize, Option<usize>) {
907 self.0.size_hint()
908 }
909}
910
911impl<'a, T> DoubleEndedIterator for RangeMut<'a, T> {
912 fn next_back(&mut self) -> Option<Self::Item> {
913 self.0.next_back().map(Self::map_element)
914 }
915}
916
917impl<'a, T> RangeMut<'a, T> {
918 /// Create an iterator over states of the range, ignoring the times.
919 ///
920 /// # Example
921 ///
922 /// ```rust
923 /// use banquo::Trace;
924 ///
925 /// let mut trace = Trace::from([
926 /// (0.0, 'a'),
927 /// (1.0, 'b'),
928 /// (2.0, 'c'),
929 /// (3.0, 'd'),
930 /// (4.0, 'e'),
931 /// (5.0, 'f'),
932 /// ]);
933 ///
934 /// let iter: Vec<&char> = trace
935 /// .range(1.0..=4.0)
936 /// .states()
937 /// .collect();
938 /// ```
939 pub fn states(self) -> States<Self> {
940 States(self)
941 }
942}
943
944fn convert_bound(bound: Bound<&f64>) -> Bound<NotNan<f64>> {
945 match bound {
946 Bound::Unbounded => Bound::Unbounded,
947 Bound::Included(&val) => Bound::Included(NotNan::new(val).unwrap()),
948 Bound::Excluded(&val) => Bound::Excluded(NotNan::new(val).unwrap()),
949 }
950}
951
952impl<T> Trace<T> {
953 /// Create an iterator yielding `(time, &state)` values from the trace in chronological order.
954 pub fn iter(&self) -> Iter<T> {
955 self.into_iter()
956 }
957
958 /// Create an iterator yielding `(time, &mut state)` values from the trace in chronological order.
959 pub fn iter_mut(&mut self) -> IterMut<T> {
960 IterMut(self.0.iter_mut())
961 }
962
963 /// Create an iterator yielding time values from the trace in chronological order.
964 pub fn times(&self) -> Times<Iter<T>> {
965 Times(self.iter())
966 }
967
968 /// Create an iterator yielding `&state` values from the trace in chronological order.
969 pub fn states(&self) -> States<Iter<T>> {
970 States(self.iter())
971 }
972
973 /// Create an iterator yielding `&mut state values` from the trace in chronological order.
974 pub fn states_mut(&mut self) -> States<IterMut<T>> {
975 States(self.iter_mut())
976 }
977
978 /// Create an iterator over the `(time, &state)` values from a sub-interval of the trace in
979 /// chronological order
980 ///
981 /// # Safety
982 ///
983 /// This function panics if either range bound is NaN.
984 pub fn range<R>(&self, bounds: R) -> Range<T>
985 where
986 R: RangeBounds<f64>,
987 {
988 let start = convert_bound(bounds.start_bound());
989 let end = convert_bound(bounds.end_bound());
990
991 Range(self.0.range((start, end)))
992 }
993
994 /// Create an iterator over the `(time, &mut state)` values from a sub-interval of the trace in
995 /// chronological order.
996 ///
997 /// # Safety
998 ///
999 /// This function panics if either range bound is NaN.
1000 pub fn range_mut<R>(&mut self, bounds: R) -> RangeMut<T>
1001 where
1002 R: RangeBounds<f64>,
1003 {
1004 let start = convert_bound(bounds.start_bound());
1005 let end = convert_bound(bounds.end_bound());
1006
1007 RangeMut(self.0.range_mut((start, end)))
1008 }
1009}
1010
1011impl<T> IntoIterator for Trace<T> {
1012 type Item = (f64, T);
1013 type IntoIter = IntoIter<T>;
1014
1015 fn into_iter(self) -> Self::IntoIter {
1016 IntoIter(self.0.into_iter())
1017 }
1018}
1019
1020impl<'a, T> IntoIterator for &'a Trace<T> {
1021 type Item = (f64, &'a T);
1022 type IntoIter = Iter<'a, T>;
1023
1024 fn into_iter(self) -> Self::IntoIter {
1025 Iter(self.0.iter())
1026 }
1027}
1028
1029impl<'a, T> IntoIterator for &'a mut Trace<T> {
1030 type Item = (f64, &'a mut T);
1031 type IntoIter = IterMut<'a, T>;
1032
1033 fn into_iter(self) -> Self::IntoIter {
1034 IterMut(self.0.iter_mut())
1035 }
1036}
1037
1038#[cfg(test)]
1039mod tests {
1040 use super::Trace;
1041
1042 #[test]
1043 fn get_element() {
1044 let times = 0..10;
1045 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1046 let trace = Trace::from_iter(times.zip(values));
1047
1048 assert_eq!(trace.at_time(3.0), Some(&4.0))
1049 }
1050
1051 #[test]
1052 fn times() {
1053 let trace = Trace::from_iter([
1054 (1.0, ()),
1055 (2.0, ()),
1056 (3.0, ()),
1057 (4.0, ()),
1058 ]);
1059
1060 let mut times = trace.times();
1061
1062 assert_eq!(times.next(), Some(1.0));
1063 assert_eq!(times.next(), Some(2.0));
1064 assert_eq!(times.next(), Some(3.0));
1065 assert_eq!(times.next(), Some(4.0));
1066 assert_eq!(times.next(), None);
1067 }
1068
1069 #[test]
1070 fn states() {
1071 let trace = Trace::from_iter([
1072 (1.0, 1.0),
1073 (2.0, 2.0),
1074 (3.0, 3.0),
1075 (4.0, 4.0),
1076 ]);
1077
1078 let mut states = trace.states();
1079
1080 assert_eq!(states.next(), Some(&1.0));
1081 assert_eq!(states.next(), Some(&2.0));
1082 assert_eq!(states.next(), Some(&3.0));
1083 assert_eq!(states.next(), Some(&4.0));
1084 assert_eq!(states.next(), None);
1085 }
1086
1087 #[test]
1088 fn select_range() {
1089 let times = 0..10;
1090 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1091 let trace = Trace::from_iter(times.zip(values));
1092
1093 let subtrace_times: Vec<f64> = trace.range(0f64..4.0).times().collect::<Vec<_>>();
1094 let subtrace_states: Vec<f64> = trace.range(0f64..4.0).states().map(|state| *state).collect::<Vec<f64>>();
1095
1096 assert_eq!(subtrace_times, vec![0.0, 1.0, 2.0, 3.0]);
1097 assert_eq!(subtrace_states, vec![1.0, 2.0, 3.0, 4.0]);
1098 }
1099}