cairo_lang_lowering/optimizations/
const_folding.rs

1#[cfg(test)]
2#[path = "const_folding_test.rs"]
3mod test;
4
5use std::sync::Arc;
6
7use cairo_lang_defs::ids::{ExternFunctionId, FreeFunctionId};
8use cairo_lang_filesystem::flag::flag_unsafe_panic;
9use cairo_lang_filesystem::ids::SmolStrId;
10use cairo_lang_semantic::corelib::CorelibSemantic;
11use cairo_lang_semantic::helper::ModuleHelper;
12use cairo_lang_semantic::items::constant::{
13    ConstCalcInfo, ConstValue, ConstValueId, ConstantSemantic, TypeRange, canonical_felt252,
14    felt252_for_downcast,
15};
16use cairo_lang_semantic::items::functions::{GenericFunctionId, GenericFunctionWithBodyId};
17use cairo_lang_semantic::items::structure::StructSemantic;
18use cairo_lang_semantic::types::{TypeSizeInformation, TypesSemantic};
19use cairo_lang_semantic::{
20    ConcreteTypeId, ConcreteVariant, GenericArgumentId, MatchArmSelector, TypeId, TypeLongId,
21    corelib,
22};
23use cairo_lang_utils::byte_array::BYTE_ARRAY_MAGIC;
24use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
25use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
26use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
27use cairo_lang_utils::{Intern, extract_matches, try_extract_matches};
28use itertools::{chain, zip_eq};
29use num_bigint::BigInt;
30use num_integer::Integer;
31use num_traits::cast::ToPrimitive;
32use num_traits::{Num, One, Zero};
33use salsa::Database;
34use starknet_types_core::felt::Felt as Felt252;
35
36use crate::db::LoweringGroup;
37use crate::ids::{
38    ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionId, SemanticFunctionIdEx,
39    SpecializedFunction,
40};
41use crate::specialization::SpecializationArg;
42use crate::utils::InliningStrategy;
43use crate::{
44    Block, BlockEnd, BlockId, Lowered, MatchArm, MatchEnumInfo, MatchExternInfo, MatchInfo,
45    Statement, StatementCall, StatementConst, StatementDesnap, StatementEnumConstruct,
46    StatementSnapshot, StatementStructConstruct, StatementStructDestructure, VarRemapping,
47    VarUsage, Variable, VariableArena, VariableId,
48};
49
50/// Keeps track of equivalent values that variables might be replaced with.
51/// Note: We don't keep track of types as we assume the usage is always correct.
52#[derive(Debug, Clone)]
53enum VarInfo<'db> {
54    /// The variable is a const value.
55    Const(ConstValueId<'db>),
56    /// The variable can be replaced by another variable.
57    Var(VarUsage<'db>),
58    /// The variable is a snapshot of another variable.
59    Snapshot(Box<VarInfo<'db>>),
60    /// The variable is a struct of other variables.
61    /// `None` values represent variables that are not tracked.
62    Struct(Vec<Option<VarInfo<'db>>>),
63    /// The variable is an enum of a known variant of other variables.
64    Enum { variant: ConcreteVariant<'db>, payload: Box<VarInfo<'db>> },
65    /// The variable is a box of another variable.
66    Box(Box<VarInfo<'db>>),
67    /// The variable is an array of known size of other variables.
68    /// `None` values represent variables that are not tracked.
69    Array(Vec<Option<VarInfo<'db>>>),
70}
71impl<'db> VarInfo<'db> {
72    /// Peels the snapshots from the variable info and returns the number of snapshots performed.
73    fn peel_snapshots(&self) -> (usize, &VarInfo<'db>) {
74        let mut n_snapshots = 0;
75        let mut info = self;
76        while let VarInfo::Snapshot(inner) = info {
77            info = inner.as_ref();
78            n_snapshots += 1;
79        }
80        (n_snapshots, info)
81    }
82    /// Wraps the variable info with the given number of snapshots.
83    fn wrap_with_snapshots(mut self, n_snapshots: usize) -> VarInfo<'db> {
84        for _ in 0..n_snapshots {
85            self = VarInfo::Snapshot(Box::new(self));
86        }
87        self
88    }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq)]
92enum Reachability {
93    /// The block is reachable from the function start only through the goto at the end of the given
94    /// block.
95    FromSingleGoto(BlockId),
96    /// The block is reachable from the function start after const-folding - just does not fit
97    /// `FromSingleGoto`.
98    Any,
99}
100
101/// Performs constant folding on the lowered program.
102/// The optimization only works when the blocks are topologically sorted.
103pub fn const_folding<'db>(
104    db: &'db dyn Database,
105    function_id: ConcreteFunctionWithBodyId<'db>,
106    lowered: &mut Lowered<'db>,
107) {
108    if lowered.blocks.is_empty() {
109        return;
110    }
111
112    // Note that we can keep the var_info across blocks because the lowering
113    // is in static single assignment form.
114    let mut ctx = ConstFoldingContext::new(db, function_id, &mut lowered.variables);
115
116    if ctx.should_skip_const_folding(db) {
117        return;
118    }
119
120    for block_id in (0..lowered.blocks.len()).map(BlockId) {
121        if !ctx.visit_block_start(block_id, |block_id| &lowered.blocks[block_id]) {
122            continue;
123        }
124
125        let block = &mut lowered.blocks[block_id];
126        for stmt in block.statements.iter_mut() {
127            ctx.visit_statement(stmt);
128        }
129        ctx.visit_block_end(block_id, block);
130    }
131}
132
133pub struct ConstFoldingContext<'db, 'mt> {
134    /// The used database.
135    db: &'db dyn Database,
136    /// The variables arena, mostly used to get the type of variables.
137    pub variables: &'mt mut VariableArena<'db>,
138    /// The accumulated information about the const values of variables.
139    var_info: UnorderedHashMap<VariableId, VarInfo<'db>>,
140    /// The libfunc information.
141    libfunc_info: &'db ConstFoldingLibfuncInfo<'db>,
142    /// The specialization base of the caller function (or the caller if the function is not
143    /// specialized).
144    caller_base: ConcreteFunctionWithBodyId<'db>,
145    /// Reachability of blocks from the function start.
146    /// If the block is not in this map, it means that it is unreachable (or that it was already
147    /// visited and its reachability won't be checked again).
148    reachability: UnorderedHashMap<BlockId, Reachability>,
149    /// Additional statements to add to the block.
150    additional_stmts: Vec<Statement<'db>>,
151}
152
153impl<'db, 'mt> ConstFoldingContext<'db, 'mt> {
154    pub fn new(
155        db: &'db dyn Database,
156        function_id: ConcreteFunctionWithBodyId<'db>,
157        variables: &'mt mut VariableArena<'db>,
158    ) -> Self {
159        let caller_base = match function_id.long(db) {
160            ConcreteFunctionWithBodyLongId::Specialized(specialized_func) => specialized_func.base,
161            _ => function_id,
162        };
163
164        Self {
165            db,
166            var_info: UnorderedHashMap::default(),
167            variables,
168            libfunc_info: priv_const_folding_info(db),
169            caller_base,
170            reachability: UnorderedHashMap::from_iter([(BlockId::root(), Reachability::Any)]),
171            additional_stmts: vec![],
172        }
173    }
174
175    /// Determines if a block is reachable from the function start and propagates constant values
176    /// when the block is reachable via a single goto statement.
177    pub fn visit_block_start<'r, 'get>(
178        &'r mut self,
179        block_id: BlockId,
180        get_block: impl FnOnce(BlockId) -> &'get Block<'db>,
181    ) -> bool
182    where
183        'db: 'get,
184    {
185        let Some(reachability) = self.reachability.remove(&block_id) else {
186            return false;
187        };
188        match reachability {
189            Reachability::Any => {}
190            Reachability::FromSingleGoto(from_block) => match &get_block(from_block).end {
191                BlockEnd::Goto(_, remapping) => {
192                    for (dst, src) in remapping.iter() {
193                        if let Some(v) = self.as_const(src.var_id) {
194                            self.var_info.insert(*dst, VarInfo::Const(v));
195                        }
196                    }
197                }
198                _ => unreachable!("Expected a goto end"),
199            },
200        }
201        true
202    }
203
204    /// Processes a statement and applies the constant folding optimizations.
205    ///
206    /// This method performs the following operations:
207    /// - Updates the `var_info` map with constant values of variables
208    /// - Replace the input statement with optimized versions when possible
209    /// - Updates `self.additional_stmts` with statements that need to be added to the block.
210    ///
211    /// Note: `self.visit_block_end` must be called after processing all statements
212    /// in a block to actually add the additional statements.
213    pub fn visit_statement(&mut self, stmt: &mut Statement<'db>) {
214        self.maybe_replace_inputs(stmt.inputs_mut());
215        match stmt {
216            Statement::Const(StatementConst { value, output, boxed }) if *boxed => {
217                self.var_info.insert(*output, VarInfo::Box(VarInfo::Const(*value).into()));
218            }
219            Statement::Const(StatementConst { value, output, .. }) => match value.long(self.db) {
220                ConstValue::Int(..)
221                | ConstValue::Struct(..)
222                | ConstValue::Enum(..)
223                | ConstValue::NonZero(..) => {
224                    self.var_info.insert(*output, VarInfo::Const(*value));
225                }
226                ConstValue::Generic(_)
227                | ConstValue::ImplConstant(_)
228                | ConstValue::Var(..)
229                | ConstValue::Missing(_) => {}
230            },
231            Statement::Snapshot(stmt) => {
232                if let Some(info) = self.var_info.get(&stmt.input.var_id).cloned() {
233                    self.var_info.insert(stmt.original(), info.clone());
234                    self.var_info.insert(stmt.snapshot(), VarInfo::Snapshot(info.into()));
235                }
236            }
237            Statement::Desnap(StatementDesnap { input, output }) => {
238                if let Some(VarInfo::Snapshot(info)) = self.var_info.get(&input.var_id) {
239                    self.var_info.insert(*output, info.as_ref().clone());
240                }
241            }
242            Statement::Call(call_stmt) => {
243                if let Some(updated_stmt) = self.handle_statement_call(call_stmt) {
244                    *stmt = updated_stmt;
245                } else if let Some(updated_stmt) = self.try_specialize_call(call_stmt) {
246                    *stmt = updated_stmt;
247                }
248            }
249            Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
250                let mut const_args = vec![];
251                let mut all_args = vec![];
252                let mut contains_info = false;
253                for input in inputs.iter() {
254                    let Some(info) = self.var_info.get(&input.var_id) else {
255                        all_args.push(var_info_if_copy(self.variables, *input));
256                        continue;
257                    };
258                    contains_info = true;
259                    if let VarInfo::Const(value) = info {
260                        const_args.push(*value);
261                    }
262                    all_args.push(Some(info.clone()));
263                }
264                if const_args.len() == inputs.len() {
265                    let value =
266                        ConstValue::Struct(const_args, self.variables[*output].ty).intern(self.db);
267                    self.var_info.insert(*output, VarInfo::Const(value));
268                } else if contains_info {
269                    self.var_info.insert(*output, VarInfo::Struct(all_args));
270                }
271            }
272            Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
273                if let Some(info) = self.var_info.get(&input.var_id) {
274                    let (n_snapshots, info) = info.peel_snapshots();
275                    match info {
276                        VarInfo::Const(const_value) => {
277                            if let ConstValue::Struct(member_values, _) = const_value.long(self.db)
278                            {
279                                for (output, value) in zip_eq(outputs, member_values) {
280                                    self.var_info.insert(
281                                        *output,
282                                        VarInfo::Const(*value).wrap_with_snapshots(n_snapshots),
283                                    );
284                                }
285                            }
286                        }
287                        VarInfo::Struct(members) => {
288                            for (output, member) in zip_eq(outputs, members.clone()) {
289                                if let Some(member) = member {
290                                    self.var_info
291                                        .insert(*output, member.wrap_with_snapshots(n_snapshots));
292                                }
293                            }
294                        }
295                        _ => {}
296                    }
297                }
298            }
299            Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
300                let value = if let Some(info) = self.var_info.get(&input.var_id) {
301                    if let VarInfo::Const(val) = info {
302                        VarInfo::Const(ConstValue::Enum(*variant, *val).intern(self.db))
303                    } else {
304                        VarInfo::Enum { variant: *variant, payload: info.clone().into() }
305                    }
306                } else {
307                    VarInfo::Enum { variant: *variant, payload: VarInfo::Var(*input).into() }
308                };
309                self.var_info.insert(*output, value);
310            }
311        }
312    }
313
314    /// Processes the block's end and incorporates additional statements into the block.
315    ///
316    /// This method handles the following tasks:
317    /// - Inserts the accumulated additional statements into the block.
318    /// - Converts match endings to goto when applicable.
319    /// - Updates self.reachability based on the block's ending.
320    pub fn visit_block_end(&mut self, block_id: BlockId, block: &mut Block<'db>) {
321        let statements = &mut block.statements;
322        statements.splice(0..0, self.additional_stmts.drain(..));
323
324        match &mut block.end {
325            BlockEnd::Goto(_, remappings) => {
326                for (_, v) in remappings.iter_mut() {
327                    self.maybe_replace_input(v);
328                }
329            }
330            BlockEnd::Match { info } => {
331                self.maybe_replace_inputs(info.inputs_mut());
332                match info {
333                    MatchInfo::Enum(info) => {
334                        if let Some(updated_end) = self.handle_enum_block_end(info, statements) {
335                            block.end = updated_end;
336                        }
337                    }
338                    MatchInfo::Extern(info) => {
339                        if let Some(updated_end) = self.handle_extern_block_end(info, statements) {
340                            block.end = updated_end;
341                        }
342                    }
343                    MatchInfo::Value(info) => {
344                        if let Some(value) =
345                            self.as_int(info.input.var_id).and_then(|x| x.to_usize())
346                            && let Some(arm) = info.arms.iter().find(|arm| {
347                                matches!(
348                                    &arm.arm_selector,
349                                    MatchArmSelector::Value(v) if v.value == value
350                                )
351                            })
352                        {
353                            // Create the variable that was previously introduced in match arm.
354                            statements.push(Statement::StructConstruct(StatementStructConstruct {
355                                inputs: vec![],
356                                output: arm.var_ids[0],
357                            }));
358                            block.end = BlockEnd::Goto(arm.block_id, Default::default());
359                        }
360                    }
361                }
362            }
363            BlockEnd::Return(inputs, _) => self.maybe_replace_inputs(inputs),
364            BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
365        }
366        match &block.end {
367            BlockEnd::Goto(dst_block_id, _) => {
368                match self.reachability.entry(*dst_block_id) {
369                    std::collections::hash_map::Entry::Occupied(mut e) => {
370                        e.insert(Reachability::Any)
371                    }
372                    std::collections::hash_map::Entry::Vacant(e) => {
373                        *e.insert(Reachability::FromSingleGoto(block_id))
374                    }
375                };
376            }
377            BlockEnd::Match { info } => {
378                for arm in info.arms() {
379                    assert!(self.reachability.insert(arm.block_id, Reachability::Any).is_none());
380                }
381            }
382            BlockEnd::NotSet | BlockEnd::Return(..) | BlockEnd::Panic(..) => {}
383        }
384    }
385
386    /// Handles a statement call.
387    ///
388    /// Returns None if no additional changes are required.
389    /// If changes are required, returns an updated statement (to override the current
390    /// statement).
391    /// May add additional statements to `self.additional_stmts` if just replacing the current
392    /// statement is not enough.
393    fn handle_statement_call(&mut self, stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
394        let db = self.db;
395        if stmt.function == self.panic_with_felt252 {
396            let val = self.as_const(stmt.inputs[0].var_id)?;
397            stmt.inputs.clear();
398            stmt.function = GenericFunctionId::Free(self.panic_with_const_felt252)
399                .concretize(db, vec![GenericArgumentId::Constant(val)])
400                .lowered(db);
401            return None;
402        } else if stmt.function == self.panic_with_byte_array && !flag_unsafe_panic(db) {
403            let snap = self.var_info.get(&stmt.inputs[0].var_id)?;
404            let bytearray = try_extract_matches!(snap, VarInfo::Snapshot)?;
405            let [
406                Some(VarInfo::Array(data)),
407                Some(VarInfo::Const(pending_word)),
408                Some(VarInfo::Const(pending_len)),
409            ] = &try_extract_matches!(bytearray.as_ref(), VarInfo::Struct)?[..]
410            else {
411                return None;
412            };
413            let mut panic_data =
414                vec![BigInt::from_str_radix(BYTE_ARRAY_MAGIC, 16).unwrap(), data.len().into()];
415            for word in data {
416                let Some(VarInfo::Const(word)) = word else {
417                    return None;
418                };
419                panic_data.push(word.long(db).to_int()?.clone());
420            }
421            panic_data.extend([
422                pending_word.long(db).to_int()?.clone(),
423                pending_len.long(db).to_int()?.clone(),
424            ]);
425            let felt252_ty = self.felt252;
426            let location = stmt.location;
427            let new_var = |ty| Variable::with_default_context(db, ty, location);
428            let as_usage = |var_id| VarUsage { var_id, location };
429            let array_fn = |extern_id| {
430                let args = vec![GenericArgumentId::Type(felt252_ty)];
431                GenericFunctionId::Extern(extern_id).concretize(db, args).lowered(db)
432            };
433            let call_stmt = |function, inputs, outputs| {
434                let with_coupon = false;
435                Statement::Call(StatementCall { function, inputs, with_coupon, outputs, location })
436            };
437            let arr_var = new_var(corelib::core_array_felt252_ty(db));
438            let mut arr = self.variables.alloc(arr_var.clone());
439            self.additional_stmts.push(call_stmt(array_fn(self.array_new), vec![], vec![arr]));
440            let felt252_var = new_var(felt252_ty);
441            let arr_append_fn = array_fn(self.array_append);
442            for word in panic_data {
443                let to_append = self.variables.alloc(felt252_var.clone());
444                let new_arr = self.variables.alloc(arr_var.clone());
445                self.additional_stmts.push(Statement::Const(StatementConst::new_flat(
446                    ConstValue::Int(word, felt252_ty).intern(db),
447                    to_append,
448                )));
449                self.additional_stmts.push(call_stmt(
450                    arr_append_fn,
451                    vec![as_usage(arr), as_usage(to_append)],
452                    vec![new_arr],
453                ));
454                arr = new_arr;
455            }
456            let panic_ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, "Panic"), vec![]);
457            let panic_var = self.variables.alloc(new_var(panic_ty));
458            self.additional_stmts.push(Statement::StructConstruct(StatementStructConstruct {
459                inputs: vec![],
460                output: panic_var,
461            }));
462            return Some(Statement::StructConstruct(StatementStructConstruct {
463                inputs: vec![as_usage(panic_var), as_usage(arr)],
464                output: stmt.outputs[0],
465            }));
466        }
467        let (id, _generic_args) = stmt.function.get_extern(db)?;
468        if id == self.felt_sub {
469            if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
470                && rhs.is_zero()
471            {
472                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
473                None
474            } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
475                && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
476            {
477                let value = canonical_felt252(&(lhs - rhs));
478                Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
479            } else {
480                None
481            }
482        } else if id == self.felt_add {
483            if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
484                && lhs.is_zero()
485            {
486                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]));
487                None
488            } else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
489                && rhs.is_zero()
490            {
491                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
492                None
493            } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
494                && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
495            {
496                let value = canonical_felt252(&(lhs + rhs));
497                Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
498            } else {
499                None
500            }
501        } else if id == self.felt_mul {
502            let lhs = self.as_int(stmt.inputs[0].var_id);
503            let rhs = self.as_int(stmt.inputs[1].var_id);
504            if lhs.map(Zero::is_zero).unwrap_or_default()
505                || rhs.map(Zero::is_zero).unwrap_or_default()
506            {
507                Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
508            } else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
509                && rhs.is_one()
510            {
511                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
512                None
513            } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
514                && lhs.is_one()
515            {
516                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]));
517                None
518            } else if let Some(lhs) = lhs
519                && let Some(rhs) = rhs
520            {
521                let value = canonical_felt252(&(lhs * rhs));
522                Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
523            } else {
524                None
525            }
526        } else if id == self.felt_div {
527            // Note that divisor is never 0, due to NonZero type always being the divisor.
528            if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
529                // Returns the original value when dividing by 1.
530                && rhs.is_one()
531            {
532                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
533                None
534            } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
535                // If the value is 0, result is 0 regardless of the divisor.
536                && lhs.is_zero()
537            {
538                Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
539            } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
540                && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
541                && let Ok(rhs_nonzero) = Felt252::from(rhs).try_into()
542            {
543                // Constant fold when both operands are constants
544
545                // Use field_div for Felt252 division
546                let lhs_felt = Felt252::from(lhs);
547                let value = lhs_felt.field_div(&rhs_nonzero).to_bigint();
548                Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
549            } else {
550                None
551            }
552        } else if self.wide_mul_fns.contains(&id) {
553            let lhs = self.as_int(stmt.inputs[0].var_id);
554            let rhs = self.as_int(stmt.inputs[1].var_id);
555            let output = stmt.outputs[0];
556            if lhs.map(Zero::is_zero).unwrap_or_default()
557                || rhs.map(Zero::is_zero).unwrap_or_default()
558            {
559                return Some(self.propagate_zero_and_get_statement(output));
560            }
561            let lhs = lhs?;
562            Some(self.propagate_const_and_get_statement(lhs * rhs?, stmt.outputs[0]))
563        } else if id == self.bounded_int_add || id == self.bounded_int_sub {
564            let lhs = self.as_int(stmt.inputs[0].var_id)?;
565            let rhs = self.as_int(stmt.inputs[1].var_id)?;
566            let value = if id == self.bounded_int_add { lhs + rhs } else { lhs - rhs };
567            Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
568        } else if self.div_rem_fns.contains(&id) {
569            let lhs = self.as_int(stmt.inputs[0].var_id);
570            if lhs.map(Zero::is_zero).unwrap_or_default() {
571                let additional_stmt = self.propagate_zero_and_get_statement(stmt.outputs[1]);
572                self.additional_stmts.push(additional_stmt);
573                return Some(self.propagate_zero_and_get_statement(stmt.outputs[0]));
574            }
575            let rhs = self.as_int(stmt.inputs[1].var_id)?;
576            let (q, r) = lhs?.div_rem(rhs);
577            let q_output = stmt.outputs[0];
578            let q_value = ConstValue::Int(q, self.variables[q_output].ty).intern(db);
579            self.var_info.insert(q_output, VarInfo::Const(q_value));
580            let r_output = stmt.outputs[1];
581            let r_value = ConstValue::Int(r, self.variables[r_output].ty).intern(db);
582            self.var_info.insert(r_output, VarInfo::Const(r_value));
583            self.additional_stmts
584                .push(Statement::Const(StatementConst::new_flat(r_value, r_output)));
585            Some(Statement::Const(StatementConst::new_flat(q_value, q_output)))
586        } else if id == self.storage_base_address_from_felt252 {
587            let input_var = stmt.inputs[0].var_id;
588            if let Some(const_value) = self.as_const(input_var)
589                && let ConstValue::Int(val, ty) = const_value.long(db)
590            {
591                stmt.inputs.clear();
592                let arg = GenericArgumentId::Constant(ConstValue::Int(val.clone(), *ty).intern(db));
593                stmt.function =
594                    self.storage_base_address_const.concretize(db, vec![arg]).lowered(db);
595            }
596            None
597        } else if id == self.into_box {
598            let input = stmt.inputs[0];
599            let var_info = self.var_info.get(&input.var_id);
600            let const_value = match var_info {
601                Some(VarInfo::Const(val)) => Some(*val),
602                Some(VarInfo::Snapshot(info)) => {
603                    try_extract_matches!(info.as_ref(), VarInfo::Const).copied()
604                }
605                _ => None,
606            };
607            let var_info = var_info.cloned().or_else(|| var_info_if_copy(self.variables, input))?;
608            self.var_info.insert(stmt.outputs[0], VarInfo::Box(var_info.into()));
609            Some(Statement::Const(StatementConst::new_boxed(const_value?, stmt.outputs[0])))
610        } else if id == self.unbox {
611            if let VarInfo::Box(inner) = self.var_info.get(&stmt.inputs[0].var_id)? {
612                let inner = inner.as_ref().clone();
613                if let VarInfo::Const(inner) =
614                    self.var_info.entry(stmt.outputs[0]).insert_entry(inner).get()
615                {
616                    return Some(Statement::Const(StatementConst::new_flat(
617                        *inner,
618                        stmt.outputs[0],
619                    )));
620                }
621            }
622            None
623        } else if self.upcast_fns.contains(&id) {
624            let int_value = self.as_int(stmt.inputs[0].var_id)?;
625            let output = stmt.outputs[0];
626            let value = ConstValue::Int(int_value.clone(), self.variables[output].ty).intern(db);
627            self.var_info.insert(output, VarInfo::Const(value));
628            Some(Statement::Const(StatementConst::new_flat(value, output)))
629        } else if id == self.array_new {
630            self.var_info.insert(stmt.outputs[0], VarInfo::Array(vec![]));
631            None
632        } else if id == self.array_append {
633            let mut var_infos =
634                if let VarInfo::Array(var_infos) = self.var_info.get(&stmt.inputs[0].var_id)? {
635                    var_infos.clone()
636                } else {
637                    return None;
638                };
639            let appended = stmt.inputs[1];
640            var_infos.push(match self.var_info.get(&appended.var_id) {
641                Some(var_info) => Some(var_info.clone()),
642                None => var_info_if_copy(self.variables, appended),
643            });
644            self.var_info.insert(stmt.outputs[0], VarInfo::Array(var_infos));
645            None
646        } else if id == self.array_len {
647            let info = self.var_info.get(&stmt.inputs[0].var_id)?;
648            let desnapped = try_extract_matches!(info, VarInfo::Snapshot)?;
649            let length = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?.len();
650            Some(self.propagate_const_and_get_statement(length.into(), stmt.outputs[0]))
651        } else {
652            None
653        }
654    }
655
656    /// Tries to specialize the call.
657    /// Returns The specialized call statement if it was specialized, or None otherwise.
658    ///
659    /// Specialization occurs only if `priv_should_specialize` returns true.
660    /// Additionally specialization of a callee the with the same base as the caller is currently
661    /// not supported.
662    fn try_specialize_call(&self, call_stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
663        if call_stmt.with_coupon {
664            return None;
665        }
666        // No specialization when avoiding inlining.
667        if matches!(self.db.optimizations().inlining_strategy(), InliningStrategy::Avoid) {
668            return None;
669        }
670
671        let Ok(Some(mut base)) = call_stmt.function.body(self.db) else {
672            return None;
673        };
674
675        if self.db.priv_never_inline(base).ok()? {
676            return None;
677        }
678
679        // Avoid specializing with the same base as the current function as it may lead to infinite
680        // specialization.
681        if base == self.caller_base {
682            return None;
683        }
684        if call_stmt.inputs.iter().all(|arg| self.var_info.get(&arg.var_id).is_none()) {
685            // No const inputs
686            return None;
687        }
688        let mut specialization_args = vec![];
689        let mut new_args = vec![];
690        for arg in &call_stmt.inputs {
691            if let Some(var_info) = self.var_info.get(&arg.var_id)
692                && self.variables[arg.var_id].info.droppable.is_ok()
693                && let Some(specialization_arg) = self.try_get_specialization_arg(
694                    var_info.clone(),
695                    self.variables[arg.var_id].ty,
696                    &mut new_args,
697                )
698            {
699                specialization_args.push(specialization_arg);
700            } else {
701                specialization_args.push(SpecializationArg::NotSpecialized);
702                new_args.push(*arg);
703                continue;
704            };
705        }
706
707        if specialization_args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized)) {
708            // No argument was assigned -> no specialization.
709            return None;
710        }
711        if let ConcreteFunctionWithBodyLongId::Specialized(specialized_function) =
712            base.long(self.db)
713        {
714            // Canonicalize the specialization rather than adding a specialization of a specialized
715            // function.
716            base = specialized_function.base;
717            let mut new_args_iter = specialization_args.into_iter();
718            let mut old_args = specialized_function.args.to_vec();
719            let mut stack = vec![];
720            for arg in old_args.iter_mut().rev() {
721                stack.push(arg);
722            }
723            while let Some(arg) = stack.pop() {
724                match arg {
725                    SpecializationArg::Const { .. } => {}
726                    SpecializationArg::Snapshot(inner) => {
727                        stack.push(inner.as_mut());
728                    }
729                    SpecializationArg::Enum { payload, .. } => {
730                        stack.push(payload.as_mut());
731                    }
732                    SpecializationArg::Array(_, values) => {
733                        for value in values.iter_mut().rev() {
734                            stack.push(value);
735                        }
736                    }
737                    SpecializationArg::Struct(specialization_args) => {
738                        for arg in specialization_args.iter_mut().rev() {
739                            stack.push(arg);
740                        }
741                    }
742                    SpecializationArg::NotSpecialized => {
743                        *arg = new_args_iter.next().unwrap_or(SpecializationArg::NotSpecialized);
744                    }
745                }
746            }
747            specialization_args = old_args;
748        }
749        let specialized = SpecializedFunction { base, args: specialization_args.into() };
750        let specialized_func_id =
751            ConcreteFunctionWithBodyLongId::Specialized(specialized).intern(self.db);
752
753        if self.db.priv_should_specialize(specialized_func_id) == Ok(false) {
754            return None;
755        }
756
757        Some(Statement::Call(StatementCall {
758            function: specialized_func_id.function_id(self.db).unwrap(),
759            inputs: new_args,
760            with_coupon: call_stmt.with_coupon,
761            outputs: std::mem::take(&mut call_stmt.outputs),
762            location: call_stmt.location,
763        }))
764    }
765
766    /// Adds `value` as a const to `var_info` and return a const statement for it.
767    fn propagate_const_and_get_statement(
768        &mut self,
769        value: BigInt,
770        output: VariableId,
771    ) -> Statement<'db> {
772        let ty = self.variables[output].ty;
773        let value = ConstValueId::from_int(self.db, ty, &value);
774        self.var_info.insert(output, VarInfo::Const(value));
775        Statement::Const(StatementConst::new_flat(value, output))
776    }
777
778    /// Adds 0 const to `var_info` and return a const statement for it.
779    fn propagate_zero_and_get_statement(&mut self, output: VariableId) -> Statement<'db> {
780        self.propagate_const_and_get_statement(BigInt::zero(), output)
781    }
782
783    /// Returns a statement that introduces the requested value into `output`, or None if fails.
784    fn try_generate_const_statement(
785        &self,
786        value: ConstValueId<'db>,
787        output: VariableId,
788    ) -> Option<Statement<'db>> {
789        if self.db.type_size_info(self.variables[output].ty) == Ok(TypeSizeInformation::Other) {
790            Some(Statement::Const(StatementConst::new_flat(value, output)))
791        } else if matches!(value.long(self.db), ConstValue::Struct(members, _) if members.is_empty())
792        {
793            // Handling const empty structs - which are not supported in sierra-gen.
794            Some(Statement::StructConstruct(StatementStructConstruct { inputs: vec![], output }))
795        } else {
796            None
797        }
798    }
799
800    /// Handles the end of block matching on an enum.
801    /// Possibly extends the blocks statements as well.
802    /// Returns None if no additional changes are required.
803    /// If changes are required, returns the updated block end.
804    fn handle_enum_block_end(
805        &mut self,
806        info: &mut MatchEnumInfo<'db>,
807        statements: &mut Vec<Statement<'db>>,
808    ) -> Option<BlockEnd<'db>> {
809        let input = info.input.var_id;
810        let (n_snapshots, var_info) = self.var_info.get(&input)?.peel_snapshots();
811        let location = info.location;
812        let as_usage = |var_id| VarUsage { var_id, location };
813        let db = self.db;
814        let snapshot_stmt = |vars: &mut VariableArena<'_>, pre_snap, post_snap| {
815            let ignored = vars.alloc(vars[pre_snap].clone());
816            Statement::Snapshot(StatementSnapshot::new(as_usage(pre_snap), ignored, post_snap))
817        };
818        // Checking whether we have actual const info on the enum.
819        if let VarInfo::Const(const_value) = var_info
820            && let ConstValue::Enum(variant, value) = const_value.long(db)
821        {
822            let arm = &info.arms[variant.idx];
823            let output = arm.var_ids[0];
824            // Propagating the const value information.
825            self.var_info.insert(output, VarInfo::Const(*value).wrap_with_snapshots(n_snapshots));
826            if self.variables[input].info.droppable.is_ok()
827                && self.variables[output].info.copyable.is_ok()
828                && let Ok(mut ty) = value.ty(db)
829                && let Some(mut stmt) = self.try_generate_const_statement(*value, output)
830            {
831                // Adding snapshot taking statements for snapshots.
832                for _ in 0..n_snapshots {
833                    let non_snap_var = Variable::with_default_context(db, ty, location);
834                    ty = TypeLongId::Snapshot(ty).intern(db);
835                    let pre_snap = self.variables.alloc(non_snap_var);
836                    stmt.outputs_mut()[0] = pre_snap;
837                    let take_snap = snapshot_stmt(self.variables, pre_snap, output);
838                    statements.push(core::mem::replace(&mut stmt, take_snap));
839                }
840                statements.push(stmt);
841                return Some(BlockEnd::Goto(arm.block_id, Default::default()));
842            }
843        } else if let VarInfo::Enum { variant, payload } = var_info {
844            let arm = &info.arms[variant.idx];
845            let variant_ty = variant.ty;
846            let output = arm.var_ids[0];
847            let payload = payload.as_ref().clone();
848            let unwrapped =
849                self.variables[input].info.droppable.is_ok().then_some(()).and_then(|_| {
850                    let (extra_snapshots, inner) = payload.peel_snapshots();
851                    match inner {
852                        VarInfo::Var(var) if self.variables[var.var_id].info.copyable.is_ok() => {
853                            Some((var.var_id, extra_snapshots))
854                        }
855                        VarInfo::Const(value) => {
856                            let const_var = self
857                                .variables
858                                .alloc(Variable::with_default_context(db, variant_ty, location));
859                            statements.push(self.try_generate_const_statement(*value, const_var)?);
860                            Some((const_var, extra_snapshots))
861                        }
862                        _ => None,
863                    }
864                });
865            // Propagating the const value information.
866            self.var_info.insert(output, payload.wrap_with_snapshots(n_snapshots));
867            if let Some((mut unwrapped, extra_snapshots)) = unwrapped {
868                let total_snapshots = n_snapshots + extra_snapshots;
869                if total_snapshots != 0 {
870                    // Adding snapshot taking statements for snapshots.
871                    for _ in 1..total_snapshots {
872                        let ty = TypeLongId::Snapshot(self.variables[unwrapped].ty).intern(db);
873                        let non_snap_var = Variable::with_default_context(self.db, ty, location);
874                        let snapped = self.variables.alloc(non_snap_var);
875                        statements.push(snapshot_stmt(self.variables, unwrapped, snapped));
876                        unwrapped = snapped;
877                    }
878                    statements.push(snapshot_stmt(self.variables, unwrapped, output));
879                };
880                return Some(BlockEnd::Goto(arm.block_id, Default::default()));
881            }
882        }
883        None
884    }
885
886    /// Handles the end of a block based on an extern function call.
887    /// Possibly extends the blocks statements as well.
888    /// Returns None if no additional changes are required.
889    /// If changes are required, returns the updated block end.
890    fn handle_extern_block_end(
891        &mut self,
892        info: &mut MatchExternInfo<'db>,
893        statements: &mut Vec<Statement<'db>>,
894    ) -> Option<BlockEnd<'db>> {
895        let db = self.db;
896        let (id, generic_args) = info.function.get_extern(db)?;
897        if self.nz_fns.contains(&id) {
898            let val = self.as_const(info.inputs[0].var_id)?;
899            let is_zero = match val.long(db) {
900                ConstValue::Int(v, _) => v.is_zero(),
901                ConstValue::Struct(s, _) => s.iter().all(|v| {
902                    v.long(db).to_int().expect("Expected ConstValue::Int for size").is_zero()
903                }),
904                _ => unreachable!(),
905            };
906            Some(if is_zero {
907                BlockEnd::Goto(info.arms[0].block_id, Default::default())
908            } else {
909                let arm = &info.arms[1];
910                let nz_var = arm.var_ids[0];
911                let nz_val = ConstValue::NonZero(val).intern(db);
912                self.var_info.insert(nz_var, VarInfo::Const(nz_val));
913                statements.push(Statement::Const(StatementConst::new_flat(nz_val, nz_var)));
914                BlockEnd::Goto(arm.block_id, Default::default())
915            })
916        } else if self.eq_fns.contains(&id) {
917            let lhs = self.as_int(info.inputs[0].var_id);
918            let rhs = self.as_int(info.inputs[1].var_id);
919            if (lhs.map(Zero::is_zero).unwrap_or_default() && rhs.is_none())
920                || (rhs.map(Zero::is_zero).unwrap_or_default() && lhs.is_none())
921            {
922                let nz_input = info.inputs[if lhs.is_some() { 1 } else { 0 }];
923                let var = &self.variables[nz_input.var_id].clone();
924                let function = self.type_info.get(&var.ty)?.is_zero;
925                let unused_nz_var = Variable::with_default_context(
926                    db,
927                    corelib::core_nonzero_ty(db, var.ty),
928                    var.location,
929                );
930                let unused_nz_var = self.variables.alloc(unused_nz_var);
931                return Some(BlockEnd::Match {
932                    info: MatchInfo::Extern(MatchExternInfo {
933                        function,
934                        inputs: vec![nz_input],
935                        arms: vec![
936                            MatchArm {
937                                arm_selector: MatchArmSelector::VariantId(
938                                    corelib::jump_nz_zero_variant(db, var.ty),
939                                ),
940                                block_id: info.arms[1].block_id,
941                                var_ids: vec![],
942                            },
943                            MatchArm {
944                                arm_selector: MatchArmSelector::VariantId(
945                                    corelib::jump_nz_nonzero_variant(db, var.ty),
946                                ),
947                                block_id: info.arms[0].block_id,
948                                var_ids: vec![unused_nz_var],
949                            },
950                        ],
951                        location: info.location,
952                    }),
953                });
954            }
955            Some(BlockEnd::Goto(
956                info.arms[if lhs? == rhs? { 1 } else { 0 }].block_id,
957                Default::default(),
958            ))
959        } else if self.uadd_fns.contains(&id)
960            || self.usub_fns.contains(&id)
961            || self.diff_fns.contains(&id)
962            || self.iadd_fns.contains(&id)
963            || self.isub_fns.contains(&id)
964        {
965            let rhs = self.as_int(info.inputs[1].var_id);
966            let lhs = self.as_int(info.inputs[0].var_id);
967            if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
968                let ty = self.variables[info.arms[0].var_ids[0]].ty;
969                let range = self.type_value_ranges.get(&ty)?;
970                let value = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
971                    lhs + rhs
972                } else {
973                    lhs - rhs
974                };
975                let (arm_index, value) = match range.normalized(value) {
976                    NormalizedResult::InRange(value) => (0, value),
977                    NormalizedResult::Under(value) => (1, value),
978                    NormalizedResult::Over(value) => (
979                        if self.iadd_fns.contains(&id) || self.isub_fns.contains(&id) {
980                            2
981                        } else {
982                            1
983                        },
984                        value,
985                    ),
986                };
987                let arm = &info.arms[arm_index];
988                let actual_output = arm.var_ids[0];
989                let value = ConstValue::Int(value, ty).intern(db);
990                self.var_info.insert(actual_output, VarInfo::Const(value));
991                statements.push(Statement::Const(StatementConst::new_flat(value, actual_output)));
992                return Some(BlockEnd::Goto(arm.block_id, Default::default()));
993            }
994            if let Some(rhs) = rhs {
995                if rhs.is_zero() && !self.diff_fns.contains(&id) {
996                    let arm = &info.arms[0];
997                    self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[0]));
998                    return Some(BlockEnd::Goto(arm.block_id, Default::default()));
999                }
1000                if rhs.is_one() && !self.diff_fns.contains(&id) {
1001                    let ty = self.variables[info.arms[0].var_ids[0]].ty;
1002                    let ty_info = self.type_info.get(&ty)?;
1003                    let function = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
1004                        ty_info.inc?
1005                    } else {
1006                        ty_info.dec?
1007                    };
1008                    let enum_ty = function.signature(db).ok()?.return_type;
1009                    let TypeLongId::Concrete(ConcreteTypeId::Enum(concrete_enum_id)) =
1010                        enum_ty.long(db)
1011                    else {
1012                        return None;
1013                    };
1014                    let result = self.variables.alloc(Variable::with_default_context(
1015                        db,
1016                        function.signature(db).unwrap().return_type,
1017                        info.location,
1018                    ));
1019                    statements.push(Statement::Call(StatementCall {
1020                        function,
1021                        inputs: vec![info.inputs[0]],
1022                        with_coupon: false,
1023                        outputs: vec![result],
1024                        location: info.location,
1025                    }));
1026                    return Some(BlockEnd::Match {
1027                        info: MatchInfo::Enum(MatchEnumInfo {
1028                            concrete_enum_id: *concrete_enum_id,
1029                            input: VarUsage { var_id: result, location: info.location },
1030                            arms: core::mem::take(&mut info.arms),
1031                            location: info.location,
1032                        }),
1033                    });
1034                }
1035            }
1036            if let Some(lhs) = lhs
1037                && lhs.is_zero()
1038                && (self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id))
1039            {
1040                let arm = &info.arms[0];
1041                self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[1]));
1042                return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1043            }
1044            None
1045        } else if let Some(reversed) = self.downcast_fns.get(&id) {
1046            let range = |ty: TypeId<'_>| {
1047                Some(if let Some(range) = self.type_value_ranges.get(&ty) {
1048                    range.clone()
1049                } else {
1050                    let (min, max) = corelib::try_extract_bounded_int_type_ranges(db, ty)?;
1051                    TypeRange { min, max }
1052                })
1053            };
1054            let (success_arm, failure_arm) = if *reversed { (1, 0) } else { (0, 1) };
1055            let input_var = info.inputs[0].var_id;
1056            let in_ty = self.variables[input_var].ty;
1057            let success_output = info.arms[success_arm].var_ids[0];
1058            let out_ty = self.variables[success_output].ty;
1059            let out_range = range(out_ty)?;
1060            let Some(value) = self.as_int(input_var) else {
1061                let in_range = range(in_ty)?;
1062                return if in_range.min < out_range.min || in_range.max > out_range.max {
1063                    None
1064                } else {
1065                    let generic_args = [in_ty, out_ty].map(GenericArgumentId::Type).to_vec();
1066                    let function = db.core_info().upcast_fn.concretize(db, generic_args);
1067                    statements.push(Statement::Call(StatementCall {
1068                        function: function.lowered(db),
1069                        inputs: vec![info.inputs[0]],
1070                        with_coupon: false,
1071                        outputs: vec![success_output],
1072                        location: info.location,
1073                    }));
1074                    return Some(BlockEnd::Goto(
1075                        info.arms[success_arm].block_id,
1076                        Default::default(),
1077                    ));
1078                };
1079            };
1080            let value = if in_ty == self.felt252 {
1081                felt252_for_downcast(value, &out_range.min)
1082            } else {
1083                value.clone()
1084            };
1085            Some(if let NormalizedResult::InRange(value) = out_range.normalized(value) {
1086                let value = ConstValue::Int(value, out_ty).intern(db);
1087                self.var_info.insert(success_output, VarInfo::Const(value));
1088                statements.push(Statement::Const(StatementConst::new_flat(value, success_output)));
1089                BlockEnd::Goto(info.arms[success_arm].block_id, Default::default())
1090            } else {
1091                BlockEnd::Goto(info.arms[failure_arm].block_id, Default::default())
1092            })
1093        } else if id == self.bounded_int_constrain {
1094            let input_var = info.inputs[0].var_id;
1095            let value = self.as_int(input_var)?;
1096            let generic_arg = generic_args[1];
1097            let constrain_value = extract_matches!(generic_arg, GenericArgumentId::Constant)
1098                .long(db)
1099                .to_int()
1100                .expect("Expected ConstValue::Int for size");
1101            let arm_idx = if value < constrain_value { 0 } else { 1 };
1102            let output = info.arms[arm_idx].var_ids[0];
1103            statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1104            Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1105        } else if id == self.array_get {
1106            let index = self.as_int(info.inputs[1].var_id)?.to_usize()?;
1107            if let Some(VarInfo::Snapshot(arr_info)) = self.var_info.get(&info.inputs[0].var_id)
1108                && let VarInfo::Array(infos) = arr_info.as_ref()
1109            {
1110                match infos.get(index) {
1111                    Some(Some(output_var_info)) => {
1112                        let arm = &info.arms[0];
1113                        let output_var_info = output_var_info.clone();
1114                        let box_info =
1115                            VarInfo::Box(VarInfo::Snapshot(output_var_info.clone().into()).into());
1116                        self.var_info.insert(arm.var_ids[0], box_info);
1117                        if let VarInfo::Const(value) = output_var_info {
1118                            let value_ty = value.ty(db).ok()?;
1119                            let value_box_ty = corelib::core_box_ty(db, value_ty);
1120                            let location = info.location;
1121                            let boxed_var =
1122                                Variable::with_default_context(db, value_box_ty, location);
1123                            let boxed = self.variables.alloc(boxed_var.clone());
1124                            let unused_boxed = self.variables.alloc(boxed_var);
1125                            let snapped = self.variables.alloc(Variable::with_default_context(
1126                                db,
1127                                TypeLongId::Snapshot(value_box_ty).intern(db),
1128                                location,
1129                            ));
1130                            statements.extend([
1131                                Statement::Const(StatementConst::new_boxed(value, boxed)),
1132                                Statement::Snapshot(StatementSnapshot {
1133                                    input: VarUsage { var_id: boxed, location },
1134                                    outputs: [unused_boxed, snapped],
1135                                }),
1136                                Statement::Call(StatementCall {
1137                                    function: self
1138                                        .box_forward_snapshot
1139                                        .concretize(db, vec![GenericArgumentId::Type(value_ty)])
1140                                        .lowered(db),
1141                                    inputs: vec![VarUsage { var_id: snapped, location }],
1142                                    with_coupon: false,
1143                                    outputs: vec![arm.var_ids[0]],
1144                                    location: info.location,
1145                                }),
1146                            ]);
1147                            return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1148                        }
1149                    }
1150                    None => {
1151                        return Some(BlockEnd::Goto(info.arms[1].block_id, Default::default()));
1152                    }
1153                    Some(None) => {}
1154                }
1155            }
1156            if index.is_zero()
1157                && let [success, failure] = info.arms.as_mut_slice()
1158            {
1159                let arr = info.inputs[0].var_id;
1160                let unused_arr_output0 = self.variables.alloc(self.variables[arr].clone());
1161                let unused_arr_output1 = self.variables.alloc(self.variables[arr].clone());
1162                info.inputs.truncate(1);
1163                info.function = GenericFunctionId::Extern(self.array_snapshot_pop_front)
1164                    .concretize(db, generic_args)
1165                    .lowered(db);
1166                success.var_ids.insert(0, unused_arr_output0);
1167                failure.var_ids.insert(0, unused_arr_output1);
1168            }
1169            None
1170        } else if id == self.array_pop_front {
1171            let VarInfo::Array(var_infos) = self.var_info.get(&info.inputs[0].var_id)? else {
1172                return None;
1173            };
1174            if let Some(first) = var_infos.first() {
1175                if let Some(first) = first.as_ref().cloned() {
1176                    let arm = &info.arms[0];
1177                    self.var_info.insert(arm.var_ids[0], VarInfo::Array(var_infos[1..].to_vec()));
1178                    self.var_info.insert(arm.var_ids[1], VarInfo::Box(first.into()));
1179                }
1180                None
1181            } else {
1182                let arm = &info.arms[1];
1183                self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]));
1184                Some(BlockEnd::Goto(
1185                    arm.block_id,
1186                    VarRemapping {
1187                        remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1188                    },
1189                ))
1190            }
1191        } else if id == self.array_snapshot_pop_back || id == self.array_snapshot_pop_front {
1192            let var_info = self.var_info.get(&info.inputs[0].var_id)?;
1193            let desnapped = try_extract_matches!(var_info, VarInfo::Snapshot)?;
1194            let element_var_infos = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?;
1195            // TODO(orizi): Propagate success values as well.
1196            if element_var_infos.is_empty() {
1197                let arm = &info.arms[1];
1198                self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]));
1199                Some(BlockEnd::Goto(
1200                    arm.block_id,
1201                    VarRemapping {
1202                        remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1203                    },
1204                ))
1205            } else {
1206                None
1207            }
1208        } else {
1209            None
1210        }
1211    }
1212
1213    /// Returns the const value of a variable if it exists.
1214    fn as_const(&self, var_id: VariableId) -> Option<ConstValueId<'db>> {
1215        try_extract_matches!(self.var_info.get(&var_id)?, VarInfo::Const).copied()
1216    }
1217
1218    /// Return the const value as a int if it exists and is an integer.
1219    fn as_int(&self, var_id: VariableId) -> Option<&BigInt> {
1220        match self.as_const(var_id)?.long(self.db) {
1221            ConstValue::Int(value, _) => Some(value),
1222            ConstValue::NonZero(const_value) => {
1223                if let ConstValue::Int(value, _) = const_value.long(self.db) {
1224                    Some(value)
1225                } else {
1226                    None
1227                }
1228            }
1229            _ => None,
1230        }
1231    }
1232
1233    /// Replaces the inputs in place if they are in the var_info map.
1234    fn maybe_replace_inputs(&self, inputs: &mut [VarUsage<'db>]) {
1235        for input in inputs {
1236            self.maybe_replace_input(input);
1237        }
1238    }
1239
1240    /// Replaces the input in place if it is in the var_info map.
1241    fn maybe_replace_input(&self, input: &mut VarUsage<'db>) {
1242        if let Some(VarInfo::Var(new_var)) = self.var_info.get(&input.var_id) {
1243            *input = *new_var;
1244        }
1245    }
1246
1247    /// Given a var_info and its type, return the corresponding specialization argument, if it
1248    /// exists.
1249    fn try_get_specialization_arg(
1250        &self,
1251        var_info: VarInfo<'db>,
1252        ty: TypeId<'db>,
1253        unknown_vars: &mut Vec<VarUsage<'db>>,
1254    ) -> Option<SpecializationArg<'db>> {
1255        if self.db.type_size_info(ty).ok()? == TypeSizeInformation::ZeroSized {
1256            // Skip zero-sized constants as they are not supported in sierra-gen.
1257            return None;
1258        }
1259
1260        match var_info {
1261            VarInfo::Const(value) => Some(SpecializationArg::Const { value, boxed: false }),
1262            VarInfo::Box(info) => try_extract_matches!(info.as_ref(), VarInfo::Const)
1263                .map(|value| SpecializationArg::Const { value: *value, boxed: true }),
1264            VarInfo::Snapshot(info) => {
1265                let desnap_ty = *extract_matches!(ty.long(self.db), TypeLongId::Snapshot);
1266                // Use a local accumulator to avoid mutating unknown_vars if we return None.
1267                let mut local_unknown_vars: Vec<VarUsage<'db>> = Vec::new();
1268                let inner = self.try_get_specialization_arg(
1269                    info.as_ref().clone(),
1270                    desnap_ty,
1271                    &mut local_unknown_vars,
1272                )?;
1273                unknown_vars.extend(local_unknown_vars);
1274                Some(SpecializationArg::Snapshot(Box::new(inner)))
1275            }
1276            VarInfo::Array(infos) => {
1277                let TypeLongId::Concrete(concrete_ty) = ty.long(self.db) else {
1278                    unreachable!("Expected a concrete type");
1279                };
1280                let [GenericArgumentId::Type(inner_ty)] = &concrete_ty.generic_args(self.db)[..]
1281                else {
1282                    unreachable!("Expected a single type generic argument");
1283                };
1284                // Accumulate into locals first; only extend unknown_vars if we end up specializing.
1285                let mut vars = vec![];
1286                let mut args = vec![];
1287                for info in infos {
1288                    let info = info?;
1289                    let arg = self.try_get_specialization_arg(info, *inner_ty, &mut vars)?;
1290                    args.push(arg);
1291                }
1292                if !args.is_empty()
1293                    && args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
1294                {
1295                    return None;
1296                }
1297                unknown_vars.extend(vars);
1298                Some(SpecializationArg::Array(*inner_ty, args))
1299            }
1300            VarInfo::Struct(infos) => {
1301                let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) =
1302                    ty.long(self.db)
1303                else {
1304                    // TODO(ilya): Support closures and fixed size arrays.
1305                    return None;
1306                };
1307
1308                let members = self.db.concrete_struct_members(*concrete_struct).unwrap();
1309                let mut struct_args = Vec::new();
1310                // Accumulate into locals first; only extend unknown_vars if we end up specializing.
1311                let mut vars = vec![];
1312                for (member, opt_var_info) in zip_eq(members.values(), infos) {
1313                    let var_info = opt_var_info?;
1314                    let arg = self.try_get_specialization_arg(var_info, member.ty, &mut vars)?;
1315                    struct_args.push(arg);
1316                }
1317                if !struct_args.is_empty()
1318                    && struct_args
1319                        .iter()
1320                        .all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
1321                {
1322                    return None;
1323                }
1324                unknown_vars.extend(vars);
1325                Some(SpecializationArg::Struct(struct_args))
1326            }
1327            VarInfo::Enum { variant, payload } => {
1328                let mut local_unknown_vars = vec![];
1329                let payload_arg = self.try_get_specialization_arg(
1330                    payload.as_ref().clone(),
1331                    variant.ty,
1332                    &mut local_unknown_vars,
1333                )?;
1334
1335                unknown_vars.extend(local_unknown_vars);
1336                Some(SpecializationArg::Enum { variant, payload: Box::new(payload_arg) })
1337            }
1338            VarInfo::Var(var_usage) => {
1339                unknown_vars.push(var_usage);
1340                Some(SpecializationArg::NotSpecialized)
1341            }
1342        }
1343    }
1344
1345    /// Returns true if const-folding should be skipped for the current function.
1346    pub fn should_skip_const_folding(&self, db: &'db dyn Database) -> bool {
1347        if db.optimizations().skip_const_folding() {
1348            return true;
1349        }
1350
1351        // Skipping const-folding for `panic_with_const_felt252` - to avoid replacing a call to
1352        // `panic_with_felt252` with `panic_with_const_felt252` and causing accidental recursion.
1353        if self.caller_base.base_semantic_function(db).generic_function(db)
1354            == GenericFunctionWithBodyId::Free(self.libfunc_info.panic_with_const_felt252)
1355        {
1356            return true;
1357        }
1358        false
1359    }
1360}
1361
1362/// Returns a `VarInfo` of a variable only if it is copyable.
1363fn var_info_if_copy<'db>(
1364    variables: &VariableArena<'db>,
1365    input: VarUsage<'db>,
1366) -> Option<VarInfo<'db>> {
1367    variables[input.var_id].info.copyable.is_ok().then_some(VarInfo::Var(input))
1368}
1369
1370/// Internal query for the libfuncs information required for const folding.
1371#[salsa::tracked(returns(ref))]
1372fn priv_const_folding_info<'db>(
1373    db: &'db dyn Database,
1374) -> crate::optimizations::const_folding::ConstFoldingLibfuncInfo<'db> {
1375    ConstFoldingLibfuncInfo::new(db)
1376}
1377
1378/// Holds static information about libfuncs required for the optimization.
1379#[derive(Debug, PartialEq, Eq, salsa::Update)]
1380pub struct ConstFoldingLibfuncInfo<'db> {
1381    /// The `felt252_sub` libfunc.
1382    felt_sub: ExternFunctionId<'db>,
1383    /// The `felt252_add` libfunc.
1384    felt_add: ExternFunctionId<'db>,
1385    /// The `felt252_mul` libfunc.
1386    felt_mul: ExternFunctionId<'db>,
1387    /// The `felt252_div` libfunc.
1388    felt_div: ExternFunctionId<'db>,
1389    /// The `into_box` libfunc.
1390    into_box: ExternFunctionId<'db>,
1391    /// The `unbox` libfunc.
1392    unbox: ExternFunctionId<'db>,
1393    /// The `box_forward_snapshot` libfunc.
1394    box_forward_snapshot: GenericFunctionId<'db>,
1395    /// The set of functions that check if numbers are equal.
1396    eq_fns: OrderedHashSet<ExternFunctionId<'db>>,
1397    /// The set of functions to add unsigned ints.
1398    uadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1399    /// The set of functions to subtract unsigned ints.
1400    usub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1401    /// The set of functions to get the difference of signed ints.
1402    diff_fns: OrderedHashSet<ExternFunctionId<'db>>,
1403    /// The set of functions to add signed ints.
1404    iadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1405    /// The set of functions to subtract signed ints.
1406    isub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1407    /// The set of functions to multiply integers.
1408    wide_mul_fns: OrderedHashSet<ExternFunctionId<'db>>,
1409    /// The set of functions to divide and get the remainder of integers.
1410    div_rem_fns: OrderedHashSet<ExternFunctionId<'db>>,
1411    /// The `bounded_int_add` libfunc.
1412    bounded_int_add: ExternFunctionId<'db>,
1413    /// The `bounded_int_sub` libfunc.
1414    bounded_int_sub: ExternFunctionId<'db>,
1415    /// The `bounded_int_constrain` libfunc.
1416    bounded_int_constrain: ExternFunctionId<'db>,
1417    /// The `array_get` libfunc.
1418    array_get: ExternFunctionId<'db>,
1419    /// The `array_snapshot_pop_front` libfunc.
1420    array_snapshot_pop_front: ExternFunctionId<'db>,
1421    /// The `array_snapshot_pop_back` libfunc.
1422    array_snapshot_pop_back: ExternFunctionId<'db>,
1423    /// The `array_len` libfunc.
1424    array_len: ExternFunctionId<'db>,
1425    /// The `array_new` libfunc.
1426    array_new: ExternFunctionId<'db>,
1427    /// The `array_append` libfunc.
1428    array_append: ExternFunctionId<'db>,
1429    /// The `array_pop_front` libfunc.
1430    array_pop_front: ExternFunctionId<'db>,
1431    /// The `storage_base_address_from_felt252` libfunc.
1432    storage_base_address_from_felt252: ExternFunctionId<'db>,
1433    /// The `storage_base_address_const` libfunc.
1434    storage_base_address_const: GenericFunctionId<'db>,
1435    /// The `core::panic_with_felt252` function.
1436    panic_with_felt252: FunctionId<'db>,
1437    /// The `core::panic_with_const_felt252` function.
1438    pub panic_with_const_felt252: FreeFunctionId<'db>,
1439    /// The `core::panics::panic_with_byte_array` function.
1440    panic_with_byte_array: FunctionId<'db>,
1441    /// Information per type.
1442    type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>>,
1443    /// The info used for semantic const calculation.
1444    const_calculation_info: Arc<ConstCalcInfo<'db>>,
1445}
1446impl<'db> ConstFoldingLibfuncInfo<'db> {
1447    fn new(db: &'db dyn Database) -> Self {
1448        let core = ModuleHelper::core(db);
1449        let box_module = core.submodule("box");
1450        let integer_module = core.submodule("integer");
1451        let internal_module = core.submodule("internal");
1452        let bounded_int_module = internal_module.submodule("bounded_int");
1453        let num_module = internal_module.submodule("num");
1454        let array_module = core.submodule("array");
1455        let starknet_module = core.submodule("starknet");
1456        let storage_access_module = starknet_module.submodule("storage_access");
1457        let utypes = ["u8", "u16", "u32", "u64", "u128"];
1458        let itypes = ["i8", "i16", "i32", "i64", "i128"];
1459        let eq_fns = OrderedHashSet::<_>::from_iter(
1460            chain!(utypes, itypes).map(|ty| integer_module.extern_function_id(&format!("{ty}_eq"))),
1461        );
1462        let uadd_fns = OrderedHashSet::<_>::from_iter(
1463            utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_add"))),
1464        );
1465        let usub_fns = OrderedHashSet::<_>::from_iter(
1466            utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_sub"))),
1467        );
1468        let diff_fns = OrderedHashSet::<_>::from_iter(
1469            itypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_diff"))),
1470        );
1471        let iadd_fns =
1472            OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1473                integer_module.extern_function_id(&format!("{ty}_overflowing_add_impl"))
1474            }));
1475        let isub_fns =
1476            OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1477                integer_module.extern_function_id(&format!("{ty}_overflowing_sub_impl"))
1478            }));
1479        let wide_mul_fns = OrderedHashSet::<_>::from_iter(chain!(
1480            [bounded_int_module.extern_function_id("bounded_int_mul")],
1481            ["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
1482                .map(|ty| integer_module.extern_function_id(&format!("{ty}_wide_mul"))),
1483        ));
1484        let div_rem_fns = OrderedHashSet::<_>::from_iter(chain!(
1485            [bounded_int_module.extern_function_id("bounded_int_div_rem")],
1486            utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_safe_divmod"))),
1487        ));
1488        let type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>> = OrderedHashMap::from_iter(
1489            [
1490                ("u8", false, true),
1491                ("u16", false, true),
1492                ("u32", false, true),
1493                ("u64", false, true),
1494                ("u128", false, true),
1495                ("u256", false, false),
1496                ("i8", true, true),
1497                ("i16", true, true),
1498                ("i32", true, true),
1499                ("i64", true, true),
1500                ("i128", true, true),
1501            ]
1502            .map(|(ty_name, as_bounded_int, inc_dec): (&'static str, bool, bool)| {
1503                let ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, ty_name), vec![]);
1504                let is_zero = if as_bounded_int {
1505                    bounded_int_module
1506                        .function_id("bounded_int_is_zero", vec![GenericArgumentId::Type(ty)])
1507                } else {
1508                    integer_module.function_id(
1509                        SmolStrId::from(db, format!("{ty_name}_is_zero")).long(db).as_str(),
1510                        vec![],
1511                    )
1512                }
1513                .lowered(db);
1514                let (inc, dec) = if inc_dec {
1515                    (
1516                        Some(
1517                            num_module
1518                                .function_id(
1519                                    SmolStrId::from(db, format!("{ty_name}_inc")).long(db).as_str(),
1520                                    vec![],
1521                                )
1522                                .lowered(db),
1523                        ),
1524                        Some(
1525                            num_module
1526                                .function_id(
1527                                    SmolStrId::from(db, format!("{ty_name}_dec")).long(db).as_str(),
1528                                    vec![],
1529                                )
1530                                .lowered(db),
1531                        ),
1532                    )
1533                } else {
1534                    (None, None)
1535                };
1536                let info = TypeInfo { is_zero, inc, dec };
1537                (ty, info)
1538            }),
1539        );
1540        Self {
1541            felt_sub: core.extern_function_id("felt252_sub"),
1542            felt_add: core.extern_function_id("felt252_add"),
1543            felt_mul: core.extern_function_id("felt252_mul"),
1544            felt_div: core.extern_function_id("felt252_div"),
1545            into_box: box_module.extern_function_id("into_box"),
1546            unbox: box_module.extern_function_id("unbox"),
1547            box_forward_snapshot: box_module.generic_function_id("box_forward_snapshot"),
1548            eq_fns,
1549            uadd_fns,
1550            usub_fns,
1551            diff_fns,
1552            iadd_fns,
1553            isub_fns,
1554            wide_mul_fns,
1555            div_rem_fns,
1556            bounded_int_add: bounded_int_module.extern_function_id("bounded_int_add"),
1557            bounded_int_sub: bounded_int_module.extern_function_id("bounded_int_sub"),
1558            bounded_int_constrain: bounded_int_module.extern_function_id("bounded_int_constrain"),
1559            array_get: array_module.extern_function_id("array_get"),
1560            array_snapshot_pop_front: array_module.extern_function_id("array_snapshot_pop_front"),
1561            array_snapshot_pop_back: array_module.extern_function_id("array_snapshot_pop_back"),
1562            array_len: array_module.extern_function_id("array_len"),
1563            array_new: array_module.extern_function_id("array_new"),
1564            array_append: array_module.extern_function_id("array_append"),
1565            array_pop_front: array_module.extern_function_id("array_pop_front"),
1566            storage_base_address_from_felt252: storage_access_module
1567                .extern_function_id("storage_base_address_from_felt252"),
1568            storage_base_address_const: storage_access_module
1569                .generic_function_id("storage_base_address_const"),
1570            panic_with_felt252: core.function_id("panic_with_felt252", vec![]).lowered(db),
1571            panic_with_const_felt252: core.free_function_id("panic_with_const_felt252"),
1572            panic_with_byte_array: core
1573                .submodule("panics")
1574                .function_id("panic_with_byte_array", vec![])
1575                .lowered(db),
1576            type_info,
1577            const_calculation_info: db.const_calc_info(),
1578        }
1579    }
1580}
1581
1582impl<'db> std::ops::Deref for ConstFoldingContext<'db, '_> {
1583    type Target = ConstFoldingLibfuncInfo<'db>;
1584    fn deref(&self) -> &ConstFoldingLibfuncInfo<'db> {
1585        self.libfunc_info
1586    }
1587}
1588
1589impl<'a> std::ops::Deref for ConstFoldingLibfuncInfo<'a> {
1590    type Target = ConstCalcInfo<'a>;
1591    fn deref(&self) -> &ConstCalcInfo<'a> {
1592        &self.const_calculation_info
1593    }
1594}
1595
1596/// The information of a type required for const foldings.
1597#[derive(Debug, PartialEq, Eq, salsa::Update)]
1598struct TypeInfo<'db> {
1599    /// The function to check if the value is zero for the type.
1600    is_zero: FunctionId<'db>,
1601    /// Inc function to increase the value by one.
1602    inc: Option<FunctionId<'db>>,
1603    /// Dec function to decrease the value by one.
1604    dec: Option<FunctionId<'db>>,
1605}
1606
1607trait TypeRangeNormalizer {
1608    /// Normalizes the value to the range.
1609    /// Assumes the value is within size of range of the range.
1610    fn normalized(&self, value: BigInt) -> NormalizedResult;
1611}
1612impl TypeRangeNormalizer for TypeRange {
1613    fn normalized(&self, value: BigInt) -> NormalizedResult {
1614        if value < self.min {
1615            NormalizedResult::Under(value - &self.min + &self.max + 1)
1616        } else if value > self.max {
1617            NormalizedResult::Over(value + &self.min - &self.max - 1)
1618        } else {
1619            NormalizedResult::InRange(value)
1620        }
1621    }
1622}
1623
1624/// The result of normalizing a value to a range.
1625enum NormalizedResult {
1626    /// The original value is in the range, carries the value, or an equivalent value.
1627    InRange(BigInt),
1628    /// The original value is larger than range max, carries the normalized value.
1629    Over(BigInt),
1630    /// The original value is smaller than range min, carries the normalized value.
1631    Under(BigInt),
1632}