#[cfg(test)]
#[path = "dedup_blocks_test.rs"]
mod test;
use cairo_lang_semantic::items::constant::ConstValueId;
use cairo_lang_semantic::{ConcreteVariant, TypeId};
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::unordered_hash_map::{self, UnorderedHashMap};
use itertools::{Itertools, zip_eq};
use crate::ids::FunctionId;
use crate::utils::{Rebuilder, RebuilderEx};
use crate::{
Block, BlockEnd, BlockId, Lowered, Statement, StatementCall, StatementConst, StatementDesnap,
StatementEnumConstruct, StatementIntoBox, StatementSnapshot, StatementStructConstruct,
StatementStructDestructure, StatementUnbox, VarRemapping, VarUsage, VariableArena, VariableId,
};
#[derive(Hash, PartialEq, Eq)]
struct CanonicBlock<'db> {
stmts: Vec<CanonicStatement<'db>>,
types: Vec<TypeId<'db>>,
returns: Vec<CanonicVar>,
}
#[derive(Hash, PartialEq, Eq)]
struct CanonicVar(usize);
#[derive(Hash, PartialEq, Eq)]
enum CanonicStatement<'db> {
Const {
value: ConstValueId<'db>,
output: CanonicVar,
boxed: bool,
},
Call {
function: FunctionId<'db>,
inputs: Vec<CanonicVar>,
with_coupon: bool,
outputs: Vec<CanonicVar>,
},
StructConstruct {
inputs: Vec<CanonicVar>,
output: CanonicVar,
},
StructDestructure {
input: CanonicVar,
outputs: Vec<CanonicVar>,
},
EnumConstruct {
variant: ConcreteVariant<'db>,
input: CanonicVar,
output: CanonicVar,
},
IntoBox {
input: CanonicVar,
output: CanonicVar,
},
Unbox {
input: CanonicVar,
output: CanonicVar,
},
Snapshot {
input: CanonicVar,
outputs: [CanonicVar; 2],
},
Desnap {
input: CanonicVar,
output: CanonicVar,
},
}
struct CanonicBlockBuilder<'db, 'a> {
variable: &'a VariableArena<'db>,
vars: UnorderedHashMap<VariableId, usize>,
types: Vec<TypeId<'db>>,
inputs: Vec<VarUsage<'db>>,
}
impl<'db, 'a> CanonicBlockBuilder<'db, 'a> {
fn new(variable: &'a VariableArena<'db>) -> CanonicBlockBuilder<'db, 'a> {
CanonicBlockBuilder {
variable,
vars: Default::default(),
types: vec![],
inputs: Default::default(),
}
}
fn handle_input(&mut self, var_usage: &VarUsage<'db>) -> CanonicVar {
let v = var_usage.var_id;
CanonicVar(match self.vars.entry(v) {
std::collections::hash_map::Entry::Occupied(e) => *e.get(),
std::collections::hash_map::Entry::Vacant(e) => {
self.types.push(self.variable[v].ty);
let new_id = *e.insert(self.types.len() - 1);
self.inputs.push(*var_usage);
new_id
}
})
}
fn handle_output(&mut self, v: &VariableId) -> CanonicVar {
CanonicVar(match self.vars.entry(*v) {
std::collections::hash_map::Entry::Occupied(e) => *e.get(),
std::collections::hash_map::Entry::Vacant(e) => {
self.types.push(self.variable[*v].ty);
*e.insert(self.types.len() - 1)
}
})
}
fn handle_statement(&mut self, statement: &Statement<'db>) -> CanonicStatement<'db> {
match statement {
Statement::Const(StatementConst { value, boxed, output }) => CanonicStatement::Const {
value: *value,
output: self.handle_output(output),
boxed: *boxed,
},
Statement::Call(StatementCall {
function,
inputs,
with_coupon,
outputs,
location: _,
is_specialization_base_call: _,
}) => CanonicStatement::Call {
function: *function,
inputs: inputs.iter().map(|input| self.handle_input(input)).collect(),
with_coupon: *with_coupon,
outputs: outputs.iter().map(|output| self.handle_output(output)).collect(),
},
Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
CanonicStatement::StructConstruct {
inputs: inputs.iter().map(|input| self.handle_input(input)).collect(),
output: self.handle_output(output),
}
}
Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
CanonicStatement::StructDestructure {
input: self.handle_input(input),
outputs: outputs.iter().map(|output| self.handle_output(output)).collect(),
}
}
Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
CanonicStatement::EnumConstruct {
variant: *variant,
input: self.handle_input(input),
output: self.handle_output(output),
}
}
Statement::Snapshot(StatementSnapshot { input, outputs }) => {
CanonicStatement::Snapshot {
input: self.handle_input(input),
outputs: outputs.map(|output| self.handle_output(&output)),
}
}
Statement::Desnap(StatementDesnap { input, output }) => CanonicStatement::Desnap {
input: self.handle_input(input),
output: self.handle_output(output),
},
Statement::IntoBox(StatementIntoBox { input, output }) => CanonicStatement::IntoBox {
input: self.handle_input(input),
output: self.handle_output(output),
},
Statement::Unbox(StatementUnbox { input, output }) => CanonicStatement::Unbox {
input: self.handle_input(input),
output: self.handle_output(output),
},
}
}
}
impl<'db> CanonicBlock<'db> {
fn try_from_block(
variable: &VariableArena<'db>,
block: &Block<'db>,
) -> Option<(CanonicBlock<'db>, Vec<VarUsage<'db>>)> {
let BlockEnd::Return(returned_vars, _) = &block.end else {
return None;
};
if block.statements.is_empty() {
return None;
}
let mut builder = CanonicBlockBuilder::new(variable);
let stmts = block
.statements
.iter()
.map(|statement| builder.handle_statement(statement))
.collect_vec();
let returns = returned_vars.iter().map(|input| builder.handle_input(input)).collect();
Some((CanonicBlock { stmts, types: builder.types, returns }, builder.inputs))
}
}
pub struct VarReassigner<'db, 'a> {
pub variables: &'a mut VariableArena<'db>,
pub vars: UnorderedHashMap<VariableId, VariableId>,
}
impl<'db, 'a> VarReassigner<'db, 'a> {
pub fn new(variables: &'a mut VariableArena<'db>) -> Self {
Self { variables, vars: UnorderedHashMap::default() }
}
}
impl<'db, 'a> Rebuilder<'db> for VarReassigner<'db, 'a> {
fn map_var_id(&mut self, var: VariableId) -> VariableId {
*self.vars.entry(var).or_insert_with(|| self.variables.alloc(self.variables[var].clone()))
}
}
#[derive(Default)]
struct DedupContext<'db> {
canonic_blocks: UnorderedHashMap<CanonicBlock<'db>, BlockId>,
block_id_to_inputs: UnorderedHashMap<BlockId, Vec<VarUsage<'db>>>,
}
fn rebuild_block_and_inputs<'db>(
variables: &mut VariableArena<'db>,
block: &Block<'db>,
inputs: &[VarUsage<'db>],
) -> (Block<'db>, Vec<VarUsage<'db>>) {
let mut var_reassigner = VarReassigner::new(variables);
(
var_reassigner.rebuild_block(block),
inputs.iter().map(|var_usage| var_reassigner.map_var_usage(*var_usage)).collect(),
)
}
pub fn dedup_blocks<'db>(lowered: &mut Lowered<'db>) {
if lowered.blocks.has_root().is_err() {
return;
}
let mut ctx = DedupContext::default();
let mut duplicates: UnorderedHashMap<BlockId, (BlockId, Vec<VarUsage<'_>>)> =
Default::default();
let mut new_blocks = vec![];
let mut next_block_id = BlockId(lowered.blocks.len());
for (block_id, block) in lowered.blocks.iter() {
let Some((canonical_block, inputs)) =
CanonicBlock::try_from_block(&lowered.variables, block)
else {
continue;
};
match ctx.canonic_blocks.entry(canonical_block) {
unordered_hash_map::Entry::Occupied(e) => {
let block_and_inputs = duplicates
.entry(*e.get())
.or_insert_with(|| {
let (block, new_inputs) =
rebuild_block_and_inputs(&mut lowered.variables, block, &inputs);
new_blocks.push(block);
let new_block_id = next_block_id;
next_block_id = next_block_id.next_block_id();
(new_block_id, new_inputs)
})
.clone();
duplicates.insert(block_id, block_and_inputs);
}
unordered_hash_map::Entry::Vacant(e) => {
e.insert(block_id);
}
};
ctx.block_id_to_inputs.insert(block_id, inputs);
}
let mut new_goto_block =
|block_id, inputs: &[VarUsage<'db>], target_inputs: &[VarUsage<'db>]| {
new_blocks.push(Block {
statements: vec![],
end: BlockEnd::Goto(
block_id,
VarRemapping {
remapping: OrderedHashMap::from_iter(zip_eq(
target_inputs.iter().map(|var_usage| var_usage.var_id),
inputs.iter().cloned(),
)),
},
),
});
let new_block_id = next_block_id;
next_block_id = next_block_id.next_block_id();
new_block_id
};
for block in lowered.blocks.iter_mut() {
match &mut block.end {
BlockEnd::Goto(target_block, remappings) => {
let Some((block_id, target_inputs)) = duplicates.get(target_block) else {
continue;
};
let inputs = ctx.block_id_to_inputs.get(target_block).unwrap();
let mut inputs_remapping = VarRemapping {
remapping: OrderedHashMap::from_iter(zip_eq(
target_inputs.iter().map(|var_usage| var_usage.var_id),
inputs.iter().cloned(),
)),
};
for (_, src) in inputs_remapping.iter_mut() {
if let Some(src_before_remapping) = remappings.get(&src.var_id) {
*src = *src_before_remapping;
}
}
*target_block = *block_id;
*remappings = inputs_remapping;
}
BlockEnd::Match { info } => {
for arm in info.arms_mut() {
let Some((block_id, target_inputs)) = duplicates.get(&arm.block_id) else {
continue;
};
let inputs = &ctx.block_id_to_inputs[&arm.block_id];
arm.block_id = new_goto_block(*block_id, inputs, target_inputs);
}
}
_ => {}
}
}
for block in new_blocks {
lowered.blocks.push(block);
}
}