1#[cfg(test)]
2#[path = "const_folding_test.rs"]
3mod test;
4
5use std::rc::Rc;
6use std::sync::Arc;
7
8use cairo_lang_defs::ids::{ExternFunctionId, FreeFunctionId};
9use cairo_lang_filesystem::flag::FlagsGroup;
10use cairo_lang_filesystem::ids::SmolStrId;
11use cairo_lang_semantic::corelib::CorelibSemantic;
12use cairo_lang_semantic::helper::ModuleHelper;
13use cairo_lang_semantic::items::constant::{
14 ConstCalcInfo, ConstValue, ConstValueId, ConstantSemantic, TypeRange, canonical_felt252,
15 felt252_for_downcast,
16};
17use cairo_lang_semantic::items::functions::{GenericFunctionId, GenericFunctionWithBodyId};
18use cairo_lang_semantic::items::structure::StructSemantic;
19use cairo_lang_semantic::types::{TypeSizeInformation, TypesSemantic};
20use cairo_lang_semantic::{
21 ConcreteTypeId, ConcreteVariant, GenericArgumentId, MatchArmSelector, TypeId, TypeLongId,
22 corelib,
23};
24use cairo_lang_utils::byte_array::BYTE_ARRAY_MAGIC;
25use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
26use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
27use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
28use cairo_lang_utils::{Intern, extract_matches, require, try_extract_matches};
29use itertools::{chain, zip_eq};
30use num_bigint::BigInt;
31use num_integer::Integer;
32use num_traits::cast::ToPrimitive;
33use num_traits::{Num, One, Zero};
34use salsa::Database;
35use starknet_types_core::felt::Felt as Felt252;
36
37use crate::db::LoweringGroup;
38use crate::ids::{
39 ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionId, SemanticFunctionIdEx,
40 SpecializedFunction,
41};
42use crate::specialization::SpecializationArg;
43use crate::utils::InliningStrategy;
44use crate::{
45 Block, BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, MatchArm, MatchEnumInfo,
46 MatchExternInfo, MatchInfo, Statement, StatementCall, StatementConst, StatementDesnap,
47 StatementEnumConstruct, StatementIntoBox, StatementSnapshot, StatementStructConstruct,
48 StatementStructDestructure, StatementUnbox, VarRemapping, VarUsage, Variable, VariableArena,
49 VariableId,
50};
51
52fn const_to_specialization_arg<'db>(
56 db: &'db dyn Database,
57 value: ConstValueId<'db>,
58 boxed: bool,
59) -> SpecializationArg<'db> {
60 match value.long(db) {
61 ConstValue::Struct(members, ty) => {
62 if matches!(
65 ty.long(db),
66 TypeLongId::Concrete(ConcreteTypeId::Struct(_))
67 | TypeLongId::Tuple(_)
68 | TypeLongId::FixedSizeArray { .. }
69 ) {
70 let args = members
71 .iter()
72 .map(|member| const_to_specialization_arg(db, *member, false))
73 .collect();
74 SpecializationArg::Struct(args)
75 } else {
76 SpecializationArg::Const { value, boxed }
77 }
78 }
79 ConstValue::Enum(variant, payload) => SpecializationArg::Enum {
80 variant: *variant,
81 payload: Box::new(const_to_specialization_arg(db, *payload, false)),
82 },
83 _ => SpecializationArg::Const { value, boxed },
84 }
85}
86
87#[derive(Debug, Clone)]
90enum VarInfo<'db> {
91 Const(ConstValueId<'db>),
93 Var(VarUsage<'db>),
95 Snapshot(Rc<VarInfo<'db>>),
97 Struct(Vec<Option<Rc<VarInfo<'db>>>>),
100 Enum { variant: ConcreteVariant<'db>, payload: Rc<VarInfo<'db>> },
102 Box(Rc<VarInfo<'db>>),
104 Array(Vec<Option<Rc<VarInfo<'db>>>>),
107}
108impl<'db> VarInfo<'db> {
109 fn peel_snapshots(mut self: Rc<Self>) -> (usize, Rc<VarInfo<'db>>) {
111 let mut n_snapshots = 0;
112 while let VarInfo::Snapshot(inner) = self.as_ref() {
113 self = inner.clone();
114 n_snapshots += 1;
115 }
116 (n_snapshots, self)
117 }
118 fn wrap_with_snapshots(mut self: Rc<Self>, n_snapshots: usize) -> Rc<VarInfo<'db>> {
120 for _ in 0..n_snapshots {
121 self = VarInfo::Snapshot(self).into();
122 }
123 self
124 }
125}
126
127#[derive(Debug, Clone, Copy, PartialEq)]
128enum Reachability {
129 FromSingleGoto(BlockId),
132 Any,
135}
136
137pub fn const_folding<'db>(
140 db: &'db dyn Database,
141 function_id: ConcreteFunctionWithBodyId<'db>,
142 lowered: &mut Lowered<'db>,
143) {
144 if lowered.blocks.is_empty() {
145 return;
146 }
147
148 let mut ctx = ConstFoldingContext::new(db, function_id, &mut lowered.variables);
151
152 if ctx.should_skip_const_folding(db) {
153 return;
154 }
155
156 for block_id in (0..lowered.blocks.len()).map(BlockId) {
157 if !ctx.visit_block_start(block_id, |block_id| &lowered.blocks[block_id]) {
158 continue;
159 }
160
161 let block = &mut lowered.blocks[block_id];
162 for stmt in block.statements.iter_mut() {
163 ctx.visit_statement(stmt);
164 }
165 ctx.visit_block_end(block_id, block);
166 }
167}
168
169pub struct ConstFoldingContext<'db, 'mt> {
170 db: &'db dyn Database,
172 pub variables: &'mt mut VariableArena<'db>,
174 var_info: UnorderedHashMap<VariableId, Rc<VarInfo<'db>>>,
176 libfunc_info: &'db ConstFoldingLibfuncInfo<'db>,
178 caller_function: ConcreteFunctionWithBodyId<'db>,
181 reachability: UnorderedHashMap<BlockId, Reachability>,
185 additional_stmts: Vec<Statement<'db>>,
187}
188
189impl<'db, 'mt> ConstFoldingContext<'db, 'mt> {
190 pub fn new(
191 db: &'db dyn Database,
192 function_id: ConcreteFunctionWithBodyId<'db>,
193 variables: &'mt mut VariableArena<'db>,
194 ) -> Self {
195 Self {
196 db,
197 var_info: UnorderedHashMap::default(),
198 variables,
199 libfunc_info: priv_const_folding_info(db),
200 caller_function: function_id,
201 reachability: UnorderedHashMap::from_iter([(BlockId::root(), Reachability::Any)]),
202 additional_stmts: vec![],
203 }
204 }
205
206 pub fn visit_block_start<'r, 'get>(
209 &'r mut self,
210 block_id: BlockId,
211 get_block: impl FnOnce(BlockId) -> &'get Block<'db>,
212 ) -> bool
213 where
214 'db: 'get,
215 {
216 let Some(reachability) = self.reachability.remove(&block_id) else {
217 return false;
218 };
219 match reachability {
220 Reachability::Any => {}
221 Reachability::FromSingleGoto(from_block) => match &get_block(from_block).end {
222 BlockEnd::Goto(_, remapping) => {
223 for (dst, src) in remapping.iter() {
224 if let Some(v) = self.as_const(src.var_id) {
225 self.var_info.insert(*dst, VarInfo::Const(v).into());
226 }
227 }
228 }
229 _ => unreachable!("Expected a goto end"),
230 },
231 }
232 true
233 }
234
235 pub fn visit_statement(&mut self, stmt: &mut Statement<'db>) {
245 self.maybe_replace_inputs(stmt.inputs_mut());
246 match stmt {
247 Statement::Const(StatementConst { value, output, boxed }) if *boxed => {
248 self.var_info.insert(*output, VarInfo::Box(VarInfo::Const(*value).into()).into());
249 }
250 Statement::Const(StatementConst { value, output, .. }) => match value.long(self.db) {
251 ConstValue::Int(..)
252 | ConstValue::Struct(..)
253 | ConstValue::Enum(..)
254 | ConstValue::NonZero(..) => {
255 self.var_info.insert(*output, VarInfo::Const(*value).into());
256 }
257 ConstValue::Generic(_)
258 | ConstValue::ImplConstant(_)
259 | ConstValue::Var(..)
260 | ConstValue::Missing(_) => {}
261 },
262 Statement::Snapshot(stmt) => {
263 if let Some(info) = self.var_info.get(&stmt.input.var_id) {
264 let info = info.clone();
265 self.var_info.insert(stmt.original(), info.clone());
266 self.var_info.insert(stmt.snapshot(), VarInfo::Snapshot(info).into());
267 }
268 }
269 Statement::Desnap(StatementDesnap { input, output }) => {
270 if let Some(info) = self.var_info.get(&input.var_id)
271 && let VarInfo::Snapshot(info) = info.as_ref()
272 {
273 self.var_info.insert(*output, info.clone());
274 }
275 }
276 Statement::Call(call_stmt) => {
277 if let Some(updated_stmt) = self.handle_statement_call(call_stmt) {
278 *stmt = updated_stmt;
279 } else if let Some(updated_stmt) = self.try_specialize_call(call_stmt) {
280 *stmt = updated_stmt;
281 }
282 }
283 Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
284 let mut const_args = vec![];
285 let mut all_args = vec![];
286 let mut contains_info = false;
287 for input in inputs.iter() {
288 let Some(info) = self.var_info.get(&input.var_id) else {
289 all_args.push(var_info_if_copy(self.variables, *input));
290 continue;
291 };
292 contains_info = true;
293 if let VarInfo::Const(value) = info.as_ref() {
294 const_args.push(*value);
295 }
296 all_args.push(Some(info.clone()));
297 }
298 if const_args.len() == inputs.len() {
299 let value =
300 ConstValue::Struct(const_args, self.variables[*output].ty).intern(self.db);
301 self.var_info.insert(*output, VarInfo::Const(value).into());
302 } else if contains_info {
303 self.var_info.insert(*output, VarInfo::Struct(all_args).into());
304 }
305 }
306 Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
307 if let Some(info) = self.var_info.get(&input.var_id) {
308 let (n_snapshots, info) = info.clone().peel_snapshots();
309 match info.as_ref() {
310 VarInfo::Const(const_value) => {
311 if let ConstValue::Struct(member_values, _) = const_value.long(self.db)
312 {
313 for (output, value) in zip_eq(outputs, member_values) {
314 self.var_info.insert(
315 *output,
316 Rc::new(VarInfo::Const(*value))
317 .wrap_with_snapshots(n_snapshots),
318 );
319 }
320 }
321 }
322 VarInfo::Struct(members) => {
323 for (output, member) in zip_eq(outputs, members.clone()) {
324 if let Some(member) = member {
325 self.var_info
326 .insert(*output, member.wrap_with_snapshots(n_snapshots));
327 }
328 }
329 }
330 _ => {}
331 }
332 }
333 }
334 Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
335 let value = if let Some(info) = self.var_info.get(&input.var_id) {
336 if let VarInfo::Const(val) = info.as_ref() {
337 VarInfo::Const(ConstValue::Enum(*variant, *val).intern(self.db))
338 } else {
339 VarInfo::Enum { variant: *variant, payload: info.clone() }
340 }
341 } else {
342 VarInfo::Enum { variant: *variant, payload: VarInfo::Var(*input).into() }
343 };
344 self.var_info.insert(*output, value.into());
345 }
346 Statement::IntoBox(StatementIntoBox { input, output }) => {
347 let var_info = self.var_info.get(&input.var_id);
348 let const_value = var_info.and_then(|var_info| match var_info.as_ref() {
349 VarInfo::Const(val) => Some(*val),
350 VarInfo::Snapshot(info) => {
351 try_extract_matches!(info.as_ref(), VarInfo::Const).copied()
352 }
353 _ => None,
354 });
355 let var_info =
356 var_info.cloned().or_else(|| var_info_if_copy(self.variables, *input));
357 if let Some(var_info) = var_info {
358 self.var_info.insert(*output, VarInfo::Box(var_info).into());
359 }
360
361 if let Some(const_value) = const_value {
362 *stmt = Statement::Const(StatementConst::new_boxed(const_value, *output));
363 }
364 }
365 Statement::Unbox(StatementUnbox { input, output }) => {
366 if let Some(inner) = self.var_info.get(&input.var_id)
367 && let VarInfo::Box(inner) = inner.as_ref()
368 {
369 let inner = inner.clone();
370 if let VarInfo::Const(inner) =
371 self.var_info.entry(*output).insert_entry(inner).get().as_ref()
372 {
373 *stmt = Statement::Const(StatementConst::new_flat(*inner, *output));
374 }
375 }
376 }
377 }
378 }
379
380 pub fn visit_block_end(&mut self, block_id: BlockId, block: &mut Block<'db>) {
387 let statements = &mut block.statements;
388 statements.splice(0..0, self.additional_stmts.drain(..));
389
390 match &mut block.end {
391 BlockEnd::Goto(_, remappings) => {
392 for (_, v) in remappings.iter_mut() {
393 self.maybe_replace_input(v);
394 }
395 }
396 BlockEnd::Match { info } => {
397 self.maybe_replace_inputs(info.inputs_mut());
398 match info {
399 MatchInfo::Enum(info) => {
400 if let Some(updated_end) = self.handle_enum_block_end(info, statements) {
401 block.end = updated_end;
402 }
403 }
404 MatchInfo::Extern(info) => {
405 if let Some(updated_end) = self.handle_extern_block_end(info, statements) {
406 block.end = updated_end;
407 }
408 }
409 MatchInfo::Value(info) => {
410 if let Some(value) =
411 self.as_int(info.input.var_id).and_then(|x| x.to_usize())
412 && let Some(arm) = info.arms.iter().find(|arm| {
413 matches!(
414 &arm.arm_selector,
415 MatchArmSelector::Value(v) if v.value == value
416 )
417 })
418 {
419 statements.push(Statement::StructConstruct(StatementStructConstruct {
421 inputs: vec![],
422 output: arm.var_ids[0],
423 }));
424 block.end = BlockEnd::Goto(arm.block_id, Default::default());
425 }
426 }
427 }
428 }
429 BlockEnd::Return(inputs, _) => self.maybe_replace_inputs(inputs),
430 BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
431 }
432 match &block.end {
433 BlockEnd::Goto(dst_block_id, _) => {
434 match self.reachability.entry(*dst_block_id) {
435 std::collections::hash_map::Entry::Occupied(mut e) => {
436 e.insert(Reachability::Any)
437 }
438 std::collections::hash_map::Entry::Vacant(e) => {
439 *e.insert(Reachability::FromSingleGoto(block_id))
440 }
441 };
442 }
443 BlockEnd::Match { info } => {
444 for arm in info.arms() {
445 assert!(self.reachability.insert(arm.block_id, Reachability::Any).is_none());
446 }
447 }
448 BlockEnd::NotSet | BlockEnd::Return(..) | BlockEnd::Panic(..) => {}
449 }
450 }
451
452 fn handle_statement_call(&mut self, stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
460 let db = self.db;
461 if stmt.function == self.panic_with_felt252 {
462 let val = self.as_const(stmt.inputs[0].var_id)?;
463 stmt.inputs.clear();
464 stmt.function = GenericFunctionId::Free(self.panic_with_const_felt252)
465 .concretize(db, vec![GenericArgumentId::Constant(val)])
466 .lowered(db);
467 return None;
468 } else if stmt.function == self.panic_with_byte_array && !db.flag_unsafe_panic() {
469 let snap = self.var_info.get(&stmt.inputs[0].var_id)?;
470 let bytearray = try_extract_matches!(snap.as_ref(), VarInfo::Snapshot)?;
471 let [Some(data), Some(pending_word), Some(pending_len)] =
472 &try_extract_matches!(bytearray.as_ref(), VarInfo::Struct)?[..]
473 else {
474 return None;
475 };
476 let data = try_extract_matches!(data.as_ref(), VarInfo::Array)?;
477 let pending_word = try_extract_matches!(pending_word.as_ref(), VarInfo::Const)?;
478 let pending_len = try_extract_matches!(pending_len.as_ref(), VarInfo::Const)?;
479 let mut panic_data =
480 vec![BigInt::from_str_radix(BYTE_ARRAY_MAGIC, 16).unwrap(), data.len().into()];
481 for word in data {
482 let VarInfo::Const(word) = word.as_ref()?.as_ref() else {
483 return None;
484 };
485 panic_data.push(word.long(db).to_int()?.clone());
486 }
487 panic_data.extend([
488 pending_word.long(db).to_int()?.clone(),
489 pending_len.long(db).to_int()?.clone(),
490 ]);
491 let felt252_ty = self.felt252;
492 let location = stmt.location;
493 let new_var = |ty| Variable::with_default_context(db, ty, location);
494 let as_usage = |var_id| VarUsage { var_id, location };
495 let array_fn = |extern_id| {
496 let args = vec![GenericArgumentId::Type(felt252_ty)];
497 GenericFunctionId::Extern(extern_id).concretize(db, args).lowered(db)
498 };
499 let call_stmt = |function, inputs, outputs| {
500 let with_coupon = false;
501 Statement::Call(StatementCall {
502 function,
503 inputs,
504 with_coupon,
505 outputs,
506 location,
507 is_specialization_base_call: false,
508 })
509 };
510 let arr_var = new_var(corelib::core_array_felt252_ty(db));
511 let mut arr = self.variables.alloc(arr_var.clone());
512 self.additional_stmts.push(call_stmt(array_fn(self.array_new), vec![], vec![arr]));
513 let felt252_var = new_var(felt252_ty);
514 let arr_append_fn = array_fn(self.array_append);
515 for word in panic_data {
516 let to_append = self.variables.alloc(felt252_var.clone());
517 let new_arr = self.variables.alloc(arr_var.clone());
518 self.additional_stmts.push(Statement::Const(StatementConst::new_flat(
519 ConstValue::Int(word, felt252_ty).intern(db),
520 to_append,
521 )));
522 self.additional_stmts.push(call_stmt(
523 arr_append_fn,
524 vec![as_usage(arr), as_usage(to_append)],
525 vec![new_arr],
526 ));
527 arr = new_arr;
528 }
529 let panic_ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, "Panic"), vec![]);
530 let panic_var = self.variables.alloc(new_var(panic_ty));
531 self.additional_stmts.push(Statement::StructConstruct(StatementStructConstruct {
532 inputs: vec![],
533 output: panic_var,
534 }));
535 return Some(Statement::StructConstruct(StatementStructConstruct {
536 inputs: vec![as_usage(panic_var), as_usage(arr)],
537 output: stmt.outputs[0],
538 }));
539 }
540 let (id, _generic_args) = stmt.function.get_extern(db)?;
541 if id == self.felt_sub {
542 if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
543 && rhs.is_zero()
544 {
545 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]).into());
546 None
547 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
548 && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
549 {
550 let value = canonical_felt252(&(lhs - rhs));
551 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
552 } else {
553 None
554 }
555 } else if id == self.felt_add {
556 if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
557 && lhs.is_zero()
558 {
559 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]).into());
560 None
561 } else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
562 && rhs.is_zero()
563 {
564 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]).into());
565 None
566 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
567 && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
568 {
569 let value = canonical_felt252(&(lhs + rhs));
570 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
571 } else {
572 None
573 }
574 } else if id == self.felt_mul {
575 let lhs = self.as_int(stmt.inputs[0].var_id);
576 let rhs = self.as_int(stmt.inputs[1].var_id);
577 if lhs.map(Zero::is_zero).unwrap_or_default()
578 || rhs.map(Zero::is_zero).unwrap_or_default()
579 {
580 Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
581 } else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
582 && rhs.is_one()
583 {
584 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]).into());
585 None
586 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
587 && lhs.is_one()
588 {
589 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]).into());
590 None
591 } else if let Some(lhs) = lhs
592 && let Some(rhs) = rhs
593 {
594 let value = canonical_felt252(&(lhs * rhs));
595 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
596 } else {
597 None
598 }
599 } else if id == self.felt_div {
600 if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
602 && rhs.is_one()
604 {
605 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]).into());
606 None
607 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
608 && lhs.is_zero()
610 {
611 Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
612 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
613 && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
614 && let Ok(rhs_nonzero) = Felt252::from(rhs).try_into()
615 {
616 let lhs_felt = Felt252::from(lhs);
620 let value = lhs_felt.field_div(&rhs_nonzero).to_bigint();
621 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
622 } else {
623 None
624 }
625 } else if self.wide_mul_fns.contains(&id) {
626 let lhs = self.as_int(stmt.inputs[0].var_id);
627 let rhs = self.as_int(stmt.inputs[1].var_id);
628 let output = stmt.outputs[0];
629 if lhs.map(Zero::is_zero).unwrap_or_default()
630 || rhs.map(Zero::is_zero).unwrap_or_default()
631 {
632 return Some(self.propagate_zero_and_get_statement(output));
633 }
634 let lhs = lhs?;
635 Some(self.propagate_const_and_get_statement(lhs * rhs?, stmt.outputs[0]))
636 } else if id == self.bounded_int_add || id == self.bounded_int_sub {
637 let lhs = self.as_int(stmt.inputs[0].var_id)?;
638 let rhs = self.as_int(stmt.inputs[1].var_id)?;
639 let value = if id == self.bounded_int_add { lhs + rhs } else { lhs - rhs };
640 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
641 } else if self.div_rem_fns.contains(&id) {
642 let lhs = self.as_int(stmt.inputs[0].var_id);
643 if lhs.map(Zero::is_zero).unwrap_or_default() {
644 let additional_stmt = self.propagate_zero_and_get_statement(stmt.outputs[1]);
645 self.additional_stmts.push(additional_stmt);
646 return Some(self.propagate_zero_and_get_statement(stmt.outputs[0]));
647 }
648 let rhs = self.as_int(stmt.inputs[1].var_id)?;
649 let (q, r) = lhs?.div_rem(rhs);
650 let q_output = stmt.outputs[0];
651 let q_value = ConstValue::Int(q, self.variables[q_output].ty).intern(db);
652 self.var_info.insert(q_output, VarInfo::Const(q_value).into());
653 let r_output = stmt.outputs[1];
654 let r_value = ConstValue::Int(r, self.variables[r_output].ty).intern(db);
655 self.var_info.insert(r_output, VarInfo::Const(r_value).into());
656 self.additional_stmts
657 .push(Statement::Const(StatementConst::new_flat(r_value, r_output)));
658 Some(Statement::Const(StatementConst::new_flat(q_value, q_output)))
659 } else if id == self.storage_base_address_from_felt252 {
660 let input_var = stmt.inputs[0].var_id;
661 if let Some(const_value) = self.as_const(input_var)
662 && let ConstValue::Int(val, ty) = const_value.long(db)
663 {
664 stmt.inputs.clear();
665 let arg = GenericArgumentId::Constant(ConstValue::Int(val.clone(), *ty).intern(db));
666 stmt.function =
667 self.storage_base_address_const.concretize(db, vec![arg]).lowered(db);
668 }
669 None
670 } else if self.upcast_fns.contains(&id) {
671 let int_value = self.as_int(stmt.inputs[0].var_id)?;
672 let output = stmt.outputs[0];
673 let value = ConstValue::Int(int_value.clone(), self.variables[output].ty).intern(db);
674 self.var_info.insert(output, VarInfo::Const(value).into());
675 Some(Statement::Const(StatementConst::new_flat(value, output)))
676 } else if id == self.array_new {
677 self.var_info.insert(stmt.outputs[0], VarInfo::Array(vec![]).into());
678 None
679 } else if id == self.array_append {
680 let mut var_infos = if let VarInfo::Array(var_infos) =
681 self.var_info.get(&stmt.inputs[0].var_id)?.as_ref()
682 {
683 var_infos.clone()
684 } else {
685 return None;
686 };
687 let appended = stmt.inputs[1];
688 var_infos.push(match self.var_info.get(&appended.var_id) {
689 Some(var_info) => Some(var_info.clone()),
690 None => var_info_if_copy(self.variables, appended),
691 });
692 self.var_info.insert(stmt.outputs[0], VarInfo::Array(var_infos).into());
693 None
694 } else if id == self.array_len {
695 let info = self.var_info.get(&stmt.inputs[0].var_id)?;
696 let desnapped = try_extract_matches!(info.as_ref(), VarInfo::Snapshot)?;
697 let length = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?.len();
698 Some(self.propagate_const_and_get_statement(length.into(), stmt.outputs[0]))
699 } else {
700 None
701 }
702 }
703
704 fn try_specialize_call(&self, call_stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
711 if call_stmt.with_coupon {
712 return None;
713 }
714 if matches!(self.db.optimizations().inlining_strategy(), InliningStrategy::Avoid) {
716 return None;
717 }
718
719 let Ok(Some(mut called_function)) = call_stmt.function.body(self.db) else {
720 return None;
721 };
722
723 let extract_base = |function: ConcreteFunctionWithBodyId<'db>| match function.long(self.db)
724 {
725 ConcreteFunctionWithBodyLongId::Specialized(specialized) => {
726 specialized.long(self.db).base
727 }
728 _ => function,
729 };
730 let called_base = extract_base(called_function);
731 let caller_base = extract_base(self.caller_function);
732
733 if self.db.priv_never_inline(called_base).ok()? {
734 return None;
735 }
736
737 if call_stmt.is_specialization_base_call {
739 return None;
740 }
741
742 if called_base == caller_base && called_function != called_base {
744 return None;
745 }
746
747 let scc =
750 self.db.lowered_scc(called_base, DependencyType::Call, LoweringStage::Monomorphized);
751 if scc.len() > 1 && scc.contains(&caller_base) {
752 return None;
753 }
754
755 if call_stmt.inputs.iter().all(|arg| self.var_info.get(&arg.var_id).is_none()) {
756 return None;
758 }
759
760 let self_specializition = if let ConcreteFunctionWithBodyLongId::Specialized(specialized) =
762 self.caller_function.long(self.db)
763 && caller_base == called_base
764 {
765 specialized.long(self.db).args.iter().map(Some).collect()
766 } else {
767 vec![None; call_stmt.inputs.len()]
768 };
769
770 let mut specialization_args = vec![];
771 let mut new_args = vec![];
772 for (arg, coerce) in zip_eq(&call_stmt.inputs, &self_specializition) {
773 if let Some(var_info) = self.var_info.get(&arg.var_id)
774 && self.variables[arg.var_id].info.droppable.is_ok()
775 && let Some(specialization_arg) = self.try_get_specialization_arg(
776 var_info.clone(),
777 self.variables[arg.var_id].ty,
778 &mut new_args,
779 *coerce,
780 )
781 {
782 specialization_args.push(specialization_arg);
783 } else {
784 specialization_args.push(SpecializationArg::NotSpecialized);
785 new_args.push(*arg);
786 continue;
787 };
788 }
789
790 if specialization_args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized)) {
791 return None;
793 }
794 if let ConcreteFunctionWithBodyLongId::Specialized(specialized_function) =
795 called_function.long(self.db)
796 {
797 let specialized_function = specialized_function.long(self.db);
798 called_function = specialized_function.base;
801 let mut new_args_iter = specialization_args.into_iter();
802 let mut old_args = specialized_function.args.clone();
803 let mut stack = vec![];
804 for arg in old_args.iter_mut().rev() {
805 stack.push(arg);
806 }
807 while let Some(arg) = stack.pop() {
808 match arg {
809 SpecializationArg::Const { .. } => {}
810 SpecializationArg::Snapshot(inner) => {
811 stack.push(inner.as_mut());
812 }
813 SpecializationArg::Enum { payload, .. } => {
814 stack.push(payload.as_mut());
815 }
816 SpecializationArg::Array(_, values) | SpecializationArg::Struct(values) => {
817 for value in values.iter_mut().rev() {
818 stack.push(value);
819 }
820 }
821 SpecializationArg::NotSpecialized => {
822 *arg = new_args_iter.next().unwrap_or(SpecializationArg::NotSpecialized);
823 }
824 }
825 }
826 specialization_args = old_args;
827 }
828 let specialized = SpecializedFunction { base: called_function, args: specialization_args }
829 .intern(self.db);
830 let specialized_func_id =
831 ConcreteFunctionWithBodyLongId::Specialized(specialized).intern(self.db);
832
833 if caller_base != called_base
834 && self.db.priv_should_specialize(specialized_func_id) == Ok(false)
835 {
836 return None;
837 }
838
839 Some(Statement::Call(StatementCall {
840 function: specialized_func_id.function_id(self.db).unwrap(),
841 inputs: new_args,
842 with_coupon: call_stmt.with_coupon,
843 outputs: std::mem::take(&mut call_stmt.outputs),
844 location: call_stmt.location,
845 is_specialization_base_call: false,
846 }))
847 }
848
849 fn propagate_const_and_get_statement(
851 &mut self,
852 value: BigInt,
853 output: VariableId,
854 ) -> Statement<'db> {
855 let ty = self.variables[output].ty;
856 let value = ConstValueId::from_int(self.db, ty, &value);
857 self.var_info.insert(output, VarInfo::Const(value).into());
858 Statement::Const(StatementConst::new_flat(value, output))
859 }
860
861 fn propagate_zero_and_get_statement(&mut self, output: VariableId) -> Statement<'db> {
863 self.propagate_const_and_get_statement(BigInt::zero(), output)
864 }
865
866 fn try_generate_const_statement(
868 &self,
869 value: ConstValueId<'db>,
870 output: VariableId,
871 ) -> Option<Statement<'db>> {
872 if self.db.type_size_info(self.variables[output].ty) == Ok(TypeSizeInformation::Other) {
873 Some(Statement::Const(StatementConst::new_flat(value, output)))
874 } else if matches!(value.long(self.db), ConstValue::Struct(members, _) if members.is_empty())
875 {
876 Some(Statement::StructConstruct(StatementStructConstruct { inputs: vec![], output }))
878 } else {
879 None
880 }
881 }
882
883 fn handle_enum_block_end(
888 &mut self,
889 info: &mut MatchEnumInfo<'db>,
890 statements: &mut Vec<Statement<'db>>,
891 ) -> Option<BlockEnd<'db>> {
892 let input = info.input.var_id;
893 let (n_snapshots, var_info) = self.var_info.get(&input)?.clone().peel_snapshots();
894 let location = info.location;
895 let as_usage = |var_id| VarUsage { var_id, location };
896 let db = self.db;
897 let snapshot_stmt = |vars: &mut VariableArena<'_>, pre_snap, post_snap| {
898 let ignored = vars.alloc(vars[pre_snap].clone());
899 Statement::Snapshot(StatementSnapshot::new(as_usage(pre_snap), ignored, post_snap))
900 };
901 if let VarInfo::Const(const_value) = var_info.as_ref()
903 && let ConstValue::Enum(variant, value) = const_value.long(db)
904 {
905 let arm = &info.arms[variant.idx];
906 let output = arm.var_ids[0];
907 self.var_info
909 .insert(output, Rc::new(VarInfo::Const(*value)).wrap_with_snapshots(n_snapshots));
910 if self.variables[input].info.droppable.is_ok()
911 && self.variables[output].info.copyable.is_ok()
912 && let Ok(mut ty) = value.ty(db)
913 && let Some(mut stmt) = self.try_generate_const_statement(*value, output)
914 {
915 for _ in 0..n_snapshots {
917 let non_snap_var = Variable::with_default_context(db, ty, location);
918 ty = TypeLongId::Snapshot(ty).intern(db);
919 let pre_snap = self.variables.alloc(non_snap_var);
920 stmt.outputs_mut()[0] = pre_snap;
921 let take_snap = snapshot_stmt(self.variables, pre_snap, output);
922 statements.push(core::mem::replace(&mut stmt, take_snap));
923 }
924 statements.push(stmt);
925 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
926 }
927 } else if let VarInfo::Enum { variant, payload } = var_info.as_ref() {
928 let arm = &info.arms[variant.idx];
929 let variant_ty = variant.ty;
930 let output = arm.var_ids[0];
931 let payload = payload.clone();
932 let unwrapped =
933 self.variables[input].info.droppable.is_ok().then_some(()).and_then(|_| {
934 let (extra_snapshots, inner) = payload.clone().peel_snapshots();
935 match inner.as_ref() {
936 VarInfo::Var(var) if self.variables[var.var_id].info.copyable.is_ok() => {
937 Some((var.var_id, extra_snapshots))
938 }
939 VarInfo::Const(value) => {
940 let const_var = self
941 .variables
942 .alloc(Variable::with_default_context(db, variant_ty, location));
943 statements.push(self.try_generate_const_statement(*value, const_var)?);
944 Some((const_var, extra_snapshots))
945 }
946 _ => None,
947 }
948 });
949 self.var_info.insert(output, payload.wrap_with_snapshots(n_snapshots));
951 if let Some((mut unwrapped, extra_snapshots)) = unwrapped {
952 let total_snapshots = n_snapshots + extra_snapshots;
953 if total_snapshots != 0 {
954 for _ in 1..total_snapshots {
956 let ty = TypeLongId::Snapshot(self.variables[unwrapped].ty).intern(db);
957 let non_snap_var = Variable::with_default_context(self.db, ty, location);
958 let snapped = self.variables.alloc(non_snap_var);
959 statements.push(snapshot_stmt(self.variables, unwrapped, snapped));
960 unwrapped = snapped;
961 }
962 statements.push(snapshot_stmt(self.variables, unwrapped, output));
963 };
964 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
965 }
966 }
967 None
968 }
969
970 fn handle_extern_block_end(
975 &mut self,
976 info: &mut MatchExternInfo<'db>,
977 statements: &mut Vec<Statement<'db>>,
978 ) -> Option<BlockEnd<'db>> {
979 let db = self.db;
980 let (id, generic_args) = info.function.get_extern(db)?;
981 if self.nz_fns.contains(&id) {
982 let val = self.as_const(info.inputs[0].var_id)?;
983 let is_zero = match val.long(db) {
984 ConstValue::Int(v, _) => v.is_zero(),
985 ConstValue::Struct(s, _) => s.iter().all(|v| {
986 v.long(db).to_int().expect("Expected ConstValue::Int for size").is_zero()
987 }),
988 _ => unreachable!(),
989 };
990 Some(if is_zero {
991 BlockEnd::Goto(info.arms[0].block_id, Default::default())
992 } else {
993 let arm = &info.arms[1];
994 let nz_var = arm.var_ids[0];
995 let nz_val = ConstValue::NonZero(val).intern(db);
996 self.var_info.insert(nz_var, VarInfo::Const(nz_val).into());
997 statements.push(Statement::Const(StatementConst::new_flat(nz_val, nz_var)));
998 BlockEnd::Goto(arm.block_id, Default::default())
999 })
1000 } else if self.eq_fns.contains(&id) {
1001 let lhs = self.as_int(info.inputs[0].var_id);
1002 let rhs = self.as_int(info.inputs[1].var_id);
1003 if (lhs.map(Zero::is_zero).unwrap_or_default() && rhs.is_none())
1004 || (rhs.map(Zero::is_zero).unwrap_or_default() && lhs.is_none())
1005 {
1006 let nz_input = info.inputs[if lhs.is_some() { 1 } else { 0 }];
1007 let var = &self.variables[nz_input.var_id].clone();
1008 let function = self.type_info.get(&var.ty)?.is_zero;
1009 let unused_nz_var = Variable::with_default_context(
1010 db,
1011 corelib::core_nonzero_ty(db, var.ty),
1012 var.location,
1013 );
1014 let unused_nz_var = self.variables.alloc(unused_nz_var);
1015 return Some(BlockEnd::Match {
1016 info: MatchInfo::Extern(MatchExternInfo {
1017 function,
1018 inputs: vec![nz_input],
1019 arms: vec![
1020 MatchArm {
1021 arm_selector: MatchArmSelector::VariantId(
1022 corelib::jump_nz_zero_variant(db, var.ty),
1023 ),
1024 block_id: info.arms[1].block_id,
1025 var_ids: vec![],
1026 },
1027 MatchArm {
1028 arm_selector: MatchArmSelector::VariantId(
1029 corelib::jump_nz_nonzero_variant(db, var.ty),
1030 ),
1031 block_id: info.arms[0].block_id,
1032 var_ids: vec![unused_nz_var],
1033 },
1034 ],
1035 location: info.location,
1036 }),
1037 });
1038 }
1039 Some(BlockEnd::Goto(
1040 info.arms[if lhs? == rhs? { 1 } else { 0 }].block_id,
1041 Default::default(),
1042 ))
1043 } else if self.uadd_fns.contains(&id)
1044 || self.usub_fns.contains(&id)
1045 || self.diff_fns.contains(&id)
1046 || self.iadd_fns.contains(&id)
1047 || self.isub_fns.contains(&id)
1048 {
1049 let rhs = self.as_int(info.inputs[1].var_id);
1050 let lhs = self.as_int(info.inputs[0].var_id);
1051 if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
1052 let ty = self.variables[info.arms[0].var_ids[0]].ty;
1053 let range = self.type_value_ranges.get(&ty)?;
1054 let value = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
1055 lhs + rhs
1056 } else {
1057 lhs - rhs
1058 };
1059 let (arm_index, value) = match range.normalized(value) {
1060 NormalizedResult::InRange(value) => (0, value),
1061 NormalizedResult::Under(value) => (1, value),
1062 NormalizedResult::Over(value) => (
1063 if self.iadd_fns.contains(&id) || self.isub_fns.contains(&id) {
1064 2
1065 } else {
1066 1
1067 },
1068 value,
1069 ),
1070 };
1071 let arm = &info.arms[arm_index];
1072 let actual_output = arm.var_ids[0];
1073 let value = ConstValue::Int(value, ty).intern(db);
1074 self.var_info.insert(actual_output, VarInfo::Const(value).into());
1075 statements.push(Statement::Const(StatementConst::new_flat(value, actual_output)));
1076 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1077 }
1078 if let Some(rhs) = rhs {
1079 if rhs.is_zero() && !self.diff_fns.contains(&id) {
1080 let arm = &info.arms[0];
1081 self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[0]).into());
1082 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1083 }
1084 if rhs.is_one() && !self.diff_fns.contains(&id) {
1085 let ty = self.variables[info.arms[0].var_ids[0]].ty;
1086 let ty_info = self.type_info.get(&ty)?;
1087 let function = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
1088 ty_info.inc?
1089 } else {
1090 ty_info.dec?
1091 };
1092 let enum_ty = function.signature(db).ok()?.return_type;
1093 let TypeLongId::Concrete(ConcreteTypeId::Enum(concrete_enum_id)) =
1094 enum_ty.long(db)
1095 else {
1096 return None;
1097 };
1098 let result = self.variables.alloc(Variable::with_default_context(
1099 db,
1100 function.signature(db).unwrap().return_type,
1101 info.location,
1102 ));
1103 statements.push(Statement::Call(StatementCall {
1104 function,
1105 inputs: vec![info.inputs[0]],
1106 with_coupon: false,
1107 outputs: vec![result],
1108 location: info.location,
1109 is_specialization_base_call: false,
1110 }));
1111 return Some(BlockEnd::Match {
1112 info: MatchInfo::Enum(MatchEnumInfo {
1113 concrete_enum_id: *concrete_enum_id,
1114 input: VarUsage { var_id: result, location: info.location },
1115 arms: core::mem::take(&mut info.arms),
1116 location: info.location,
1117 }),
1118 });
1119 }
1120 }
1121 if let Some(lhs) = lhs
1122 && lhs.is_zero()
1123 && (self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id))
1124 {
1125 let arm = &info.arms[0];
1126 self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[1]).into());
1127 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1128 }
1129 None
1130 } else if let Some(reversed) = self.downcast_fns.get(&id) {
1131 let range = |ty: TypeId<'_>| {
1132 Some(if let Some(range) = self.type_value_ranges.get(&ty) {
1133 range.clone()
1134 } else {
1135 let (min, max) = corelib::try_extract_bounded_int_type_ranges(db, ty)?;
1136 TypeRange { min, max }
1137 })
1138 };
1139 let (success_arm, failure_arm) = if *reversed { (1, 0) } else { (0, 1) };
1140 let input_var = info.inputs[0].var_id;
1141 let in_ty = self.variables[input_var].ty;
1142 let success_output = info.arms[success_arm].var_ids[0];
1143 let out_ty = self.variables[success_output].ty;
1144 let out_range = range(out_ty)?;
1145 let Some(value) = self.as_int(input_var) else {
1146 let in_range = range(in_ty)?;
1147 return if in_range.min < out_range.min || in_range.max > out_range.max {
1148 None
1149 } else {
1150 let generic_args = [in_ty, out_ty].map(GenericArgumentId::Type).to_vec();
1151 let function = db.core_info().upcast_fn.concretize(db, generic_args);
1152 statements.push(Statement::Call(StatementCall {
1153 function: function.lowered(db),
1154 inputs: vec![info.inputs[0]],
1155 with_coupon: false,
1156 outputs: vec![success_output],
1157 location: info.location,
1158 is_specialization_base_call: false,
1159 }));
1160 return Some(BlockEnd::Goto(
1161 info.arms[success_arm].block_id,
1162 Default::default(),
1163 ));
1164 };
1165 };
1166 let value = if in_ty == self.felt252 {
1167 felt252_for_downcast(value, &out_range.min)
1168 } else {
1169 value.clone()
1170 };
1171 Some(if let NormalizedResult::InRange(value) = out_range.normalized(value) {
1172 let value = ConstValue::Int(value, out_ty).intern(db);
1173 self.var_info.insert(success_output, VarInfo::Const(value).into());
1174 statements.push(Statement::Const(StatementConst::new_flat(value, success_output)));
1175 BlockEnd::Goto(info.arms[success_arm].block_id, Default::default())
1176 } else {
1177 BlockEnd::Goto(info.arms[failure_arm].block_id, Default::default())
1178 })
1179 } else if id == self.bounded_int_constrain {
1180 let input_var = info.inputs[0].var_id;
1181 let value = self.as_int(input_var)?;
1182 let generic_arg = generic_args[1];
1183 let constrain_value = extract_matches!(generic_arg, GenericArgumentId::Constant)
1184 .long(db)
1185 .to_int()
1186 .expect("Expected ConstValue::Int for size");
1187 let arm_idx = if value < constrain_value { 0 } else { 1 };
1188 let output = info.arms[arm_idx].var_ids[0];
1189 statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1190 Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1191 } else if id == self.bounded_int_trim_min {
1192 let input_var = info.inputs[0].var_id;
1193 let ConstValue::Int(value, ty) = self.as_const(input_var)?.long(self.db) else {
1194 return None;
1195 };
1196 let is_trimmed = if let Some(range) = self.type_value_ranges.get(ty) {
1197 range.min == *value
1198 } else {
1199 corelib::try_extract_bounded_int_type_ranges(db, *ty)?.0 == *value
1200 };
1201 let arm_idx = if is_trimmed {
1202 0
1203 } else {
1204 let output = info.arms[1].var_ids[0];
1205 statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1206 1
1207 };
1208 Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1209 } else if id == self.bounded_int_trim_max {
1210 let input_var = info.inputs[0].var_id;
1211 let ConstValue::Int(value, ty) = self.as_const(input_var)?.long(self.db) else {
1212 return None;
1213 };
1214 let is_trimmed = if let Some(range) = self.type_value_ranges.get(ty) {
1215 range.max == *value
1216 } else {
1217 corelib::try_extract_bounded_int_type_ranges(db, *ty)?.1 == *value
1218 };
1219 let arm_idx = if is_trimmed {
1220 0
1221 } else {
1222 let output = info.arms[1].var_ids[0];
1223 statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1224 1
1225 };
1226 Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1227 } else if id == self.array_get {
1228 let index = self.as_int(info.inputs[1].var_id)?.to_usize()?;
1229 if let Some(arr_info) = self.var_info.get(&info.inputs[0].var_id)
1230 && let VarInfo::Snapshot(arr_info) = arr_info.as_ref()
1231 && let VarInfo::Array(infos) = arr_info.as_ref()
1232 {
1233 match infos.get(index) {
1234 Some(Some(output_var_info)) => {
1235 let arm = &info.arms[0];
1236 let output_var_info = output_var_info.clone();
1237 self.var_info.insert(
1238 arm.var_ids[0],
1239 VarInfo::Box(VarInfo::Snapshot(output_var_info.clone()).into()).into(),
1240 );
1241 if let VarInfo::Const(value) = output_var_info.as_ref() {
1242 let value_ty = value.ty(db).ok()?;
1243 let value_box_ty = corelib::core_box_ty(db, value_ty);
1244 let location = info.location;
1245 let boxed_var =
1246 Variable::with_default_context(db, value_box_ty, location);
1247 let boxed = self.variables.alloc(boxed_var.clone());
1248 let unused_boxed = self.variables.alloc(boxed_var);
1249 let snapped = self.variables.alloc(Variable::with_default_context(
1250 db,
1251 TypeLongId::Snapshot(value_box_ty).intern(db),
1252 location,
1253 ));
1254 statements.extend([
1255 Statement::Const(StatementConst::new_boxed(*value, boxed)),
1256 Statement::Snapshot(StatementSnapshot {
1257 input: VarUsage { var_id: boxed, location },
1258 outputs: [unused_boxed, snapped],
1259 }),
1260 Statement::Call(StatementCall {
1261 function: self
1262 .box_forward_snapshot
1263 .concretize(db, vec![GenericArgumentId::Type(value_ty)])
1264 .lowered(db),
1265 inputs: vec![VarUsage { var_id: snapped, location }],
1266 with_coupon: false,
1267 outputs: vec![arm.var_ids[0]],
1268 location: info.location,
1269 is_specialization_base_call: false,
1270 }),
1271 ]);
1272 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1273 }
1274 }
1275 None => {
1276 return Some(BlockEnd::Goto(info.arms[1].block_id, Default::default()));
1277 }
1278 Some(None) => {}
1279 }
1280 }
1281 if index.is_zero()
1282 && let [success, failure] = info.arms.as_mut_slice()
1283 {
1284 let arr = info.inputs[0].var_id;
1285 let unused_arr_output0 = self.variables.alloc(self.variables[arr].clone());
1286 let unused_arr_output1 = self.variables.alloc(self.variables[arr].clone());
1287 info.inputs.truncate(1);
1288 info.function = GenericFunctionId::Extern(self.array_snapshot_pop_front)
1289 .concretize(db, generic_args)
1290 .lowered(db);
1291 success.var_ids.insert(0, unused_arr_output0);
1292 failure.var_ids.insert(0, unused_arr_output1);
1293 }
1294 None
1295 } else if id == self.array_pop_front {
1296 let VarInfo::Array(var_infos) = self.var_info.get(&info.inputs[0].var_id)?.as_ref()
1297 else {
1298 return None;
1299 };
1300 if let Some(first) = var_infos.first() {
1301 if let Some(first) = first.as_ref().cloned() {
1302 let arm = &info.arms[0];
1303 self.var_info
1304 .insert(arm.var_ids[0], VarInfo::Array(var_infos[1..].to_vec()).into());
1305 self.var_info.insert(arm.var_ids[1], VarInfo::Box(first).into());
1306 }
1307 None
1308 } else {
1309 let arm = &info.arms[1];
1310 self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]).into());
1311 Some(BlockEnd::Goto(
1312 arm.block_id,
1313 VarRemapping {
1314 remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1315 },
1316 ))
1317 }
1318 } else if id == self.array_snapshot_pop_back || id == self.array_snapshot_pop_front {
1319 let var_info = self.var_info.get(&info.inputs[0].var_id)?;
1320 let desnapped = try_extract_matches!(var_info.as_ref(), VarInfo::Snapshot)?;
1321 let element_var_infos = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?;
1322 if element_var_infos.is_empty() {
1324 let arm = &info.arms[1];
1325 self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]).into());
1326 Some(BlockEnd::Goto(
1327 arm.block_id,
1328 VarRemapping {
1329 remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1330 },
1331 ))
1332 } else {
1333 None
1334 }
1335 } else {
1336 None
1337 }
1338 }
1339
1340 fn as_const(&self, var_id: VariableId) -> Option<ConstValueId<'db>> {
1342 try_extract_matches!(self.var_info.get(&var_id)?.as_ref(), VarInfo::Const).copied()
1343 }
1344
1345 fn as_int(&self, var_id: VariableId) -> Option<&BigInt> {
1347 match self.as_const(var_id)?.long(self.db) {
1348 ConstValue::Int(value, _) => Some(value),
1349 ConstValue::NonZero(const_value) => {
1350 if let ConstValue::Int(value, _) = const_value.long(self.db) {
1351 Some(value)
1352 } else {
1353 None
1354 }
1355 }
1356 _ => None,
1357 }
1358 }
1359
1360 fn maybe_replace_inputs(&self, inputs: &mut [VarUsage<'db>]) {
1362 for input in inputs {
1363 self.maybe_replace_input(input);
1364 }
1365 }
1366
1367 fn maybe_replace_input(&self, input: &mut VarUsage<'db>) {
1369 if let Some(info) = self.var_info.get(&input.var_id)
1370 && let VarInfo::Var(new_var) = info.as_ref()
1371 {
1372 *input = *new_var;
1373 }
1374 }
1375
1376 fn try_get_specialization_arg(
1382 &self,
1383 var_info: Rc<VarInfo<'db>>,
1384 ty: TypeId<'db>,
1385 unknown_vars: &mut Vec<VarUsage<'db>>,
1386 coerce: Option<&SpecializationArg<'db>>,
1387 ) -> Option<SpecializationArg<'db>> {
1388 require(self.db.type_size_info(ty).ok()? != TypeSizeInformation::ZeroSized)?;
1390 require(!matches!(coerce, Some(SpecializationArg::NotSpecialized)))?;
1392
1393 match var_info.as_ref() {
1394 VarInfo::Const(value) => {
1395 let res = const_to_specialization_arg(self.db, *value, false);
1396 let Some(coerce) = coerce else {
1397 return Some(res);
1398 };
1399 if *coerce == res { Some(res) } else { None }
1400 }
1401 VarInfo::Box(info) => {
1402 let res = try_extract_matches!(info.as_ref(), VarInfo::Const)
1403 .map(|value| SpecializationArg::Const { value: *value, boxed: true });
1404 let Some(coerce) = coerce else {
1405 return res;
1406 };
1407 if Some(coerce.clone()) == res { res } else { None }
1408 }
1409 VarInfo::Snapshot(info) => {
1410 let desnap_ty = *extract_matches!(ty.long(self.db), TypeLongId::Snapshot);
1411 let mut local_unknown_vars: Vec<VarUsage<'db>> = Vec::new();
1413 let inner = self.try_get_specialization_arg(
1414 info.clone(),
1415 desnap_ty,
1416 &mut local_unknown_vars,
1417 coerce.map(|coerce| {
1418 extract_matches!(coerce, SpecializationArg::Snapshot).as_ref()
1419 }),
1420 )?;
1421 unknown_vars.extend(local_unknown_vars);
1422 Some(SpecializationArg::Snapshot(Box::new(inner)))
1423 }
1424 VarInfo::Array(infos) => {
1425 let TypeLongId::Concrete(concrete_ty) = ty.long(self.db) else {
1426 unreachable!("Expected a concrete type");
1427 };
1428 let [GenericArgumentId::Type(inner_ty)] = &concrete_ty.generic_args(self.db)[..]
1429 else {
1430 unreachable!("Expected a single type generic argument");
1431 };
1432 let coerces = match coerce {
1433 Some(coerce) => {
1434 let SpecializationArg::Array(ty, specialization_args) = coerce else {
1435 unreachable!("Expected an array specialization argument");
1436 };
1437 assert_eq!(ty, inner_ty);
1438 if specialization_args.len() != infos.len() {
1439 return None;
1440 }
1441
1442 specialization_args.iter().map(Some).collect()
1443 }
1444 None => vec![None; infos.len()],
1445 };
1446 let mut vars = vec![];
1448 let mut args = vec![];
1449 for (info, coerce) in zip_eq(infos, coerces) {
1450 let info = info.as_ref()?.clone();
1451 let arg =
1452 self.try_get_specialization_arg(info, *inner_ty, &mut vars, coerce)?;
1453 args.push(arg);
1454 }
1455 if !args.is_empty()
1456 && args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
1457 {
1458 return None;
1459 }
1460 unknown_vars.extend(vars);
1461 Some(SpecializationArg::Array(*inner_ty, args))
1462 }
1463 VarInfo::Struct(infos) => {
1464 let element_types: Vec<TypeId<'db>> = match ty.long(self.db) {
1466 TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) => {
1467 let members = self.db.concrete_struct_members(*concrete_struct).unwrap();
1468 members.values().map(|member| member.ty).collect()
1469 }
1470 TypeLongId::Tuple(element_types) => element_types.clone(),
1471 TypeLongId::FixedSizeArray { type_id, .. } => vec![*type_id; infos.len()],
1472 _ => return None,
1474 };
1475
1476 let coerces = match coerce {
1477 Some(SpecializationArg::Struct(specialization_args)) => {
1478 assert_eq!(specialization_args.len(), infos.len());
1479 specialization_args.iter().map(Some).collect()
1480 }
1481 Some(_) => unreachable!("Expected a struct specialization argument"),
1482 None => vec![None; infos.len()],
1483 };
1484
1485 let mut struct_args = Vec::new();
1486 let mut vars = vec![];
1488 for ((elem_ty, opt_var_info), coerce) in
1489 zip_eq(zip_eq(element_types, infos), coerces)
1490 {
1491 let var_info = opt_var_info.as_ref()?.clone();
1492 let arg =
1493 self.try_get_specialization_arg(var_info, elem_ty, &mut vars, coerce)?;
1494 struct_args.push(arg);
1495 }
1496 if !struct_args.is_empty()
1497 && struct_args
1498 .iter()
1499 .all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
1500 {
1501 return None;
1502 }
1503 unknown_vars.extend(vars);
1504 Some(SpecializationArg::Struct(struct_args))
1505 }
1506 VarInfo::Enum { variant, payload } => {
1507 let coerce = match coerce {
1508 Some(coerce) => {
1509 let SpecializationArg::Enum { variant: coercion_variant, payload } = coerce
1510 else {
1511 unreachable!("Expected an enum specialization argument");
1512 };
1513 if coercion_variant != variant {
1514 return None;
1515 }
1516 Some(payload.as_ref())
1517 }
1518 None => None,
1519 };
1520 let mut local_unknown_vars = vec![];
1521 let payload_arg = self.try_get_specialization_arg(
1522 payload.clone(),
1523 variant.ty,
1524 &mut local_unknown_vars,
1525 coerce,
1526 )?;
1527
1528 unknown_vars.extend(local_unknown_vars);
1529 Some(SpecializationArg::Enum { variant: *variant, payload: Box::new(payload_arg) })
1530 }
1531 VarInfo::Var(var_usage) => {
1532 unknown_vars.push(*var_usage);
1533 Some(SpecializationArg::NotSpecialized)
1534 }
1535 }
1536 }
1537
1538 pub fn should_skip_const_folding(&self, db: &'db dyn Database) -> bool {
1540 if db.optimizations().skip_const_folding() {
1541 return true;
1542 }
1543
1544 if self.caller_function.base_semantic_function(db).generic_function(db)
1547 == GenericFunctionWithBodyId::Free(self.libfunc_info.panic_with_const_felt252)
1548 {
1549 return true;
1550 }
1551 false
1552 }
1553}
1554
1555fn var_info_if_copy<'db>(
1557 variables: &VariableArena<'db>,
1558 input: VarUsage<'db>,
1559) -> Option<Rc<VarInfo<'db>>> {
1560 variables[input.var_id].info.copyable.is_ok().then(|| VarInfo::Var(input).into())
1561}
1562
1563#[salsa::tracked(returns(ref))]
1565fn priv_const_folding_info<'db>(
1566 db: &'db dyn Database,
1567) -> crate::optimizations::const_folding::ConstFoldingLibfuncInfo<'db> {
1568 ConstFoldingLibfuncInfo::new(db)
1569}
1570
1571#[derive(Debug, PartialEq, Eq, salsa::Update)]
1573pub struct ConstFoldingLibfuncInfo<'db> {
1574 felt_sub: ExternFunctionId<'db>,
1576 felt_add: ExternFunctionId<'db>,
1578 felt_mul: ExternFunctionId<'db>,
1580 felt_div: ExternFunctionId<'db>,
1582 box_forward_snapshot: GenericFunctionId<'db>,
1584 eq_fns: OrderedHashSet<ExternFunctionId<'db>>,
1586 uadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1588 usub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1590 diff_fns: OrderedHashSet<ExternFunctionId<'db>>,
1592 iadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1594 isub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1596 wide_mul_fns: OrderedHashSet<ExternFunctionId<'db>>,
1598 div_rem_fns: OrderedHashSet<ExternFunctionId<'db>>,
1600 bounded_int_add: ExternFunctionId<'db>,
1602 bounded_int_sub: ExternFunctionId<'db>,
1604 bounded_int_constrain: ExternFunctionId<'db>,
1606 bounded_int_trim_min: ExternFunctionId<'db>,
1608 bounded_int_trim_max: ExternFunctionId<'db>,
1610 array_get: ExternFunctionId<'db>,
1612 array_snapshot_pop_front: ExternFunctionId<'db>,
1614 array_snapshot_pop_back: ExternFunctionId<'db>,
1616 array_len: ExternFunctionId<'db>,
1618 array_new: ExternFunctionId<'db>,
1620 array_append: ExternFunctionId<'db>,
1622 array_pop_front: ExternFunctionId<'db>,
1624 storage_base_address_from_felt252: ExternFunctionId<'db>,
1626 storage_base_address_const: GenericFunctionId<'db>,
1628 panic_with_felt252: FunctionId<'db>,
1630 pub panic_with_const_felt252: FreeFunctionId<'db>,
1632 panic_with_byte_array: FunctionId<'db>,
1634 type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>>,
1636 const_calculation_info: Arc<ConstCalcInfo<'db>>,
1638}
1639impl<'db> ConstFoldingLibfuncInfo<'db> {
1640 fn new(db: &'db dyn Database) -> Self {
1641 let core = ModuleHelper::core(db);
1642 let box_module = core.submodule("box");
1643 let integer_module = core.submodule("integer");
1644 let internal_module = core.submodule("internal");
1645 let bounded_int_module = internal_module.submodule("bounded_int");
1646 let num_module = internal_module.submodule("num");
1647 let array_module = core.submodule("array");
1648 let starknet_module = core.submodule("starknet");
1649 let storage_access_module = starknet_module.submodule("storage_access");
1650 let utypes = ["u8", "u16", "u32", "u64", "u128"];
1651 let itypes = ["i8", "i16", "i32", "i64", "i128"];
1652 let eq_fns = OrderedHashSet::<_>::from_iter(
1653 chain!(utypes, itypes).map(|ty| integer_module.extern_function_id(&format!("{ty}_eq"))),
1654 );
1655 let uadd_fns = OrderedHashSet::<_>::from_iter(
1656 utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_add"))),
1657 );
1658 let usub_fns = OrderedHashSet::<_>::from_iter(
1659 utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_sub"))),
1660 );
1661 let diff_fns = OrderedHashSet::<_>::from_iter(
1662 itypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_diff"))),
1663 );
1664 let iadd_fns =
1665 OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1666 integer_module.extern_function_id(&format!("{ty}_overflowing_add_impl"))
1667 }));
1668 let isub_fns =
1669 OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1670 integer_module.extern_function_id(&format!("{ty}_overflowing_sub_impl"))
1671 }));
1672 let wide_mul_fns = OrderedHashSet::<_>::from_iter(chain!(
1673 [bounded_int_module.extern_function_id("bounded_int_mul")],
1674 ["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
1675 .map(|ty| integer_module.extern_function_id(&format!("{ty}_wide_mul"))),
1676 ));
1677 let div_rem_fns = OrderedHashSet::<_>::from_iter(chain!(
1678 [bounded_int_module.extern_function_id("bounded_int_div_rem")],
1679 utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_safe_divmod"))),
1680 ));
1681 let type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>> = OrderedHashMap::from_iter(
1682 [
1683 ("u8", false, true),
1684 ("u16", false, true),
1685 ("u32", false, true),
1686 ("u64", false, true),
1687 ("u128", false, true),
1688 ("u256", false, false),
1689 ("i8", true, true),
1690 ("i16", true, true),
1691 ("i32", true, true),
1692 ("i64", true, true),
1693 ("i128", true, true),
1694 ]
1695 .map(|(ty_name, as_bounded_int, inc_dec): (&'static str, bool, bool)| {
1696 let ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, ty_name), vec![]);
1697 let is_zero = if as_bounded_int {
1698 bounded_int_module
1699 .function_id("bounded_int_is_zero", vec![GenericArgumentId::Type(ty)])
1700 } else {
1701 integer_module.function_id(
1702 SmolStrId::from(db, format!("{ty_name}_is_zero")).long(db).as_str(),
1703 vec![],
1704 )
1705 }
1706 .lowered(db);
1707 let (inc, dec) = if inc_dec {
1708 (
1709 Some(
1710 num_module
1711 .function_id(
1712 SmolStrId::from(db, format!("{ty_name}_inc")).long(db).as_str(),
1713 vec![],
1714 )
1715 .lowered(db),
1716 ),
1717 Some(
1718 num_module
1719 .function_id(
1720 SmolStrId::from(db, format!("{ty_name}_dec")).long(db).as_str(),
1721 vec![],
1722 )
1723 .lowered(db),
1724 ),
1725 )
1726 } else {
1727 (None, None)
1728 };
1729 let info = TypeInfo { is_zero, inc, dec };
1730 (ty, info)
1731 }),
1732 );
1733 Self {
1734 felt_sub: core.extern_function_id("felt252_sub"),
1735 felt_add: core.extern_function_id("felt252_add"),
1736 felt_mul: core.extern_function_id("felt252_mul"),
1737 felt_div: core.extern_function_id("felt252_div"),
1738 box_forward_snapshot: box_module.generic_function_id("box_forward_snapshot"),
1739 eq_fns,
1740 uadd_fns,
1741 usub_fns,
1742 diff_fns,
1743 iadd_fns,
1744 isub_fns,
1745 wide_mul_fns,
1746 div_rem_fns,
1747 bounded_int_add: bounded_int_module.extern_function_id("bounded_int_add"),
1748 bounded_int_sub: bounded_int_module.extern_function_id("bounded_int_sub"),
1749 bounded_int_constrain: bounded_int_module.extern_function_id("bounded_int_constrain"),
1750 bounded_int_trim_min: bounded_int_module.extern_function_id("bounded_int_trim_min"),
1751 bounded_int_trim_max: bounded_int_module.extern_function_id("bounded_int_trim_max"),
1752 array_get: array_module.extern_function_id("array_get"),
1753 array_snapshot_pop_front: array_module.extern_function_id("array_snapshot_pop_front"),
1754 array_snapshot_pop_back: array_module.extern_function_id("array_snapshot_pop_back"),
1755 array_len: array_module.extern_function_id("array_len"),
1756 array_new: array_module.extern_function_id("array_new"),
1757 array_append: array_module.extern_function_id("array_append"),
1758 array_pop_front: array_module.extern_function_id("array_pop_front"),
1759 storage_base_address_from_felt252: storage_access_module
1760 .extern_function_id("storage_base_address_from_felt252"),
1761 storage_base_address_const: storage_access_module
1762 .generic_function_id("storage_base_address_const"),
1763 panic_with_felt252: core.function_id("panic_with_felt252", vec![]).lowered(db),
1764 panic_with_const_felt252: core.free_function_id("panic_with_const_felt252"),
1765 panic_with_byte_array: core
1766 .submodule("panics")
1767 .function_id("panic_with_byte_array", vec![])
1768 .lowered(db),
1769 type_info,
1770 const_calculation_info: db.const_calc_info(),
1771 }
1772 }
1773}
1774
1775impl<'db> std::ops::Deref for ConstFoldingContext<'db, '_> {
1776 type Target = ConstFoldingLibfuncInfo<'db>;
1777 fn deref(&self) -> &ConstFoldingLibfuncInfo<'db> {
1778 self.libfunc_info
1779 }
1780}
1781
1782impl<'a> std::ops::Deref for ConstFoldingLibfuncInfo<'a> {
1783 type Target = ConstCalcInfo<'a>;
1784 fn deref(&self) -> &ConstCalcInfo<'a> {
1785 &self.const_calculation_info
1786 }
1787}
1788
1789#[derive(Debug, PartialEq, Eq, salsa::Update)]
1791struct TypeInfo<'db> {
1792 is_zero: FunctionId<'db>,
1794 inc: Option<FunctionId<'db>>,
1796 dec: Option<FunctionId<'db>>,
1798}
1799
1800trait TypeRangeNormalizer {
1801 fn normalized(&self, value: BigInt) -> NormalizedResult;
1804}
1805impl TypeRangeNormalizer for TypeRange {
1806 fn normalized(&self, value: BigInt) -> NormalizedResult {
1807 if value < self.min {
1808 NormalizedResult::Under(value - &self.min + &self.max + 1)
1809 } else if value > self.max {
1810 NormalizedResult::Over(value + &self.min - &self.max - 1)
1811 } else {
1812 NormalizedResult::InRange(value)
1813 }
1814 }
1815}
1816
1817enum NormalizedResult {
1819 InRange(BigInt),
1821 Over(BigInt),
1823 Under(BigInt),
1825}