Skip to main content

alkahest_cas/plot/
dot.rs

1/// Graphviz DOT emitter for symbolic expression DAGs.
2///
3/// Walks the expression tree rooted at `expr` and emits a `digraph` in DOT
4/// format.  Shared sub-expressions (same `ExprId`) are rendered as a single
5/// node with multiple incoming edges, faithfully representing the DAG
6/// structure.
7///
8/// Pipe the output through `dot -Tpng -o graph.png` or `dot -Tsvg` to render.
9use crate::kernel::expr::ExprData;
10use crate::kernel::{ExprId, ExprPool};
11use std::collections::HashSet;
12
13/// Emit a Graphviz DOT string for the expression DAG rooted at `expr`.
14pub fn render_dot(pool: &ExprPool, expr: ExprId) -> String {
15    let mut out =
16        String::from("digraph expr {\n  node [shape=box fontname=\"Courier\" fontsize=10];\n");
17    let mut visited = HashSet::new();
18    emit_node(pool, expr, &mut out, &mut visited);
19    out.push_str("}\n");
20    out
21}
22
23fn node_id(id: ExprId) -> String {
24    format!("n{}", id.0)
25}
26
27fn node_label(pool: &ExprPool, id: ExprId) -> String {
28    match pool.get(id) {
29        ExprData::Symbol { name, .. } => format!("sym\\n{}", escape_dot(&name)),
30        ExprData::Integer(n) => format!("int\\n{}", n),
31        ExprData::Rational(r) => format!("rat\\n{}", r),
32        ExprData::Float(f) => format!("float\\n{}", f),
33        ExprData::Add(_) => "Add".to_string(),
34        ExprData::Mul(_) => "Mul".to_string(),
35        ExprData::Pow { .. } => "Pow".to_string(),
36        ExprData::Func { name, .. } => format!("fn\\n{}", escape_dot(&name)),
37        ExprData::Piecewise { .. } => "Piecewise".to_string(),
38        ExprData::Predicate { kind, .. } => format!("pred\\n{}", kind),
39        ExprData::Forall { .. } => "Forall".to_string(),
40        ExprData::Exists { .. } => "Exists".to_string(),
41        ExprData::BigO(_) => "BigO".to_string(),
42        ExprData::RootSum { .. } => "RootSum".to_string(),
43    }
44}
45
46fn escape_dot(s: &str) -> String {
47    s.replace('\\', "\\\\").replace('"', "\\\"")
48}
49
50fn emit_node(pool: &ExprPool, id: ExprId, out: &mut String, visited: &mut HashSet<u32>) {
51    if !visited.insert(id.0) {
52        return;
53    }
54    let label = node_label(pool, id);
55    out.push_str(&format!("  {} [label=\"{}\"];\n", node_id(id), label));
56
57    let children: Vec<ExprId> = match pool.get(id) {
58        ExprData::Add(kids) | ExprData::Mul(kids) => kids,
59        ExprData::Pow { base, exp } => vec![base, exp],
60        ExprData::Func { args, .. } => args,
61        ExprData::Piecewise { branches, default } => {
62            let mut v: Vec<ExprId> = branches.iter().flat_map(|(c, v)| [*c, *v]).collect();
63            v.push(default);
64            v
65        }
66        ExprData::Predicate { args, .. } => args,
67        ExprData::Forall { var, body } | ExprData::Exists { var, body } => vec![var, body],
68        ExprData::BigO(inner) => vec![inner],
69        _ => vec![],
70    };
71
72    for child in &children {
73        emit_node(pool, *child, out, visited);
74        out.push_str(&format!("  {} -> {};\n", node_id(id), node_id(*child)));
75    }
76}