#[cfg(test)]
#[path = "return_optimization_test.rs"]
mod test;
use cairo_lang_semantic::types::TypesSemantic;
use cairo_lang_semantic::{self as semantic, ConcreteTypeId, TypeId, TypeLongId};
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use cairo_lang_utils::{Intern, require};
use salsa::Database;
use semantic::MatchArmSelector;
use crate::analysis::{Analyzer, BackAnalysis, StatementLocation};
use crate::ids::LocationId;
use crate::{
Block, BlockEnd, BlockId, Lowered, MatchArm, MatchEnumInfo, MatchInfo, Statement,
StatementEnumConstruct, StatementStructConstruct, StatementStructDestructure, VarRemapping,
VarUsage, Variable, VariableArena, VariableId,
};
pub fn return_optimization<'db>(db: &'db dyn Database, lowered: &mut Lowered<'db>) {
if lowered.blocks.is_empty() {
return;
}
let ctx = ReturnOptimizerContext { db, lowered, fixes: vec![] };
let mut analysis = BackAnalysis::new(lowered, ctx);
analysis.get_root_info();
let ctx = analysis.analyzer;
let ReturnOptimizerContext { fixes, .. } = ctx;
for FixInfo { location: (block_id, statement_idx), return_info } in fixes {
let block = &mut lowered.blocks[block_id];
block.statements.truncate(statement_idx);
let mut ctx = EarlyReturnContext {
db,
constructed: UnorderedHashMap::default(),
variables: &mut lowered.variables,
statements: &mut block.statements,
location: return_info.location,
};
let vars = ctx.prepare_early_return_vars(&return_info.returned_vars);
block.end = BlockEnd::Return(vars, return_info.location)
}
}
struct EarlyReturnContext<'db, 'a> {
db: &'db dyn Database,
constructed: UnorderedHashMap<(TypeId<'db>, Vec<VariableId>), VariableId>,
variables: &'a mut VariableArena<'db>,
statements: &'a mut Vec<Statement<'db>>,
location: LocationId<'db>,
}
impl<'db, 'a> EarlyReturnContext<'db, 'a> {
fn prepare_early_return_vars(&mut self, ret_infos: &[ValueInfo<'db>]) -> Vec<VarUsage<'db>> {
let mut res = vec![];
for var_info in ret_infos.iter() {
match var_info {
ValueInfo::Var(var_usage) => {
res.push(*var_usage);
}
ValueInfo::StructConstruct { ty, var_infos } => {
let inputs = self.prepare_early_return_vars(var_infos);
let output = *self
.constructed
.entry((*ty, inputs.iter().map(|var_usage| var_usage.var_id).collect()))
.or_insert_with(|| {
let output = self.variables.alloc(Variable::with_default_context(
self.db,
*ty,
self.location,
));
self.statements.push(Statement::StructConstruct(
StatementStructConstruct { inputs, output },
));
output
});
res.push(VarUsage { var_id: output, location: self.location });
}
ValueInfo::EnumConstruct { var_info, variant } => {
let input = self.prepare_early_return_vars(std::slice::from_ref(var_info))[0];
let ty = TypeLongId::Concrete(ConcreteTypeId::Enum(variant.concrete_enum_id))
.intern(self.db);
let output =
*self.constructed.entry((ty, vec![input.var_id])).or_insert_with(|| {
let output = self.variables.alloc(Variable::with_default_context(
self.db,
ty,
self.location,
));
self.statements.push(Statement::EnumConstruct(
StatementEnumConstruct { variant: *variant, input, output },
));
output
});
res.push(VarUsage { var_id: output, location: self.location });
}
ValueInfo::Interchangeable(_) => {
unreachable!("early_return_possible should have prevented this.")
}
}
}
res
}
}
pub struct ReturnOptimizerContext<'db, 'a> {
db: &'db dyn Database,
lowered: &'a Lowered<'db>,
fixes: Vec<FixInfo<'db>>,
}
impl<'db, 'a> ReturnOptimizerContext<'db, 'a> {
fn get_var_info(&self, var_usage: &VarUsage<'db>) -> ValueInfo<'db> {
let var_ty = &self.lowered.variables[var_usage.var_id].ty;
if self.is_droppable(var_usage.var_id) && self.db.single_value_type(*var_ty).unwrap() {
ValueInfo::Interchangeable(*var_ty)
} else {
ValueInfo::Var(*var_usage)
}
}
fn is_droppable(&self, var_id: VariableId) -> bool {
self.lowered.variables[var_id].info.droppable.is_ok()
}
fn try_merge_match(
&mut self,
match_info: &MatchInfo<'db>,
infos: impl Iterator<Item = AnalyzerInfo<'db>>,
) -> Option<ReturnInfo<'db>> {
let MatchInfo::Enum(MatchEnumInfo { input, arms, .. }) = match_info else {
return None;
};
require(!arms.is_empty())?;
let input_info = self.get_var_info(input);
let mut opt_last_info = None;
for (arm, info) in arms.iter().zip(infos) {
let mut curr_info = info.clone();
curr_info.apply_match_arm(self.is_droppable(input.var_id), &input_info, arm);
match curr_info.try_get_early_return_info() {
Some(return_info)
if opt_last_info
.map(|x: ReturnInfo<'_>| x.returned_vars == return_info.returned_vars)
.unwrap_or(true) =>
{
opt_last_info = Some(return_info.clone())
}
_ => return None,
}
}
Some(opt_last_info.unwrap())
}
}
pub struct FixInfo<'db> {
location: StatementLocation,
return_info: ReturnInfo<'db>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ValueInfo<'db> {
Var(VarUsage<'db>),
Interchangeable(semantic::TypeId<'db>),
StructConstruct {
ty: semantic::TypeId<'db>,
var_infos: Vec<ValueInfo<'db>>,
},
EnumConstruct {
var_info: Box<ValueInfo<'db>>,
variant: semantic::ConcreteVariant<'db>,
},
}
enum OpResult {
InputConsumed,
ValueInvalidated,
NoChange,
}
impl<'db> ValueInfo<'db> {
fn apply<F>(&mut self, f: &F)
where
F: Fn(&VarUsage<'db>) -> ValueInfo<'db>,
{
match self {
ValueInfo::Var(var_usage) => *self = f(var_usage),
ValueInfo::StructConstruct { ty: _, var_infos } => {
for var_info in var_infos.iter_mut() {
var_info.apply(f);
}
}
ValueInfo::EnumConstruct { var_info, .. } => {
var_info.apply(f);
}
ValueInfo::Interchangeable(_) => {}
}
}
fn apply_deconstruct(
&mut self,
ctx: &ReturnOptimizerContext<'db, '_>,
stmt: &StatementStructDestructure<'db>,
) -> OpResult {
match self {
ValueInfo::Var(var_usage) => {
if stmt.outputs.contains(&var_usage.var_id) {
OpResult::ValueInvalidated
} else {
OpResult::NoChange
}
}
ValueInfo::StructConstruct { ty, var_infos } => {
let mut cancels_out = ty == &ctx.lowered.variables[stmt.input.var_id].ty
&& var_infos.len() == stmt.outputs.len();
for (var_info, output) in var_infos.iter().zip(stmt.outputs.iter()) {
if !cancels_out {
break;
}
match var_info {
ValueInfo::Var(var_usage) if &var_usage.var_id == output => {}
ValueInfo::Interchangeable(ty)
if &ctx.lowered.variables[*output].ty == ty => {}
_ => cancels_out = false,
}
}
if cancels_out {
*self = ValueInfo::Var(stmt.input);
return OpResult::InputConsumed;
}
let mut input_consumed = false;
for var_info in var_infos.iter_mut() {
match var_info.apply_deconstruct(ctx, stmt) {
OpResult::InputConsumed => {
input_consumed = true;
}
OpResult::ValueInvalidated => {
return OpResult::ValueInvalidated;
}
OpResult::NoChange => {}
}
}
match input_consumed {
true => OpResult::InputConsumed,
false => OpResult::NoChange,
}
}
ValueInfo::EnumConstruct { var_info, .. } => var_info.apply_deconstruct(ctx, stmt),
ValueInfo::Interchangeable(_) => OpResult::NoChange,
}
}
fn apply_match_arm(&mut self, input: &ValueInfo<'db>, arm: &MatchArm<'db>) -> OpResult {
match self {
ValueInfo::Var(var_usage) => {
if arm.var_ids == [var_usage.var_id] {
OpResult::ValueInvalidated
} else {
OpResult::NoChange
}
}
ValueInfo::StructConstruct { ty: _, var_infos } => {
let mut input_consumed = false;
for var_info in var_infos.iter_mut() {
match var_info.apply_match_arm(input, arm) {
OpResult::InputConsumed => {
input_consumed = true;
}
OpResult::ValueInvalidated => return OpResult::ValueInvalidated,
OpResult::NoChange => {}
}
}
if input_consumed {
return OpResult::InputConsumed;
}
OpResult::NoChange
}
ValueInfo::EnumConstruct { var_info, variant } => {
let MatchArmSelector::VariantId(arm_variant) = &arm.arm_selector else {
panic!("Enum construct should not appear in value match");
};
if *variant == *arm_variant {
let cancels_out = match **var_info {
ValueInfo::Interchangeable(_) => true,
ValueInfo::Var(var_usage) if arm.var_ids == [var_usage.var_id] => true,
_ => false,
};
if cancels_out {
*self = input.clone();
return OpResult::InputConsumed;
}
}
var_info.apply_match_arm(input, arm)
}
ValueInfo::Interchangeable(_) => OpResult::NoChange,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ReturnInfo<'db> {
returned_vars: Vec<ValueInfo<'db>>,
location: LocationId<'db>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct AnalyzerInfo<'db> {
opt_return_info: Option<ReturnInfo<'db>>,
}
impl<'db> AnalyzerInfo<'db> {
fn invalidated() -> Self {
AnalyzerInfo { opt_return_info: None }
}
fn invalidate(&mut self) {
*self = Self::invalidated();
}
fn apply<F>(&mut self, f: &F)
where
F: Fn(&VarUsage<'db>) -> ValueInfo<'db>,
{
let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else {
return;
};
for var_info in returned_vars.iter_mut() {
var_info.apply(f)
}
}
fn replace(&mut self, var_id: VariableId, var_info: ValueInfo<'db>) {
self.apply(&|var_usage| {
if var_usage.var_id == var_id { var_info.clone() } else { ValueInfo::Var(*var_usage) }
});
}
fn apply_deconstruct(
&mut self,
ctx: &ReturnOptimizerContext<'db, '_>,
stmt: &StatementStructDestructure<'db>,
) {
let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
let mut input_consumed = false;
for var_info in returned_vars.iter_mut() {
match var_info.apply_deconstruct(ctx, stmt) {
OpResult::InputConsumed => {
input_consumed = true;
}
OpResult::ValueInvalidated => {
self.invalidate();
return;
}
OpResult::NoChange => {}
};
}
if !(input_consumed || ctx.is_droppable(stmt.input.var_id)) {
self.invalidate();
}
}
fn apply_match_arm(&mut self, is_droppable: bool, input: &ValueInfo<'db>, arm: &MatchArm<'db>) {
let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
let mut input_consumed = false;
for var_info in returned_vars.iter_mut() {
match var_info.apply_match_arm(input, arm) {
OpResult::InputConsumed => {
input_consumed = true;
}
OpResult::ValueInvalidated => {
self.invalidate();
return;
}
OpResult::NoChange => {}
};
}
if !(input_consumed || is_droppable) {
self.invalidate();
}
}
fn try_get_early_return_info(&self) -> Option<&ReturnInfo<'db>> {
let return_info = self.opt_return_info.as_ref()?;
let mut stack = return_info.returned_vars.clone();
while let Some(var_info) = stack.pop() {
match var_info {
ValueInfo::Var(_) => {}
ValueInfo::StructConstruct { ty: _, var_infos } => stack.extend(var_infos),
ValueInfo::EnumConstruct { var_info, variant: _ } => stack.push(*var_info),
ValueInfo::Interchangeable(_) => return None,
}
}
Some(return_info)
}
}
impl<'db, 'a> Analyzer<'db, 'a> for ReturnOptimizerContext<'db, 'a> {
type Info = AnalyzerInfo<'db>;
fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &Block<'db>) {
if let Some(return_info) = info.try_get_early_return_info() {
self.fixes.push(FixInfo { location: (block_id, 0), return_info: return_info.clone() });
}
}
fn visit_stmt(
&mut self,
info: &mut Self::Info,
(block_idx, statement_idx): StatementLocation,
stmt: &'a Statement<'db>,
) {
let opt_early_return_info = info.try_get_early_return_info().cloned();
match stmt {
Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
info.replace(
*output,
ValueInfo::StructConstruct {
ty: self.lowered.variables[*output].ty,
var_infos: inputs.iter().map(|input| self.get_var_info(input)).collect(),
},
);
}
Statement::StructDestructure(stmt) => info.apply_deconstruct(self, stmt),
Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
info.replace(
*output,
ValueInfo::EnumConstruct {
var_info: Box::new(self.get_var_info(input)),
variant: *variant,
},
);
}
_ => info.invalidate(),
}
if let Some(early_return_info) = opt_early_return_info
&& info.try_get_early_return_info().is_none()
{
self.fixes.push(FixInfo {
location: (block_idx, statement_idx + 1),
return_info: early_return_info,
});
}
}
fn visit_goto(
&mut self,
info: &mut Self::Info,
_statement_location: StatementLocation,
_target_block_id: BlockId,
remapping: &VarRemapping<'db>,
) {
info.apply(&|var_usage| {
if let Some(usage) = remapping.get(&var_usage.var_id) {
ValueInfo::Var(*usage)
} else {
ValueInfo::Var(*var_usage)
}
});
}
fn merge_match(
&mut self,
_statement_location: StatementLocation,
match_info: &'a MatchInfo<'db>,
infos: impl Iterator<Item = Self::Info>,
) -> Self::Info {
Self::Info { opt_return_info: self.try_merge_match(match_info, infos) }
}
fn info_from_return(
&mut self,
(block_id, _statement_idx): StatementLocation,
vars: &'a [VarUsage<'db>],
) -> Self::Info {
let location = match &self.lowered.blocks[block_id].end {
BlockEnd::Return(_vars, location) => *location,
_ => unreachable!(),
};
AnalyzerInfo {
opt_return_info: Some(ReturnInfo {
returned_vars: vars.iter().map(|var_usage| ValueInfo::Var(*var_usage)).collect(),
location,
}),
}
}
}