Skip to main content

pflow_core/
net.rs

1//! Core Petri net data structures.
2//!
3//! A Petri net consists of Places (states), Transitions (events), and Arcs (connections).
4
5use std::collections::HashMap;
6
7/// A state in a Petri net that can hold tokens.
8#[derive(Debug, Clone)]
9pub struct Place {
10    pub label: String,
11    pub initial: Vec<f64>,
12    pub capacity: Vec<f64>,
13    pub x: f64,
14    pub y: f64,
15    pub label_text: Option<String>,
16}
17
18impl Place {
19    /// Creates a new Place.
20    pub fn new(
21        label: impl Into<String>,
22        initial: Vec<f64>,
23        capacity: Vec<f64>,
24        x: f64,
25        y: f64,
26        label_text: Option<String>,
27    ) -> Self {
28        Self {
29            label: label.into(),
30            initial,
31            capacity,
32            x,
33            y,
34            label_text,
35        }
36    }
37
38    /// Returns the sum of all tokens in this place.
39    pub fn token_count(&self) -> f64 {
40        if self.initial.is_empty() {
41            0.0
42        } else {
43            self.initial.iter().sum()
44        }
45    }
46}
47
48/// An event that can occur in a Petri net.
49#[derive(Debug, Clone)]
50pub struct Transition {
51    pub label: String,
52    pub role: String,
53    pub x: f64,
54    pub y: f64,
55    pub label_text: Option<String>,
56}
57
58impl Transition {
59    pub fn new(
60        label: impl Into<String>,
61        role: impl Into<String>,
62        x: f64,
63        y: f64,
64        label_text: Option<String>,
65    ) -> Self {
66        Self {
67            label: label.into(),
68            role: role.into(),
69            x,
70            y,
71            label_text,
72        }
73    }
74}
75
76/// A directed connection between a Place and a Transition.
77#[derive(Debug, Clone)]
78pub struct Arc {
79    pub source: String,
80    pub target: String,
81    pub weight: Vec<f64>,
82    pub inhibit_transition: bool,
83}
84
85impl Arc {
86    pub fn new(
87        source: impl Into<String>,
88        target: impl Into<String>,
89        weight: Vec<f64>,
90        inhibit_transition: bool,
91    ) -> Self {
92        Self {
93            source: source.into(),
94            target: target.into(),
95            weight,
96            inhibit_transition,
97        }
98    }
99
100    /// Returns the sum of all weight values. Returns 1.0 if empty.
101    pub fn weight_sum(&self) -> f64 {
102        if self.weight.is_empty() {
103            1.0
104        } else {
105            self.weight.iter().sum()
106        }
107    }
108}
109
110/// State map type: place label -> token count.
111pub type State = HashMap<String, f64>;
112
113/// A complete Petri net model.
114#[derive(Debug, Clone)]
115pub struct PetriNet {
116    pub places: HashMap<String, Place>,
117    pub transitions: HashMap<String, Transition>,
118    pub arcs: Vec<Arc>,
119    pub token: Vec<String>,
120}
121
122impl PetriNet {
123    /// Creates an empty Petri net.
124    pub fn new() -> Self {
125        Self {
126            places: HashMap::new(),
127            transitions: HashMap::new(),
128            arcs: Vec::new(),
129            token: Vec::new(),
130        }
131    }
132
133    /// Starts building a Petri net with the fluent API.
134    pub fn build() -> super::builder::Builder {
135        super::builder::Builder::new()
136    }
137
138    /// Adds a place to the net.
139    pub fn add_place(
140        &mut self,
141        label: impl Into<String>,
142        initial: Vec<f64>,
143        capacity: Vec<f64>,
144        x: f64,
145        y: f64,
146        label_text: Option<String>,
147    ) -> &Place {
148        let label = label.into();
149        let p = Place::new(label.clone(), initial, capacity, x, y, label_text);
150        self.places.insert(label.clone(), p);
151        &self.places[&label]
152    }
153
154    /// Adds a transition to the net.
155    pub fn add_transition(
156        &mut self,
157        label: impl Into<String>,
158        role: impl Into<String>,
159        x: f64,
160        y: f64,
161        label_text: Option<String>,
162    ) -> &Transition {
163        let label = label.into();
164        let t = Transition::new(label.clone(), role, x, y, label_text);
165        self.transitions.insert(label.clone(), t);
166        &self.transitions[&label]
167    }
168
169    /// Adds an arc to the net.
170    pub fn add_arc(
171        &mut self,
172        source: impl Into<String>,
173        target: impl Into<String>,
174        weight: Vec<f64>,
175        inhibit_transition: bool,
176    ) {
177        let a = Arc::new(source, target, weight, inhibit_transition);
178        self.arcs.push(a);
179    }
180
181    /// Returns all arcs that lead into the given transition.
182    pub fn input_arcs(&self, transition_label: &str) -> Vec<&Arc> {
183        self.arcs
184            .iter()
185            .filter(|a| a.target == transition_label)
186            .collect()
187    }
188
189    /// Returns all arcs that lead out from the given transition.
190    pub fn output_arcs(&self, transition_label: &str) -> Vec<&Arc> {
191        self.arcs
192            .iter()
193            .filter(|a| a.source == transition_label)
194            .collect()
195    }
196
197    /// Creates a state map from the net's initial state.
198    /// If `custom_state` is provided, those values override defaults.
199    pub fn set_state(&self, custom_state: Option<&State>) -> State {
200        let mut state = State::new();
201        for (label, place) in &self.places {
202            if let Some(custom) = custom_state {
203                if let Some(&v) = custom.get(label) {
204                    state.insert(label.clone(), v);
205                    continue;
206                }
207            }
208            state.insert(label.clone(), place.token_count());
209        }
210        state
211    }
212
213    /// Creates a rate map for all transitions.
214    /// If `custom_rates` is provided, those values override the default of 1.0.
215    pub fn set_rates(&self, custom_rates: Option<&HashMap<String, f64>>) -> HashMap<String, f64> {
216        let mut rates = HashMap::new();
217        for label in self.transitions.keys() {
218            if let Some(custom) = custom_rates {
219                if let Some(&v) = custom.get(label) {
220                    rates.insert(label.clone(), v);
221                    continue;
222                }
223            }
224            rates.insert(label.clone(), 1.0);
225        }
226        rates
227    }
228}
229
230impl Default for PetriNet {
231    fn default() -> Self {
232        Self::new()
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn test_place_token_count() {
242        let p = Place::new("test", vec![3.0, 7.0], vec![], 0.0, 0.0, None);
243        assert_eq!(p.token_count(), 10.0);
244
245        let empty = Place::new("empty", vec![], vec![], 0.0, 0.0, None);
246        assert_eq!(empty.token_count(), 0.0);
247    }
248
249    #[test]
250    fn test_arc_weight_sum() {
251        let a = Arc::new("a", "b", vec![2.0, 3.0], false);
252        assert_eq!(a.weight_sum(), 5.0);
253
254        let empty = Arc::new("a", "b", vec![], false);
255        assert_eq!(empty.weight_sum(), 1.0);
256    }
257
258    #[test]
259    fn test_petri_net_basic() {
260        let mut net = PetriNet::new();
261        net.add_place("S", vec![999.0], vec![], 0.0, 0.0, None);
262        net.add_place("I", vec![1.0], vec![], 0.0, 0.0, None);
263        net.add_transition("infect", "default", 0.0, 0.0, None);
264        net.add_arc("S", "infect", vec![1.0], false);
265        net.add_arc("I", "infect", vec![1.0], false);
266        net.add_arc("infect", "I", vec![2.0], false);
267
268        assert_eq!(net.places.len(), 2);
269        assert_eq!(net.transitions.len(), 1);
270        assert_eq!(net.arcs.len(), 3);
271
272        let inputs = net.input_arcs("infect");
273        assert_eq!(inputs.len(), 2);
274
275        let outputs = net.output_arcs("infect");
276        assert_eq!(outputs.len(), 1);
277    }
278
279    #[test]
280    fn test_set_state() {
281        let mut net = PetriNet::new();
282        net.add_place("A", vec![10.0], vec![], 0.0, 0.0, None);
283        net.add_place("B", vec![0.0], vec![], 0.0, 0.0, None);
284
285        let state = net.set_state(None);
286        assert_eq!(state["A"], 10.0);
287        assert_eq!(state["B"], 0.0);
288
289        let mut custom = State::new();
290        custom.insert("A".into(), 5.0);
291        let state = net.set_state(Some(&custom));
292        assert_eq!(state["A"], 5.0);
293        assert_eq!(state["B"], 0.0);
294    }
295}