use anyhow::Result;
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::rc::Rc;
use crate::config::Config;
use crate::mir::Mir;
use crate::node::Node;
use crate::statement::Statement;
use crate::symbol::{Loc, Symbol, SymbolTable};
#[derive(Debug, Clone)]
pub struct Block {
pub stmts: Vec<Statement>,
pub sym_table: SymbolTable,
pub num_tmp: usize,
pub calls: HashMap<(String, u64), Node>,
pub config: Config,
}
impl Block {
pub fn new(config: Config) -> Block {
Block {
stmts: Vec::new(),
sym_table: SymbolTable::new(config.is_complex()),
num_tmp: 0,
calls: HashMap::new(),
config,
}
}
pub fn clear(&mut self) {
self.stmts.clear();
self.sym_table.syms.clear();
}
pub fn add_label(&mut self, label: &str) {
self.stmts.push(Statement::Label {
label: label.to_string(),
});
}
pub fn add_branch(&mut self, label: &str) {
self.stmts.push(Statement::Branch {
label: label.to_string(),
})
}
pub fn add_branch_if(&mut self, cond: Node, label: &str, is_else: bool) {
self.stmts.push(Statement::BranchIf {
cond,
label: label.to_string(),
is_else,
})
}
pub fn add_assign(&mut self, lhs: Node, rhs: Node) {
let rhs = self.process(rhs);
self.stmts.push(Statement::assign(lhs, rhs));
}
pub fn compile(&mut self, ir: &mut Mir) -> Result<()> {
for stmt in self.stmts.iter_mut() {
stmt.compile(ir)?;
}
Ok(())
}
pub fn create_mem(&mut self, name: &str) {
self.sym_table.add_mem(name);
}
pub fn create_tmp(&mut self) -> Node {
let name = format!("ψ{}", self.num_tmp);
self.num_tmp += 1;
self.create_tmp_named(&name)
}
pub fn create_tmp_named(&mut self, name: &str) -> Node {
self.sym_table.add_stack(name);
let sym = self.sym_table.find_sym(name).unwrap();
Node::Var {
sym,
}
}
pub fn var_exists(&self, name: &str) -> bool {
self.sym_table.contains(name)
}
pub fn create_void(&mut self) -> Node {
Node::create_void()
}
pub fn create_const(&mut self, val: f64, idx: u32) -> Node {
Node::create_const(val, idx)
}
pub fn create_var(&mut self, sym: Rc<RefCell<Symbol>>) -> Node {
Node::create_var(sym)
}
pub fn create_unary(&mut self, op: &str, arg: Node) -> Node {
Node::create_unary(op, arg, 1)
}
pub fn create_binary(&mut self, op: &str, left: Node, right: Node) -> Node {
Node::create_binary(op, left, right, 1, None)
}
pub fn create_powi(&mut self, arg: Node, power: i32) -> Node {
Node::create_powi(arg, power)
}
pub fn create_modular_powi(&mut self, left: Node, right: Node, power: i32) -> Node {
Node::create_modular_powi(left, right, power)
}
pub fn create_ifelse(&mut self, cond: Node, left: Node, right: Node) -> Node {
let tmp = self.create_tmp();
self.add_assign(tmp.clone(), cond);
Node::create_ifelse(&tmp, left, right)
}
fn process(&mut self, node: Node) -> Node {
self.trim(node)
}
fn trim(&mut self, node: Node) -> Node {
match node {
Node::Void => Node::Void,
Node::Const { val, idx } => Node::Const { val, idx },
Node::Var { sym } => Node::Var { sym },
Node::Unary { op, arg, power, .. } => self.trim_unary(&op, *arg, power),
Node::Binary {
op,
left,
right,
power,
cond,
..
} => self.trim_binary(&op, *left, *right, power, cond),
}
}
fn trim_unary(&mut self, op: &str, arg: Node, power: i32) -> Node {
let arg = self.trim(arg);
if !self.config.is_intrinsic_unary(op) {
self.break_call_unary(op, arg)
} else {
Node::create_unary(op, arg, power)
}
}
fn break_call_unary(&mut self, op: &str, arg: Node) -> Node {
let n = (op.to_string(), arg.hashof());
if self.config.cse() {
if let Some(lhs) = self.calls.get(&n) {
return lhs.clone();
}
}
let arg = self.create_unary("_call_", arg);
let lhs = self.create_tmp();
self.stmts.push(Statement::call(op, lhs.clone(), arg, 1));
self.calls.insert(n, lhs.clone());
lhs
}
fn trim_binary(
&mut self,
op: &str,
left: Node,
right: Node,
power: i32,
cond: Option<Loc>,
) -> Node {
let left = self.trim(left);
let right = self.trim(right);
if !self.config.is_intrinsic_binary(op) {
return self.break_call_binary(op, left, right);
}
let count_scratch = self.config.count_scratch();
let right = if left.ershov_number() == count_scratch - 1
&& right.ershov_number() == count_scratch - 1
{
let lhs = self.create_tmp();
self.stmts.push(Statement::assign(lhs.clone(), right));
lhs
} else {
right
};
Node::create_binary(op, left, right, power, cond)
}
pub fn break_call_binary(&mut self, op: &str, left: Node, right: Node) -> Node {
let n = (op.to_string(), left.hashof() ^ (right.hashof() + 1));
if self.config.cse() {
if let Some(lhs) = self.calls.get(&n) {
return lhs.clone();
}
}
let left = self.process(left);
let right = self.process(right);
let arg = self.create_binary("_call_", left, right);
let lhs = self.create_tmp();
self.stmts.push(Statement::call(op, lhs.clone(), arg, 2));
self.calls.insert(n, lhs.clone());
lhs
}
pub fn eliminate(&mut self) {
for _ in 0..5 {
if !self.elimination_pass() {
return;
}
}
}
pub fn elimination_pass(&mut self) -> bool {
if !self.config.cse() {
return false;
}
let mut stmts = std::mem::take(&mut self.stmts);
let mut hs: HashSet<u64> = HashSet::new(); let mut cs: HashMap<u64, (Node, Node)> = HashMap::new();
let mut depth: i32 = 0;
for s in stmts.iter_mut() {
match s {
Statement::Assign { rhs, .. } => {
if depth == 0 {
self.find_cse(&mut hs, &mut cs, rhs);
}
}
Statement::Call { arg, .. } => {
if depth == 0 {
self.find_cse(&mut hs, &mut cs, arg);
}
}
Statement::Label { .. } => {
depth += 1;
}
Statement::Branch { .. } | Statement::BranchIf { .. } => {
depth -= 1;
}
}
}
if cs.is_empty() {
self.stmts = std::mem::take(&mut stmts);
return false;
}
let mut ls: HashSet<u64> = HashSet::new();
for s in stmts {
match s {
Statement::Assign { lhs, rhs } => {
let rhs = self.rewrite_cse(&cs, &mut ls, rhs);
self.stmts.push(Statement::Assign { lhs, rhs });
}
Statement::Call {
op,
lhs,
arg,
num_args,
} => {
let arg = self.rewrite_cse(&cs, &mut ls, arg);
self.stmts.push(Statement::Call {
op,
lhs,
arg,
num_args,
});
}
Statement::Label { label } => {
self.stmts.push(Statement::Label { label });
}
Statement::Branch { label } => {
self.stmts.push(Statement::Branch { label });
}
Statement::BranchIf {
cond,
label,
is_else,
} => {
let cond = self.rewrite_cse(&cs, &mut ls, cond);
self.stmts.push(Statement::BranchIf {
cond,
label,
is_else,
});
}
}
}
true
}
fn find_cse(
&mut self,
hs: &mut HashSet<u64>,
cs: &mut HashMap<u64, (Node, Node)>,
node: &mut Node,
) {
if node.weightof() >= 5 && !node.is_unary("_call_") && !node.is_binary("_call_") {
let h = node.hashof();
if hs.contains(&h) {
cs.entry(h).or_insert_with(|| {
let lhs = self.create_tmp();
(lhs, node.clone())
});
} else {
hs.insert(h);
};
}
if let Some(n) = node.first() {
self.find_cse(hs, cs, n)
};
if let Some(n) = node.second() {
self.find_cse(hs, cs, n)
};
}
fn rewrite_cse(
&mut self,
cs: &HashMap<u64, (Node, Node)>,
ls: &mut HashSet<u64>,
node: Node,
) -> Node {
if node.weightof() < 5 {
return node;
}
match node {
Node::Void => Node::Void,
Node::Const { val, idx } => Node::Const { val, idx },
Node::Var { sym } => Node::Var { sym },
Node::Unary {
op, arg, power, h, ..
} => self.common_subexpr(cs, ls, h).unwrap_or_else(|| {
let arg = self.rewrite_cse(cs, ls, *arg);
Node::create_unary(op.as_str(), arg, power)
}),
Node::Binary {
op,
left,
right,
power,
cond,
h,
..
} => self.common_subexpr(cs, ls, h).unwrap_or_else(|| {
let left = self.rewrite_cse(cs, ls, *left);
let right = self.rewrite_cse(cs, ls, *right);
Node::create_binary(op.as_str(), left, right, power, cond)
}),
}
}
fn common_subexpr(
&mut self,
cs: &HashMap<u64, (Node, Node)>,
ls: &mut HashSet<u64>,
h: u64,
) -> Option<Node> {
if let Some((lhs, rhs)) = cs.get(&h) {
let k = &lhs.hashof();
if !ls.contains(k) {
self.stmts.push(Statement::assign(lhs.clone(), rhs.clone()));
ls.insert(*k);
}
return Some(lhs.clone());
}
None
}
}