Skip to main content

cairo_lang_lowering/optimizations/
const_folding.rs

1#[cfg(test)]
2#[path = "const_folding_test.rs"]
3mod test;
4
5use std::sync::Arc;
6
7use cairo_lang_defs::ids::{ExternFunctionId, FreeFunctionId};
8use cairo_lang_filesystem::flag::flag_unsafe_panic;
9use cairo_lang_semantic::corelib::try_extract_nz_wrapped_type;
10use cairo_lang_semantic::helper::ModuleHelper;
11use cairo_lang_semantic::items::constant::{ConstCalcInfo, ConstValue};
12use cairo_lang_semantic::items::functions::{GenericFunctionId, GenericFunctionWithBodyId};
13use cairo_lang_semantic::types::TypeSizeInformation;
14use cairo_lang_semantic::{
15    ConcreteTypeId, GenericArgumentId, MatchArmSelector, TypeId, TypeLongId, corelib,
16};
17use cairo_lang_utils::byte_array::BYTE_ARRAY_MAGIC;
18use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
19use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
20use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
21use cairo_lang_utils::{Intern, LookupIntern, extract_matches, try_extract_matches};
22use id_arena::Arena;
23use itertools::{chain, zip_eq};
24use num_bigint::BigInt;
25use num_integer::Integer;
26use num_traits::cast::ToPrimitive;
27use num_traits::{Num, One, Zero};
28
29use crate::db::LoweringGroup;
30use crate::ids::{
31    ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionId, SemanticFunctionIdEx,
32    SpecializedFunction,
33};
34use crate::specialization::SpecializationArg;
35use crate::utils::InliningStrategy;
36use crate::{
37    Block, BlockEnd, BlockId, Lowered, MatchArm, MatchEnumInfo, MatchExternInfo, MatchInfo,
38    Statement, StatementCall, StatementConst, StatementDesnap, StatementEnumConstruct,
39    StatementSnapshot, StatementStructConstruct, StatementStructDestructure, VarRemapping,
40    VarUsage, Variable, VariableId,
41};
42
43/// Keeps track of equivalent values that a variables might be replaced with.
44/// Note: We don't keep track of types as we assume the usage is always correct.
45#[derive(Debug, Clone)]
46enum VarInfo {
47    /// The variable is a const value.
48    Const(ConstValue),
49    /// The variable can be replaced by another variable.
50    Var(VarUsage),
51    /// The variable is a snapshot of another variable.
52    Snapshot(Box<VarInfo>),
53    /// The variable is a struct of other variables.
54    /// `None` values represent variables that are not tracked.
55    Struct(Vec<Option<VarInfo>>),
56    /// The variable is a box of another variable.
57    Box(Box<VarInfo>),
58    /// The variable is an array of known size of other variables.
59    /// `None` values represent variables that are not tracked.
60    Array(Vec<Option<VarInfo>>),
61}
62impl VarInfo {
63    /// Peels the snapshots from the variable info and returns the number of snapshots performed.
64    fn peel_snapshots(&self) -> (usize, &VarInfo) {
65        let mut n_snapshots = 0;
66        let mut info = self;
67        while let VarInfo::Snapshot(inner) = info {
68            info = inner.as_ref();
69            n_snapshots += 1;
70        }
71        (n_snapshots, info)
72    }
73    /// Wraps the variable info with the given number of snapshots.
74    fn wrap_with_snapshots(mut self, n_snapshots: usize) -> VarInfo {
75        for _ in 0..n_snapshots {
76            self = VarInfo::Snapshot(Box::new(self));
77        }
78        self
79    }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq)]
83enum Reachability {
84    /// The block is reachable from the function start only through the goto at the end of the given
85    /// block.
86    FromSingleGoto(BlockId),
87    /// The block is reachable from the function start after const-folding - just does not fit
88    /// `FromSingleGoto`.
89    Any,
90}
91
92/// Performs constant folding on the lowered program.
93/// The optimization only works when the blocks are topologically sorted.
94pub fn const_folding(
95    db: &dyn LoweringGroup,
96    function_id: ConcreteFunctionWithBodyId,
97    lowered: &mut Lowered,
98) {
99    if lowered.blocks.is_empty() {
100        return;
101    }
102
103    // Note that we can keep the var_info across blocks because the lowering
104    // is in static single assignment form.
105    let mut ctx = ConstFoldingContext::new(db, function_id, &mut lowered.variables);
106
107    if ctx.should_skip_const_folding(db) {
108        return;
109    }
110
111    for block_id in (0..lowered.blocks.len()).map(BlockId) {
112        if !ctx.visit_block_start(block_id, |block_id| &lowered.blocks[block_id]) {
113            continue;
114        }
115
116        let block = &mut lowered.blocks[block_id];
117        for stmt in block.statements.iter_mut() {
118            ctx.visit_statement(stmt);
119        }
120        ctx.visit_block_end(block_id, block);
121    }
122}
123
124pub struct ConstFoldingContext<'a> {
125    /// The used database.
126    db: &'a dyn LoweringGroup,
127    /// The variables arena, mostly used to get the type of variables.
128    pub variables: &'a mut Arena<Variable>,
129    /// The accumulated information about the const values of variables.
130    var_info: UnorderedHashMap<VariableId, VarInfo>,
131    /// The libfunc information.
132    libfunc_info: Arc<ConstFoldingLibfuncInfo>,
133    /// The specialization base of the caller function (or the caller if the function is not
134    /// specialized).
135    caller_base: ConcreteFunctionWithBodyId,
136    /// Reachability of blocks from the function start.
137    /// If the block is not in this map, it means that it is unreachable (or that it was already
138    /// visited and its reachability won't be checked again).
139    reachability: UnorderedHashMap<BlockId, Reachability>,
140    /// Additional statements to add to the block.
141    additional_stmts: Vec<Statement>,
142}
143
144impl<'a> ConstFoldingContext<'a> {
145    pub fn new(
146        db: &'a dyn LoweringGroup,
147        function_id: ConcreteFunctionWithBodyId,
148        variables: &'a mut Arena<Variable>,
149    ) -> Self {
150        let caller_base = match function_id.lookup_intern(db) {
151            ConcreteFunctionWithBodyLongId::Specialized(specialized_func) => specialized_func.base,
152            _ => function_id,
153        };
154
155        Self {
156            db,
157            var_info: UnorderedHashMap::default(),
158            variables,
159            libfunc_info: priv_const_folding_info(db),
160            caller_base,
161            reachability: UnorderedHashMap::from_iter([(BlockId::root(), Reachability::Any)]),
162            additional_stmts: vec![],
163        }
164    }
165
166    /// Determines if a block is reachable from the function start and propagates constant values
167    /// when the block is reachable via a single goto statement.
168    pub fn visit_block_start<'b>(
169        &mut self,
170        block_id: BlockId,
171        get_block: impl FnOnce(BlockId) -> &'b Block,
172    ) -> bool {
173        let Some(reachability) = self.reachability.remove(&block_id) else {
174            return false;
175        };
176        match reachability {
177            Reachability::Any => {}
178            Reachability::FromSingleGoto(from_block) => match &get_block(from_block).end {
179                BlockEnd::Goto(_, remapping) => {
180                    for (dst, src) in remapping.iter() {
181                        if let Some(v) = self.as_const(src.var_id) {
182                            self.var_info.insert(*dst, VarInfo::Const(v.clone()));
183                        }
184                    }
185                }
186                _ => unreachable!("Expected a goto end"),
187            },
188        }
189        true
190    }
191
192    /// Processes a statement and applies the constant folding optimizations.
193    ///
194    /// This method performs the following operations:
195    /// - Updates the `var_info` map with constant values of variables
196    /// - Replace the input statement with optimized versions when possible
197    /// - Updates `self.additional_stmts` with statements that need to be added to the block.
198    ///
199    /// Note: `self.visit_block_end` must be called after processing all statements
200    /// in a block to actually add the additional statements.
201    pub fn visit_statement(&mut self, stmt: &mut Statement) {
202        self.maybe_replace_inputs(stmt.inputs_mut());
203        match stmt {
204            Statement::Const(StatementConst { value, output }) => match value {
205                value @ (ConstValue::Int(..)
206                | ConstValue::Struct(..)
207                | ConstValue::Enum(..)
208                | ConstValue::NonZero(..)) => {
209                    self.var_info.insert(*output, VarInfo::Const(value.clone()));
210                }
211                ConstValue::Boxed(inner) => {
212                    self.var_info.insert(
213                        *output,
214                        VarInfo::Box(VarInfo::Const(inner.as_ref().clone()).into()),
215                    );
216                }
217                ConstValue::Generic(_)
218                | ConstValue::ImplConstant(_)
219                | ConstValue::Var(..)
220                | ConstValue::Missing(_) => {}
221            },
222            Statement::Snapshot(stmt) => {
223                if let Some(info) = self.var_info.get(&stmt.input.var_id).cloned() {
224                    self.var_info.insert(stmt.original(), info.clone());
225                    self.var_info.insert(stmt.snapshot(), VarInfo::Snapshot(info.into()));
226                }
227            }
228            Statement::Desnap(StatementDesnap { input, output }) => {
229                if let Some(VarInfo::Snapshot(info)) = self.var_info.get(&input.var_id) {
230                    self.var_info.insert(*output, info.as_ref().clone());
231                }
232            }
233            Statement::Call(call_stmt) => {
234                if let Some(updated_stmt) = self.handle_statement_call(call_stmt) {
235                    *stmt = updated_stmt;
236                } else if let Some(updated_stmt) = self.try_specialize_call(call_stmt) {
237                    *stmt = updated_stmt;
238                }
239            }
240            Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
241                let mut const_args = vec![];
242                let mut all_args = vec![];
243                let mut contains_info = false;
244                for input in inputs.iter() {
245                    let Some(info) = self.var_info.get(&input.var_id) else {
246                        all_args.push(var_info_if_copy(self.variables, *input));
247                        continue;
248                    };
249                    contains_info = true;
250                    if let VarInfo::Const(value) = info {
251                        const_args.push(value.clone());
252                    }
253                    all_args.push(Some(info.clone()));
254                }
255                if const_args.len() == inputs.len() {
256                    let value = ConstValue::Struct(const_args, self.variables[*output].ty);
257                    self.var_info.insert(*output, VarInfo::Const(value));
258                } else if contains_info {
259                    self.var_info.insert(*output, VarInfo::Struct(all_args));
260                }
261            }
262            Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
263                if let Some(info) = self.var_info.get(&input.var_id) {
264                    let (n_snapshots, info) = info.peel_snapshots();
265                    match info {
266                        VarInfo::Const(ConstValue::Struct(member_values, _)) => {
267                            for (output, value) in zip_eq(outputs, member_values.clone()) {
268                                self.var_info.insert(
269                                    *output,
270                                    VarInfo::Const(value).wrap_with_snapshots(n_snapshots),
271                                );
272                            }
273                        }
274                        VarInfo::Struct(members) => {
275                            for (output, member) in zip_eq(outputs, members.clone()) {
276                                if let Some(member) = member {
277                                    self.var_info
278                                        .insert(*output, member.wrap_with_snapshots(n_snapshots));
279                                }
280                            }
281                        }
282                        _ => {}
283                    }
284                }
285            }
286            Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
287                if let Some(VarInfo::Const(val)) = self.var_info.get(&input.var_id) {
288                    let value = ConstValue::Enum(*variant, val.clone().into());
289                    self.var_info.insert(*output, VarInfo::Const(value.clone()));
290                }
291            }
292        }
293    }
294
295    /// Processes the block's end and incorporates additional statements into the block.
296    ///
297    /// This method handles the following tasks:
298    /// - Inserts the accumulated additional statements into the block.
299    /// - Converts match endings to goto when applicable.
300    /// - Updates self.reachability based on the block's ending.
301    pub fn visit_block_end(&mut self, block_id: BlockId, block: &mut crate::Block) {
302        block.statements.splice(0..0, self.additional_stmts.drain(..));
303
304        match &mut block.end {
305            BlockEnd::Goto(_, remappings) => {
306                for (_, v) in remappings.iter_mut() {
307                    self.maybe_replace_input(v);
308                }
309            }
310            BlockEnd::Match { info } => {
311                self.maybe_replace_inputs(info.inputs_mut());
312                match info {
313                    MatchInfo::Enum(MatchEnumInfo { input, arms, location, .. }) => {
314                        if let Some(var_info) = self.var_info.get(&input.var_id) {
315                            let (n_snapshots, var_info) = var_info.peel_snapshots();
316                            // Checking whether we have actual const info on the enum.
317                            if let VarInfo::Const(ConstValue::Enum(variant, value)) = var_info {
318                                let arm = &arms[variant.idx];
319                                let output = arm.var_ids[0];
320                                if self.variables[input.var_id].info.droppable.is_ok()
321                                    && self.variables[output].info.copyable.is_ok()
322                                {
323                                    if let Ok(mut ty) = value.ty(self.db) {
324                                        if let Some(mut stmt) =
325                                            self.try_generate_const_statement(value, output)
326                                        {
327                                            // Adding snapshot taking statements for snapshots.
328                                            for _ in 0..n_snapshots {
329                                                let non_snap_var = Variable::with_default_context(
330                                                    self.db, ty, *location,
331                                                );
332                                                ty = TypeLongId::Snapshot(ty).intern(self.db);
333                                                let ignored =
334                                                    self.variables.alloc(non_snap_var.clone());
335                                                let pre_snap = self.variables.alloc(non_snap_var);
336                                                stmt.outputs_mut()[0] = pre_snap;
337                                                block.statements.push(core::mem::replace(
338                                                    &mut stmt,
339                                                    Statement::Snapshot(StatementSnapshot::new(
340                                                        VarUsage {
341                                                            var_id: pre_snap,
342                                                            location: *location,
343                                                        },
344                                                        ignored,
345                                                        output,
346                                                    )),
347                                                ));
348                                            }
349                                            block.statements.push(stmt);
350                                            block.end =
351                                                BlockEnd::Goto(arm.block_id, Default::default());
352                                        }
353                                    }
354                                }
355                                // Propagating the const value information.
356                                self.var_info.insert(
357                                    output,
358                                    VarInfo::Const(value.as_ref().clone())
359                                        .wrap_with_snapshots(n_snapshots),
360                                );
361                            }
362                        }
363                    }
364                    MatchInfo::Value(info) => {
365                        if let Some(value) =
366                            self.as_int(info.input.var_id).and_then(|x| x.to_usize())
367                        {
368                            if let Some(arm) = info.arms.iter().find(|arm| {
369                                matches!(
370                                    &arm.arm_selector,
371                                    MatchArmSelector::Value(v) if v.value == value
372                                )
373                            }) {
374                                // Create the variable that was previously introduced in match arm.
375                                block.statements.push(Statement::StructConstruct(
376                                    StatementStructConstruct {
377                                        inputs: vec![],
378                                        output: arm.var_ids[0],
379                                    },
380                                ));
381                                block.end = BlockEnd::Goto(arm.block_id, Default::default());
382                            }
383                        }
384                    }
385                    MatchInfo::Extern(info) => {
386                        if let Some((extra_stmts, updated_end)) = self.handle_extern_block_end(info)
387                        {
388                            block.statements.extend(extra_stmts);
389                            block.end = updated_end;
390                        }
391                    }
392                }
393            }
394            BlockEnd::Return(ref mut inputs, _) => self.maybe_replace_inputs(inputs),
395            BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
396        }
397        match &block.end {
398            BlockEnd::Goto(dst_block_id, _) => {
399                match self.reachability.entry(*dst_block_id) {
400                    std::collections::hash_map::Entry::Occupied(mut e) => {
401                        e.insert(Reachability::Any)
402                    }
403                    std::collections::hash_map::Entry::Vacant(e) => {
404                        *e.insert(Reachability::FromSingleGoto(block_id))
405                    }
406                };
407            }
408            BlockEnd::Match { info } => {
409                for arm in info.arms() {
410                    assert!(self.reachability.insert(arm.block_id, Reachability::Any).is_none());
411                }
412            }
413            BlockEnd::NotSet | BlockEnd::Return(..) | BlockEnd::Panic(..) => {}
414        }
415    }
416
417    /// Handles a statement call.
418    ///
419    /// Returns None if no additional changes are required.
420    /// If changes are required, returns an updated statement (to override the current
421    /// statement).
422    /// May add additional statements to `self.additional_stmts` if just replacing the current
423    /// statement is not enough.
424    fn handle_statement_call(&mut self, stmt: &mut StatementCall) -> Option<Statement> {
425        let db = self.db;
426        if stmt.function == self.panic_with_felt252 {
427            let val = self.as_const(stmt.inputs[0].var_id)?;
428            stmt.inputs.clear();
429            stmt.function = GenericFunctionId::Free(self.panic_with_const_felt252)
430                .concretize(db, vec![GenericArgumentId::Constant(val.clone().intern(db))])
431                .lowered(db);
432            return None;
433        } else if stmt.function == self.panic_with_byte_array && !flag_unsafe_panic(db) {
434            let snap = self.var_info.get(&stmt.inputs[0].var_id)?;
435            let bytearray = try_extract_matches!(snap, VarInfo::Snapshot)?;
436            let [
437                Some(VarInfo::Array(data)),
438                Some(VarInfo::Const(ConstValue::Int(pending_word, _))),
439                Some(VarInfo::Const(ConstValue::Int(pending_len, _))),
440            ] = &try_extract_matches!(bytearray.as_ref(), VarInfo::Struct)?[..]
441            else {
442                return None;
443            };
444            let mut panic_data =
445                vec![BigInt::from_str_radix(BYTE_ARRAY_MAGIC, 16).unwrap(), data.len().into()];
446            for word in data {
447                let Some(VarInfo::Const(ConstValue::Int(word, _))) = word else {
448                    continue;
449                };
450                panic_data.push(word.clone());
451            }
452            panic_data.extend([pending_word.clone(), pending_len.clone()]);
453            let felt252_ty = self.felt252;
454            let location = stmt.location;
455            let new_var = |ty| Variable::with_default_context(db, ty, location);
456            let as_usage = |var_id| VarUsage { var_id, location };
457            let array_fn = |extern_id| {
458                let args = vec![GenericArgumentId::Type(felt252_ty)];
459                GenericFunctionId::Extern(extern_id).concretize(db, args).lowered(db)
460            };
461            let call_stmt = |function, inputs, outputs| {
462                let with_coupon = false;
463                Statement::Call(StatementCall { function, inputs, with_coupon, outputs, location })
464            };
465            let arr_var = new_var(corelib::core_array_felt252_ty(db));
466            let mut arr = self.variables.alloc(arr_var.clone());
467            self.additional_stmts.push(call_stmt(array_fn(self.array_new), vec![], vec![arr]));
468            let felt252_var = new_var(felt252_ty);
469            let arr_append_fn = array_fn(self.array_append);
470            for word in panic_data {
471                let to_append = self.variables.alloc(felt252_var.clone());
472                let new_arr = self.variables.alloc(arr_var.clone());
473                self.additional_stmts.push(Statement::Const(StatementConst {
474                    value: ConstValue::Int(word, felt252_ty),
475                    output: to_append,
476                }));
477                self.additional_stmts.push(call_stmt(
478                    arr_append_fn,
479                    vec![as_usage(arr), as_usage(to_append)],
480                    vec![new_arr],
481                ));
482                arr = new_arr;
483            }
484            let panic_ty = corelib::get_core_ty_by_name(db, "Panic".into(), vec![]);
485            let panic_var = self.variables.alloc(new_var(panic_ty));
486            self.additional_stmts.push(Statement::StructConstruct(StatementStructConstruct {
487                inputs: vec![],
488                output: panic_var,
489            }));
490            return Some(Statement::StructConstruct(StatementStructConstruct {
491                inputs: vec![as_usage(panic_var), as_usage(arr)],
492                output: stmt.outputs[0],
493            }));
494        }
495        let (id, _generic_args) = stmt.function.get_extern(db)?;
496        if id == self.felt_sub {
497            // (a - 0) can be replaced by a.
498            let val = self.as_int(stmt.inputs[1].var_id)?;
499            if val.is_zero() {
500                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
501            }
502            None
503        } else if self.wide_mul_fns.contains(&id) {
504            let lhs = self.as_int_ex(stmt.inputs[0].var_id);
505            let rhs = self.as_int(stmt.inputs[1].var_id);
506            let output = stmt.outputs[0];
507            if lhs.map(|(v, _)| v.is_zero()).unwrap_or_default()
508                || rhs.map(Zero::is_zero).unwrap_or_default()
509            {
510                return Some(self.propagate_zero_and_get_statement(output));
511            }
512            let (lhs, nz_ty) = lhs?;
513            Some(self.propagate_const_and_get_statement(lhs * rhs?, stmt.outputs[0], nz_ty))
514        } else if id == self.bounded_int_add || id == self.bounded_int_sub {
515            let lhs = self.as_int(stmt.inputs[0].var_id)?;
516            let rhs = self.as_int(stmt.inputs[1].var_id)?;
517            let value = if id == self.bounded_int_add { lhs + rhs } else { lhs - rhs };
518            Some(self.propagate_const_and_get_statement(value, stmt.outputs[0], false))
519        } else if self.div_rem_fns.contains(&id) {
520            let lhs = self.as_int(stmt.inputs[0].var_id);
521            if lhs.map(Zero::is_zero).unwrap_or_default() {
522                let additional_stmt = self.propagate_zero_and_get_statement(stmt.outputs[1]);
523                self.additional_stmts.push(additional_stmt);
524                return Some(self.propagate_zero_and_get_statement(stmt.outputs[0]));
525            }
526            let rhs = self.as_int(stmt.inputs[1].var_id)?;
527            let (q, r) = lhs?.div_rem(rhs);
528            let q_output = stmt.outputs[0];
529            let q_value = ConstValue::Int(q, self.variables[q_output].ty);
530            self.var_info.insert(q_output, VarInfo::Const(q_value.clone()));
531            let r_output = stmt.outputs[1];
532            let r_value = ConstValue::Int(r, self.variables[r_output].ty);
533            self.var_info.insert(r_output, VarInfo::Const(r_value.clone()));
534            self.additional_stmts
535                .push(Statement::Const(StatementConst { value: r_value, output: r_output }));
536            Some(Statement::Const(StatementConst { value: q_value, output: q_output }))
537        } else if id == self.storage_base_address_from_felt252 {
538            let input_var = stmt.inputs[0].var_id;
539            if let Some(ConstValue::Int(val, ty)) = self.as_const(input_var) {
540                stmt.inputs.clear();
541                let arg = GenericArgumentId::Constant(ConstValue::Int(val.clone(), *ty).intern(db));
542                stmt.function =
543                    self.storage_base_address_const.concretize(db, vec![arg]).lowered(db);
544            }
545            None
546        } else if id == self.into_box {
547            let input = stmt.inputs[0];
548            let var_info = self.var_info.get(&input.var_id);
549            let const_value = match var_info {
550                Some(VarInfo::Const(val)) => Some(val.clone()),
551                Some(VarInfo::Snapshot(info)) => {
552                    try_extract_matches!(info.as_ref(), VarInfo::Const).cloned()
553                }
554                _ => None,
555            };
556            let var_info = var_info.cloned().or_else(|| var_info_if_copy(self.variables, input))?;
557            self.var_info.insert(stmt.outputs[0], VarInfo::Box(var_info.into()));
558            Some(Statement::Const(StatementConst {
559                value: ConstValue::Boxed(const_value?.into()),
560                output: stmt.outputs[0],
561            }))
562        } else if id == self.unbox {
563            if let VarInfo::Box(inner) = self.var_info.get(&stmt.inputs[0].var_id)? {
564                let inner = inner.as_ref().clone();
565                if let VarInfo::Const(inner) =
566                    self.var_info.entry(stmt.outputs[0]).insert_entry(inner).get()
567                {
568                    return Some(Statement::Const(StatementConst {
569                        value: inner.clone(),
570                        output: stmt.outputs[0],
571                    }));
572                }
573            }
574            None
575        } else if self.upcast_fns.contains(&id) {
576            let int_value = self.as_int(stmt.inputs[0].var_id)?;
577            let output = stmt.outputs[0];
578            let value = ConstValue::Int(int_value.clone(), self.variables[output].ty);
579            self.var_info.insert(output, VarInfo::Const(value.clone()));
580            Some(Statement::Const(StatementConst { value, output }))
581        } else if id == self.array_new {
582            self.var_info.insert(stmt.outputs[0], VarInfo::Array(vec![]));
583            None
584        } else if id == self.array_append {
585            let mut var_infos =
586                if let VarInfo::Array(var_infos) = self.var_info.get(&stmt.inputs[0].var_id)? {
587                    var_infos.clone()
588                } else {
589                    return None;
590                };
591            let appended = stmt.inputs[1];
592            var_infos.push(match self.var_info.get(&appended.var_id) {
593                Some(var_info) => Some(var_info.clone()),
594                None => var_info_if_copy(self.variables, appended),
595            });
596            self.var_info.insert(stmt.outputs[0], VarInfo::Array(var_infos));
597            None
598        } else if id == self.array_len {
599            let info = self.var_info.get(&stmt.inputs[0].var_id)?;
600            let desnapped = try_extract_matches!(info, VarInfo::Snapshot)?;
601            let length = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?.len();
602            Some(self.propagate_const_and_get_statement(length.into(), stmt.outputs[0], false))
603        } else {
604            None
605        }
606    }
607
608    /// Tries to specialize the call.
609    /// Returns The specialized call statement if it was specialized, or None otherwise.
610    ///
611    /// Specialization occurs only if `priv_should_specialize` returns true.
612    /// Additionally specialization of a callee the with the same base as the caller is currently
613    /// not supported.
614    fn try_specialize_call(&mut self, call_stmt: &mut StatementCall) -> Option<Statement> {
615        if call_stmt.with_coupon {
616            return None;
617        }
618        // No specialization when avoiding inlining.
619        if matches!(self.db.optimization_config().inlining_strategy, InliningStrategy::Avoid) {
620            return None;
621        }
622
623        let Ok(Some(mut base)) = call_stmt.function.body(self.db) else {
624            return None;
625        };
626
627        if self.db.priv_never_inline(base).ok()? {
628            return None;
629        }
630
631        if call_stmt
632            .inputs
633            .iter()
634            .all(|arg| !matches!(self.var_info.get(&arg.var_id), Some(VarInfo::Const(_))))
635        {
636            // No const inputs
637            return None;
638        }
639
640        let mut const_args = vec![];
641        let mut new_args = vec![];
642        for arg in &call_stmt.inputs {
643            if let Some(var_info) = self.var_info.get(&arg.var_id) {
644                if let Some(c) =
645                    self.try_get_specialization_arg(var_info.clone(), self.variables[arg.var_id].ty)
646                {
647                    const_args.push(Some(c.clone()));
648                    continue;
649                }
650            }
651            const_args.push(None);
652            new_args.push(*arg);
653        }
654
655        if new_args.len() == call_stmt.inputs.len() {
656            // No argument was assigned -> no specialization.
657            return None;
658        }
659
660        if let ConcreteFunctionWithBodyLongId::Specialized(specialized_function) =
661            base.lookup_intern(self.db)
662        {
663            // Canonicalize the specialization rather than adding a specialization of a specialized
664            // function.
665            base = specialized_function.base;
666            let mut new_args_iter = const_args.into_iter();
667            const_args = specialized_function.args.to_vec();
668            for arg in &mut const_args {
669                if arg.is_none() {
670                    *arg = new_args_iter.next().unwrap_or_default();
671                }
672            }
673        }
674
675        // Avoid specializing with the same base as the current function as it may lead to infinite
676        // specialization.
677        if base == self.caller_base {
678            return None;
679        }
680        let specialized = SpecializedFunction { base, args: const_args.into() };
681        let specialized_func_id =
682            ConcreteFunctionWithBodyLongId::Specialized(specialized).intern(self.db);
683
684        if self.db.priv_should_specialize(specialized_func_id) == Ok(false) {
685            return None;
686        }
687
688        Some(Statement::Call(StatementCall {
689            function: specialized_func_id.function_id(self.db).unwrap(),
690            inputs: new_args,
691            with_coupon: call_stmt.with_coupon,
692            outputs: std::mem::take(&mut call_stmt.outputs),
693            location: call_stmt.location,
694        }))
695    }
696
697    /// Adds `value` as a const to `var_info` and return a const statement for it.
698    fn propagate_const_and_get_statement(
699        &mut self,
700        value: BigInt,
701        output: VariableId,
702        nz_ty: bool,
703    ) -> Statement {
704        let ty = self.variables[output].ty;
705        let value = if nz_ty {
706            let inner_ty =
707                try_extract_nz_wrapped_type(self.db, ty).expect("Expected a non-zero type");
708            ConstValue::NonZero(Box::new(ConstValue::Int(value, inner_ty)))
709        } else {
710            ConstValue::Int(value, ty)
711        };
712        self.var_info.insert(output, VarInfo::Const(value.clone()));
713        Statement::Const(StatementConst { value, output })
714    }
715
716    /// Adds 0 const to `var_info` and return a const statement for it.
717    fn propagate_zero_and_get_statement(&mut self, output: VariableId) -> Statement {
718        self.propagate_const_and_get_statement(BigInt::zero(), output, false)
719    }
720
721    /// Returns a statement that introduces the requested value into `output`, or None if fails.
722    fn try_generate_const_statement(
723        &self,
724        value: &ConstValue,
725        output: VariableId,
726    ) -> Option<Statement> {
727        if self.db.type_size_info(self.variables[output].ty) == Ok(TypeSizeInformation::Other) {
728            Some(Statement::Const(StatementConst { value: value.clone(), output }))
729        } else if matches!(value, ConstValue::Struct(members, _) if members.is_empty()) {
730            // Handling const empty structs - which are not supported in sierra-gen.
731            Some(Statement::StructConstruct(StatementStructConstruct { inputs: vec![], output }))
732        } else {
733            None
734        }
735    }
736
737    /// Handles the end of an extern block.
738    /// Returns None if no additional changes are required.
739    /// If changes are required, returns a possible additional const-statement to the block, as well
740    /// as an updated block end.
741    fn handle_extern_block_end(
742        &mut self,
743        info: &mut MatchExternInfo,
744    ) -> Option<(Vec<Statement>, BlockEnd)> {
745        let db = self.db;
746        let (id, generic_args) = info.function.get_extern(db)?;
747        if self.nz_fns.contains(&id) {
748            let val = self.as_const(info.inputs[0].var_id)?;
749            let is_zero = match val {
750                ConstValue::Int(v, _) => v.is_zero(),
751                ConstValue::Struct(s, _) => s.iter().all(|v| {
752                    v.clone().into_int().expect("Expected ConstValue::Int for size").is_zero()
753                }),
754                _ => unreachable!(),
755            };
756            Some(if is_zero {
757                (vec![], BlockEnd::Goto(info.arms[0].block_id, Default::default()))
758            } else {
759                let arm = &info.arms[1];
760                let nz_var = arm.var_ids[0];
761                let nz_val = ConstValue::NonZero(Box::new(val.clone()));
762                self.var_info.insert(nz_var, VarInfo::Const(nz_val.clone()));
763                (
764                    vec![Statement::Const(StatementConst { value: nz_val, output: nz_var })],
765                    BlockEnd::Goto(arm.block_id, Default::default()),
766                )
767            })
768        } else if self.eq_fns.contains(&id) {
769            let lhs = self.as_int(info.inputs[0].var_id);
770            let rhs = self.as_int(info.inputs[1].var_id);
771            if (lhs.map(Zero::is_zero).unwrap_or_default() && rhs.is_none())
772                || (rhs.map(Zero::is_zero).unwrap_or_default() && lhs.is_none())
773            {
774                let nz_input = info.inputs[if lhs.is_some() { 1 } else { 0 }];
775                let var = &self.variables[nz_input.var_id].clone();
776                let function = self.type_value_ranges.get(&var.ty)?.is_zero;
777                let unused_nz_var = Variable::with_default_context(
778                    db,
779                    corelib::core_nonzero_ty(db, var.ty),
780                    var.location,
781                );
782                let unused_nz_var = self.variables.alloc(unused_nz_var);
783                return Some((
784                    vec![],
785                    BlockEnd::Match {
786                        info: MatchInfo::Extern(MatchExternInfo {
787                            function,
788                            inputs: vec![nz_input],
789                            arms: vec![
790                                MatchArm {
791                                    arm_selector: MatchArmSelector::VariantId(
792                                        corelib::jump_nz_zero_variant(db, var.ty),
793                                    ),
794                                    block_id: info.arms[1].block_id,
795                                    var_ids: vec![],
796                                },
797                                MatchArm {
798                                    arm_selector: MatchArmSelector::VariantId(
799                                        corelib::jump_nz_nonzero_variant(db, var.ty),
800                                    ),
801                                    block_id: info.arms[0].block_id,
802                                    var_ids: vec![unused_nz_var],
803                                },
804                            ],
805                            location: info.location,
806                        }),
807                    },
808                ));
809            }
810            Some((
811                vec![],
812                BlockEnd::Goto(
813                    info.arms[if lhs? == rhs? { 1 } else { 0 }].block_id,
814                    Default::default(),
815                ),
816            ))
817        } else if self.uadd_fns.contains(&id)
818            || self.usub_fns.contains(&id)
819            || self.diff_fns.contains(&id)
820            || self.iadd_fns.contains(&id)
821            || self.isub_fns.contains(&id)
822        {
823            let rhs = self.as_int(info.inputs[1].var_id);
824            let lhs = self.as_int(info.inputs[0].var_id);
825            if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
826                let ty = self.variables[info.arms[0].var_ids[0]].ty;
827                let range = &self.type_value_ranges.get(&ty)?.range;
828                let value = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
829                    lhs + rhs
830                } else {
831                    lhs - rhs
832                };
833                let (arm_index, value) = match range.normalized(value) {
834                    NormalizedResult::InRange(value) => (0, value),
835                    NormalizedResult::Under(value) => (1, value),
836                    NormalizedResult::Over(value) => (
837                        if self.iadd_fns.contains(&id) || self.isub_fns.contains(&id) {
838                            2
839                        } else {
840                            1
841                        },
842                        value,
843                    ),
844                };
845                let arm = &info.arms[arm_index];
846                let actual_output = arm.var_ids[0];
847                let value = ConstValue::Int(value, ty);
848                self.var_info.insert(actual_output, VarInfo::Const(value.clone()));
849                return Some((
850                    vec![Statement::Const(StatementConst { value, output: actual_output })],
851                    BlockEnd::Goto(arm.block_id, Default::default()),
852                ));
853            }
854            if let Some(rhs) = rhs {
855                if rhs.is_zero() && !self.diff_fns.contains(&id) {
856                    let arm = &info.arms[0];
857                    self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[0]));
858                    return Some((vec![], BlockEnd::Goto(arm.block_id, Default::default())));
859                }
860                if rhs.is_one() && !self.diff_fns.contains(&id) {
861                    let ty = self.variables[info.arms[0].var_ids[0]].ty;
862                    let ty_info = self.type_value_ranges.get(&ty)?;
863                    let function = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
864                        ty_info.inc?
865                    } else {
866                        ty_info.dec?
867                    };
868                    let enum_ty = function.signature(db).ok()?.return_type;
869                    let TypeLongId::Concrete(ConcreteTypeId::Enum(concrete_enum_id)) =
870                        enum_ty.lookup_intern(db)
871                    else {
872                        return None;
873                    };
874                    let result = self.variables.alloc(Variable::with_default_context(
875                        db,
876                        function.signature(db).unwrap().return_type,
877                        info.location,
878                    ));
879                    return Some((
880                        vec![Statement::Call(StatementCall {
881                            function,
882                            inputs: vec![info.inputs[0]],
883                            with_coupon: false,
884                            outputs: vec![result],
885                            location: info.location,
886                        })],
887                        BlockEnd::Match {
888                            info: MatchInfo::Enum(MatchEnumInfo {
889                                concrete_enum_id,
890                                input: VarUsage { var_id: result, location: info.location },
891                                arms: core::mem::take(&mut info.arms),
892                                location: info.location,
893                            }),
894                        },
895                    ));
896                }
897            }
898            if let Some(lhs) = lhs {
899                if lhs.is_zero() && (self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id)) {
900                    let arm = &info.arms[0];
901                    self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[1]));
902                    return Some((vec![], BlockEnd::Goto(arm.block_id, Default::default())));
903                }
904            }
905            None
906        } else if let Some(reversed) = self.downcast_fns.get(&id) {
907            let range = |ty: TypeId| {
908                Some(if let Some(ti) = self.type_value_ranges.get(&ty) {
909                    ti.range.clone()
910                } else {
911                    let (min, max) = corelib::try_extract_bounded_int_type_ranges(db, ty)?;
912                    TypeRange { min, max }
913                })
914            };
915            let (success_arm, failure_arm) = if *reversed { (1, 0) } else { (0, 1) };
916            let input_var = info.inputs[0].var_id;
917            let success_output = info.arms[success_arm].var_ids[0];
918            let out_ty = self.variables[success_output].ty;
919            let out_range = range(out_ty)?;
920            let Some(value) = self.as_int(input_var) else {
921                let in_ty = self.variables[input_var].ty;
922                let in_range = range(in_ty)?;
923                return if in_range.min < out_range.min || in_range.max > out_range.max {
924                    None
925                } else {
926                    let generic_args = [in_ty, out_ty].map(GenericArgumentId::Type).to_vec();
927                    let function = db.core_info().upcast_fn.concretize(db, generic_args);
928                    return Some((
929                        vec![Statement::Call(StatementCall {
930                            function: function.lowered(db),
931                            inputs: vec![info.inputs[0]],
932                            with_coupon: false,
933                            outputs: vec![success_output],
934                            location: info.location,
935                        })],
936                        BlockEnd::Goto(info.arms[success_arm].block_id, Default::default()),
937                    ));
938                };
939            };
940            Some(if let NormalizedResult::InRange(value) = out_range.normalized(value.clone()) {
941                let value = ConstValue::Int(value, out_ty);
942                self.var_info.insert(success_output, VarInfo::Const(value.clone()));
943                (
944                    vec![Statement::Const(StatementConst { value, output: success_output })],
945                    BlockEnd::Goto(info.arms[success_arm].block_id, Default::default()),
946                )
947            } else {
948                (vec![], BlockEnd::Goto(info.arms[failure_arm].block_id, Default::default()))
949            })
950        } else if id == self.bounded_int_constrain {
951            let input_var = info.inputs[0].var_id;
952            let (value, nz_ty) = self.as_int_ex(input_var)?;
953            let generic_arg = generic_args[1];
954            let constrain_value = extract_matches!(generic_arg, GenericArgumentId::Constant)
955                .lookup_intern(db)
956                .into_int()
957                .unwrap();
958            let arm_idx = if value < &constrain_value { 0 } else { 1 };
959            let output = info.arms[arm_idx].var_ids[0];
960            Some((
961                vec![self.propagate_const_and_get_statement(value.clone(), output, nz_ty)],
962                BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()),
963            ))
964        } else if id == self.array_get {
965            let index = self.as_int(info.inputs[1].var_id)?.to_usize()?;
966            if let Some(VarInfo::Snapshot(arr_info)) = self.var_info.get(&info.inputs[0].var_id) {
967                if let VarInfo::Array(infos) = arr_info.as_ref() {
968                    if let Some(Some(output_var_info)) = infos.get(index) {
969                        let arm = &info.arms[0];
970                        let output_var_info = output_var_info.clone();
971                        let box_info =
972                            VarInfo::Box(VarInfo::Snapshot(output_var_info.clone().into()).into());
973                        self.var_info.insert(arm.var_ids[0], box_info);
974                        if let VarInfo::Const(value) = output_var_info {
975                            let value_ty = value.ty(db).ok()?;
976                            let value_box_ty = corelib::core_box_ty(db, value_ty);
977                            let location = info.location;
978                            let boxed_var =
979                                Variable::with_default_context(db, value_box_ty, location);
980                            let boxed = self.variables.alloc(boxed_var.clone());
981                            let unused_boxed = self.variables.alloc(boxed_var);
982                            let snapped = self.variables.alloc(Variable::with_default_context(
983                                db,
984                                TypeLongId::Snapshot(value_box_ty).intern(db),
985                                location,
986                            ));
987                            return Some((
988                                vec![
989                                    Statement::Const(StatementConst {
990                                        value: ConstValue::Boxed(value.into()),
991                                        output: boxed,
992                                    }),
993                                    Statement::Snapshot(StatementSnapshot {
994                                        input: VarUsage { var_id: boxed, location },
995                                        outputs: [unused_boxed, snapped],
996                                    }),
997                                    Statement::Call(StatementCall {
998                                        function: self
999                                            .box_forward_snapshot
1000                                            .concretize(db, vec![GenericArgumentId::Type(value_ty)])
1001                                            .lowered(db),
1002                                        inputs: vec![VarUsage { var_id: snapped, location }],
1003                                        with_coupon: false,
1004                                        outputs: vec![arm.var_ids[0]],
1005                                        location: info.location,
1006                                    }),
1007                                ],
1008                                BlockEnd::Goto(arm.block_id, Default::default()),
1009                            ));
1010                        }
1011                    } else {
1012                        return Some((
1013                            vec![],
1014                            BlockEnd::Goto(info.arms[1].block_id, Default::default()),
1015                        ));
1016                    }
1017                }
1018            }
1019            if index.is_zero() {
1020                if let [success, failure] = info.arms.as_mut_slice() {
1021                    let arr = info.inputs[0].var_id;
1022                    let unused_arr_output0 = self.variables.alloc(self.variables[arr].clone());
1023                    let unused_arr_output1 = self.variables.alloc(self.variables[arr].clone());
1024                    info.inputs.truncate(1);
1025                    info.function = GenericFunctionId::Extern(self.array_snapshot_pop_front)
1026                        .concretize(db, generic_args)
1027                        .lowered(db);
1028                    success.var_ids.insert(0, unused_arr_output0);
1029                    failure.var_ids.insert(0, unused_arr_output1);
1030                }
1031            }
1032            None
1033        } else if id == self.array_pop_front {
1034            let VarInfo::Array(var_infos) = self.var_info.get(&info.inputs[0].var_id)? else {
1035                return None;
1036            };
1037            if let Some(first) = var_infos.first() {
1038                if let Some(first) = first.as_ref().cloned() {
1039                    let arm = &info.arms[0];
1040                    self.var_info.insert(arm.var_ids[0], VarInfo::Array(var_infos[1..].to_vec()));
1041                    self.var_info.insert(arm.var_ids[1], VarInfo::Box(first.into()));
1042                }
1043                None
1044            } else {
1045                let arm = &info.arms[1];
1046                self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]));
1047                Some((
1048                    vec![],
1049                    BlockEnd::Goto(
1050                        arm.block_id,
1051                        VarRemapping { remapping: [(arm.var_ids[0], info.inputs[0])].into() },
1052                    ),
1053                ))
1054            }
1055        } else if id == self.array_snapshot_pop_back || id == self.array_snapshot_pop_front {
1056            let var_info = self.var_info.get(&info.inputs[0].var_id)?;
1057            let desnapped = try_extract_matches!(var_info, VarInfo::Snapshot)?;
1058            let element_var_infos = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?;
1059            // TODO(orizi): Propagate success values as well.
1060            if element_var_infos.is_empty() {
1061                let arm = &info.arms[1];
1062                self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]));
1063                Some((
1064                    vec![],
1065                    BlockEnd::Goto(
1066                        arm.block_id,
1067                        VarRemapping { remapping: [(arm.var_ids[0], info.inputs[0])].into() },
1068                    ),
1069                ))
1070            } else {
1071                None
1072            }
1073        } else {
1074            None
1075        }
1076    }
1077
1078    /// Returns the const value of a variable if it exists.
1079    fn as_const(&self, var_id: VariableId) -> Option<&ConstValue> {
1080        try_extract_matches!(self.var_info.get(&var_id)?, VarInfo::Const)
1081    }
1082
1083    /// Return the const value as an int if it exists and is an integer, additionally, if it is of a
1084    /// non-zero type.
1085    fn as_int_ex(&self, var_id: VariableId) -> Option<(&BigInt, bool)> {
1086        match self.as_const(var_id)? {
1087            ConstValue::Int(value, _) => Some((value, false)),
1088            ConstValue::NonZero(const_value) => {
1089                if let ConstValue::Int(value, _) = const_value.as_ref() {
1090                    Some((value, true))
1091                } else {
1092                    None
1093                }
1094            }
1095            _ => None,
1096        }
1097    }
1098
1099    /// Return the const value as a int if it exists and is an integer.
1100    fn as_int(&self, var_id: VariableId) -> Option<&BigInt> {
1101        Some(self.as_int_ex(var_id)?.0)
1102    }
1103
1104    /// Replaces the inputs in place if they are in the var_info map.
1105    fn maybe_replace_inputs(&mut self, inputs: &mut [VarUsage]) {
1106        for input in inputs {
1107            self.maybe_replace_input(input);
1108        }
1109    }
1110
1111    /// Replaces the input in place if it is in the var_info map.
1112    fn maybe_replace_input(&mut self, input: &mut VarUsage) {
1113        if let Some(VarInfo::Var(new_var)) = self.var_info.get(&input.var_id) {
1114            *input = *new_var;
1115        }
1116    }
1117
1118    /// Given a var_info and its type, return the corresponding specialization argument, if it
1119    /// exists.
1120    fn try_get_specialization_arg(
1121        &mut self,
1122        var_info: VarInfo,
1123        ty: TypeId,
1124    ) -> Option<SpecializationArg> {
1125        if self.db.type_size_info(ty).ok()? == TypeSizeInformation::ZeroSized {
1126            // Skip zero-sized constants as they are not supported in sierra-gen.
1127            return None;
1128        }
1129
1130        match var_info {
1131            VarInfo::Const(c) => Some(SpecializationArg::Const(c.clone())),
1132            VarInfo::Array(infos) if infos.is_empty() => {
1133                let TypeLongId::Concrete(concrete_ty) = ty.lookup_intern(self.db) else {
1134                    unreachable!("Expected a concrete type");
1135                };
1136                let [GenericArgumentId::Type(inner_ty)] = &concrete_ty.generic_args(self.db)[..]
1137                else {
1138                    unreachable!("Expected a single type generic argument");
1139                };
1140                Some(SpecializationArg::EmptyArray(*inner_ty))
1141            }
1142            VarInfo::Struct(infos) => {
1143                let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) =
1144                    ty.lookup_intern(self.db)
1145                else {
1146                    // TODO(ilya): Support closures and fixed size arrays.
1147                    return None;
1148                };
1149
1150                let members = self.db.concrete_struct_members(concrete_struct).unwrap();
1151                let struct_args = zip_eq(members.values(), infos)
1152                    .map(|(member, opt_var_info)| {
1153                        self.try_get_specialization_arg(opt_var_info?.clone(), member.ty)
1154                    })
1155                    .collect::<Option<Vec<_>>>()?;
1156
1157                Some(SpecializationArg::Struct(struct_args))
1158            }
1159            _ => None,
1160        }
1161    }
1162
1163    /// Returns true if const-folding should be skipped for the current function.
1164    pub fn should_skip_const_folding(&self, db: &dyn LoweringGroup) -> bool {
1165        if db.optimization_config().skip_const_folding {
1166            return true;
1167        }
1168
1169        // Skipping const-folding for `panic_with_const_felt252` - to avoid replacing a call to
1170        // `panic_with_felt252` with `panic_with_const_felt252` and causing accidental recursion.
1171        if self.caller_base.base_semantic_function(db).generic_function(db)
1172            == GenericFunctionWithBodyId::Free(self.libfunc_info.panic_with_const_felt252)
1173        {
1174            return true;
1175        }
1176        false
1177    }
1178}
1179
1180/// Returns a `VarInfo` of a variable only if it is copyable.
1181fn var_info_if_copy(variables: &Arena<Variable>, input: VarUsage) -> Option<VarInfo> {
1182    variables[input.var_id].info.copyable.is_ok().then_some(VarInfo::Var(input))
1183}
1184
1185/// Query implementation of [LoweringGroup::priv_const_folding_info].
1186pub fn priv_const_folding_info(
1187    db: &dyn LoweringGroup,
1188) -> Arc<crate::optimizations::const_folding::ConstFoldingLibfuncInfo> {
1189    Arc::new(ConstFoldingLibfuncInfo::new(db))
1190}
1191
1192/// Holds static information about libfuncs required for the optimization.
1193#[derive(Debug, PartialEq, Eq)]
1194pub struct ConstFoldingLibfuncInfo {
1195    /// The `felt252_sub` libfunc.
1196    felt_sub: ExternFunctionId,
1197    /// The `into_box` libfunc.
1198    into_box: ExternFunctionId,
1199    /// The `unbox` libfunc.
1200    unbox: ExternFunctionId,
1201    /// The `box_forward_snapshot` libfunc.
1202    box_forward_snapshot: GenericFunctionId,
1203    /// The set of functions that check if a number is zero.
1204    nz_fns: OrderedHashSet<ExternFunctionId>,
1205    /// The set of functions that check if numbers are equal.
1206    eq_fns: OrderedHashSet<ExternFunctionId>,
1207    /// The set of functions to add unsigned ints.
1208    uadd_fns: OrderedHashSet<ExternFunctionId>,
1209    /// The set of functions to subtract unsigned ints.
1210    usub_fns: OrderedHashSet<ExternFunctionId>,
1211    /// The set of functions to get the difference of signed ints.
1212    diff_fns: OrderedHashSet<ExternFunctionId>,
1213    /// The set of functions to add signed ints.
1214    iadd_fns: OrderedHashSet<ExternFunctionId>,
1215    /// The set of functions to subtract signed ints.
1216    isub_fns: OrderedHashSet<ExternFunctionId>,
1217    /// The set of functions to multiply integers.
1218    wide_mul_fns: OrderedHashSet<ExternFunctionId>,
1219    /// The set of functions to divide and get the remainder of integers.
1220    div_rem_fns: OrderedHashSet<ExternFunctionId>,
1221    /// The `bounded_int_add` libfunc.
1222    bounded_int_add: ExternFunctionId,
1223    /// The `bounded_int_sub` libfunc.
1224    bounded_int_sub: ExternFunctionId,
1225    /// The `bounded_int_constrain` libfunc.
1226    bounded_int_constrain: ExternFunctionId,
1227    /// The `array_get` libfunc.
1228    array_get: ExternFunctionId,
1229    /// The `array_snapshot_pop_front` libfunc.
1230    array_snapshot_pop_front: ExternFunctionId,
1231    /// The `array_snapshot_pop_back` libfunc.
1232    array_snapshot_pop_back: ExternFunctionId,
1233    /// The `array_len` libfunc.
1234    array_len: ExternFunctionId,
1235    /// The `array_new` libfunc.
1236    array_new: ExternFunctionId,
1237    /// The `array_append` libfunc.
1238    array_append: ExternFunctionId,
1239    /// The `array_pop_front` libfunc.
1240    array_pop_front: ExternFunctionId,
1241    /// The `storage_base_address_from_felt252` libfunc.
1242    storage_base_address_from_felt252: ExternFunctionId,
1243    /// The `storage_base_address_const` libfunc.
1244    storage_base_address_const: GenericFunctionId,
1245    /// The `core::panic_with_felt252` function.
1246    panic_with_felt252: FunctionId,
1247    /// The `core::panic_with_const_felt252` function.
1248    pub panic_with_const_felt252: FreeFunctionId,
1249    /// The `core::panics::panic_with_byte_array` function.
1250    panic_with_byte_array: FunctionId,
1251    /// Type ranges.
1252    type_value_ranges: OrderedHashMap<TypeId, TypeInfo>,
1253    /// The info used for semantic const calculation.
1254    const_calculation_info: Arc<ConstCalcInfo>,
1255}
1256impl ConstFoldingLibfuncInfo {
1257    fn new(db: &dyn LoweringGroup) -> Self {
1258        let core = ModuleHelper::core(db);
1259        let box_module = core.submodule("box");
1260        let integer_module = core.submodule("integer");
1261        let internal_module = core.submodule("internal");
1262        let bounded_int_module = internal_module.submodule("bounded_int");
1263        let num_module = internal_module.submodule("num");
1264        let array_module = core.submodule("array");
1265        let starknet_module = core.submodule("starknet");
1266        let storage_access_module = starknet_module.submodule("storage_access");
1267        let nz_fns = OrderedHashSet::<_>::from_iter(chain!(
1268            [
1269                core.extern_function_id("felt252_is_zero"),
1270                bounded_int_module.extern_function_id("bounded_int_is_zero")
1271            ],
1272            ["u8", "u16", "u32", "u64", "u128", "u256"]
1273                .map(|ty| integer_module.extern_function_id(format!("{ty}_is_zero")))
1274        ));
1275        let utypes = ["u8", "u16", "u32", "u64", "u128"];
1276        let itypes = ["i8", "i16", "i32", "i64", "i128"];
1277        let eq_fns = OrderedHashSet::<_>::from_iter(
1278            chain!(utypes, itypes).map(|ty| integer_module.extern_function_id(format!("{ty}_eq"))),
1279        );
1280        let uadd_fns = OrderedHashSet::<_>::from_iter(
1281            utypes.map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_add"))),
1282        );
1283        let usub_fns = OrderedHashSet::<_>::from_iter(
1284            utypes.map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_sub"))),
1285        );
1286        let diff_fns = OrderedHashSet::<_>::from_iter(
1287            itypes.map(|ty| integer_module.extern_function_id(format!("{ty}_diff"))),
1288        );
1289        let iadd_fns = OrderedHashSet::<_>::from_iter(
1290            itypes
1291                .map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_add_impl"))),
1292        );
1293        let isub_fns = OrderedHashSet::<_>::from_iter(
1294            itypes
1295                .map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_sub_impl"))),
1296        );
1297        let wide_mul_fns = OrderedHashSet::<_>::from_iter(chain!(
1298            [bounded_int_module.extern_function_id("bounded_int_mul")],
1299            ["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
1300                .map(|ty| integer_module.extern_function_id(format!("{ty}_wide_mul"))),
1301        ));
1302        let div_rem_fns = OrderedHashSet::<_>::from_iter(chain!(
1303            [bounded_int_module.extern_function_id("bounded_int_div_rem")],
1304            utypes.map(|ty| integer_module.extern_function_id(format!("{ty}_safe_divmod"))),
1305        ));
1306        let type_value_ranges = OrderedHashMap::from_iter(
1307            [
1308                ("u8", BigInt::ZERO, u8::MAX.into(), false, true),
1309                ("u16", BigInt::ZERO, u16::MAX.into(), false, true),
1310                ("u32", BigInt::ZERO, u32::MAX.into(), false, true),
1311                ("u64", BigInt::ZERO, u64::MAX.into(), false, true),
1312                ("u128", BigInt::ZERO, u128::MAX.into(), false, true),
1313                ("u256", BigInt::ZERO, BigInt::from(1) << 256, false, false),
1314                ("i8", i8::MIN.into(), i8::MAX.into(), true, true),
1315                ("i16", i16::MIN.into(), i16::MAX.into(), true, true),
1316                ("i32", i32::MIN.into(), i32::MAX.into(), true, true),
1317                ("i64", i64::MIN.into(), i64::MAX.into(), true, true),
1318                ("i128", i128::MIN.into(), i128::MAX.into(), true, true),
1319            ]
1320            .map(
1321                |(ty_name, min, max, as_bounded_int, inc_dec): (
1322                    &str,
1323                    BigInt,
1324                    BigInt,
1325                    bool,
1326                    bool,
1327                )| {
1328                    let ty = corelib::get_core_ty_by_name(db, ty_name.into(), vec![]);
1329                    let is_zero = if as_bounded_int {
1330                        bounded_int_module
1331                            .function_id("bounded_int_is_zero", vec![GenericArgumentId::Type(ty)])
1332                    } else {
1333                        integer_module.function_id(format!("{ty_name}_is_zero"), vec![])
1334                    }
1335                    .lowered(db);
1336                    let (inc, dec) = if inc_dec {
1337                        (
1338                            Some(
1339                                num_module
1340                                    .function_id(format!("{ty_name}_inc"), vec![])
1341                                    .lowered(db),
1342                            ),
1343                            Some(
1344                                num_module
1345                                    .function_id(format!("{ty_name}_dec"), vec![])
1346                                    .lowered(db),
1347                            ),
1348                        )
1349                    } else {
1350                        (None, None)
1351                    };
1352                    let info = TypeInfo { range: TypeRange { min, max }, is_zero, inc, dec };
1353                    (ty, info)
1354                },
1355            ),
1356        );
1357        Self {
1358            felt_sub: core.extern_function_id("felt252_sub"),
1359            into_box: box_module.extern_function_id("into_box"),
1360            unbox: box_module.extern_function_id("unbox"),
1361            box_forward_snapshot: box_module.generic_function_id("box_forward_snapshot"),
1362            nz_fns,
1363            eq_fns,
1364            uadd_fns,
1365            usub_fns,
1366            diff_fns,
1367            iadd_fns,
1368            isub_fns,
1369            wide_mul_fns,
1370            div_rem_fns,
1371            bounded_int_add: bounded_int_module.extern_function_id("bounded_int_add"),
1372            bounded_int_sub: bounded_int_module.extern_function_id("bounded_int_sub"),
1373            bounded_int_constrain: bounded_int_module.extern_function_id("bounded_int_constrain"),
1374            array_get: array_module.extern_function_id("array_get"),
1375            array_snapshot_pop_front: array_module.extern_function_id("array_snapshot_pop_front"),
1376            array_snapshot_pop_back: array_module.extern_function_id("array_snapshot_pop_back"),
1377            array_len: array_module.extern_function_id("array_len"),
1378            array_new: array_module.extern_function_id("array_new"),
1379            array_append: array_module.extern_function_id("array_append"),
1380            array_pop_front: array_module.extern_function_id("array_pop_front"),
1381            storage_base_address_from_felt252: storage_access_module
1382                .extern_function_id("storage_base_address_from_felt252"),
1383            storage_base_address_const: storage_access_module
1384                .generic_function_id("storage_base_address_const"),
1385            panic_with_felt252: core.function_id("panic_with_felt252", vec![]).lowered(db),
1386            panic_with_const_felt252: core.free_function_id("panic_with_const_felt252"),
1387            panic_with_byte_array: core
1388                .submodule("panics")
1389                .function_id("panic_with_byte_array", vec![])
1390                .lowered(db),
1391            type_value_ranges,
1392            const_calculation_info: db.const_calc_info(),
1393        }
1394    }
1395}
1396
1397impl std::ops::Deref for ConstFoldingContext<'_> {
1398    type Target = ConstFoldingLibfuncInfo;
1399    fn deref(&self) -> &ConstFoldingLibfuncInfo {
1400        &self.libfunc_info
1401    }
1402}
1403
1404impl std::ops::Deref for ConstFoldingLibfuncInfo {
1405    type Target = ConstCalcInfo;
1406    fn deref(&self) -> &ConstCalcInfo {
1407        &self.const_calculation_info
1408    }
1409}
1410
1411/// The information of a type required for const foldings.
1412#[derive(Debug, PartialEq, Eq)]
1413struct TypeInfo {
1414    /// The value range of the type.
1415    range: TypeRange,
1416    /// The function to check if the value is zero for the type.
1417    is_zero: FunctionId,
1418    /// Inc function to increase the value by one.
1419    inc: Option<FunctionId>,
1420    /// Dec function to decrease the value by one.
1421    dec: Option<FunctionId>,
1422}
1423
1424/// The range of values of a numeric type.
1425#[derive(Debug, PartialEq, Eq, Clone)]
1426struct TypeRange {
1427    /// The minimum value of the type.
1428    min: BigInt,
1429    /// The maximum value of the type.
1430    max: BigInt,
1431}
1432impl TypeRange {
1433    /// Normalizes the value to the range.
1434    /// Assumes the value is within size of range of the range.
1435    fn normalized(&self, value: BigInt) -> NormalizedResult {
1436        if value < self.min {
1437            NormalizedResult::Under(value - &self.min + &self.max + 1)
1438        } else if value > self.max {
1439            NormalizedResult::Over(value + &self.min - &self.max - 1)
1440        } else {
1441            NormalizedResult::InRange(value)
1442        }
1443    }
1444}
1445
1446/// The result of normalizing a value to a range.
1447enum NormalizedResult {
1448    /// The original value is in the range, carries the value, or an equivalent value.
1449    InRange(BigInt),
1450    /// The original value is larger than range max, carries the normalized value.
1451    Over(BigInt),
1452    /// The original value is smaller than range min, carries the normalized value.
1453    Under(BigInt),
1454}