#[cfg(test)]
#[path = "reorder_statements_test.rs"]
mod test;
use std::cmp::Reverse;
use cairo_lang_defs::ids::ExternFunctionId;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::unordered_hash_map::{Entry, UnorderedHashMap};
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
use itertools::{Itertools, zip_eq};
use salsa::Database;
use crate::analysis::{Analyzer, BackAnalysis, StatementLocation};
use crate::db::LoweringGroup;
use crate::{
BlockId, Lowered, MatchInfo, Statement, StatementCall, VarRemapping, VarUsage, VariableId,
};
pub fn reorder_statements(db: &dyn Database, lowered: &mut Lowered<'_>) {
if lowered.blocks.is_empty() {
return;
}
let ctx = ReorderStatementsContext {
db,
lowered: &*lowered,
moveable_functions: db.priv_movable_function_ids(),
statement_to_move: vec![],
};
let mut analysis = BackAnalysis::new(lowered, ctx);
analysis.get_root_info();
let ctx = analysis.analyzer;
let mut changes_by_block =
OrderedHashMap::<BlockId, Vec<(usize, Option<Statement<'_>>)>>::default();
for (src, opt_dst) in ctx.statement_to_move {
changes_by_block.entry(src.0).or_insert_with(Vec::new).push((src.1, None));
if let Some(dst) = opt_dst {
let statement = lowered[src].clone();
changes_by_block.entry(dst.0).or_insert_with(Vec::new).push((dst.1, Some(statement)));
}
}
for (block_id, block_changes) in changes_by_block {
let statements = &mut lowered.blocks[block_id].statements;
for (index, opt_statement) in
block_changes.into_iter().sorted_by_key(|(index, _)| Reverse(*index))
{
match opt_statement {
Some(stmt) => statements.insert(index, stmt),
None => {
statements.remove(index);
}
}
}
}
}
#[derive(Clone, Default)]
pub struct ReorderStatementsInfo {
next_use: UnorderedHashMap<VariableId, StatementLocation>,
}
pub struct ReorderStatementsContext<'db> {
db: &'db dyn Database,
lowered: &'db Lowered<'db>,
moveable_functions: &'db UnorderedHashSet<ExternFunctionId<'db>>,
statement_to_move: Vec<(StatementLocation, Option<StatementLocation>)>,
}
impl<'db> ReorderStatementsContext<'db> {
fn call_can_be_moved(&mut self, stmt: &StatementCall<'db>) -> bool {
if let Some((extern_id, _)) = stmt.function.get_extern(self.db) {
self.moveable_functions.contains(&extern_id)
} else {
false
}
}
}
impl<'db> Analyzer<'db, '_> for ReorderStatementsContext<'db> {
type Info = ReorderStatementsInfo;
fn visit_stmt(
&mut self,
info: &mut Self::Info,
statement_location: StatementLocation,
stmt: &Statement<'db>,
) {
let mut immovable = matches!(stmt, Statement::Call(stmt) if !self.call_can_be_moved(stmt));
let mut optional_target_location = None;
for var_to_move in stmt.outputs() {
let Some((block_id, index)) = info.next_use.remove(var_to_move) else { continue };
if let Some((target_block_id, target_index)) = &mut optional_target_location {
*target_index = std::cmp::min(*target_index, index);
immovable |= target_block_id != &block_id;
} else {
optional_target_location = Some((block_id, index));
}
}
if immovable {
for var_usage in stmt.inputs() {
info.next_use.insert(var_usage.var_id, statement_location);
}
return;
}
if let Some(target_location) = optional_target_location {
for var_usage in stmt.inputs() {
match info.next_use.entry(var_usage.var_id) {
Entry::Occupied(mut e) => {
&e.insert(statement_location)
}
Entry::Vacant(e) => e.insert(target_location),
};
}
self.statement_to_move.push((statement_location, Some(target_location)))
} else if stmt
.inputs()
.iter()
.all(|v| self.lowered.variables[v.var_id].info.droppable.is_ok())
{
self.statement_to_move.push((statement_location, None))
} else {
for var_usage in stmt.inputs() {
info.next_use.insert(var_usage.var_id, statement_location);
}
}
}
fn visit_goto(
&mut self,
info: &mut Self::Info,
statement_location: StatementLocation,
_target_block_id: BlockId,
remapping: &VarRemapping<'db>,
) {
for VarUsage { var_id, .. } in remapping.values() {
info.next_use.insert(*var_id, statement_location);
}
}
fn merge_match(
&mut self,
statement_location: StatementLocation,
match_info: &MatchInfo<'db>,
infos: impl Iterator<Item = Self::Info>,
) -> Self::Info {
let mut infos = zip_eq(infos, match_info.arms()).map(|(mut info, arm)| {
for var_id in &arm.var_ids {
info.next_use.remove(var_id);
}
info
});
let mut info = infos.next().unwrap_or_default();
for arm_info in infos {
info.next_use.merge(&arm_info.next_use, |e, _| {
*e.into_mut() = statement_location;
});
}
for var_usage in match_info.inputs() {
info.next_use.insert(var_usage.var_id, statement_location);
}
info
}
fn info_from_return(
&mut self,
statement_location: StatementLocation,
vars: &[VarUsage<'db>],
) -> Self::Info {
let mut info = Self::Info::default();
for var_usage in vars {
info.next_use.insert(var_usage.var_id, statement_location);
}
info
}
}