Skip to main content

libpetri_export/
mapper.rs

1use std::collections::HashSet;
2
3use libpetri_core::input::In;
4use libpetri_core::output::{self, Out};
5use libpetri_core::petri_net::PetriNet;
6
7use crate::graph::*;
8use crate::styles;
9
10/// Configuration for DOT export.
11#[derive(Debug, Clone)]
12pub struct DotConfig {
13    pub direction: RankDir,
14    pub show_types: bool,
15    pub show_intervals: bool,
16    pub show_priority: bool,
17    pub environment_places: HashSet<String>,
18}
19
20impl Default for DotConfig {
21    fn default() -> Self {
22        Self {
23            direction: RankDir::TopToBottom,
24            show_types: true,
25            show_intervals: true,
26            show_priority: true,
27            environment_places: HashSet::new(),
28        }
29    }
30}
31
32/// Sanitize a name for use as a DOT identifier.
33pub fn sanitize(name: &str) -> String {
34    name.chars()
35        .map(|c| {
36            if c.is_alphanumeric() || c == '_' {
37                c
38            } else {
39                '_'
40            }
41        })
42        .collect()
43}
44
45/// Place classification for visual styling.
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47enum PlaceCategory {
48    Start,
49    End,
50    Environment,
51    Regular,
52}
53
54/// Maps a PetriNet to a format-agnostic Graph.
55pub fn map_to_graph(net: &PetriNet, config: &DotConfig) -> Graph {
56    let mut graph = Graph::new(net.name());
57    graph.rankdir = config.direction;
58
59    // Graph attributes
60    graph
61        .graph_attrs
62        .push(("nodesep".into(), styles::NODESEP.to_string()));
63    graph
64        .graph_attrs
65        .push(("ranksep".into(), styles::RANKSEP.to_string()));
66    graph
67        .graph_attrs
68        .push(("forcelabels".into(), styles::FORCE_LABELS.into()));
69    graph
70        .graph_attrs
71        .push(("overlap".into(), styles::OVERLAP.into()));
72
73    // Node defaults
74    graph
75        .node_defaults
76        .push(("fontname".into(), styles::FONT_FAMILY.into()));
77    graph
78        .node_defaults
79        .push(("fontsize".into(), styles::FONT_NODE_SIZE.to_string()));
80
81    // Edge defaults
82    graph
83        .edge_defaults
84        .push(("fontname".into(), styles::FONT_FAMILY.into()));
85    graph
86        .edge_defaults
87        .push(("fontsize".into(), styles::FONT_EDGE_SIZE.to_string()));
88
89    // Analyze places
90    let (has_incoming, has_outgoing) = analyze_places(net);
91
92    // Create place nodes
93    for place_ref in net.places() {
94        let name = place_ref.name();
95        let id = format!("p_{}", sanitize(name));
96        let category = place_category(
97            name,
98            has_incoming.contains(name),
99            has_outgoing.contains(name),
100            config.environment_places.contains(name),
101        );
102        let style = match category {
103            PlaceCategory::Start => &styles::START_PLACE,
104            PlaceCategory::End => &styles::END_PLACE,
105            PlaceCategory::Environment => &styles::ENVIRONMENT_PLACE,
106            PlaceCategory::Regular => &styles::PLACE,
107        };
108
109        let shape = match style.shape {
110            "circle" => NodeShape::Circle,
111            "doublecircle" => NodeShape::DoubleCircle,
112            _ => NodeShape::Circle,
113        };
114
115        let node = GraphNode {
116            id,
117            label: name.to_string(),
118            shape,
119            fill: Some(style.fill.to_string()),
120            stroke: Some(style.stroke.to_string()),
121            penwidth: Some(style.penwidth),
122            semantic_id: Some(name.to_string()),
123            style: style.style.map(|s| s.to_string()),
124            height: None,
125            width: None,
126            attrs: Vec::new(),
127        };
128        graph.nodes.push(node);
129    }
130
131    // Create transition nodes and edges
132    for t in net.transitions() {
133        let t_id = format!("t_{}", sanitize(t.name()));
134        let label = transition_label(t, config);
135
136        graph.nodes.push(GraphNode {
137            id: t_id.clone(),
138            label,
139            shape: NodeShape::Box,
140            fill: Some(styles::TRANSITION.fill.to_string()),
141            stroke: Some(styles::TRANSITION.stroke.to_string()),
142            penwidth: Some(styles::TRANSITION.penwidth),
143            semantic_id: Some(t.name().to_string()),
144            style: None,
145            height: None,
146            width: None,
147            attrs: Vec::new(),
148        });
149
150        // Input edges
151        for in_spec in t.input_specs() {
152            let from_id = format!("p_{}", sanitize(in_spec.place_name()));
153            let label = input_label(in_spec);
154
155            graph.edges.push(GraphEdge {
156                from: from_id,
157                to: t_id.clone(),
158                label,
159                color: Some(styles::INPUT_EDGE.color.to_string()),
160                style: Some(EdgeLineStyle::Solid),
161                arrowhead: Some(ArrowHead::Normal),
162                penwidth: Some(styles::INPUT_EDGE.penwidth),
163                arc_type: Some("input".into()),
164                attrs: Vec::new(),
165            });
166        }
167
168        // Output edges
169        if let Some(out_spec) = t.output_spec() {
170            let reset_places: HashSet<&str> = t.resets().iter().map(|r| r.place.name()).collect();
171            output_edges(&t_id, out_spec, &reset_places, &mut graph.edges);
172        }
173
174        // Inhibitor edges
175        for inh in t.inhibitors() {
176            let from_id = format!("p_{}", sanitize(inh.place.name()));
177            graph.edges.push(GraphEdge {
178                from: from_id,
179                to: t_id.clone(),
180                label: None,
181                color: Some(styles::INHIBITOR_EDGE.color.to_string()),
182                style: Some(EdgeLineStyle::Solid),
183                arrowhead: Some(ArrowHead::Odot),
184                penwidth: Some(styles::INHIBITOR_EDGE.penwidth),
185                arc_type: Some("inhibitor".into()),
186                attrs: Vec::new(),
187            });
188        }
189
190        // Read edges
191        for r in t.reads() {
192            let from_id = format!("p_{}", sanitize(r.place.name()));
193            graph.edges.push(GraphEdge {
194                from: from_id,
195                to: t_id.clone(),
196                label: None,
197                color: Some(styles::READ_EDGE.color.to_string()),
198                style: Some(EdgeLineStyle::Dashed),
199                arrowhead: Some(ArrowHead::Normal),
200                penwidth: Some(styles::READ_EDGE.penwidth),
201                arc_type: Some("read".into()),
202                attrs: Vec::new(),
203            });
204        }
205
206        // Reset edges (only those not overlapping with outputs)
207        for r in t.resets() {
208            if t.output_places().contains(&r.place) {
209                continue; // suppress if already an output
210            }
211            let from_id = format!("p_{}", sanitize(r.place.name()));
212            graph.edges.push(GraphEdge {
213                from: t_id.clone(),
214                to: from_id,
215                label: None,
216                color: Some(styles::RESET_EDGE.color.to_string()),
217                style: Some(EdgeLineStyle::Bold),
218                arrowhead: Some(ArrowHead::Normal),
219                penwidth: Some(styles::RESET_EDGE.penwidth),
220                arc_type: Some("reset".into()),
221                attrs: Vec::new(),
222            });
223        }
224    }
225
226    graph
227}
228
229fn analyze_places(net: &PetriNet) -> (HashSet<String>, HashSet<String>) {
230    let mut has_incoming = HashSet::new();
231    let mut has_outgoing = HashSet::new();
232
233    for t in net.transitions() {
234        // Input arcs: place -> transition (place has outgoing)
235        for spec in t.input_specs() {
236            has_outgoing.insert(spec.place_name().to_string());
237        }
238        // Output arcs: transition -> place (place has incoming)
239        if let Some(out) = t.output_spec() {
240            for p in output::all_places(out) {
241                has_incoming.insert(p.name().to_string());
242            }
243        }
244    }
245
246    (has_incoming, has_outgoing)
247}
248
249fn place_category(
250    _name: &str,
251    has_incoming: bool,
252    has_outgoing: bool,
253    is_environment: bool,
254) -> PlaceCategory {
255    if is_environment {
256        PlaceCategory::Environment
257    } else if !has_incoming && has_outgoing {
258        PlaceCategory::Start
259    } else if has_incoming && !has_outgoing {
260        PlaceCategory::End
261    } else {
262        PlaceCategory::Regular
263    }
264}
265
266fn transition_label(t: &libpetri_core::transition::Transition, config: &DotConfig) -> String {
267    let mut label = t.name().to_string();
268
269    if config.show_intervals && *t.timing() != libpetri_core::timing::Timing::Immediate {
270        let earliest = t.timing().earliest();
271        let latest = t.timing().latest();
272        if latest < libpetri_core::timing::MAX_DURATION_MS {
273            label.push_str(&format!("\n[{earliest}, {latest}]"));
274        } else {
275            label.push_str(&format!("\n[{earliest}, \u{221e})"));
276        }
277    }
278
279    if config.show_priority && t.priority() != 0 {
280        label.push_str(&format!("\nP={}", t.priority()));
281    }
282
283    label
284}
285
286fn input_label(spec: &In) -> Option<String> {
287    match spec {
288        In::One { .. } => None,
289        In::Exactly { count, .. } => Some(count.to_string()),
290        In::All { .. } => Some("*".to_string()),
291        In::AtLeast { minimum, .. } => Some(format!("{minimum}+")),
292    }
293}
294
295#[allow(clippy::only_used_in_recursion)]
296fn output_edges(t_id: &str, out: &Out, reset_places: &HashSet<&str>, edges: &mut Vec<GraphEdge>) {
297    match out {
298        Out::Place(p) => {
299            let to_id = format!("p_{}", sanitize(p.name()));
300            edges.push(GraphEdge {
301                from: t_id.to_string(),
302                to: to_id,
303                label: None,
304                color: Some(styles::OUTPUT_EDGE.color.to_string()),
305                style: Some(EdgeLineStyle::Solid),
306                arrowhead: Some(ArrowHead::Normal),
307                penwidth: Some(styles::OUTPUT_EDGE.penwidth),
308                arc_type: Some("output".into()),
309                attrs: Vec::new(),
310            });
311        }
312        Out::And(children) => {
313            for child in children {
314                output_edges(t_id, child, reset_places, edges);
315            }
316        }
317        Out::Xor(children) => {
318            for (i, child) in children.iter().enumerate() {
319                let branch_label = infer_branch_label(child).unwrap_or_else(|| format!("b{i}"));
320                output_edges_with_label(t_id, child, Some(&branch_label), edges);
321            }
322        }
323        Out::Timeout { after_ms: _, child } => {
324            output_edges(t_id, child, reset_places, edges);
325            // The timeout itself is handled by the action
326        }
327        Out::ForwardInput { from, to } => {
328            let to_id = format!("p_{}", sanitize(to.name()));
329            edges.push(GraphEdge {
330                from: t_id.to_string(),
331                to: to_id,
332                label: Some(format!("\u{21a9} {}", from.name())),
333                color: Some(styles::OUTPUT_EDGE.color.to_string()),
334                style: Some(EdgeLineStyle::Dashed),
335                arrowhead: Some(ArrowHead::Normal),
336                penwidth: Some(styles::OUTPUT_EDGE.penwidth),
337                arc_type: Some("output".into()),
338                attrs: Vec::new(),
339            });
340        }
341    }
342}
343
344fn output_edges_with_label(t_id: &str, out: &Out, label: Option<&str>, edges: &mut Vec<GraphEdge>) {
345    match out {
346        Out::Place(p) => {
347            let to_id = format!("p_{}", sanitize(p.name()));
348            edges.push(GraphEdge {
349                from: t_id.to_string(),
350                to: to_id,
351                label: label.map(|s| s.to_string()),
352                color: Some(styles::OUTPUT_EDGE.color.to_string()),
353                style: Some(EdgeLineStyle::Solid),
354                arrowhead: Some(ArrowHead::Normal),
355                penwidth: Some(styles::OUTPUT_EDGE.penwidth),
356                arc_type: Some("output".into()),
357                attrs: Vec::new(),
358            });
359        }
360        Out::And(children) => {
361            for child in children {
362                output_edges_with_label(t_id, child, label, edges);
363            }
364        }
365        _ => {
366            output_edges(t_id, out, &HashSet::new(), edges);
367        }
368    }
369}
370
371fn infer_branch_label(out: &Out) -> Option<String> {
372    match out {
373        Out::Place(p) => Some(p.name().to_string()),
374        _ => None,
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use libpetri_core::input::one;
382    use libpetri_core::output::out_place;
383    use libpetri_core::place::Place;
384    use libpetri_core::transition::Transition;
385
386    #[test]
387    fn sanitize_names() {
388        assert_eq!(sanitize("hello"), "hello");
389        assert_eq!(sanitize("hello world"), "hello_world");
390        assert_eq!(sanitize("a-b.c"), "a_b_c");
391    }
392
393    #[test]
394    fn basic_graph_mapping() {
395        let p1 = Place::<i32>::new("p1");
396        let p2 = Place::<i32>::new("p2");
397        let t = Transition::builder("t1")
398            .input(one(&p1))
399            .output(out_place(&p2))
400            .build();
401        let net = PetriNet::builder("test").transition(t).build();
402
403        let graph = map_to_graph(&net, &DotConfig::default());
404
405        // 2 place nodes + 1 transition node
406        assert_eq!(graph.nodes.len(), 3);
407        // 1 input edge + 1 output edge
408        assert_eq!(graph.edges.len(), 2);
409    }
410
411    #[test]
412    fn place_categories() {
413        let p_start = Place::<i32>::new("start");
414        let p_mid = Place::<i32>::new("mid");
415        let p_end = Place::<i32>::new("end");
416
417        let t1 = Transition::builder("t1")
418            .input(one(&p_start))
419            .output(out_place(&p_mid))
420            .build();
421        let t2 = Transition::builder("t2")
422            .input(one(&p_mid))
423            .output(out_place(&p_end))
424            .build();
425
426        let net = PetriNet::builder("test").transitions([t1, t2]).build();
427
428        let graph = map_to_graph(&net, &DotConfig::default());
429
430        // Find start place (green)
431        let start_node = graph.nodes.iter().find(|n| n.id == "p_start").unwrap();
432        assert_eq!(start_node.fill.as_deref(), Some(styles::START_PLACE.fill));
433
434        // Find end place (blue)
435        let end_node = graph.nodes.iter().find(|n| n.id == "p_end").unwrap();
436        assert_eq!(end_node.fill.as_deref(), Some(styles::END_PLACE.fill));
437    }
438}