use ling_ast::ast::{BinOp, UnOp};
use ling_mir::ir::*;
use std::collections::{HashMap, HashSet};
#[derive(Default)]
pub struct NumberTypes {
locals: HashMap<String, HashSet<usize>>,
bools: HashMap<String, HashSet<usize>>,
}
impl NumberTypes {
pub fn local_is_num(&self, func: &str, local: usize) -> bool {
self.locals.get(func).is_some_and(|s| s.contains(&local))
}
pub fn operand_is_num(&self, func: &str, op: &Operand) -> bool {
match op {
Operand::Copy(l) | Operand::Move(l) => self.local_is_num(func, l.0),
Operand::Constant(c) => matches!(c, Constant::I64(_) | Constant::F64(_)),
}
}
pub fn operand_is_bool(&self, func: &str, op: &Operand) -> bool {
match op {
Operand::Copy(l) | Operand::Move(l) => {
self.bools.get(func).is_some_and(|s| s.contains(&l.0))
},
Operand::Constant(Constant::Bool(_)) => true,
_ => false,
}
}
}
fn bool_binop(op: &BinOp) -> bool {
matches!(
op,
BinOp::Eq
| BinOp::Ne
| BinOp::Lt
| BinOp::Le
| BinOp::Gt
| BinOp::Ge
| BinOp::And
| BinOp::Or
)
}
fn arith_binop(op: &BinOp) -> bool {
matches!(
op,
BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Rem
)
}
pub fn analyze(functions: &[MirFunction]) -> NumberTypes {
let by_name: HashMap<&str, &MirFunction> =
functions.iter().map(|f| (f.name.as_str(), f)).collect();
let mut call_sites: HashMap<String, Vec<(String, Vec<Operand>)>> = HashMap::new();
let mut address_taken: HashSet<String> = HashSet::new();
for func in functions {
for bb in &func.basic_blocks {
for stmt in &bb.statements {
if let StatementKind::Assign(_, rval) = &stmt.kind {
match rval {
Rvalue::Call { func: callee, args } => {
if let Operand::Constant(Constant::Function(name)) = callee {
call_sites
.entry(name.clone())
.or_default()
.push((func.name.clone(), args.clone()));
}
for a in args {
if let Operand::Constant(Constant::Function(n)) = a {
address_taken.insert(n.clone());
}
}
},
Rvalue::Use(Operand::Constant(Constant::Function(n))) => {
address_taken.insert(n.clone());
},
_ => {},
}
}
}
}
}
let mut state: HashMap<String, HashSet<usize>> = HashMap::new();
for func in functions {
let all: HashSet<usize> = (0..func.locals.len() + func.arg_count + 1).collect();
state.insert(func.name.clone(), all);
}
let num_of = |state: &HashMap<String, HashSet<usize>>, func: &str, op: &Operand| -> bool {
match op {
Operand::Copy(l) | Operand::Move(l) => {
state.get(func).is_some_and(|s| s.contains(&l.0))
},
Operand::Constant(c) => matches!(c, Constant::I64(_) | Constant::F64(_)),
}
};
let mut changed = true;
while changed {
changed = false;
let mut param_num: HashMap<String, Vec<bool>> = HashMap::new();
for func in functions {
let mut pnums = vec![false; func.arg_count];
let sites = call_sites.get(&func.name);
let callable_directly = sites.is_some() && !address_taken.contains(&func.name);
if callable_directly {
for (j, pnum) in pnums.iter_mut().enumerate() {
*pnum = sites.unwrap().iter().all(|(caller, args)| {
args.get(j).is_some_and(|a| num_of(&state, caller, a))
});
}
}
param_num.insert(func.name.clone(), pnums);
}
for func in functions {
let pnums = ¶m_num[&func.name];
let mut writers: HashMap<usize, Vec<&Rvalue>> = HashMap::new();
for bb in &func.basic_blocks {
for stmt in &bb.statements {
if let StatementKind::Assign(l, rval) = &stmt.kind {
writers.entry(l.0).or_default().push(rval);
}
}
}
let total = func.locals.len() + func.arg_count + 1;
let mut new_set = HashSet::new();
for idx in 0..total {
if idx >= 1 && idx <= func.arg_count {
if pnums[idx - 1] {
new_set.insert(idx);
}
continue;
}
let assigns = writers.get(&idx);
let is_num = match assigns {
None => false,
Some(rvals) => rvals
.iter()
.all(|r| rvalue_is_num(r, &state, ¶m_num, func, &by_name)),
};
if is_num {
new_set.insert(idx);
}
}
let prev = state.get(&func.name);
if prev != Some(&new_set) {
changed = true;
state.insert(func.name.clone(), new_set);
}
}
}
let mut bools: HashMap<String, HashSet<usize>> = HashMap::new();
for func in functions {
let mut writers: HashMap<usize, Vec<&Rvalue>> = HashMap::new();
for bb in &func.basic_blocks {
for stmt in &bb.statements {
if let StatementKind::Assign(l, rval) = &stmt.kind {
writers.entry(l.0).or_default().push(rval);
}
}
}
let mut set: HashSet<usize> = HashSet::new();
let mut changed = true;
while changed {
changed = false;
for (&idx, rvals) in &writers {
if set.contains(&idx) {
continue;
}
let is_bool = rvals.iter().all(|r| match r {
Rvalue::BinaryOp(op, _, _) => bool_binop(op),
Rvalue::UnaryOp(UnOp::Not, _) => true,
Rvalue::Use(Operand::Constant(Constant::Bool(_))) => true,
Rvalue::Use(Operand::Copy(l)) | Rvalue::Use(Operand::Move(l)) => {
set.contains(&l.0)
},
_ => false,
});
if is_bool {
set.insert(idx);
changed = true;
}
}
}
bools.insert(func.name.clone(), set);
}
NumberTypes { locals: state, bools }
}
fn rvalue_is_num(
rval: &Rvalue,
state: &HashMap<String, HashSet<usize>>,
param_num: &HashMap<String, Vec<bool>>,
func: &MirFunction,
by_name: &HashMap<&str, &MirFunction>,
) -> bool {
let op_num = |op: &Operand| -> bool {
match op {
Operand::Copy(l) | Operand::Move(l) => {
state.get(&func.name).is_some_and(|s| s.contains(&l.0))
},
Operand::Constant(c) => matches!(c, Constant::I64(_) | Constant::F64(_)),
}
};
match rval {
Rvalue::Use(op) => op_num(op),
Rvalue::BinaryOp(op, a, b) => arith_binop(op) && op_num(a) && op_num(b),
Rvalue::UnaryOp(UnOp::Neg, a) => op_num(a),
Rvalue::UnaryOp(_, _) => false,
Rvalue::Call { func: callee, .. } => {
if let Operand::Constant(Constant::Function(name)) = callee {
if by_name.contains_key(name.as_str()) {
let _ = param_num;
return state.get(name).is_some_and(|s| s.contains(&0));
}
}
false
},
_ => false,
}
}