1use std::collections::HashMap;
4
5use crate::net::State;
6
7pub fn copy(state: &State) -> State {
9 state.clone()
10}
11
12pub fn apply(base: &State, updates: &State) -> State {
14 let mut out = base.clone();
15 for (k, v) in updates {
16 out.insert(k.clone(), *v);
17 }
18 out
19}
20
21pub fn merge(states: &[&State]) -> State {
23 let size: usize = states.iter().map(|s| s.len()).sum();
24 let mut out = HashMap::with_capacity(size);
25 for s in states {
26 for (k, v) in *s {
27 out.insert(k.clone(), *v);
28 }
29 }
30 out
31}
32
33pub fn equal(a: &State, b: &State) -> bool {
35 if a.len() != b.len() {
36 return false;
37 }
38 for (k, v) in a {
39 match b.get(k) {
40 Some(bv) if *v == *bv => {}
41 _ => return false,
42 }
43 }
44 true
45}
46
47pub fn equal_tol(a: &State, b: &State, tol: f64) -> bool {
49 if a.len() != b.len() {
50 return false;
51 }
52 for (k, v) in a {
53 match b.get(k) {
54 Some(bv) if (v - bv).abs() <= tol => {}
55 _ => return false,
56 }
57 }
58 true
59}
60
61pub fn get(state: &State, key: &str) -> f64 {
63 state.get(key).copied().unwrap_or(0.0)
64}
65
66pub fn sum(state: &State) -> f64 {
68 state.values().sum()
69}
70
71pub fn sum_keys(state: &State, keys: &[&str]) -> f64 {
73 keys.iter().map(|k| state.get(*k).copied().unwrap_or(0.0)).sum()
74}
75
76pub fn scale(state: &State, factor: f64) -> State {
78 state.iter().map(|(k, v)| (k.clone(), v * factor)).collect()
79}
80
81pub fn filter(state: &State, predicate: impl Fn(&str) -> bool) -> State {
83 state
84 .iter()
85 .filter(|(k, _)| predicate(k))
86 .map(|(k, v)| (k.clone(), *v))
87 .collect()
88}
89
90pub fn keys(state: &State) -> Vec<String> {
92 state.keys().cloned().collect()
93}
94
95pub fn non_zero(state: &State) -> Vec<String> {
97 state
98 .iter()
99 .filter(|(_, v)| **v != 0.0)
100 .map(|(k, _)| k.clone())
101 .collect()
102}
103
104pub fn diff(a: &State, b: &State) -> State {
107 let mut d = State::new();
108
109 for (k, bv) in b {
111 match a.get(k) {
112 Some(av) if *av == *bv => {}
113 _ => {
114 d.insert(k.clone(), *bv);
115 }
116 }
117 }
118
119 for k in a.keys() {
121 if !b.contains_key(k) {
122 d.insert(k.clone(), 0.0);
123 }
124 }
125
126 d
127}
128
129pub fn max(state: &State) -> Option<(String, f64)> {
131 state
132 .iter()
133 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
134 .map(|(k, v)| (k.clone(), *v))
135}
136
137pub fn min(state: &State) -> Option<(String, f64)> {
139 state
140 .iter()
141 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
142 .map(|(k, v)| (k.clone(), *v))
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 fn make_state(pairs: &[(&str, f64)]) -> State {
150 pairs.iter().map(|(k, v)| (k.to_string(), *v)).collect()
151 }
152
153 #[test]
154 fn test_copy() {
155 let state = make_state(&[("A", 10.0), ("B", 5.0)]);
156 let copied = copy(&state);
157 assert_eq!(state, copied);
158 }
159
160 #[test]
161 fn test_apply() {
162 let base = make_state(&[("A", 10.0), ("B", 5.0)]);
163 let updates = make_state(&[("A", 0.0), ("C", 1.0)]);
164 let result = apply(&base, &updates);
165 assert_eq!(result["A"], 0.0);
166 assert_eq!(result["B"], 5.0);
167 assert_eq!(result["C"], 1.0);
168 }
169
170 #[test]
171 fn test_merge() {
172 let s1 = make_state(&[("A", 1.0)]);
173 let s2 = make_state(&[("B", 2.0)]);
174 let s3 = make_state(&[("A", 3.0)]);
175 let result = merge(&[&s1, &s2, &s3]);
176 assert_eq!(result["A"], 3.0);
177 assert_eq!(result["B"], 2.0);
178 }
179
180 #[test]
181 fn test_equal() {
182 let a = make_state(&[("A", 1.0), ("B", 2.0)]);
183 let b = make_state(&[("A", 1.0), ("B", 2.0)]);
184 let c = make_state(&[("A", 1.0), ("B", 3.0)]);
185 assert!(equal(&a, &b));
186 assert!(!equal(&a, &c));
187 }
188
189 #[test]
190 fn test_equal_tol() {
191 let a = make_state(&[("A", 1.0), ("B", 2.0)]);
192 let b = make_state(&[("A", 1.0001), ("B", 2.0002)]);
193 assert!(equal_tol(&a, &b, 0.001));
194 assert!(!equal_tol(&a, &b, 0.0001));
195 }
196
197 #[test]
198 fn test_sum() {
199 let state = make_state(&[("A", 10.0), ("B", 5.0), ("C", 3.0)]);
200 assert_eq!(sum(&state), 18.0);
201 }
202
203 #[test]
204 fn test_sum_keys() {
205 let state = make_state(&[("A", 10.0), ("B", 5.0), ("C", 3.0)]);
206 assert_eq!(sum_keys(&state, &["A", "C"]), 13.0);
207 }
208
209 #[test]
210 fn test_scale() {
211 let state = make_state(&[("A", 10.0), ("B", 5.0)]);
212 let scaled = scale(&state, 2.0);
213 assert_eq!(scaled["A"], 20.0);
214 assert_eq!(scaled["B"], 10.0);
215 }
216
217 #[test]
218 fn test_filter() {
219 let state = make_state(&[("_X0", 1.0), ("_X1", 1.0), ("pos", 5.0)]);
220 let history = filter(&state, |k| k.starts_with('_'));
221 assert_eq!(history.len(), 2);
222 assert!(!history.contains_key("pos"));
223 }
224
225 #[test]
226 fn test_diff() {
227 let a = make_state(&[("A", 1.0), ("B", 2.0), ("C", 3.0)]);
228 let b = make_state(&[("A", 1.0), ("B", 5.0), ("D", 4.0)]);
229 let d = diff(&a, &b);
230 assert!(!d.contains_key("A")); assert_eq!(d["B"], 5.0); assert_eq!(d["C"], 0.0); assert_eq!(d["D"], 4.0); }
235
236 #[test]
237 fn test_max_min() {
238 let state = make_state(&[("A", 10.0), ("B", 5.0), ("C", 20.0)]);
239 let (mk, mv) = max(&state).unwrap();
240 assert_eq!(mk, "C");
241 assert_eq!(mv, 20.0);
242
243 let (nk, nv) = min(&state).unwrap();
244 assert_eq!(nk, "B");
245 assert_eq!(nv, 5.0);
246 }
247
248 #[test]
249 fn test_non_zero() {
250 let state = make_state(&[("A", 0.0), ("B", 5.0), ("C", 0.0)]);
251 let nz = non_zero(&state);
252 assert_eq!(nz.len(), 1);
253 assert!(nz.contains(&"B".to_string()));
254 }
255}