cairo_lang_lowering/optimizations/
const_folding.rs

1#[cfg(test)]
2#[path = "const_folding_test.rs"]
3mod test;
4
5use std::sync::Arc;
6
7use cairo_lang_defs::ids::{ExternFunctionId, FreeFunctionId};
8use cairo_lang_filesystem::flag::flag_unsafe_panic;
9use cairo_lang_filesystem::ids::SmolStrId;
10use cairo_lang_semantic::corelib::{CorelibSemantic, try_extract_nz_wrapped_type};
11use cairo_lang_semantic::helper::ModuleHelper;
12use cairo_lang_semantic::items::constant::{
13    ConstCalcInfo, ConstValue, ConstValueId, ConstantSemantic,
14};
15use cairo_lang_semantic::items::functions::{GenericFunctionId, GenericFunctionWithBodyId};
16use cairo_lang_semantic::items::structure::StructSemantic;
17use cairo_lang_semantic::types::{TypeSizeInformation, TypesSemantic};
18use cairo_lang_semantic::{
19    ConcreteTypeId, GenericArgumentId, MatchArmSelector, TypeId, TypeLongId, corelib,
20};
21use cairo_lang_utils::byte_array::BYTE_ARRAY_MAGIC;
22use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
23use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
24use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
25use cairo_lang_utils::{Intern, extract_matches, try_extract_matches};
26use itertools::{chain, zip_eq};
27use num_bigint::BigInt;
28use num_integer::Integer;
29use num_traits::cast::ToPrimitive;
30use num_traits::{Num, One, Zero};
31use salsa::Database;
32
33use crate::db::LoweringGroup;
34use crate::ids::{
35    ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionId, SemanticFunctionIdEx,
36    SpecializedFunction,
37};
38use crate::specialization::SpecializationArg;
39use crate::utils::InliningStrategy;
40use crate::{
41    Block, BlockEnd, BlockId, Lowered, MatchArm, MatchEnumInfo, MatchExternInfo, MatchInfo,
42    Statement, StatementCall, StatementConst, StatementDesnap, StatementEnumConstruct,
43    StatementSnapshot, StatementStructConstruct, StatementStructDestructure, VarRemapping,
44    VarUsage, Variable, VariableArena, VariableId,
45};
46
47/// Keeps track of equivalent values that a variables might be replaced with.
48/// Note: We don't keep track of types as we assume the usage is always correct.
49#[derive(Debug, Clone)]
50enum VarInfo<'db> {
51    /// The variable is a const value.
52    Const(ConstValueId<'db>),
53    /// The variable can be replaced by another variable.
54    Var(VarUsage<'db>),
55    /// The variable is a snapshot of another variable.
56    Snapshot(Box<VarInfo<'db>>),
57    /// The variable is a struct of other variables.
58    /// `None` values represent variables that are not tracked.
59    Struct(Vec<Option<VarInfo<'db>>>),
60    /// The variable is a box of another variable.
61    Box(Box<VarInfo<'db>>),
62    /// The variable is an array of known size of other variables.
63    /// `None` values represent variables that are not tracked.
64    Array(Vec<Option<VarInfo<'db>>>),
65}
66impl<'db> VarInfo<'db> {
67    /// Peels the snapshots from the variable info and returns the number of snapshots performed.
68    fn peel_snapshots(&self) -> (usize, &VarInfo<'db>) {
69        let mut n_snapshots = 0;
70        let mut info = self;
71        while let VarInfo::Snapshot(inner) = info {
72            info = inner.as_ref();
73            n_snapshots += 1;
74        }
75        (n_snapshots, info)
76    }
77    /// Wraps the variable info with the given number of snapshots.
78    fn wrap_with_snapshots(mut self, n_snapshots: usize) -> VarInfo<'db> {
79        for _ in 0..n_snapshots {
80            self = VarInfo::Snapshot(Box::new(self));
81        }
82        self
83    }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq)]
87enum Reachability {
88    /// The block is reachable from the function start only through the goto at the end of the given
89    /// block.
90    FromSingleGoto(BlockId),
91    /// The block is reachable from the function start after const-folding - just does not fit
92    /// `FromSingleGoto`.
93    Any,
94}
95
96/// Performs constant folding on the lowered program.
97/// The optimization only works when the blocks are topologically sorted.
98pub fn const_folding<'db>(
99    db: &'db dyn Database,
100    function_id: ConcreteFunctionWithBodyId<'db>,
101    lowered: &mut Lowered<'db>,
102) {
103    if lowered.blocks.is_empty() {
104        return;
105    }
106
107    // Note that we can keep the var_info across blocks because the lowering
108    // is in static single assignment form.
109    let mut ctx = ConstFoldingContext::new(db, function_id, &mut lowered.variables);
110
111    if ctx.should_skip_const_folding(db) {
112        return;
113    }
114
115    for block_id in (0..lowered.blocks.len()).map(BlockId) {
116        if !ctx.visit_block_start(block_id, |block_id| &lowered.blocks[block_id]) {
117            continue;
118        }
119
120        let block = &mut lowered.blocks[block_id];
121        for stmt in block.statements.iter_mut() {
122            ctx.visit_statement(stmt);
123        }
124        ctx.visit_block_end(block_id, block);
125    }
126}
127
128pub struct ConstFoldingContext<'db, 'mt> {
129    /// The used database.
130    db: &'db dyn Database,
131    /// The variables arena, mostly used to get the type of variables.
132    pub variables: &'mt mut VariableArena<'db>,
133    /// The accumulated information about the const values of variables.
134    var_info: UnorderedHashMap<VariableId, VarInfo<'db>>,
135    /// The libfunc information.
136    libfunc_info: &'db ConstFoldingLibfuncInfo<'db>,
137    /// The specialization base of the caller function (or the caller if the function is not
138    /// specialized).
139    caller_base: ConcreteFunctionWithBodyId<'db>,
140    /// Reachability of blocks from the function start.
141    /// If the block is not in this map, it means that it is unreachable (or that it was already
142    /// visited and its reachability won't be checked again).
143    reachability: UnorderedHashMap<BlockId, Reachability>,
144    /// Additional statements to add to the block.
145    additional_stmts: Vec<Statement<'db>>,
146}
147
148impl<'db, 'mt> ConstFoldingContext<'db, 'mt> {
149    pub fn new(
150        db: &'db dyn Database,
151        function_id: ConcreteFunctionWithBodyId<'db>,
152        variables: &'mt mut VariableArena<'db>,
153    ) -> Self {
154        let caller_base = match function_id.long(db) {
155            ConcreteFunctionWithBodyLongId::Specialized(specialized_func) => specialized_func.base,
156            _ => function_id,
157        };
158
159        Self {
160            db,
161            var_info: UnorderedHashMap::default(),
162            variables,
163            libfunc_info: priv_const_folding_info(db),
164            caller_base,
165            reachability: UnorderedHashMap::from_iter([(BlockId::root(), Reachability::Any)]),
166            additional_stmts: vec![],
167        }
168    }
169
170    /// Determines if a block is reachable from the function start and propagates constant values
171    /// when the block is reachable via a single goto statement.
172    pub fn visit_block_start<'r, 'get>(
173        &'r mut self,
174        block_id: BlockId,
175        get_block: impl FnOnce(BlockId) -> &'get Block<'db>,
176    ) -> bool
177    where
178        'db: 'get,
179    {
180        let Some(reachability) = self.reachability.remove(&block_id) else {
181            return false;
182        };
183        match reachability {
184            Reachability::Any => {}
185            Reachability::FromSingleGoto(from_block) => match &get_block(from_block).end {
186                BlockEnd::Goto(_, remapping) => {
187                    for (dst, src) in remapping.iter() {
188                        if let Some(v) = self.var_info.get(&src.var_id) {
189                            self.var_info.insert(*dst, v.clone());
190                        }
191                    }
192                }
193                _ => unreachable!("Expected a goto end"),
194            },
195        }
196        true
197    }
198
199    /// Processes a statement and applies the constant folding optimizations.
200    ///
201    /// This method performs the following operations:
202    /// - Updates the `var_info` map with constant values of variables
203    /// - Replace the input statement with optimized versions when possible
204    /// - Updates `self.additional_stmts` with statements that need to be added to the block.
205    ///
206    /// Note: `self.visit_block_end` must be called after processing all statements
207    /// in a block to actually add the additional statements.
208    pub fn visit_statement(&mut self, stmt: &mut Statement<'db>) {
209        self.maybe_replace_inputs(stmt.inputs_mut());
210        match stmt {
211            Statement::Const(StatementConst { value, output, boxed }) if *boxed => {
212                self.var_info.insert(*output, VarInfo::Box(VarInfo::Const(*value).into()));
213            }
214            Statement::Const(StatementConst { value, output, .. }) => match value.long(self.db) {
215                ConstValue::Int(..)
216                | ConstValue::Struct(..)
217                | ConstValue::Enum(..)
218                | ConstValue::NonZero(..) => {
219                    self.var_info.insert(*output, VarInfo::Const(*value));
220                }
221                ConstValue::Generic(_)
222                | ConstValue::ImplConstant(_)
223                | ConstValue::Var(..)
224                | ConstValue::Missing(_) => {}
225            },
226            Statement::Snapshot(stmt) => {
227                if let Some(info) = self.var_info.get(&stmt.input.var_id).cloned() {
228                    self.var_info.insert(stmt.original(), info.clone());
229                    self.var_info.insert(stmt.snapshot(), VarInfo::Snapshot(info.into()));
230                }
231            }
232            Statement::Desnap(StatementDesnap { input, output }) => {
233                if let Some(VarInfo::Snapshot(info)) = self.var_info.get(&input.var_id) {
234                    self.var_info.insert(*output, info.as_ref().clone());
235                }
236            }
237            Statement::Call(call_stmt) => {
238                if let Some(updated_stmt) = self.handle_statement_call(call_stmt) {
239                    *stmt = updated_stmt;
240                } else if let Some(updated_stmt) = self.try_specialize_call(call_stmt) {
241                    *stmt = updated_stmt;
242                }
243            }
244            Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
245                let mut const_args = vec![];
246                let mut all_args = vec![];
247                let mut contains_info = false;
248                for input in inputs.iter() {
249                    let Some(info) = self.var_info.get(&input.var_id) else {
250                        all_args.push(var_info_if_copy(self.variables, *input));
251                        continue;
252                    };
253                    contains_info = true;
254                    if let VarInfo::Const(value) = info {
255                        const_args.push(*value);
256                    }
257                    all_args.push(Some(info.clone()));
258                }
259                if const_args.len() == inputs.len() {
260                    let value =
261                        ConstValue::Struct(const_args, self.variables[*output].ty).intern(self.db);
262                    self.var_info.insert(*output, VarInfo::Const(value));
263                } else if contains_info {
264                    self.var_info.insert(*output, VarInfo::Struct(all_args));
265                }
266            }
267            Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
268                if let Some(info) = self.var_info.get(&input.var_id) {
269                    let (n_snapshots, info) = info.peel_snapshots();
270                    match info {
271                        VarInfo::Const(const_value) => {
272                            if let ConstValue::Struct(member_values, _) = const_value.long(self.db)
273                            {
274                                for (output, value) in zip_eq(outputs, member_values) {
275                                    self.var_info.insert(
276                                        *output,
277                                        VarInfo::Const(*value).wrap_with_snapshots(n_snapshots),
278                                    );
279                                }
280                            }
281                        }
282                        VarInfo::Struct(members) => {
283                            for (output, member) in zip_eq(outputs, members.clone()) {
284                                if let Some(member) = member {
285                                    self.var_info
286                                        .insert(*output, member.wrap_with_snapshots(n_snapshots));
287                                }
288                            }
289                        }
290                        _ => {}
291                    }
292                }
293            }
294            Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
295                if let Some(VarInfo::Const(val)) = self.var_info.get(&input.var_id) {
296                    let value = ConstValue::Enum(*variant, *val).intern(self.db);
297                    self.var_info.insert(*output, VarInfo::Const(value));
298                }
299            }
300        }
301    }
302
303    /// Processes the block's end and incorporates additional statements into the block.
304    ///
305    /// This method handles the following tasks:
306    /// - Inserts the accumulated additional statements into the block.
307    /// - Converts match endings to goto when applicable.
308    /// - Updates self.reachability based on the block's ending.
309    pub fn visit_block_end(&mut self, block_id: BlockId, block: &mut crate::Block<'db>) {
310        block.statements.splice(0..0, self.additional_stmts.drain(..));
311
312        match &mut block.end {
313            BlockEnd::Goto(_, remappings) => {
314                for (_, v) in remappings.iter_mut() {
315                    self.maybe_replace_input(v);
316                }
317            }
318            BlockEnd::Match { info } => {
319                self.maybe_replace_inputs(info.inputs_mut());
320                match info {
321                    MatchInfo::Enum(MatchEnumInfo { input, arms, location, .. }) => {
322                        if let Some(var_info) = self.var_info.get(&input.var_id) {
323                            let (n_snapshots, var_info) = var_info.peel_snapshots();
324                            // Checking whether we have actual const info on the enum.
325                            if let VarInfo::Const(const_value) = var_info
326                                && let ConstValue::Enum(variant, value) = const_value.long(self.db)
327                            {
328                                let arm = &arms[variant.idx];
329                                let output = arm.var_ids[0];
330                                if self.variables[input.var_id].info.droppable.is_ok()
331                                    && self.variables[output].info.copyable.is_ok()
332                                    && let Ok(mut ty) = value.ty(self.db)
333                                    && let Some(mut stmt) =
334                                        self.try_generate_const_statement(*value, output)
335                                {
336                                    // Adding snapshot taking statements for snapshots.
337                                    for _ in 0..n_snapshots {
338                                        let non_snap_var =
339                                            Variable::with_default_context(self.db, ty, *location);
340                                        ty = TypeLongId::Snapshot(ty).intern(self.db);
341                                        let ignored = self.variables.alloc(non_snap_var.clone());
342                                        let pre_snap = self.variables.alloc(non_snap_var);
343                                        stmt.outputs_mut()[0] = pre_snap;
344                                        block.statements.push(core::mem::replace(
345                                            &mut stmt,
346                                            Statement::Snapshot(StatementSnapshot::new(
347                                                VarUsage { var_id: pre_snap, location: *location },
348                                                ignored,
349                                                output,
350                                            )),
351                                        ));
352                                    }
353                                    block.statements.push(stmt);
354                                    block.end = BlockEnd::Goto(arm.block_id, Default::default());
355                                }
356                                // Propagating the const value information.
357                                self.var_info.insert(
358                                    output,
359                                    VarInfo::Const(*value).wrap_with_snapshots(n_snapshots),
360                                );
361                            }
362                        }
363                    }
364                    MatchInfo::Value(info) => {
365                        if let Some(value) =
366                            self.as_int(info.input.var_id).and_then(|x| x.to_usize())
367                            && let Some(arm) = info.arms.iter().find(|arm| {
368                                matches!(
369                                    &arm.arm_selector,
370                                    MatchArmSelector::Value(v) if v.value == value
371                                )
372                            })
373                        {
374                            // Create the variable that was previously introduced in match arm.
375                            block.statements.push(Statement::StructConstruct(
376                                StatementStructConstruct { inputs: vec![], output: arm.var_ids[0] },
377                            ));
378                            block.end = BlockEnd::Goto(arm.block_id, Default::default());
379                        }
380                    }
381                    MatchInfo::Extern(info) => {
382                        if let Some((extra_stmts, updated_end)) = self.handle_extern_block_end(info)
383                        {
384                            block.statements.extend(extra_stmts);
385                            block.end = updated_end;
386                        }
387                    }
388                }
389            }
390            BlockEnd::Return(inputs, _) => self.maybe_replace_inputs(inputs),
391            BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
392        }
393        match &block.end {
394            BlockEnd::Goto(dst_block_id, _) => {
395                match self.reachability.entry(*dst_block_id) {
396                    std::collections::hash_map::Entry::Occupied(mut e) => {
397                        e.insert(Reachability::Any)
398                    }
399                    std::collections::hash_map::Entry::Vacant(e) => {
400                        *e.insert(Reachability::FromSingleGoto(block_id))
401                    }
402                };
403            }
404            BlockEnd::Match { info } => {
405                for arm in info.arms() {
406                    assert!(self.reachability.insert(arm.block_id, Reachability::Any).is_none());
407                }
408            }
409            BlockEnd::NotSet | BlockEnd::Return(..) | BlockEnd::Panic(..) => {}
410        }
411    }
412
413    /// Handles a statement call.
414    ///
415    /// Returns None if no additional changes are required.
416    /// If changes are required, returns an updated statement (to override the current
417    /// statement).
418    /// May add additional statements to `self.additional_stmts` if just replacing the current
419    /// statement is not enough.
420    fn handle_statement_call(&mut self, stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
421        let db = self.db;
422        if stmt.function == self.panic_with_felt252 {
423            let val = self.as_const(stmt.inputs[0].var_id)?;
424            stmt.inputs.clear();
425            stmt.function = GenericFunctionId::Free(self.panic_with_const_felt252)
426                .concretize(db, vec![GenericArgumentId::Constant(val)])
427                .lowered(db);
428            return None;
429        } else if stmt.function == self.panic_with_byte_array && !flag_unsafe_panic(db) {
430            let snap = self.var_info.get(&stmt.inputs[0].var_id)?;
431            let bytearray = try_extract_matches!(snap, VarInfo::Snapshot)?;
432            let [
433                Some(VarInfo::Array(data)),
434                Some(VarInfo::Const(pending_word)),
435                Some(VarInfo::Const(pending_len)),
436            ] = &try_extract_matches!(bytearray.as_ref(), VarInfo::Struct)?[..]
437            else {
438                return None;
439            };
440            let mut panic_data =
441                vec![BigInt::from_str_radix(BYTE_ARRAY_MAGIC, 16).unwrap(), data.len().into()];
442            for word in data {
443                let Some(VarInfo::Const(word)) = word else {
444                    return None;
445                };
446                panic_data.push(word.long(db).to_int()?.clone());
447            }
448            panic_data.extend([
449                pending_word.long(db).to_int()?.clone(),
450                pending_len.long(db).to_int()?.clone(),
451            ]);
452            let felt252_ty = self.felt252;
453            let location = stmt.location;
454            let new_var = |ty| Variable::with_default_context(db, ty, location);
455            let as_usage = |var_id| VarUsage { var_id, location };
456            let array_fn = |extern_id| {
457                let args = vec![GenericArgumentId::Type(felt252_ty)];
458                GenericFunctionId::Extern(extern_id).concretize(db, args).lowered(db)
459            };
460            let call_stmt = |function, inputs, outputs| {
461                let with_coupon = false;
462                Statement::Call(StatementCall { function, inputs, with_coupon, outputs, location })
463            };
464            let arr_var = new_var(corelib::core_array_felt252_ty(db));
465            let mut arr = self.variables.alloc(arr_var.clone());
466            self.additional_stmts.push(call_stmt(array_fn(self.array_new), vec![], vec![arr]));
467            let felt252_var = new_var(felt252_ty);
468            let arr_append_fn = array_fn(self.array_append);
469            for word in panic_data {
470                let to_append = self.variables.alloc(felt252_var.clone());
471                let new_arr = self.variables.alloc(arr_var.clone());
472                self.additional_stmts.push(Statement::Const(StatementConst::new_flat(
473                    ConstValue::Int(word, felt252_ty).intern(db),
474                    to_append,
475                )));
476                self.additional_stmts.push(call_stmt(
477                    arr_append_fn,
478                    vec![as_usage(arr), as_usage(to_append)],
479                    vec![new_arr],
480                ));
481                arr = new_arr;
482            }
483            let panic_ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, "Panic"), vec![]);
484            let panic_var = self.variables.alloc(new_var(panic_ty));
485            self.additional_stmts.push(Statement::StructConstruct(StatementStructConstruct {
486                inputs: vec![],
487                output: panic_var,
488            }));
489            return Some(Statement::StructConstruct(StatementStructConstruct {
490                inputs: vec![as_usage(panic_var), as_usage(arr)],
491                output: stmt.outputs[0],
492            }));
493        }
494        let (id, _generic_args) = stmt.function.get_extern(db)?;
495        if id == self.felt_sub {
496            // (a - 0) can be replaced by a.
497            let val = self.as_int(stmt.inputs[1].var_id)?;
498            if val.is_zero() {
499                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
500            }
501            None
502        } else if self.wide_mul_fns.contains(&id) {
503            let lhs = self.as_int_ex(stmt.inputs[0].var_id);
504            let rhs = self.as_int(stmt.inputs[1].var_id);
505            let output = stmt.outputs[0];
506            if lhs.map(|(v, _)| v.is_zero()).unwrap_or_default()
507                || rhs.map(Zero::is_zero).unwrap_or_default()
508            {
509                return Some(self.propagate_zero_and_get_statement(output));
510            }
511            let (lhs, nz_ty) = lhs?;
512            Some(self.propagate_const_and_get_statement(lhs * rhs?, stmt.outputs[0], nz_ty))
513        } else if id == self.bounded_int_add || id == self.bounded_int_sub {
514            let lhs = self.as_int(stmt.inputs[0].var_id)?;
515            let rhs = self.as_int(stmt.inputs[1].var_id)?;
516            let value = if id == self.bounded_int_add { lhs + rhs } else { lhs - rhs };
517            Some(self.propagate_const_and_get_statement(value, stmt.outputs[0], false))
518        } else if self.div_rem_fns.contains(&id) {
519            let lhs = self.as_int(stmt.inputs[0].var_id);
520            if lhs.map(Zero::is_zero).unwrap_or_default() {
521                let additional_stmt = self.propagate_zero_and_get_statement(stmt.outputs[1]);
522                self.additional_stmts.push(additional_stmt);
523                return Some(self.propagate_zero_and_get_statement(stmt.outputs[0]));
524            }
525            let rhs = self.as_int(stmt.inputs[1].var_id)?;
526            let (q, r) = lhs?.div_rem(rhs);
527            let q_output = stmt.outputs[0];
528            let q_value = ConstValue::Int(q, self.variables[q_output].ty).intern(db);
529            self.var_info.insert(q_output, VarInfo::Const(q_value));
530            let r_output = stmt.outputs[1];
531            let r_value = ConstValue::Int(r, self.variables[r_output].ty).intern(db);
532            self.var_info.insert(r_output, VarInfo::Const(r_value));
533            self.additional_stmts
534                .push(Statement::Const(StatementConst::new_flat(r_value, r_output)));
535            Some(Statement::Const(StatementConst::new_flat(q_value, q_output)))
536        } else if id == self.storage_base_address_from_felt252 {
537            let input_var = stmt.inputs[0].var_id;
538            if let Some(const_value) = self.as_const(input_var)
539                && let ConstValue::Int(val, ty) = const_value.long(db)
540            {
541                stmt.inputs.clear();
542                let arg = GenericArgumentId::Constant(ConstValue::Int(val.clone(), *ty).intern(db));
543                stmt.function =
544                    self.storage_base_address_const.concretize(db, vec![arg]).lowered(db);
545            }
546            None
547        } else if id == self.into_box {
548            let input = stmt.inputs[0];
549            let var_info = self.var_info.get(&input.var_id);
550            let const_value = match var_info {
551                Some(VarInfo::Const(val)) => Some(*val),
552                Some(VarInfo::Snapshot(info)) => {
553                    try_extract_matches!(info.as_ref(), VarInfo::Const).copied()
554                }
555                _ => None,
556            };
557            let var_info = var_info.cloned().or_else(|| var_info_if_copy(self.variables, input))?;
558            self.var_info.insert(stmt.outputs[0], VarInfo::Box(var_info.into()));
559            Some(Statement::Const(StatementConst::new_boxed(const_value?, stmt.outputs[0])))
560        } else if id == self.unbox {
561            if let VarInfo::Box(inner) = self.var_info.get(&stmt.inputs[0].var_id)? {
562                let inner = inner.as_ref().clone();
563                if let VarInfo::Const(inner) =
564                    self.var_info.entry(stmt.outputs[0]).insert_entry(inner).get()
565                {
566                    return Some(Statement::Const(StatementConst::new_flat(
567                        *inner,
568                        stmt.outputs[0],
569                    )));
570                }
571            }
572            None
573        } else if self.upcast_fns.contains(&id) {
574            let int_value = self.as_int(stmt.inputs[0].var_id)?;
575            let output = stmt.outputs[0];
576            let value = ConstValue::Int(int_value.clone(), self.variables[output].ty).intern(db);
577            self.var_info.insert(output, VarInfo::Const(value));
578            Some(Statement::Const(StatementConst::new_flat(value, output)))
579        } else if id == self.array_new {
580            self.var_info.insert(stmt.outputs[0], VarInfo::Array(vec![]));
581            None
582        } else if id == self.array_append {
583            let mut var_infos =
584                if let VarInfo::Array(var_infos) = self.var_info.get(&stmt.inputs[0].var_id)? {
585                    var_infos.clone()
586                } else {
587                    return None;
588                };
589            let appended = stmt.inputs[1];
590            var_infos.push(match self.var_info.get(&appended.var_id) {
591                Some(var_info) => Some(var_info.clone()),
592                None => var_info_if_copy(self.variables, appended),
593            });
594            self.var_info.insert(stmt.outputs[0], VarInfo::Array(var_infos));
595            None
596        } else if id == self.array_len {
597            let info = self.var_info.get(&stmt.inputs[0].var_id)?;
598            let desnapped = try_extract_matches!(info, VarInfo::Snapshot)?;
599            let length = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?.len();
600            Some(self.propagate_const_and_get_statement(length.into(), stmt.outputs[0], false))
601        } else {
602            None
603        }
604    }
605
606    /// Tries to specialize the call.
607    /// Returns The specialized call statement if it was specialized, or None otherwise.
608    ///
609    /// Specialization occurs only if `priv_should_specialize` returns true.
610    /// Additionally specialization of a callee the with the same base as the caller is currently
611    /// not supported.
612    fn try_specialize_call(&self, call_stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
613        if call_stmt.with_coupon {
614            return None;
615        }
616        // No specialization when avoiding inlining.
617        if matches!(self.db.optimization_config().inlining_strategy, InliningStrategy::Avoid) {
618            return None;
619        }
620
621        let Ok(Some(mut base)) = call_stmt.function.body(self.db) else {
622            return None;
623        };
624
625        if self.db.priv_never_inline(base).ok()? {
626            return None;
627        }
628
629        // Avoid specializing with the same base as the current function as it may lead to infinite
630        // specialization.
631        if base == self.caller_base {
632            return None;
633        }
634        if call_stmt.inputs.iter().all(|arg| self.var_info.get(&arg.var_id).is_none()) {
635            // No const inputs
636            return None;
637        }
638
639        let mut const_args = vec![];
640        let mut new_args = vec![];
641        for arg in &call_stmt.inputs {
642            if let Some(var_info) = self.var_info.get(&arg.var_id)
643                && let Some(c) =
644                    self.try_get_specialization_arg(var_info.clone(), self.variables[arg.var_id].ty)
645            {
646                const_args.push(Some(c.clone()));
647                continue;
648            }
649            const_args.push(None);
650            new_args.push(*arg);
651        }
652
653        if new_args.len() == call_stmt.inputs.len() {
654            // No argument was assigned -> no specialization.
655            return None;
656        }
657
658        if let ConcreteFunctionWithBodyLongId::Specialized(specialized_function) =
659            base.long(self.db)
660        {
661            // Canonicalize the specialization rather than adding a specialization of a specialized
662            // function.
663            base = specialized_function.base;
664            let mut new_args_iter = const_args.into_iter();
665            const_args = specialized_function.args.to_vec();
666            for arg in &mut const_args {
667                if arg.is_none() {
668                    *arg = new_args_iter.next().unwrap_or_default();
669                }
670            }
671        }
672        let specialized = SpecializedFunction { base, args: const_args.into() };
673        let specialized_func_id =
674            ConcreteFunctionWithBodyLongId::Specialized(specialized).intern(self.db);
675
676        if self.db.priv_should_specialize(specialized_func_id) == Ok(false) {
677            return None;
678        }
679
680        Some(Statement::Call(StatementCall {
681            function: specialized_func_id.function_id(self.db).unwrap(),
682            inputs: new_args,
683            with_coupon: call_stmt.with_coupon,
684            outputs: std::mem::take(&mut call_stmt.outputs),
685            location: call_stmt.location,
686        }))
687    }
688
689    /// Adds `value` as a const to `var_info` and return a const statement for it.
690    fn propagate_const_and_get_statement(
691        &mut self,
692        value: BigInt,
693        output: VariableId,
694        nz_ty: bool,
695    ) -> Statement<'db> {
696        let ty = self.variables[output].ty;
697        let value = if nz_ty {
698            let inner_ty =
699                try_extract_nz_wrapped_type(self.db, ty).expect("Expected a non-zero type");
700            ConstValue::NonZero(ConstValue::Int(value, inner_ty).intern(self.db)).intern(self.db)
701        } else {
702            ConstValue::Int(value, ty).intern(self.db)
703        };
704        self.var_info.insert(output, VarInfo::Const(value));
705        Statement::Const(StatementConst::new_flat(value, output))
706    }
707
708    /// Adds 0 const to `var_info` and return a const statement for it.
709    fn propagate_zero_and_get_statement(&mut self, output: VariableId) -> Statement<'db> {
710        self.propagate_const_and_get_statement(BigInt::zero(), output, false)
711    }
712
713    /// Returns a statement that introduces the requested value into `output`, or None if fails.
714    fn try_generate_const_statement(
715        &self,
716        value: ConstValueId<'db>,
717        output: VariableId,
718    ) -> Option<Statement<'db>> {
719        if self.db.type_size_info(self.variables[output].ty) == Ok(TypeSizeInformation::Other) {
720            Some(Statement::Const(StatementConst::new_flat(value, output)))
721        } else if matches!(value.long(self.db), ConstValue::Struct(members, _) if members.is_empty())
722        {
723            // Handling const empty structs - which are not supported in sierra-gen.
724            Some(Statement::StructConstruct(StatementStructConstruct { inputs: vec![], output }))
725        } else {
726            None
727        }
728    }
729
730    /// Handles the end of an extern block.
731    /// Returns None if no additional changes are required.
732    /// If changes are required, returns a possible additional const-statement to the block, as well
733    /// as an updated block end.
734    fn handle_extern_block_end(
735        &mut self,
736        info: &mut MatchExternInfo<'db>,
737    ) -> Option<(Vec<Statement<'db>>, BlockEnd<'db>)> {
738        let db = self.db;
739        let (id, generic_args) = info.function.get_extern(db)?;
740        if self.nz_fns.contains(&id) {
741            let val = self.as_const(info.inputs[0].var_id)?;
742            let is_zero = match val.long(db) {
743                ConstValue::Int(v, _) => v.is_zero(),
744                ConstValue::Struct(s, _) => s.iter().all(|v| {
745                    v.long(db).to_int().expect("Expected ConstValue::Int for size").is_zero()
746                }),
747                _ => unreachable!(),
748            };
749            Some(if is_zero {
750                (vec![], BlockEnd::Goto(info.arms[0].block_id, Default::default()))
751            } else {
752                let arm = &info.arms[1];
753                let nz_var = arm.var_ids[0];
754                let nz_val = ConstValue::NonZero(val).intern(db);
755                self.var_info.insert(nz_var, VarInfo::Const(nz_val));
756                (
757                    vec![Statement::Const(StatementConst::new_flat(nz_val, nz_var))],
758                    BlockEnd::Goto(arm.block_id, Default::default()),
759                )
760            })
761        } else if self.eq_fns.contains(&id) {
762            let lhs = self.as_int(info.inputs[0].var_id);
763            let rhs = self.as_int(info.inputs[1].var_id);
764            if (lhs.map(Zero::is_zero).unwrap_or_default() && rhs.is_none())
765                || (rhs.map(Zero::is_zero).unwrap_or_default() && lhs.is_none())
766            {
767                let nz_input = info.inputs[if lhs.is_some() { 1 } else { 0 }];
768                let var = &self.variables[nz_input.var_id].clone();
769                let function = self.type_value_ranges.get(&var.ty)?.is_zero;
770                let unused_nz_var = Variable::with_default_context(
771                    db,
772                    corelib::core_nonzero_ty(db, var.ty),
773                    var.location,
774                );
775                let unused_nz_var = self.variables.alloc(unused_nz_var);
776                return Some((
777                    vec![],
778                    BlockEnd::Match {
779                        info: MatchInfo::Extern(MatchExternInfo {
780                            function,
781                            inputs: vec![nz_input],
782                            arms: vec![
783                                MatchArm {
784                                    arm_selector: MatchArmSelector::VariantId(
785                                        corelib::jump_nz_zero_variant(db, var.ty),
786                                    ),
787                                    block_id: info.arms[1].block_id,
788                                    var_ids: vec![],
789                                },
790                                MatchArm {
791                                    arm_selector: MatchArmSelector::VariantId(
792                                        corelib::jump_nz_nonzero_variant(db, var.ty),
793                                    ),
794                                    block_id: info.arms[0].block_id,
795                                    var_ids: vec![unused_nz_var],
796                                },
797                            ],
798                            location: info.location,
799                        }),
800                    },
801                ));
802            }
803            Some((
804                vec![],
805                BlockEnd::Goto(
806                    info.arms[if lhs? == rhs? { 1 } else { 0 }].block_id,
807                    Default::default(),
808                ),
809            ))
810        } else if self.uadd_fns.contains(&id)
811            || self.usub_fns.contains(&id)
812            || self.diff_fns.contains(&id)
813            || self.iadd_fns.contains(&id)
814            || self.isub_fns.contains(&id)
815        {
816            let rhs = self.as_int(info.inputs[1].var_id);
817            let lhs = self.as_int(info.inputs[0].var_id);
818            if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
819                let ty = self.variables[info.arms[0].var_ids[0]].ty;
820                let range = &self.type_value_ranges.get(&ty)?.range;
821                let value = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
822                    lhs + rhs
823                } else {
824                    lhs - rhs
825                };
826                let (arm_index, value) = match range.normalized(value) {
827                    NormalizedResult::InRange(value) => (0, value),
828                    NormalizedResult::Under(value) => (1, value),
829                    NormalizedResult::Over(value) => (
830                        if self.iadd_fns.contains(&id) || self.isub_fns.contains(&id) {
831                            2
832                        } else {
833                            1
834                        },
835                        value,
836                    ),
837                };
838                let arm = &info.arms[arm_index];
839                let actual_output = arm.var_ids[0];
840                let value = ConstValue::Int(value, ty).intern(db);
841                self.var_info.insert(actual_output, VarInfo::Const(value));
842                return Some((
843                    vec![Statement::Const(StatementConst::new_flat(value, actual_output))],
844                    BlockEnd::Goto(arm.block_id, Default::default()),
845                ));
846            }
847            if let Some(rhs) = rhs {
848                if rhs.is_zero() && !self.diff_fns.contains(&id) {
849                    let arm = &info.arms[0];
850                    self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[0]));
851                    return Some((vec![], BlockEnd::Goto(arm.block_id, Default::default())));
852                }
853                if rhs.is_one() && !self.diff_fns.contains(&id) {
854                    let ty = self.variables[info.arms[0].var_ids[0]].ty;
855                    let ty_info = self.type_value_ranges.get(&ty)?;
856                    let function = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
857                        ty_info.inc?
858                    } else {
859                        ty_info.dec?
860                    };
861                    let enum_ty = function.signature(db).ok()?.return_type;
862                    let TypeLongId::Concrete(ConcreteTypeId::Enum(concrete_enum_id)) =
863                        enum_ty.long(db)
864                    else {
865                        return None;
866                    };
867                    let result = self.variables.alloc(Variable::with_default_context(
868                        db,
869                        function.signature(db).unwrap().return_type,
870                        info.location,
871                    ));
872                    return Some((
873                        vec![Statement::Call(StatementCall {
874                            function,
875                            inputs: vec![info.inputs[0]],
876                            with_coupon: false,
877                            outputs: vec![result],
878                            location: info.location,
879                        })],
880                        BlockEnd::Match {
881                            info: MatchInfo::Enum(MatchEnumInfo {
882                                concrete_enum_id: *concrete_enum_id,
883                                input: VarUsage { var_id: result, location: info.location },
884                                arms: core::mem::take(&mut info.arms),
885                                location: info.location,
886                            }),
887                        },
888                    ));
889                }
890            }
891            if let Some(lhs) = lhs
892                && lhs.is_zero()
893                && (self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id))
894            {
895                let arm = &info.arms[0];
896                self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[1]));
897                return Some((vec![], BlockEnd::Goto(arm.block_id, Default::default())));
898            }
899            None
900        } else if let Some(reversed) = self.downcast_fns.get(&id) {
901            let range = |ty: TypeId<'_>| {
902                Some(if let Some(ti) = self.type_value_ranges.get(&ty) {
903                    ti.range.clone()
904                } else {
905                    let (min, max) = corelib::try_extract_bounded_int_type_ranges(db, ty)?;
906                    TypeRange { min, max }
907                })
908            };
909            let (success_arm, failure_arm) = if *reversed { (1, 0) } else { (0, 1) };
910            let input_var = info.inputs[0].var_id;
911            let success_output = info.arms[success_arm].var_ids[0];
912            let out_ty = self.variables[success_output].ty;
913            let out_range = range(out_ty)?;
914            let Some(value) = self.as_int(input_var) else {
915                let in_ty = self.variables[input_var].ty;
916                let in_range = range(in_ty)?;
917                return if in_range.min < out_range.min || in_range.max > out_range.max {
918                    None
919                } else {
920                    let generic_args = [in_ty, out_ty].map(GenericArgumentId::Type).to_vec();
921                    let function = db.core_info().upcast_fn.concretize(db, generic_args);
922                    return Some((
923                        vec![Statement::Call(StatementCall {
924                            function: function.lowered(db),
925                            inputs: vec![info.inputs[0]],
926                            with_coupon: false,
927                            outputs: vec![success_output],
928                            location: info.location,
929                        })],
930                        BlockEnd::Goto(info.arms[success_arm].block_id, Default::default()),
931                    ));
932                };
933            };
934            Some(if let NormalizedResult::InRange(value) = out_range.normalized(value.clone()) {
935                let value = ConstValue::Int(value, out_ty).intern(db);
936                self.var_info.insert(success_output, VarInfo::Const(value));
937                (
938                    vec![Statement::Const(StatementConst::new_flat(value, success_output))],
939                    BlockEnd::Goto(info.arms[success_arm].block_id, Default::default()),
940                )
941            } else {
942                (vec![], BlockEnd::Goto(info.arms[failure_arm].block_id, Default::default()))
943            })
944        } else if id == self.bounded_int_constrain {
945            let input_var = info.inputs[0].var_id;
946            let (value, nz_ty) = self.as_int_ex(input_var)?;
947            let generic_arg = generic_args[1];
948            let constrain_value = extract_matches!(generic_arg, GenericArgumentId::Constant)
949                .long(db)
950                .to_int()
951                .expect("Expected ConstValue::Int for size");
952            let arm_idx = if value < constrain_value { 0 } else { 1 };
953            let output = info.arms[arm_idx].var_ids[0];
954            Some((
955                vec![self.propagate_const_and_get_statement(value.clone(), output, nz_ty)],
956                BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()),
957            ))
958        } else if id == self.array_get {
959            let index = self.as_int(info.inputs[1].var_id)?.to_usize()?;
960            if let Some(VarInfo::Snapshot(arr_info)) = self.var_info.get(&info.inputs[0].var_id)
961                && let VarInfo::Array(infos) = arr_info.as_ref()
962            {
963                match infos.get(index) {
964                    Some(Some(output_var_info)) => {
965                        let arm = &info.arms[0];
966                        let output_var_info = output_var_info.clone();
967                        let box_info =
968                            VarInfo::Box(VarInfo::Snapshot(output_var_info.clone().into()).into());
969                        self.var_info.insert(arm.var_ids[0], box_info);
970                        if let VarInfo::Const(value) = output_var_info {
971                            let value_ty = value.ty(db).ok()?;
972                            let value_box_ty = corelib::core_box_ty(db, value_ty);
973                            let location = info.location;
974                            let boxed_var =
975                                Variable::with_default_context(db, value_box_ty, location);
976                            let boxed = self.variables.alloc(boxed_var.clone());
977                            let unused_boxed = self.variables.alloc(boxed_var);
978                            let snapped = self.variables.alloc(Variable::with_default_context(
979                                db,
980                                TypeLongId::Snapshot(value_box_ty).intern(db),
981                                location,
982                            ));
983                            return Some((
984                                vec![
985                                    Statement::Const(StatementConst::new_boxed(value, boxed)),
986                                    Statement::Snapshot(StatementSnapshot {
987                                        input: VarUsage { var_id: boxed, location },
988                                        outputs: [unused_boxed, snapped],
989                                    }),
990                                    Statement::Call(StatementCall {
991                                        function: self
992                                            .box_forward_snapshot
993                                            .concretize(db, vec![GenericArgumentId::Type(value_ty)])
994                                            .lowered(db),
995                                        inputs: vec![VarUsage { var_id: snapped, location }],
996                                        with_coupon: false,
997                                        outputs: vec![arm.var_ids[0]],
998                                        location: info.location,
999                                    }),
1000                                ],
1001                                BlockEnd::Goto(arm.block_id, Default::default()),
1002                            ));
1003                        }
1004                    }
1005                    None => {
1006                        return Some((
1007                            vec![],
1008                            BlockEnd::Goto(info.arms[1].block_id, Default::default()),
1009                        ));
1010                    }
1011                    Some(None) => {}
1012                }
1013            }
1014            if index.is_zero()
1015                && let [success, failure] = info.arms.as_mut_slice()
1016            {
1017                let arr = info.inputs[0].var_id;
1018                let unused_arr_output0 = self.variables.alloc(self.variables[arr].clone());
1019                let unused_arr_output1 = self.variables.alloc(self.variables[arr].clone());
1020                info.inputs.truncate(1);
1021                info.function = GenericFunctionId::Extern(self.array_snapshot_pop_front)
1022                    .concretize(db, generic_args)
1023                    .lowered(db);
1024                success.var_ids.insert(0, unused_arr_output0);
1025                failure.var_ids.insert(0, unused_arr_output1);
1026            }
1027            None
1028        } else if id == self.array_pop_front {
1029            let VarInfo::Array(var_infos) = self.var_info.get(&info.inputs[0].var_id)? else {
1030                return None;
1031            };
1032            if let Some(first) = var_infos.first() {
1033                if let Some(first) = first.as_ref().cloned() {
1034                    let arm = &info.arms[0];
1035                    self.var_info.insert(arm.var_ids[0], VarInfo::Array(var_infos[1..].to_vec()));
1036                    self.var_info.insert(arm.var_ids[1], VarInfo::Box(first.into()));
1037                }
1038                None
1039            } else {
1040                let arm = &info.arms[1];
1041                self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]));
1042                Some((
1043                    vec![],
1044                    BlockEnd::Goto(
1045                        arm.block_id,
1046                        VarRemapping {
1047                            remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1048                        },
1049                    ),
1050                ))
1051            }
1052        } else if id == self.array_snapshot_pop_back || id == self.array_snapshot_pop_front {
1053            let var_info = self.var_info.get(&info.inputs[0].var_id)?;
1054            let desnapped = try_extract_matches!(var_info, VarInfo::Snapshot)?;
1055            let element_var_infos = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?;
1056            // TODO(orizi): Propagate success values as well.
1057            if element_var_infos.is_empty() {
1058                let arm = &info.arms[1];
1059                self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]));
1060                Some((
1061                    vec![],
1062                    BlockEnd::Goto(
1063                        arm.block_id,
1064                        VarRemapping {
1065                            remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1066                        },
1067                    ),
1068                ))
1069            } else {
1070                None
1071            }
1072        } else {
1073            None
1074        }
1075    }
1076
1077    /// Returns the const value of a variable if it exists.
1078    fn as_const(&self, var_id: VariableId) -> Option<ConstValueId<'db>> {
1079        try_extract_matches!(self.var_info.get(&var_id)?, VarInfo::Const).copied()
1080    }
1081
1082    /// Return the const value as an int if it exists and is an integer, additionally, if it is of a
1083    /// non-zero type.
1084    fn as_int_ex(&self, var_id: VariableId) -> Option<(&BigInt, bool)> {
1085        match self.as_const(var_id)?.long(self.db) {
1086            ConstValue::Int(value, _) => Some((value, false)),
1087            ConstValue::NonZero(const_value) => {
1088                if let ConstValue::Int(value, _) = const_value.long(self.db) {
1089                    Some((value, true))
1090                } else {
1091                    None
1092                }
1093            }
1094            _ => None,
1095        }
1096    }
1097
1098    /// Return the const value as a int if it exists and is an integer.
1099    fn as_int(&self, var_id: VariableId) -> Option<&BigInt> {
1100        Some(self.as_int_ex(var_id)?.0)
1101    }
1102
1103    /// Replaces the inputs in place if they are in the var_info map.
1104    fn maybe_replace_inputs(&self, inputs: &mut [VarUsage<'db>]) {
1105        for input in inputs {
1106            self.maybe_replace_input(input);
1107        }
1108    }
1109
1110    /// Replaces the input in place if it is in the var_info map.
1111    fn maybe_replace_input(&self, input: &mut VarUsage<'db>) {
1112        if let Some(VarInfo::Var(new_var)) = self.var_info.get(&input.var_id) {
1113            *input = *new_var;
1114        }
1115    }
1116
1117    /// Given a var_info and its type, return the corresponding specialization argument, if it
1118    /// exists.
1119    fn try_get_specialization_arg(
1120        &self,
1121        var_info: VarInfo<'db>,
1122        ty: TypeId<'db>,
1123    ) -> Option<SpecializationArg<'db>> {
1124        if self.db.type_size_info(ty).ok()? == TypeSizeInformation::ZeroSized {
1125            // Skip zero-sized constants as they are not supported in sierra-gen.
1126            return None;
1127        }
1128
1129        match var_info {
1130            VarInfo::Const(value) => Some(SpecializationArg::Const { value, boxed: false }),
1131            VarInfo::Box(info) => try_extract_matches!(info.as_ref(), VarInfo::Const)
1132                .map(|value| SpecializationArg::Const { value: *value, boxed: true }),
1133            VarInfo::Snapshot(info) => {
1134                let desnap_ty = *extract_matches!(ty.long(self.db), TypeLongId::Snapshot);
1135                Some(SpecializationArg::Snapshot(
1136                    self.try_get_specialization_arg(info.as_ref().clone(), desnap_ty)?.into(),
1137                ))
1138            }
1139            VarInfo::Array(infos) => {
1140                let TypeLongId::Concrete(concrete_ty) = ty.long(self.db) else {
1141                    unreachable!("Expected a concrete type");
1142                };
1143                let [GenericArgumentId::Type(inner_ty)] = &concrete_ty.generic_args(self.db)[..]
1144                else {
1145                    unreachable!("Expected a single type generic argument");
1146                };
1147                Some(SpecializationArg::Array(
1148                    *inner_ty,
1149                    infos
1150                        .into_iter()
1151                        .map(|info| self.try_get_specialization_arg(info?, *inner_ty))
1152                        .collect::<Option<Vec<_>>>()?,
1153                ))
1154            }
1155            VarInfo::Struct(infos) => {
1156                let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) =
1157                    ty.long(self.db)
1158                else {
1159                    // TODO(ilya): Support closures and fixed size arrays.
1160                    return None;
1161                };
1162
1163                let members = self.db.concrete_struct_members(*concrete_struct).unwrap();
1164                let struct_args = zip_eq(members.values(), infos)
1165                    .map(|(member, opt_var_info)| {
1166                        self.try_get_specialization_arg(opt_var_info?.clone(), member.ty)
1167                    })
1168                    .collect::<Option<Vec<_>>>()?;
1169
1170                Some(SpecializationArg::Struct(struct_args))
1171            }
1172            _ => None,
1173        }
1174    }
1175
1176    /// Returns true if const-folding should be skipped for the current function.
1177    pub fn should_skip_const_folding(&self, db: &'db dyn Database) -> bool {
1178        if db.optimization_config().skip_const_folding {
1179            return true;
1180        }
1181
1182        // Skipping const-folding for `panic_with_const_felt252` - to avoid replacing a call to
1183        // `panic_with_felt252` with `panic_with_const_felt252` and causing accidental recursion.
1184        if self.caller_base.base_semantic_function(db).generic_function(db)
1185            == GenericFunctionWithBodyId::Free(self.libfunc_info.panic_with_const_felt252)
1186        {
1187            return true;
1188        }
1189        false
1190    }
1191}
1192
1193/// Returns a `VarInfo` of a variable only if it is copyable.
1194fn var_info_if_copy<'db>(
1195    variables: &VariableArena<'db>,
1196    input: VarUsage<'db>,
1197) -> Option<VarInfo<'db>> {
1198    variables[input.var_id].info.copyable.is_ok().then_some(VarInfo::Var(input))
1199}
1200
1201/// Internal query for the libfuncs information required for const folding.
1202#[salsa::tracked(returns(ref))]
1203fn priv_const_folding_info<'db>(
1204    db: &'db dyn Database,
1205) -> crate::optimizations::const_folding::ConstFoldingLibfuncInfo<'db> {
1206    ConstFoldingLibfuncInfo::new(db)
1207}
1208
1209/// Holds static information about libfuncs required for the optimization.
1210#[derive(Debug, PartialEq, Eq, salsa::Update)]
1211pub struct ConstFoldingLibfuncInfo<'db> {
1212    /// The `felt252_sub` libfunc.
1213    felt_sub: ExternFunctionId<'db>,
1214    /// The `into_box` libfunc.
1215    into_box: ExternFunctionId<'db>,
1216    /// The `unbox` libfunc.
1217    unbox: ExternFunctionId<'db>,
1218    /// The `box_forward_snapshot` libfunc.
1219    box_forward_snapshot: GenericFunctionId<'db>,
1220    /// The set of functions that check if numbers are equal.
1221    eq_fns: OrderedHashSet<ExternFunctionId<'db>>,
1222    /// The set of functions to add unsigned ints.
1223    uadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1224    /// The set of functions to subtract unsigned ints.
1225    usub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1226    /// The set of functions to get the difference of signed ints.
1227    diff_fns: OrderedHashSet<ExternFunctionId<'db>>,
1228    /// The set of functions to add signed ints.
1229    iadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1230    /// The set of functions to subtract signed ints.
1231    isub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1232    /// The set of functions to multiply integers.
1233    wide_mul_fns: OrderedHashSet<ExternFunctionId<'db>>,
1234    /// The set of functions to divide and get the remainder of integers.
1235    div_rem_fns: OrderedHashSet<ExternFunctionId<'db>>,
1236    /// The `bounded_int_add` libfunc.
1237    bounded_int_add: ExternFunctionId<'db>,
1238    /// The `bounded_int_sub` libfunc.
1239    bounded_int_sub: ExternFunctionId<'db>,
1240    /// The `bounded_int_constrain` libfunc.
1241    bounded_int_constrain: ExternFunctionId<'db>,
1242    /// The `array_get` libfunc.
1243    array_get: ExternFunctionId<'db>,
1244    /// The `array_snapshot_pop_front` libfunc.
1245    array_snapshot_pop_front: ExternFunctionId<'db>,
1246    /// The `array_snapshot_pop_back` libfunc.
1247    array_snapshot_pop_back: ExternFunctionId<'db>,
1248    /// The `array_len` libfunc.
1249    array_len: ExternFunctionId<'db>,
1250    /// The `array_new` libfunc.
1251    array_new: ExternFunctionId<'db>,
1252    /// The `array_append` libfunc.
1253    array_append: ExternFunctionId<'db>,
1254    /// The `array_pop_front` libfunc.
1255    array_pop_front: ExternFunctionId<'db>,
1256    /// The `storage_base_address_from_felt252` libfunc.
1257    storage_base_address_from_felt252: ExternFunctionId<'db>,
1258    /// The `storage_base_address_const` libfunc.
1259    storage_base_address_const: GenericFunctionId<'db>,
1260    /// The `core::panic_with_felt252` function.
1261    panic_with_felt252: FunctionId<'db>,
1262    /// The `core::panic_with_const_felt252` function.
1263    pub panic_with_const_felt252: FreeFunctionId<'db>,
1264    /// The `core::panics::panic_with_byte_array` function.
1265    panic_with_byte_array: FunctionId<'db>,
1266    /// Type ranges.
1267    type_value_ranges: OrderedHashMap<TypeId<'db>, TypeInfo<'db>>,
1268    /// The info used for semantic const calculation.
1269    const_calculation_info: Arc<ConstCalcInfo<'db>>,
1270}
1271impl<'db> ConstFoldingLibfuncInfo<'db> {
1272    fn new(db: &'db dyn Database) -> Self {
1273        let core = ModuleHelper::core(db);
1274        let box_module = core.submodule("box");
1275        let integer_module = core.submodule("integer");
1276        let internal_module = core.submodule("internal");
1277        let bounded_int_module = internal_module.submodule("bounded_int");
1278        let num_module = internal_module.submodule("num");
1279        let array_module = core.submodule("array");
1280        let starknet_module = core.submodule("starknet");
1281        let storage_access_module = starknet_module.submodule("storage_access");
1282        let utypes = ["u8", "u16", "u32", "u64", "u128"];
1283        let itypes = ["i8", "i16", "i32", "i64", "i128"];
1284        let eq_fns = OrderedHashSet::<_>::from_iter(
1285            chain!(utypes, itypes).map(|ty| integer_module.extern_function_id(&format!("{ty}_eq"))),
1286        );
1287        let uadd_fns = OrderedHashSet::<_>::from_iter(
1288            utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_add"))),
1289        );
1290        let usub_fns = OrderedHashSet::<_>::from_iter(
1291            utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_sub"))),
1292        );
1293        let diff_fns = OrderedHashSet::<_>::from_iter(
1294            itypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_diff"))),
1295        );
1296        let iadd_fns =
1297            OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1298                integer_module.extern_function_id(&format!("{ty}_overflowing_add_impl"))
1299            }));
1300        let isub_fns =
1301            OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1302                integer_module.extern_function_id(&format!("{ty}_overflowing_sub_impl"))
1303            }));
1304        let wide_mul_fns = OrderedHashSet::<_>::from_iter(chain!(
1305            [bounded_int_module.extern_function_id("bounded_int_mul")],
1306            ["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
1307                .map(|ty| integer_module.extern_function_id(&format!("{ty}_wide_mul"))),
1308        ));
1309        let div_rem_fns = OrderedHashSet::<_>::from_iter(chain!(
1310            [bounded_int_module.extern_function_id("bounded_int_div_rem")],
1311            utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_safe_divmod"))),
1312        ));
1313        let type_value_ranges: OrderedHashMap<TypeId<'db>, TypeInfo<'db>> =
1314            OrderedHashMap::from_iter(
1315                [
1316                    ("u8", BigInt::ZERO, u8::MAX.into(), false, true),
1317                    ("u16", BigInt::ZERO, u16::MAX.into(), false, true),
1318                    ("u32", BigInt::ZERO, u32::MAX.into(), false, true),
1319                    ("u64", BigInt::ZERO, u64::MAX.into(), false, true),
1320                    ("u128", BigInt::ZERO, u128::MAX.into(), false, true),
1321                    ("u256", BigInt::ZERO, BigInt::from(1) << 256, false, false),
1322                    ("i8", i8::MIN.into(), i8::MAX.into(), true, true),
1323                    ("i16", i16::MIN.into(), i16::MAX.into(), true, true),
1324                    ("i32", i32::MIN.into(), i32::MAX.into(), true, true),
1325                    ("i64", i64::MIN.into(), i64::MAX.into(), true, true),
1326                    ("i128", i128::MIN.into(), i128::MAX.into(), true, true),
1327                ]
1328                .map(
1329                    |(ty_name, min, max, as_bounded_int, inc_dec): (
1330                        &'static str,
1331                        BigInt,
1332                        BigInt,
1333                        bool,
1334                        bool,
1335                    )| {
1336                        let ty =
1337                            corelib::get_core_ty_by_name(db, SmolStrId::from(db, ty_name), vec![]);
1338                        let is_zero = if as_bounded_int {
1339                            bounded_int_module.function_id(
1340                                "bounded_int_is_zero",
1341                                vec![GenericArgumentId::Type(ty)],
1342                            )
1343                        } else {
1344                            integer_module.function_id(
1345                                SmolStrId::from(db, format!("{ty_name}_is_zero")).long(db).as_str(),
1346                                vec![],
1347                            )
1348                        }
1349                        .lowered(db);
1350                        let (inc, dec) = if inc_dec {
1351                            (
1352                                Some(
1353                                    num_module
1354                                        .function_id(
1355                                            SmolStrId::from(db, format!("{ty_name}_inc"))
1356                                                .long(db)
1357                                                .as_str(),
1358                                            vec![],
1359                                        )
1360                                        .lowered(db),
1361                                ),
1362                                Some(
1363                                    num_module
1364                                        .function_id(
1365                                            SmolStrId::from(db, format!("{ty_name}_dec"))
1366                                                .long(db)
1367                                                .as_str(),
1368                                            vec![],
1369                                        )
1370                                        .lowered(db),
1371                                ),
1372                            )
1373                        } else {
1374                            (None, None)
1375                        };
1376                        let info = TypeInfo { range: TypeRange { min, max }, is_zero, inc, dec };
1377                        (ty, info)
1378                    },
1379                ),
1380            );
1381        Self {
1382            felt_sub: core.extern_function_id("felt252_sub"),
1383            into_box: box_module.extern_function_id("into_box"),
1384            unbox: box_module.extern_function_id("unbox"),
1385            box_forward_snapshot: box_module.generic_function_id("box_forward_snapshot"),
1386            eq_fns,
1387            uadd_fns,
1388            usub_fns,
1389            diff_fns,
1390            iadd_fns,
1391            isub_fns,
1392            wide_mul_fns,
1393            div_rem_fns,
1394            bounded_int_add: bounded_int_module.extern_function_id("bounded_int_add"),
1395            bounded_int_sub: bounded_int_module.extern_function_id("bounded_int_sub"),
1396            bounded_int_constrain: bounded_int_module.extern_function_id("bounded_int_constrain"),
1397            array_get: array_module.extern_function_id("array_get"),
1398            array_snapshot_pop_front: array_module.extern_function_id("array_snapshot_pop_front"),
1399            array_snapshot_pop_back: array_module.extern_function_id("array_snapshot_pop_back"),
1400            array_len: array_module.extern_function_id("array_len"),
1401            array_new: array_module.extern_function_id("array_new"),
1402            array_append: array_module.extern_function_id("array_append"),
1403            array_pop_front: array_module.extern_function_id("array_pop_front"),
1404            storage_base_address_from_felt252: storage_access_module
1405                .extern_function_id("storage_base_address_from_felt252"),
1406            storage_base_address_const: storage_access_module
1407                .generic_function_id("storage_base_address_const"),
1408            panic_with_felt252: core.function_id("panic_with_felt252", vec![]).lowered(db),
1409            panic_with_const_felt252: core.free_function_id("panic_with_const_felt252"),
1410            panic_with_byte_array: core
1411                .submodule("panics")
1412                .function_id("panic_with_byte_array", vec![])
1413                .lowered(db),
1414            type_value_ranges,
1415            const_calculation_info: db.const_calc_info(),
1416        }
1417    }
1418}
1419
1420impl<'db> std::ops::Deref for ConstFoldingContext<'db, '_> {
1421    type Target = ConstFoldingLibfuncInfo<'db>;
1422    fn deref(&self) -> &ConstFoldingLibfuncInfo<'db> {
1423        self.libfunc_info
1424    }
1425}
1426
1427impl<'a> std::ops::Deref for ConstFoldingLibfuncInfo<'a> {
1428    type Target = ConstCalcInfo<'a>;
1429    fn deref(&self) -> &ConstCalcInfo<'a> {
1430        &self.const_calculation_info
1431    }
1432}
1433
1434/// The information of a type required for const foldings.
1435#[derive(Debug, PartialEq, Eq, salsa::Update)]
1436struct TypeInfo<'db> {
1437    /// The value range of the type.
1438    range: TypeRange,
1439    /// The function to check if the value is zero for the type.
1440    is_zero: FunctionId<'db>,
1441    /// Inc function to increase the value by one.
1442    inc: Option<FunctionId<'db>>,
1443    /// Dec function to decrease the value by one.
1444    dec: Option<FunctionId<'db>>,
1445}
1446
1447/// The range of values of a numeric type.
1448#[derive(Debug, PartialEq, Eq, Clone)]
1449struct TypeRange {
1450    /// The minimum value of the type.
1451    min: BigInt,
1452    /// The maximum value of the type.
1453    max: BigInt,
1454}
1455impl TypeRange {
1456    /// Normalizes the value to the range.
1457    /// Assumes the value is within size of range of the range.
1458    fn normalized(&self, value: BigInt) -> NormalizedResult {
1459        if value < self.min {
1460            NormalizedResult::Under(value - &self.min + &self.max + 1)
1461        } else if value > self.max {
1462            NormalizedResult::Over(value + &self.min - &self.max - 1)
1463        } else {
1464            NormalizedResult::InRange(value)
1465        }
1466    }
1467}
1468
1469/// The result of normalizing a value to a range.
1470enum NormalizedResult {
1471    /// The original value is in the range, carries the value, or an equivalent value.
1472    InRange(BigInt),
1473    /// The original value is larger than range max, carries the normalized value.
1474    Over(BigInt),
1475    /// The original value is smaller than range min, carries the normalized value.
1476    Under(BigInt),
1477}