cairo_lang_lowering/optimizations/
const_folding.rs

1#[cfg(test)]
2#[path = "const_folding_test.rs"]
3mod test;
4
5use std::sync::Arc;
6
7use cairo_lang_defs::ids::{ExternFunctionId, FreeFunctionId};
8use cairo_lang_filesystem::flag::flag_unsafe_panic;
9use cairo_lang_filesystem::ids::SmolStrId;
10use cairo_lang_semantic::corelib::CorelibSemantic;
11use cairo_lang_semantic::helper::ModuleHelper;
12use cairo_lang_semantic::items::constant::{
13    ConstCalcInfo, ConstValue, ConstValueId, ConstantSemantic, TypeRange, canonical_felt252,
14    felt252_for_downcast,
15};
16use cairo_lang_semantic::items::functions::{GenericFunctionId, GenericFunctionWithBodyId};
17use cairo_lang_semantic::items::structure::StructSemantic;
18use cairo_lang_semantic::types::{TypeSizeInformation, TypesSemantic};
19use cairo_lang_semantic::{
20    ConcreteTypeId, ConcreteVariant, GenericArgumentId, MatchArmSelector, TypeId, TypeLongId,
21    corelib,
22};
23use cairo_lang_utils::byte_array::BYTE_ARRAY_MAGIC;
24use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
25use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
26use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
27use cairo_lang_utils::{Intern, extract_matches, try_extract_matches};
28use itertools::{chain, zip_eq};
29use num_bigint::BigInt;
30use num_integer::Integer;
31use num_traits::cast::ToPrimitive;
32use num_traits::{Num, One, Zero};
33use salsa::Database;
34use starknet_types_core::felt::Felt as Felt252;
35
36use crate::db::LoweringGroup;
37use crate::ids::{
38    ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionId, SemanticFunctionIdEx,
39    SpecializedFunction,
40};
41use crate::specialization::SpecializationArg;
42use crate::utils::InliningStrategy;
43use crate::{
44    Block, BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, MatchArm, MatchEnumInfo,
45    MatchExternInfo, MatchInfo, Statement, StatementCall, StatementConst, StatementDesnap,
46    StatementEnumConstruct, StatementSnapshot, StatementStructConstruct,
47    StatementStructDestructure, VarRemapping, VarUsage, Variable, VariableArena, VariableId,
48};
49
50/// Converts a const value to a specialization arg.
51/// For struct and enum const values, recursively converts to SpecializationArg::Struct/Enum.
52fn const_to_specialization_arg<'db>(
53    db: &'db dyn Database,
54    value: ConstValueId<'db>,
55    boxed: bool,
56) -> SpecializationArg<'db> {
57    match value.long(db) {
58        ConstValue::Struct(members, ty) => {
59            // Only convert to SpecializationArg::Struct if the type is actually a concrete struct,
60            // not a closure or fixed size array.
61            if matches!(ty.long(db), TypeLongId::Concrete(ConcreteTypeId::Struct(_))) {
62                let args = members
63                    .iter()
64                    .map(|member| const_to_specialization_arg(db, *member, false))
65                    .collect();
66                SpecializationArg::Struct(args)
67            } else {
68                SpecializationArg::Const { value, boxed }
69            }
70        }
71        ConstValue::Enum(variant, payload) => SpecializationArg::Enum {
72            variant: *variant,
73            payload: Box::new(const_to_specialization_arg(db, *payload, false)),
74        },
75        _ => SpecializationArg::Const { value, boxed },
76    }
77}
78
79/// Keeps track of equivalent values that variables might be replaced with.
80/// Note: We don't keep track of types as we assume the usage is always correct.
81#[derive(Debug, Clone)]
82enum VarInfo<'db> {
83    /// The variable is a const value.
84    Const(ConstValueId<'db>),
85    /// The variable can be replaced by another variable.
86    Var(VarUsage<'db>),
87    /// The variable is a snapshot of another variable.
88    Snapshot(Box<VarInfo<'db>>),
89    /// The variable is a struct of other variables.
90    /// `None` values represent variables that are not tracked.
91    Struct(Vec<Option<VarInfo<'db>>>),
92    /// The variable is an enum of a known variant of other variables.
93    Enum { variant: ConcreteVariant<'db>, payload: Box<VarInfo<'db>> },
94    /// The variable is a box of another variable.
95    Box(Box<VarInfo<'db>>),
96    /// The variable is an array of known size of other variables.
97    /// `None` values represent variables that are not tracked.
98    Array(Vec<Option<VarInfo<'db>>>),
99}
100impl<'db> VarInfo<'db> {
101    /// Peels the snapshots from the variable info and returns the number of snapshots performed.
102    fn peel_snapshots(&self) -> (usize, &VarInfo<'db>) {
103        let mut n_snapshots = 0;
104        let mut info = self;
105        while let VarInfo::Snapshot(inner) = info {
106            info = inner.as_ref();
107            n_snapshots += 1;
108        }
109        (n_snapshots, info)
110    }
111    /// Wraps the variable info with the given number of snapshots.
112    fn wrap_with_snapshots(mut self, n_snapshots: usize) -> VarInfo<'db> {
113        for _ in 0..n_snapshots {
114            self = VarInfo::Snapshot(Box::new(self));
115        }
116        self
117    }
118}
119
120#[derive(Debug, Clone, Copy, PartialEq)]
121enum Reachability {
122    /// The block is reachable from the function start only through the goto at the end of the given
123    /// block.
124    FromSingleGoto(BlockId),
125    /// The block is reachable from the function start after const-folding - just does not fit
126    /// `FromSingleGoto`.
127    Any,
128}
129
130/// Performs constant folding on the lowered program.
131/// The optimization only works when the blocks are topologically sorted.
132pub fn const_folding<'db>(
133    db: &'db dyn Database,
134    function_id: ConcreteFunctionWithBodyId<'db>,
135    lowered: &mut Lowered<'db>,
136) {
137    if lowered.blocks.is_empty() {
138        return;
139    }
140
141    // Note that we can keep the var_info across blocks because the lowering
142    // is in static single assignment form.
143    let mut ctx = ConstFoldingContext::new(db, function_id, &mut lowered.variables);
144
145    if ctx.should_skip_const_folding(db) {
146        return;
147    }
148
149    for block_id in (0..lowered.blocks.len()).map(BlockId) {
150        if !ctx.visit_block_start(block_id, |block_id| &lowered.blocks[block_id]) {
151            continue;
152        }
153
154        let block = &mut lowered.blocks[block_id];
155        for stmt in block.statements.iter_mut() {
156            ctx.visit_statement(stmt);
157        }
158        ctx.visit_block_end(block_id, block);
159    }
160}
161
162pub struct ConstFoldingContext<'db, 'mt> {
163    /// The used database.
164    db: &'db dyn Database,
165    /// The variables arena, mostly used to get the type of variables.
166    pub variables: &'mt mut VariableArena<'db>,
167    /// The accumulated information about the const values of variables.
168    var_info: UnorderedHashMap<VariableId, VarInfo<'db>>,
169    /// The libfunc information.
170    libfunc_info: &'db ConstFoldingLibfuncInfo<'db>,
171    /// The specialization base of the caller function (or the caller if the function is not
172    /// specialized).
173    caller_function: ConcreteFunctionWithBodyId<'db>,
174    /// Reachability of blocks from the function start.
175    /// If the block is not in this map, it means that it is unreachable (or that it was already
176    /// visited and its reachability won't be checked again).
177    reachability: UnorderedHashMap<BlockId, Reachability>,
178    /// Additional statements to add to the block.
179    additional_stmts: Vec<Statement<'db>>,
180}
181
182impl<'db, 'mt> ConstFoldingContext<'db, 'mt> {
183    pub fn new(
184        db: &'db dyn Database,
185        function_id: ConcreteFunctionWithBodyId<'db>,
186        variables: &'mt mut VariableArena<'db>,
187    ) -> Self {
188        Self {
189            db,
190            var_info: UnorderedHashMap::default(),
191            variables,
192            libfunc_info: priv_const_folding_info(db),
193            caller_function: function_id,
194            reachability: UnorderedHashMap::from_iter([(BlockId::root(), Reachability::Any)]),
195            additional_stmts: vec![],
196        }
197    }
198
199    /// Determines if a block is reachable from the function start and propagates constant values
200    /// when the block is reachable via a single goto statement.
201    pub fn visit_block_start<'r, 'get>(
202        &'r mut self,
203        block_id: BlockId,
204        get_block: impl FnOnce(BlockId) -> &'get Block<'db>,
205    ) -> bool
206    where
207        'db: 'get,
208    {
209        let Some(reachability) = self.reachability.remove(&block_id) else {
210            return false;
211        };
212        match reachability {
213            Reachability::Any => {}
214            Reachability::FromSingleGoto(from_block) => match &get_block(from_block).end {
215                BlockEnd::Goto(_, remapping) => {
216                    for (dst, src) in remapping.iter() {
217                        if let Some(v) = self.as_const(src.var_id) {
218                            self.var_info.insert(*dst, VarInfo::Const(v));
219                        }
220                    }
221                }
222                _ => unreachable!("Expected a goto end"),
223            },
224        }
225        true
226    }
227
228    /// Processes a statement and applies the constant folding optimizations.
229    ///
230    /// This method performs the following operations:
231    /// - Updates the `var_info` map with constant values of variables
232    /// - Replace the input statement with optimized versions when possible
233    /// - Updates `self.additional_stmts` with statements that need to be added to the block.
234    ///
235    /// Note: `self.visit_block_end` must be called after processing all statements
236    /// in a block to actually add the additional statements.
237    pub fn visit_statement(&mut self, stmt: &mut Statement<'db>) {
238        self.maybe_replace_inputs(stmt.inputs_mut());
239        match stmt {
240            Statement::Const(StatementConst { value, output, boxed }) if *boxed => {
241                self.var_info.insert(*output, VarInfo::Box(VarInfo::Const(*value).into()));
242            }
243            Statement::Const(StatementConst { value, output, .. }) => match value.long(self.db) {
244                ConstValue::Int(..)
245                | ConstValue::Struct(..)
246                | ConstValue::Enum(..)
247                | ConstValue::NonZero(..) => {
248                    self.var_info.insert(*output, VarInfo::Const(*value));
249                }
250                ConstValue::Generic(_)
251                | ConstValue::ImplConstant(_)
252                | ConstValue::Var(..)
253                | ConstValue::Missing(_) => {}
254            },
255            Statement::Snapshot(stmt) => {
256                if let Some(info) = self.var_info.get(&stmt.input.var_id).cloned() {
257                    self.var_info.insert(stmt.original(), info.clone());
258                    self.var_info.insert(stmt.snapshot(), VarInfo::Snapshot(info.into()));
259                }
260            }
261            Statement::Desnap(StatementDesnap { input, output }) => {
262                if let Some(VarInfo::Snapshot(info)) = self.var_info.get(&input.var_id) {
263                    self.var_info.insert(*output, info.as_ref().clone());
264                }
265            }
266            Statement::Call(call_stmt) => {
267                if let Some(updated_stmt) = self.handle_statement_call(call_stmt) {
268                    *stmt = updated_stmt;
269                } else if let Some(updated_stmt) = self.try_specialize_call(call_stmt) {
270                    *stmt = updated_stmt;
271                }
272            }
273            Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
274                let mut const_args = vec![];
275                let mut all_args = vec![];
276                let mut contains_info = false;
277                for input in inputs.iter() {
278                    let Some(info) = self.var_info.get(&input.var_id) else {
279                        all_args.push(var_info_if_copy(self.variables, *input));
280                        continue;
281                    };
282                    contains_info = true;
283                    if let VarInfo::Const(value) = info {
284                        const_args.push(*value);
285                    }
286                    all_args.push(Some(info.clone()));
287                }
288                if const_args.len() == inputs.len() {
289                    let value =
290                        ConstValue::Struct(const_args, self.variables[*output].ty).intern(self.db);
291                    self.var_info.insert(*output, VarInfo::Const(value));
292                } else if contains_info {
293                    self.var_info.insert(*output, VarInfo::Struct(all_args));
294                }
295            }
296            Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
297                if let Some(info) = self.var_info.get(&input.var_id) {
298                    let (n_snapshots, info) = info.peel_snapshots();
299                    match info {
300                        VarInfo::Const(const_value) => {
301                            if let ConstValue::Struct(member_values, _) = const_value.long(self.db)
302                            {
303                                for (output, value) in zip_eq(outputs, member_values) {
304                                    self.var_info.insert(
305                                        *output,
306                                        VarInfo::Const(*value).wrap_with_snapshots(n_snapshots),
307                                    );
308                                }
309                            }
310                        }
311                        VarInfo::Struct(members) => {
312                            for (output, member) in zip_eq(outputs, members.clone()) {
313                                if let Some(member) = member {
314                                    self.var_info
315                                        .insert(*output, member.wrap_with_snapshots(n_snapshots));
316                                }
317                            }
318                        }
319                        _ => {}
320                    }
321                }
322            }
323            Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
324                let value = if let Some(info) = self.var_info.get(&input.var_id) {
325                    if let VarInfo::Const(val) = info {
326                        VarInfo::Const(ConstValue::Enum(*variant, *val).intern(self.db))
327                    } else {
328                        VarInfo::Enum { variant: *variant, payload: info.clone().into() }
329                    }
330                } else {
331                    VarInfo::Enum { variant: *variant, payload: VarInfo::Var(*input).into() }
332                };
333                self.var_info.insert(*output, value);
334            }
335        }
336    }
337
338    /// Processes the block's end and incorporates additional statements into the block.
339    ///
340    /// This method handles the following tasks:
341    /// - Inserts the accumulated additional statements into the block.
342    /// - Converts match endings to goto when applicable.
343    /// - Updates self.reachability based on the block's ending.
344    pub fn visit_block_end(&mut self, block_id: BlockId, block: &mut Block<'db>) {
345        let statements = &mut block.statements;
346        statements.splice(0..0, self.additional_stmts.drain(..));
347
348        match &mut block.end {
349            BlockEnd::Goto(_, remappings) => {
350                for (_, v) in remappings.iter_mut() {
351                    self.maybe_replace_input(v);
352                }
353            }
354            BlockEnd::Match { info } => {
355                self.maybe_replace_inputs(info.inputs_mut());
356                match info {
357                    MatchInfo::Enum(info) => {
358                        if let Some(updated_end) = self.handle_enum_block_end(info, statements) {
359                            block.end = updated_end;
360                        }
361                    }
362                    MatchInfo::Extern(info) => {
363                        if let Some(updated_end) = self.handle_extern_block_end(info, statements) {
364                            block.end = updated_end;
365                        }
366                    }
367                    MatchInfo::Value(info) => {
368                        if let Some(value) =
369                            self.as_int(info.input.var_id).and_then(|x| x.to_usize())
370                            && let Some(arm) = info.arms.iter().find(|arm| {
371                                matches!(
372                                    &arm.arm_selector,
373                                    MatchArmSelector::Value(v) if v.value == value
374                                )
375                            })
376                        {
377                            // Create the variable that was previously introduced in match arm.
378                            statements.push(Statement::StructConstruct(StatementStructConstruct {
379                                inputs: vec![],
380                                output: arm.var_ids[0],
381                            }));
382                            block.end = BlockEnd::Goto(arm.block_id, Default::default());
383                        }
384                    }
385                }
386            }
387            BlockEnd::Return(inputs, _) => self.maybe_replace_inputs(inputs),
388            BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
389        }
390        match &block.end {
391            BlockEnd::Goto(dst_block_id, _) => {
392                match self.reachability.entry(*dst_block_id) {
393                    std::collections::hash_map::Entry::Occupied(mut e) => {
394                        e.insert(Reachability::Any)
395                    }
396                    std::collections::hash_map::Entry::Vacant(e) => {
397                        *e.insert(Reachability::FromSingleGoto(block_id))
398                    }
399                };
400            }
401            BlockEnd::Match { info } => {
402                for arm in info.arms() {
403                    assert!(self.reachability.insert(arm.block_id, Reachability::Any).is_none());
404                }
405            }
406            BlockEnd::NotSet | BlockEnd::Return(..) | BlockEnd::Panic(..) => {}
407        }
408    }
409
410    /// Handles a statement call.
411    ///
412    /// Returns None if no additional changes are required.
413    /// If changes are required, returns an updated statement (to override the current
414    /// statement).
415    /// May add additional statements to `self.additional_stmts` if just replacing the current
416    /// statement is not enough.
417    fn handle_statement_call(&mut self, stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
418        let db = self.db;
419        if stmt.function == self.panic_with_felt252 {
420            let val = self.as_const(stmt.inputs[0].var_id)?;
421            stmt.inputs.clear();
422            stmt.function = GenericFunctionId::Free(self.panic_with_const_felt252)
423                .concretize(db, vec![GenericArgumentId::Constant(val)])
424                .lowered(db);
425            return None;
426        } else if stmt.function == self.panic_with_byte_array && !flag_unsafe_panic(db) {
427            let snap = self.var_info.get(&stmt.inputs[0].var_id)?;
428            let bytearray = try_extract_matches!(snap, VarInfo::Snapshot)?;
429            let [
430                Some(VarInfo::Array(data)),
431                Some(VarInfo::Const(pending_word)),
432                Some(VarInfo::Const(pending_len)),
433            ] = &try_extract_matches!(bytearray.as_ref(), VarInfo::Struct)?[..]
434            else {
435                return None;
436            };
437            let mut panic_data =
438                vec![BigInt::from_str_radix(BYTE_ARRAY_MAGIC, 16).unwrap(), data.len().into()];
439            for word in data {
440                let Some(VarInfo::Const(word)) = word else {
441                    return None;
442                };
443                panic_data.push(word.long(db).to_int()?.clone());
444            }
445            panic_data.extend([
446                pending_word.long(db).to_int()?.clone(),
447                pending_len.long(db).to_int()?.clone(),
448            ]);
449            let felt252_ty = self.felt252;
450            let location = stmt.location;
451            let new_var = |ty| Variable::with_default_context(db, ty, location);
452            let as_usage = |var_id| VarUsage { var_id, location };
453            let array_fn = |extern_id| {
454                let args = vec![GenericArgumentId::Type(felt252_ty)];
455                GenericFunctionId::Extern(extern_id).concretize(db, args).lowered(db)
456            };
457            let call_stmt = |function, inputs, outputs| {
458                let with_coupon = false;
459                Statement::Call(StatementCall {
460                    function,
461                    inputs,
462                    with_coupon,
463                    outputs,
464                    location,
465                    is_specialization_base_call: false,
466                })
467            };
468            let arr_var = new_var(corelib::core_array_felt252_ty(db));
469            let mut arr = self.variables.alloc(arr_var.clone());
470            self.additional_stmts.push(call_stmt(array_fn(self.array_new), vec![], vec![arr]));
471            let felt252_var = new_var(felt252_ty);
472            let arr_append_fn = array_fn(self.array_append);
473            for word in panic_data {
474                let to_append = self.variables.alloc(felt252_var.clone());
475                let new_arr = self.variables.alloc(arr_var.clone());
476                self.additional_stmts.push(Statement::Const(StatementConst::new_flat(
477                    ConstValue::Int(word, felt252_ty).intern(db),
478                    to_append,
479                )));
480                self.additional_stmts.push(call_stmt(
481                    arr_append_fn,
482                    vec![as_usage(arr), as_usage(to_append)],
483                    vec![new_arr],
484                ));
485                arr = new_arr;
486            }
487            let panic_ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, "Panic"), vec![]);
488            let panic_var = self.variables.alloc(new_var(panic_ty));
489            self.additional_stmts.push(Statement::StructConstruct(StatementStructConstruct {
490                inputs: vec![],
491                output: panic_var,
492            }));
493            return Some(Statement::StructConstruct(StatementStructConstruct {
494                inputs: vec![as_usage(panic_var), as_usage(arr)],
495                output: stmt.outputs[0],
496            }));
497        }
498        let (id, _generic_args) = stmt.function.get_extern(db)?;
499        if id == self.felt_sub {
500            if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
501                && rhs.is_zero()
502            {
503                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
504                None
505            } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
506                && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
507            {
508                let value = canonical_felt252(&(lhs - rhs));
509                Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
510            } else {
511                None
512            }
513        } else if id == self.felt_add {
514            if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
515                && lhs.is_zero()
516            {
517                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]));
518                None
519            } else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
520                && rhs.is_zero()
521            {
522                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
523                None
524            } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
525                && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
526            {
527                let value = canonical_felt252(&(lhs + rhs));
528                Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
529            } else {
530                None
531            }
532        } else if id == self.felt_mul {
533            let lhs = self.as_int(stmt.inputs[0].var_id);
534            let rhs = self.as_int(stmt.inputs[1].var_id);
535            if lhs.map(Zero::is_zero).unwrap_or_default()
536                || rhs.map(Zero::is_zero).unwrap_or_default()
537            {
538                Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
539            } else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
540                && rhs.is_one()
541            {
542                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
543                None
544            } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
545                && lhs.is_one()
546            {
547                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]));
548                None
549            } else if let Some(lhs) = lhs
550                && let Some(rhs) = rhs
551            {
552                let value = canonical_felt252(&(lhs * rhs));
553                Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
554            } else {
555                None
556            }
557        } else if id == self.felt_div {
558            // Note that divisor is never 0, due to NonZero type always being the divisor.
559            if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
560                // Returns the original value when dividing by 1.
561                && rhs.is_one()
562            {
563                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
564                None
565            } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
566                // If the value is 0, result is 0 regardless of the divisor.
567                && lhs.is_zero()
568            {
569                Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
570            } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
571                && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
572                && let Ok(rhs_nonzero) = Felt252::from(rhs).try_into()
573            {
574                // Constant fold when both operands are constants
575
576                // Use field_div for Felt252 division
577                let lhs_felt = Felt252::from(lhs);
578                let value = lhs_felt.field_div(&rhs_nonzero).to_bigint();
579                Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
580            } else {
581                None
582            }
583        } else if self.wide_mul_fns.contains(&id) {
584            let lhs = self.as_int(stmt.inputs[0].var_id);
585            let rhs = self.as_int(stmt.inputs[1].var_id);
586            let output = stmt.outputs[0];
587            if lhs.map(Zero::is_zero).unwrap_or_default()
588                || rhs.map(Zero::is_zero).unwrap_or_default()
589            {
590                return Some(self.propagate_zero_and_get_statement(output));
591            }
592            let lhs = lhs?;
593            Some(self.propagate_const_and_get_statement(lhs * rhs?, stmt.outputs[0]))
594        } else if id == self.bounded_int_add || id == self.bounded_int_sub {
595            let lhs = self.as_int(stmt.inputs[0].var_id)?;
596            let rhs = self.as_int(stmt.inputs[1].var_id)?;
597            let value = if id == self.bounded_int_add { lhs + rhs } else { lhs - rhs };
598            Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
599        } else if self.div_rem_fns.contains(&id) {
600            let lhs = self.as_int(stmt.inputs[0].var_id);
601            if lhs.map(Zero::is_zero).unwrap_or_default() {
602                let additional_stmt = self.propagate_zero_and_get_statement(stmt.outputs[1]);
603                self.additional_stmts.push(additional_stmt);
604                return Some(self.propagate_zero_and_get_statement(stmt.outputs[0]));
605            }
606            let rhs = self.as_int(stmt.inputs[1].var_id)?;
607            let (q, r) = lhs?.div_rem(rhs);
608            let q_output = stmt.outputs[0];
609            let q_value = ConstValue::Int(q, self.variables[q_output].ty).intern(db);
610            self.var_info.insert(q_output, VarInfo::Const(q_value));
611            let r_output = stmt.outputs[1];
612            let r_value = ConstValue::Int(r, self.variables[r_output].ty).intern(db);
613            self.var_info.insert(r_output, VarInfo::Const(r_value));
614            self.additional_stmts
615                .push(Statement::Const(StatementConst::new_flat(r_value, r_output)));
616            Some(Statement::Const(StatementConst::new_flat(q_value, q_output)))
617        } else if id == self.storage_base_address_from_felt252 {
618            let input_var = stmt.inputs[0].var_id;
619            if let Some(const_value) = self.as_const(input_var)
620                && let ConstValue::Int(val, ty) = const_value.long(db)
621            {
622                stmt.inputs.clear();
623                let arg = GenericArgumentId::Constant(ConstValue::Int(val.clone(), *ty).intern(db));
624                stmt.function =
625                    self.storage_base_address_const.concretize(db, vec![arg]).lowered(db);
626            }
627            None
628        } else if id == self.into_box {
629            let input = stmt.inputs[0];
630            let var_info = self.var_info.get(&input.var_id);
631            let const_value = match var_info {
632                Some(VarInfo::Const(val)) => Some(*val),
633                Some(VarInfo::Snapshot(info)) => {
634                    try_extract_matches!(info.as_ref(), VarInfo::Const).copied()
635                }
636                _ => None,
637            };
638            let var_info = var_info.cloned().or_else(|| var_info_if_copy(self.variables, input))?;
639            self.var_info.insert(stmt.outputs[0], VarInfo::Box(var_info.into()));
640            Some(Statement::Const(StatementConst::new_boxed(const_value?, stmt.outputs[0])))
641        } else if id == self.unbox {
642            if let VarInfo::Box(inner) = self.var_info.get(&stmt.inputs[0].var_id)? {
643                let inner = inner.as_ref().clone();
644                if let VarInfo::Const(inner) =
645                    self.var_info.entry(stmt.outputs[0]).insert_entry(inner).get()
646                {
647                    return Some(Statement::Const(StatementConst::new_flat(
648                        *inner,
649                        stmt.outputs[0],
650                    )));
651                }
652            }
653            None
654        } else if self.upcast_fns.contains(&id) {
655            let int_value = self.as_int(stmt.inputs[0].var_id)?;
656            let output = stmt.outputs[0];
657            let value = ConstValue::Int(int_value.clone(), self.variables[output].ty).intern(db);
658            self.var_info.insert(output, VarInfo::Const(value));
659            Some(Statement::Const(StatementConst::new_flat(value, output)))
660        } else if id == self.array_new {
661            self.var_info.insert(stmt.outputs[0], VarInfo::Array(vec![]));
662            None
663        } else if id == self.array_append {
664            let mut var_infos =
665                if let VarInfo::Array(var_infos) = self.var_info.get(&stmt.inputs[0].var_id)? {
666                    var_infos.clone()
667                } else {
668                    return None;
669                };
670            let appended = stmt.inputs[1];
671            var_infos.push(match self.var_info.get(&appended.var_id) {
672                Some(var_info) => Some(var_info.clone()),
673                None => var_info_if_copy(self.variables, appended),
674            });
675            self.var_info.insert(stmt.outputs[0], VarInfo::Array(var_infos));
676            None
677        } else if id == self.array_len {
678            let info = self.var_info.get(&stmt.inputs[0].var_id)?;
679            let desnapped = try_extract_matches!(info, VarInfo::Snapshot)?;
680            let length = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?.len();
681            Some(self.propagate_const_and_get_statement(length.into(), stmt.outputs[0]))
682        } else {
683            None
684        }
685    }
686
687    /// Tries to specialize the call.
688    /// Returns The specialized call statement if it was specialized, or None otherwise.
689    ///
690    /// Specialization occurs only if `priv_should_specialize` returns true.
691    /// Additionally specialization of a callee the with the same base as the caller is currently
692    /// not supported.
693    fn try_specialize_call(&self, call_stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
694        if call_stmt.with_coupon {
695            return None;
696        }
697        // No specialization when avoiding inlining.
698        if matches!(self.db.optimizations().inlining_strategy(), InliningStrategy::Avoid) {
699            return None;
700        }
701
702        let Ok(Some(mut called_function)) = call_stmt.function.body(self.db) else {
703            return None;
704        };
705
706        let extract_base = |function: ConcreteFunctionWithBodyId<'db>| match function.long(self.db)
707        {
708            ConcreteFunctionWithBodyLongId::Specialized(specialized) => specialized.base,
709            _ => function,
710        };
711        let called_base = extract_base(called_function);
712        let caller_base = extract_base(self.caller_function);
713
714        if self.db.priv_never_inline(called_base).ok()? {
715            return None;
716        }
717
718        // Do not specialize the call that should be inlined.
719        if call_stmt.is_specialization_base_call {
720            return None;
721        }
722
723        // Do not specialize a recursive call that was already specialized.
724        if called_base == caller_base && called_function != called_base {
725            return None;
726        }
727
728        // Avoid specializing with a function that is in the same SCC as the caller (and is not the
729        // same function).
730        let scc =
731            self.db.lowered_scc(called_base, DependencyType::Call, LoweringStage::Monomorphized);
732        if scc.len() > 1 && scc.contains(&caller_base) {
733            return None;
734        }
735
736        if call_stmt.inputs.iter().all(|arg| self.var_info.get(&arg.var_id).is_none()) {
737            // No const inputs
738            return None;
739        }
740
741        // If we are specializing a recursive call, use only subset of the caller.
742        let self_specializition = if let ConcreteFunctionWithBodyLongId::Specialized(specialized) =
743            self.caller_function.long(self.db)
744            && caller_base == called_base
745        {
746            specialized.args.iter().map(Some).collect()
747        } else {
748            vec![None; call_stmt.inputs.len()]
749        };
750
751        let mut specialization_args = vec![];
752        let mut new_args = vec![];
753        for (arg, coerce) in zip_eq(&call_stmt.inputs, &self_specializition) {
754            if let Some(var_info) = self.var_info.get(&arg.var_id)
755                && self.variables[arg.var_id].info.droppable.is_ok()
756                && let Some(specialization_arg) = self.try_get_specialization_arg(
757                    var_info.clone(),
758                    self.variables[arg.var_id].ty,
759                    &mut new_args,
760                    *coerce,
761                )
762            {
763                specialization_args.push(specialization_arg);
764            } else {
765                specialization_args.push(SpecializationArg::NotSpecialized);
766                new_args.push(*arg);
767                continue;
768            };
769        }
770
771        if specialization_args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized)) {
772            // No argument was assigned -> no specialization.
773            return None;
774        }
775        if let ConcreteFunctionWithBodyLongId::Specialized(specialized_function) =
776            called_function.long(self.db)
777        {
778            // Canonicalize the specialization rather than adding a specialization of a specialized
779            // function.
780            called_function = specialized_function.base;
781            let mut new_args_iter = specialization_args.into_iter();
782            let mut old_args = specialized_function.args.to_vec();
783            let mut stack = vec![];
784            for arg in old_args.iter_mut().rev() {
785                stack.push(arg);
786            }
787            while let Some(arg) = stack.pop() {
788                match arg {
789                    SpecializationArg::Const { .. } => {}
790                    SpecializationArg::Snapshot(inner) => {
791                        stack.push(inner.as_mut());
792                    }
793                    SpecializationArg::Enum { payload, .. } => {
794                        stack.push(payload.as_mut());
795                    }
796                    SpecializationArg::Array(_, values) => {
797                        for value in values.iter_mut().rev() {
798                            stack.push(value);
799                        }
800                    }
801                    SpecializationArg::Struct(specialization_args) => {
802                        for arg in specialization_args.iter_mut().rev() {
803                            stack.push(arg);
804                        }
805                    }
806                    SpecializationArg::NotSpecialized => {
807                        *arg = new_args_iter.next().unwrap_or(SpecializationArg::NotSpecialized);
808                    }
809                }
810            }
811            specialization_args = old_args;
812        }
813        let specialized =
814            SpecializedFunction { base: called_function, args: specialization_args.into() };
815        let specialized_func_id =
816            ConcreteFunctionWithBodyLongId::Specialized(specialized).intern(self.db);
817
818        if caller_base != called_base
819            && self.db.priv_should_specialize(specialized_func_id) == Ok(false)
820        {
821            return None;
822        }
823
824        Some(Statement::Call(StatementCall {
825            function: specialized_func_id.function_id(self.db).unwrap(),
826            inputs: new_args,
827            with_coupon: call_stmt.with_coupon,
828            outputs: std::mem::take(&mut call_stmt.outputs),
829            location: call_stmt.location,
830            is_specialization_base_call: false,
831        }))
832    }
833
834    /// Adds `value` as a const to `var_info` and return a const statement for it.
835    fn propagate_const_and_get_statement(
836        &mut self,
837        value: BigInt,
838        output: VariableId,
839    ) -> Statement<'db> {
840        let ty = self.variables[output].ty;
841        let value = ConstValueId::from_int(self.db, ty, &value);
842        self.var_info.insert(output, VarInfo::Const(value));
843        Statement::Const(StatementConst::new_flat(value, output))
844    }
845
846    /// Adds 0 const to `var_info` and return a const statement for it.
847    fn propagate_zero_and_get_statement(&mut self, output: VariableId) -> Statement<'db> {
848        self.propagate_const_and_get_statement(BigInt::zero(), output)
849    }
850
851    /// Returns a statement that introduces the requested value into `output`, or None if fails.
852    fn try_generate_const_statement(
853        &self,
854        value: ConstValueId<'db>,
855        output: VariableId,
856    ) -> Option<Statement<'db>> {
857        if self.db.type_size_info(self.variables[output].ty) == Ok(TypeSizeInformation::Other) {
858            Some(Statement::Const(StatementConst::new_flat(value, output)))
859        } else if matches!(value.long(self.db), ConstValue::Struct(members, _) if members.is_empty())
860        {
861            // Handling const empty structs - which are not supported in sierra-gen.
862            Some(Statement::StructConstruct(StatementStructConstruct { inputs: vec![], output }))
863        } else {
864            None
865        }
866    }
867
868    /// Handles the end of block matching on an enum.
869    /// Possibly extends the blocks statements as well.
870    /// Returns None if no additional changes are required.
871    /// If changes are required, returns the updated block end.
872    fn handle_enum_block_end(
873        &mut self,
874        info: &mut MatchEnumInfo<'db>,
875        statements: &mut Vec<Statement<'db>>,
876    ) -> Option<BlockEnd<'db>> {
877        let input = info.input.var_id;
878        let (n_snapshots, var_info) = self.var_info.get(&input)?.peel_snapshots();
879        let location = info.location;
880        let as_usage = |var_id| VarUsage { var_id, location };
881        let db = self.db;
882        let snapshot_stmt = |vars: &mut VariableArena<'_>, pre_snap, post_snap| {
883            let ignored = vars.alloc(vars[pre_snap].clone());
884            Statement::Snapshot(StatementSnapshot::new(as_usage(pre_snap), ignored, post_snap))
885        };
886        // Checking whether we have actual const info on the enum.
887        if let VarInfo::Const(const_value) = var_info
888            && let ConstValue::Enum(variant, value) = const_value.long(db)
889        {
890            let arm = &info.arms[variant.idx];
891            let output = arm.var_ids[0];
892            // Propagating the const value information.
893            self.var_info.insert(output, VarInfo::Const(*value).wrap_with_snapshots(n_snapshots));
894            if self.variables[input].info.droppable.is_ok()
895                && self.variables[output].info.copyable.is_ok()
896                && let Ok(mut ty) = value.ty(db)
897                && let Some(mut stmt) = self.try_generate_const_statement(*value, output)
898            {
899                // Adding snapshot taking statements for snapshots.
900                for _ in 0..n_snapshots {
901                    let non_snap_var = Variable::with_default_context(db, ty, location);
902                    ty = TypeLongId::Snapshot(ty).intern(db);
903                    let pre_snap = self.variables.alloc(non_snap_var);
904                    stmt.outputs_mut()[0] = pre_snap;
905                    let take_snap = snapshot_stmt(self.variables, pre_snap, output);
906                    statements.push(core::mem::replace(&mut stmt, take_snap));
907                }
908                statements.push(stmt);
909                return Some(BlockEnd::Goto(arm.block_id, Default::default()));
910            }
911        } else if let VarInfo::Enum { variant, payload } = var_info {
912            let arm = &info.arms[variant.idx];
913            let variant_ty = variant.ty;
914            let output = arm.var_ids[0];
915            let payload = payload.as_ref().clone();
916            let unwrapped =
917                self.variables[input].info.droppable.is_ok().then_some(()).and_then(|_| {
918                    let (extra_snapshots, inner) = payload.peel_snapshots();
919                    match inner {
920                        VarInfo::Var(var) if self.variables[var.var_id].info.copyable.is_ok() => {
921                            Some((var.var_id, extra_snapshots))
922                        }
923                        VarInfo::Const(value) => {
924                            let const_var = self
925                                .variables
926                                .alloc(Variable::with_default_context(db, variant_ty, location));
927                            statements.push(self.try_generate_const_statement(*value, const_var)?);
928                            Some((const_var, extra_snapshots))
929                        }
930                        _ => None,
931                    }
932                });
933            // Propagating the const value information.
934            self.var_info.insert(output, payload.wrap_with_snapshots(n_snapshots));
935            if let Some((mut unwrapped, extra_snapshots)) = unwrapped {
936                let total_snapshots = n_snapshots + extra_snapshots;
937                if total_snapshots != 0 {
938                    // Adding snapshot taking statements for snapshots.
939                    for _ in 1..total_snapshots {
940                        let ty = TypeLongId::Snapshot(self.variables[unwrapped].ty).intern(db);
941                        let non_snap_var = Variable::with_default_context(self.db, ty, location);
942                        let snapped = self.variables.alloc(non_snap_var);
943                        statements.push(snapshot_stmt(self.variables, unwrapped, snapped));
944                        unwrapped = snapped;
945                    }
946                    statements.push(snapshot_stmt(self.variables, unwrapped, output));
947                };
948                return Some(BlockEnd::Goto(arm.block_id, Default::default()));
949            }
950        }
951        None
952    }
953
954    /// Handles the end of a block based on an extern function call.
955    /// Possibly extends the blocks statements as well.
956    /// Returns None if no additional changes are required.
957    /// If changes are required, returns the updated block end.
958    fn handle_extern_block_end(
959        &mut self,
960        info: &mut MatchExternInfo<'db>,
961        statements: &mut Vec<Statement<'db>>,
962    ) -> Option<BlockEnd<'db>> {
963        let db = self.db;
964        let (id, generic_args) = info.function.get_extern(db)?;
965        if self.nz_fns.contains(&id) {
966            let val = self.as_const(info.inputs[0].var_id)?;
967            let is_zero = match val.long(db) {
968                ConstValue::Int(v, _) => v.is_zero(),
969                ConstValue::Struct(s, _) => s.iter().all(|v| {
970                    v.long(db).to_int().expect("Expected ConstValue::Int for size").is_zero()
971                }),
972                _ => unreachable!(),
973            };
974            Some(if is_zero {
975                BlockEnd::Goto(info.arms[0].block_id, Default::default())
976            } else {
977                let arm = &info.arms[1];
978                let nz_var = arm.var_ids[0];
979                let nz_val = ConstValue::NonZero(val).intern(db);
980                self.var_info.insert(nz_var, VarInfo::Const(nz_val));
981                statements.push(Statement::Const(StatementConst::new_flat(nz_val, nz_var)));
982                BlockEnd::Goto(arm.block_id, Default::default())
983            })
984        } else if self.eq_fns.contains(&id) {
985            let lhs = self.as_int(info.inputs[0].var_id);
986            let rhs = self.as_int(info.inputs[1].var_id);
987            if (lhs.map(Zero::is_zero).unwrap_or_default() && rhs.is_none())
988                || (rhs.map(Zero::is_zero).unwrap_or_default() && lhs.is_none())
989            {
990                let nz_input = info.inputs[if lhs.is_some() { 1 } else { 0 }];
991                let var = &self.variables[nz_input.var_id].clone();
992                let function = self.type_info.get(&var.ty)?.is_zero;
993                let unused_nz_var = Variable::with_default_context(
994                    db,
995                    corelib::core_nonzero_ty(db, var.ty),
996                    var.location,
997                );
998                let unused_nz_var = self.variables.alloc(unused_nz_var);
999                return Some(BlockEnd::Match {
1000                    info: MatchInfo::Extern(MatchExternInfo {
1001                        function,
1002                        inputs: vec![nz_input],
1003                        arms: vec![
1004                            MatchArm {
1005                                arm_selector: MatchArmSelector::VariantId(
1006                                    corelib::jump_nz_zero_variant(db, var.ty),
1007                                ),
1008                                block_id: info.arms[1].block_id,
1009                                var_ids: vec![],
1010                            },
1011                            MatchArm {
1012                                arm_selector: MatchArmSelector::VariantId(
1013                                    corelib::jump_nz_nonzero_variant(db, var.ty),
1014                                ),
1015                                block_id: info.arms[0].block_id,
1016                                var_ids: vec![unused_nz_var],
1017                            },
1018                        ],
1019                        location: info.location,
1020                    }),
1021                });
1022            }
1023            Some(BlockEnd::Goto(
1024                info.arms[if lhs? == rhs? { 1 } else { 0 }].block_id,
1025                Default::default(),
1026            ))
1027        } else if self.uadd_fns.contains(&id)
1028            || self.usub_fns.contains(&id)
1029            || self.diff_fns.contains(&id)
1030            || self.iadd_fns.contains(&id)
1031            || self.isub_fns.contains(&id)
1032        {
1033            let rhs = self.as_int(info.inputs[1].var_id);
1034            let lhs = self.as_int(info.inputs[0].var_id);
1035            if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
1036                let ty = self.variables[info.arms[0].var_ids[0]].ty;
1037                let range = self.type_value_ranges.get(&ty)?;
1038                let value = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
1039                    lhs + rhs
1040                } else {
1041                    lhs - rhs
1042                };
1043                let (arm_index, value) = match range.normalized(value) {
1044                    NormalizedResult::InRange(value) => (0, value),
1045                    NormalizedResult::Under(value) => (1, value),
1046                    NormalizedResult::Over(value) => (
1047                        if self.iadd_fns.contains(&id) || self.isub_fns.contains(&id) {
1048                            2
1049                        } else {
1050                            1
1051                        },
1052                        value,
1053                    ),
1054                };
1055                let arm = &info.arms[arm_index];
1056                let actual_output = arm.var_ids[0];
1057                let value = ConstValue::Int(value, ty).intern(db);
1058                self.var_info.insert(actual_output, VarInfo::Const(value));
1059                statements.push(Statement::Const(StatementConst::new_flat(value, actual_output)));
1060                return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1061            }
1062            if let Some(rhs) = rhs {
1063                if rhs.is_zero() && !self.diff_fns.contains(&id) {
1064                    let arm = &info.arms[0];
1065                    self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[0]));
1066                    return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1067                }
1068                if rhs.is_one() && !self.diff_fns.contains(&id) {
1069                    let ty = self.variables[info.arms[0].var_ids[0]].ty;
1070                    let ty_info = self.type_info.get(&ty)?;
1071                    let function = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
1072                        ty_info.inc?
1073                    } else {
1074                        ty_info.dec?
1075                    };
1076                    let enum_ty = function.signature(db).ok()?.return_type;
1077                    let TypeLongId::Concrete(ConcreteTypeId::Enum(concrete_enum_id)) =
1078                        enum_ty.long(db)
1079                    else {
1080                        return None;
1081                    };
1082                    let result = self.variables.alloc(Variable::with_default_context(
1083                        db,
1084                        function.signature(db).unwrap().return_type,
1085                        info.location,
1086                    ));
1087                    statements.push(Statement::Call(StatementCall {
1088                        function,
1089                        inputs: vec![info.inputs[0]],
1090                        with_coupon: false,
1091                        outputs: vec![result],
1092                        location: info.location,
1093                        is_specialization_base_call: false,
1094                    }));
1095                    return Some(BlockEnd::Match {
1096                        info: MatchInfo::Enum(MatchEnumInfo {
1097                            concrete_enum_id: *concrete_enum_id,
1098                            input: VarUsage { var_id: result, location: info.location },
1099                            arms: core::mem::take(&mut info.arms),
1100                            location: info.location,
1101                        }),
1102                    });
1103                }
1104            }
1105            if let Some(lhs) = lhs
1106                && lhs.is_zero()
1107                && (self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id))
1108            {
1109                let arm = &info.arms[0];
1110                self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[1]));
1111                return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1112            }
1113            None
1114        } else if let Some(reversed) = self.downcast_fns.get(&id) {
1115            let range = |ty: TypeId<'_>| {
1116                Some(if let Some(range) = self.type_value_ranges.get(&ty) {
1117                    range.clone()
1118                } else {
1119                    let (min, max) = corelib::try_extract_bounded_int_type_ranges(db, ty)?;
1120                    TypeRange { min, max }
1121                })
1122            };
1123            let (success_arm, failure_arm) = if *reversed { (1, 0) } else { (0, 1) };
1124            let input_var = info.inputs[0].var_id;
1125            let in_ty = self.variables[input_var].ty;
1126            let success_output = info.arms[success_arm].var_ids[0];
1127            let out_ty = self.variables[success_output].ty;
1128            let out_range = range(out_ty)?;
1129            let Some(value) = self.as_int(input_var) else {
1130                let in_range = range(in_ty)?;
1131                return if in_range.min < out_range.min || in_range.max > out_range.max {
1132                    None
1133                } else {
1134                    let generic_args = [in_ty, out_ty].map(GenericArgumentId::Type).to_vec();
1135                    let function = db.core_info().upcast_fn.concretize(db, generic_args);
1136                    statements.push(Statement::Call(StatementCall {
1137                        function: function.lowered(db),
1138                        inputs: vec![info.inputs[0]],
1139                        with_coupon: false,
1140                        outputs: vec![success_output],
1141                        location: info.location,
1142                        is_specialization_base_call: false,
1143                    }));
1144                    return Some(BlockEnd::Goto(
1145                        info.arms[success_arm].block_id,
1146                        Default::default(),
1147                    ));
1148                };
1149            };
1150            let value = if in_ty == self.felt252 {
1151                felt252_for_downcast(value, &out_range.min)
1152            } else {
1153                value.clone()
1154            };
1155            Some(if let NormalizedResult::InRange(value) = out_range.normalized(value) {
1156                let value = ConstValue::Int(value, out_ty).intern(db);
1157                self.var_info.insert(success_output, VarInfo::Const(value));
1158                statements.push(Statement::Const(StatementConst::new_flat(value, success_output)));
1159                BlockEnd::Goto(info.arms[success_arm].block_id, Default::default())
1160            } else {
1161                BlockEnd::Goto(info.arms[failure_arm].block_id, Default::default())
1162            })
1163        } else if id == self.bounded_int_constrain {
1164            let input_var = info.inputs[0].var_id;
1165            let value = self.as_int(input_var)?;
1166            let generic_arg = generic_args[1];
1167            let constrain_value = extract_matches!(generic_arg, GenericArgumentId::Constant)
1168                .long(db)
1169                .to_int()
1170                .expect("Expected ConstValue::Int for size");
1171            let arm_idx = if value < constrain_value { 0 } else { 1 };
1172            let output = info.arms[arm_idx].var_ids[0];
1173            statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1174            Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1175        } else if id == self.bounded_int_trim_min {
1176            let input_var = info.inputs[0].var_id;
1177            let ConstValue::Int(value, ty) = self.as_const(input_var)?.long(self.db) else {
1178                return None;
1179            };
1180            let is_trimmed = if let Some(range) = self.type_value_ranges.get(ty) {
1181                range.min == *value
1182            } else {
1183                corelib::try_extract_bounded_int_type_ranges(db, *ty)?.0 == *value
1184            };
1185            let arm_idx = if is_trimmed {
1186                0
1187            } else {
1188                let output = info.arms[1].var_ids[0];
1189                statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1190                1
1191            };
1192            Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1193        } else if id == self.bounded_int_trim_max {
1194            let input_var = info.inputs[0].var_id;
1195            let ConstValue::Int(value, ty) = self.as_const(input_var)?.long(self.db) else {
1196                return None;
1197            };
1198            let is_trimmed = if let Some(range) = self.type_value_ranges.get(ty) {
1199                range.max == *value
1200            } else {
1201                corelib::try_extract_bounded_int_type_ranges(db, *ty)?.1 == *value
1202            };
1203            let arm_idx = if is_trimmed {
1204                0
1205            } else {
1206                let output = info.arms[1].var_ids[0];
1207                statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1208                1
1209            };
1210            Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1211        } else if id == self.array_get {
1212            let index = self.as_int(info.inputs[1].var_id)?.to_usize()?;
1213            if let Some(VarInfo::Snapshot(arr_info)) = self.var_info.get(&info.inputs[0].var_id)
1214                && let VarInfo::Array(infos) = arr_info.as_ref()
1215            {
1216                match infos.get(index) {
1217                    Some(Some(output_var_info)) => {
1218                        let arm = &info.arms[0];
1219                        let output_var_info = output_var_info.clone();
1220                        let box_info =
1221                            VarInfo::Box(VarInfo::Snapshot(output_var_info.clone().into()).into());
1222                        self.var_info.insert(arm.var_ids[0], box_info);
1223                        if let VarInfo::Const(value) = output_var_info {
1224                            let value_ty = value.ty(db).ok()?;
1225                            let value_box_ty = corelib::core_box_ty(db, value_ty);
1226                            let location = info.location;
1227                            let boxed_var =
1228                                Variable::with_default_context(db, value_box_ty, location);
1229                            let boxed = self.variables.alloc(boxed_var.clone());
1230                            let unused_boxed = self.variables.alloc(boxed_var);
1231                            let snapped = self.variables.alloc(Variable::with_default_context(
1232                                db,
1233                                TypeLongId::Snapshot(value_box_ty).intern(db),
1234                                location,
1235                            ));
1236                            statements.extend([
1237                                Statement::Const(StatementConst::new_boxed(value, boxed)),
1238                                Statement::Snapshot(StatementSnapshot {
1239                                    input: VarUsage { var_id: boxed, location },
1240                                    outputs: [unused_boxed, snapped],
1241                                }),
1242                                Statement::Call(StatementCall {
1243                                    function: self
1244                                        .box_forward_snapshot
1245                                        .concretize(db, vec![GenericArgumentId::Type(value_ty)])
1246                                        .lowered(db),
1247                                    inputs: vec![VarUsage { var_id: snapped, location }],
1248                                    with_coupon: false,
1249                                    outputs: vec![arm.var_ids[0]],
1250                                    location: info.location,
1251                                    is_specialization_base_call: false,
1252                                }),
1253                            ]);
1254                            return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1255                        }
1256                    }
1257                    None => {
1258                        return Some(BlockEnd::Goto(info.arms[1].block_id, Default::default()));
1259                    }
1260                    Some(None) => {}
1261                }
1262            }
1263            if index.is_zero()
1264                && let [success, failure] = info.arms.as_mut_slice()
1265            {
1266                let arr = info.inputs[0].var_id;
1267                let unused_arr_output0 = self.variables.alloc(self.variables[arr].clone());
1268                let unused_arr_output1 = self.variables.alloc(self.variables[arr].clone());
1269                info.inputs.truncate(1);
1270                info.function = GenericFunctionId::Extern(self.array_snapshot_pop_front)
1271                    .concretize(db, generic_args)
1272                    .lowered(db);
1273                success.var_ids.insert(0, unused_arr_output0);
1274                failure.var_ids.insert(0, unused_arr_output1);
1275            }
1276            None
1277        } else if id == self.array_pop_front {
1278            let VarInfo::Array(var_infos) = self.var_info.get(&info.inputs[0].var_id)? else {
1279                return None;
1280            };
1281            if let Some(first) = var_infos.first() {
1282                if let Some(first) = first.as_ref().cloned() {
1283                    let arm = &info.arms[0];
1284                    self.var_info.insert(arm.var_ids[0], VarInfo::Array(var_infos[1..].to_vec()));
1285                    self.var_info.insert(arm.var_ids[1], VarInfo::Box(first.into()));
1286                }
1287                None
1288            } else {
1289                let arm = &info.arms[1];
1290                self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]));
1291                Some(BlockEnd::Goto(
1292                    arm.block_id,
1293                    VarRemapping {
1294                        remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1295                    },
1296                ))
1297            }
1298        } else if id == self.array_snapshot_pop_back || id == self.array_snapshot_pop_front {
1299            let var_info = self.var_info.get(&info.inputs[0].var_id)?;
1300            let desnapped = try_extract_matches!(var_info, VarInfo::Snapshot)?;
1301            let element_var_infos = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?;
1302            // TODO(orizi): Propagate success values as well.
1303            if element_var_infos.is_empty() {
1304                let arm = &info.arms[1];
1305                self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]));
1306                Some(BlockEnd::Goto(
1307                    arm.block_id,
1308                    VarRemapping {
1309                        remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1310                    },
1311                ))
1312            } else {
1313                None
1314            }
1315        } else {
1316            None
1317        }
1318    }
1319
1320    /// Returns the const value of a variable if it exists.
1321    fn as_const(&self, var_id: VariableId) -> Option<ConstValueId<'db>> {
1322        try_extract_matches!(self.var_info.get(&var_id)?, VarInfo::Const).copied()
1323    }
1324
1325    /// Return the const value as an int if it exists and is an integer.
1326    fn as_int(&self, var_id: VariableId) -> Option<&BigInt> {
1327        match self.as_const(var_id)?.long(self.db) {
1328            ConstValue::Int(value, _) => Some(value),
1329            ConstValue::NonZero(const_value) => {
1330                if let ConstValue::Int(value, _) = const_value.long(self.db) {
1331                    Some(value)
1332                } else {
1333                    None
1334                }
1335            }
1336            _ => None,
1337        }
1338    }
1339
1340    /// Replaces the inputs in place if they are in the var_info map.
1341    fn maybe_replace_inputs(&self, inputs: &mut [VarUsage<'db>]) {
1342        for input in inputs {
1343            self.maybe_replace_input(input);
1344        }
1345    }
1346
1347    /// Replaces the input in place if it is in the var_info map.
1348    fn maybe_replace_input(&self, input: &mut VarUsage<'db>) {
1349        if let Some(VarInfo::Var(new_var)) = self.var_info.get(&input.var_id) {
1350            *input = *new_var;
1351        }
1352    }
1353
1354    /// Given a var_info and its type, return the corresponding specialization argument, if it
1355    /// exists.
1356    ///
1357    /// The `coerce` argument is used to constrain the specialization argument of recursive calls to
1358    /// the value that is used by the caller.
1359    fn try_get_specialization_arg(
1360        &self,
1361        var_info: VarInfo<'db>,
1362        ty: TypeId<'db>,
1363        unknown_vars: &mut Vec<VarUsage<'db>>,
1364        coerce: Option<&SpecializationArg<'db>>,
1365    ) -> Option<SpecializationArg<'db>> {
1366        if self.db.type_size_info(ty).ok()? == TypeSizeInformation::ZeroSized {
1367            // Skip zero-sized constants as they are not supported in sierra-gen.
1368            return None;
1369        }
1370
1371        match var_info {
1372            VarInfo::Const(value) => {
1373                let res = const_to_specialization_arg(self.db, value, false);
1374                let Some(coerce) = coerce else {
1375                    return Some(res);
1376                };
1377                if *coerce == res { Some(res) } else { None }
1378            }
1379            VarInfo::Box(info) => {
1380                let res = try_extract_matches!(info.as_ref(), VarInfo::Const)
1381                    .map(|value| SpecializationArg::Const { value: *value, boxed: true });
1382                let Some(coerce) = coerce else {
1383                    return res;
1384                };
1385                if Some(coerce.clone()) == res { res } else { None }
1386            }
1387            VarInfo::Snapshot(info) => {
1388                let desnap_ty = *extract_matches!(ty.long(self.db), TypeLongId::Snapshot);
1389                // Use a local accumulator to avoid mutating unknown_vars if we return None.
1390                let mut local_unknown_vars: Vec<VarUsage<'db>> = Vec::new();
1391                let inner = self.try_get_specialization_arg(
1392                    info.as_ref().clone(),
1393                    desnap_ty,
1394                    &mut local_unknown_vars,
1395                    coerce.map(|coerce| {
1396                        extract_matches!(coerce, SpecializationArg::Snapshot).as_ref()
1397                    }),
1398                )?;
1399                unknown_vars.extend(local_unknown_vars);
1400                Some(SpecializationArg::Snapshot(Box::new(inner)))
1401            }
1402            VarInfo::Array(infos) => {
1403                let TypeLongId::Concrete(concrete_ty) = ty.long(self.db) else {
1404                    unreachable!("Expected a concrete type");
1405                };
1406                let [GenericArgumentId::Type(inner_ty)] = &concrete_ty.generic_args(self.db)[..]
1407                else {
1408                    unreachable!("Expected a single type generic argument");
1409                };
1410                let coerces = match coerce {
1411                    Some(coerce) => {
1412                        let SpecializationArg::Array(ty, specialization_args) = coerce else {
1413                            unreachable!("Expected an array specialization argument");
1414                        };
1415                        assert_eq!(ty, inner_ty);
1416                        if specialization_args.len() != infos.len() {
1417                            return None;
1418                        }
1419
1420                        specialization_args.iter().map(Some).collect()
1421                    }
1422                    None => vec![None; infos.len()],
1423                };
1424                // Accumulate into locals first; only extend unknown_vars if we end up specializing.
1425                let mut vars = vec![];
1426                let mut args = vec![];
1427                for (info, coerce) in zip_eq(infos, coerces) {
1428                    let info = info?;
1429                    let arg =
1430                        self.try_get_specialization_arg(info, *inner_ty, &mut vars, coerce)?;
1431                    args.push(arg);
1432                }
1433                if !args.is_empty()
1434                    && args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
1435                {
1436                    return None;
1437                }
1438                unknown_vars.extend(vars);
1439                Some(SpecializationArg::Array(*inner_ty, args))
1440            }
1441            VarInfo::Struct(infos) => {
1442                let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) =
1443                    ty.long(self.db)
1444                else {
1445                    // TODO(ilya): Support closures and fixed size arrays.
1446                    return None;
1447                };
1448
1449                let members = self.db.concrete_struct_members(*concrete_struct).unwrap();
1450                let coerces = match coerce {
1451                    Some(coerce) => {
1452                        let SpecializationArg::Struct(specialization_args) = coerce else {
1453                            unreachable!("Expected a struct specialization argument");
1454                        };
1455                        assert_eq!(specialization_args.len(), infos.len());
1456
1457                        specialization_args.iter().map(Some).collect()
1458                    }
1459                    None => vec![None; infos.len()],
1460                };
1461                let mut struct_args = Vec::new();
1462                // Accumulate into locals first; only extend unknown_vars if we end up specializing.
1463                let mut vars = vec![];
1464                for ((member, opt_var_info), coerce) in
1465                    zip_eq(zip_eq(members.values(), infos), coerces)
1466                {
1467                    let var_info = opt_var_info?;
1468                    let arg =
1469                        self.try_get_specialization_arg(var_info, member.ty, &mut vars, coerce)?;
1470                    struct_args.push(arg);
1471                }
1472                if !struct_args.is_empty()
1473                    && struct_args
1474                        .iter()
1475                        .all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
1476                {
1477                    return None;
1478                }
1479                unknown_vars.extend(vars);
1480                Some(SpecializationArg::Struct(struct_args))
1481            }
1482            VarInfo::Enum { variant, payload } => {
1483                let coerce = match coerce {
1484                    Some(coerce) => {
1485                        let SpecializationArg::Enum { variant: coercion_variant, payload } = coerce
1486                        else {
1487                            unreachable!("Expected an enum specialization argument");
1488                        };
1489                        if *coercion_variant != variant {
1490                            return None;
1491                        }
1492                        Some(payload.as_ref())
1493                    }
1494                    None => None,
1495                };
1496                let mut local_unknown_vars = vec![];
1497                let payload_arg = self.try_get_specialization_arg(
1498                    payload.as_ref().clone(),
1499                    variant.ty,
1500                    &mut local_unknown_vars,
1501                    coerce,
1502                )?;
1503
1504                unknown_vars.extend(local_unknown_vars);
1505                Some(SpecializationArg::Enum { variant, payload: Box::new(payload_arg) })
1506            }
1507            VarInfo::Var(var_usage) => {
1508                unknown_vars.push(var_usage);
1509                Some(SpecializationArg::NotSpecialized)
1510            }
1511        }
1512    }
1513
1514    /// Returns true if const-folding should be skipped for the current function.
1515    pub fn should_skip_const_folding(&self, db: &'db dyn Database) -> bool {
1516        if db.optimizations().skip_const_folding() {
1517            return true;
1518        }
1519
1520        // Skipping const-folding for `panic_with_const_felt252` - to avoid replacing a call to
1521        // `panic_with_felt252` with `panic_with_const_felt252` and causing accidental recursion.
1522        if self.caller_function.base_semantic_function(db).generic_function(db)
1523            == GenericFunctionWithBodyId::Free(self.libfunc_info.panic_with_const_felt252)
1524        {
1525            return true;
1526        }
1527        false
1528    }
1529}
1530
1531/// Returns a `VarInfo` of a variable only if it is copyable.
1532fn var_info_if_copy<'db>(
1533    variables: &VariableArena<'db>,
1534    input: VarUsage<'db>,
1535) -> Option<VarInfo<'db>> {
1536    variables[input.var_id].info.copyable.is_ok().then_some(VarInfo::Var(input))
1537}
1538
1539/// Internal query for the libfuncs information required for const folding.
1540#[salsa::tracked(returns(ref))]
1541fn priv_const_folding_info<'db>(
1542    db: &'db dyn Database,
1543) -> crate::optimizations::const_folding::ConstFoldingLibfuncInfo<'db> {
1544    ConstFoldingLibfuncInfo::new(db)
1545}
1546
1547/// Holds static information about libfuncs required for the optimization.
1548#[derive(Debug, PartialEq, Eq, salsa::Update)]
1549pub struct ConstFoldingLibfuncInfo<'db> {
1550    /// The `felt252_sub` libfunc.
1551    felt_sub: ExternFunctionId<'db>,
1552    /// The `felt252_add` libfunc.
1553    felt_add: ExternFunctionId<'db>,
1554    /// The `felt252_mul` libfunc.
1555    felt_mul: ExternFunctionId<'db>,
1556    /// The `felt252_div` libfunc.
1557    felt_div: ExternFunctionId<'db>,
1558    /// The `into_box` libfunc.
1559    into_box: ExternFunctionId<'db>,
1560    /// The `unbox` libfunc.
1561    unbox: ExternFunctionId<'db>,
1562    /// The `box_forward_snapshot` libfunc.
1563    box_forward_snapshot: GenericFunctionId<'db>,
1564    /// The set of functions that check if numbers are equal.
1565    eq_fns: OrderedHashSet<ExternFunctionId<'db>>,
1566    /// The set of functions to add unsigned ints.
1567    uadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1568    /// The set of functions to subtract unsigned ints.
1569    usub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1570    /// The set of functions to get the difference of signed ints.
1571    diff_fns: OrderedHashSet<ExternFunctionId<'db>>,
1572    /// The set of functions to add signed ints.
1573    iadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1574    /// The set of functions to subtract signed ints.
1575    isub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1576    /// The set of functions to multiply integers.
1577    wide_mul_fns: OrderedHashSet<ExternFunctionId<'db>>,
1578    /// The set of functions to divide and get the remainder of integers.
1579    div_rem_fns: OrderedHashSet<ExternFunctionId<'db>>,
1580    /// The `bounded_int_add` libfunc.
1581    bounded_int_add: ExternFunctionId<'db>,
1582    /// The `bounded_int_sub` libfunc.
1583    bounded_int_sub: ExternFunctionId<'db>,
1584    /// The `bounded_int_constrain` libfunc.
1585    bounded_int_constrain: ExternFunctionId<'db>,
1586    /// The `bounded_int_trim_min` libfunc.
1587    bounded_int_trim_min: ExternFunctionId<'db>,
1588    /// The `bounded_int_trim_max` libfunc.
1589    bounded_int_trim_max: ExternFunctionId<'db>,
1590    /// The `array_get` libfunc.
1591    array_get: ExternFunctionId<'db>,
1592    /// The `array_snapshot_pop_front` libfunc.
1593    array_snapshot_pop_front: ExternFunctionId<'db>,
1594    /// The `array_snapshot_pop_back` libfunc.
1595    array_snapshot_pop_back: ExternFunctionId<'db>,
1596    /// The `array_len` libfunc.
1597    array_len: ExternFunctionId<'db>,
1598    /// The `array_new` libfunc.
1599    array_new: ExternFunctionId<'db>,
1600    /// The `array_append` libfunc.
1601    array_append: ExternFunctionId<'db>,
1602    /// The `array_pop_front` libfunc.
1603    array_pop_front: ExternFunctionId<'db>,
1604    /// The `storage_base_address_from_felt252` libfunc.
1605    storage_base_address_from_felt252: ExternFunctionId<'db>,
1606    /// The `storage_base_address_const` libfunc.
1607    storage_base_address_const: GenericFunctionId<'db>,
1608    /// The `core::panic_with_felt252` function.
1609    panic_with_felt252: FunctionId<'db>,
1610    /// The `core::panic_with_const_felt252` function.
1611    pub panic_with_const_felt252: FreeFunctionId<'db>,
1612    /// The `core::panics::panic_with_byte_array` function.
1613    panic_with_byte_array: FunctionId<'db>,
1614    /// Information per type.
1615    type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>>,
1616    /// The info used for semantic const calculation.
1617    const_calculation_info: Arc<ConstCalcInfo<'db>>,
1618}
1619impl<'db> ConstFoldingLibfuncInfo<'db> {
1620    fn new(db: &'db dyn Database) -> Self {
1621        let core = ModuleHelper::core(db);
1622        let box_module = core.submodule("box");
1623        let integer_module = core.submodule("integer");
1624        let internal_module = core.submodule("internal");
1625        let bounded_int_module = internal_module.submodule("bounded_int");
1626        let num_module = internal_module.submodule("num");
1627        let array_module = core.submodule("array");
1628        let starknet_module = core.submodule("starknet");
1629        let storage_access_module = starknet_module.submodule("storage_access");
1630        let utypes = ["u8", "u16", "u32", "u64", "u128"];
1631        let itypes = ["i8", "i16", "i32", "i64", "i128"];
1632        let eq_fns = OrderedHashSet::<_>::from_iter(
1633            chain!(utypes, itypes).map(|ty| integer_module.extern_function_id(&format!("{ty}_eq"))),
1634        );
1635        let uadd_fns = OrderedHashSet::<_>::from_iter(
1636            utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_add"))),
1637        );
1638        let usub_fns = OrderedHashSet::<_>::from_iter(
1639            utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_sub"))),
1640        );
1641        let diff_fns = OrderedHashSet::<_>::from_iter(
1642            itypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_diff"))),
1643        );
1644        let iadd_fns =
1645            OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1646                integer_module.extern_function_id(&format!("{ty}_overflowing_add_impl"))
1647            }));
1648        let isub_fns =
1649            OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1650                integer_module.extern_function_id(&format!("{ty}_overflowing_sub_impl"))
1651            }));
1652        let wide_mul_fns = OrderedHashSet::<_>::from_iter(chain!(
1653            [bounded_int_module.extern_function_id("bounded_int_mul")],
1654            ["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
1655                .map(|ty| integer_module.extern_function_id(&format!("{ty}_wide_mul"))),
1656        ));
1657        let div_rem_fns = OrderedHashSet::<_>::from_iter(chain!(
1658            [bounded_int_module.extern_function_id("bounded_int_div_rem")],
1659            utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_safe_divmod"))),
1660        ));
1661        let type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>> = OrderedHashMap::from_iter(
1662            [
1663                ("u8", false, true),
1664                ("u16", false, true),
1665                ("u32", false, true),
1666                ("u64", false, true),
1667                ("u128", false, true),
1668                ("u256", false, false),
1669                ("i8", true, true),
1670                ("i16", true, true),
1671                ("i32", true, true),
1672                ("i64", true, true),
1673                ("i128", true, true),
1674            ]
1675            .map(|(ty_name, as_bounded_int, inc_dec): (&'static str, bool, bool)| {
1676                let ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, ty_name), vec![]);
1677                let is_zero = if as_bounded_int {
1678                    bounded_int_module
1679                        .function_id("bounded_int_is_zero", vec![GenericArgumentId::Type(ty)])
1680                } else {
1681                    integer_module.function_id(
1682                        SmolStrId::from(db, format!("{ty_name}_is_zero")).long(db).as_str(),
1683                        vec![],
1684                    )
1685                }
1686                .lowered(db);
1687                let (inc, dec) = if inc_dec {
1688                    (
1689                        Some(
1690                            num_module
1691                                .function_id(
1692                                    SmolStrId::from(db, format!("{ty_name}_inc")).long(db).as_str(),
1693                                    vec![],
1694                                )
1695                                .lowered(db),
1696                        ),
1697                        Some(
1698                            num_module
1699                                .function_id(
1700                                    SmolStrId::from(db, format!("{ty_name}_dec")).long(db).as_str(),
1701                                    vec![],
1702                                )
1703                                .lowered(db),
1704                        ),
1705                    )
1706                } else {
1707                    (None, None)
1708                };
1709                let info = TypeInfo { is_zero, inc, dec };
1710                (ty, info)
1711            }),
1712        );
1713        Self {
1714            felt_sub: core.extern_function_id("felt252_sub"),
1715            felt_add: core.extern_function_id("felt252_add"),
1716            felt_mul: core.extern_function_id("felt252_mul"),
1717            felt_div: core.extern_function_id("felt252_div"),
1718            into_box: box_module.extern_function_id("into_box"),
1719            unbox: box_module.extern_function_id("unbox"),
1720            box_forward_snapshot: box_module.generic_function_id("box_forward_snapshot"),
1721            eq_fns,
1722            uadd_fns,
1723            usub_fns,
1724            diff_fns,
1725            iadd_fns,
1726            isub_fns,
1727            wide_mul_fns,
1728            div_rem_fns,
1729            bounded_int_add: bounded_int_module.extern_function_id("bounded_int_add"),
1730            bounded_int_sub: bounded_int_module.extern_function_id("bounded_int_sub"),
1731            bounded_int_constrain: bounded_int_module.extern_function_id("bounded_int_constrain"),
1732            bounded_int_trim_min: bounded_int_module.extern_function_id("bounded_int_trim_min"),
1733            bounded_int_trim_max: bounded_int_module.extern_function_id("bounded_int_trim_max"),
1734            array_get: array_module.extern_function_id("array_get"),
1735            array_snapshot_pop_front: array_module.extern_function_id("array_snapshot_pop_front"),
1736            array_snapshot_pop_back: array_module.extern_function_id("array_snapshot_pop_back"),
1737            array_len: array_module.extern_function_id("array_len"),
1738            array_new: array_module.extern_function_id("array_new"),
1739            array_append: array_module.extern_function_id("array_append"),
1740            array_pop_front: array_module.extern_function_id("array_pop_front"),
1741            storage_base_address_from_felt252: storage_access_module
1742                .extern_function_id("storage_base_address_from_felt252"),
1743            storage_base_address_const: storage_access_module
1744                .generic_function_id("storage_base_address_const"),
1745            panic_with_felt252: core.function_id("panic_with_felt252", vec![]).lowered(db),
1746            panic_with_const_felt252: core.free_function_id("panic_with_const_felt252"),
1747            panic_with_byte_array: core
1748                .submodule("panics")
1749                .function_id("panic_with_byte_array", vec![])
1750                .lowered(db),
1751            type_info,
1752            const_calculation_info: db.const_calc_info(),
1753        }
1754    }
1755}
1756
1757impl<'db> std::ops::Deref for ConstFoldingContext<'db, '_> {
1758    type Target = ConstFoldingLibfuncInfo<'db>;
1759    fn deref(&self) -> &ConstFoldingLibfuncInfo<'db> {
1760        self.libfunc_info
1761    }
1762}
1763
1764impl<'a> std::ops::Deref for ConstFoldingLibfuncInfo<'a> {
1765    type Target = ConstCalcInfo<'a>;
1766    fn deref(&self) -> &ConstCalcInfo<'a> {
1767        &self.const_calculation_info
1768    }
1769}
1770
1771/// The information of a type required for const foldings.
1772#[derive(Debug, PartialEq, Eq, salsa::Update)]
1773struct TypeInfo<'db> {
1774    /// The function to check if the value is zero for the type.
1775    is_zero: FunctionId<'db>,
1776    /// Inc function to increase the value by one.
1777    inc: Option<FunctionId<'db>>,
1778    /// Dec function to decrease the value by one.
1779    dec: Option<FunctionId<'db>>,
1780}
1781
1782trait TypeRangeNormalizer {
1783    /// Normalizes the value to the range.
1784    /// Assumes the value is within size of range of the range.
1785    fn normalized(&self, value: BigInt) -> NormalizedResult;
1786}
1787impl TypeRangeNormalizer for TypeRange {
1788    fn normalized(&self, value: BigInt) -> NormalizedResult {
1789        if value < self.min {
1790            NormalizedResult::Under(value - &self.min + &self.max + 1)
1791        } else if value > self.max {
1792            NormalizedResult::Over(value + &self.min - &self.max - 1)
1793        } else {
1794            NormalizedResult::InRange(value)
1795        }
1796    }
1797}
1798
1799/// The result of normalizing a value to a range.
1800enum NormalizedResult {
1801    /// The original value is in the range, carries the value, or an equivalent value.
1802    InRange(BigInt),
1803    /// The original value is larger than range max, carries the normalized value.
1804    Over(BigInt),
1805    /// The original value is smaller than range min, carries the normalized value.
1806    Under(BigInt),
1807}