Skip to main content

pflow_core/
stateutil.rs

1//! State map utilities for manipulating Petri net state maps.
2
3use std::collections::HashMap;
4
5use crate::net::State;
6
7/// Creates a deep copy of a state map.
8pub fn copy(state: &State) -> State {
9    state.clone()
10}
11
12/// Creates a new state by copying base and applying updates.
13pub fn apply(base: &State, updates: &State) -> State {
14    let mut out = base.clone();
15    for (k, v) in updates {
16        out.insert(k.clone(), *v);
17    }
18    out
19}
20
21/// Combines multiple state maps, with later maps taking precedence.
22pub fn merge(states: &[&State]) -> State {
23    let size: usize = states.iter().map(|s| s.len()).sum();
24    let mut out = HashMap::with_capacity(size);
25    for s in states {
26        for (k, v) in *s {
27            out.insert(k.clone(), *v);
28        }
29    }
30    out
31}
32
33/// Returns true if two states have the same keys and values (exact comparison).
34pub fn equal(a: &State, b: &State) -> bool {
35    if a.len() != b.len() {
36        return false;
37    }
38    for (k, v) in a {
39        match b.get(k) {
40            Some(bv) if *v == *bv => {}
41            _ => return false,
42        }
43    }
44    true
45}
46
47/// Returns true if two states have the same keys and values within tolerance.
48pub fn equal_tol(a: &State, b: &State, tol: f64) -> bool {
49    if a.len() != b.len() {
50        return false;
51    }
52    for (k, v) in a {
53        match b.get(k) {
54            Some(bv) if (v - bv).abs() <= tol => {}
55            _ => return false,
56        }
57    }
58    true
59}
60
61/// Returns the value for a key, or 0 if not found.
62pub fn get(state: &State, key: &str) -> f64 {
63    state.get(key).copied().unwrap_or(0.0)
64}
65
66/// Returns the sum of all values in the state.
67pub fn sum(state: &State) -> f64 {
68    state.values().sum()
69}
70
71/// Returns the sum of values for the specified keys.
72pub fn sum_keys(state: &State, keys: &[&str]) -> f64 {
73    keys.iter().map(|k| state.get(*k).copied().unwrap_or(0.0)).sum()
74}
75
76/// Returns a new state with all values multiplied by factor.
77pub fn scale(state: &State, factor: f64) -> State {
78    state.iter().map(|(k, v)| (k.clone(), v * factor)).collect()
79}
80
81/// Returns a new state containing only keys that pass the predicate.
82pub fn filter(state: &State, predicate: impl Fn(&str) -> bool) -> State {
83    state
84        .iter()
85        .filter(|(k, _)| predicate(k))
86        .map(|(k, v)| (k.clone(), *v))
87        .collect()
88}
89
90/// Returns all keys in the state map.
91pub fn keys(state: &State) -> Vec<String> {
92    state.keys().cloned().collect()
93}
94
95/// Returns keys that have non-zero values.
96pub fn non_zero(state: &State) -> Vec<String> {
97    state
98        .iter()
99        .filter(|(_, v)| **v != 0.0)
100        .map(|(k, _)| k.clone())
101        .collect()
102}
103
104/// Returns a map of keys where values differ between a and b.
105/// Values in the result are from state b.
106pub fn diff(a: &State, b: &State) -> State {
107    let mut d = State::new();
108
109    // Changed or new keys in b
110    for (k, bv) in b {
111        match a.get(k) {
112            Some(av) if *av == *bv => {}
113            _ => {
114                d.insert(k.clone(), *bv);
115            }
116        }
117    }
118
119    // Keys removed in b
120    for k in a.keys() {
121        if !b.contains_key(k) {
122            d.insert(k.clone(), 0.0);
123        }
124    }
125
126    d
127}
128
129/// Returns the key with the maximum value.
130pub fn max(state: &State) -> Option<(String, f64)> {
131    state
132        .iter()
133        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
134        .map(|(k, v)| (k.clone(), *v))
135}
136
137/// Returns the key with the minimum value.
138pub fn min(state: &State) -> Option<(String, f64)> {
139    state
140        .iter()
141        .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
142        .map(|(k, v)| (k.clone(), *v))
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    fn make_state(pairs: &[(&str, f64)]) -> State {
150        pairs.iter().map(|(k, v)| (k.to_string(), *v)).collect()
151    }
152
153    #[test]
154    fn test_copy() {
155        let state = make_state(&[("A", 10.0), ("B", 5.0)]);
156        let copied = copy(&state);
157        assert_eq!(state, copied);
158    }
159
160    #[test]
161    fn test_apply() {
162        let base = make_state(&[("A", 10.0), ("B", 5.0)]);
163        let updates = make_state(&[("A", 0.0), ("C", 1.0)]);
164        let result = apply(&base, &updates);
165        assert_eq!(result["A"], 0.0);
166        assert_eq!(result["B"], 5.0);
167        assert_eq!(result["C"], 1.0);
168    }
169
170    #[test]
171    fn test_merge() {
172        let s1 = make_state(&[("A", 1.0)]);
173        let s2 = make_state(&[("B", 2.0)]);
174        let s3 = make_state(&[("A", 3.0)]);
175        let result = merge(&[&s1, &s2, &s3]);
176        assert_eq!(result["A"], 3.0);
177        assert_eq!(result["B"], 2.0);
178    }
179
180    #[test]
181    fn test_equal() {
182        let a = make_state(&[("A", 1.0), ("B", 2.0)]);
183        let b = make_state(&[("A", 1.0), ("B", 2.0)]);
184        let c = make_state(&[("A", 1.0), ("B", 3.0)]);
185        assert!(equal(&a, &b));
186        assert!(!equal(&a, &c));
187    }
188
189    #[test]
190    fn test_equal_tol() {
191        let a = make_state(&[("A", 1.0), ("B", 2.0)]);
192        let b = make_state(&[("A", 1.0001), ("B", 2.0002)]);
193        assert!(equal_tol(&a, &b, 0.001));
194        assert!(!equal_tol(&a, &b, 0.0001));
195    }
196
197    #[test]
198    fn test_sum() {
199        let state = make_state(&[("A", 10.0), ("B", 5.0), ("C", 3.0)]);
200        assert_eq!(sum(&state), 18.0);
201    }
202
203    #[test]
204    fn test_sum_keys() {
205        let state = make_state(&[("A", 10.0), ("B", 5.0), ("C", 3.0)]);
206        assert_eq!(sum_keys(&state, &["A", "C"]), 13.0);
207    }
208
209    #[test]
210    fn test_scale() {
211        let state = make_state(&[("A", 10.0), ("B", 5.0)]);
212        let scaled = scale(&state, 2.0);
213        assert_eq!(scaled["A"], 20.0);
214        assert_eq!(scaled["B"], 10.0);
215    }
216
217    #[test]
218    fn test_filter() {
219        let state = make_state(&[("_X0", 1.0), ("_X1", 1.0), ("pos", 5.0)]);
220        let history = filter(&state, |k| k.starts_with('_'));
221        assert_eq!(history.len(), 2);
222        assert!(!history.contains_key("pos"));
223    }
224
225    #[test]
226    fn test_diff() {
227        let a = make_state(&[("A", 1.0), ("B", 2.0), ("C", 3.0)]);
228        let b = make_state(&[("A", 1.0), ("B", 5.0), ("D", 4.0)]);
229        let d = diff(&a, &b);
230        assert!(!d.contains_key("A")); // unchanged
231        assert_eq!(d["B"], 5.0); // changed
232        assert_eq!(d["C"], 0.0); // removed
233        assert_eq!(d["D"], 4.0); // added
234    }
235
236    #[test]
237    fn test_max_min() {
238        let state = make_state(&[("A", 10.0), ("B", 5.0), ("C", 20.0)]);
239        let (mk, mv) = max(&state).unwrap();
240        assert_eq!(mk, "C");
241        assert_eq!(mv, 20.0);
242
243        let (nk, nv) = min(&state).unwrap();
244        assert_eq!(nk, "B");
245        assert_eq!(nv, 5.0);
246    }
247
248    #[test]
249    fn test_non_zero() {
250        let state = make_state(&[("A", 0.0), ("B", 5.0), ("C", 0.0)]);
251        let nz = non_zero(&state);
252        assert_eq!(nz.len(), 1);
253        assert!(nz.contains(&"B".to_string()));
254    }
255}