cairo_lang_lowering/optimizations/
const_folding.rs

1#[cfg(test)]
2#[path = "const_folding_test.rs"]
3mod test;
4
5use std::sync::Arc;
6
7use cairo_lang_defs::ids::{ExternFunctionId, FreeFunctionId};
8use cairo_lang_filesystem::flag::flag_unsafe_panic;
9use cairo_lang_filesystem::ids::SmolStrId;
10use cairo_lang_semantic::corelib::CorelibSemantic;
11use cairo_lang_semantic::helper::ModuleHelper;
12use cairo_lang_semantic::items::constant::{
13    ConstCalcInfo, ConstValue, ConstValueId, ConstantSemantic, TypeRange, canonical_felt252,
14    felt252_for_downcast,
15};
16use cairo_lang_semantic::items::functions::{GenericFunctionId, GenericFunctionWithBodyId};
17use cairo_lang_semantic::items::structure::StructSemantic;
18use cairo_lang_semantic::types::{TypeSizeInformation, TypesSemantic};
19use cairo_lang_semantic::{
20    ConcreteTypeId, ConcreteVariant, GenericArgumentId, MatchArmSelector, TypeId, TypeLongId,
21    corelib,
22};
23use cairo_lang_utils::byte_array::BYTE_ARRAY_MAGIC;
24use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
25use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
26use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
27use cairo_lang_utils::{Intern, extract_matches, require, try_extract_matches};
28use itertools::{chain, zip_eq};
29use num_bigint::BigInt;
30use num_integer::Integer;
31use num_traits::cast::ToPrimitive;
32use num_traits::{Num, One, Zero};
33use salsa::Database;
34use starknet_types_core::felt::Felt as Felt252;
35
36use crate::db::LoweringGroup;
37use crate::ids::{
38    ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionId, SemanticFunctionIdEx,
39    SpecializedFunction,
40};
41use crate::specialization::SpecializationArg;
42use crate::utils::InliningStrategy;
43use crate::{
44    Block, BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, MatchArm, MatchEnumInfo,
45    MatchExternInfo, MatchInfo, Statement, StatementCall, StatementConst, StatementDesnap,
46    StatementEnumConstruct, StatementSnapshot, StatementStructConstruct,
47    StatementStructDestructure, VarRemapping, VarUsage, Variable, VariableArena, VariableId,
48};
49
50/// Converts a const value to a specialization arg.
51/// For struct, tuple, fixed-size array and enum const values, recursively converts to the
52/// corresponding SpecializationArg variant.
53fn const_to_specialization_arg<'db>(
54    db: &'db dyn Database,
55    value: ConstValueId<'db>,
56    boxed: bool,
57) -> SpecializationArg<'db> {
58    match value.long(db) {
59        ConstValue::Struct(members, ty) => {
60            // Only convert to SpecializationArg::Struct for struct, tuple, or fixed-size array.
61            // For closures and other types, fall back to Const.
62            if matches!(
63                ty.long(db),
64                TypeLongId::Concrete(ConcreteTypeId::Struct(_))
65                    | TypeLongId::Tuple(_)
66                    | TypeLongId::FixedSizeArray { .. }
67            ) {
68                let args = members
69                    .iter()
70                    .map(|member| const_to_specialization_arg(db, *member, false))
71                    .collect();
72                SpecializationArg::Struct(args)
73            } else {
74                SpecializationArg::Const { value, boxed }
75            }
76        }
77        ConstValue::Enum(variant, payload) => SpecializationArg::Enum {
78            variant: *variant,
79            payload: Box::new(const_to_specialization_arg(db, *payload, false)),
80        },
81        _ => SpecializationArg::Const { value, boxed },
82    }
83}
84
85/// Keeps track of equivalent values that variables might be replaced with.
86/// Note: We don't keep track of types as we assume the usage is always correct.
87#[derive(Debug, Clone)]
88enum VarInfo<'db> {
89    /// The variable is a const value.
90    Const(ConstValueId<'db>),
91    /// The variable can be replaced by another variable.
92    Var(VarUsage<'db>),
93    /// The variable is a snapshot of another variable.
94    Snapshot(Box<VarInfo<'db>>),
95    /// The variable is a struct of other variables.
96    /// `None` values represent variables that are not tracked.
97    Struct(Vec<Option<VarInfo<'db>>>),
98    /// The variable is an enum of a known variant of other variables.
99    Enum { variant: ConcreteVariant<'db>, payload: Box<VarInfo<'db>> },
100    /// The variable is a box of another variable.
101    Box(Box<VarInfo<'db>>),
102    /// The variable is an array of known size of other variables.
103    /// `None` values represent variables that are not tracked.
104    Array(Vec<Option<VarInfo<'db>>>),
105}
106impl<'db> VarInfo<'db> {
107    /// Peels the snapshots from the variable info and returns the number of snapshots performed.
108    fn peel_snapshots(&self) -> (usize, &VarInfo<'db>) {
109        let mut n_snapshots = 0;
110        let mut info = self;
111        while let VarInfo::Snapshot(inner) = info {
112            info = inner.as_ref();
113            n_snapshots += 1;
114        }
115        (n_snapshots, info)
116    }
117    /// Wraps the variable info with the given number of snapshots.
118    fn wrap_with_snapshots(mut self, n_snapshots: usize) -> VarInfo<'db> {
119        for _ in 0..n_snapshots {
120            self = VarInfo::Snapshot(Box::new(self));
121        }
122        self
123    }
124}
125
126#[derive(Debug, Clone, Copy, PartialEq)]
127enum Reachability {
128    /// The block is reachable from the function start only through the goto at the end of the given
129    /// block.
130    FromSingleGoto(BlockId),
131    /// The block is reachable from the function start after const-folding - just does not fit
132    /// `FromSingleGoto`.
133    Any,
134}
135
136/// Performs constant folding on the lowered program.
137/// The optimization only works when the blocks are topologically sorted.
138pub fn const_folding<'db>(
139    db: &'db dyn Database,
140    function_id: ConcreteFunctionWithBodyId<'db>,
141    lowered: &mut Lowered<'db>,
142) {
143    if lowered.blocks.is_empty() {
144        return;
145    }
146
147    // Note that we can keep the var_info across blocks because the lowering
148    // is in static single assignment form.
149    let mut ctx = ConstFoldingContext::new(db, function_id, &mut lowered.variables);
150
151    if ctx.should_skip_const_folding(db) {
152        return;
153    }
154
155    for block_id in (0..lowered.blocks.len()).map(BlockId) {
156        if !ctx.visit_block_start(block_id, |block_id| &lowered.blocks[block_id]) {
157            continue;
158        }
159
160        let block = &mut lowered.blocks[block_id];
161        for stmt in block.statements.iter_mut() {
162            ctx.visit_statement(stmt);
163        }
164        ctx.visit_block_end(block_id, block);
165    }
166}
167
168pub struct ConstFoldingContext<'db, 'mt> {
169    /// The used database.
170    db: &'db dyn Database,
171    /// The variables arena, mostly used to get the type of variables.
172    pub variables: &'mt mut VariableArena<'db>,
173    /// The accumulated information about the const values of variables.
174    var_info: UnorderedHashMap<VariableId, VarInfo<'db>>,
175    /// The libfunc information.
176    libfunc_info: &'db ConstFoldingLibfuncInfo<'db>,
177    /// The specialization base of the caller function (or the caller if the function is not
178    /// specialized).
179    caller_function: ConcreteFunctionWithBodyId<'db>,
180    /// Reachability of blocks from the function start.
181    /// If the block is not in this map, it means that it is unreachable (or that it was already
182    /// visited and its reachability won't be checked again).
183    reachability: UnorderedHashMap<BlockId, Reachability>,
184    /// Additional statements to add to the block.
185    additional_stmts: Vec<Statement<'db>>,
186}
187
188impl<'db, 'mt> ConstFoldingContext<'db, 'mt> {
189    pub fn new(
190        db: &'db dyn Database,
191        function_id: ConcreteFunctionWithBodyId<'db>,
192        variables: &'mt mut VariableArena<'db>,
193    ) -> Self {
194        Self {
195            db,
196            var_info: UnorderedHashMap::default(),
197            variables,
198            libfunc_info: priv_const_folding_info(db),
199            caller_function: function_id,
200            reachability: UnorderedHashMap::from_iter([(BlockId::root(), Reachability::Any)]),
201            additional_stmts: vec![],
202        }
203    }
204
205    /// Determines if a block is reachable from the function start and propagates constant values
206    /// when the block is reachable via a single goto statement.
207    pub fn visit_block_start<'r, 'get>(
208        &'r mut self,
209        block_id: BlockId,
210        get_block: impl FnOnce(BlockId) -> &'get Block<'db>,
211    ) -> bool
212    where
213        'db: 'get,
214    {
215        let Some(reachability) = self.reachability.remove(&block_id) else {
216            return false;
217        };
218        match reachability {
219            Reachability::Any => {}
220            Reachability::FromSingleGoto(from_block) => match &get_block(from_block).end {
221                BlockEnd::Goto(_, remapping) => {
222                    for (dst, src) in remapping.iter() {
223                        if let Some(v) = self.as_const(src.var_id) {
224                            self.var_info.insert(*dst, VarInfo::Const(v));
225                        }
226                    }
227                }
228                _ => unreachable!("Expected a goto end"),
229            },
230        }
231        true
232    }
233
234    /// Processes a statement and applies the constant folding optimizations.
235    ///
236    /// This method performs the following operations:
237    /// - Updates the `var_info` map with constant values of variables
238    /// - Replace the input statement with optimized versions when possible
239    /// - Updates `self.additional_stmts` with statements that need to be added to the block.
240    ///
241    /// Note: `self.visit_block_end` must be called after processing all statements
242    /// in a block to actually add the additional statements.
243    pub fn visit_statement(&mut self, stmt: &mut Statement<'db>) {
244        self.maybe_replace_inputs(stmt.inputs_mut());
245        match stmt {
246            Statement::Const(StatementConst { value, output, boxed }) if *boxed => {
247                self.var_info.insert(*output, VarInfo::Box(VarInfo::Const(*value).into()));
248            }
249            Statement::Const(StatementConst { value, output, .. }) => match value.long(self.db) {
250                ConstValue::Int(..)
251                | ConstValue::Struct(..)
252                | ConstValue::Enum(..)
253                | ConstValue::NonZero(..) => {
254                    self.var_info.insert(*output, VarInfo::Const(*value));
255                }
256                ConstValue::Generic(_)
257                | ConstValue::ImplConstant(_)
258                | ConstValue::Var(..)
259                | ConstValue::Missing(_) => {}
260            },
261            Statement::Snapshot(stmt) => {
262                if let Some(info) = self.var_info.get(&stmt.input.var_id).cloned() {
263                    self.var_info.insert(stmt.original(), info.clone());
264                    self.var_info.insert(stmt.snapshot(), VarInfo::Snapshot(info.into()));
265                }
266            }
267            Statement::Desnap(StatementDesnap { input, output }) => {
268                if let Some(VarInfo::Snapshot(info)) = self.var_info.get(&input.var_id) {
269                    self.var_info.insert(*output, info.as_ref().clone());
270                }
271            }
272            Statement::Call(call_stmt) => {
273                if let Some(updated_stmt) = self.handle_statement_call(call_stmt) {
274                    *stmt = updated_stmt;
275                } else if let Some(updated_stmt) = self.try_specialize_call(call_stmt) {
276                    *stmt = updated_stmt;
277                }
278            }
279            Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
280                let mut const_args = vec![];
281                let mut all_args = vec![];
282                let mut contains_info = false;
283                for input in inputs.iter() {
284                    let Some(info) = self.var_info.get(&input.var_id) else {
285                        all_args.push(var_info_if_copy(self.variables, *input));
286                        continue;
287                    };
288                    contains_info = true;
289                    if let VarInfo::Const(value) = info {
290                        const_args.push(*value);
291                    }
292                    all_args.push(Some(info.clone()));
293                }
294                if const_args.len() == inputs.len() {
295                    let value =
296                        ConstValue::Struct(const_args, self.variables[*output].ty).intern(self.db);
297                    self.var_info.insert(*output, VarInfo::Const(value));
298                } else if contains_info {
299                    self.var_info.insert(*output, VarInfo::Struct(all_args));
300                }
301            }
302            Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
303                if let Some(info) = self.var_info.get(&input.var_id) {
304                    let (n_snapshots, info) = info.peel_snapshots();
305                    match info {
306                        VarInfo::Const(const_value) => {
307                            if let ConstValue::Struct(member_values, _) = const_value.long(self.db)
308                            {
309                                for (output, value) in zip_eq(outputs, member_values) {
310                                    self.var_info.insert(
311                                        *output,
312                                        VarInfo::Const(*value).wrap_with_snapshots(n_snapshots),
313                                    );
314                                }
315                            }
316                        }
317                        VarInfo::Struct(members) => {
318                            for (output, member) in zip_eq(outputs, members.clone()) {
319                                if let Some(member) = member {
320                                    self.var_info
321                                        .insert(*output, member.wrap_with_snapshots(n_snapshots));
322                                }
323                            }
324                        }
325                        _ => {}
326                    }
327                }
328            }
329            Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
330                let value = if let Some(info) = self.var_info.get(&input.var_id) {
331                    if let VarInfo::Const(val) = info {
332                        VarInfo::Const(ConstValue::Enum(*variant, *val).intern(self.db))
333                    } else {
334                        VarInfo::Enum { variant: *variant, payload: info.clone().into() }
335                    }
336                } else {
337                    VarInfo::Enum { variant: *variant, payload: VarInfo::Var(*input).into() }
338                };
339                self.var_info.insert(*output, value);
340            }
341        }
342    }
343
344    /// Processes the block's end and incorporates additional statements into the block.
345    ///
346    /// This method handles the following tasks:
347    /// - Inserts the accumulated additional statements into the block.
348    /// - Converts match endings to goto when applicable.
349    /// - Updates self.reachability based on the block's ending.
350    pub fn visit_block_end(&mut self, block_id: BlockId, block: &mut Block<'db>) {
351        let statements = &mut block.statements;
352        statements.splice(0..0, self.additional_stmts.drain(..));
353
354        match &mut block.end {
355            BlockEnd::Goto(_, remappings) => {
356                for (_, v) in remappings.iter_mut() {
357                    self.maybe_replace_input(v);
358                }
359            }
360            BlockEnd::Match { info } => {
361                self.maybe_replace_inputs(info.inputs_mut());
362                match info {
363                    MatchInfo::Enum(info) => {
364                        if let Some(updated_end) = self.handle_enum_block_end(info, statements) {
365                            block.end = updated_end;
366                        }
367                    }
368                    MatchInfo::Extern(info) => {
369                        if let Some(updated_end) = self.handle_extern_block_end(info, statements) {
370                            block.end = updated_end;
371                        }
372                    }
373                    MatchInfo::Value(info) => {
374                        if let Some(value) =
375                            self.as_int(info.input.var_id).and_then(|x| x.to_usize())
376                            && let Some(arm) = info.arms.iter().find(|arm| {
377                                matches!(
378                                    &arm.arm_selector,
379                                    MatchArmSelector::Value(v) if v.value == value
380                                )
381                            })
382                        {
383                            // Create the variable that was previously introduced in match arm.
384                            statements.push(Statement::StructConstruct(StatementStructConstruct {
385                                inputs: vec![],
386                                output: arm.var_ids[0],
387                            }));
388                            block.end = BlockEnd::Goto(arm.block_id, Default::default());
389                        }
390                    }
391                }
392            }
393            BlockEnd::Return(inputs, _) => self.maybe_replace_inputs(inputs),
394            BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
395        }
396        match &block.end {
397            BlockEnd::Goto(dst_block_id, _) => {
398                match self.reachability.entry(*dst_block_id) {
399                    std::collections::hash_map::Entry::Occupied(mut e) => {
400                        e.insert(Reachability::Any)
401                    }
402                    std::collections::hash_map::Entry::Vacant(e) => {
403                        *e.insert(Reachability::FromSingleGoto(block_id))
404                    }
405                };
406            }
407            BlockEnd::Match { info } => {
408                for arm in info.arms() {
409                    assert!(self.reachability.insert(arm.block_id, Reachability::Any).is_none());
410                }
411            }
412            BlockEnd::NotSet | BlockEnd::Return(..) | BlockEnd::Panic(..) => {}
413        }
414    }
415
416    /// Handles a statement call.
417    ///
418    /// Returns None if no additional changes are required.
419    /// If changes are required, returns an updated statement (to override the current
420    /// statement).
421    /// May add additional statements to `self.additional_stmts` if just replacing the current
422    /// statement is not enough.
423    fn handle_statement_call(&mut self, stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
424        let db = self.db;
425        if stmt.function == self.panic_with_felt252 {
426            let val = self.as_const(stmt.inputs[0].var_id)?;
427            stmt.inputs.clear();
428            stmt.function = GenericFunctionId::Free(self.panic_with_const_felt252)
429                .concretize(db, vec![GenericArgumentId::Constant(val)])
430                .lowered(db);
431            return None;
432        } else if stmt.function == self.panic_with_byte_array && !flag_unsafe_panic(db) {
433            let snap = self.var_info.get(&stmt.inputs[0].var_id)?;
434            let bytearray = try_extract_matches!(snap, VarInfo::Snapshot)?;
435            let [
436                Some(VarInfo::Array(data)),
437                Some(VarInfo::Const(pending_word)),
438                Some(VarInfo::Const(pending_len)),
439            ] = &try_extract_matches!(bytearray.as_ref(), VarInfo::Struct)?[..]
440            else {
441                return None;
442            };
443            let mut panic_data =
444                vec![BigInt::from_str_radix(BYTE_ARRAY_MAGIC, 16).unwrap(), data.len().into()];
445            for word in data {
446                let Some(VarInfo::Const(word)) = word else {
447                    return None;
448                };
449                panic_data.push(word.long(db).to_int()?.clone());
450            }
451            panic_data.extend([
452                pending_word.long(db).to_int()?.clone(),
453                pending_len.long(db).to_int()?.clone(),
454            ]);
455            let felt252_ty = self.felt252;
456            let location = stmt.location;
457            let new_var = |ty| Variable::with_default_context(db, ty, location);
458            let as_usage = |var_id| VarUsage { var_id, location };
459            let array_fn = |extern_id| {
460                let args = vec![GenericArgumentId::Type(felt252_ty)];
461                GenericFunctionId::Extern(extern_id).concretize(db, args).lowered(db)
462            };
463            let call_stmt = |function, inputs, outputs| {
464                let with_coupon = false;
465                Statement::Call(StatementCall {
466                    function,
467                    inputs,
468                    with_coupon,
469                    outputs,
470                    location,
471                    is_specialization_base_call: false,
472                })
473            };
474            let arr_var = new_var(corelib::core_array_felt252_ty(db));
475            let mut arr = self.variables.alloc(arr_var.clone());
476            self.additional_stmts.push(call_stmt(array_fn(self.array_new), vec![], vec![arr]));
477            let felt252_var = new_var(felt252_ty);
478            let arr_append_fn = array_fn(self.array_append);
479            for word in panic_data {
480                let to_append = self.variables.alloc(felt252_var.clone());
481                let new_arr = self.variables.alloc(arr_var.clone());
482                self.additional_stmts.push(Statement::Const(StatementConst::new_flat(
483                    ConstValue::Int(word, felt252_ty).intern(db),
484                    to_append,
485                )));
486                self.additional_stmts.push(call_stmt(
487                    arr_append_fn,
488                    vec![as_usage(arr), as_usage(to_append)],
489                    vec![new_arr],
490                ));
491                arr = new_arr;
492            }
493            let panic_ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, "Panic"), vec![]);
494            let panic_var = self.variables.alloc(new_var(panic_ty));
495            self.additional_stmts.push(Statement::StructConstruct(StatementStructConstruct {
496                inputs: vec![],
497                output: panic_var,
498            }));
499            return Some(Statement::StructConstruct(StatementStructConstruct {
500                inputs: vec![as_usage(panic_var), as_usage(arr)],
501                output: stmt.outputs[0],
502            }));
503        }
504        let (id, _generic_args) = stmt.function.get_extern(db)?;
505        if id == self.felt_sub {
506            if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
507                && rhs.is_zero()
508            {
509                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
510                None
511            } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
512                && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
513            {
514                let value = canonical_felt252(&(lhs - rhs));
515                Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
516            } else {
517                None
518            }
519        } else if id == self.felt_add {
520            if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
521                && lhs.is_zero()
522            {
523                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]));
524                None
525            } else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
526                && rhs.is_zero()
527            {
528                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
529                None
530            } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
531                && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
532            {
533                let value = canonical_felt252(&(lhs + rhs));
534                Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
535            } else {
536                None
537            }
538        } else if id == self.felt_mul {
539            let lhs = self.as_int(stmt.inputs[0].var_id);
540            let rhs = self.as_int(stmt.inputs[1].var_id);
541            if lhs.map(Zero::is_zero).unwrap_or_default()
542                || rhs.map(Zero::is_zero).unwrap_or_default()
543            {
544                Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
545            } else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
546                && rhs.is_one()
547            {
548                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
549                None
550            } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
551                && lhs.is_one()
552            {
553                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]));
554                None
555            } else if let Some(lhs) = lhs
556                && let Some(rhs) = rhs
557            {
558                let value = canonical_felt252(&(lhs * rhs));
559                Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
560            } else {
561                None
562            }
563        } else if id == self.felt_div {
564            // Note that divisor is never 0, due to NonZero type always being the divisor.
565            if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
566                // Returns the original value when dividing by 1.
567                && rhs.is_one()
568            {
569                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
570                None
571            } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
572                // If the value is 0, result is 0 regardless of the divisor.
573                && lhs.is_zero()
574            {
575                Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
576            } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
577                && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
578                && let Ok(rhs_nonzero) = Felt252::from(rhs).try_into()
579            {
580                // Constant fold when both operands are constants
581
582                // Use field_div for Felt252 division
583                let lhs_felt = Felt252::from(lhs);
584                let value = lhs_felt.field_div(&rhs_nonzero).to_bigint();
585                Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
586            } else {
587                None
588            }
589        } else if self.wide_mul_fns.contains(&id) {
590            let lhs = self.as_int(stmt.inputs[0].var_id);
591            let rhs = self.as_int(stmt.inputs[1].var_id);
592            let output = stmt.outputs[0];
593            if lhs.map(Zero::is_zero).unwrap_or_default()
594                || rhs.map(Zero::is_zero).unwrap_or_default()
595            {
596                return Some(self.propagate_zero_and_get_statement(output));
597            }
598            let lhs = lhs?;
599            Some(self.propagate_const_and_get_statement(lhs * rhs?, stmt.outputs[0]))
600        } else if id == self.bounded_int_add || id == self.bounded_int_sub {
601            let lhs = self.as_int(stmt.inputs[0].var_id)?;
602            let rhs = self.as_int(stmt.inputs[1].var_id)?;
603            let value = if id == self.bounded_int_add { lhs + rhs } else { lhs - rhs };
604            Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
605        } else if self.div_rem_fns.contains(&id) {
606            let lhs = self.as_int(stmt.inputs[0].var_id);
607            if lhs.map(Zero::is_zero).unwrap_or_default() {
608                let additional_stmt = self.propagate_zero_and_get_statement(stmt.outputs[1]);
609                self.additional_stmts.push(additional_stmt);
610                return Some(self.propagate_zero_and_get_statement(stmt.outputs[0]));
611            }
612            let rhs = self.as_int(stmt.inputs[1].var_id)?;
613            let (q, r) = lhs?.div_rem(rhs);
614            let q_output = stmt.outputs[0];
615            let q_value = ConstValue::Int(q, self.variables[q_output].ty).intern(db);
616            self.var_info.insert(q_output, VarInfo::Const(q_value));
617            let r_output = stmt.outputs[1];
618            let r_value = ConstValue::Int(r, self.variables[r_output].ty).intern(db);
619            self.var_info.insert(r_output, VarInfo::Const(r_value));
620            self.additional_stmts
621                .push(Statement::Const(StatementConst::new_flat(r_value, r_output)));
622            Some(Statement::Const(StatementConst::new_flat(q_value, q_output)))
623        } else if id == self.storage_base_address_from_felt252 {
624            let input_var = stmt.inputs[0].var_id;
625            if let Some(const_value) = self.as_const(input_var)
626                && let ConstValue::Int(val, ty) = const_value.long(db)
627            {
628                stmt.inputs.clear();
629                let arg = GenericArgumentId::Constant(ConstValue::Int(val.clone(), *ty).intern(db));
630                stmt.function =
631                    self.storage_base_address_const.concretize(db, vec![arg]).lowered(db);
632            }
633            None
634        } else if id == self.into_box {
635            let input = stmt.inputs[0];
636            let var_info = self.var_info.get(&input.var_id);
637            let const_value = match var_info {
638                Some(VarInfo::Const(val)) => Some(*val),
639                Some(VarInfo::Snapshot(info)) => {
640                    try_extract_matches!(info.as_ref(), VarInfo::Const).copied()
641                }
642                _ => None,
643            };
644            let var_info = var_info.cloned().or_else(|| var_info_if_copy(self.variables, input))?;
645            self.var_info.insert(stmt.outputs[0], VarInfo::Box(var_info.into()));
646            Some(Statement::Const(StatementConst::new_boxed(const_value?, stmt.outputs[0])))
647        } else if id == self.unbox {
648            if let VarInfo::Box(inner) = self.var_info.get(&stmt.inputs[0].var_id)? {
649                let inner = inner.as_ref().clone();
650                if let VarInfo::Const(inner) =
651                    self.var_info.entry(stmt.outputs[0]).insert_entry(inner).get()
652                {
653                    return Some(Statement::Const(StatementConst::new_flat(
654                        *inner,
655                        stmt.outputs[0],
656                    )));
657                }
658            }
659            None
660        } else if self.upcast_fns.contains(&id) {
661            let int_value = self.as_int(stmt.inputs[0].var_id)?;
662            let output = stmt.outputs[0];
663            let value = ConstValue::Int(int_value.clone(), self.variables[output].ty).intern(db);
664            self.var_info.insert(output, VarInfo::Const(value));
665            Some(Statement::Const(StatementConst::new_flat(value, output)))
666        } else if id == self.array_new {
667            self.var_info.insert(stmt.outputs[0], VarInfo::Array(vec![]));
668            None
669        } else if id == self.array_append {
670            let mut var_infos =
671                if let VarInfo::Array(var_infos) = self.var_info.get(&stmt.inputs[0].var_id)? {
672                    var_infos.clone()
673                } else {
674                    return None;
675                };
676            let appended = stmt.inputs[1];
677            var_infos.push(match self.var_info.get(&appended.var_id) {
678                Some(var_info) => Some(var_info.clone()),
679                None => var_info_if_copy(self.variables, appended),
680            });
681            self.var_info.insert(stmt.outputs[0], VarInfo::Array(var_infos));
682            None
683        } else if id == self.array_len {
684            let info = self.var_info.get(&stmt.inputs[0].var_id)?;
685            let desnapped = try_extract_matches!(info, VarInfo::Snapshot)?;
686            let length = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?.len();
687            Some(self.propagate_const_and_get_statement(length.into(), stmt.outputs[0]))
688        } else {
689            None
690        }
691    }
692
693    /// Tries to specialize the call.
694    /// Returns The specialized call statement if it was specialized, or None otherwise.
695    ///
696    /// Specialization occurs only if `priv_should_specialize` returns true.
697    /// Additionally specialization of a callee the with the same base as the caller is currently
698    /// not supported.
699    fn try_specialize_call(&self, call_stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
700        if call_stmt.with_coupon {
701            return None;
702        }
703        // No specialization when avoiding inlining.
704        if matches!(self.db.optimizations().inlining_strategy(), InliningStrategy::Avoid) {
705            return None;
706        }
707
708        let Ok(Some(mut called_function)) = call_stmt.function.body(self.db) else {
709            return None;
710        };
711
712        let extract_base = |function: ConcreteFunctionWithBodyId<'db>| match function.long(self.db)
713        {
714            ConcreteFunctionWithBodyLongId::Specialized(specialized) => specialized.base,
715            _ => function,
716        };
717        let called_base = extract_base(called_function);
718        let caller_base = extract_base(self.caller_function);
719
720        if self.db.priv_never_inline(called_base).ok()? {
721            return None;
722        }
723
724        // Do not specialize the call that should be inlined.
725        if call_stmt.is_specialization_base_call {
726            return None;
727        }
728
729        // Do not specialize a recursive call that was already specialized.
730        if called_base == caller_base && called_function != called_base {
731            return None;
732        }
733
734        // Avoid specializing with a function that is in the same SCC as the caller (and is not the
735        // same function).
736        let scc =
737            self.db.lowered_scc(called_base, DependencyType::Call, LoweringStage::Monomorphized);
738        if scc.len() > 1 && scc.contains(&caller_base) {
739            return None;
740        }
741
742        if call_stmt.inputs.iter().all(|arg| self.var_info.get(&arg.var_id).is_none()) {
743            // No const inputs
744            return None;
745        }
746
747        // If we are specializing a recursive call, use only subset of the caller.
748        let self_specializition = if let ConcreteFunctionWithBodyLongId::Specialized(specialized) =
749            self.caller_function.long(self.db)
750            && caller_base == called_base
751        {
752            specialized.args.iter().map(Some).collect()
753        } else {
754            vec![None; call_stmt.inputs.len()]
755        };
756
757        let mut specialization_args = vec![];
758        let mut new_args = vec![];
759        for (arg, coerce) in zip_eq(&call_stmt.inputs, &self_specializition) {
760            if let Some(var_info) = self.var_info.get(&arg.var_id)
761                && self.variables[arg.var_id].info.droppable.is_ok()
762                && let Some(specialization_arg) = self.try_get_specialization_arg(
763                    var_info.clone(),
764                    self.variables[arg.var_id].ty,
765                    &mut new_args,
766                    *coerce,
767                )
768            {
769                specialization_args.push(specialization_arg);
770            } else {
771                specialization_args.push(SpecializationArg::NotSpecialized);
772                new_args.push(*arg);
773                continue;
774            };
775        }
776
777        if specialization_args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized)) {
778            // No argument was assigned -> no specialization.
779            return None;
780        }
781        if let ConcreteFunctionWithBodyLongId::Specialized(specialized_function) =
782            called_function.long(self.db)
783        {
784            // Canonicalize the specialization rather than adding a specialization of a specialized
785            // function.
786            called_function = specialized_function.base;
787            let mut new_args_iter = specialization_args.into_iter();
788            let mut old_args = specialized_function.args.to_vec();
789            let mut stack = vec![];
790            for arg in old_args.iter_mut().rev() {
791                stack.push(arg);
792            }
793            while let Some(arg) = stack.pop() {
794                match arg {
795                    SpecializationArg::Const { .. } => {}
796                    SpecializationArg::Snapshot(inner) => {
797                        stack.push(inner.as_mut());
798                    }
799                    SpecializationArg::Enum { payload, .. } => {
800                        stack.push(payload.as_mut());
801                    }
802                    SpecializationArg::Array(_, values) | SpecializationArg::Struct(values) => {
803                        for value in values.iter_mut().rev() {
804                            stack.push(value);
805                        }
806                    }
807                    SpecializationArg::NotSpecialized => {
808                        *arg = new_args_iter.next().unwrap_or(SpecializationArg::NotSpecialized);
809                    }
810                }
811            }
812            specialization_args = old_args;
813        }
814        let specialized =
815            SpecializedFunction { base: called_function, args: specialization_args.into() };
816        let specialized_func_id =
817            ConcreteFunctionWithBodyLongId::Specialized(specialized).intern(self.db);
818
819        if caller_base != called_base
820            && self.db.priv_should_specialize(specialized_func_id) == Ok(false)
821        {
822            return None;
823        }
824
825        Some(Statement::Call(StatementCall {
826            function: specialized_func_id.function_id(self.db).unwrap(),
827            inputs: new_args,
828            with_coupon: call_stmt.with_coupon,
829            outputs: std::mem::take(&mut call_stmt.outputs),
830            location: call_stmt.location,
831            is_specialization_base_call: false,
832        }))
833    }
834
835    /// Adds `value` as a const to `var_info` and return a const statement for it.
836    fn propagate_const_and_get_statement(
837        &mut self,
838        value: BigInt,
839        output: VariableId,
840    ) -> Statement<'db> {
841        let ty = self.variables[output].ty;
842        let value = ConstValueId::from_int(self.db, ty, &value);
843        self.var_info.insert(output, VarInfo::Const(value));
844        Statement::Const(StatementConst::new_flat(value, output))
845    }
846
847    /// Adds 0 const to `var_info` and return a const statement for it.
848    fn propagate_zero_and_get_statement(&mut self, output: VariableId) -> Statement<'db> {
849        self.propagate_const_and_get_statement(BigInt::zero(), output)
850    }
851
852    /// Returns a statement that introduces the requested value into `output`, or None if fails.
853    fn try_generate_const_statement(
854        &self,
855        value: ConstValueId<'db>,
856        output: VariableId,
857    ) -> Option<Statement<'db>> {
858        if self.db.type_size_info(self.variables[output].ty) == Ok(TypeSizeInformation::Other) {
859            Some(Statement::Const(StatementConst::new_flat(value, output)))
860        } else if matches!(value.long(self.db), ConstValue::Struct(members, _) if members.is_empty())
861        {
862            // Handling const empty structs - which are not supported in sierra-gen.
863            Some(Statement::StructConstruct(StatementStructConstruct { inputs: vec![], output }))
864        } else {
865            None
866        }
867    }
868
869    /// Handles the end of block matching on an enum.
870    /// Possibly extends the blocks statements as well.
871    /// Returns None if no additional changes are required.
872    /// If changes are required, returns the updated block end.
873    fn handle_enum_block_end(
874        &mut self,
875        info: &mut MatchEnumInfo<'db>,
876        statements: &mut Vec<Statement<'db>>,
877    ) -> Option<BlockEnd<'db>> {
878        let input = info.input.var_id;
879        let (n_snapshots, var_info) = self.var_info.get(&input)?.peel_snapshots();
880        let location = info.location;
881        let as_usage = |var_id| VarUsage { var_id, location };
882        let db = self.db;
883        let snapshot_stmt = |vars: &mut VariableArena<'_>, pre_snap, post_snap| {
884            let ignored = vars.alloc(vars[pre_snap].clone());
885            Statement::Snapshot(StatementSnapshot::new(as_usage(pre_snap), ignored, post_snap))
886        };
887        // Checking whether we have actual const info on the enum.
888        if let VarInfo::Const(const_value) = var_info
889            && let ConstValue::Enum(variant, value) = const_value.long(db)
890        {
891            let arm = &info.arms[variant.idx];
892            let output = arm.var_ids[0];
893            // Propagating the const value information.
894            self.var_info.insert(output, VarInfo::Const(*value).wrap_with_snapshots(n_snapshots));
895            if self.variables[input].info.droppable.is_ok()
896                && self.variables[output].info.copyable.is_ok()
897                && let Ok(mut ty) = value.ty(db)
898                && let Some(mut stmt) = self.try_generate_const_statement(*value, output)
899            {
900                // Adding snapshot taking statements for snapshots.
901                for _ in 0..n_snapshots {
902                    let non_snap_var = Variable::with_default_context(db, ty, location);
903                    ty = TypeLongId::Snapshot(ty).intern(db);
904                    let pre_snap = self.variables.alloc(non_snap_var);
905                    stmt.outputs_mut()[0] = pre_snap;
906                    let take_snap = snapshot_stmt(self.variables, pre_snap, output);
907                    statements.push(core::mem::replace(&mut stmt, take_snap));
908                }
909                statements.push(stmt);
910                return Some(BlockEnd::Goto(arm.block_id, Default::default()));
911            }
912        } else if let VarInfo::Enum { variant, payload } = var_info {
913            let arm = &info.arms[variant.idx];
914            let variant_ty = variant.ty;
915            let output = arm.var_ids[0];
916            let payload = payload.as_ref().clone();
917            let unwrapped =
918                self.variables[input].info.droppable.is_ok().then_some(()).and_then(|_| {
919                    let (extra_snapshots, inner) = payload.peel_snapshots();
920                    match inner {
921                        VarInfo::Var(var) if self.variables[var.var_id].info.copyable.is_ok() => {
922                            Some((var.var_id, extra_snapshots))
923                        }
924                        VarInfo::Const(value) => {
925                            let const_var = self
926                                .variables
927                                .alloc(Variable::with_default_context(db, variant_ty, location));
928                            statements.push(self.try_generate_const_statement(*value, const_var)?);
929                            Some((const_var, extra_snapshots))
930                        }
931                        _ => None,
932                    }
933                });
934            // Propagating the const value information.
935            self.var_info.insert(output, payload.wrap_with_snapshots(n_snapshots));
936            if let Some((mut unwrapped, extra_snapshots)) = unwrapped {
937                let total_snapshots = n_snapshots + extra_snapshots;
938                if total_snapshots != 0 {
939                    // Adding snapshot taking statements for snapshots.
940                    for _ in 1..total_snapshots {
941                        let ty = TypeLongId::Snapshot(self.variables[unwrapped].ty).intern(db);
942                        let non_snap_var = Variable::with_default_context(self.db, ty, location);
943                        let snapped = self.variables.alloc(non_snap_var);
944                        statements.push(snapshot_stmt(self.variables, unwrapped, snapped));
945                        unwrapped = snapped;
946                    }
947                    statements.push(snapshot_stmt(self.variables, unwrapped, output));
948                };
949                return Some(BlockEnd::Goto(arm.block_id, Default::default()));
950            }
951        }
952        None
953    }
954
955    /// Handles the end of a block based on an extern function call.
956    /// Possibly extends the blocks statements as well.
957    /// Returns None if no additional changes are required.
958    /// If changes are required, returns the updated block end.
959    fn handle_extern_block_end(
960        &mut self,
961        info: &mut MatchExternInfo<'db>,
962        statements: &mut Vec<Statement<'db>>,
963    ) -> Option<BlockEnd<'db>> {
964        let db = self.db;
965        let (id, generic_args) = info.function.get_extern(db)?;
966        if self.nz_fns.contains(&id) {
967            let val = self.as_const(info.inputs[0].var_id)?;
968            let is_zero = match val.long(db) {
969                ConstValue::Int(v, _) => v.is_zero(),
970                ConstValue::Struct(s, _) => s.iter().all(|v| {
971                    v.long(db).to_int().expect("Expected ConstValue::Int for size").is_zero()
972                }),
973                _ => unreachable!(),
974            };
975            Some(if is_zero {
976                BlockEnd::Goto(info.arms[0].block_id, Default::default())
977            } else {
978                let arm = &info.arms[1];
979                let nz_var = arm.var_ids[0];
980                let nz_val = ConstValue::NonZero(val).intern(db);
981                self.var_info.insert(nz_var, VarInfo::Const(nz_val));
982                statements.push(Statement::Const(StatementConst::new_flat(nz_val, nz_var)));
983                BlockEnd::Goto(arm.block_id, Default::default())
984            })
985        } else if self.eq_fns.contains(&id) {
986            let lhs = self.as_int(info.inputs[0].var_id);
987            let rhs = self.as_int(info.inputs[1].var_id);
988            if (lhs.map(Zero::is_zero).unwrap_or_default() && rhs.is_none())
989                || (rhs.map(Zero::is_zero).unwrap_or_default() && lhs.is_none())
990            {
991                let nz_input = info.inputs[if lhs.is_some() { 1 } else { 0 }];
992                let var = &self.variables[nz_input.var_id].clone();
993                let function = self.type_info.get(&var.ty)?.is_zero;
994                let unused_nz_var = Variable::with_default_context(
995                    db,
996                    corelib::core_nonzero_ty(db, var.ty),
997                    var.location,
998                );
999                let unused_nz_var = self.variables.alloc(unused_nz_var);
1000                return Some(BlockEnd::Match {
1001                    info: MatchInfo::Extern(MatchExternInfo {
1002                        function,
1003                        inputs: vec![nz_input],
1004                        arms: vec![
1005                            MatchArm {
1006                                arm_selector: MatchArmSelector::VariantId(
1007                                    corelib::jump_nz_zero_variant(db, var.ty),
1008                                ),
1009                                block_id: info.arms[1].block_id,
1010                                var_ids: vec![],
1011                            },
1012                            MatchArm {
1013                                arm_selector: MatchArmSelector::VariantId(
1014                                    corelib::jump_nz_nonzero_variant(db, var.ty),
1015                                ),
1016                                block_id: info.arms[0].block_id,
1017                                var_ids: vec![unused_nz_var],
1018                            },
1019                        ],
1020                        location: info.location,
1021                    }),
1022                });
1023            }
1024            Some(BlockEnd::Goto(
1025                info.arms[if lhs? == rhs? { 1 } else { 0 }].block_id,
1026                Default::default(),
1027            ))
1028        } else if self.uadd_fns.contains(&id)
1029            || self.usub_fns.contains(&id)
1030            || self.diff_fns.contains(&id)
1031            || self.iadd_fns.contains(&id)
1032            || self.isub_fns.contains(&id)
1033        {
1034            let rhs = self.as_int(info.inputs[1].var_id);
1035            let lhs = self.as_int(info.inputs[0].var_id);
1036            if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
1037                let ty = self.variables[info.arms[0].var_ids[0]].ty;
1038                let range = self.type_value_ranges.get(&ty)?;
1039                let value = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
1040                    lhs + rhs
1041                } else {
1042                    lhs - rhs
1043                };
1044                let (arm_index, value) = match range.normalized(value) {
1045                    NormalizedResult::InRange(value) => (0, value),
1046                    NormalizedResult::Under(value) => (1, value),
1047                    NormalizedResult::Over(value) => (
1048                        if self.iadd_fns.contains(&id) || self.isub_fns.contains(&id) {
1049                            2
1050                        } else {
1051                            1
1052                        },
1053                        value,
1054                    ),
1055                };
1056                let arm = &info.arms[arm_index];
1057                let actual_output = arm.var_ids[0];
1058                let value = ConstValue::Int(value, ty).intern(db);
1059                self.var_info.insert(actual_output, VarInfo::Const(value));
1060                statements.push(Statement::Const(StatementConst::new_flat(value, actual_output)));
1061                return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1062            }
1063            if let Some(rhs) = rhs {
1064                if rhs.is_zero() && !self.diff_fns.contains(&id) {
1065                    let arm = &info.arms[0];
1066                    self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[0]));
1067                    return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1068                }
1069                if rhs.is_one() && !self.diff_fns.contains(&id) {
1070                    let ty = self.variables[info.arms[0].var_ids[0]].ty;
1071                    let ty_info = self.type_info.get(&ty)?;
1072                    let function = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
1073                        ty_info.inc?
1074                    } else {
1075                        ty_info.dec?
1076                    };
1077                    let enum_ty = function.signature(db).ok()?.return_type;
1078                    let TypeLongId::Concrete(ConcreteTypeId::Enum(concrete_enum_id)) =
1079                        enum_ty.long(db)
1080                    else {
1081                        return None;
1082                    };
1083                    let result = self.variables.alloc(Variable::with_default_context(
1084                        db,
1085                        function.signature(db).unwrap().return_type,
1086                        info.location,
1087                    ));
1088                    statements.push(Statement::Call(StatementCall {
1089                        function,
1090                        inputs: vec![info.inputs[0]],
1091                        with_coupon: false,
1092                        outputs: vec![result],
1093                        location: info.location,
1094                        is_specialization_base_call: false,
1095                    }));
1096                    return Some(BlockEnd::Match {
1097                        info: MatchInfo::Enum(MatchEnumInfo {
1098                            concrete_enum_id: *concrete_enum_id,
1099                            input: VarUsage { var_id: result, location: info.location },
1100                            arms: core::mem::take(&mut info.arms),
1101                            location: info.location,
1102                        }),
1103                    });
1104                }
1105            }
1106            if let Some(lhs) = lhs
1107                && lhs.is_zero()
1108                && (self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id))
1109            {
1110                let arm = &info.arms[0];
1111                self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[1]));
1112                return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1113            }
1114            None
1115        } else if let Some(reversed) = self.downcast_fns.get(&id) {
1116            let range = |ty: TypeId<'_>| {
1117                Some(if let Some(range) = self.type_value_ranges.get(&ty) {
1118                    range.clone()
1119                } else {
1120                    let (min, max) = corelib::try_extract_bounded_int_type_ranges(db, ty)?;
1121                    TypeRange { min, max }
1122                })
1123            };
1124            let (success_arm, failure_arm) = if *reversed { (1, 0) } else { (0, 1) };
1125            let input_var = info.inputs[0].var_id;
1126            let in_ty = self.variables[input_var].ty;
1127            let success_output = info.arms[success_arm].var_ids[0];
1128            let out_ty = self.variables[success_output].ty;
1129            let out_range = range(out_ty)?;
1130            let Some(value) = self.as_int(input_var) else {
1131                let in_range = range(in_ty)?;
1132                return if in_range.min < out_range.min || in_range.max > out_range.max {
1133                    None
1134                } else {
1135                    let generic_args = [in_ty, out_ty].map(GenericArgumentId::Type).to_vec();
1136                    let function = db.core_info().upcast_fn.concretize(db, generic_args);
1137                    statements.push(Statement::Call(StatementCall {
1138                        function: function.lowered(db),
1139                        inputs: vec![info.inputs[0]],
1140                        with_coupon: false,
1141                        outputs: vec![success_output],
1142                        location: info.location,
1143                        is_specialization_base_call: false,
1144                    }));
1145                    return Some(BlockEnd::Goto(
1146                        info.arms[success_arm].block_id,
1147                        Default::default(),
1148                    ));
1149                };
1150            };
1151            let value = if in_ty == self.felt252 {
1152                felt252_for_downcast(value, &out_range.min)
1153            } else {
1154                value.clone()
1155            };
1156            Some(if let NormalizedResult::InRange(value) = out_range.normalized(value) {
1157                let value = ConstValue::Int(value, out_ty).intern(db);
1158                self.var_info.insert(success_output, VarInfo::Const(value));
1159                statements.push(Statement::Const(StatementConst::new_flat(value, success_output)));
1160                BlockEnd::Goto(info.arms[success_arm].block_id, Default::default())
1161            } else {
1162                BlockEnd::Goto(info.arms[failure_arm].block_id, Default::default())
1163            })
1164        } else if id == self.bounded_int_constrain {
1165            let input_var = info.inputs[0].var_id;
1166            let value = self.as_int(input_var)?;
1167            let generic_arg = generic_args[1];
1168            let constrain_value = extract_matches!(generic_arg, GenericArgumentId::Constant)
1169                .long(db)
1170                .to_int()
1171                .expect("Expected ConstValue::Int for size");
1172            let arm_idx = if value < constrain_value { 0 } else { 1 };
1173            let output = info.arms[arm_idx].var_ids[0];
1174            statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1175            Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1176        } else if id == self.bounded_int_trim_min {
1177            let input_var = info.inputs[0].var_id;
1178            let ConstValue::Int(value, ty) = self.as_const(input_var)?.long(self.db) else {
1179                return None;
1180            };
1181            let is_trimmed = if let Some(range) = self.type_value_ranges.get(ty) {
1182                range.min == *value
1183            } else {
1184                corelib::try_extract_bounded_int_type_ranges(db, *ty)?.0 == *value
1185            };
1186            let arm_idx = if is_trimmed {
1187                0
1188            } else {
1189                let output = info.arms[1].var_ids[0];
1190                statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1191                1
1192            };
1193            Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1194        } else if id == self.bounded_int_trim_max {
1195            let input_var = info.inputs[0].var_id;
1196            let ConstValue::Int(value, ty) = self.as_const(input_var)?.long(self.db) else {
1197                return None;
1198            };
1199            let is_trimmed = if let Some(range) = self.type_value_ranges.get(ty) {
1200                range.max == *value
1201            } else {
1202                corelib::try_extract_bounded_int_type_ranges(db, *ty)?.1 == *value
1203            };
1204            let arm_idx = if is_trimmed {
1205                0
1206            } else {
1207                let output = info.arms[1].var_ids[0];
1208                statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1209                1
1210            };
1211            Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1212        } else if id == self.array_get {
1213            let index = self.as_int(info.inputs[1].var_id)?.to_usize()?;
1214            if let Some(VarInfo::Snapshot(arr_info)) = self.var_info.get(&info.inputs[0].var_id)
1215                && let VarInfo::Array(infos) = arr_info.as_ref()
1216            {
1217                match infos.get(index) {
1218                    Some(Some(output_var_info)) => {
1219                        let arm = &info.arms[0];
1220                        let output_var_info = output_var_info.clone();
1221                        let box_info =
1222                            VarInfo::Box(VarInfo::Snapshot(output_var_info.clone().into()).into());
1223                        self.var_info.insert(arm.var_ids[0], box_info);
1224                        if let VarInfo::Const(value) = output_var_info {
1225                            let value_ty = value.ty(db).ok()?;
1226                            let value_box_ty = corelib::core_box_ty(db, value_ty);
1227                            let location = info.location;
1228                            let boxed_var =
1229                                Variable::with_default_context(db, value_box_ty, location);
1230                            let boxed = self.variables.alloc(boxed_var.clone());
1231                            let unused_boxed = self.variables.alloc(boxed_var);
1232                            let snapped = self.variables.alloc(Variable::with_default_context(
1233                                db,
1234                                TypeLongId::Snapshot(value_box_ty).intern(db),
1235                                location,
1236                            ));
1237                            statements.extend([
1238                                Statement::Const(StatementConst::new_boxed(value, boxed)),
1239                                Statement::Snapshot(StatementSnapshot {
1240                                    input: VarUsage { var_id: boxed, location },
1241                                    outputs: [unused_boxed, snapped],
1242                                }),
1243                                Statement::Call(StatementCall {
1244                                    function: self
1245                                        .box_forward_snapshot
1246                                        .concretize(db, vec![GenericArgumentId::Type(value_ty)])
1247                                        .lowered(db),
1248                                    inputs: vec![VarUsage { var_id: snapped, location }],
1249                                    with_coupon: false,
1250                                    outputs: vec![arm.var_ids[0]],
1251                                    location: info.location,
1252                                    is_specialization_base_call: false,
1253                                }),
1254                            ]);
1255                            return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1256                        }
1257                    }
1258                    None => {
1259                        return Some(BlockEnd::Goto(info.arms[1].block_id, Default::default()));
1260                    }
1261                    Some(None) => {}
1262                }
1263            }
1264            if index.is_zero()
1265                && let [success, failure] = info.arms.as_mut_slice()
1266            {
1267                let arr = info.inputs[0].var_id;
1268                let unused_arr_output0 = self.variables.alloc(self.variables[arr].clone());
1269                let unused_arr_output1 = self.variables.alloc(self.variables[arr].clone());
1270                info.inputs.truncate(1);
1271                info.function = GenericFunctionId::Extern(self.array_snapshot_pop_front)
1272                    .concretize(db, generic_args)
1273                    .lowered(db);
1274                success.var_ids.insert(0, unused_arr_output0);
1275                failure.var_ids.insert(0, unused_arr_output1);
1276            }
1277            None
1278        } else if id == self.array_pop_front {
1279            let VarInfo::Array(var_infos) = self.var_info.get(&info.inputs[0].var_id)? else {
1280                return None;
1281            };
1282            if let Some(first) = var_infos.first() {
1283                if let Some(first) = first.as_ref().cloned() {
1284                    let arm = &info.arms[0];
1285                    self.var_info.insert(arm.var_ids[0], VarInfo::Array(var_infos[1..].to_vec()));
1286                    self.var_info.insert(arm.var_ids[1], VarInfo::Box(first.into()));
1287                }
1288                None
1289            } else {
1290                let arm = &info.arms[1];
1291                self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]));
1292                Some(BlockEnd::Goto(
1293                    arm.block_id,
1294                    VarRemapping {
1295                        remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1296                    },
1297                ))
1298            }
1299        } else if id == self.array_snapshot_pop_back || id == self.array_snapshot_pop_front {
1300            let var_info = self.var_info.get(&info.inputs[0].var_id)?;
1301            let desnapped = try_extract_matches!(var_info, VarInfo::Snapshot)?;
1302            let element_var_infos = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?;
1303            // TODO(orizi): Propagate success values as well.
1304            if element_var_infos.is_empty() {
1305                let arm = &info.arms[1];
1306                self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]));
1307                Some(BlockEnd::Goto(
1308                    arm.block_id,
1309                    VarRemapping {
1310                        remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1311                    },
1312                ))
1313            } else {
1314                None
1315            }
1316        } else {
1317            None
1318        }
1319    }
1320
1321    /// Returns the const value of a variable if it exists.
1322    fn as_const(&self, var_id: VariableId) -> Option<ConstValueId<'db>> {
1323        try_extract_matches!(self.var_info.get(&var_id)?, VarInfo::Const).copied()
1324    }
1325
1326    /// Returns the const value as an int if it exists and is an integer.
1327    fn as_int(&self, var_id: VariableId) -> Option<&BigInt> {
1328        match self.as_const(var_id)?.long(self.db) {
1329            ConstValue::Int(value, _) => Some(value),
1330            ConstValue::NonZero(const_value) => {
1331                if let ConstValue::Int(value, _) = const_value.long(self.db) {
1332                    Some(value)
1333                } else {
1334                    None
1335                }
1336            }
1337            _ => None,
1338        }
1339    }
1340
1341    /// Replaces the inputs in place if they are in the var_info map.
1342    fn maybe_replace_inputs(&self, inputs: &mut [VarUsage<'db>]) {
1343        for input in inputs {
1344            self.maybe_replace_input(input);
1345        }
1346    }
1347
1348    /// Replaces the input in place if it is in the var_info map.
1349    fn maybe_replace_input(&self, input: &mut VarUsage<'db>) {
1350        if let Some(VarInfo::Var(new_var)) = self.var_info.get(&input.var_id) {
1351            *input = *new_var;
1352        }
1353    }
1354
1355    /// Given a var_info and its type, return the corresponding specialization argument, if it
1356    /// exists.
1357    ///
1358    /// The `coerce` argument is used to constrain the specialization argument of recursive calls to
1359    /// the value that is used by the caller.
1360    fn try_get_specialization_arg(
1361        &self,
1362        var_info: VarInfo<'db>,
1363        ty: TypeId<'db>,
1364        unknown_vars: &mut Vec<VarUsage<'db>>,
1365        coerce: Option<&SpecializationArg<'db>>,
1366    ) -> Option<SpecializationArg<'db>> {
1367        // Skip zero-sized constants as they are not supported in sierra-gen.
1368        require(self.db.type_size_info(ty).ok()? != TypeSizeInformation::ZeroSized)?;
1369        // Skip specialization arguments if they are coerced to not be specialized.
1370        require(!matches!(coerce, Some(SpecializationArg::NotSpecialized)))?;
1371
1372        match var_info {
1373            VarInfo::Const(value) => {
1374                let res = const_to_specialization_arg(self.db, value, false);
1375                let Some(coerce) = coerce else {
1376                    return Some(res);
1377                };
1378                if *coerce == res { Some(res) } else { None }
1379            }
1380            VarInfo::Box(info) => {
1381                let res = try_extract_matches!(info.as_ref(), VarInfo::Const)
1382                    .map(|value| SpecializationArg::Const { value: *value, boxed: true });
1383                let Some(coerce) = coerce else {
1384                    return res;
1385                };
1386                if Some(coerce.clone()) == res { res } else { None }
1387            }
1388            VarInfo::Snapshot(info) => {
1389                let desnap_ty = *extract_matches!(ty.long(self.db), TypeLongId::Snapshot);
1390                // Use a local accumulator to avoid mutating unknown_vars if we return None.
1391                let mut local_unknown_vars: Vec<VarUsage<'db>> = Vec::new();
1392                let inner = self.try_get_specialization_arg(
1393                    info.as_ref().clone(),
1394                    desnap_ty,
1395                    &mut local_unknown_vars,
1396                    coerce.map(|coerce| {
1397                        extract_matches!(coerce, SpecializationArg::Snapshot).as_ref()
1398                    }),
1399                )?;
1400                unknown_vars.extend(local_unknown_vars);
1401                Some(SpecializationArg::Snapshot(Box::new(inner)))
1402            }
1403            VarInfo::Array(infos) => {
1404                let TypeLongId::Concrete(concrete_ty) = ty.long(self.db) else {
1405                    unreachable!("Expected a concrete type");
1406                };
1407                let [GenericArgumentId::Type(inner_ty)] = &concrete_ty.generic_args(self.db)[..]
1408                else {
1409                    unreachable!("Expected a single type generic argument");
1410                };
1411                let coerces = match coerce {
1412                    Some(coerce) => {
1413                        let SpecializationArg::Array(ty, specialization_args) = coerce else {
1414                            unreachable!("Expected an array specialization argument");
1415                        };
1416                        assert_eq!(ty, inner_ty);
1417                        if specialization_args.len() != infos.len() {
1418                            return None;
1419                        }
1420
1421                        specialization_args.iter().map(Some).collect()
1422                    }
1423                    None => vec![None; infos.len()],
1424                };
1425                // Accumulate into locals first; only extend unknown_vars if we end up specializing.
1426                let mut vars = vec![];
1427                let mut args = vec![];
1428                for (info, coerce) in zip_eq(infos, coerces) {
1429                    let info = info?;
1430                    let arg =
1431                        self.try_get_specialization_arg(info, *inner_ty, &mut vars, coerce)?;
1432                    args.push(arg);
1433                }
1434                if !args.is_empty()
1435                    && args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
1436                {
1437                    return None;
1438                }
1439                unknown_vars.extend(vars);
1440                Some(SpecializationArg::Array(*inner_ty, args))
1441            }
1442            VarInfo::Struct(infos) => {
1443                // Get element types based on the type.
1444                let element_types: Vec<TypeId<'db>> = match ty.long(self.db) {
1445                    TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) => {
1446                        let members = self.db.concrete_struct_members(*concrete_struct).unwrap();
1447                        members.values().map(|member| member.ty).collect()
1448                    }
1449                    TypeLongId::Tuple(element_types) => element_types.clone(),
1450                    TypeLongId::FixedSizeArray { type_id, .. } => vec![*type_id; infos.len()],
1451                    // For closures and other types, don't specialize.
1452                    _ => return None,
1453                };
1454
1455                let coerces = match coerce {
1456                    Some(SpecializationArg::Struct(specialization_args)) => {
1457                        assert_eq!(specialization_args.len(), infos.len());
1458                        specialization_args.iter().map(Some).collect()
1459                    }
1460                    Some(_) => unreachable!("Expected a struct specialization argument"),
1461                    None => vec![None; infos.len()],
1462                };
1463
1464                let mut struct_args = Vec::new();
1465                // Accumulate into locals first; only extend unknown_vars if we end up specializing.
1466                let mut vars = vec![];
1467                for ((elem_ty, opt_var_info), coerce) in
1468                    zip_eq(zip_eq(element_types, infos), coerces)
1469                {
1470                    let var_info = opt_var_info?;
1471                    let arg =
1472                        self.try_get_specialization_arg(var_info, elem_ty, &mut vars, coerce)?;
1473                    struct_args.push(arg);
1474                }
1475                if !struct_args.is_empty()
1476                    && struct_args
1477                        .iter()
1478                        .all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
1479                {
1480                    return None;
1481                }
1482                unknown_vars.extend(vars);
1483                Some(SpecializationArg::Struct(struct_args))
1484            }
1485            VarInfo::Enum { variant, payload } => {
1486                let coerce = match coerce {
1487                    Some(coerce) => {
1488                        let SpecializationArg::Enum { variant: coercion_variant, payload } = coerce
1489                        else {
1490                            unreachable!("Expected an enum specialization argument");
1491                        };
1492                        if *coercion_variant != variant {
1493                            return None;
1494                        }
1495                        Some(payload.as_ref())
1496                    }
1497                    None => None,
1498                };
1499                let mut local_unknown_vars = vec![];
1500                let payload_arg = self.try_get_specialization_arg(
1501                    payload.as_ref().clone(),
1502                    variant.ty,
1503                    &mut local_unknown_vars,
1504                    coerce,
1505                )?;
1506
1507                unknown_vars.extend(local_unknown_vars);
1508                Some(SpecializationArg::Enum { variant, payload: Box::new(payload_arg) })
1509            }
1510            VarInfo::Var(var_usage) => {
1511                unknown_vars.push(var_usage);
1512                Some(SpecializationArg::NotSpecialized)
1513            }
1514        }
1515    }
1516
1517    /// Returns true if const-folding should be skipped for the current function.
1518    pub fn should_skip_const_folding(&self, db: &'db dyn Database) -> bool {
1519        if db.optimizations().skip_const_folding() {
1520            return true;
1521        }
1522
1523        // Skipping const-folding for `panic_with_const_felt252` - to avoid replacing a call to
1524        // `panic_with_felt252` with `panic_with_const_felt252` and causing accidental recursion.
1525        if self.caller_function.base_semantic_function(db).generic_function(db)
1526            == GenericFunctionWithBodyId::Free(self.libfunc_info.panic_with_const_felt252)
1527        {
1528            return true;
1529        }
1530        false
1531    }
1532}
1533
1534/// Returns a `VarInfo` of a variable only if it is copyable.
1535fn var_info_if_copy<'db>(
1536    variables: &VariableArena<'db>,
1537    input: VarUsage<'db>,
1538) -> Option<VarInfo<'db>> {
1539    variables[input.var_id].info.copyable.is_ok().then_some(VarInfo::Var(input))
1540}
1541
1542/// Internal query for the libfuncs information required for const folding.
1543#[salsa::tracked(returns(ref))]
1544fn priv_const_folding_info<'db>(
1545    db: &'db dyn Database,
1546) -> crate::optimizations::const_folding::ConstFoldingLibfuncInfo<'db> {
1547    ConstFoldingLibfuncInfo::new(db)
1548}
1549
1550/// Holds static information about libfuncs required for the optimization.
1551#[derive(Debug, PartialEq, Eq, salsa::Update)]
1552pub struct ConstFoldingLibfuncInfo<'db> {
1553    /// The `felt252_sub` libfunc.
1554    felt_sub: ExternFunctionId<'db>,
1555    /// The `felt252_add` libfunc.
1556    felt_add: ExternFunctionId<'db>,
1557    /// The `felt252_mul` libfunc.
1558    felt_mul: ExternFunctionId<'db>,
1559    /// The `felt252_div` libfunc.
1560    felt_div: ExternFunctionId<'db>,
1561    /// The `into_box` libfunc.
1562    into_box: ExternFunctionId<'db>,
1563    /// The `unbox` libfunc.
1564    unbox: ExternFunctionId<'db>,
1565    /// The `box_forward_snapshot` libfunc.
1566    box_forward_snapshot: GenericFunctionId<'db>,
1567    /// The set of functions that check if numbers are equal.
1568    eq_fns: OrderedHashSet<ExternFunctionId<'db>>,
1569    /// The set of functions to add unsigned ints.
1570    uadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1571    /// The set of functions to subtract unsigned ints.
1572    usub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1573    /// The set of functions to get the difference of signed ints.
1574    diff_fns: OrderedHashSet<ExternFunctionId<'db>>,
1575    /// The set of functions to add signed ints.
1576    iadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1577    /// The set of functions to subtract signed ints.
1578    isub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1579    /// The set of functions to multiply integers.
1580    wide_mul_fns: OrderedHashSet<ExternFunctionId<'db>>,
1581    /// The set of functions to divide and get the remainder of integers.
1582    div_rem_fns: OrderedHashSet<ExternFunctionId<'db>>,
1583    /// The `bounded_int_add` libfunc.
1584    bounded_int_add: ExternFunctionId<'db>,
1585    /// The `bounded_int_sub` libfunc.
1586    bounded_int_sub: ExternFunctionId<'db>,
1587    /// The `bounded_int_constrain` libfunc.
1588    bounded_int_constrain: ExternFunctionId<'db>,
1589    /// The `bounded_int_trim_min` libfunc.
1590    bounded_int_trim_min: ExternFunctionId<'db>,
1591    /// The `bounded_int_trim_max` libfunc.
1592    bounded_int_trim_max: ExternFunctionId<'db>,
1593    /// The `array_get` libfunc.
1594    array_get: ExternFunctionId<'db>,
1595    /// The `array_snapshot_pop_front` libfunc.
1596    array_snapshot_pop_front: ExternFunctionId<'db>,
1597    /// The `array_snapshot_pop_back` libfunc.
1598    array_snapshot_pop_back: ExternFunctionId<'db>,
1599    /// The `array_len` libfunc.
1600    array_len: ExternFunctionId<'db>,
1601    /// The `array_new` libfunc.
1602    array_new: ExternFunctionId<'db>,
1603    /// The `array_append` libfunc.
1604    array_append: ExternFunctionId<'db>,
1605    /// The `array_pop_front` libfunc.
1606    array_pop_front: ExternFunctionId<'db>,
1607    /// The `storage_base_address_from_felt252` libfunc.
1608    storage_base_address_from_felt252: ExternFunctionId<'db>,
1609    /// The `storage_base_address_const` libfunc.
1610    storage_base_address_const: GenericFunctionId<'db>,
1611    /// The `core::panic_with_felt252` function.
1612    panic_with_felt252: FunctionId<'db>,
1613    /// The `core::panic_with_const_felt252` function.
1614    pub panic_with_const_felt252: FreeFunctionId<'db>,
1615    /// The `core::panics::panic_with_byte_array` function.
1616    panic_with_byte_array: FunctionId<'db>,
1617    /// Information per type.
1618    type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>>,
1619    /// The info used for semantic const calculation.
1620    const_calculation_info: Arc<ConstCalcInfo<'db>>,
1621}
1622impl<'db> ConstFoldingLibfuncInfo<'db> {
1623    fn new(db: &'db dyn Database) -> Self {
1624        let core = ModuleHelper::core(db);
1625        let box_module = core.submodule("box");
1626        let integer_module = core.submodule("integer");
1627        let internal_module = core.submodule("internal");
1628        let bounded_int_module = internal_module.submodule("bounded_int");
1629        let num_module = internal_module.submodule("num");
1630        let array_module = core.submodule("array");
1631        let starknet_module = core.submodule("starknet");
1632        let storage_access_module = starknet_module.submodule("storage_access");
1633        let utypes = ["u8", "u16", "u32", "u64", "u128"];
1634        let itypes = ["i8", "i16", "i32", "i64", "i128"];
1635        let eq_fns = OrderedHashSet::<_>::from_iter(
1636            chain!(utypes, itypes).map(|ty| integer_module.extern_function_id(&format!("{ty}_eq"))),
1637        );
1638        let uadd_fns = OrderedHashSet::<_>::from_iter(
1639            utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_add"))),
1640        );
1641        let usub_fns = OrderedHashSet::<_>::from_iter(
1642            utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_sub"))),
1643        );
1644        let diff_fns = OrderedHashSet::<_>::from_iter(
1645            itypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_diff"))),
1646        );
1647        let iadd_fns =
1648            OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1649                integer_module.extern_function_id(&format!("{ty}_overflowing_add_impl"))
1650            }));
1651        let isub_fns =
1652            OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1653                integer_module.extern_function_id(&format!("{ty}_overflowing_sub_impl"))
1654            }));
1655        let wide_mul_fns = OrderedHashSet::<_>::from_iter(chain!(
1656            [bounded_int_module.extern_function_id("bounded_int_mul")],
1657            ["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
1658                .map(|ty| integer_module.extern_function_id(&format!("{ty}_wide_mul"))),
1659        ));
1660        let div_rem_fns = OrderedHashSet::<_>::from_iter(chain!(
1661            [bounded_int_module.extern_function_id("bounded_int_div_rem")],
1662            utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_safe_divmod"))),
1663        ));
1664        let type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>> = OrderedHashMap::from_iter(
1665            [
1666                ("u8", false, true),
1667                ("u16", false, true),
1668                ("u32", false, true),
1669                ("u64", false, true),
1670                ("u128", false, true),
1671                ("u256", false, false),
1672                ("i8", true, true),
1673                ("i16", true, true),
1674                ("i32", true, true),
1675                ("i64", true, true),
1676                ("i128", true, true),
1677            ]
1678            .map(|(ty_name, as_bounded_int, inc_dec): (&'static str, bool, bool)| {
1679                let ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, ty_name), vec![]);
1680                let is_zero = if as_bounded_int {
1681                    bounded_int_module
1682                        .function_id("bounded_int_is_zero", vec![GenericArgumentId::Type(ty)])
1683                } else {
1684                    integer_module.function_id(
1685                        SmolStrId::from(db, format!("{ty_name}_is_zero")).long(db).as_str(),
1686                        vec![],
1687                    )
1688                }
1689                .lowered(db);
1690                let (inc, dec) = if inc_dec {
1691                    (
1692                        Some(
1693                            num_module
1694                                .function_id(
1695                                    SmolStrId::from(db, format!("{ty_name}_inc")).long(db).as_str(),
1696                                    vec![],
1697                                )
1698                                .lowered(db),
1699                        ),
1700                        Some(
1701                            num_module
1702                                .function_id(
1703                                    SmolStrId::from(db, format!("{ty_name}_dec")).long(db).as_str(),
1704                                    vec![],
1705                                )
1706                                .lowered(db),
1707                        ),
1708                    )
1709                } else {
1710                    (None, None)
1711                };
1712                let info = TypeInfo { is_zero, inc, dec };
1713                (ty, info)
1714            }),
1715        );
1716        Self {
1717            felt_sub: core.extern_function_id("felt252_sub"),
1718            felt_add: core.extern_function_id("felt252_add"),
1719            felt_mul: core.extern_function_id("felt252_mul"),
1720            felt_div: core.extern_function_id("felt252_div"),
1721            into_box: box_module.extern_function_id("into_box"),
1722            unbox: box_module.extern_function_id("unbox"),
1723            box_forward_snapshot: box_module.generic_function_id("box_forward_snapshot"),
1724            eq_fns,
1725            uadd_fns,
1726            usub_fns,
1727            diff_fns,
1728            iadd_fns,
1729            isub_fns,
1730            wide_mul_fns,
1731            div_rem_fns,
1732            bounded_int_add: bounded_int_module.extern_function_id("bounded_int_add"),
1733            bounded_int_sub: bounded_int_module.extern_function_id("bounded_int_sub"),
1734            bounded_int_constrain: bounded_int_module.extern_function_id("bounded_int_constrain"),
1735            bounded_int_trim_min: bounded_int_module.extern_function_id("bounded_int_trim_min"),
1736            bounded_int_trim_max: bounded_int_module.extern_function_id("bounded_int_trim_max"),
1737            array_get: array_module.extern_function_id("array_get"),
1738            array_snapshot_pop_front: array_module.extern_function_id("array_snapshot_pop_front"),
1739            array_snapshot_pop_back: array_module.extern_function_id("array_snapshot_pop_back"),
1740            array_len: array_module.extern_function_id("array_len"),
1741            array_new: array_module.extern_function_id("array_new"),
1742            array_append: array_module.extern_function_id("array_append"),
1743            array_pop_front: array_module.extern_function_id("array_pop_front"),
1744            storage_base_address_from_felt252: storage_access_module
1745                .extern_function_id("storage_base_address_from_felt252"),
1746            storage_base_address_const: storage_access_module
1747                .generic_function_id("storage_base_address_const"),
1748            panic_with_felt252: core.function_id("panic_with_felt252", vec![]).lowered(db),
1749            panic_with_const_felt252: core.free_function_id("panic_with_const_felt252"),
1750            panic_with_byte_array: core
1751                .submodule("panics")
1752                .function_id("panic_with_byte_array", vec![])
1753                .lowered(db),
1754            type_info,
1755            const_calculation_info: db.const_calc_info(),
1756        }
1757    }
1758}
1759
1760impl<'db> std::ops::Deref for ConstFoldingContext<'db, '_> {
1761    type Target = ConstFoldingLibfuncInfo<'db>;
1762    fn deref(&self) -> &ConstFoldingLibfuncInfo<'db> {
1763        self.libfunc_info
1764    }
1765}
1766
1767impl<'a> std::ops::Deref for ConstFoldingLibfuncInfo<'a> {
1768    type Target = ConstCalcInfo<'a>;
1769    fn deref(&self) -> &ConstCalcInfo<'a> {
1770        &self.const_calculation_info
1771    }
1772}
1773
1774/// The information of a type required for const foldings.
1775#[derive(Debug, PartialEq, Eq, salsa::Update)]
1776struct TypeInfo<'db> {
1777    /// The function to check if the value is zero for the type.
1778    is_zero: FunctionId<'db>,
1779    /// Inc function to increase the value by one.
1780    inc: Option<FunctionId<'db>>,
1781    /// Dec function to decrease the value by one.
1782    dec: Option<FunctionId<'db>>,
1783}
1784
1785trait TypeRangeNormalizer {
1786    /// Normalizes the value to the range.
1787    /// Assumes the value is within size of range of the range.
1788    fn normalized(&self, value: BigInt) -> NormalizedResult;
1789}
1790impl TypeRangeNormalizer for TypeRange {
1791    fn normalized(&self, value: BigInt) -> NormalizedResult {
1792        if value < self.min {
1793            NormalizedResult::Under(value - &self.min + &self.max + 1)
1794        } else if value > self.max {
1795            NormalizedResult::Over(value + &self.min - &self.max - 1)
1796        } else {
1797            NormalizedResult::InRange(value)
1798        }
1799    }
1800}
1801
1802/// The result of normalizing a value to a range.
1803enum NormalizedResult {
1804    /// The original value is in the range, carries the value, or an equivalent value.
1805    InRange(BigInt),
1806    /// The original value is larger than range max, carries the normalized value.
1807    Over(BigInt),
1808    /// The original value is smaller than range min, carries the normalized value.
1809    Under(BigInt),
1810}