stepflow_data/
statedata.rs

1use std::collections::{HashMap, HashSet};
2use super::{InvalidValue, InvalidVars};
3use super::value::{Value, ValidVal};
4use super::var::{Var, VarId};
5
6/// Store a set of [`Var`]s and corresponding [`Value`]s.
7///
8/// Internally the [`Value`] is wrapped in a [`ValidVal`](crate::value::ValidVal) to keep knowledge that this value has been validated for a specific [`Var`] already.
9#[derive(Debug, Clone, PartialEq)]
10#[cfg_attr(feature = "serde-support", derive(serde::Serialize))]
11pub struct StateData {
12  data: HashMap<VarId, ValidVal>,
13}
14
15impl StateData {
16  /// Create a new StateData instance
17  pub fn new() -> Self {
18    Self {
19      data: HashMap::new()
20    }
21  }
22
23  /// Add a new value
24  pub fn insert(&mut self, var: &Box<dyn Var + Send + Sync>, state_val: Box<dyn Value>)  -> Result<(), InvalidValue> {
25    let state_val_valid = ValidVal::try_new(state_val, var)?;
26    self.data.insert(var.id().clone(), state_val_valid);
27    Ok(())
28  }
29
30  /// Get the value based on its [`VarId`]. Returns a [`ValidVal`] to keep knowledge that the value has already been validated for the specific [`Var`].
31  pub fn get(&self, var_id: &VarId) -> Option<&ValidVal> {
32    self.data.get(var_id)
33  }
34
35  pub fn contains(&self, var_id: &VarId) -> bool {
36    self.data.contains_key(var_id)
37  }
38
39  /// Confirm that the StateData *only* contains the set of [`VarId`]s listed
40  pub fn contains_only(&self, contains_only: &HashSet<&VarId>) -> bool {
41    let found_excluded = self.data.iter().find(|(var_id, _)| !contains_only.contains(var_id));
42    found_excluded == None
43  }
44
45  /// Merge the data from another `StateData` into this one.
46  pub fn merge_from(&mut self, src: StateData) {
47    for (k, v) in src.data {
48      self.data.insert(k, v);
49    }
50  }
51
52  // Get an iterator over the values
53  pub fn iter_val(&self) -> impl Iterator<Item = (&VarId, &Box<dyn Value>)>  {
54    self.data.iter().map(|(var_id, valid_val)| {
55      (var_id, valid_val.get_val())
56    })
57  }
58
59
60  /// Create a `StateData` instance from an iterator of values
61  // NOTE: can't implement TryFrom for this because of blanket implementation in core
62  pub fn from_vals<'a, T>(iter: T)  -> Result<Self, InvalidVars> 
63    where T : std::iter::IntoIterator<Item = (&'a Box<dyn Var + Send + Sync + 'static>, Box<dyn Value>)>
64  {
65    let mut all_valid = true;
66    let validations = iter.into_iter()
67      .map(|(var, val)| {
68        match ValidVal::try_new(val, var) {
69          Ok(validated) => Ok((var, validated)),
70          Err(e) => {
71            all_valid = false;
72            Err((var, e))
73          }
74        }
75      })
76      .collect::<Vec<Result<_,_>>>();
77
78    if !all_valid {
79      let invalid: HashMap<VarId, InvalidValue> = validations.into_iter().filter_map(|validation| {
80        if let Err(e) = validation {
81          Some((e.0.id().clone(), e.1))
82        } else {
83          None
84        }
85      })
86      .collect();
87      return Err(InvalidVars::new(invalid));
88    }
89
90    let data: HashMap<VarId, ValidVal> = validations
91      .into_iter()
92      .map(|validation| {
93        let valid = validation.unwrap();
94        (valid.0.id().clone(), valid.1)
95      })
96      .collect();
97    Ok(StateData { data })
98  }
99}
100
101
102#[cfg(test)]
103mod tests {
104  use std::collections::{HashMap, HashSet};
105  use crate::{var::{Var, VarId, StringVar}, value::{Value, TrueValue}, InvalidValue, test_var_val};
106  use stepflow_test_util::test_id;
107  use super::{StateData, InvalidVars};
108
109  #[test]
110  fn merge() {
111    let mut data1 = StateData::new();
112    let mut data2 = StateData::new();
113    let mut data_merged = StateData::new();
114
115    let var1 = test_var_val();
116    let var2 = test_var_val();
117    let var3 = test_var_val();
118    let var4 = test_var_val();
119
120    data1.insert(&var1.0, var1.1).unwrap();
121    data2.insert(&var2.0, var2.1).unwrap();
122    data2.insert(&var3.0, var3.1).unwrap();
123    data_merged.insert(&var4.0, var4.1).unwrap();
124
125    assert!(!data_merged.contains(var1.0.id()));
126    data_merged.merge_from(data1);
127    assert!(data_merged.contains(var1.0.id()));
128
129    assert!(!data_merged.contains(var2.0.id()));
130    assert!(!data_merged.contains(var3.0.id()));
131    data_merged.merge_from(data2);
132    assert!(data_merged.contains(var2.0.id()));
133    assert!(data_merged.contains(var3.0.id()));
134  }
135
136  #[test]
137  fn from_vals_err() {
138    let var1 = test_var_val();
139    let var2 = test_var_val();
140    let badvar1: (Box<dyn Var + Send + Sync>, Box<dyn Value>) = (
141      Box::new(StringVar::new(test_id!(VarId))),
142      Box::new(TrueValue::new()));
143    let badvar2: (Box<dyn Var + Send + Sync>, Box<dyn Value>) = (
144      Box::new(StringVar::new(test_id!(VarId))),
145      Box::new(TrueValue::new()));
146    let badvar1_id = badvar1.0.id().clone();
147    let badvar2_id = badvar2.0.id().clone();
148
149    let vars = vec![var1, badvar1, var2, badvar2];
150    let vars = vars
151      .iter()
152      .map(|(var, val)| {
153        (var, val.clone())
154      });
155
156    let mut bad_ids = HashMap::new();
157    bad_ids.insert(badvar1_id.clone(), InvalidValue::WrongType);
158    bad_ids.insert(badvar2_id.clone(), InvalidValue::WrongType);
159    let expected_err = InvalidVars(bad_ids);
160
161    assert_eq!(StateData::from_vals(vars), Err(expected_err));
162  }
163
164  #[test]
165  fn contains_only() {
166    let mut data = StateData::new();
167
168    let var1 = test_var_val();
169    let var2 = test_var_val();
170    let var3 = test_var_val();
171
172    // add var1 + var2
173    data.insert(&var1.0, var1.1).unwrap();
174    data.insert(&var2.0, var2.1).unwrap();
175
176    let mut contains_only = HashSet::new();
177    contains_only.insert(var1.0.id());
178    contains_only.insert(var2.0.id());
179
180    // check only contains var1 + var2
181    assert_eq!(data.contains_only(&contains_only), true);
182
183    // add var3
184    data.insert(&var3.0, var3.1).unwrap();
185
186    // check only contains var1 + var2
187    assert!(!data.contains_only(&contains_only));
188  }
189
190  #[test]
191  fn iter() {
192    let mut data = StateData::new();
193    let var1 = test_var_val();
194    let var2 = test_var_val();
195    data.insert(&var1.0, var1.1.clone()).unwrap();
196    data.insert(&var2.0, var2.1.clone()).unwrap();
197
198    let hashmap = data.iter_val().collect::<HashMap<_,_>>();
199    assert_eq!(hashmap.len(), 2);
200    assert_eq!(hashmap.get(var1.0.id()), Some(&&var1.1));
201    assert_eq!(hashmap.get(var2.0.id()), Some(&&var2.1));
202  }
203}