#[cfg(test)]
#[path = "cse_test.rs"]
mod test;
use std::iter::zip;
use cairo_lang_semantic::items::constant::ConstValueId;
use cairo_lang_semantic::{ConcreteVariant, TypeId};
use cairo_lang_utils::unordered_hash_map::{Entry, UnorderedHashMap};
use itertools::Itertools;
use crate::ids::FunctionId;
use crate::optimizations::var_renamer::VarRenamer;
use crate::utils::RebuilderEx;
use crate::{BlockEnd, BlockId, Lowered, Statement, VariableArena, VariableId};
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
enum ExpressionKey<'db> {
Const(ConstValueId<'db>, bool),
StructConstruct(TypeId<'db>, Vec<VariableId>),
StructDestructure(VariableId),
EnumConstruct(ConcreteVariant<'db>, VariableId),
Snapshot(VariableId),
Desnap(VariableId),
PureCall(FunctionId<'db>, Vec<VariableId>),
}
struct CseContext<'db> {
expression_map: UnorderedHashMap<ExpressionKey<'db>, Vec<VariableId>>,
var_replacements: UnorderedHashMap<VariableId, VariableId>,
snapshot_remappings: UnorderedHashMap<VariableId, VariableId>,
variables: &'db VariableArena<'db>,
}
impl<'db> CseContext<'db> {
fn new(variables: &'db VariableArena<'db>) -> Self {
Self {
expression_map: UnorderedHashMap::default(),
var_replacements: UnorderedHashMap::default(),
snapshot_remappings: UnorderedHashMap::default(),
variables,
}
}
fn process_statement(&mut self, stmt: &Statement<'db>) -> bool {
let outputs = stmt.outputs();
if outputs.iter().any(|output| self.variables[*output].info.copyable.is_err())
|| stmt
.inputs()
.iter()
.any(|input| self.variables[input.var_id].info.droppable.is_err())
{
return false;
}
let key = match stmt {
Statement::Const(c) => ExpressionKey::Const(c.value, c.boxed),
Statement::StructConstruct(s) => ExpressionKey::StructConstruct(
self.variables[s.output].ty,
s.inputs.iter().map(|usage| self.resolve_var(usage.var_id)).collect(),
),
Statement::StructDestructure(s) => {
ExpressionKey::StructDestructure(self.resolve_var(s.input.var_id))
}
Statement::EnumConstruct(s) => {
ExpressionKey::EnumConstruct(s.variant, self.resolve_var(s.input.var_id))
}
Statement::Snapshot(s) => {
self.snapshot_remappings.insert(s.original(), s.input.var_id);
ExpressionKey::Snapshot(self.resolve_var(s.input.var_id))
}
Statement::Desnap(s) => ExpressionKey::Desnap(self.resolve_var(s.input.var_id)),
Statement::Call(s) if self.is_pure_function(&s.function) => ExpressionKey::PureCall(
s.function,
s.inputs.iter().map(|usage| self.resolve_var(usage.var_id)).collect(),
),
_ => return false, };
match self.expression_map.entry(key) {
Entry::Vacant(entry) => {
entry.insert(outputs.to_vec());
false
}
Entry::Occupied(entry) => {
self.var_replacements
.extend(zip(outputs.iter().copied(), entry.get().iter().copied()));
true
}
}
}
fn resolve_var(&self, var: VariableId) -> VariableId {
match self.var_replacements.get(&var).or_else(|| self.snapshot_remappings.get(&var)) {
Some(&replacement) => self.resolve_var(replacement),
None => var,
}
}
fn is_pure_function(&self, _function: &FunctionId<'db>) -> bool {
false
}
}
pub fn cse<'db>(lowered: &mut Lowered<'db>) {
if lowered.blocks.is_empty() {
return;
}
let mut ctx = CseContext::new(&lowered.variables);
let mut block_expression_map = UnorderedHashMap::<BlockId, _>::default();
block_expression_map.insert(BlockId::root(), Default::default());
for block_id in (0..lowered.blocks.len()).map(BlockId) {
let block = &mut lowered.blocks[block_id];
ctx.expression_map = block_expression_map
.remove(&block_id)
.unwrap_or_else(|| panic!("{block_id:?} expressions were not prepared"));
let mut statements_to_remove = Vec::new();
for (stmt_idx, stmt) in block.statements.iter().enumerate() {
if ctx.process_statement(stmt) {
statements_to_remove.push(stmt_idx);
}
}
for stmt_idx in statements_to_remove.into_iter().rev() {
block.statements.remove(stmt_idx);
}
match &block.end {
BlockEnd::NotSet | BlockEnd::Return(..) | BlockEnd::Panic(..) => {}
BlockEnd::Match { info } => {
for arm in info.arms() {
let next = arm.block_id;
assert!(
block_expression_map.insert(next, ctx.expression_map.clone()).is_none(),
"{next:?} was previously propagated - should not happen on match arms.",
);
}
}
BlockEnd::Goto(block_id, _remapping) => match block_expression_map.entry(*block_id) {
Entry::Vacant(entry) => {
entry.insert(std::mem::take(&mut ctx.expression_map));
}
Entry::Occupied(mut entry) => {
let e = std::mem::take(entry.get_mut());
entry.insert(e.filter(|k, v| {
if let Some(new_val) = ctx.expression_map.remove(k) {
new_val == *v
} else {
false
}
}));
}
},
}
}
assert!(
block_expression_map.is_empty(),
"Some blocks were not processed: [{}]",
block_expression_map
.iter_sorted_by_key(|(k, _)| k.0)
.map(|(k, _)| format!("{k:?}"))
.join(", ")
);
let CseContext { var_replacements: renamed_vars, .. } = ctx;
let mut renamer = VarRenamer { renamed_vars };
for block in lowered.blocks.iter_mut() {
*block = renamer.rebuild_block(block);
}
}