use std::collections::BTreeMap;
use crate::error::*;
use crate::sir::*;
use fnv;
fn evaluate_binop(
kind: BinOpKind,
left: LiteralKind,
right: LiteralKind,
) -> WeldResult<LiteralKind> {
use crate::ast::BinOpKind::*;
use crate::ast::LiteralKind::*;
let result = match kind {
Add => match (left, right) {
(I8Literal(l), I8Literal(r)) => I8Literal(l + r),
(I16Literal(l), I16Literal(r)) => I16Literal(l + r),
(I32Literal(l), I32Literal(r)) => I32Literal(l + r),
(I64Literal(l), I64Literal(r)) => I64Literal(l + r),
(F32Literal(l), F32Literal(r)) => {
F32Literal((f32::from_bits(l) + f32::from_bits(r)).to_bits())
}
(F64Literal(l), F64Literal(r)) => {
F64Literal((f64::from_bits(l) + f64::from_bits(r)).to_bits())
}
_ => {
return compile_err!("Mismatched types in evaluate_binop");
}
},
Subtract => match (left, right) {
(I8Literal(l), I8Literal(r)) => I8Literal(l - r),
(I16Literal(l), I16Literal(r)) => I16Literal(l - r),
(I32Literal(l), I32Literal(r)) => I32Literal(l - r),
(I64Literal(l), I64Literal(r)) => I64Literal(l - r),
(F32Literal(l), F32Literal(r)) => {
F32Literal((f32::from_bits(l) - f32::from_bits(r)).to_bits())
}
(F64Literal(l), F64Literal(r)) => {
F64Literal((f64::from_bits(l) - f64::from_bits(r)).to_bits())
}
_ => {
return compile_err!("Mismatched types in evaluate_binop");
}
},
Multiply => match (left, right) {
(I8Literal(l), I8Literal(r)) => I8Literal(l * r),
(I16Literal(l), I16Literal(r)) => I16Literal(l * r),
(I32Literal(l), I32Literal(r)) => I32Literal(l * r),
(I64Literal(l), I64Literal(r)) => I64Literal(l * r),
(F32Literal(l), F32Literal(r)) => {
F32Literal((f32::from_bits(l) * f32::from_bits(r)).to_bits())
}
(F64Literal(l), F64Literal(r)) => {
F64Literal((f64::from_bits(l) * f64::from_bits(r)).to_bits())
}
_ => {
return compile_err!("Mismatched types in evaluate_binop");
}
},
Divide => match (left, right) {
(I8Literal(l), I8Literal(r)) => I8Literal(l / r),
(I16Literal(l), I16Literal(r)) => I16Literal(l / r),
(I32Literal(l), I32Literal(r)) => I32Literal(l / r),
(I64Literal(l), I64Literal(r)) => I64Literal(l / r),
(F32Literal(l), F32Literal(r)) => {
F32Literal((f32::from_bits(l) / f32::from_bits(r)).to_bits())
}
(F64Literal(l), F64Literal(r)) => {
F64Literal((f64::from_bits(l) / f64::from_bits(r)).to_bits())
}
_ => {
return compile_err!("Mismatched types in evaluate_binop");
}
},
_ => {
return compile_err!("Unsupported binary operation in evaluate_binop");
}
};
Ok(result)
}
pub fn fold_constants(prog: &mut SirProgram) -> WeldResult<()> {
let parameters = &mut fnv::FnvHashSet::default();
for func in prog.funcs.iter() {
parameters.extend(func.params.iter().map(&|(k, _)| k).cloned());
for block in func.blocks.iter() {
parameters.extend(block.terminator.children().cloned());
}
}
for func in prog.funcs.iter_mut() {
fold_constants_in_function(func, parameters)?;
}
Ok(())
}
fn fold_constants_in_function(
func: &mut SirFunction,
global_params: &fnv::FnvHashSet<Symbol>,
) -> WeldResult<()> {
use crate::sir::StatementKind::*;
let mut assignment_counts: fnv::FnvHashMap<Symbol, i32> = fnv::FnvHashMap::default();
for block in func.blocks.iter_mut() {
for statement in block.statements.iter_mut() {
if statement.output.is_some() {
let assignment_count = assignment_counts
.entry(statement.output.clone().unwrap())
.or_insert(0);
*assignment_count += 1;
}
}
}
let mut values: fnv::FnvHashMap<Symbol, LiteralKind> = fnv::FnvHashMap::default();
let mut used_symbols = fnv::FnvHashSet::default();
for block in func.blocks.iter() {
for statement in block.statements.iter() {
if let Some(ref sym) = statement.output {
if func.symbol_type(sym)?.is_builder() {
used_symbols.insert(sym.clone());
}
}
}
}
for var in func.loop_variables.iter() {
used_symbols.insert(var.clone());
}
for block in func.blocks.iter_mut() {
for statement in block.statements.iter_mut() {
let replacement_lit = match statement.kind {
AssignLiteral(ref lit) => {
let output_sym = statement.output.clone().unwrap();
if assignment_counts[&output_sym] == 1 {
values.insert(output_sym, (*lit).clone());
Some((*lit).clone())
} else {
None
}
}
Assign(ref sym) => {
if values.contains_key(sym) {
let output_sym = statement.output.clone().unwrap();
let value = values[sym].clone();
if assignment_counts[&output_sym] == 1 {
values.insert(output_sym, value.clone());
}
Some(value)
} else {
None
}
}
BinOp {
ref op,
ref left,
ref right,
} if (&values).contains_key(left) && (&values).contains_key(right) => {
let left_val = values[left].clone();
let right_val = values[right].clone();
if let Ok(result) = evaluate_binop(*op, left_val, right_val) {
let output_sym = statement.output.clone().unwrap();
if assignment_counts[&output_sym] == 1 {
values.insert(output_sym, result.clone());
}
Some(result)
} else {
None
}
}
_ => None,
};
if let Some(val) = replacement_lit {
let kind = AssignLiteral(val);
let new_statement = Statement::new(statement.output.clone(), kind);
*statement = new_statement;
} else {
used_symbols.extend(statement.kind.children().cloned());
}
}
}
for block in func.blocks.iter_mut() {
block.statements.retain(|ref s| {
if let Some(ref sym) = s.output {
used_symbols.contains(sym) || global_params.contains(sym)
} else {
true
}
});
}
let mut locals = BTreeMap::new();
for (k, v) in func.locals.iter() {
if used_symbols.contains(k) || global_params.contains(k) {
locals.insert(k.clone(), v.clone());
}
}
func.locals = locals;
Ok(())
}