1#[cfg(test)]
2#[path = "const_folding_test.rs"]
3mod test;
4
5use std::sync::Arc;
6
7use cairo_lang_defs::ids::{ExternFunctionId, ModuleId};
8use cairo_lang_semantic::helper::ModuleHelper;
9use cairo_lang_semantic::items::constant::ConstValue;
10use cairo_lang_semantic::items::imp::ImplLookupContext;
11use cairo_lang_semantic::{GenericArgumentId, MatchArmSelector, TypeId, corelib};
12use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
13use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
14use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
15use cairo_lang_utils::{Intern, LookupIntern, extract_matches, try_extract_matches};
16use id_arena::Arena;
17use itertools::{chain, zip_eq};
18use num_bigint::BigInt;
19use num_integer::Integer;
20use num_traits::Zero;
21
22use crate::db::LoweringGroup;
23use crate::ids::{FunctionId, SemanticFunctionIdEx};
24use crate::{
25 BlockId, FlatBlockEnd, FlatLowered, MatchArm, MatchEnumInfo, MatchExternInfo, MatchInfo,
26 Statement, StatementCall, StatementConst, StatementDesnap, StatementEnumConstruct,
27 StatementStructConstruct, StatementStructDestructure, VarUsage, Variable, VariableId,
28};
29
30#[derive(Debug, Clone)]
33enum VarInfo {
34 Const(ConstValue),
36 Var(VarUsage),
38 Snapshot(Box<VarInfo>),
40 Struct(Vec<Option<VarInfo>>),
43}
44
45#[derive(Debug, Clone, Copy, PartialEq)]
46enum Reachability {
47 Unreachable,
49 FromSingleGoto(BlockId),
52 Any,
55}
56
57pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
60 if db.optimization_config().skip_const_folding || lowered.blocks.is_empty() {
61 return;
62 }
63 let libfunc_info = priv_const_folding_info(db);
64 let mut ctx = ConstFoldingContext {
67 db,
68 var_info: UnorderedHashMap::default(),
69 variables: &mut lowered.variables,
70 libfunc_info: &libfunc_info,
71 };
72 let mut reachability = vec![Reachability::Unreachable; lowered.blocks.len()];
73 reachability[0] = Reachability::Any;
74 for block_id in 0..lowered.blocks.len() {
75 match reachability[block_id] {
76 Reachability::Unreachable => continue,
77 Reachability::Any => {}
78 Reachability::FromSingleGoto(from_block) => match &lowered.blocks[from_block].end {
79 FlatBlockEnd::Goto(_, remapping) => {
80 for (dst, src) in remapping.iter() {
81 if let Some(v) = ctx.as_const(src.var_id) {
82 ctx.var_info.insert(*dst, VarInfo::Const(v.clone()));
83 }
84 }
85 }
86 _ => unreachable!("Expected a goto end"),
87 },
88 }
89 let block = &mut lowered.blocks[BlockId(block_id)];
90 let mut additional_consts = vec![];
91 for stmt in block.statements.iter_mut() {
92 ctx.maybe_replace_inputs(stmt.inputs_mut());
93 match stmt {
94 Statement::Const(StatementConst { value, output }) => {
95 if matches!(
98 value,
99 ConstValue::Int(..)
100 | ConstValue::Struct(..)
101 | ConstValue::Enum(..)
102 | ConstValue::NonZero(..)
103 ) {
104 ctx.var_info.insert(*output, VarInfo::Const(value.clone()));
105 }
106 }
107 Statement::Snapshot(stmt) => {
108 if let Some(info) = ctx.var_info.get(&stmt.input.var_id).cloned() {
109 ctx.var_info.insert(stmt.original(), info.clone());
110 ctx.var_info.insert(stmt.snapshot(), VarInfo::Snapshot(info.into()));
111 }
112 }
113 Statement::Desnap(StatementDesnap { input, output }) => {
114 if let Some(VarInfo::Snapshot(info)) = ctx.var_info.get(&input.var_id) {
115 ctx.var_info.insert(*output, info.as_ref().clone());
116 }
117 }
118 Statement::Call(call_stmt) => {
119 if let Some(updated_stmt) =
120 ctx.handle_statement_call(call_stmt, &mut additional_consts)
121 {
122 *stmt = Statement::Const(updated_stmt);
123 }
124 }
125 Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
126 let mut const_args = vec![];
127 let mut all_args = vec![];
128 let mut contains_info = false;
129 for input in inputs.iter() {
130 let Some(info) = ctx.var_info.get(&input.var_id) else {
131 all_args.push(
132 ctx.variables[input.var_id]
133 .copyable
134 .is_ok()
135 .then_some(VarInfo::Var(*input)),
136 );
137 continue;
138 };
139 contains_info = true;
140 if let VarInfo::Const(value) = info {
141 const_args.push(value.clone());
142 }
143 all_args.push(Some(info.clone()));
144 }
145 if const_args.len() == inputs.len() {
146 let value = ConstValue::Struct(const_args, ctx.variables[*output].ty);
147 ctx.var_info.insert(*output, VarInfo::Const(value));
148 } else if contains_info {
149 ctx.var_info.insert(*output, VarInfo::Struct(all_args));
150 }
151 }
152 Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
153 if let Some(mut info) = ctx.var_info.get(&input.var_id) {
154 let mut n_snapshot = 0;
155 while let VarInfo::Snapshot(inner) = info {
156 info = inner.as_ref();
157 n_snapshot += 1;
158 }
159 let wrap_with_snapshots = |mut info| {
160 for _ in 0..n_snapshot {
161 info = VarInfo::Snapshot(Box::new(info));
162 }
163 info
164 };
165 match info {
166 VarInfo::Const(ConstValue::Struct(member_values, _)) => {
167 for (output, value) in zip_eq(outputs, member_values.clone()) {
168 ctx.var_info.insert(
169 *output,
170 wrap_with_snapshots(VarInfo::Const(value)),
171 );
172 }
173 }
174 VarInfo::Struct(members) => {
175 for (output, member) in zip_eq(outputs, members.clone()) {
176 if let Some(member) = member {
177 ctx.var_info.insert(*output, wrap_with_snapshots(member));
178 }
179 }
180 }
181 _ => {}
182 }
183 }
184 }
185 Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
186 if let Some(VarInfo::Const(val)) = ctx.var_info.get(&input.var_id) {
187 let value = ConstValue::Enum(variant.clone(), val.clone().into());
188 ctx.var_info.insert(*output, VarInfo::Const(value.clone()));
189 }
190 }
191 }
192 }
193 block.statements.splice(0..0, additional_consts.into_iter().map(Statement::Const));
194
195 match &mut block.end {
196 FlatBlockEnd::Goto(_, remappings) => {
197 for (_, v) in remappings.iter_mut() {
198 ctx.maybe_replace_input(v);
199 }
200 }
201 FlatBlockEnd::Match { info } => {
202 ctx.maybe_replace_inputs(info.inputs_mut());
203 match info {
204 MatchInfo::Enum(MatchEnumInfo { input, arms, .. }) => {
205 if let Some(VarInfo::Const(ConstValue::Enum(variant, value))) =
206 ctx.var_info.get(&input.var_id)
207 {
208 let arm = &arms[variant.idx];
209 ctx.var_info
210 .insert(arm.var_ids[0], VarInfo::Const(value.as_ref().clone()));
211 }
212 }
213 MatchInfo::Extern(info) => {
214 if let Some((extra_stmt, updated_end)) = ctx.handle_extern_block_end(info) {
215 if let Some(stmt) = extra_stmt {
216 block.statements.push(Statement::Const(stmt));
217 }
218 block.end = updated_end;
219 }
220 }
221 MatchInfo::Value(..) => {}
222 }
223 }
224 FlatBlockEnd::Return(ref mut inputs, _) => ctx.maybe_replace_inputs(inputs),
225 FlatBlockEnd::Panic(_) | FlatBlockEnd::NotSet => unreachable!(),
226 }
227 match &block.end {
228 FlatBlockEnd::Goto(dst_block_id, _) => {
229 reachability[dst_block_id.0] = match reachability[dst_block_id.0] {
230 Reachability::Unreachable => Reachability::FromSingleGoto(BlockId(block_id)),
231 Reachability::FromSingleGoto(_) | Reachability::Any => Reachability::Any,
232 }
233 }
234 FlatBlockEnd::Match { info } => {
235 for arm in info.arms() {
236 assert_eq!(reachability[arm.block_id.0], Reachability::Unreachable);
237 reachability[arm.block_id.0] = Reachability::Any;
238 }
239 }
240 FlatBlockEnd::NotSet | FlatBlockEnd::Return(..) | FlatBlockEnd::Panic(..) => {}
241 }
242 }
243}
244
245struct ConstFoldingContext<'a> {
246 db: &'a dyn LoweringGroup,
248 variables: &'a mut Arena<Variable>,
250 var_info: UnorderedHashMap<VariableId, VarInfo>,
252 libfunc_info: &'a ConstFoldingLibfuncInfo,
254}
255
256impl ConstFoldingContext<'_> {
257 fn handle_statement_call(
265 &mut self,
266 stmt: &mut StatementCall,
267 additional_consts: &mut Vec<StatementConst>,
268 ) -> Option<StatementConst> {
269 if stmt.function == self.panic_with_felt252 {
270 let val = self.as_const(stmt.inputs[0].var_id)?;
271 stmt.inputs.clear();
272 stmt.function = ModuleHelper::core(self.db.upcast())
273 .function_id(
274 "panic_with_const_felt252",
275 vec![GenericArgumentId::Constant(val.clone().intern(self.db))],
276 )
277 .lowered(self.db);
278 return None;
279 }
280 let (id, _generic_args) = stmt.function.get_extern(self.db)?;
281 if id == self.felt_sub {
282 let val = self.as_int(stmt.inputs[1].var_id)?;
284 if val.is_zero() {
285 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
286 }
287 None
288 } else if self.wide_mul_fns.contains(&id) {
289 let lhs = self.as_int_ex(stmt.inputs[0].var_id);
290 let rhs = self.as_int(stmt.inputs[1].var_id);
291 let output = stmt.outputs[0];
292 if lhs.map(|(v, _)| v.is_zero()).unwrap_or_default()
293 || rhs.map(Zero::is_zero).unwrap_or_default()
294 {
295 return Some(self.propagate_zero_and_get_statement(output));
296 }
297 let (lhs, nz_ty) = lhs?;
298 Some(self.propagate_const_and_get_statement(lhs * rhs?, stmt.outputs[0], nz_ty))
299 } else if id == self.bounded_int_add || id == self.bounded_int_sub {
300 let lhs = self.as_int(stmt.inputs[0].var_id)?;
301 let rhs = self.as_int(stmt.inputs[1].var_id)?;
302 let value = if id == self.bounded_int_add { lhs + rhs } else { lhs - rhs };
303 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0], false))
304 } else if self.div_rem_fns.contains(&id) {
305 let lhs = self.as_int(stmt.inputs[0].var_id);
306 if lhs.map(Zero::is_zero).unwrap_or_default() {
307 additional_consts.push(self.propagate_zero_and_get_statement(stmt.outputs[1]));
308 return Some(self.propagate_zero_and_get_statement(stmt.outputs[0]));
309 }
310 let rhs = self.as_int(stmt.inputs[1].var_id)?;
311 let (q, r) = lhs?.div_rem(rhs);
312 let q_output = stmt.outputs[0];
313 let q_value = ConstValue::Int(q, self.variables[q_output].ty);
314 self.var_info.insert(q_output, VarInfo::Const(q_value.clone()));
315 let r_output = stmt.outputs[1];
316 let r_value = ConstValue::Int(r, self.variables[r_output].ty);
317 self.var_info.insert(r_output, VarInfo::Const(r_value.clone()));
318 additional_consts.push(StatementConst { value: r_value, output: r_output });
319 Some(StatementConst { value: q_value, output: q_output })
320 } else if id == self.storage_base_address_from_felt252 {
321 let input_var = stmt.inputs[0].var_id;
322 if let Some(ConstValue::Int(val, ty)) = self.as_const(input_var) {
323 stmt.inputs.clear();
324 stmt.function =
325 ModuleHelper { db: self.db.upcast(), id: self.storage_access_module }
326 .function_id(
327 "storage_base_address_const",
328 vec![GenericArgumentId::Constant(
329 ConstValue::Int(val.clone(), *ty).intern(self.db),
330 )],
331 )
332 .lowered(self.db);
333 }
334 None
335 } else if id == self.into_box {
336 let const_value = match self.var_info.get(&stmt.inputs[0].var_id)? {
337 VarInfo::Const(val) => val,
338 VarInfo::Snapshot(info) => try_extract_matches!(info.as_ref(), VarInfo::Const)?,
339 _ => return None,
340 };
341 let value = ConstValue::Boxed(const_value.clone().into());
342 Some(StatementConst { value, output: stmt.outputs[0] })
345 } else if id == self.upcast {
346 let int_value = self.as_int(stmt.inputs[0].var_id)?;
347 let output = stmt.outputs[0];
348 let value = ConstValue::Int(int_value.clone(), self.variables[output].ty);
349 self.var_info.insert(output, VarInfo::Const(value.clone()));
350 Some(StatementConst { value, output })
351 } else {
352 None
353 }
354 }
355
356 fn propagate_const_and_get_statement(
358 &mut self,
359 value: BigInt,
360 output: VariableId,
361 nz_ty: bool,
362 ) -> StatementConst {
363 let mut value = ConstValue::Int(value, self.variables[output].ty);
364 if nz_ty {
365 value = ConstValue::NonZero(Box::new(value));
366 }
367 self.var_info.insert(output, VarInfo::Const(value.clone()));
368 StatementConst { value, output }
369 }
370
371 fn propagate_zero_and_get_statement(&mut self, output: VariableId) -> StatementConst {
373 self.propagate_const_and_get_statement(BigInt::zero(), output, false)
374 }
375
376 fn handle_extern_block_end(
381 &mut self,
382 info: &mut MatchExternInfo,
383 ) -> Option<(Option<StatementConst>, FlatBlockEnd)> {
384 let (id, generic_args) = info.function.get_extern(self.db)?;
385 if self.nz_fns.contains(&id) {
386 let val = self.as_const(info.inputs[0].var_id)?;
387 let is_zero = match val {
388 ConstValue::Int(v, _) => v.is_zero(),
389 ConstValue::Struct(s, _) => s.iter().all(|v| {
390 v.clone().into_int().expect("Expected ConstValue::Int for size").is_zero()
391 }),
392 _ => unreachable!(),
393 };
394 Some(if is_zero {
395 (None, FlatBlockEnd::Goto(info.arms[0].block_id, Default::default()))
396 } else {
397 let arm = &info.arms[1];
398 let nz_var = arm.var_ids[0];
399 let nz_val = ConstValue::NonZero(Box::new(val.clone()));
400 self.var_info.insert(nz_var, VarInfo::Const(nz_val.clone()));
401 (
402 Some(StatementConst { value: nz_val, output: nz_var }),
403 FlatBlockEnd::Goto(arm.block_id, Default::default()),
404 )
405 })
406 } else if self.eq_fns.contains(&id) {
407 let lhs = self.as_int(info.inputs[0].var_id);
408 let rhs = self.as_int(info.inputs[1].var_id);
409 if (lhs.map(Zero::is_zero).unwrap_or_default() && rhs.is_none())
410 || (rhs.map(Zero::is_zero).unwrap_or_default() && lhs.is_none())
411 {
412 let db = self.db.upcast();
413 let nz_input = info.inputs[if lhs.is_some() { 1 } else { 0 }];
414 let var = &self.variables[nz_input.var_id].clone();
415 let function = self.type_value_ranges.get(&var.ty)?.is_zero;
416 let unused_nz_var = Variable::new(
417 self.db,
418 ImplLookupContext::default(),
419 corelib::core_nonzero_ty(db, var.ty),
420 var.location,
421 );
422 let unused_nz_var = self.variables.alloc(unused_nz_var);
423 return Some((
424 None,
425 FlatBlockEnd::Match {
426 info: MatchInfo::Extern(MatchExternInfo {
427 function,
428 inputs: vec![nz_input],
429 arms: vec![
430 MatchArm {
431 arm_selector: MatchArmSelector::VariantId(
432 corelib::jump_nz_zero_variant(db, var.ty),
433 ),
434 block_id: info.arms[1].block_id,
435 var_ids: vec![],
436 },
437 MatchArm {
438 arm_selector: MatchArmSelector::VariantId(
439 corelib::jump_nz_nonzero_variant(db, var.ty),
440 ),
441 block_id: info.arms[0].block_id,
442 var_ids: vec![unused_nz_var],
443 },
444 ],
445 location: info.location,
446 }),
447 },
448 ));
449 }
450 Some((
451 None,
452 FlatBlockEnd::Goto(
453 info.arms[if lhs? == rhs? { 1 } else { 0 }].block_id,
454 Default::default(),
455 ),
456 ))
457 } else if self.uadd_fns.contains(&id)
458 || self.usub_fns.contains(&id)
459 || self.diff_fns.contains(&id)
460 || self.iadd_fns.contains(&id)
461 || self.isub_fns.contains(&id)
462 {
463 let rhs = self.as_int(info.inputs[1].var_id);
464 if rhs.map(Zero::is_zero).unwrap_or_default() && !self.diff_fns.contains(&id) {
465 let arm = &info.arms[0];
466 self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[0]));
467 return Some((None, FlatBlockEnd::Goto(arm.block_id, Default::default())));
468 }
469 let lhs = self.as_int(info.inputs[0].var_id);
470 let value = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
471 if lhs.map(Zero::is_zero).unwrap_or_default() {
472 let arm = &info.arms[0];
473 self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[1]));
474 return Some((None, FlatBlockEnd::Goto(arm.block_id, Default::default())));
475 }
476 lhs? + rhs?
477 } else {
478 lhs? - rhs?
479 };
480 let ty = self.variables[info.arms[0].var_ids[0]].ty;
481 let range = self.type_value_ranges.get(&ty)?;
482 let (arm_index, value) = match range.normalized(value) {
483 NormalizedResult::InRange(value) => (0, value),
484 NormalizedResult::Under(value) => (1, value),
485 NormalizedResult::Over(value) => (
486 if self.iadd_fns.contains(&id) || self.isub_fns.contains(&id) { 2 } else { 1 },
487 value,
488 ),
489 };
490 let arm = &info.arms[arm_index];
491 let actual_output = arm.var_ids[0];
492 let value = ConstValue::Int(value, ty);
493 self.var_info.insert(actual_output, VarInfo::Const(value.clone()));
494 Some((
495 Some(StatementConst { value, output: actual_output }),
496 FlatBlockEnd::Goto(arm.block_id, Default::default()),
497 ))
498 } else if id == self.downcast {
499 let input_var = info.inputs[0].var_id;
500 let value = self.as_int(input_var)?;
501 let success_output = info.arms[0].var_ids[0];
502 let ty = self.variables[success_output].ty;
503 let range = self.type_value_ranges.get(&ty)?;
504 Some(if let NormalizedResult::InRange(value) = range.normalized(value.clone()) {
505 let value = ConstValue::Int(value, ty);
506 self.var_info.insert(success_output, VarInfo::Const(value.clone()));
507 (
508 Some(StatementConst { value, output: success_output }),
509 FlatBlockEnd::Goto(info.arms[0].block_id, Default::default()),
510 )
511 } else {
512 (None, FlatBlockEnd::Goto(info.arms[1].block_id, Default::default()))
513 })
514 } else if id == self.bounded_int_constrain {
515 let input_var = info.inputs[0].var_id;
516 let (value, nz_ty) = self.as_int_ex(input_var)?;
517 let generic_arg = generic_args[1];
518 let constrain_value = extract_matches!(generic_arg, GenericArgumentId::Constant)
519 .lookup_intern(self.db)
520 .into_int()
521 .unwrap();
522 let arm_idx = if value < &constrain_value { 0 } else { 1 };
523 let output = info.arms[arm_idx].var_ids[0];
524 Some((
525 Some(self.propagate_const_and_get_statement(value.clone(), output, nz_ty)),
526 FlatBlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()),
527 ))
528 } else if id == self.array_get {
529 if self.as_int(info.inputs[1].var_id)?.is_zero() {
530 if let [success, failure] = info.arms.as_mut_slice() {
531 let arr = info.inputs[0].var_id;
532 let unused_arr_output0 = self.variables.alloc(self.variables[arr].clone());
533 let unused_arr_output1 = self.variables.alloc(self.variables[arr].clone());
534 info.inputs.truncate(1);
535 info.function = ModuleHelper { db: self.db.upcast(), id: self.array_module }
536 .function_id("array_snapshot_pop_front", generic_args)
537 .lowered(self.db);
538 success.var_ids.insert(0, unused_arr_output0);
539 failure.var_ids.insert(0, unused_arr_output1);
540 }
541 }
542 None
543 } else {
544 None
545 }
546 }
547
548 fn as_const(&self, var_id: VariableId) -> Option<&ConstValue> {
550 try_extract_matches!(self.var_info.get(&var_id)?, VarInfo::Const)
551 }
552
553 fn as_int_ex(&self, var_id: VariableId) -> Option<(&BigInt, bool)> {
556 match self.as_const(var_id)? {
557 ConstValue::Int(value, _) => Some((value, false)),
558 ConstValue::NonZero(const_value) => {
559 if let ConstValue::Int(value, _) = const_value.as_ref() {
560 Some((value, true))
561 } else {
562 None
563 }
564 }
565 _ => None,
566 }
567 }
568
569 fn as_int(&self, var_id: VariableId) -> Option<&BigInt> {
571 Some(self.as_int_ex(var_id)?.0)
572 }
573
574 fn maybe_replace_inputs(&mut self, inputs: &mut [VarUsage]) {
576 for input in inputs {
577 self.maybe_replace_input(input);
578 }
579 }
580
581 fn maybe_replace_input(&mut self, input: &mut VarUsage) {
583 if let Some(VarInfo::Var(new_var)) = self.var_info.get(&input.var_id) {
584 *input = *new_var;
585 }
586 }
587}
588
589pub fn priv_const_folding_info(
591 db: &dyn LoweringGroup,
592) -> Arc<crate::optimizations::const_folding::ConstFoldingLibfuncInfo> {
593 Arc::new(ConstFoldingLibfuncInfo::new(db))
594}
595
596#[derive(Debug, PartialEq, Eq)]
598pub struct ConstFoldingLibfuncInfo {
599 felt_sub: ExternFunctionId,
601 into_box: ExternFunctionId,
603 upcast: ExternFunctionId,
605 downcast: ExternFunctionId,
607 nz_fns: OrderedHashSet<ExternFunctionId>,
609 eq_fns: OrderedHashSet<ExternFunctionId>,
611 uadd_fns: OrderedHashSet<ExternFunctionId>,
613 usub_fns: OrderedHashSet<ExternFunctionId>,
615 diff_fns: OrderedHashSet<ExternFunctionId>,
617 iadd_fns: OrderedHashSet<ExternFunctionId>,
619 isub_fns: OrderedHashSet<ExternFunctionId>,
621 wide_mul_fns: OrderedHashSet<ExternFunctionId>,
623 div_rem_fns: OrderedHashSet<ExternFunctionId>,
625 bounded_int_add: ExternFunctionId,
627 bounded_int_sub: ExternFunctionId,
629 bounded_int_constrain: ExternFunctionId,
631 array_module: ModuleId,
633 array_get: ExternFunctionId,
635 storage_access_module: ModuleId,
637 storage_base_address_from_felt252: ExternFunctionId,
639 panic_with_felt252: FunctionId,
641 type_value_ranges: OrderedHashMap<TypeId, TypeInfo>,
643}
644impl ConstFoldingLibfuncInfo {
645 fn new(db: &dyn LoweringGroup) -> Self {
646 let core = ModuleHelper::core(db.upcast());
647 let felt_sub = core.extern_function_id("felt252_sub");
648 let box_module = core.submodule("box");
649 let into_box = box_module.extern_function_id("into_box");
650 let integer_module = core.submodule("integer");
651 let bounded_int_module = core.submodule("internal").submodule("bounded_int");
652 let upcast = integer_module.extern_function_id("upcast");
653 let downcast = integer_module.extern_function_id("downcast");
654 let array_module = core.submodule("array");
655 let array_get = array_module.extern_function_id("array_get");
656 let starknet_module = core.submodule("starknet");
657 let storage_access_module = starknet_module.submodule("storage_access");
658 let storage_base_address_from_felt252 =
659 storage_access_module.extern_function_id("storage_base_address_from_felt252");
660 let nz_fns = OrderedHashSet::<_>::from_iter(chain!(
661 [
662 core.extern_function_id("felt252_is_zero"),
663 bounded_int_module.extern_function_id("bounded_int_is_zero")
664 ],
665 ["u8", "u16", "u32", "u64", "u128", "u256"]
666 .map(|ty| integer_module.extern_function_id(format!("{ty}_is_zero")))
667 ));
668 let utypes = ["u8", "u16", "u32", "u64", "u128"];
669 let itypes = ["i8", "i16", "i32", "i64", "i128"];
670 let eq_fns = OrderedHashSet::<_>::from_iter(
671 chain!(utypes, itypes).map(|ty| integer_module.extern_function_id(format!("{ty}_eq"))),
672 );
673 let uadd_fns = OrderedHashSet::<_>::from_iter(
674 utypes.map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_add"))),
675 );
676 let usub_fns = OrderedHashSet::<_>::from_iter(
677 utypes.map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_sub"))),
678 );
679 let diff_fns = OrderedHashSet::<_>::from_iter(
680 itypes.map(|ty| integer_module.extern_function_id(format!("{ty}_diff"))),
681 );
682 let iadd_fns = OrderedHashSet::<_>::from_iter(
683 itypes
684 .map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_add_impl"))),
685 );
686 let isub_fns = OrderedHashSet::<_>::from_iter(
687 itypes
688 .map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_sub_impl"))),
689 );
690 let wide_mul_fns = OrderedHashSet::<_>::from_iter(chain!(
691 [bounded_int_module.extern_function_id("bounded_int_mul")],
692 ["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
693 .map(|ty| integer_module.extern_function_id(format!("{ty}_wide_mul"))),
694 ));
695 let div_rem_fns = OrderedHashSet::<_>::from_iter(chain!(
696 [bounded_int_module.extern_function_id("bounded_int_div_rem")],
697 utypes.map(|ty| integer_module.extern_function_id(format!("{ty}_safe_divmod"))),
698 ));
699 let bounded_int_add = bounded_int_module.extern_function_id("bounded_int_add");
700 let bounded_int_sub = bounded_int_module.extern_function_id("bounded_int_sub");
701 let bounded_int_constrain = bounded_int_module.extern_function_id("bounded_int_constrain");
702 let type_value_ranges = OrderedHashMap::from_iter(
703 [
704 ("u8", BigInt::ZERO, u8::MAX.into(), false),
705 ("u16", BigInt::ZERO, u16::MAX.into(), false),
706 ("u32", BigInt::ZERO, u32::MAX.into(), false),
707 ("u64", BigInt::ZERO, u64::MAX.into(), false),
708 ("u128", BigInt::ZERO, u128::MAX.into(), false),
709 ("u256", BigInt::ZERO, BigInt::from(1) << 256, false),
710 ("i8", i8::MIN.into(), i8::MAX.into(), true),
711 ("i16", i16::MIN.into(), i16::MAX.into(), true),
712 ("i32", i32::MIN.into(), i32::MAX.into(), true),
713 ("i64", i64::MIN.into(), i64::MAX.into(), true),
714 ("i128", i128::MIN.into(), i128::MAX.into(), true),
715 ]
716 .map(
717 |(ty_name, min, max, as_bounded_int): (&str, BigInt, BigInt, bool)| {
718 let ty = corelib::get_core_ty_by_name(db.upcast(), ty_name.into(), vec![]);
719 let is_zero = if as_bounded_int {
720 bounded_int_module
721 .function_id("bounded_int_is_zero", vec![GenericArgumentId::Type(ty)])
722 } else {
723 integer_module.function_id(format!("{ty_name}_is_zero"), vec![])
724 }
725 .lowered(db);
726 let info = TypeInfo { min, max, is_zero };
727 (ty, info)
728 },
729 ),
730 );
731 Self {
732 felt_sub,
733 into_box,
734 upcast,
735 downcast,
736 nz_fns,
737 eq_fns,
738 uadd_fns,
739 usub_fns,
740 diff_fns,
741 iadd_fns,
742 isub_fns,
743 wide_mul_fns,
744 div_rem_fns,
745 bounded_int_add,
746 bounded_int_sub,
747 bounded_int_constrain,
748 array_module: array_module.id,
749 array_get,
750 storage_access_module: storage_access_module.id,
751 storage_base_address_from_felt252,
752 panic_with_felt252: core.function_id("panic_with_felt252", vec![]).lowered(db),
753 type_value_ranges,
754 }
755 }
756}
757
758impl std::ops::Deref for ConstFoldingContext<'_> {
759 type Target = ConstFoldingLibfuncInfo;
760 fn deref(&self) -> &ConstFoldingLibfuncInfo {
761 self.libfunc_info
762 }
763}
764
765#[derive(Debug, PartialEq, Eq)]
767struct TypeInfo {
768 min: BigInt,
770 max: BigInt,
772 is_zero: FunctionId,
774}
775impl TypeInfo {
776 fn normalized(&self, value: BigInt) -> NormalizedResult {
779 if value < self.min {
780 NormalizedResult::Under(value - &self.min + &self.max + 1)
781 } else if value > self.max {
782 NormalizedResult::Over(value + &self.min - &self.max - 1)
783 } else {
784 NormalizedResult::InRange(value)
785 }
786 }
787}
788
789enum NormalizedResult {
791 InRange(BigInt),
793 Over(BigInt),
795 Under(BigInt),
797}