Skip to main content

libpetri_export/
mapper.rs

1use std::collections::HashSet;
2
3use libpetri_core::input::In;
4use libpetri_core::output::{self, Out};
5use libpetri_core::petri_net::PetriNet;
6
7use crate::graph::*;
8use crate::styles;
9
10/// Configuration for DOT export.
11#[derive(Debug, Clone)]
12pub struct DotConfig {
13    pub direction: RankDir,
14    pub show_types: bool,
15    pub show_intervals: bool,
16    pub show_priority: bool,
17    pub environment_places: HashSet<String>,
18}
19
20impl Default for DotConfig {
21    fn default() -> Self {
22        Self {
23            direction: RankDir::TopToBottom,
24            show_types: true,
25            show_intervals: true,
26            show_priority: true,
27            environment_places: HashSet::new(),
28        }
29    }
30}
31
32/// Sanitize a name for use as a DOT identifier.
33pub fn sanitize(name: &str) -> String {
34    name.chars()
35        .map(|c| {
36            if c.is_alphanumeric() || c == '_' {
37                c
38            } else {
39                '_'
40            }
41        })
42        .collect()
43}
44
45/// Place classification for visual styling.
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47enum PlaceCategory {
48    Start,
49    End,
50    Environment,
51    Regular,
52}
53
54/// Maps a PetriNet to a format-agnostic Graph.
55pub fn map_to_graph(net: &PetriNet, config: &DotConfig) -> Graph {
56    let mut graph = Graph::new(net.name());
57    graph.rankdir = config.direction;
58
59    // Graph attributes
60    graph
61        .graph_attrs
62        .push(("nodesep".into(), styles::NODESEP.to_string()));
63    graph
64        .graph_attrs
65        .push(("ranksep".into(), styles::RANKSEP.to_string()));
66    graph
67        .graph_attrs
68        .push(("forcelabels".into(), styles::FORCE_LABELS.into()));
69    graph
70        .graph_attrs
71        .push(("overlap".into(), styles::OVERLAP.into()));
72
73    // Node defaults
74    graph
75        .node_defaults
76        .push(("fontname".into(), styles::FONT_FAMILY.into()));
77    graph
78        .node_defaults
79        .push(("fontsize".into(), styles::FONT_NODE_SIZE.to_string()));
80
81    // Edge defaults
82    graph
83        .edge_defaults
84        .push(("fontname".into(), styles::FONT_FAMILY.into()));
85    graph
86        .edge_defaults
87        .push(("fontsize".into(), styles::FONT_EDGE_SIZE.to_string()));
88
89    // Analyze places
90    let (has_incoming, has_outgoing) = analyze_places(net);
91
92    // Create place nodes
93    for place_ref in net.places() {
94        let name = place_ref.name();
95        let id = format!("p_{}", sanitize(name));
96        let category = place_category(
97            name,
98            has_incoming.contains(name),
99            has_outgoing.contains(name),
100            config.environment_places.contains(name),
101        );
102        let style = match category {
103            PlaceCategory::Start => &styles::START_PLACE,
104            PlaceCategory::End => &styles::END_PLACE,
105            PlaceCategory::Environment => &styles::ENVIRONMENT_PLACE,
106            PlaceCategory::Regular => &styles::PLACE,
107        };
108
109        let shape = match style.shape {
110            "circle" => NodeShape::Circle,
111            "doublecircle" => NodeShape::DoubleCircle,
112            _ => NodeShape::Circle,
113        };
114
115        let node = GraphNode {
116            id,
117            label: String::new(),
118            shape,
119            fill: Some(style.fill.to_string()),
120            stroke: Some(style.stroke.to_string()),
121            penwidth: Some(style.penwidth),
122            semantic_id: Some(name.to_string()),
123            style: style.style.map(|s| s.to_string()),
124            height: style.height,
125            width: style.width,
126            attrs: vec![
127                ("xlabel".into(), name.to_string()),
128                ("fixedsize".into(), "true".into()),
129            ],
130        };
131        graph.nodes.push(node);
132    }
133
134    // Create transition nodes and edges
135    for t in net.transitions() {
136        let t_id = format!("t_{}", sanitize(t.name()));
137        let label = transition_label(t, config);
138
139        graph.nodes.push(GraphNode {
140            id: t_id.clone(),
141            label,
142            shape: NodeShape::Box,
143            fill: Some(styles::TRANSITION.fill.to_string()),
144            stroke: Some(styles::TRANSITION.stroke.to_string()),
145            penwidth: Some(styles::TRANSITION.penwidth),
146            semantic_id: Some(t.name().to_string()),
147            style: None,
148            height: styles::TRANSITION.height,
149            width: styles::TRANSITION.width,
150            attrs: Vec::new(),
151        });
152
153        // Input edges
154        for in_spec in t.input_specs() {
155            let from_id = format!("p_{}", sanitize(in_spec.place_name()));
156            let label = input_label(in_spec);
157
158            graph.edges.push(GraphEdge {
159                from: from_id,
160                to: t_id.clone(),
161                label,
162                color: Some(styles::INPUT_EDGE.color.to_string()),
163                style: Some(EdgeLineStyle::Solid),
164                arrowhead: Some(ArrowHead::Normal),
165                penwidth: styles::INPUT_EDGE.penwidth,
166                arc_type: Some("input".into()),
167                attrs: Vec::new(),
168            });
169        }
170
171        // Output edges
172        if let Some(out_spec) = t.output_spec() {
173            let reset_places: HashSet<&str> = t.resets().iter().map(|r| r.place.name()).collect();
174            output_edges(&t_id, out_spec, &reset_places, &mut graph.edges);
175        }
176
177        // Inhibitor edges
178        for inh in t.inhibitors() {
179            let from_id = format!("p_{}", sanitize(inh.place.name()));
180            graph.edges.push(GraphEdge {
181                from: from_id,
182                to: t_id.clone(),
183                label: None,
184                color: Some(styles::INHIBITOR_EDGE.color.to_string()),
185                style: Some(EdgeLineStyle::Solid),
186                arrowhead: Some(ArrowHead::Odot),
187                penwidth: styles::INHIBITOR_EDGE.penwidth,
188                arc_type: Some("inhibitor".into()),
189                attrs: Vec::new(),
190            });
191        }
192
193        // Read edges
194        for r in t.reads() {
195            let from_id = format!("p_{}", sanitize(r.place.name()));
196            graph.edges.push(GraphEdge {
197                from: from_id,
198                to: t_id.clone(),
199                label: Some("read".into()),
200                color: Some(styles::READ_EDGE.color.to_string()),
201                style: Some(EdgeLineStyle::Dashed),
202                arrowhead: Some(ArrowHead::Normal),
203                penwidth: styles::READ_EDGE.penwidth,
204                arc_type: Some("read".into()),
205                attrs: Vec::new(),
206            });
207        }
208
209        // Reset edges (only those not overlapping with outputs)
210        for r in t.resets() {
211            if t.output_places().contains(&r.place) {
212                continue; // suppress if already an output
213            }
214            let from_id = format!("p_{}", sanitize(r.place.name()));
215            graph.edges.push(GraphEdge {
216                from: t_id.clone(),
217                to: from_id,
218                label: Some("reset".into()),
219                color: Some(styles::RESET_EDGE.color.to_string()),
220                style: Some(EdgeLineStyle::Bold),
221                arrowhead: Some(ArrowHead::Normal),
222                penwidth: styles::RESET_EDGE.penwidth,
223                arc_type: Some("reset".into()),
224                attrs: Vec::new(),
225            });
226        }
227    }
228
229    graph
230}
231
232fn analyze_places(net: &PetriNet) -> (HashSet<String>, HashSet<String>) {
233    let mut has_incoming = HashSet::new();
234    let mut has_outgoing = HashSet::new();
235
236    for t in net.transitions() {
237        // Input arcs: place -> transition (place has outgoing)
238        for spec in t.input_specs() {
239            has_outgoing.insert(spec.place_name().to_string());
240        }
241        // Output arcs: transition -> place (place has incoming)
242        if let Some(out) = t.output_spec() {
243            for p in output::all_places(out) {
244                has_incoming.insert(p.name().to_string());
245            }
246        }
247    }
248
249    (has_incoming, has_outgoing)
250}
251
252fn place_category(
253    _name: &str,
254    has_incoming: bool,
255    has_outgoing: bool,
256    is_environment: bool,
257) -> PlaceCategory {
258    if is_environment {
259        PlaceCategory::Environment
260    } else if !has_incoming && has_outgoing {
261        PlaceCategory::Start
262    } else if has_incoming && !has_outgoing {
263        PlaceCategory::End
264    } else {
265        PlaceCategory::Regular
266    }
267}
268
269fn transition_label(t: &libpetri_core::transition::Transition, config: &DotConfig) -> String {
270    let mut parts = vec![t.name().to_string()];
271
272    if config.show_intervals && *t.timing() != libpetri_core::timing::Timing::Immediate {
273        let earliest = t.timing().earliest();
274        let latest = t.timing().latest();
275        if latest < libpetri_core::timing::MAX_DURATION_MS {
276            parts.push(format!("[{earliest}, {latest}]ms"));
277        } else {
278            parts.push(format!("[{earliest}, \u{221e})ms"));
279        }
280    }
281
282    if config.show_priority && t.priority() != 0 {
283        parts.push(format!("prio={}", t.priority()));
284    }
285
286    parts.join(" ")
287}
288
289fn input_label(spec: &In) -> Option<String> {
290    match spec {
291        In::One { .. } => None,
292        In::Exactly { count, .. } => Some(format!("\u{00d7}{count}")),
293        In::All { .. } => Some("*".to_string()),
294        In::AtLeast { minimum, .. } => Some(format!("\u{2265}{minimum}")),
295    }
296}
297
298#[allow(clippy::only_used_in_recursion)]
299fn output_edges(t_id: &str, out: &Out, reset_places: &HashSet<&str>, edges: &mut Vec<GraphEdge>) {
300    match out {
301        Out::Place(p) => {
302            let to_id = format!("p_{}", sanitize(p.name()));
303            edges.push(GraphEdge {
304                from: t_id.to_string(),
305                to: to_id,
306                label: None,
307                color: Some(styles::OUTPUT_EDGE.color.to_string()),
308                style: Some(EdgeLineStyle::Solid),
309                arrowhead: Some(ArrowHead::Normal),
310                penwidth: styles::OUTPUT_EDGE.penwidth,
311                arc_type: Some("output".into()),
312                attrs: Vec::new(),
313            });
314        }
315        Out::And(children) => {
316            for child in children {
317                output_edges(t_id, child, reset_places, edges);
318            }
319        }
320        Out::Xor(children) => {
321            for child in children {
322                let branch_label = infer_branch_label(child);
323                output_edges_with_label(t_id, child, branch_label.as_deref(), edges);
324            }
325        }
326        Out::Timeout { after_ms, child } => {
327            let label = format!("\u{23f1}{after_ms}ms");
328            output_edges_with_label(t_id, child, Some(&label), edges);
329        }
330        Out::ForwardInput { from, to } => {
331            let to_id = format!("p_{}", sanitize(to.name()));
332            edges.push(GraphEdge {
333                from: t_id.to_string(),
334                to: to_id,
335                label: Some(format!("\u{27f5}{}", from.name())),
336                color: Some(styles::OUTPUT_EDGE.color.to_string()),
337                style: Some(EdgeLineStyle::Dashed),
338                arrowhead: Some(ArrowHead::Normal),
339                penwidth: styles::OUTPUT_EDGE.penwidth,
340                arc_type: Some("output".into()),
341                attrs: Vec::new(),
342            });
343        }
344    }
345}
346
347fn output_edges_with_label(t_id: &str, out: &Out, label: Option<&str>, edges: &mut Vec<GraphEdge>) {
348    match out {
349        Out::Place(p) => {
350            let to_id = format!("p_{}", sanitize(p.name()));
351            edges.push(GraphEdge {
352                from: t_id.to_string(),
353                to: to_id,
354                label: label.map(|s| s.to_string()),
355                color: Some(styles::OUTPUT_EDGE.color.to_string()),
356                style: Some(EdgeLineStyle::Solid),
357                arrowhead: Some(ArrowHead::Normal),
358                penwidth: styles::OUTPUT_EDGE.penwidth,
359                arc_type: Some("output".into()),
360                attrs: Vec::new(),
361            });
362        }
363        Out::And(children) => {
364            for child in children {
365                output_edges_with_label(t_id, child, label, edges);
366            }
367        }
368        Out::ForwardInput { from, to } => {
369            let to_id = format!("p_{}", sanitize(to.name()));
370            let fwd_label = match label {
371                Some(l) => format!("{l} \u{27f5}{}", from.name()),
372                None => format!("\u{27f5}{}", from.name()),
373            };
374            edges.push(GraphEdge {
375                from: t_id.to_string(),
376                to: to_id,
377                label: Some(fwd_label),
378                color: Some(styles::OUTPUT_EDGE.color.to_string()),
379                style: Some(EdgeLineStyle::Dashed),
380                arrowhead: Some(ArrowHead::Normal),
381                penwidth: styles::OUTPUT_EDGE.penwidth,
382                arc_type: Some("output".into()),
383                attrs: Vec::new(),
384            });
385        }
386        _ => {
387            output_edges(t_id, out, &HashSet::new(), edges);
388        }
389    }
390}
391
392fn infer_branch_label(out: &Out) -> Option<String> {
393    match out {
394        Out::Place(p) => Some(p.name().to_string()),
395        Out::Timeout { after_ms, .. } => Some(format!("\u{23f1}{after_ms}ms")),
396        Out::ForwardInput { to, .. } => Some(to.name().to_string()),
397        _ => None,
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use libpetri_core::input::one;
405    use libpetri_core::output::out_place;
406    use libpetri_core::place::Place;
407    use libpetri_core::transition::Transition;
408
409    #[test]
410    fn sanitize_names() {
411        assert_eq!(sanitize("hello"), "hello");
412        assert_eq!(sanitize("hello world"), "hello_world");
413        assert_eq!(sanitize("a-b.c"), "a_b_c");
414    }
415
416    #[test]
417    fn basic_graph_mapping() {
418        let p1 = Place::<i32>::new("p1");
419        let p2 = Place::<i32>::new("p2");
420        let t = Transition::builder("t1")
421            .input(one(&p1))
422            .output(out_place(&p2))
423            .build();
424        let net = PetriNet::builder("test").transition(t).build();
425
426        let graph = map_to_graph(&net, &DotConfig::default());
427
428        // 2 place nodes + 1 transition node
429        assert_eq!(graph.nodes.len(), 3);
430        // 1 input edge + 1 output edge
431        assert_eq!(graph.edges.len(), 2);
432    }
433
434    #[test]
435    fn place_categories() {
436        let p_start = Place::<i32>::new("start");
437        let p_mid = Place::<i32>::new("mid");
438        let p_end = Place::<i32>::new("end");
439
440        let t1 = Transition::builder("t1")
441            .input(one(&p_start))
442            .output(out_place(&p_mid))
443            .build();
444        let t2 = Transition::builder("t2")
445            .input(one(&p_mid))
446            .output(out_place(&p_end))
447            .build();
448
449        let net = PetriNet::builder("test").transitions([t1, t2]).build();
450
451        let graph = map_to_graph(&net, &DotConfig::default());
452
453        // Find start place (green)
454        let start_node = graph.nodes.iter().find(|n| n.id == "p_start").unwrap();
455        assert_eq!(start_node.fill.as_deref(), Some(styles::START_PLACE.fill));
456
457        // Find end place (blue, doublecircle)
458        let end_node = graph.nodes.iter().find(|n| n.id == "p_end").unwrap();
459        assert_eq!(end_node.fill.as_deref(), Some(styles::END_PLACE.fill));
460        assert_eq!(end_node.shape, NodeShape::DoubleCircle);
461    }
462
463    #[test]
464    fn places_have_empty_label_and_xlabel() {
465        let p1 = Place::<i32>::new("Start");
466        let p2 = Place::<i32>::new("End");
467        let t = Transition::builder("t1")
468            .input(one(&p1))
469            .output(out_place(&p2))
470            .build();
471        let net = PetriNet::builder("test").transition(t).build();
472
473        let graph = map_to_graph(&net, &DotConfig::default());
474
475        for node in &graph.nodes {
476            if node.id.starts_with("p_") {
477                assert_eq!(node.label, "", "Place label should be empty");
478                let xlabel = node.attrs.iter().find(|(k, _)| k == "xlabel");
479                assert!(xlabel.is_some(), "Place should have xlabel");
480                let fixedsize = node.attrs.iter().find(|(k, _)| k == "fixedsize");
481                assert_eq!(fixedsize.unwrap().1, "true");
482            }
483        }
484    }
485
486    #[test]
487    fn transition_has_dimensions() {
488        let p1 = Place::<i32>::new("p1");
489        let p2 = Place::<i32>::new("p2");
490        let t = Transition::builder("t1")
491            .input(one(&p1))
492            .output(out_place(&p2))
493            .build();
494        let net = PetriNet::builder("test").transition(t).build();
495
496        let graph = map_to_graph(&net, &DotConfig::default());
497
498        let t_node = graph.nodes.iter().find(|n| n.id == "t_t1").unwrap();
499        assert_eq!(t_node.height, Some(0.4));
500        assert_eq!(t_node.width, Some(0.8));
501    }
502
503    #[test]
504    fn input_labels_use_unicode() {
505        use libpetri_core::input::{exactly, at_least};
506
507        let p1 = Place::<i32>::new("p1");
508        let p2 = Place::<i32>::new("p2");
509
510        let t = Transition::builder("t1")
511            .input(exactly(3, &p1))
512            .output(out_place(&p2))
513            .build();
514        let net = PetriNet::builder("test").transition(t).build();
515        let graph = map_to_graph(&net, &DotConfig::default());
516        let edge = &graph.edges[0];
517        assert_eq!(edge.label.as_deref(), Some("\u{00d7}3"));
518
519        let t2 = Transition::builder("t2")
520            .input(at_least(2, &p1))
521            .output(out_place(&p2))
522            .build();
523        let net2 = PetriNet::builder("test2").transition(t2).build();
524        let graph2 = map_to_graph(&net2, &DotConfig::default());
525        let edge2 = &graph2.edges[0];
526        assert_eq!(edge2.label.as_deref(), Some("\u{2265}2"));
527    }
528
529    #[test]
530    fn edge_penwidth_only_set_when_style_has_some() {
531        let p1 = Place::<i32>::new("p1");
532        let p2 = Place::<i32>::new("p2");
533        let t = Transition::builder("t1")
534            .input(one(&p1))
535            .output(out_place(&p2))
536            .build();
537        let net = PetriNet::builder("test").transition(t).build();
538
539        let graph = map_to_graph(&net, &DotConfig::default());
540
541        // Input/output edges should have no penwidth (styles have None)
542        for edge in &graph.edges {
543            assert_eq!(edge.penwidth, None, "input/output edges should have no penwidth");
544        }
545    }
546
547    #[test]
548    fn transition_label_space_separated() {
549        let p1 = Place::<i32>::new("p1");
550        let p2 = Place::<i32>::new("p2");
551        let t = Transition::builder("fire")
552            .input(one(&p1))
553            .output(out_place(&p2))
554            .timing(libpetri_core::timing::Timing::Delayed { after_ms: 500 })
555            .build();
556        let net = PetriNet::builder("test").transition(t).build();
557
558        let graph = map_to_graph(&net, &DotConfig::default());
559        let t_node = graph.nodes.iter().find(|n| n.id == "t_fire").unwrap();
560        assert_eq!(t_node.label, "fire [500, \u{221e})ms");
561    }
562
563    #[test]
564    fn read_edge_has_label() {
565        use libpetri_core::arc::read;
566
567        let p1 = Place::<i32>::new("p1");
568        let p2 = Place::<i32>::new("p2");
569        let cfg = Place::<i32>::new("cfg");
570        let t = Transition::builder("t1")
571            .input(one(&p1))
572            .output(out_place(&p2))
573            .read(read(&cfg))
574            .build();
575        let net = PetriNet::builder("test").transition(t).build();
576
577        let graph = map_to_graph(&net, &DotConfig::default());
578        let read_edge = graph.edges.iter().find(|e| e.arc_type.as_deref() == Some("read")).unwrap();
579        assert_eq!(read_edge.label.as_deref(), Some("read"));
580    }
581
582    #[test]
583    fn reset_edge_has_label_and_penwidth() {
584        use libpetri_core::arc::reset;
585
586        let p1 = Place::<i32>::new("p1");
587        let p2 = Place::<i32>::new("p2");
588        let cache = Place::<i32>::new("cache");
589        let t = Transition::builder("t1")
590            .input(one(&p1))
591            .output(out_place(&p2))
592            .reset(reset(&cache))
593            .build();
594        let net = PetriNet::builder("test").transition(t).build();
595
596        let graph = map_to_graph(&net, &DotConfig::default());
597        let reset_edge = graph.edges.iter().find(|e| e.arc_type.as_deref() == Some("reset")).unwrap();
598        assert_eq!(reset_edge.label.as_deref(), Some("reset"));
599        assert_eq!(reset_edge.penwidth, Some(2.0));
600    }
601}