stepflow_data/
statedata.rs1use std::collections::{HashMap, HashSet};
2use super::{InvalidValue, InvalidVars};
3use super::value::{Value, ValidVal};
4use super::var::{Var, VarId};
5
6#[derive(Debug, Clone, PartialEq)]
10#[cfg_attr(feature = "serde-support", derive(serde::Serialize))]
11pub struct StateData {
12 data: HashMap<VarId, ValidVal>,
13}
14
15impl StateData {
16 pub fn new() -> Self {
18 Self {
19 data: HashMap::new()
20 }
21 }
22
23 pub fn insert(&mut self, var: &Box<dyn Var + Send + Sync>, state_val: Box<dyn Value>) -> Result<(), InvalidValue> {
25 let state_val_valid = ValidVal::try_new(state_val, var)?;
26 self.data.insert(var.id().clone(), state_val_valid);
27 Ok(())
28 }
29
30 pub fn get(&self, var_id: &VarId) -> Option<&ValidVal> {
32 self.data.get(var_id)
33 }
34
35 pub fn contains(&self, var_id: &VarId) -> bool {
36 self.data.contains_key(var_id)
37 }
38
39 pub fn contains_only(&self, contains_only: &HashSet<&VarId>) -> bool {
41 let found_excluded = self.data.iter().find(|(var_id, _)| !contains_only.contains(var_id));
42 found_excluded == None
43 }
44
45 pub fn merge_from(&mut self, src: StateData) {
47 for (k, v) in src.data {
48 self.data.insert(k, v);
49 }
50 }
51
52 pub fn iter_val(&self) -> impl Iterator<Item = (&VarId, &Box<dyn Value>)> {
54 self.data.iter().map(|(var_id, valid_val)| {
55 (var_id, valid_val.get_val())
56 })
57 }
58
59
60 pub fn from_vals<'a, T>(iter: T) -> Result<Self, InvalidVars>
63 where T : std::iter::IntoIterator<Item = (&'a Box<dyn Var + Send + Sync + 'static>, Box<dyn Value>)>
64 {
65 let mut all_valid = true;
66 let validations = iter.into_iter()
67 .map(|(var, val)| {
68 match ValidVal::try_new(val, var) {
69 Ok(validated) => Ok((var, validated)),
70 Err(e) => {
71 all_valid = false;
72 Err((var, e))
73 }
74 }
75 })
76 .collect::<Vec<Result<_,_>>>();
77
78 if !all_valid {
79 let invalid: HashMap<VarId, InvalidValue> = validations.into_iter().filter_map(|validation| {
80 if let Err(e) = validation {
81 Some((e.0.id().clone(), e.1))
82 } else {
83 None
84 }
85 })
86 .collect();
87 return Err(InvalidVars::new(invalid));
88 }
89
90 let data: HashMap<VarId, ValidVal> = validations
91 .into_iter()
92 .map(|validation| {
93 let valid = validation.unwrap();
94 (valid.0.id().clone(), valid.1)
95 })
96 .collect();
97 Ok(StateData { data })
98 }
99}
100
101
102#[cfg(test)]
103mod tests {
104 use std::collections::{HashMap, HashSet};
105 use crate::{var::{Var, VarId, StringVar}, value::{Value, TrueValue}, InvalidValue, test_var_val};
106 use stepflow_test_util::test_id;
107 use super::{StateData, InvalidVars};
108
109 #[test]
110 fn merge() {
111 let mut data1 = StateData::new();
112 let mut data2 = StateData::new();
113 let mut data_merged = StateData::new();
114
115 let var1 = test_var_val();
116 let var2 = test_var_val();
117 let var3 = test_var_val();
118 let var4 = test_var_val();
119
120 data1.insert(&var1.0, var1.1).unwrap();
121 data2.insert(&var2.0, var2.1).unwrap();
122 data2.insert(&var3.0, var3.1).unwrap();
123 data_merged.insert(&var4.0, var4.1).unwrap();
124
125 assert!(!data_merged.contains(var1.0.id()));
126 data_merged.merge_from(data1);
127 assert!(data_merged.contains(var1.0.id()));
128
129 assert!(!data_merged.contains(var2.0.id()));
130 assert!(!data_merged.contains(var3.0.id()));
131 data_merged.merge_from(data2);
132 assert!(data_merged.contains(var2.0.id()));
133 assert!(data_merged.contains(var3.0.id()));
134 }
135
136 #[test]
137 fn from_vals_err() {
138 let var1 = test_var_val();
139 let var2 = test_var_val();
140 let badvar1: (Box<dyn Var + Send + Sync>, Box<dyn Value>) = (
141 Box::new(StringVar::new(test_id!(VarId))),
142 Box::new(TrueValue::new()));
143 let badvar2: (Box<dyn Var + Send + Sync>, Box<dyn Value>) = (
144 Box::new(StringVar::new(test_id!(VarId))),
145 Box::new(TrueValue::new()));
146 let badvar1_id = badvar1.0.id().clone();
147 let badvar2_id = badvar2.0.id().clone();
148
149 let vars = vec![var1, badvar1, var2, badvar2];
150 let vars = vars
151 .iter()
152 .map(|(var, val)| {
153 (var, val.clone())
154 });
155
156 let mut bad_ids = HashMap::new();
157 bad_ids.insert(badvar1_id.clone(), InvalidValue::WrongType);
158 bad_ids.insert(badvar2_id.clone(), InvalidValue::WrongType);
159 let expected_err = InvalidVars(bad_ids);
160
161 assert_eq!(StateData::from_vals(vars), Err(expected_err));
162 }
163
164 #[test]
165 fn contains_only() {
166 let mut data = StateData::new();
167
168 let var1 = test_var_val();
169 let var2 = test_var_val();
170 let var3 = test_var_val();
171
172 data.insert(&var1.0, var1.1).unwrap();
174 data.insert(&var2.0, var2.1).unwrap();
175
176 let mut contains_only = HashSet::new();
177 contains_only.insert(var1.0.id());
178 contains_only.insert(var2.0.id());
179
180 assert_eq!(data.contains_only(&contains_only), true);
182
183 data.insert(&var3.0, var3.1).unwrap();
185
186 assert!(!data.contains_only(&contains_only));
188 }
189
190 #[test]
191 fn iter() {
192 let mut data = StateData::new();
193 let var1 = test_var_val();
194 let var2 = test_var_val();
195 data.insert(&var1.0, var1.1.clone()).unwrap();
196 data.insert(&var2.0, var2.1.clone()).unwrap();
197
198 let hashmap = data.iter_val().collect::<HashMap<_,_>>();
199 assert_eq!(hashmap.len(), 2);
200 assert_eq!(hashmap.get(var1.0.id()), Some(&&var1.1));
201 assert_eq!(hashmap.get(var2.0.id()), Some(&&var2.1));
202 }
203}