1use std::collections::HashMap;
5
6#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
8pub struct ParamState {
9 pub height: f32,
10 pub weight: f32,
11 pub muscle: f32,
12 pub age: f32,
13 pub extra: HashMap<String, f32>,
15}
16
17impl Default for ParamState {
18 fn default() -> Self {
19 ParamState {
20 height: 0.5,
21 weight: 0.5,
22 muscle: 0.5,
23 age: 0.5,
24 extra: HashMap::new(),
25 }
26 }
27}
28
29impl ParamState {
30 pub fn new(height: f32, weight: f32, muscle: f32, age: f32) -> Self {
31 ParamState {
32 height,
33 weight,
34 muscle,
35 age,
36 extra: HashMap::new(),
37 }
38 }
39
40 pub fn get(&self, key: &str) -> Option<f32> {
42 match key {
43 "height" => Some(self.height),
44 "weight" => Some(self.weight),
45 "muscle" => Some(self.muscle),
46 "age" => Some(self.age),
47 other => self.extra.get(other).copied(),
48 }
49 }
50}
51
52#[cfg(test)]
53mod tests {
54 use super::*;
55
56 #[test]
57 fn default_params_are_midpoint() {
58 let p = ParamState::default();
59 assert!((p.height - 0.5).abs() < 1e-6);
60 assert!((p.weight - 0.5).abs() < 1e-6);
61 }
62
63 #[test]
64 fn get_by_name() {
65 let mut p = ParamState::default();
66 p.extra.insert("bmi".to_string(), 0.3);
67 assert_eq!(p.get("height"), Some(0.5));
68 assert_eq!(p.get("bmi"), Some(0.3));
69 assert_eq!(p.get("missing"), None);
70 }
71}