#[cfg(test)]
#[path = "const_folding_test.rs"]
mod test;
use std::rc::Rc;
use std::sync::Arc;
use cairo_lang_defs::ids::{ExternFunctionId, FreeFunctionId};
use cairo_lang_filesystem::flag::FlagsGroup;
use cairo_lang_filesystem::ids::SmolStrId;
use cairo_lang_semantic::corelib::CorelibSemantic;
use cairo_lang_semantic::helper::ModuleHelper;
use cairo_lang_semantic::items::constant::{
ConstCalcInfo, ConstValue, ConstValueId, ConstantSemantic, TypeRange, canonical_felt252,
felt252_for_downcast,
};
use cairo_lang_semantic::items::functions::{GenericFunctionId, GenericFunctionWithBodyId};
use cairo_lang_semantic::items::structure::StructSemantic;
use cairo_lang_semantic::types::{TypeSizeInformation, TypesSemantic};
use cairo_lang_semantic::{
ConcreteTypeId, ConcreteVariant, GenericArgumentId, MatchArmSelector, TypeId, TypeLongId,
corelib,
};
use cairo_lang_utils::byte_array::BYTE_ARRAY_MAGIC;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use cairo_lang_utils::{Intern, extract_matches, require, try_extract_matches};
use itertools::{chain, zip_eq};
use num_bigint::BigInt;
use num_integer::Integer;
use num_traits::cast::ToPrimitive;
use num_traits::{Num, One, Zero};
use salsa::Database;
use starknet_types_core::felt::Felt as Felt252;
use crate::db::LoweringGroup;
use crate::ids::{
ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionId, SemanticFunctionIdEx,
SpecializedFunction,
};
use crate::specialization::SpecializationArg;
use crate::utils::InliningStrategy;
use crate::{
Block, BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, MatchArm, MatchEnumInfo,
MatchExternInfo, MatchInfo, Statement, StatementCall, StatementConst, StatementDesnap,
StatementEnumConstruct, StatementIntoBox, StatementSnapshot, StatementStructConstruct,
StatementStructDestructure, StatementUnbox, VarRemapping, VarUsage, Variable, VariableArena,
VariableId,
};
fn const_to_specialization_arg<'db>(
db: &'db dyn Database,
value: ConstValueId<'db>,
boxed: bool,
) -> SpecializationArg<'db> {
match value.long(db) {
ConstValue::Struct(members, ty) => {
if matches!(
ty.long(db),
TypeLongId::Concrete(ConcreteTypeId::Struct(_))
| TypeLongId::Tuple(_)
| TypeLongId::FixedSizeArray { .. }
) {
let args = members
.iter()
.map(|member| const_to_specialization_arg(db, *member, false))
.collect();
SpecializationArg::Struct(args)
} else {
SpecializationArg::Const { value, boxed }
}
}
ConstValue::Enum(variant, payload) => SpecializationArg::Enum {
variant: *variant,
payload: Box::new(const_to_specialization_arg(db, *payload, false)),
},
_ => SpecializationArg::Const { value, boxed },
}
}
#[derive(Debug, Clone)]
enum VarInfo<'db> {
Const(ConstValueId<'db>),
Var(VarUsage<'db>),
Snapshot(Rc<VarInfo<'db>>),
Struct(Vec<Option<Rc<VarInfo<'db>>>>),
Enum { variant: ConcreteVariant<'db>, payload: Rc<VarInfo<'db>> },
Box(Rc<VarInfo<'db>>),
Array(Vec<Option<Rc<VarInfo<'db>>>>),
}
impl<'db> VarInfo<'db> {
fn peel_snapshots(mut self: Rc<Self>) -> (usize, Rc<VarInfo<'db>>) {
let mut n_snapshots = 0;
while let VarInfo::Snapshot(inner) = self.as_ref() {
self = inner.clone();
n_snapshots += 1;
}
(n_snapshots, self)
}
fn wrap_with_snapshots(mut self: Rc<Self>, n_snapshots: usize) -> Rc<VarInfo<'db>> {
for _ in 0..n_snapshots {
self = VarInfo::Snapshot(self).into();
}
self
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum Reachability {
FromSingleGoto(BlockId),
Any,
}
pub fn const_folding<'db>(
db: &'db dyn Database,
function_id: ConcreteFunctionWithBodyId<'db>,
lowered: &mut Lowered<'db>,
) {
if lowered.blocks.is_empty() {
return;
}
let mut ctx = ConstFoldingContext::new(db, function_id, &mut lowered.variables);
if ctx.should_skip_const_folding(db) {
return;
}
for block_id in (0..lowered.blocks.len()).map(BlockId) {
if !ctx.visit_block_start(block_id, |block_id| &lowered.blocks[block_id]) {
continue;
}
let block = &mut lowered.blocks[block_id];
for stmt in block.statements.iter_mut() {
ctx.visit_statement(stmt);
}
ctx.visit_block_end(block_id, block);
}
}
pub struct ConstFoldingContext<'db, 'mt> {
db: &'db dyn Database,
pub variables: &'mt mut VariableArena<'db>,
var_info: UnorderedHashMap<VariableId, Rc<VarInfo<'db>>>,
libfunc_info: &'db ConstFoldingLibfuncInfo<'db>,
caller_function: ConcreteFunctionWithBodyId<'db>,
reachability: UnorderedHashMap<BlockId, Reachability>,
additional_stmts: Vec<Statement<'db>>,
}
impl<'db, 'mt> ConstFoldingContext<'db, 'mt> {
pub fn new(
db: &'db dyn Database,
function_id: ConcreteFunctionWithBodyId<'db>,
variables: &'mt mut VariableArena<'db>,
) -> Self {
Self {
db,
var_info: UnorderedHashMap::default(),
variables,
libfunc_info: priv_const_folding_info(db),
caller_function: function_id,
reachability: UnorderedHashMap::from_iter([(BlockId::root(), Reachability::Any)]),
additional_stmts: vec![],
}
}
pub fn visit_block_start<'r, 'get>(
&'r mut self,
block_id: BlockId,
get_block: impl FnOnce(BlockId) -> &'get Block<'db>,
) -> bool
where
'db: 'get,
{
let Some(reachability) = self.reachability.remove(&block_id) else {
return false;
};
match reachability {
Reachability::Any => {}
Reachability::FromSingleGoto(from_block) => match &get_block(from_block).end {
BlockEnd::Goto(_, remapping) => {
for (dst, src) in remapping.iter() {
if let Some(v) = self.as_const(src.var_id) {
self.var_info.insert(*dst, VarInfo::Const(v).into());
}
}
}
_ => unreachable!("Expected a goto end"),
},
}
true
}
pub fn visit_statement(&mut self, stmt: &mut Statement<'db>) {
self.maybe_replace_inputs(stmt.inputs_mut());
match stmt {
Statement::Const(StatementConst { value, output, boxed }) if *boxed => {
self.var_info.insert(*output, VarInfo::Box(VarInfo::Const(*value).into()).into());
}
Statement::Const(StatementConst { value, output, .. }) => match value.long(self.db) {
ConstValue::Int(..)
| ConstValue::Struct(..)
| ConstValue::Enum(..)
| ConstValue::NonZero(..) => {
self.var_info.insert(*output, VarInfo::Const(*value).into());
}
ConstValue::Generic(_)
| ConstValue::ImplConstant(_)
| ConstValue::Var(..)
| ConstValue::Missing(_) => {}
},
Statement::Snapshot(stmt) => {
if let Some(info) = self.var_info.get(&stmt.input.var_id) {
let info = info.clone();
self.var_info.insert(stmt.original(), info.clone());
self.var_info.insert(stmt.snapshot(), VarInfo::Snapshot(info).into());
}
}
Statement::Desnap(StatementDesnap { input, output }) => {
if let Some(info) = self.var_info.get(&input.var_id)
&& let VarInfo::Snapshot(info) = info.as_ref()
{
self.var_info.insert(*output, info.clone());
}
}
Statement::Call(call_stmt) => {
if let Some(updated_stmt) = self.handle_statement_call(call_stmt) {
*stmt = updated_stmt;
} else if let Some(updated_stmt) = self.try_specialize_call(call_stmt) {
*stmt = updated_stmt;
}
}
Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
let mut const_args = vec![];
let mut all_args = vec![];
let mut contains_info = false;
for input in inputs.iter() {
let Some(info) = self.var_info.get(&input.var_id) else {
all_args.push(var_info_if_copy(self.variables, *input));
continue;
};
contains_info = true;
if let VarInfo::Const(value) = info.as_ref() {
const_args.push(*value);
}
all_args.push(Some(info.clone()));
}
if const_args.len() == inputs.len() {
let value =
ConstValue::Struct(const_args, self.variables[*output].ty).intern(self.db);
self.var_info.insert(*output, VarInfo::Const(value).into());
} else if contains_info {
self.var_info.insert(*output, VarInfo::Struct(all_args).into());
}
}
Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
if let Some(info) = self.var_info.get(&input.var_id) {
let (n_snapshots, info) = info.clone().peel_snapshots();
match info.as_ref() {
VarInfo::Const(const_value) => {
if let ConstValue::Struct(member_values, _) = const_value.long(self.db)
{
for (output, value) in zip_eq(outputs, member_values) {
self.var_info.insert(
*output,
Rc::new(VarInfo::Const(*value))
.wrap_with_snapshots(n_snapshots),
);
}
}
}
VarInfo::Struct(members) => {
for (output, member) in zip_eq(outputs, members.clone()) {
if let Some(member) = member {
self.var_info
.insert(*output, member.wrap_with_snapshots(n_snapshots));
}
}
}
_ => {}
}
}
}
Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
let value = if let Some(info) = self.var_info.get(&input.var_id) {
if let VarInfo::Const(val) = info.as_ref() {
VarInfo::Const(ConstValue::Enum(*variant, *val).intern(self.db))
} else {
VarInfo::Enum { variant: *variant, payload: info.clone() }
}
} else {
VarInfo::Enum { variant: *variant, payload: VarInfo::Var(*input).into() }
};
self.var_info.insert(*output, value.into());
}
Statement::IntoBox(StatementIntoBox { input, output }) => {
let var_info = self.var_info.get(&input.var_id);
let const_value = var_info.and_then(|var_info| match var_info.as_ref() {
VarInfo::Const(val) => Some(*val),
VarInfo::Snapshot(info) => {
try_extract_matches!(info.as_ref(), VarInfo::Const).copied()
}
_ => None,
});
let var_info =
var_info.cloned().or_else(|| var_info_if_copy(self.variables, *input));
if let Some(var_info) = var_info {
self.var_info.insert(*output, VarInfo::Box(var_info).into());
}
if let Some(const_value) = const_value {
*stmt = Statement::Const(StatementConst::new_boxed(const_value, *output));
}
}
Statement::Unbox(StatementUnbox { input, output }) => {
if let Some(inner) = self.var_info.get(&input.var_id)
&& let VarInfo::Box(inner) = inner.as_ref()
{
let inner = inner.clone();
if let VarInfo::Const(inner) =
self.var_info.entry(*output).insert_entry(inner).get().as_ref()
{
*stmt = Statement::Const(StatementConst::new_flat(*inner, *output));
}
}
}
}
}
pub fn visit_block_end(&mut self, block_id: BlockId, block: &mut Block<'db>) {
let statements = &mut block.statements;
statements.splice(0..0, self.additional_stmts.drain(..));
match &mut block.end {
BlockEnd::Goto(_, remappings) => {
for (_, v) in remappings.iter_mut() {
self.maybe_replace_input(v);
}
}
BlockEnd::Match { info } => {
self.maybe_replace_inputs(info.inputs_mut());
match info {
MatchInfo::Enum(info) => {
if let Some(updated_end) = self.handle_enum_block_end(info, statements) {
block.end = updated_end;
}
}
MatchInfo::Extern(info) => {
if let Some(updated_end) = self.handle_extern_block_end(info, statements) {
block.end = updated_end;
}
}
MatchInfo::Value(info) => {
if let Some(value) =
self.as_int(info.input.var_id).and_then(|x| x.to_usize())
&& let Some(arm) = info.arms.iter().find(|arm| {
matches!(
&arm.arm_selector,
MatchArmSelector::Value(v) if v.value == value
)
})
{
statements.push(Statement::StructConstruct(StatementStructConstruct {
inputs: vec![],
output: arm.var_ids[0],
}));
block.end = BlockEnd::Goto(arm.block_id, Default::default());
}
}
}
}
BlockEnd::Return(inputs, _) => self.maybe_replace_inputs(inputs),
BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
}
match &block.end {
BlockEnd::Goto(dst_block_id, _) => {
match self.reachability.entry(*dst_block_id) {
std::collections::hash_map::Entry::Occupied(mut e) => {
e.insert(Reachability::Any)
}
std::collections::hash_map::Entry::Vacant(e) => {
*e.insert(Reachability::FromSingleGoto(block_id))
}
};
}
BlockEnd::Match { info } => {
for arm in info.arms() {
assert!(self.reachability.insert(arm.block_id, Reachability::Any).is_none());
}
}
BlockEnd::NotSet | BlockEnd::Return(..) | BlockEnd::Panic(..) => {}
}
}
fn handle_statement_call(&mut self, stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
let db = self.db;
if stmt.function == self.panic_with_felt252 {
let val = self.as_const(stmt.inputs[0].var_id)?;
stmt.inputs.clear();
stmt.function = GenericFunctionId::Free(self.panic_with_const_felt252)
.concretize(db, vec![GenericArgumentId::Constant(val)])
.lowered(db);
return None;
} else if stmt.function == self.panic_with_byte_array && !db.flag_unsafe_panic() {
let snap = self.var_info.get(&stmt.inputs[0].var_id)?;
let bytearray = try_extract_matches!(snap.as_ref(), VarInfo::Snapshot)?;
let [Some(data), Some(pending_word), Some(pending_len)] =
&try_extract_matches!(bytearray.as_ref(), VarInfo::Struct)?[..]
else {
return None;
};
let data = try_extract_matches!(data.as_ref(), VarInfo::Array)?;
let pending_word = try_extract_matches!(pending_word.as_ref(), VarInfo::Const)?;
let pending_len = try_extract_matches!(pending_len.as_ref(), VarInfo::Const)?;
let mut panic_data =
vec![BigInt::from_str_radix(BYTE_ARRAY_MAGIC, 16).unwrap(), data.len().into()];
for word in data {
let VarInfo::Const(word) = word.as_ref()?.as_ref() else {
return None;
};
panic_data.push(word.long(db).to_int()?.clone());
}
panic_data.extend([
pending_word.long(db).to_int()?.clone(),
pending_len.long(db).to_int()?.clone(),
]);
let felt252_ty = self.felt252;
let location = stmt.location;
let new_var = |ty| Variable::with_default_context(db, ty, location);
let as_usage = |var_id| VarUsage { var_id, location };
let array_fn = |extern_id| {
let args = vec![GenericArgumentId::Type(felt252_ty)];
GenericFunctionId::Extern(extern_id).concretize(db, args).lowered(db)
};
let call_stmt = |function, inputs, outputs| {
let with_coupon = false;
Statement::Call(StatementCall {
function,
inputs,
with_coupon,
outputs,
location,
is_specialization_base_call: false,
})
};
let arr_var = new_var(corelib::core_array_felt252_ty(db));
let mut arr = self.variables.alloc(arr_var.clone());
self.additional_stmts.push(call_stmt(array_fn(self.array_new), vec![], vec![arr]));
let felt252_var = new_var(felt252_ty);
let arr_append_fn = array_fn(self.array_append);
for word in panic_data {
let to_append = self.variables.alloc(felt252_var.clone());
let new_arr = self.variables.alloc(arr_var.clone());
self.additional_stmts.push(Statement::Const(StatementConst::new_flat(
ConstValue::Int(word, felt252_ty).intern(db),
to_append,
)));
self.additional_stmts.push(call_stmt(
arr_append_fn,
vec![as_usage(arr), as_usage(to_append)],
vec![new_arr],
));
arr = new_arr;
}
let panic_ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, "Panic"), vec![]);
let panic_var = self.variables.alloc(new_var(panic_ty));
self.additional_stmts.push(Statement::StructConstruct(StatementStructConstruct {
inputs: vec![],
output: panic_var,
}));
return Some(Statement::StructConstruct(StatementStructConstruct {
inputs: vec![as_usage(panic_var), as_usage(arr)],
output: stmt.outputs[0],
}));
}
let (id, _generic_args) = stmt.function.get_extern(db)?;
if id == self.felt_sub {
if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
&& rhs.is_zero()
{
self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]).into());
None
} else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
&& let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
{
let value = canonical_felt252(&(lhs - rhs));
Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
} else {
None
}
} else if id == self.felt_add {
if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
&& lhs.is_zero()
{
self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]).into());
None
} else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
&& rhs.is_zero()
{
self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]).into());
None
} else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
&& let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
{
let value = canonical_felt252(&(lhs + rhs));
Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
} else {
None
}
} else if id == self.felt_mul {
let lhs = self.as_int(stmt.inputs[0].var_id);
let rhs = self.as_int(stmt.inputs[1].var_id);
if lhs.map(Zero::is_zero).unwrap_or_default()
|| rhs.map(Zero::is_zero).unwrap_or_default()
{
Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
} else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
&& rhs.is_one()
{
self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]).into());
None
} else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
&& lhs.is_one()
{
self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]).into());
None
} else if let Some(lhs) = lhs
&& let Some(rhs) = rhs
{
let value = canonical_felt252(&(lhs * rhs));
Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
} else {
None
}
} else if id == self.felt_div {
if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
&& rhs.is_one()
{
self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]).into());
None
} else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
&& lhs.is_zero()
{
Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
} else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
&& let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
&& let Ok(rhs_nonzero) = Felt252::from(rhs).try_into()
{
let lhs_felt = Felt252::from(lhs);
let value = lhs_felt.field_div(&rhs_nonzero).to_bigint();
Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
} else {
None
}
} else if self.wide_mul_fns.contains(&id) {
let lhs = self.as_int(stmt.inputs[0].var_id);
let rhs = self.as_int(stmt.inputs[1].var_id);
let output = stmt.outputs[0];
if lhs.map(Zero::is_zero).unwrap_or_default()
|| rhs.map(Zero::is_zero).unwrap_or_default()
{
return Some(self.propagate_zero_and_get_statement(output));
}
let lhs = lhs?;
Some(self.propagate_const_and_get_statement(lhs * rhs?, stmt.outputs[0]))
} else if id == self.bounded_int_add || id == self.bounded_int_sub {
let lhs = self.as_int(stmt.inputs[0].var_id)?;
let rhs = self.as_int(stmt.inputs[1].var_id)?;
let value = if id == self.bounded_int_add { lhs + rhs } else { lhs - rhs };
Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
} else if self.div_rem_fns.contains(&id) {
let lhs = self.as_int(stmt.inputs[0].var_id);
if lhs.map(Zero::is_zero).unwrap_or_default() {
let additional_stmt = self.propagate_zero_and_get_statement(stmt.outputs[1]);
self.additional_stmts.push(additional_stmt);
return Some(self.propagate_zero_and_get_statement(stmt.outputs[0]));
}
let rhs = self.as_int(stmt.inputs[1].var_id)?;
let (q, r) = lhs?.div_rem(rhs);
let q_output = stmt.outputs[0];
let q_value = ConstValue::Int(q, self.variables[q_output].ty).intern(db);
self.var_info.insert(q_output, VarInfo::Const(q_value).into());
let r_output = stmt.outputs[1];
let r_value = ConstValue::Int(r, self.variables[r_output].ty).intern(db);
self.var_info.insert(r_output, VarInfo::Const(r_value).into());
self.additional_stmts
.push(Statement::Const(StatementConst::new_flat(r_value, r_output)));
Some(Statement::Const(StatementConst::new_flat(q_value, q_output)))
} else if id == self.storage_base_address_from_felt252 {
let input_var = stmt.inputs[0].var_id;
if let Some(const_value) = self.as_const(input_var)
&& let ConstValue::Int(val, ty) = const_value.long(db)
{
stmt.inputs.clear();
let arg = GenericArgumentId::Constant(ConstValue::Int(val.clone(), *ty).intern(db));
stmt.function =
self.storage_base_address_const.concretize(db, vec![arg]).lowered(db);
}
None
} else if self.upcast_fns.contains(&id) {
let int_value = self.as_int(stmt.inputs[0].var_id)?;
let output = stmt.outputs[0];
let value = ConstValue::Int(int_value.clone(), self.variables[output].ty).intern(db);
self.var_info.insert(output, VarInfo::Const(value).into());
Some(Statement::Const(StatementConst::new_flat(value, output)))
} else if id == self.array_new {
self.var_info.insert(stmt.outputs[0], VarInfo::Array(vec![]).into());
None
} else if id == self.array_append {
let mut var_infos = if let VarInfo::Array(var_infos) =
self.var_info.get(&stmt.inputs[0].var_id)?.as_ref()
{
var_infos.clone()
} else {
return None;
};
let appended = stmt.inputs[1];
var_infos.push(match self.var_info.get(&appended.var_id) {
Some(var_info) => Some(var_info.clone()),
None => var_info_if_copy(self.variables, appended),
});
self.var_info.insert(stmt.outputs[0], VarInfo::Array(var_infos).into());
None
} else if id == self.array_len {
let info = self.var_info.get(&stmt.inputs[0].var_id)?;
let desnapped = try_extract_matches!(info.as_ref(), VarInfo::Snapshot)?;
let length = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?.len();
Some(self.propagate_const_and_get_statement(length.into(), stmt.outputs[0]))
} else {
None
}
}
fn try_specialize_call(&self, call_stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
if call_stmt.with_coupon {
return None;
}
if matches!(self.db.optimizations().inlining_strategy(), InliningStrategy::Avoid) {
return None;
}
let Ok(Some(mut called_function)) = call_stmt.function.body(self.db) else {
return None;
};
let extract_base = |function: ConcreteFunctionWithBodyId<'db>| match function.long(self.db)
{
ConcreteFunctionWithBodyLongId::Specialized(specialized) => {
specialized.long(self.db).base
}
_ => function,
};
let called_base = extract_base(called_function);
let caller_base = extract_base(self.caller_function);
if self.db.priv_never_inline(called_base).ok()? {
return None;
}
if call_stmt.is_specialization_base_call {
return None;
}
if called_base == caller_base && called_function != called_base {
return None;
}
let scc =
self.db.lowered_scc(called_base, DependencyType::Call, LoweringStage::Monomorphized);
if scc.len() > 1 && scc.contains(&caller_base) {
return None;
}
if call_stmt.inputs.iter().all(|arg| self.var_info.get(&arg.var_id).is_none()) {
return None;
}
let self_specializition = if let ConcreteFunctionWithBodyLongId::Specialized(specialized) =
self.caller_function.long(self.db)
&& caller_base == called_base
{
specialized.long(self.db).args.iter().map(Some).collect()
} else {
vec![None; call_stmt.inputs.len()]
};
let mut specialization_args = vec![];
let mut new_args = vec![];
for (arg, coerce) in zip_eq(&call_stmt.inputs, &self_specializition) {
if let Some(var_info) = self.var_info.get(&arg.var_id)
&& self.variables[arg.var_id].info.droppable.is_ok()
&& let Some(specialization_arg) = self.try_get_specialization_arg(
var_info.clone(),
self.variables[arg.var_id].ty,
&mut new_args,
*coerce,
)
{
specialization_args.push(specialization_arg);
} else {
specialization_args.push(SpecializationArg::NotSpecialized);
new_args.push(*arg);
continue;
};
}
if specialization_args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized)) {
return None;
}
if let ConcreteFunctionWithBodyLongId::Specialized(specialized_function) =
called_function.long(self.db)
{
let specialized_function = specialized_function.long(self.db);
called_function = specialized_function.base;
let mut new_args_iter = specialization_args.into_iter();
let mut old_args = specialized_function.args.clone();
let mut stack = vec![];
for arg in old_args.iter_mut().rev() {
stack.push(arg);
}
while let Some(arg) = stack.pop() {
match arg {
SpecializationArg::Const { .. } => {}
SpecializationArg::Snapshot(inner) => {
stack.push(inner.as_mut());
}
SpecializationArg::Enum { payload, .. } => {
stack.push(payload.as_mut());
}
SpecializationArg::Array(_, values) | SpecializationArg::Struct(values) => {
for value in values.iter_mut().rev() {
stack.push(value);
}
}
SpecializationArg::NotSpecialized => {
*arg = new_args_iter.next().unwrap_or(SpecializationArg::NotSpecialized);
}
}
}
specialization_args = old_args;
}
let specialized = SpecializedFunction { base: called_function, args: specialization_args }
.intern(self.db);
let specialized_func_id =
ConcreteFunctionWithBodyLongId::Specialized(specialized).intern(self.db);
if caller_base != called_base
&& self.db.priv_should_specialize(specialized_func_id) == Ok(false)
{
return None;
}
Some(Statement::Call(StatementCall {
function: specialized_func_id.function_id(self.db).unwrap(),
inputs: new_args,
with_coupon: call_stmt.with_coupon,
outputs: std::mem::take(&mut call_stmt.outputs),
location: call_stmt.location,
is_specialization_base_call: false,
}))
}
fn propagate_const_and_get_statement(
&mut self,
value: BigInt,
output: VariableId,
) -> Statement<'db> {
let ty = self.variables[output].ty;
let value = ConstValueId::from_int(self.db, ty, &value);
self.var_info.insert(output, VarInfo::Const(value).into());
Statement::Const(StatementConst::new_flat(value, output))
}
fn propagate_zero_and_get_statement(&mut self, output: VariableId) -> Statement<'db> {
self.propagate_const_and_get_statement(BigInt::zero(), output)
}
fn try_generate_const_statement(
&self,
value: ConstValueId<'db>,
output: VariableId,
) -> Option<Statement<'db>> {
if self.db.type_size_info(self.variables[output].ty) == Ok(TypeSizeInformation::Other) {
Some(Statement::Const(StatementConst::new_flat(value, output)))
} else if matches!(value.long(self.db), ConstValue::Struct(members, _) if members.is_empty())
{
Some(Statement::StructConstruct(StatementStructConstruct { inputs: vec![], output }))
} else {
None
}
}
fn handle_enum_block_end(
&mut self,
info: &mut MatchEnumInfo<'db>,
statements: &mut Vec<Statement<'db>>,
) -> Option<BlockEnd<'db>> {
let input = info.input.var_id;
let (n_snapshots, var_info) = self.var_info.get(&input)?.clone().peel_snapshots();
let location = info.location;
let as_usage = |var_id| VarUsage { var_id, location };
let db = self.db;
let snapshot_stmt = |vars: &mut VariableArena<'_>, pre_snap, post_snap| {
let ignored = vars.alloc(vars[pre_snap].clone());
Statement::Snapshot(StatementSnapshot::new(as_usage(pre_snap), ignored, post_snap))
};
if let VarInfo::Const(const_value) = var_info.as_ref()
&& let ConstValue::Enum(variant, value) = const_value.long(db)
{
let arm = &info.arms[variant.idx];
let output = arm.var_ids[0];
self.var_info
.insert(output, Rc::new(VarInfo::Const(*value)).wrap_with_snapshots(n_snapshots));
if self.variables[input].info.droppable.is_ok()
&& self.variables[output].info.copyable.is_ok()
&& let Ok(mut ty) = value.ty(db)
&& let Some(mut stmt) = self.try_generate_const_statement(*value, output)
{
for _ in 0..n_snapshots {
let non_snap_var = Variable::with_default_context(db, ty, location);
ty = TypeLongId::Snapshot(ty).intern(db);
let pre_snap = self.variables.alloc(non_snap_var);
stmt.outputs_mut()[0] = pre_snap;
let take_snap = snapshot_stmt(self.variables, pre_snap, output);
statements.push(core::mem::replace(&mut stmt, take_snap));
}
statements.push(stmt);
return Some(BlockEnd::Goto(arm.block_id, Default::default()));
}
} else if let VarInfo::Enum { variant, payload } = var_info.as_ref() {
let arm = &info.arms[variant.idx];
let variant_ty = variant.ty;
let output = arm.var_ids[0];
let payload = payload.clone();
let unwrapped =
self.variables[input].info.droppable.is_ok().then_some(()).and_then(|_| {
let (extra_snapshots, inner) = payload.clone().peel_snapshots();
match inner.as_ref() {
VarInfo::Var(var) if self.variables[var.var_id].info.copyable.is_ok() => {
Some((var.var_id, extra_snapshots))
}
VarInfo::Const(value) => {
let const_var = self
.variables
.alloc(Variable::with_default_context(db, variant_ty, location));
statements.push(self.try_generate_const_statement(*value, const_var)?);
Some((const_var, extra_snapshots))
}
_ => None,
}
});
self.var_info.insert(output, payload.wrap_with_snapshots(n_snapshots));
if let Some((mut unwrapped, extra_snapshots)) = unwrapped {
let total_snapshots = n_snapshots + extra_snapshots;
if total_snapshots != 0 {
for _ in 1..total_snapshots {
let ty = TypeLongId::Snapshot(self.variables[unwrapped].ty).intern(db);
let non_snap_var = Variable::with_default_context(self.db, ty, location);
let snapped = self.variables.alloc(non_snap_var);
statements.push(snapshot_stmt(self.variables, unwrapped, snapped));
unwrapped = snapped;
}
statements.push(snapshot_stmt(self.variables, unwrapped, output));
};
return Some(BlockEnd::Goto(arm.block_id, Default::default()));
}
}
None
}
fn handle_extern_block_end(
&mut self,
info: &mut MatchExternInfo<'db>,
statements: &mut Vec<Statement<'db>>,
) -> Option<BlockEnd<'db>> {
let db = self.db;
let (id, generic_args) = info.function.get_extern(db)?;
if self.nz_fns.contains(&id) {
let val = self.as_const(info.inputs[0].var_id)?;
let is_zero = match val.long(db) {
ConstValue::Int(v, _) => v.is_zero(),
ConstValue::Struct(s, _) => s.iter().all(|v| {
v.long(db).to_int().expect("Expected ConstValue::Int for size").is_zero()
}),
_ => unreachable!(),
};
Some(if is_zero {
BlockEnd::Goto(info.arms[0].block_id, Default::default())
} else {
let arm = &info.arms[1];
let nz_var = arm.var_ids[0];
let nz_val = ConstValue::NonZero(val).intern(db);
self.var_info.insert(nz_var, VarInfo::Const(nz_val).into());
statements.push(Statement::Const(StatementConst::new_flat(nz_val, nz_var)));
BlockEnd::Goto(arm.block_id, Default::default())
})
} else if self.eq_fns.contains(&id) {
let lhs = self.as_int(info.inputs[0].var_id);
let rhs = self.as_int(info.inputs[1].var_id);
if (lhs.map(Zero::is_zero).unwrap_or_default() && rhs.is_none())
|| (rhs.map(Zero::is_zero).unwrap_or_default() && lhs.is_none())
{
let nz_input = info.inputs[if lhs.is_some() { 1 } else { 0 }];
let var = &self.variables[nz_input.var_id].clone();
let function = self.type_info.get(&var.ty)?.is_zero;
let unused_nz_var = Variable::with_default_context(
db,
corelib::core_nonzero_ty(db, var.ty),
var.location,
);
let unused_nz_var = self.variables.alloc(unused_nz_var);
return Some(BlockEnd::Match {
info: MatchInfo::Extern(MatchExternInfo {
function,
inputs: vec![nz_input],
arms: vec![
MatchArm {
arm_selector: MatchArmSelector::VariantId(
corelib::jump_nz_zero_variant(db, var.ty),
),
block_id: info.arms[1].block_id,
var_ids: vec![],
},
MatchArm {
arm_selector: MatchArmSelector::VariantId(
corelib::jump_nz_nonzero_variant(db, var.ty),
),
block_id: info.arms[0].block_id,
var_ids: vec![unused_nz_var],
},
],
location: info.location,
}),
});
}
Some(BlockEnd::Goto(
info.arms[if lhs? == rhs? { 1 } else { 0 }].block_id,
Default::default(),
))
} else if self.uadd_fns.contains(&id)
|| self.usub_fns.contains(&id)
|| self.diff_fns.contains(&id)
|| self.iadd_fns.contains(&id)
|| self.isub_fns.contains(&id)
{
let rhs = self.as_int(info.inputs[1].var_id);
let lhs = self.as_int(info.inputs[0].var_id);
if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
let ty = self.variables[info.arms[0].var_ids[0]].ty;
let range = self.type_value_ranges.get(&ty)?;
let value = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
lhs + rhs
} else {
lhs - rhs
};
let (arm_index, value) = match range.normalized(value) {
NormalizedResult::InRange(value) => (0, value),
NormalizedResult::Under(value) => (1, value),
NormalizedResult::Over(value) => (
if self.iadd_fns.contains(&id) || self.isub_fns.contains(&id) {
2
} else {
1
},
value,
),
};
let arm = &info.arms[arm_index];
let actual_output = arm.var_ids[0];
let value = ConstValue::Int(value, ty).intern(db);
self.var_info.insert(actual_output, VarInfo::Const(value).into());
statements.push(Statement::Const(StatementConst::new_flat(value, actual_output)));
return Some(BlockEnd::Goto(arm.block_id, Default::default()));
}
if let Some(rhs) = rhs {
if rhs.is_zero() && !self.diff_fns.contains(&id) {
let arm = &info.arms[0];
self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[0]).into());
return Some(BlockEnd::Goto(arm.block_id, Default::default()));
}
if rhs.is_one() && !self.diff_fns.contains(&id) {
let ty = self.variables[info.arms[0].var_ids[0]].ty;
let ty_info = self.type_info.get(&ty)?;
let function = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
ty_info.inc?
} else {
ty_info.dec?
};
let enum_ty = function.signature(db).ok()?.return_type;
let TypeLongId::Concrete(ConcreteTypeId::Enum(concrete_enum_id)) =
enum_ty.long(db)
else {
return None;
};
let result = self.variables.alloc(Variable::with_default_context(
db,
function.signature(db).unwrap().return_type,
info.location,
));
statements.push(Statement::Call(StatementCall {
function,
inputs: vec![info.inputs[0]],
with_coupon: false,
outputs: vec![result],
location: info.location,
is_specialization_base_call: false,
}));
return Some(BlockEnd::Match {
info: MatchInfo::Enum(MatchEnumInfo {
concrete_enum_id: *concrete_enum_id,
input: VarUsage { var_id: result, location: info.location },
arms: core::mem::take(&mut info.arms),
location: info.location,
}),
});
}
}
if let Some(lhs) = lhs
&& lhs.is_zero()
&& (self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id))
{
let arm = &info.arms[0];
self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[1]).into());
return Some(BlockEnd::Goto(arm.block_id, Default::default()));
}
None
} else if let Some(reversed) = self.downcast_fns.get(&id) {
let range = |ty: TypeId<'_>| {
Some(if let Some(range) = self.type_value_ranges.get(&ty) {
range.clone()
} else {
let (min, max) = corelib::try_extract_bounded_int_type_ranges(db, ty)?;
TypeRange { min, max }
})
};
let (success_arm, failure_arm) = if *reversed { (1, 0) } else { (0, 1) };
let input_var = info.inputs[0].var_id;
let in_ty = self.variables[input_var].ty;
let success_output = info.arms[success_arm].var_ids[0];
let out_ty = self.variables[success_output].ty;
let out_range = range(out_ty)?;
let Some(value) = self.as_int(input_var) else {
let in_range = range(in_ty)?;
return if in_range.min < out_range.min || in_range.max > out_range.max {
None
} else {
let generic_args = [in_ty, out_ty].map(GenericArgumentId::Type).to_vec();
let function = db.core_info().upcast_fn.concretize(db, generic_args);
statements.push(Statement::Call(StatementCall {
function: function.lowered(db),
inputs: vec![info.inputs[0]],
with_coupon: false,
outputs: vec![success_output],
location: info.location,
is_specialization_base_call: false,
}));
return Some(BlockEnd::Goto(
info.arms[success_arm].block_id,
Default::default(),
));
};
};
let value = if in_ty == self.felt252 {
felt252_for_downcast(value, &out_range.min)
} else {
value.clone()
};
Some(if let NormalizedResult::InRange(value) = out_range.normalized(value) {
let value = ConstValue::Int(value, out_ty).intern(db);
self.var_info.insert(success_output, VarInfo::Const(value).into());
statements.push(Statement::Const(StatementConst::new_flat(value, success_output)));
BlockEnd::Goto(info.arms[success_arm].block_id, Default::default())
} else {
BlockEnd::Goto(info.arms[failure_arm].block_id, Default::default())
})
} else if id == self.bounded_int_constrain {
let input_var = info.inputs[0].var_id;
let value = self.as_int(input_var)?;
let generic_arg = generic_args[1];
let constrain_value = extract_matches!(generic_arg, GenericArgumentId::Constant)
.long(db)
.to_int()
.expect("Expected ConstValue::Int for size");
let arm_idx = if value < constrain_value { 0 } else { 1 };
let output = info.arms[arm_idx].var_ids[0];
statements.push(self.propagate_const_and_get_statement(value.clone(), output));
Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
} else if id == self.bounded_int_trim_min {
let input_var = info.inputs[0].var_id;
let ConstValue::Int(value, ty) = self.as_const(input_var)?.long(self.db) else {
return None;
};
let is_trimmed = if let Some(range) = self.type_value_ranges.get(ty) {
range.min == *value
} else {
corelib::try_extract_bounded_int_type_ranges(db, *ty)?.0 == *value
};
let arm_idx = if is_trimmed {
0
} else {
let output = info.arms[1].var_ids[0];
statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1
};
Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
} else if id == self.bounded_int_trim_max {
let input_var = info.inputs[0].var_id;
let ConstValue::Int(value, ty) = self.as_const(input_var)?.long(self.db) else {
return None;
};
let is_trimmed = if let Some(range) = self.type_value_ranges.get(ty) {
range.max == *value
} else {
corelib::try_extract_bounded_int_type_ranges(db, *ty)?.1 == *value
};
let arm_idx = if is_trimmed {
0
} else {
let output = info.arms[1].var_ids[0];
statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1
};
Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
} else if id == self.array_get {
let index = self.as_int(info.inputs[1].var_id)?.to_usize()?;
if let Some(arr_info) = self.var_info.get(&info.inputs[0].var_id)
&& let VarInfo::Snapshot(arr_info) = arr_info.as_ref()
&& let VarInfo::Array(infos) = arr_info.as_ref()
{
match infos.get(index) {
Some(Some(output_var_info)) => {
let arm = &info.arms[0];
let output_var_info = output_var_info.clone();
self.var_info.insert(
arm.var_ids[0],
VarInfo::Box(VarInfo::Snapshot(output_var_info.clone()).into()).into(),
);
if let VarInfo::Const(value) = output_var_info.as_ref() {
let value_ty = value.ty(db).ok()?;
let value_box_ty = corelib::core_box_ty(db, value_ty);
let location = info.location;
let boxed_var =
Variable::with_default_context(db, value_box_ty, location);
let boxed = self.variables.alloc(boxed_var.clone());
let unused_boxed = self.variables.alloc(boxed_var);
let snapped = self.variables.alloc(Variable::with_default_context(
db,
TypeLongId::Snapshot(value_box_ty).intern(db),
location,
));
statements.extend([
Statement::Const(StatementConst::new_boxed(*value, boxed)),
Statement::Snapshot(StatementSnapshot {
input: VarUsage { var_id: boxed, location },
outputs: [unused_boxed, snapped],
}),
Statement::Call(StatementCall {
function: self
.box_forward_snapshot
.concretize(db, vec![GenericArgumentId::Type(value_ty)])
.lowered(db),
inputs: vec![VarUsage { var_id: snapped, location }],
with_coupon: false,
outputs: vec![arm.var_ids[0]],
location: info.location,
is_specialization_base_call: false,
}),
]);
return Some(BlockEnd::Goto(arm.block_id, Default::default()));
}
}
None => {
return Some(BlockEnd::Goto(info.arms[1].block_id, Default::default()));
}
Some(None) => {}
}
}
if index.is_zero()
&& let [success, failure] = info.arms.as_mut_slice()
{
let arr = info.inputs[0].var_id;
let unused_arr_output0 = self.variables.alloc(self.variables[arr].clone());
let unused_arr_output1 = self.variables.alloc(self.variables[arr].clone());
info.inputs.truncate(1);
info.function = GenericFunctionId::Extern(self.array_snapshot_pop_front)
.concretize(db, generic_args)
.lowered(db);
success.var_ids.insert(0, unused_arr_output0);
failure.var_ids.insert(0, unused_arr_output1);
}
None
} else if id == self.array_pop_front {
let VarInfo::Array(var_infos) = self.var_info.get(&info.inputs[0].var_id)?.as_ref()
else {
return None;
};
if let Some(first) = var_infos.first() {
if let Some(first) = first.as_ref().cloned() {
let arm = &info.arms[0];
self.var_info
.insert(arm.var_ids[0], VarInfo::Array(var_infos[1..].to_vec()).into());
self.var_info.insert(arm.var_ids[1], VarInfo::Box(first).into());
}
None
} else {
let arm = &info.arms[1];
self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]).into());
Some(BlockEnd::Goto(
arm.block_id,
VarRemapping {
remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
},
))
}
} else if id == self.array_snapshot_pop_back || id == self.array_snapshot_pop_front {
let var_info = self.var_info.get(&info.inputs[0].var_id)?;
let desnapped = try_extract_matches!(var_info.as_ref(), VarInfo::Snapshot)?;
let element_var_infos = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?;
if element_var_infos.is_empty() {
let arm = &info.arms[1];
self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]).into());
Some(BlockEnd::Goto(
arm.block_id,
VarRemapping {
remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
},
))
} else {
None
}
} else {
None
}
}
fn as_const(&self, var_id: VariableId) -> Option<ConstValueId<'db>> {
try_extract_matches!(self.var_info.get(&var_id)?.as_ref(), VarInfo::Const).copied()
}
fn as_int(&self, var_id: VariableId) -> Option<&BigInt> {
match self.as_const(var_id)?.long(self.db) {
ConstValue::Int(value, _) => Some(value),
ConstValue::NonZero(const_value) => {
if let ConstValue::Int(value, _) = const_value.long(self.db) {
Some(value)
} else {
None
}
}
_ => None,
}
}
fn maybe_replace_inputs(&self, inputs: &mut [VarUsage<'db>]) {
for input in inputs {
self.maybe_replace_input(input);
}
}
fn maybe_replace_input(&self, input: &mut VarUsage<'db>) {
if let Some(info) = self.var_info.get(&input.var_id)
&& let VarInfo::Var(new_var) = info.as_ref()
{
*input = *new_var;
}
}
fn try_get_specialization_arg(
&self,
var_info: Rc<VarInfo<'db>>,
ty: TypeId<'db>,
unknown_vars: &mut Vec<VarUsage<'db>>,
coerce: Option<&SpecializationArg<'db>>,
) -> Option<SpecializationArg<'db>> {
require(self.db.type_size_info(ty).ok()? != TypeSizeInformation::ZeroSized)?;
require(!matches!(coerce, Some(SpecializationArg::NotSpecialized)))?;
match var_info.as_ref() {
VarInfo::Const(value) => {
let res = const_to_specialization_arg(self.db, *value, false);
let Some(coerce) = coerce else {
return Some(res);
};
if *coerce == res { Some(res) } else { None }
}
VarInfo::Box(info) => {
let res = try_extract_matches!(info.as_ref(), VarInfo::Const)
.map(|value| SpecializationArg::Const { value: *value, boxed: true });
let Some(coerce) = coerce else {
return res;
};
if Some(coerce.clone()) == res { res } else { None }
}
VarInfo::Snapshot(info) => {
let desnap_ty = *extract_matches!(ty.long(self.db), TypeLongId::Snapshot);
let mut local_unknown_vars: Vec<VarUsage<'db>> = Vec::new();
let inner = self.try_get_specialization_arg(
info.clone(),
desnap_ty,
&mut local_unknown_vars,
coerce.map(|coerce| {
extract_matches!(coerce, SpecializationArg::Snapshot).as_ref()
}),
)?;
unknown_vars.extend(local_unknown_vars);
Some(SpecializationArg::Snapshot(Box::new(inner)))
}
VarInfo::Array(infos) => {
let TypeLongId::Concrete(concrete_ty) = ty.long(self.db) else {
unreachable!("Expected a concrete type");
};
let [GenericArgumentId::Type(inner_ty)] = &concrete_ty.generic_args(self.db)[..]
else {
unreachable!("Expected a single type generic argument");
};
let coerces = match coerce {
Some(coerce) => {
let SpecializationArg::Array(ty, specialization_args) = coerce else {
unreachable!("Expected an array specialization argument");
};
assert_eq!(ty, inner_ty);
if specialization_args.len() != infos.len() {
return None;
}
specialization_args.iter().map(Some).collect()
}
None => vec![None; infos.len()],
};
let mut vars = vec![];
let mut args = vec![];
for (info, coerce) in zip_eq(infos, coerces) {
let info = info.as_ref()?.clone();
let arg =
self.try_get_specialization_arg(info, *inner_ty, &mut vars, coerce)?;
args.push(arg);
}
if !args.is_empty()
&& args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
{
return None;
}
unknown_vars.extend(vars);
Some(SpecializationArg::Array(*inner_ty, args))
}
VarInfo::Struct(infos) => {
let element_types: Vec<TypeId<'db>> = match ty.long(self.db) {
TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) => {
let members = self.db.concrete_struct_members(*concrete_struct).unwrap();
members.values().map(|member| member.ty).collect()
}
TypeLongId::Tuple(element_types) => element_types.clone(),
TypeLongId::FixedSizeArray { type_id, .. } => vec![*type_id; infos.len()],
_ => return None,
};
let coerces = match coerce {
Some(SpecializationArg::Struct(specialization_args)) => {
assert_eq!(specialization_args.len(), infos.len());
specialization_args.iter().map(Some).collect()
}
Some(_) => unreachable!("Expected a struct specialization argument"),
None => vec![None; infos.len()],
};
let mut struct_args = Vec::new();
let mut vars = vec![];
for ((elem_ty, opt_var_info), coerce) in
zip_eq(zip_eq(element_types, infos), coerces)
{
let var_info = opt_var_info.as_ref()?.clone();
let arg =
self.try_get_specialization_arg(var_info, elem_ty, &mut vars, coerce)?;
struct_args.push(arg);
}
if !struct_args.is_empty()
&& struct_args
.iter()
.all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
{
return None;
}
unknown_vars.extend(vars);
Some(SpecializationArg::Struct(struct_args))
}
VarInfo::Enum { variant, payload } => {
let coerce = match coerce {
Some(coerce) => {
let SpecializationArg::Enum { variant: coercion_variant, payload } = coerce
else {
unreachable!("Expected an enum specialization argument");
};
if coercion_variant != variant {
return None;
}
Some(payload.as_ref())
}
None => None,
};
let mut local_unknown_vars = vec![];
let payload_arg = self.try_get_specialization_arg(
payload.clone(),
variant.ty,
&mut local_unknown_vars,
coerce,
)?;
unknown_vars.extend(local_unknown_vars);
Some(SpecializationArg::Enum { variant: *variant, payload: Box::new(payload_arg) })
}
VarInfo::Var(var_usage) => {
unknown_vars.push(*var_usage);
Some(SpecializationArg::NotSpecialized)
}
}
}
pub fn should_skip_const_folding(&self, db: &'db dyn Database) -> bool {
if db.optimizations().skip_const_folding() {
return true;
}
if self.caller_function.base_semantic_function(db).generic_function(db)
== GenericFunctionWithBodyId::Free(self.libfunc_info.panic_with_const_felt252)
{
return true;
}
false
}
}
fn var_info_if_copy<'db>(
variables: &VariableArena<'db>,
input: VarUsage<'db>,
) -> Option<Rc<VarInfo<'db>>> {
variables[input.var_id].info.copyable.is_ok().then(|| VarInfo::Var(input).into())
}
#[salsa::tracked(returns(ref))]
fn priv_const_folding_info<'db>(
db: &'db dyn Database,
) -> crate::optimizations::const_folding::ConstFoldingLibfuncInfo<'db> {
ConstFoldingLibfuncInfo::new(db)
}
#[derive(Debug, PartialEq, Eq, salsa::Update)]
pub struct ConstFoldingLibfuncInfo<'db> {
felt_sub: ExternFunctionId<'db>,
felt_add: ExternFunctionId<'db>,
felt_mul: ExternFunctionId<'db>,
felt_div: ExternFunctionId<'db>,
box_forward_snapshot: GenericFunctionId<'db>,
eq_fns: OrderedHashSet<ExternFunctionId<'db>>,
uadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
usub_fns: OrderedHashSet<ExternFunctionId<'db>>,
diff_fns: OrderedHashSet<ExternFunctionId<'db>>,
iadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
isub_fns: OrderedHashSet<ExternFunctionId<'db>>,
wide_mul_fns: OrderedHashSet<ExternFunctionId<'db>>,
div_rem_fns: OrderedHashSet<ExternFunctionId<'db>>,
bounded_int_add: ExternFunctionId<'db>,
bounded_int_sub: ExternFunctionId<'db>,
bounded_int_constrain: ExternFunctionId<'db>,
bounded_int_trim_min: ExternFunctionId<'db>,
bounded_int_trim_max: ExternFunctionId<'db>,
array_get: ExternFunctionId<'db>,
array_snapshot_pop_front: ExternFunctionId<'db>,
array_snapshot_pop_back: ExternFunctionId<'db>,
array_len: ExternFunctionId<'db>,
array_new: ExternFunctionId<'db>,
array_append: ExternFunctionId<'db>,
array_pop_front: ExternFunctionId<'db>,
storage_base_address_from_felt252: ExternFunctionId<'db>,
storage_base_address_const: GenericFunctionId<'db>,
panic_with_felt252: FunctionId<'db>,
pub panic_with_const_felt252: FreeFunctionId<'db>,
panic_with_byte_array: FunctionId<'db>,
type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>>,
const_calculation_info: Arc<ConstCalcInfo<'db>>,
}
impl<'db> ConstFoldingLibfuncInfo<'db> {
fn new(db: &'db dyn Database) -> Self {
let core = ModuleHelper::core(db);
let box_module = core.submodule("box");
let integer_module = core.submodule("integer");
let internal_module = core.submodule("internal");
let bounded_int_module = internal_module.submodule("bounded_int");
let num_module = internal_module.submodule("num");
let array_module = core.submodule("array");
let starknet_module = core.submodule("starknet");
let storage_access_module = starknet_module.submodule("storage_access");
let utypes = ["u8", "u16", "u32", "u64", "u128"];
let itypes = ["i8", "i16", "i32", "i64", "i128"];
let eq_fns = OrderedHashSet::<_>::from_iter(
chain!(utypes, itypes).map(|ty| integer_module.extern_function_id(&format!("{ty}_eq"))),
);
let uadd_fns = OrderedHashSet::<_>::from_iter(
utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_add"))),
);
let usub_fns = OrderedHashSet::<_>::from_iter(
utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_sub"))),
);
let diff_fns = OrderedHashSet::<_>::from_iter(
itypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_diff"))),
);
let iadd_fns =
OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
integer_module.extern_function_id(&format!("{ty}_overflowing_add_impl"))
}));
let isub_fns =
OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
integer_module.extern_function_id(&format!("{ty}_overflowing_sub_impl"))
}));
let wide_mul_fns = OrderedHashSet::<_>::from_iter(chain!(
[bounded_int_module.extern_function_id("bounded_int_mul")],
["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
.map(|ty| integer_module.extern_function_id(&format!("{ty}_wide_mul"))),
));
let div_rem_fns = OrderedHashSet::<_>::from_iter(chain!(
[bounded_int_module.extern_function_id("bounded_int_div_rem")],
utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_safe_divmod"))),
));
let type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>> = OrderedHashMap::from_iter(
[
("u8", false, true),
("u16", false, true),
("u32", false, true),
("u64", false, true),
("u128", false, true),
("u256", false, false),
("i8", true, true),
("i16", true, true),
("i32", true, true),
("i64", true, true),
("i128", true, true),
]
.map(|(ty_name, as_bounded_int, inc_dec): (&'static str, bool, bool)| {
let ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, ty_name), vec![]);
let is_zero = if as_bounded_int {
bounded_int_module
.function_id("bounded_int_is_zero", vec![GenericArgumentId::Type(ty)])
} else {
integer_module.function_id(
SmolStrId::from(db, format!("{ty_name}_is_zero")).long(db).as_str(),
vec![],
)
}
.lowered(db);
let (inc, dec) = if inc_dec {
(
Some(
num_module
.function_id(
SmolStrId::from(db, format!("{ty_name}_inc")).long(db).as_str(),
vec![],
)
.lowered(db),
),
Some(
num_module
.function_id(
SmolStrId::from(db, format!("{ty_name}_dec")).long(db).as_str(),
vec![],
)
.lowered(db),
),
)
} else {
(None, None)
};
let info = TypeInfo { is_zero, inc, dec };
(ty, info)
}),
);
Self {
felt_sub: core.extern_function_id("felt252_sub"),
felt_add: core.extern_function_id("felt252_add"),
felt_mul: core.extern_function_id("felt252_mul"),
felt_div: core.extern_function_id("felt252_div"),
box_forward_snapshot: box_module.generic_function_id("box_forward_snapshot"),
eq_fns,
uadd_fns,
usub_fns,
diff_fns,
iadd_fns,
isub_fns,
wide_mul_fns,
div_rem_fns,
bounded_int_add: bounded_int_module.extern_function_id("bounded_int_add"),
bounded_int_sub: bounded_int_module.extern_function_id("bounded_int_sub"),
bounded_int_constrain: bounded_int_module.extern_function_id("bounded_int_constrain"),
bounded_int_trim_min: bounded_int_module.extern_function_id("bounded_int_trim_min"),
bounded_int_trim_max: bounded_int_module.extern_function_id("bounded_int_trim_max"),
array_get: array_module.extern_function_id("array_get"),
array_snapshot_pop_front: array_module.extern_function_id("array_snapshot_pop_front"),
array_snapshot_pop_back: array_module.extern_function_id("array_snapshot_pop_back"),
array_len: array_module.extern_function_id("array_len"),
array_new: array_module.extern_function_id("array_new"),
array_append: array_module.extern_function_id("array_append"),
array_pop_front: array_module.extern_function_id("array_pop_front"),
storage_base_address_from_felt252: storage_access_module
.extern_function_id("storage_base_address_from_felt252"),
storage_base_address_const: storage_access_module
.generic_function_id("storage_base_address_const"),
panic_with_felt252: core.function_id("panic_with_felt252", vec![]).lowered(db),
panic_with_const_felt252: core.free_function_id("panic_with_const_felt252"),
panic_with_byte_array: core
.submodule("panics")
.function_id("panic_with_byte_array", vec![])
.lowered(db),
type_info,
const_calculation_info: db.const_calc_info(),
}
}
}
impl<'db> std::ops::Deref for ConstFoldingContext<'db, '_> {
type Target = ConstFoldingLibfuncInfo<'db>;
fn deref(&self) -> &ConstFoldingLibfuncInfo<'db> {
self.libfunc_info
}
}
impl<'a> std::ops::Deref for ConstFoldingLibfuncInfo<'a> {
type Target = ConstCalcInfo<'a>;
fn deref(&self) -> &ConstCalcInfo<'a> {
&self.const_calculation_info
}
}
#[derive(Debug, PartialEq, Eq, salsa::Update)]
struct TypeInfo<'db> {
is_zero: FunctionId<'db>,
inc: Option<FunctionId<'db>>,
dec: Option<FunctionId<'db>>,
}
trait TypeRangeNormalizer {
fn normalized(&self, value: BigInt) -> NormalizedResult;
}
impl TypeRangeNormalizer for TypeRange {
fn normalized(&self, value: BigInt) -> NormalizedResult {
if value < self.min {
NormalizedResult::Under(value - &self.min + &self.max + 1)
} else if value > self.max {
NormalizedResult::Over(value + &self.min - &self.max - 1)
} else {
NormalizedResult::InRange(value)
}
}
}
enum NormalizedResult {
InRange(BigInt),
Over(BigInt),
Under(BigInt),
}