rust_rocket/
track.rs

1//! This module contains `Key` and `Track` types.
2
3use crate::interpolation::*;
4use serde::{Deserialize, Serialize};
5
6#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
7/// The `Key` Type.
8pub struct Key {
9    row: u32,
10    value: f32,
11    interpolation: Interpolation,
12}
13
14impl Key {
15    /// Construct a new `Key`.
16    pub fn new(row: u32, value: f32, interp: Interpolation) -> Key {
17        Key {
18            row,
19            value,
20            interpolation: interp,
21        }
22    }
23}
24
25#[derive(Serialize, Deserialize, Debug, Clone)]
26/// The `Track` Type. This is a collection of `Key`s with a name.
27pub struct Track {
28    name: String,
29    keys: Vec<Key>,
30}
31
32impl Track {
33    /// Construct a new Track with a name.
34    pub fn new<S: Into<String>>(name: S) -> Track {
35        Track {
36            name: name.into(),
37            keys: Vec::new(),
38        }
39    }
40
41    /// Get the name of the track.
42    pub fn get_name(&self) -> &str {
43        self.name.as_str()
44    }
45
46    fn get_exact_position(&self, row: u32) -> Option<usize> {
47        self.keys.iter().position(|k| k.row == row)
48    }
49
50    fn get_insert_position(&self, row: u32) -> Option<usize> {
51        self.keys.iter().position(|k| k.row >= row)
52    }
53
54    fn get_lower_bound_position(&self, row: u32) -> usize {
55        self.keys
56            .iter()
57            .position(|k| k.row > row)
58            .unwrap_or(self.keys.len())
59            - 1
60    }
61
62    /// Insert or update a key on a track.
63    pub fn set_key(&mut self, key: Key) {
64        if let Some(pos) = self.get_exact_position(key.row) {
65            self.keys[pos] = key;
66        } else if let Some(pos) = self.get_insert_position(key.row) {
67            self.keys.insert(pos, key);
68        } else {
69            self.keys.push(key);
70        }
71    }
72
73    /// Delete a key from a track.
74    ///
75    /// If a key does not exist this will do nothing.
76    pub fn delete_key(&mut self, row: u32) {
77        if let Some(pos) = self.get_exact_position(row) {
78            self.keys.remove(pos);
79        }
80    }
81
82    /// Get a value based on a row.
83    ///
84    /// The row can be between two integers.
85    /// This will perform the required interpolation.
86    pub fn get_value(&self, row: f32) -> f32 {
87        if self.keys.is_empty() {
88            return 0.0;
89        }
90
91        let lower_row = row.floor() as u32;
92
93        if lower_row <= self.keys[0].row {
94            return self.keys[0].value;
95        }
96
97        if lower_row >= self.keys[self.keys.len() - 1].row {
98            return self.keys[self.keys.len() - 1].value;
99        }
100
101        let pos = self.get_lower_bound_position(lower_row);
102
103        let lower = &self.keys[pos];
104        let higher = &self.keys[pos + 1];
105
106        let t = (row - (lower.row as f32)) / ((higher.row as f32) - (lower.row as f32));
107        let it = lower.interpolation.interpolate(t);
108
109        (lower.value as f32) + ((higher.value as f32) - (lower.value as f32)) * it
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn test_three_keys() {
119        let mut track = Track::new("test");
120        track.set_key(Key::new(0, 1.0, Interpolation::Step));
121        track.set_key(Key::new(5, 0.0, Interpolation::Step));
122        track.set_key(Key::new(10, 1.0, Interpolation::Step));
123
124        assert_eq!(track.get_value(-1.), 1.0);
125        assert_eq!(track.get_value(0.), 1.0);
126        assert_eq!(track.get_value(1.), 1.0);
127
128        assert_eq!(track.get_value(4.), 1.0);
129        assert_eq!(track.get_value(5.), 0.0);
130        assert_eq!(track.get_value(6.), 0.0);
131
132        assert_eq!(track.get_value(9.), 0.0);
133        assert_eq!(track.get_value(10.), 1.0);
134        assert_eq!(track.get_value(11.), 1.0);
135    }
136}