1use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
9pub struct Place {
10 pub label: String,
11 pub initial: Vec<f64>,
12 pub capacity: Vec<f64>,
13 pub x: f64,
14 pub y: f64,
15 pub label_text: Option<String>,
16}
17
18impl Place {
19 pub fn new(
21 label: impl Into<String>,
22 initial: Vec<f64>,
23 capacity: Vec<f64>,
24 x: f64,
25 y: f64,
26 label_text: Option<String>,
27 ) -> Self {
28 Self {
29 label: label.into(),
30 initial,
31 capacity,
32 x,
33 y,
34 label_text,
35 }
36 }
37
38 pub fn token_count(&self) -> f64 {
40 if self.initial.is_empty() {
41 0.0
42 } else {
43 self.initial.iter().sum()
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct Transition {
51 pub label: String,
52 pub role: String,
53 pub x: f64,
54 pub y: f64,
55 pub label_text: Option<String>,
56}
57
58impl Transition {
59 pub fn new(
60 label: impl Into<String>,
61 role: impl Into<String>,
62 x: f64,
63 y: f64,
64 label_text: Option<String>,
65 ) -> Self {
66 Self {
67 label: label.into(),
68 role: role.into(),
69 x,
70 y,
71 label_text,
72 }
73 }
74}
75
76#[derive(Debug, Clone)]
78pub struct Arc {
79 pub source: String,
80 pub target: String,
81 pub weight: Vec<f64>,
82 pub inhibit_transition: bool,
83}
84
85impl Arc {
86 pub fn new(
87 source: impl Into<String>,
88 target: impl Into<String>,
89 weight: Vec<f64>,
90 inhibit_transition: bool,
91 ) -> Self {
92 Self {
93 source: source.into(),
94 target: target.into(),
95 weight,
96 inhibit_transition,
97 }
98 }
99
100 pub fn weight_sum(&self) -> f64 {
102 if self.weight.is_empty() {
103 1.0
104 } else {
105 self.weight.iter().sum()
106 }
107 }
108}
109
110pub type State = HashMap<String, f64>;
112
113#[derive(Debug, Clone)]
115pub struct PetriNet {
116 pub places: HashMap<String, Place>,
117 pub transitions: HashMap<String, Transition>,
118 pub arcs: Vec<Arc>,
119 pub token: Vec<String>,
120}
121
122impl PetriNet {
123 pub fn new() -> Self {
125 Self {
126 places: HashMap::new(),
127 transitions: HashMap::new(),
128 arcs: Vec::new(),
129 token: Vec::new(),
130 }
131 }
132
133 pub fn build() -> super::builder::Builder {
135 super::builder::Builder::new()
136 }
137
138 pub fn add_place(
140 &mut self,
141 label: impl Into<String>,
142 initial: Vec<f64>,
143 capacity: Vec<f64>,
144 x: f64,
145 y: f64,
146 label_text: Option<String>,
147 ) -> &Place {
148 let label = label.into();
149 let p = Place::new(label.clone(), initial, capacity, x, y, label_text);
150 self.places.insert(label.clone(), p);
151 &self.places[&label]
152 }
153
154 pub fn add_transition(
156 &mut self,
157 label: impl Into<String>,
158 role: impl Into<String>,
159 x: f64,
160 y: f64,
161 label_text: Option<String>,
162 ) -> &Transition {
163 let label = label.into();
164 let t = Transition::new(label.clone(), role, x, y, label_text);
165 self.transitions.insert(label.clone(), t);
166 &self.transitions[&label]
167 }
168
169 pub fn add_arc(
171 &mut self,
172 source: impl Into<String>,
173 target: impl Into<String>,
174 weight: Vec<f64>,
175 inhibit_transition: bool,
176 ) {
177 let a = Arc::new(source, target, weight, inhibit_transition);
178 self.arcs.push(a);
179 }
180
181 pub fn input_arcs(&self, transition_label: &str) -> Vec<&Arc> {
183 self.arcs
184 .iter()
185 .filter(|a| a.target == transition_label)
186 .collect()
187 }
188
189 pub fn output_arcs(&self, transition_label: &str) -> Vec<&Arc> {
191 self.arcs
192 .iter()
193 .filter(|a| a.source == transition_label)
194 .collect()
195 }
196
197 pub fn set_state(&self, custom_state: Option<&State>) -> State {
200 let mut state = State::new();
201 for (label, place) in &self.places {
202 if let Some(custom) = custom_state {
203 if let Some(&v) = custom.get(label) {
204 state.insert(label.clone(), v);
205 continue;
206 }
207 }
208 state.insert(label.clone(), place.token_count());
209 }
210 state
211 }
212
213 pub fn set_rates(&self, custom_rates: Option<&HashMap<String, f64>>) -> HashMap<String, f64> {
216 let mut rates = HashMap::new();
217 for label in self.transitions.keys() {
218 if let Some(custom) = custom_rates {
219 if let Some(&v) = custom.get(label) {
220 rates.insert(label.clone(), v);
221 continue;
222 }
223 }
224 rates.insert(label.clone(), 1.0);
225 }
226 rates
227 }
228}
229
230impl Default for PetriNet {
231 fn default() -> Self {
232 Self::new()
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn test_place_token_count() {
242 let p = Place::new("test", vec![3.0, 7.0], vec![], 0.0, 0.0, None);
243 assert_eq!(p.token_count(), 10.0);
244
245 let empty = Place::new("empty", vec![], vec![], 0.0, 0.0, None);
246 assert_eq!(empty.token_count(), 0.0);
247 }
248
249 #[test]
250 fn test_arc_weight_sum() {
251 let a = Arc::new("a", "b", vec![2.0, 3.0], false);
252 assert_eq!(a.weight_sum(), 5.0);
253
254 let empty = Arc::new("a", "b", vec![], false);
255 assert_eq!(empty.weight_sum(), 1.0);
256 }
257
258 #[test]
259 fn test_petri_net_basic() {
260 let mut net = PetriNet::new();
261 net.add_place("S", vec![999.0], vec![], 0.0, 0.0, None);
262 net.add_place("I", vec![1.0], vec![], 0.0, 0.0, None);
263 net.add_transition("infect", "default", 0.0, 0.0, None);
264 net.add_arc("S", "infect", vec![1.0], false);
265 net.add_arc("I", "infect", vec![1.0], false);
266 net.add_arc("infect", "I", vec![2.0], false);
267
268 assert_eq!(net.places.len(), 2);
269 assert_eq!(net.transitions.len(), 1);
270 assert_eq!(net.arcs.len(), 3);
271
272 let inputs = net.input_arcs("infect");
273 assert_eq!(inputs.len(), 2);
274
275 let outputs = net.output_arcs("infect");
276 assert_eq!(outputs.len(), 1);
277 }
278
279 #[test]
280 fn test_set_state() {
281 let mut net = PetriNet::new();
282 net.add_place("A", vec![10.0], vec![], 0.0, 0.0, None);
283 net.add_place("B", vec![0.0], vec![], 0.0, 0.0, None);
284
285 let state = net.set_state(None);
286 assert_eq!(state["A"], 10.0);
287 assert_eq!(state["B"], 0.0);
288
289 let mut custom = State::new();
290 custom.insert("A".into(), 5.0);
291 let state = net.set_state(Some(&custom));
292 assert_eq!(state["A"], 5.0);
293 assert_eq!(state["B"], 0.0);
294 }
295}