cairo_lang_lowering/optimizations/
const_folding.rs

1#[cfg(test)]
2#[path = "const_folding_test.rs"]
3mod test;
4
5use std::sync::Arc;
6
7use cairo_lang_defs::ids::{ExternFunctionId, ModuleId};
8use cairo_lang_semantic::helper::ModuleHelper;
9use cairo_lang_semantic::items::constant::ConstValue;
10use cairo_lang_semantic::items::imp::ImplLookupContext;
11use cairo_lang_semantic::{GenericArgumentId, MatchArmSelector, TypeId, corelib};
12use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
13use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
14use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
15use cairo_lang_utils::{Intern, LookupIntern, extract_matches, try_extract_matches};
16use id_arena::Arena;
17use itertools::{chain, zip_eq};
18use num_bigint::BigInt;
19use num_integer::Integer;
20use num_traits::Zero;
21
22use crate::db::LoweringGroup;
23use crate::ids::{FunctionId, SemanticFunctionIdEx};
24use crate::{
25    BlockId, FlatBlockEnd, FlatLowered, MatchArm, MatchEnumInfo, MatchExternInfo, MatchInfo,
26    Statement, StatementCall, StatementConst, StatementDesnap, StatementEnumConstruct,
27    StatementStructConstruct, StatementStructDestructure, VarUsage, Variable, VariableId,
28};
29
30/// Keeps track of equivalent values that a variables might be replaced with.
31/// Note: We don't keep track of types as we assume the usage is always correct.
32#[derive(Debug, Clone)]
33enum VarInfo {
34    /// The variable is a const value.
35    Const(ConstValue),
36    /// The variable can be replaced by another variable.
37    Var(VarUsage),
38    /// The variable is a snapshot of another variable.
39    Snapshot(Box<VarInfo>),
40    /// The variable is a struct of other variables.
41    /// `None` values represent variables that are not tracked.
42    Struct(Vec<Option<VarInfo>>),
43}
44
45#[derive(Debug, Clone, Copy, PartialEq)]
46enum Reachability {
47    /// The block is not reachable from the function start after const-folding.
48    Unreachable,
49    /// The block is reachable from the function start only through the goto at the end of the given
50    /// block.
51    FromSingleGoto(BlockId),
52    /// The block is reachable from the function start after const-folding - just does not fit
53    /// `FromSingleGoto`.
54    Any,
55}
56
57/// Performs constant folding on the lowered program.
58/// The optimization only works when the blocks are topologically sorted.
59pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
60    if db.optimization_config().skip_const_folding || lowered.blocks.is_empty() {
61        return;
62    }
63    let libfunc_info = priv_const_folding_info(db);
64    // Note that we can keep the var_info across blocks because the lowering
65    // is in static single assignment form.
66    let mut ctx = ConstFoldingContext {
67        db,
68        var_info: UnorderedHashMap::default(),
69        variables: &mut lowered.variables,
70        libfunc_info: &libfunc_info,
71    };
72    let mut reachability = vec![Reachability::Unreachable; lowered.blocks.len()];
73    reachability[0] = Reachability::Any;
74    for block_id in 0..lowered.blocks.len() {
75        match reachability[block_id] {
76            Reachability::Unreachable => continue,
77            Reachability::Any => {}
78            Reachability::FromSingleGoto(from_block) => match &lowered.blocks[from_block].end {
79                FlatBlockEnd::Goto(_, remapping) => {
80                    for (dst, src) in remapping.iter() {
81                        if let Some(v) = ctx.as_const(src.var_id) {
82                            ctx.var_info.insert(*dst, VarInfo::Const(v.clone()));
83                        }
84                    }
85                }
86                _ => unreachable!("Expected a goto end"),
87            },
88        }
89        let block = &mut lowered.blocks[BlockId(block_id)];
90        let mut additional_consts = vec![];
91        for stmt in block.statements.iter_mut() {
92            ctx.maybe_replace_inputs(stmt.inputs_mut());
93            match stmt {
94                Statement::Const(StatementConst { value, output }) => {
95                    // Preventing the insertion of non-member consts values (such as a `Box` of a
96                    // const).
97                    if matches!(
98                        value,
99                        ConstValue::Int(..)
100                            | ConstValue::Struct(..)
101                            | ConstValue::Enum(..)
102                            | ConstValue::NonZero(..)
103                    ) {
104                        ctx.var_info.insert(*output, VarInfo::Const(value.clone()));
105                    }
106                }
107                Statement::Snapshot(stmt) => {
108                    if let Some(info) = ctx.var_info.get(&stmt.input.var_id).cloned() {
109                        ctx.var_info.insert(stmt.original(), info.clone());
110                        ctx.var_info.insert(stmt.snapshot(), VarInfo::Snapshot(info.into()));
111                    }
112                }
113                Statement::Desnap(StatementDesnap { input, output }) => {
114                    if let Some(VarInfo::Snapshot(info)) = ctx.var_info.get(&input.var_id) {
115                        ctx.var_info.insert(*output, info.as_ref().clone());
116                    }
117                }
118                Statement::Call(call_stmt) => {
119                    if let Some(updated_stmt) =
120                        ctx.handle_statement_call(call_stmt, &mut additional_consts)
121                    {
122                        *stmt = Statement::Const(updated_stmt);
123                    }
124                }
125                Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
126                    let mut const_args = vec![];
127                    let mut all_args = vec![];
128                    let mut contains_info = false;
129                    for input in inputs.iter() {
130                        let Some(info) = ctx.var_info.get(&input.var_id) else {
131                            all_args.push(
132                                ctx.variables[input.var_id]
133                                    .copyable
134                                    .is_ok()
135                                    .then_some(VarInfo::Var(*input)),
136                            );
137                            continue;
138                        };
139                        contains_info = true;
140                        if let VarInfo::Const(value) = info {
141                            const_args.push(value.clone());
142                        }
143                        all_args.push(Some(info.clone()));
144                    }
145                    if const_args.len() == inputs.len() {
146                        let value = ConstValue::Struct(const_args, ctx.variables[*output].ty);
147                        ctx.var_info.insert(*output, VarInfo::Const(value));
148                    } else if contains_info {
149                        ctx.var_info.insert(*output, VarInfo::Struct(all_args));
150                    }
151                }
152                Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
153                    if let Some(mut info) = ctx.var_info.get(&input.var_id) {
154                        let mut n_snapshot = 0;
155                        while let VarInfo::Snapshot(inner) = info {
156                            info = inner.as_ref();
157                            n_snapshot += 1;
158                        }
159                        let wrap_with_snapshots = |mut info| {
160                            for _ in 0..n_snapshot {
161                                info = VarInfo::Snapshot(Box::new(info));
162                            }
163                            info
164                        };
165                        match info {
166                            VarInfo::Const(ConstValue::Struct(member_values, _)) => {
167                                for (output, value) in zip_eq(outputs, member_values.clone()) {
168                                    ctx.var_info.insert(
169                                        *output,
170                                        wrap_with_snapshots(VarInfo::Const(value)),
171                                    );
172                                }
173                            }
174                            VarInfo::Struct(members) => {
175                                for (output, member) in zip_eq(outputs, members.clone()) {
176                                    if let Some(member) = member {
177                                        ctx.var_info.insert(*output, wrap_with_snapshots(member));
178                                    }
179                                }
180                            }
181                            _ => {}
182                        }
183                    }
184                }
185                Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
186                    if let Some(VarInfo::Const(val)) = ctx.var_info.get(&input.var_id) {
187                        let value = ConstValue::Enum(variant.clone(), val.clone().into());
188                        ctx.var_info.insert(*output, VarInfo::Const(value.clone()));
189                    }
190                }
191            }
192        }
193        block.statements.splice(0..0, additional_consts.into_iter().map(Statement::Const));
194
195        match &mut block.end {
196            FlatBlockEnd::Goto(_, remappings) => {
197                for (_, v) in remappings.iter_mut() {
198                    ctx.maybe_replace_input(v);
199                }
200            }
201            FlatBlockEnd::Match { info } => {
202                ctx.maybe_replace_inputs(info.inputs_mut());
203                match info {
204                    MatchInfo::Enum(MatchEnumInfo { input, arms, .. }) => {
205                        if let Some(VarInfo::Const(ConstValue::Enum(variant, value))) =
206                            ctx.var_info.get(&input.var_id)
207                        {
208                            let arm = &arms[variant.idx];
209                            ctx.var_info
210                                .insert(arm.var_ids[0], VarInfo::Const(value.as_ref().clone()));
211                        }
212                    }
213                    MatchInfo::Extern(info) => {
214                        if let Some((extra_stmt, updated_end)) = ctx.handle_extern_block_end(info) {
215                            if let Some(stmt) = extra_stmt {
216                                block.statements.push(Statement::Const(stmt));
217                            }
218                            block.end = updated_end;
219                        }
220                    }
221                    MatchInfo::Value(..) => {}
222                }
223            }
224            FlatBlockEnd::Return(ref mut inputs, _) => ctx.maybe_replace_inputs(inputs),
225            FlatBlockEnd::Panic(_) | FlatBlockEnd::NotSet => unreachable!(),
226        }
227        match &block.end {
228            FlatBlockEnd::Goto(dst_block_id, _) => {
229                reachability[dst_block_id.0] = match reachability[dst_block_id.0] {
230                    Reachability::Unreachable => Reachability::FromSingleGoto(BlockId(block_id)),
231                    Reachability::FromSingleGoto(_) | Reachability::Any => Reachability::Any,
232                }
233            }
234            FlatBlockEnd::Match { info } => {
235                for arm in info.arms() {
236                    assert_eq!(reachability[arm.block_id.0], Reachability::Unreachable);
237                    reachability[arm.block_id.0] = Reachability::Any;
238                }
239            }
240            FlatBlockEnd::NotSet | FlatBlockEnd::Return(..) | FlatBlockEnd::Panic(..) => {}
241        }
242    }
243}
244
245struct ConstFoldingContext<'a> {
246    /// The used database.
247    db: &'a dyn LoweringGroup,
248    /// The variables arena, mostly used to get the type of variables.
249    variables: &'a mut Arena<Variable>,
250    /// The accumulated information about the const values of variables.
251    var_info: UnorderedHashMap<VariableId, VarInfo>,
252    /// The libfunc information.
253    libfunc_info: &'a ConstFoldingLibfuncInfo,
254}
255
256impl ConstFoldingContext<'_> {
257    /// Handles a statement call.
258    ///
259    /// Returns None if no additional changes are required.
260    /// If changes are required, returns an updated const-statement (to override the current
261    /// statement).
262    /// May add an additional const to `additional_consts` if just replacing the current statement
263    /// is not enough.
264    fn handle_statement_call(
265        &mut self,
266        stmt: &mut StatementCall,
267        additional_consts: &mut Vec<StatementConst>,
268    ) -> Option<StatementConst> {
269        if stmt.function == self.panic_with_felt252 {
270            let val = self.as_const(stmt.inputs[0].var_id)?;
271            stmt.inputs.clear();
272            stmt.function = ModuleHelper::core(self.db.upcast())
273                .function_id(
274                    "panic_with_const_felt252",
275                    vec![GenericArgumentId::Constant(val.clone().intern(self.db))],
276                )
277                .lowered(self.db);
278            return None;
279        }
280        let (id, _generic_args) = stmt.function.get_extern(self.db)?;
281        if id == self.felt_sub {
282            // (a - 0) can be replaced by a.
283            let val = self.as_int(stmt.inputs[1].var_id)?;
284            if val.is_zero() {
285                self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
286            }
287            None
288        } else if self.wide_mul_fns.contains(&id) {
289            let lhs = self.as_int_ex(stmt.inputs[0].var_id);
290            let rhs = self.as_int(stmt.inputs[1].var_id);
291            let output = stmt.outputs[0];
292            if lhs.map(|(v, _)| v.is_zero()).unwrap_or_default()
293                || rhs.map(Zero::is_zero).unwrap_or_default()
294            {
295                return Some(self.propagate_zero_and_get_statement(output));
296            }
297            let (lhs, nz_ty) = lhs?;
298            Some(self.propagate_const_and_get_statement(lhs * rhs?, stmt.outputs[0], nz_ty))
299        } else if id == self.bounded_int_add || id == self.bounded_int_sub {
300            let lhs = self.as_int(stmt.inputs[0].var_id)?;
301            let rhs = self.as_int(stmt.inputs[1].var_id)?;
302            let value = if id == self.bounded_int_add { lhs + rhs } else { lhs - rhs };
303            Some(self.propagate_const_and_get_statement(value, stmt.outputs[0], false))
304        } else if self.div_rem_fns.contains(&id) {
305            let lhs = self.as_int(stmt.inputs[0].var_id);
306            if lhs.map(Zero::is_zero).unwrap_or_default() {
307                additional_consts.push(self.propagate_zero_and_get_statement(stmt.outputs[1]));
308                return Some(self.propagate_zero_and_get_statement(stmt.outputs[0]));
309            }
310            let rhs = self.as_int(stmt.inputs[1].var_id)?;
311            let (q, r) = lhs?.div_rem(rhs);
312            let q_output = stmt.outputs[0];
313            let q_value = ConstValue::Int(q, self.variables[q_output].ty);
314            self.var_info.insert(q_output, VarInfo::Const(q_value.clone()));
315            let r_output = stmt.outputs[1];
316            let r_value = ConstValue::Int(r, self.variables[r_output].ty);
317            self.var_info.insert(r_output, VarInfo::Const(r_value.clone()));
318            additional_consts.push(StatementConst { value: r_value, output: r_output });
319            Some(StatementConst { value: q_value, output: q_output })
320        } else if id == self.storage_base_address_from_felt252 {
321            let input_var = stmt.inputs[0].var_id;
322            if let Some(ConstValue::Int(val, ty)) = self.as_const(input_var) {
323                stmt.inputs.clear();
324                stmt.function =
325                    ModuleHelper { db: self.db.upcast(), id: self.storage_access_module }
326                        .function_id(
327                            "storage_base_address_const",
328                            vec![GenericArgumentId::Constant(
329                                ConstValue::Int(val.clone(), *ty).intern(self.db),
330                            )],
331                        )
332                        .lowered(self.db);
333            }
334            None
335        } else if id == self.into_box {
336            let const_value = match self.var_info.get(&stmt.inputs[0].var_id)? {
337                VarInfo::Const(val) => val,
338                VarInfo::Snapshot(info) => try_extract_matches!(info.as_ref(), VarInfo::Const)?,
339                _ => return None,
340            };
341            let value = ConstValue::Boxed(const_value.clone().into());
342            // Not inserting the value into the `var_info` map because the
343            // resulting box isn't an actual const at the Sierra level.
344            Some(StatementConst { value, output: stmt.outputs[0] })
345        } else if id == self.upcast {
346            let int_value = self.as_int(stmt.inputs[0].var_id)?;
347            let output = stmt.outputs[0];
348            let value = ConstValue::Int(int_value.clone(), self.variables[output].ty);
349            self.var_info.insert(output, VarInfo::Const(value.clone()));
350            Some(StatementConst { value, output })
351        } else {
352            None
353        }
354    }
355
356    /// Adds `value` as a const to `var_info` and return a const statement for it.
357    fn propagate_const_and_get_statement(
358        &mut self,
359        value: BigInt,
360        output: VariableId,
361        nz_ty: bool,
362    ) -> StatementConst {
363        let mut value = ConstValue::Int(value, self.variables[output].ty);
364        if nz_ty {
365            value = ConstValue::NonZero(Box::new(value));
366        }
367        self.var_info.insert(output, VarInfo::Const(value.clone()));
368        StatementConst { value, output }
369    }
370
371    /// Adds 0 const to `var_info` and return a const statement for it.
372    fn propagate_zero_and_get_statement(&mut self, output: VariableId) -> StatementConst {
373        self.propagate_const_and_get_statement(BigInt::zero(), output, false)
374    }
375
376    /// Handles the end of an extern block.
377    /// Returns None if no additional changes are required.
378    /// If changes are required, returns a possible additional const-statement to the block, as well
379    /// as an updated block end.
380    fn handle_extern_block_end(
381        &mut self,
382        info: &mut MatchExternInfo,
383    ) -> Option<(Option<StatementConst>, FlatBlockEnd)> {
384        let (id, generic_args) = info.function.get_extern(self.db)?;
385        if self.nz_fns.contains(&id) {
386            let val = self.as_const(info.inputs[0].var_id)?;
387            let is_zero = match val {
388                ConstValue::Int(v, _) => v.is_zero(),
389                ConstValue::Struct(s, _) => s.iter().all(|v| {
390                    v.clone().into_int().expect("Expected ConstValue::Int for size").is_zero()
391                }),
392                _ => unreachable!(),
393            };
394            Some(if is_zero {
395                (None, FlatBlockEnd::Goto(info.arms[0].block_id, Default::default()))
396            } else {
397                let arm = &info.arms[1];
398                let nz_var = arm.var_ids[0];
399                let nz_val = ConstValue::NonZero(Box::new(val.clone()));
400                self.var_info.insert(nz_var, VarInfo::Const(nz_val.clone()));
401                (
402                    Some(StatementConst { value: nz_val, output: nz_var }),
403                    FlatBlockEnd::Goto(arm.block_id, Default::default()),
404                )
405            })
406        } else if self.eq_fns.contains(&id) {
407            let lhs = self.as_int(info.inputs[0].var_id);
408            let rhs = self.as_int(info.inputs[1].var_id);
409            if (lhs.map(Zero::is_zero).unwrap_or_default() && rhs.is_none())
410                || (rhs.map(Zero::is_zero).unwrap_or_default() && lhs.is_none())
411            {
412                let db = self.db.upcast();
413                let nz_input = info.inputs[if lhs.is_some() { 1 } else { 0 }];
414                let var = &self.variables[nz_input.var_id].clone();
415                let function = self.type_value_ranges.get(&var.ty)?.is_zero;
416                let unused_nz_var = Variable::new(
417                    self.db,
418                    ImplLookupContext::default(),
419                    corelib::core_nonzero_ty(db, var.ty),
420                    var.location,
421                );
422                let unused_nz_var = self.variables.alloc(unused_nz_var);
423                return Some((
424                    None,
425                    FlatBlockEnd::Match {
426                        info: MatchInfo::Extern(MatchExternInfo {
427                            function,
428                            inputs: vec![nz_input],
429                            arms: vec![
430                                MatchArm {
431                                    arm_selector: MatchArmSelector::VariantId(
432                                        corelib::jump_nz_zero_variant(db, var.ty),
433                                    ),
434                                    block_id: info.arms[1].block_id,
435                                    var_ids: vec![],
436                                },
437                                MatchArm {
438                                    arm_selector: MatchArmSelector::VariantId(
439                                        corelib::jump_nz_nonzero_variant(db, var.ty),
440                                    ),
441                                    block_id: info.arms[0].block_id,
442                                    var_ids: vec![unused_nz_var],
443                                },
444                            ],
445                            location: info.location,
446                        }),
447                    },
448                ));
449            }
450            Some((
451                None,
452                FlatBlockEnd::Goto(
453                    info.arms[if lhs? == rhs? { 1 } else { 0 }].block_id,
454                    Default::default(),
455                ),
456            ))
457        } else if self.uadd_fns.contains(&id)
458            || self.usub_fns.contains(&id)
459            || self.diff_fns.contains(&id)
460            || self.iadd_fns.contains(&id)
461            || self.isub_fns.contains(&id)
462        {
463            let rhs = self.as_int(info.inputs[1].var_id);
464            if rhs.map(Zero::is_zero).unwrap_or_default() && !self.diff_fns.contains(&id) {
465                let arm = &info.arms[0];
466                self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[0]));
467                return Some((None, FlatBlockEnd::Goto(arm.block_id, Default::default())));
468            }
469            let lhs = self.as_int(info.inputs[0].var_id);
470            let value = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
471                if lhs.map(Zero::is_zero).unwrap_or_default() {
472                    let arm = &info.arms[0];
473                    self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[1]));
474                    return Some((None, FlatBlockEnd::Goto(arm.block_id, Default::default())));
475                }
476                lhs? + rhs?
477            } else {
478                lhs? - rhs?
479            };
480            let ty = self.variables[info.arms[0].var_ids[0]].ty;
481            let range = self.type_value_ranges.get(&ty)?;
482            let (arm_index, value) = match range.normalized(value) {
483                NormalizedResult::InRange(value) => (0, value),
484                NormalizedResult::Under(value) => (1, value),
485                NormalizedResult::Over(value) => (
486                    if self.iadd_fns.contains(&id) || self.isub_fns.contains(&id) { 2 } else { 1 },
487                    value,
488                ),
489            };
490            let arm = &info.arms[arm_index];
491            let actual_output = arm.var_ids[0];
492            let value = ConstValue::Int(value, ty);
493            self.var_info.insert(actual_output, VarInfo::Const(value.clone()));
494            Some((
495                Some(StatementConst { value, output: actual_output }),
496                FlatBlockEnd::Goto(arm.block_id, Default::default()),
497            ))
498        } else if id == self.downcast {
499            let input_var = info.inputs[0].var_id;
500            let value = self.as_int(input_var)?;
501            let success_output = info.arms[0].var_ids[0];
502            let ty = self.variables[success_output].ty;
503            let range = self.type_value_ranges.get(&ty)?;
504            Some(if let NormalizedResult::InRange(value) = range.normalized(value.clone()) {
505                let value = ConstValue::Int(value, ty);
506                self.var_info.insert(success_output, VarInfo::Const(value.clone()));
507                (
508                    Some(StatementConst { value, output: success_output }),
509                    FlatBlockEnd::Goto(info.arms[0].block_id, Default::default()),
510                )
511            } else {
512                (None, FlatBlockEnd::Goto(info.arms[1].block_id, Default::default()))
513            })
514        } else if id == self.bounded_int_constrain {
515            let input_var = info.inputs[0].var_id;
516            let (value, nz_ty) = self.as_int_ex(input_var)?;
517            let generic_arg = generic_args[1];
518            let constrain_value = extract_matches!(generic_arg, GenericArgumentId::Constant)
519                .lookup_intern(self.db)
520                .into_int()
521                .unwrap();
522            let arm_idx = if value < &constrain_value { 0 } else { 1 };
523            let output = info.arms[arm_idx].var_ids[0];
524            Some((
525                Some(self.propagate_const_and_get_statement(value.clone(), output, nz_ty)),
526                FlatBlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()),
527            ))
528        } else if id == self.array_get {
529            if self.as_int(info.inputs[1].var_id)?.is_zero() {
530                if let [success, failure] = info.arms.as_mut_slice() {
531                    let arr = info.inputs[0].var_id;
532                    let unused_arr_output0 = self.variables.alloc(self.variables[arr].clone());
533                    let unused_arr_output1 = self.variables.alloc(self.variables[arr].clone());
534                    info.inputs.truncate(1);
535                    info.function = ModuleHelper { db: self.db.upcast(), id: self.array_module }
536                        .function_id("array_snapshot_pop_front", generic_args)
537                        .lowered(self.db);
538                    success.var_ids.insert(0, unused_arr_output0);
539                    failure.var_ids.insert(0, unused_arr_output1);
540                }
541            }
542            None
543        } else {
544            None
545        }
546    }
547
548    /// Returns the const value of a variable if it exists.
549    fn as_const(&self, var_id: VariableId) -> Option<&ConstValue> {
550        try_extract_matches!(self.var_info.get(&var_id)?, VarInfo::Const)
551    }
552
553    /// Return the const value as an int if it exists and is an integer, additionally, if it is of a
554    /// non-zero type.
555    fn as_int_ex(&self, var_id: VariableId) -> Option<(&BigInt, bool)> {
556        match self.as_const(var_id)? {
557            ConstValue::Int(value, _) => Some((value, false)),
558            ConstValue::NonZero(const_value) => {
559                if let ConstValue::Int(value, _) = const_value.as_ref() {
560                    Some((value, true))
561                } else {
562                    None
563                }
564            }
565            _ => None,
566        }
567    }
568
569    /// Return the const value as a int if it exists and is an integer.
570    fn as_int(&self, var_id: VariableId) -> Option<&BigInt> {
571        Some(self.as_int_ex(var_id)?.0)
572    }
573
574    /// Replaces the inputs in place if they are in the var_info map.
575    fn maybe_replace_inputs(&mut self, inputs: &mut [VarUsage]) {
576        for input in inputs {
577            self.maybe_replace_input(input);
578        }
579    }
580
581    /// Replaces the input in place if it is in the var_info map.
582    fn maybe_replace_input(&mut self, input: &mut VarUsage) {
583        if let Some(VarInfo::Var(new_var)) = self.var_info.get(&input.var_id) {
584            *input = *new_var;
585        }
586    }
587}
588
589/// Query implementation of [LoweringGroup::priv_const_folding_info].
590pub fn priv_const_folding_info(
591    db: &dyn LoweringGroup,
592) -> Arc<crate::optimizations::const_folding::ConstFoldingLibfuncInfo> {
593    Arc::new(ConstFoldingLibfuncInfo::new(db))
594}
595
596/// Holds static information about libfuncs required for the optimization.
597#[derive(Debug, PartialEq, Eq)]
598pub struct ConstFoldingLibfuncInfo {
599    /// The `felt252_sub` libfunc.
600    felt_sub: ExternFunctionId,
601    /// The `into_box` libfunc.
602    into_box: ExternFunctionId,
603    /// The `upcast` libfunc.
604    upcast: ExternFunctionId,
605    /// The `downcast` libfunc.
606    downcast: ExternFunctionId,
607    /// The set of functions that check if a number is zero.
608    nz_fns: OrderedHashSet<ExternFunctionId>,
609    /// The set of functions that check if numbers are equal.
610    eq_fns: OrderedHashSet<ExternFunctionId>,
611    /// The set of functions to add unsigned ints.
612    uadd_fns: OrderedHashSet<ExternFunctionId>,
613    /// The set of functions to subtract unsigned ints.
614    usub_fns: OrderedHashSet<ExternFunctionId>,
615    /// The set of functions to get the difference of signed ints.
616    diff_fns: OrderedHashSet<ExternFunctionId>,
617    /// The set of functions to add signed ints.
618    iadd_fns: OrderedHashSet<ExternFunctionId>,
619    /// The set of functions to subtract signed ints.
620    isub_fns: OrderedHashSet<ExternFunctionId>,
621    /// The set of functions to multiply integers.
622    wide_mul_fns: OrderedHashSet<ExternFunctionId>,
623    /// The set of functions to divide and get the remainder of integers.
624    div_rem_fns: OrderedHashSet<ExternFunctionId>,
625    /// The `bounded_int_add` libfunc.
626    bounded_int_add: ExternFunctionId,
627    /// The `bounded_int_sub` libfunc.
628    bounded_int_sub: ExternFunctionId,
629    /// The `bounded_int_constrain` libfunc.
630    bounded_int_constrain: ExternFunctionId,
631    /// The array module.
632    array_module: ModuleId,
633    /// The `array_get` libfunc.
634    array_get: ExternFunctionId,
635    /// The storage access module.
636    storage_access_module: ModuleId,
637    /// The `storage_base_address_from_felt252` libfunc.
638    storage_base_address_from_felt252: ExternFunctionId,
639    /// The `core::panic_with_felt252` function.
640    panic_with_felt252: FunctionId,
641    /// Type ranges.
642    type_value_ranges: OrderedHashMap<TypeId, TypeInfo>,
643}
644impl ConstFoldingLibfuncInfo {
645    fn new(db: &dyn LoweringGroup) -> Self {
646        let core = ModuleHelper::core(db.upcast());
647        let felt_sub = core.extern_function_id("felt252_sub");
648        let box_module = core.submodule("box");
649        let into_box = box_module.extern_function_id("into_box");
650        let integer_module = core.submodule("integer");
651        let bounded_int_module = core.submodule("internal").submodule("bounded_int");
652        let upcast = integer_module.extern_function_id("upcast");
653        let downcast = integer_module.extern_function_id("downcast");
654        let array_module = core.submodule("array");
655        let array_get = array_module.extern_function_id("array_get");
656        let starknet_module = core.submodule("starknet");
657        let storage_access_module = starknet_module.submodule("storage_access");
658        let storage_base_address_from_felt252 =
659            storage_access_module.extern_function_id("storage_base_address_from_felt252");
660        let nz_fns = OrderedHashSet::<_>::from_iter(chain!(
661            [
662                core.extern_function_id("felt252_is_zero"),
663                bounded_int_module.extern_function_id("bounded_int_is_zero")
664            ],
665            ["u8", "u16", "u32", "u64", "u128", "u256"]
666                .map(|ty| integer_module.extern_function_id(format!("{ty}_is_zero")))
667        ));
668        let utypes = ["u8", "u16", "u32", "u64", "u128"];
669        let itypes = ["i8", "i16", "i32", "i64", "i128"];
670        let eq_fns = OrderedHashSet::<_>::from_iter(
671            chain!(utypes, itypes).map(|ty| integer_module.extern_function_id(format!("{ty}_eq"))),
672        );
673        let uadd_fns = OrderedHashSet::<_>::from_iter(
674            utypes.map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_add"))),
675        );
676        let usub_fns = OrderedHashSet::<_>::from_iter(
677            utypes.map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_sub"))),
678        );
679        let diff_fns = OrderedHashSet::<_>::from_iter(
680            itypes.map(|ty| integer_module.extern_function_id(format!("{ty}_diff"))),
681        );
682        let iadd_fns = OrderedHashSet::<_>::from_iter(
683            itypes
684                .map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_add_impl"))),
685        );
686        let isub_fns = OrderedHashSet::<_>::from_iter(
687            itypes
688                .map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_sub_impl"))),
689        );
690        let wide_mul_fns = OrderedHashSet::<_>::from_iter(chain!(
691            [bounded_int_module.extern_function_id("bounded_int_mul")],
692            ["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
693                .map(|ty| integer_module.extern_function_id(format!("{ty}_wide_mul"))),
694        ));
695        let div_rem_fns = OrderedHashSet::<_>::from_iter(chain!(
696            [bounded_int_module.extern_function_id("bounded_int_div_rem")],
697            utypes.map(|ty| integer_module.extern_function_id(format!("{ty}_safe_divmod"))),
698        ));
699        let bounded_int_add = bounded_int_module.extern_function_id("bounded_int_add");
700        let bounded_int_sub = bounded_int_module.extern_function_id("bounded_int_sub");
701        let bounded_int_constrain = bounded_int_module.extern_function_id("bounded_int_constrain");
702        let type_value_ranges = OrderedHashMap::from_iter(
703            [
704                ("u8", BigInt::ZERO, u8::MAX.into(), false),
705                ("u16", BigInt::ZERO, u16::MAX.into(), false),
706                ("u32", BigInt::ZERO, u32::MAX.into(), false),
707                ("u64", BigInt::ZERO, u64::MAX.into(), false),
708                ("u128", BigInt::ZERO, u128::MAX.into(), false),
709                ("u256", BigInt::ZERO, BigInt::from(1) << 256, false),
710                ("i8", i8::MIN.into(), i8::MAX.into(), true),
711                ("i16", i16::MIN.into(), i16::MAX.into(), true),
712                ("i32", i32::MIN.into(), i32::MAX.into(), true),
713                ("i64", i64::MIN.into(), i64::MAX.into(), true),
714                ("i128", i128::MIN.into(), i128::MAX.into(), true),
715            ]
716            .map(
717                |(ty_name, min, max, as_bounded_int): (&str, BigInt, BigInt, bool)| {
718                    let ty = corelib::get_core_ty_by_name(db.upcast(), ty_name.into(), vec![]);
719                    let is_zero = if as_bounded_int {
720                        bounded_int_module
721                            .function_id("bounded_int_is_zero", vec![GenericArgumentId::Type(ty)])
722                    } else {
723                        integer_module.function_id(format!("{ty_name}_is_zero"), vec![])
724                    }
725                    .lowered(db);
726                    let info = TypeInfo { min, max, is_zero };
727                    (ty, info)
728                },
729            ),
730        );
731        Self {
732            felt_sub,
733            into_box,
734            upcast,
735            downcast,
736            nz_fns,
737            eq_fns,
738            uadd_fns,
739            usub_fns,
740            diff_fns,
741            iadd_fns,
742            isub_fns,
743            wide_mul_fns,
744            div_rem_fns,
745            bounded_int_add,
746            bounded_int_sub,
747            bounded_int_constrain,
748            array_module: array_module.id,
749            array_get,
750            storage_access_module: storage_access_module.id,
751            storage_base_address_from_felt252,
752            panic_with_felt252: core.function_id("panic_with_felt252", vec![]).lowered(db),
753            type_value_ranges,
754        }
755    }
756}
757
758impl std::ops::Deref for ConstFoldingContext<'_> {
759    type Target = ConstFoldingLibfuncInfo;
760    fn deref(&self) -> &ConstFoldingLibfuncInfo {
761        self.libfunc_info
762    }
763}
764
765/// The information of a type required for const foldings.
766#[derive(Debug, PartialEq, Eq)]
767struct TypeInfo {
768    /// The minimum value of the type.
769    min: BigInt,
770    /// The maximum value of the type.
771    max: BigInt,
772    /// The function to check if the value is zero for the type.
773    is_zero: FunctionId,
774}
775impl TypeInfo {
776    /// Normalizes the value to the range.
777    /// Assumes the value is within size of range of the range.
778    fn normalized(&self, value: BigInt) -> NormalizedResult {
779        if value < self.min {
780            NormalizedResult::Under(value - &self.min + &self.max + 1)
781        } else if value > self.max {
782            NormalizedResult::Over(value + &self.min - &self.max - 1)
783        } else {
784            NormalizedResult::InRange(value)
785        }
786    }
787}
788
789/// The result of normalizing a value to a range.
790enum NormalizedResult {
791    /// The original value is in the range, carries the value, or an equivalent value.
792    InRange(BigInt),
793    /// The original value is larger than range max, carries the normalized value.
794    Over(BigInt),
795    /// The original value is smaller than range min, carries the normalized value.
796    Under(BigInt),
797}