use crate::ast::{Ast, AstNode};
use crate::error::{Error, Result};
use hash_chain::ChainMap;
use itertools::{Itertools, Position};
use miette::NamedSource;
use petgraph::stable_graph::{NodeIndex, StableDiGraph};
use petgraph::visit::{EdgeRef, IntoNodeReferences};
use petgraph::EdgeDirection;
use std::collections::HashMap;
use std::{cell::RefCell, rc::Rc};
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum GraphNodeType {
Dummy,
Begin,
End,
Node(String),
Choice(String),
}
#[derive(Debug, Clone, Copy)]
pub enum EdgeType {
Normal,
Branch(bool),
}
pub type Graph = StableDiGraph<GraphNodeType, EdgeType>;
struct GraphContext {
pub graph: Graph,
pub break_target: Option<NodeIndex>,
pub continue_target: Option<NodeIndex>,
pub goto_target: ChainMap<String, NodeIndex>,
#[allow(dead_code)]
pub global_begin: NodeIndex,
pub global_end: NodeIndex,
pub local_source: NodeIndex,
pub local_sink: NodeIndex,
}
impl GraphContext {
fn new() -> GraphContext {
let mut graph = Graph::new();
let begin = graph.add_node(GraphNodeType::Begin);
let end = graph.add_node(GraphNodeType::End);
GraphContext {
graph,
break_target: None,
continue_target: None,
goto_target: ChainMap::new(HashMap::new()),
global_begin: begin,
global_end: end,
local_source: begin,
local_sink: end,
}
}
}
fn build_graph(ast: &Ast, context: &mut GraphContext, source: &str, file_name: &str) -> Result<()> {
let local_source = context.local_source;
let local_sink = context.local_sink;
let break_target = context.break_target;
let continue_target = context.continue_target;
if let Some(labels) = &ast.label {
for i in labels {
if let Some(v) = context.goto_target.get(i) {
context.graph.add_edge(*v, local_source, EdgeType::Normal);
} else {
let v = context.graph.add_node(GraphNodeType::Dummy);
context.goto_target.insert_at(0, i.clone(), v)?;
context.graph.add_edge(v, local_source, EdgeType::Normal);
}
}
}
match &ast.node {
AstNode::Dummy => {
return Err(Error::UnexpectedDummyAstNode {
src: NamedSource::new(file_name, source.to_string()),
range: ast.range.clone().into(),
})
}
AstNode::Compound(v) => {
let mut sub_source = context.graph.add_node(GraphNodeType::Dummy);
let mut sub_sink = context.graph.add_node(GraphNodeType::Dummy);
context
.graph
.add_edge(local_source, sub_source, EdgeType::Normal);
if v.is_empty() {
context
.graph
.add_edge(sub_source, sub_sink, EdgeType::Normal);
} else {
for (pos, i) in v.iter().with_position() {
context.local_source = sub_source;
context.local_sink = sub_sink;
build_graph(&i.borrow(), context, source, file_name)?;
match pos {
itertools::Position::First | itertools::Position::Middle => {
sub_source = sub_sink;
sub_sink = context.graph.add_node(GraphNodeType::Dummy);
}
_ => {}
}
}
}
context
.graph
.add_edge(sub_sink, local_sink, EdgeType::Normal);
context.local_source = local_source;
context.local_sink = local_sink;
}
AstNode::Stat(s) => {
let current = context.graph.add_node(GraphNodeType::Node(s.clone()));
context
.graph
.add_edge(local_source, current, EdgeType::Normal);
context
.graph
.add_edge(current, local_sink, EdgeType::Normal);
}
AstNode::Continue(s) => {
let current = context.graph.add_node(GraphNodeType::Node(s.clone()));
context
.graph
.add_edge(local_source, current, EdgeType::Normal);
context.graph.add_edge(
current,
context.continue_target.ok_or(Error::UnexpectedContinue {
src: NamedSource::new(file_name, source.to_string()),
range: ast.range.clone().into(),
})?,
EdgeType::Normal,
);
}
AstNode::Break(s) => {
let current = context.graph.add_node(GraphNodeType::Node(s.clone()));
context
.graph
.add_edge(local_source, current, EdgeType::Normal);
context.graph.add_edge(
current,
context.break_target.ok_or(Error::UnexpectedBreak {
src: NamedSource::new(file_name, source.to_string()),
range: ast.range.clone().into(),
})?,
EdgeType::Normal,
);
}
AstNode::Return(s) => {
let current = context.graph.add_node(GraphNodeType::Node(s.clone()));
context
.graph
.add_edge(local_source, current, EdgeType::Normal);
context
.graph
.add_edge(current, context.global_end, EdgeType::Normal);
}
AstNode::If {
cond,
body,
otherwise,
} => {
let cond = context.graph.add_node(GraphNodeType::Choice(cond.clone()));
let sub_source = context.graph.add_node(GraphNodeType::Dummy);
let sub_sink = context.graph.add_node(GraphNodeType::Dummy);
context.graph.add_edge(local_source, cond, EdgeType::Normal);
context
.graph
.add_edge(cond, sub_source, EdgeType::Branch(true));
context
.graph
.add_edge(sub_sink, local_sink, EdgeType::Normal);
context.local_source = sub_source;
context.local_sink = sub_sink;
build_graph(&body.borrow(), context, source, file_name)?;
context.local_source = local_source;
context.local_sink = local_sink;
if let Some(t) = otherwise {
let sub_source1 = context.graph.add_node(GraphNodeType::Dummy);
context
.graph
.add_edge(cond, sub_source1, EdgeType::Branch(false));
context.local_source = sub_source1;
context.local_sink = sub_sink;
build_graph(&t.borrow(), context, source, file_name)?;
context.local_source = local_source;
context.local_sink = local_sink;
} else {
context
.graph
.add_edge(cond, local_sink, EdgeType::Branch(false));
}
}
AstNode::While { cond, body } => {
let cond = context.graph.add_node(GraphNodeType::Choice(cond.clone()));
let sub_source = context.graph.add_node(GraphNodeType::Dummy);
let sub_sink = context.graph.add_node(GraphNodeType::Dummy);
context.graph.add_edge(local_source, cond, EdgeType::Normal);
context
.graph
.add_edge(cond, sub_source, EdgeType::Branch(true));
context
.graph
.add_edge(cond, local_sink, EdgeType::Branch(false));
context.graph.add_edge(sub_sink, cond, EdgeType::Normal);
context.continue_target = Some(cond);
context.break_target = Some(local_sink);
context.local_source = sub_source;
context.local_sink = sub_sink;
build_graph(&body.borrow(), context, source, file_name)?;
context.continue_target = continue_target;
context.break_target = break_target;
context.local_source = local_source;
context.local_sink = local_sink;
}
AstNode::DoWhile { cond, body } => {
let sub_source = context.graph.add_node(GraphNodeType::Dummy);
let sub_sink = context.graph.add_node(GraphNodeType::Dummy);
let cond = context.graph.add_node(GraphNodeType::Choice(cond.clone()));
context
.graph
.add_edge(local_source, sub_source, EdgeType::Normal);
context.graph.add_edge(sub_sink, cond, EdgeType::Normal);
context
.graph
.add_edge(cond, sub_source, EdgeType::Branch(true));
context
.graph
.add_edge(cond, local_sink, EdgeType::Branch(false));
context.continue_target = Some(cond);
context.break_target = Some(local_sink);
context.local_source = sub_source;
context.local_sink = sub_sink;
build_graph(&body.borrow(), context, source, file_name)?;
context.continue_target = continue_target;
context.break_target = break_target;
context.local_source = local_source;
context.local_sink = local_sink;
}
AstNode::For {
init,
cond,
upd,
body,
} => {
let sub_source = context.graph.add_node(GraphNodeType::Dummy);
let sub_sink = context.graph.add_node(GraphNodeType::Dummy);
let cond = context.graph.add_node(GraphNodeType::Choice(cond.clone()));
let init = context.graph.add_node(GraphNodeType::Node(init.clone()));
let upd = context.graph.add_node(GraphNodeType::Node(upd.clone()));
context.graph.add_edge(local_source, init, EdgeType::Normal);
context.graph.add_edge(init, cond, EdgeType::Normal);
context
.graph
.add_edge(cond, sub_source, EdgeType::Branch(true));
context
.graph
.add_edge(cond, local_sink, EdgeType::Branch(false));
context.graph.add_edge(sub_sink, upd, EdgeType::Normal);
context.graph.add_edge(upd, cond, EdgeType::Normal);
context.continue_target = Some(upd);
context.break_target = Some(local_sink);
context.local_source = sub_source;
context.local_sink = sub_sink;
build_graph(&body.borrow(), context, source, file_name)?;
context.continue_target = continue_target;
context.break_target = break_target;
context.local_source = local_source;
context.local_sink = local_sink;
}
AstNode::Switch { cond, body, cases } => {
let case_goto_targets: HashMap<String, NodeIndex> = cases
.iter()
.map(|c| (c.clone(), context.graph.add_node(GraphNodeType::Dummy)))
.collect();
let table_start = generate_jump_table(
cond,
&mut context.graph,
&mut cases.iter().filter(|x| *x != "default").with_position(),
&case_goto_targets,
&cases.iter().any(|x| x == "default"),
&local_sink,
);
context
.graph
.add_edge(local_source, table_start, EdgeType::Normal);
let sub_source = context.graph.add_node(GraphNodeType::Dummy);
let sub_sink = context.graph.add_node(GraphNodeType::Dummy);
context.goto_target.new_child_with(case_goto_targets);
context.local_source = sub_source;
context.local_sink = sub_sink;
context.break_target = Some(local_sink);
context.continue_target = None;
context
.graph
.add_edge(sub_sink, local_sink, EdgeType::Normal);
build_graph(&body.borrow(), context, source, file_name)?;
context.local_source = local_source;
context.local_sink = local_sink;
context.break_target = break_target;
context.continue_target = continue_target;
context.goto_target.remove_child();
}
AstNode::Goto(t) => {
if let Some(target) = context.goto_target.get(t) {
context
.graph
.add_edge(local_source, *target, EdgeType::Normal);
} else {
let v = context.graph.add_node(GraphNodeType::Dummy);
context.goto_target.insert_at(0, t.clone(), v)?;
context.graph.add_edge(local_source, v, EdgeType::Normal);
}
}
}
Ok(())
}
fn generate_jump_table<'a, I, R>(
cond: &str,
graph: &mut Graph,
iter: &mut I,
case_goto_targets: &HashMap<String, NodeIndex>,
has_default: &bool,
sink: &NodeIndex,
) -> NodeIndex
where
I: Itertools<Item = (Position, R)>,
R: AsRef<str>,
{
if let Some((pos, i)) = iter.next() {
let cur = graph.add_node(GraphNodeType::Choice(format!("{} == {}", cond, i.as_ref())));
graph.add_edge(cur, case_goto_targets[i.as_ref()], EdgeType::Branch(true));
match pos {
itertools::Position::First | itertools::Position::Middle => {
let idx =
generate_jump_table(cond, graph, iter, case_goto_targets, has_default, sink);
graph.add_edge(cur, idx, EdgeType::Branch(false));
}
itertools::Position::Last | itertools::Position::Only => {
if *has_default {
graph.add_edge(cur, case_goto_targets["default"], EdgeType::Branch(false));
} else {
graph.add_edge(cur, *sink, EdgeType::Branch(false));
}
}
};
cur
} else {
let cur = graph.add_node(GraphNodeType::Dummy);
if *has_default {
graph.add_edge(cur, case_goto_targets["default"], EdgeType::Normal);
} else {
graph.add_edge(cur, *sink, EdgeType::Normal);
}
cur
}
}
fn remove_zero_in_degree_nodes(graph: &mut Graph, _source: &str) -> bool {
let nodes = graph
.node_indices()
.filter(|i| -> bool {
*graph.node_weight(*i).unwrap() == GraphNodeType::Dummy
&& graph.edges_directed(*i, EdgeDirection::Incoming).count() == 0
})
.collect_vec();
nodes
.iter()
.map(|x| graph.remove_node(*x))
.any(|x| x.is_some())
}
fn remove_single_node<F>(graph: &mut Graph, _source: &str, predicate: F) -> Result<bool>
where
F: Fn(NodeIndex, &GraphNodeType) -> bool,
{
if let Some(node_index) = graph
.node_references()
.filter(|(x, t)| predicate(*x, t))
.map(|(x, _)| x)
.take(1)
.next()
{
let incoming_edges = graph
.edges_directed(node_index, EdgeDirection::Incoming)
.map(|x| (x.source(), *x.weight()))
.collect_vec();
let neighbors = graph
.neighbors_directed(node_index, EdgeDirection::Outgoing)
.collect_vec();
if neighbors.len() != 1 {
return Err(Error::UnexpectedOutgoingEdges {
node_index,
neighbors,
graph: graph.clone(),
});
}
let next_node = neighbors[0];
for (src, edge_type) in incoming_edges {
graph.add_edge(src, next_node, edge_type);
}
graph.remove_node(node_index);
Ok(true)
} else {
Ok(false)
}
}
pub fn from_ast(ast: Rc<RefCell<Ast>>, source: &str, file_name: &str) -> Result<Graph> {
let mut ctx = GraphContext::new();
build_graph(&ast.borrow(), &mut ctx, source, file_name)?;
while remove_zero_in_degree_nodes(&mut ctx.graph, source) {}
while remove_single_node(&mut ctx.graph, source, |_, t| *t == GraphNodeType::Dummy)? {}
let remove_empty_nodes: fn(NodeIndex, &GraphNodeType) -> bool = |_, t| match t {
GraphNodeType::Node(t) => t.is_empty() || t.trim() == ";",
_ => false,
};
while remove_single_node(&mut ctx.graph, source, remove_empty_nodes)? {}
Ok(ctx.graph)
}