1#[cfg(test)]
2#[path = "const_folding_test.rs"]
3mod test;
4
5use std::rc::Rc;
6use std::sync::Arc;
7
8use cairo_lang_defs::ids::{ExternFunctionId, FreeFunctionId};
9use cairo_lang_filesystem::flag::FlagsGroup;
10use cairo_lang_filesystem::ids::SmolStrId;
11use cairo_lang_semantic::corelib::CorelibSemantic;
12use cairo_lang_semantic::helper::ModuleHelper;
13use cairo_lang_semantic::items::constant::{
14 ConstCalcInfo, ConstValue, ConstValueId, ConstantSemantic, TypeRange, canonical_felt252,
15 felt252_for_downcast,
16};
17use cairo_lang_semantic::items::functions::{GenericFunctionId, GenericFunctionWithBodyId};
18use cairo_lang_semantic::items::structure::StructSemantic;
19use cairo_lang_semantic::types::{TypeSizeInformation, TypesSemantic};
20use cairo_lang_semantic::{
21 ConcreteTypeId, ConcreteVariant, GenericArgumentId, MatchArmSelector, TypeId, TypeLongId,
22 corelib,
23};
24use cairo_lang_utils::byte_array::BYTE_ARRAY_MAGIC;
25use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
26use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
27use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
28use cairo_lang_utils::{Intern, extract_matches, require, try_extract_matches};
29use itertools::{Itertools, chain, zip_eq};
30use num_bigint::BigInt;
31use num_integer::Integer;
32use num_traits::cast::ToPrimitive;
33use num_traits::{Num, One, Zero};
34use salsa::Database;
35use starknet_types_core::felt::Felt as Felt252;
36
37use crate::db::LoweringGroup;
38use crate::ids::{
39 ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionId, SemanticFunctionIdEx,
40 SpecializedFunction,
41};
42use crate::specialization::SpecializationArg;
43use crate::utils::InliningStrategy;
44use crate::{
45 Block, BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, MatchArm, MatchEnumInfo,
46 MatchExternInfo, MatchInfo, Statement, StatementCall, StatementConst, StatementDesnap,
47 StatementEnumConstruct, StatementIntoBox, StatementSnapshot, StatementStructConstruct,
48 StatementStructDestructure, StatementUnbox, VarRemapping, VarUsage, Variable, VariableArena,
49 VariableId,
50};
51
52fn const_to_specialization_arg<'db>(
56 db: &'db dyn Database,
57 value: ConstValueId<'db>,
58 boxed: bool,
59) -> SpecializationArg<'db> {
60 match value.long(db) {
61 ConstValue::Struct(members, ty) => {
62 if matches!(
65 ty.long(db),
66 TypeLongId::Concrete(ConcreteTypeId::Struct(_))
67 | TypeLongId::Tuple(_)
68 | TypeLongId::FixedSizeArray { .. }
69 ) {
70 let args = members
71 .iter()
72 .map(|member| const_to_specialization_arg(db, *member, false))
73 .collect();
74 SpecializationArg::Struct(args)
75 } else {
76 SpecializationArg::Const { value, boxed }
77 }
78 }
79 ConstValue::Enum(variant, payload) => SpecializationArg::Enum {
80 variant: *variant,
81 payload: Box::new(const_to_specialization_arg(db, *payload, false)),
82 },
83 _ => SpecializationArg::Const { value, boxed },
84 }
85}
86
87#[derive(Debug, Clone)]
90enum VarInfo<'db> {
91 Const(ConstValueId<'db>),
93 Var(VarUsage<'db>),
95 Snapshot(Rc<VarInfo<'db>>),
97 Struct(Vec<Option<Rc<VarInfo<'db>>>>),
100 Enum { variant: ConcreteVariant<'db>, payload: Rc<VarInfo<'db>> },
102 Box(Rc<VarInfo<'db>>),
104 Array(Vec<Option<Rc<VarInfo<'db>>>>),
107}
108impl<'db> VarInfo<'db> {
109 fn peel_snapshots(mut self: Rc<Self>) -> (usize, Rc<VarInfo<'db>>) {
111 let mut n_snapshots = 0;
112 while let VarInfo::Snapshot(inner) = self.as_ref() {
113 self = inner.clone();
114 n_snapshots += 1;
115 }
116 (n_snapshots, self)
117 }
118 fn wrap_with_snapshots(mut self: Rc<Self>, n_snapshots: usize) -> Rc<VarInfo<'db>> {
120 for _ in 0..n_snapshots {
121 self = VarInfo::Snapshot(self).into();
122 }
123 self
124 }
125}
126
127#[derive(Debug, Clone, Copy, PartialEq)]
128enum Reachability {
129 FromSingleGoto(BlockId),
132 Any,
135}
136
137pub fn const_folding<'db>(
140 db: &'db dyn Database,
141 function_id: ConcreteFunctionWithBodyId<'db>,
142 lowered: &mut Lowered<'db>,
143) {
144 if lowered.blocks.is_empty() {
145 return;
146 }
147
148 let mut ctx = ConstFoldingContext::new(db, function_id, &mut lowered.variables);
151
152 if ctx.should_skip_const_folding(db) {
153 return;
154 }
155
156 for block_id in (0..lowered.blocks.len()).map(BlockId) {
157 if !ctx.visit_block_start(block_id, |block_id| &lowered.blocks[block_id]) {
158 continue;
159 }
160
161 let block = &mut lowered.blocks[block_id];
162 for stmt in block.statements.iter_mut() {
163 ctx.visit_statement(stmt);
164 }
165 ctx.visit_block_end(block_id, block);
166 }
167}
168
169pub struct ConstFoldingContext<'db, 'mt> {
170 db: &'db dyn Database,
172 pub variables: &'mt mut VariableArena<'db>,
174 var_info: UnorderedHashMap<VariableId, Rc<VarInfo<'db>>>,
176 libfunc_info: &'db ConstFoldingLibfuncInfo<'db>,
178 caller_function: ConcreteFunctionWithBodyId<'db>,
181 reachability: UnorderedHashMap<BlockId, Reachability>,
185 additional_stmts: Vec<Statement<'db>>,
187}
188
189impl<'db, 'mt> ConstFoldingContext<'db, 'mt> {
190 pub fn new(
191 db: &'db dyn Database,
192 function_id: ConcreteFunctionWithBodyId<'db>,
193 variables: &'mt mut VariableArena<'db>,
194 ) -> Self {
195 Self {
196 db,
197 var_info: UnorderedHashMap::default(),
198 variables,
199 libfunc_info: priv_const_folding_info(db),
200 caller_function: function_id,
201 reachability: UnorderedHashMap::from_iter([(BlockId::root(), Reachability::Any)]),
202 additional_stmts: vec![],
203 }
204 }
205
206 pub fn visit_block_start<'r, 'get>(
209 &'r mut self,
210 block_id: BlockId,
211 get_block: impl FnOnce(BlockId) -> &'get Block<'db>,
212 ) -> bool
213 where
214 'db: 'get,
215 {
216 let Some(reachability) = self.reachability.remove(&block_id) else {
217 return false;
218 };
219 match reachability {
220 Reachability::Any => {}
221 Reachability::FromSingleGoto(from_block) => match &get_block(from_block).end {
222 BlockEnd::Goto(_, remapping) => {
223 for (dst, src) in remapping.iter() {
224 if let Some(v) = self.as_const(src.var_id) {
225 self.var_info.insert(*dst, VarInfo::Const(v).into());
226 }
227 }
228 }
229 _ => unreachable!("Expected a goto end"),
230 },
231 }
232 true
233 }
234
235 pub fn visit_statement(&mut self, stmt: &mut Statement<'db>) {
245 self.maybe_replace_inputs(stmt.inputs_mut());
246 match stmt {
247 Statement::Const(StatementConst { value, output, boxed }) if *boxed => {
248 self.var_info.insert(*output, VarInfo::Box(VarInfo::Const(*value).into()).into());
249 }
250 Statement::Const(StatementConst { value, output, .. }) => match value.long(self.db) {
251 ConstValue::Int(..)
252 | ConstValue::Struct(..)
253 | ConstValue::Enum(..)
254 | ConstValue::NonZero(..) => {
255 self.var_info.insert(*output, VarInfo::Const(*value).into());
256 }
257 ConstValue::Generic(_)
258 | ConstValue::ImplConstant(_)
259 | ConstValue::Var(..)
260 | ConstValue::Missing(_) => {}
261 },
262 Statement::Snapshot(stmt) => {
263 if let Some(info) = self.var_info.get(&stmt.input.var_id) {
264 let info = info.clone();
265 self.var_info.insert(stmt.original(), info.clone());
266 self.var_info.insert(stmt.snapshot(), VarInfo::Snapshot(info).into());
267 }
268 }
269 Statement::Desnap(StatementDesnap { input, output }) => {
270 if let Some(info) = self.var_info.get(&input.var_id)
271 && let VarInfo::Snapshot(info) = info.as_ref()
272 {
273 self.var_info.insert(*output, info.clone());
274 }
275 }
276 Statement::Call(call_stmt) => {
277 if let Some(updated_stmt) = self.handle_statement_call(call_stmt) {
278 *stmt = updated_stmt;
279 } else if let Some(updated_stmt) = self.try_specialize_call(call_stmt) {
280 *stmt = updated_stmt;
281 }
282 }
283 Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
284 let mut const_args = vec![];
285 let mut all_args = vec![];
286 let mut contains_info = false;
287 for input in inputs.iter() {
288 let Some(info) = self.var_info.get(&input.var_id) else {
289 all_args.push(var_info_if_copy(self.variables, *input));
290 continue;
291 };
292 contains_info = true;
293 if let VarInfo::Const(value) = info.as_ref() {
294 const_args.push(*value);
295 }
296 all_args.push(Some(info.clone()));
297 }
298 if const_args.len() == inputs.len() {
299 let value =
300 ConstValue::Struct(const_args, self.variables[*output].ty).intern(self.db);
301 self.var_info.insert(*output, VarInfo::Const(value).into());
302 } else if contains_info {
303 self.var_info.insert(*output, VarInfo::Struct(all_args).into());
304 }
305 }
306 Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
307 if let Some(info) = self.var_info.get(&input.var_id) {
308 let (n_snapshots, info) = info.clone().peel_snapshots();
309 match info.as_ref() {
310 VarInfo::Const(const_value) => {
311 if let ConstValue::Struct(member_values, _) = const_value.long(self.db)
312 {
313 for (output, value) in zip_eq(outputs, member_values) {
314 self.var_info.insert(
315 *output,
316 Rc::new(VarInfo::Const(*value))
317 .wrap_with_snapshots(n_snapshots),
318 );
319 }
320 }
321 }
322 VarInfo::Struct(members) => {
323 for (output, member) in zip_eq(outputs, members.clone()) {
324 if let Some(member) = member {
325 self.var_info
326 .insert(*output, member.wrap_with_snapshots(n_snapshots));
327 }
328 }
329 }
330 _ => {}
331 }
332 }
333 }
334 Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
335 let value = if let Some(info) = self.var_info.get(&input.var_id) {
336 if let VarInfo::Const(val) = info.as_ref() {
337 VarInfo::Const(ConstValue::Enum(*variant, *val).intern(self.db))
338 } else {
339 VarInfo::Enum { variant: *variant, payload: info.clone() }
340 }
341 } else {
342 VarInfo::Enum { variant: *variant, payload: VarInfo::Var(*input).into() }
343 };
344 self.var_info.insert(*output, value.into());
345 }
346 Statement::IntoBox(StatementIntoBox { input, output }) => {
347 let var_info = self.var_info.get(&input.var_id);
348 let const_value = var_info.and_then(|var_info| match var_info.as_ref() {
349 VarInfo::Const(val) => Some(*val),
350 VarInfo::Snapshot(info) => {
351 try_extract_matches!(info.as_ref(), VarInfo::Const).copied()
352 }
353 _ => None,
354 });
355 let var_info =
356 var_info.cloned().or_else(|| var_info_if_copy(self.variables, *input));
357 if let Some(var_info) = var_info {
358 self.var_info.insert(*output, VarInfo::Box(var_info).into());
359 }
360
361 if let Some(const_value) = const_value {
362 *stmt = Statement::Const(StatementConst::new_boxed(const_value, *output));
363 }
364 }
365 Statement::Unbox(StatementUnbox { input, output }) => {
366 if let Some(inner) = self.var_info.get(&input.var_id)
367 && let VarInfo::Box(inner) = inner.as_ref()
368 {
369 let inner = inner.clone();
370 if let VarInfo::Const(inner) =
371 self.var_info.entry(*output).insert_entry(inner).get().as_ref()
372 {
373 *stmt = Statement::Const(StatementConst::new_flat(*inner, *output));
374 }
375 }
376 }
377 }
378 }
379
380 pub fn visit_block_end(&mut self, block_id: BlockId, block: &mut Block<'db>) {
387 let statements = &mut block.statements;
388 statements.splice(0..0, self.additional_stmts.drain(..));
389
390 match &mut block.end {
391 BlockEnd::Goto(_, remappings) => {
392 for (_, v) in remappings.iter_mut() {
393 self.maybe_replace_input(v);
394 }
395 }
396 BlockEnd::Match { info } => {
397 self.maybe_replace_inputs(info.inputs_mut());
398 match info {
399 MatchInfo::Enum(info) => {
400 if let Some(updated_end) = self.handle_enum_block_end(info, statements) {
401 block.end = updated_end;
402 }
403 }
404 MatchInfo::Extern(info) => {
405 if let Some(updated_end) = self.handle_extern_block_end(info, statements) {
406 block.end = updated_end;
407 }
408 }
409 MatchInfo::Value(info) => {
410 if let Some(value) =
411 self.as_int(info.input.var_id).and_then(|x| x.to_usize())
412 && let Some(arm) = info.arms.iter().find(|arm| {
413 matches!(
414 &arm.arm_selector,
415 MatchArmSelector::Value(v) if v.value == value
416 )
417 })
418 {
419 statements.push(Statement::StructConstruct(StatementStructConstruct {
421 inputs: vec![],
422 output: arm.var_ids[0],
423 }));
424 block.end = BlockEnd::Goto(arm.block_id, Default::default());
425 }
426 }
427 }
428 }
429 BlockEnd::Return(inputs, _) => self.maybe_replace_inputs(inputs),
430 BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
431 }
432 match &block.end {
433 BlockEnd::Goto(dst_block_id, _) => {
434 match self.reachability.entry(*dst_block_id) {
435 std::collections::hash_map::Entry::Occupied(mut e) => {
436 e.insert(Reachability::Any)
437 }
438 std::collections::hash_map::Entry::Vacant(e) => {
439 *e.insert(Reachability::FromSingleGoto(block_id))
440 }
441 };
442 }
443 BlockEnd::Match { info } => {
444 for arm in info.arms() {
445 assert!(self.reachability.insert(arm.block_id, Reachability::Any).is_none());
446 }
447 }
448 BlockEnd::NotSet | BlockEnd::Return(..) | BlockEnd::Panic(..) => {}
449 }
450 }
451
452 fn handle_statement_call(&mut self, stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
460 let db = self.db;
461 if stmt.function == self.panic_with_felt252 {
462 let val = self.as_const(stmt.inputs[0].var_id)?;
463 stmt.inputs.clear();
464 stmt.function = GenericFunctionId::Free(self.panic_with_const_felt252)
465 .concretize(db, vec![GenericArgumentId::Constant(val)])
466 .lowered(db);
467 return None;
468 } else if stmt.function == self.panic_with_byte_array && !db.flag_unsafe_panic() {
469 let snap = self.var_info.get(&stmt.inputs[0].var_id)?;
470 let bytearray = try_extract_matches!(snap.as_ref(), VarInfo::Snapshot)?;
471 let [Some(data), Some(pending_word), Some(pending_len)] =
472 &try_extract_matches!(bytearray.as_ref(), VarInfo::Struct)?[..]
473 else {
474 return None;
475 };
476 let data = try_extract_matches!(data.as_ref(), VarInfo::Array)?;
477 let pending_word = try_extract_matches!(pending_word.as_ref(), VarInfo::Const)?;
478 let pending_len = try_extract_matches!(pending_len.as_ref(), VarInfo::Const)?;
479 let mut panic_data =
480 vec![BigInt::from_str_radix(BYTE_ARRAY_MAGIC, 16).unwrap(), data.len().into()];
481 for word in data {
482 let VarInfo::Const(word) = word.as_ref()?.as_ref() else {
483 return None;
484 };
485 panic_data.push(word.to_int(db)?.clone());
486 }
487 panic_data.extend([pending_word.to_int(db)?.clone(), pending_len.to_int(db)?.clone()]);
488 let felt252_ty = self.felt252;
489 let location = stmt.location;
490 let new_var = |ty| Variable::with_default_context(db, ty, location);
491 let as_usage = |var_id| VarUsage { var_id, location };
492 let array_fn = |extern_id| {
493 let args = vec![GenericArgumentId::Type(felt252_ty)];
494 GenericFunctionId::Extern(extern_id).concretize(db, args).lowered(db)
495 };
496 let call_stmt = |function, inputs, outputs| {
497 let with_coupon = false;
498 Statement::Call(StatementCall {
499 function,
500 inputs,
501 with_coupon,
502 outputs,
503 location,
504 is_specialization_base_call: false,
505 })
506 };
507 let arr_var = new_var(corelib::core_array_felt252_ty(db));
508 let mut arr = self.variables.alloc(arr_var.clone());
509 self.additional_stmts.push(call_stmt(array_fn(self.array_new), vec![], vec![arr]));
510 let felt252_var = new_var(felt252_ty);
511 let arr_append_fn = array_fn(self.array_append);
512 for word in panic_data {
513 let to_append = self.variables.alloc(felt252_var.clone());
514 let new_arr = self.variables.alloc(arr_var.clone());
515 self.additional_stmts.push(Statement::Const(StatementConst::new_flat(
516 ConstValue::Int(word, felt252_ty).intern(db),
517 to_append,
518 )));
519 self.additional_stmts.push(call_stmt(
520 arr_append_fn,
521 vec![as_usage(arr), as_usage(to_append)],
522 vec![new_arr],
523 ));
524 arr = new_arr;
525 }
526 let panic_ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, "Panic"), vec![]);
527 let panic_var = self.variables.alloc(new_var(panic_ty));
528 self.additional_stmts.push(Statement::StructConstruct(StatementStructConstruct {
529 inputs: vec![],
530 output: panic_var,
531 }));
532 return Some(Statement::StructConstruct(StatementStructConstruct {
533 inputs: vec![as_usage(panic_var), as_usage(arr)],
534 output: stmt.outputs[0],
535 }));
536 }
537 let (id, _generic_args) = stmt.function.get_extern(db)?;
538 if id == self.felt_sub {
539 if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
540 && rhs.is_zero()
541 {
542 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]).into());
543 None
544 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
545 && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
546 {
547 let value = canonical_felt252(&(lhs - rhs));
548 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
549 } else {
550 None
551 }
552 } else if id == self.felt_add {
553 if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
554 && lhs.is_zero()
555 {
556 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]).into());
557 None
558 } else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
559 && rhs.is_zero()
560 {
561 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]).into());
562 None
563 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
564 && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
565 {
566 let value = canonical_felt252(&(lhs + rhs));
567 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
568 } else {
569 None
570 }
571 } else if id == self.felt_mul {
572 let lhs = self.as_int(stmt.inputs[0].var_id);
573 let rhs = self.as_int(stmt.inputs[1].var_id);
574 if lhs.map(Zero::is_zero).unwrap_or_default()
575 || rhs.map(Zero::is_zero).unwrap_or_default()
576 {
577 Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
578 } else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
579 && rhs.is_one()
580 {
581 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]).into());
582 None
583 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
584 && lhs.is_one()
585 {
586 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]).into());
587 None
588 } else if let Some(lhs) = lhs
589 && let Some(rhs) = rhs
590 {
591 let value = canonical_felt252(&(lhs * rhs));
592 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
593 } else {
594 None
595 }
596 } else if id == self.felt_div {
597 if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
599 && rhs.is_one()
601 {
602 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]).into());
603 None
604 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
605 && lhs.is_zero()
607 {
608 Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
609 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
610 && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
611 && let Ok(rhs_nonzero) = Felt252::from(rhs).try_into()
612 {
613 let lhs_felt = Felt252::from(lhs);
617 let value = lhs_felt.field_div(&rhs_nonzero).to_bigint();
618 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
619 } else {
620 None
621 }
622 } else if self.wide_mul_fns.contains(&id) {
623 let lhs = self.as_int(stmt.inputs[0].var_id);
624 let rhs = self.as_int(stmt.inputs[1].var_id);
625 let output = stmt.outputs[0];
626 if lhs.map(Zero::is_zero).unwrap_or_default()
627 || rhs.map(Zero::is_zero).unwrap_or_default()
628 {
629 return Some(self.propagate_zero_and_get_statement(output));
630 }
631 let lhs = lhs?;
632 Some(self.propagate_const_and_get_statement(lhs * rhs?, stmt.outputs[0]))
633 } else if id == self.bounded_int_add || id == self.bounded_int_sub {
634 let lhs = self.as_int(stmt.inputs[0].var_id)?;
635 let rhs = self.as_int(stmt.inputs[1].var_id)?;
636 let value = if id == self.bounded_int_add { lhs + rhs } else { lhs - rhs };
637 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
638 } else if self.div_rem_fns.contains(&id) {
639 let lhs = self.as_int(stmt.inputs[0].var_id);
640 if lhs.map(Zero::is_zero).unwrap_or_default() {
641 let additional_stmt = self.propagate_zero_and_get_statement(stmt.outputs[1]);
642 self.additional_stmts.push(additional_stmt);
643 return Some(self.propagate_zero_and_get_statement(stmt.outputs[0]));
644 }
645 let rhs = self.as_int(stmt.inputs[1].var_id)?;
646 let (q, r) = lhs?.div_rem(rhs);
647 let q_output = stmt.outputs[0];
648 let q_value = ConstValue::Int(q, self.variables[q_output].ty).intern(db);
649 self.var_info.insert(q_output, VarInfo::Const(q_value).into());
650 let r_output = stmt.outputs[1];
651 let r_value = ConstValue::Int(r, self.variables[r_output].ty).intern(db);
652 self.var_info.insert(r_output, VarInfo::Const(r_value).into());
653 self.additional_stmts
654 .push(Statement::Const(StatementConst::new_flat(r_value, r_output)));
655 Some(Statement::Const(StatementConst::new_flat(q_value, q_output)))
656 } else if id == self.storage_base_address_from_felt252 {
657 let input_var = stmt.inputs[0].var_id;
658 if let Some(const_value) = self.as_const(input_var)
659 && let ConstValue::Int(val, ty) = const_value.long(db)
660 {
661 stmt.inputs.clear();
662 let arg = GenericArgumentId::Constant(ConstValue::Int(val.clone(), *ty).intern(db));
663 stmt.function =
664 self.storage_base_address_const.concretize(db, vec![arg]).lowered(db);
665 }
666 None
667 } else if self.upcast_fns.contains(&id) {
668 let int_value = self.as_int(stmt.inputs[0].var_id)?;
669 let output = stmt.outputs[0];
670 let value = ConstValue::Int(int_value.clone(), self.variables[output].ty).intern(db);
671 self.var_info.insert(output, VarInfo::Const(value).into());
672 Some(Statement::Const(StatementConst::new_flat(value, output)))
673 } else if id == self.array_new {
674 self.var_info.insert(stmt.outputs[0], VarInfo::Array(vec![]).into());
675 None
676 } else if id == self.array_append {
677 let mut var_infos = if let VarInfo::Array(var_infos) =
678 self.var_info.get(&stmt.inputs[0].var_id)?.as_ref()
679 {
680 var_infos.clone()
681 } else {
682 return None;
683 };
684 let appended = stmt.inputs[1];
685 var_infos.push(match self.var_info.get(&appended.var_id) {
686 Some(var_info) => Some(var_info.clone()),
687 None => var_info_if_copy(self.variables, appended),
688 });
689 self.var_info.insert(stmt.outputs[0], VarInfo::Array(var_infos).into());
690 None
691 } else if id == self.array_len {
692 let info = self.var_info.get(&stmt.inputs[0].var_id)?;
693 let desnapped = try_extract_matches!(info.as_ref(), VarInfo::Snapshot)?;
694 let length = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?.len();
695 Some(self.propagate_const_and_get_statement(length.into(), stmt.outputs[0]))
696 } else {
697 None
698 }
699 }
700
701 fn try_specialize_call(&self, call_stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
708 if call_stmt.with_coupon {
709 return None;
710 }
711 if matches!(self.db.optimizations().inlining_strategy(), InliningStrategy::Avoid) {
713 return None;
714 }
715
716 let Ok(Some(mut called_function)) = call_stmt.function.body(self.db) else {
717 return None;
718 };
719
720 let extract_base = |function: ConcreteFunctionWithBodyId<'db>| match function.long(self.db)
721 {
722 ConcreteFunctionWithBodyLongId::Specialized(specialized) => {
723 specialized.long(self.db).base
724 }
725 _ => function,
726 };
727 let called_base = extract_base(called_function);
728 let caller_base = extract_base(self.caller_function);
729
730 if self.db.priv_never_inline(called_base).ok()? {
731 return None;
732 }
733
734 if call_stmt.is_specialization_base_call {
736 return None;
737 }
738
739 if called_base == caller_base && called_function != called_base {
741 return None;
742 }
743
744 let scc =
747 self.db.lowered_scc(called_base, DependencyType::Call, LoweringStage::Monomorphized);
748 if scc.len() > 1 && scc.contains(&caller_base) {
749 return None;
750 }
751
752 if call_stmt.inputs.iter().all(|arg| self.var_info.get(&arg.var_id).is_none()) {
753 return None;
755 }
756
757 let self_specializition = if let ConcreteFunctionWithBodyLongId::Specialized(specialized) =
759 self.caller_function.long(self.db)
760 && caller_base == called_base
761 {
762 specialized.long(self.db).args.iter().map(Some).collect()
763 } else {
764 vec![None; call_stmt.inputs.len()]
765 };
766
767 let mut specialization_args = vec![];
768 let mut new_args = vec![];
769 for (arg, coerce) in zip_eq(&call_stmt.inputs, &self_specializition) {
770 if let Some(var_info) = self.var_info.get(&arg.var_id)
771 && self.variables[arg.var_id].info.droppable.is_ok()
772 && let Some(specialization_arg) = self.try_get_specialization_arg(
773 var_info.clone(),
774 self.variables[arg.var_id].ty,
775 &mut new_args,
776 *coerce,
777 )
778 {
779 specialization_args.push(specialization_arg);
780 } else {
781 specialization_args.push(SpecializationArg::NotSpecialized);
782 new_args.push(*arg);
783 continue;
784 };
785 }
786
787 if specialization_args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized)) {
788 return None;
790 }
791 if let ConcreteFunctionWithBodyLongId::Specialized(specialized_function) =
792 called_function.long(self.db)
793 {
794 let specialized_function = specialized_function.long(self.db);
795 called_function = specialized_function.base;
798 let mut new_args_iter = specialization_args.into_iter();
799 let mut old_args = specialized_function.args.clone();
800 let mut stack = vec![];
801 for arg in old_args.iter_mut().rev() {
802 stack.push(arg);
803 }
804 while let Some(arg) = stack.pop() {
805 match arg {
806 SpecializationArg::Const { .. } => {}
807 SpecializationArg::Snapshot(inner) => {
808 stack.push(inner.as_mut());
809 }
810 SpecializationArg::Enum { payload, .. } => {
811 stack.push(payload.as_mut());
812 }
813 SpecializationArg::Array(_, values) | SpecializationArg::Struct(values) => {
814 for value in values.iter_mut().rev() {
815 stack.push(value);
816 }
817 }
818 SpecializationArg::NotSpecialized => {
819 *arg = new_args_iter.next().unwrap_or(SpecializationArg::NotSpecialized);
820 }
821 }
822 }
823 specialization_args = old_args;
824 }
825 let specialized = SpecializedFunction { base: called_function, args: specialization_args }
826 .intern(self.db);
827 let specialized_func_id =
828 ConcreteFunctionWithBodyLongId::Specialized(specialized).intern(self.db);
829
830 if caller_base != called_base
831 && self.db.priv_should_specialize(specialized_func_id) == Ok(false)
832 {
833 return None;
834 }
835
836 Some(Statement::Call(StatementCall {
837 function: specialized_func_id.function_id(self.db).unwrap(),
838 inputs: new_args,
839 with_coupon: call_stmt.with_coupon,
840 outputs: std::mem::take(&mut call_stmt.outputs),
841 location: call_stmt.location,
842 is_specialization_base_call: false,
843 }))
844 }
845
846 fn propagate_const_and_get_statement(
848 &mut self,
849 value: BigInt,
850 output: VariableId,
851 ) -> Statement<'db> {
852 let ty = self.variables[output].ty;
853 let value = ConstValueId::from_int(self.db, ty, &value);
854 self.var_info.insert(output, VarInfo::Const(value).into());
855 Statement::Const(StatementConst::new_flat(value, output))
856 }
857
858 fn propagate_zero_and_get_statement(&mut self, output: VariableId) -> Statement<'db> {
860 self.propagate_const_and_get_statement(BigInt::zero(), output)
861 }
862
863 fn try_generate_const_statement(
865 &self,
866 value: ConstValueId<'db>,
867 output: VariableId,
868 ) -> Option<Statement<'db>> {
869 if self.db.type_size_info(self.variables[output].ty) == Ok(TypeSizeInformation::Other) {
870 Some(Statement::Const(StatementConst::new_flat(value, output)))
871 } else if matches!(value.long(self.db), ConstValue::Struct(members, _) if members.is_empty())
872 {
873 Some(Statement::StructConstruct(StatementStructConstruct { inputs: vec![], output }))
875 } else {
876 None
877 }
878 }
879
880 fn handle_enum_block_end(
885 &mut self,
886 info: &mut MatchEnumInfo<'db>,
887 statements: &mut Vec<Statement<'db>>,
888 ) -> Option<BlockEnd<'db>> {
889 let input = info.input.var_id;
890 let (n_snapshots, var_info) = self.var_info.get(&input)?.clone().peel_snapshots();
891 let location = info.location;
892 let as_usage = |var_id| VarUsage { var_id, location };
893 let db = self.db;
894 let snapshot_stmt = |vars: &mut VariableArena<'_>, pre_snap, post_snap| {
895 let ignored = vars.alloc(vars[pre_snap].clone());
896 Statement::Snapshot(StatementSnapshot::new(as_usage(pre_snap), ignored, post_snap))
897 };
898 if let VarInfo::Const(const_value) = var_info.as_ref()
900 && let ConstValue::Enum(variant, value) = const_value.long(db)
901 {
902 let arm = &info.arms[variant.idx];
903 let output = arm.var_ids[0];
904 self.var_info
906 .insert(output, Rc::new(VarInfo::Const(*value)).wrap_with_snapshots(n_snapshots));
907 if self.variables[input].info.droppable.is_ok()
908 && self.variables[output].info.copyable.is_ok()
909 && let Ok(mut ty) = value.ty(db)
910 && let Some(mut stmt) = self.try_generate_const_statement(*value, output)
911 {
912 let snapshot_vars = (0..n_snapshots)
913 .map(|_| {
914 let old_ty = ty;
915 ty = TypeLongId::Snapshot(ty).intern(db);
916 self.variables.alloc(Variable::with_default_context(db, old_ty, location))
917 })
918 .chain([output])
919 .collect::<Vec<_>>();
920 stmt.outputs_mut()[0] = snapshot_vars[0];
922 statements.push(stmt);
923 statements.extend(snapshot_vars.into_iter().tuple_windows().map(
925 |(pre_snap, post_snap)| snapshot_stmt(self.variables, pre_snap, post_snap),
926 ));
927
928 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
929 }
930 } else if let VarInfo::Enum { variant, payload } = var_info.as_ref() {
931 let arm = &info.arms[variant.idx];
932 let variant_ty = variant.ty;
933 let output = arm.var_ids[0];
934 let payload = payload.clone();
935 let unwrapped =
936 self.variables[input].info.droppable.is_ok().then_some(()).and_then(|_| {
937 let (extra_snapshots, inner) = payload.clone().peel_snapshots();
938 match inner.as_ref() {
939 VarInfo::Var(var) if self.variables[var.var_id].info.copyable.is_ok() => {
940 Some((var.var_id, extra_snapshots))
941 }
942 VarInfo::Const(value) => {
943 let const_var = self
944 .variables
945 .alloc(Variable::with_default_context(db, variant_ty, location));
946 statements.push(self.try_generate_const_statement(*value, const_var)?);
947 Some((const_var, extra_snapshots))
948 }
949 _ => None,
950 }
951 });
952 self.var_info.insert(output, payload.wrap_with_snapshots(n_snapshots));
954 if let Some((mut unwrapped, extra_snapshots)) = unwrapped {
955 let total_snapshots = n_snapshots + extra_snapshots;
956 if total_snapshots != 0 {
957 for _ in 1..total_snapshots {
959 let ty = TypeLongId::Snapshot(self.variables[unwrapped].ty).intern(db);
960 let non_snap_var = Variable::with_default_context(self.db, ty, location);
961 let snapped = self.variables.alloc(non_snap_var);
962 statements.push(snapshot_stmt(self.variables, unwrapped, snapped));
963 unwrapped = snapped;
964 }
965 statements.push(snapshot_stmt(self.variables, unwrapped, output));
966 };
967 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
968 }
969 }
970 None
971 }
972
973 fn handle_extern_block_end(
978 &mut self,
979 info: &mut MatchExternInfo<'db>,
980 statements: &mut Vec<Statement<'db>>,
981 ) -> Option<BlockEnd<'db>> {
982 let db = self.db;
983 let (id, generic_args) = info.function.get_extern(db)?;
984 if self.nz_fns.contains(&id) {
985 let val = self.as_const(info.inputs[0].var_id)?;
986 let is_zero = match val.long(db) {
987 ConstValue::Int(v, _) => v.is_zero(),
988 ConstValue::Struct(s, _) => s
989 .iter()
990 .all(|v| v.to_int(db).expect("Expected ConstValue::Int for size").is_zero()),
991 _ => unreachable!(),
992 };
993 Some(if is_zero {
994 BlockEnd::Goto(info.arms[0].block_id, Default::default())
995 } else {
996 let arm = &info.arms[1];
997 let nz_var = arm.var_ids[0];
998 let nz_val = ConstValue::NonZero(val).intern(db);
999 self.var_info.insert(nz_var, VarInfo::Const(nz_val).into());
1000 statements.push(Statement::Const(StatementConst::new_flat(nz_val, nz_var)));
1001 BlockEnd::Goto(arm.block_id, Default::default())
1002 })
1003 } else if self.eq_fns.contains(&id) {
1004 let lhs = self.as_int(info.inputs[0].var_id);
1005 let rhs = self.as_int(info.inputs[1].var_id);
1006 if (lhs.map(Zero::is_zero).unwrap_or_default() && rhs.is_none())
1007 || (rhs.map(Zero::is_zero).unwrap_or_default() && lhs.is_none())
1008 {
1009 let nz_input = info.inputs[if lhs.is_some() { 1 } else { 0 }];
1010 let var = &self.variables[nz_input.var_id].clone();
1011 let function = self.type_info.get(&var.ty)?.is_zero;
1012 let unused_nz_var = Variable::with_default_context(
1013 db,
1014 corelib::core_nonzero_ty(db, var.ty),
1015 var.location,
1016 );
1017 let unused_nz_var = self.variables.alloc(unused_nz_var);
1018 return Some(BlockEnd::Match {
1019 info: MatchInfo::Extern(MatchExternInfo {
1020 function,
1021 inputs: vec![nz_input],
1022 arms: vec![
1023 MatchArm {
1024 arm_selector: MatchArmSelector::VariantId(
1025 corelib::jump_nz_zero_variant(db, var.ty),
1026 ),
1027 block_id: info.arms[1].block_id,
1028 var_ids: vec![],
1029 },
1030 MatchArm {
1031 arm_selector: MatchArmSelector::VariantId(
1032 corelib::jump_nz_nonzero_variant(db, var.ty),
1033 ),
1034 block_id: info.arms[0].block_id,
1035 var_ids: vec![unused_nz_var],
1036 },
1037 ],
1038 location: info.location,
1039 }),
1040 });
1041 }
1042 Some(BlockEnd::Goto(
1043 info.arms[if lhs? == rhs? { 1 } else { 0 }].block_id,
1044 Default::default(),
1045 ))
1046 } else if self.uadd_fns.contains(&id)
1047 || self.usub_fns.contains(&id)
1048 || self.diff_fns.contains(&id)
1049 || self.iadd_fns.contains(&id)
1050 || self.isub_fns.contains(&id)
1051 {
1052 let rhs = self.as_int(info.inputs[1].var_id);
1053 let lhs = self.as_int(info.inputs[0].var_id);
1054 if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
1055 let ty = self.variables[info.arms[0].var_ids[0]].ty;
1056 let range = self.type_value_ranges.get(&ty)?;
1057 let value = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
1058 lhs + rhs
1059 } else {
1060 lhs - rhs
1061 };
1062 let (arm_index, value) = match range.normalized(value) {
1063 NormalizedResult::InRange(value) => (0, value),
1064 NormalizedResult::Under(value) => (1, value),
1065 NormalizedResult::Over(value) => (
1066 if self.iadd_fns.contains(&id) || self.isub_fns.contains(&id) {
1067 2
1068 } else {
1069 1
1070 },
1071 value,
1072 ),
1073 };
1074 let arm = &info.arms[arm_index];
1075 let actual_output = arm.var_ids[0];
1076 let value = ConstValue::Int(value, ty).intern(db);
1077 self.var_info.insert(actual_output, VarInfo::Const(value).into());
1078 statements.push(Statement::Const(StatementConst::new_flat(value, actual_output)));
1079 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1080 }
1081 if let Some(rhs) = rhs {
1082 if rhs.is_zero() && !self.diff_fns.contains(&id) {
1083 let arm = &info.arms[0];
1084 self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[0]).into());
1085 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1086 }
1087 if rhs.is_one() && !self.diff_fns.contains(&id) {
1088 let ty = self.variables[info.arms[0].var_ids[0]].ty;
1089 let ty_info = self.type_info.get(&ty)?;
1090 let function = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
1091 ty_info.inc?
1092 } else {
1093 ty_info.dec?
1094 };
1095 let enum_ty = function.signature(db).ok()?.return_type;
1096 let TypeLongId::Concrete(ConcreteTypeId::Enum(concrete_enum_id)) =
1097 enum_ty.long(db)
1098 else {
1099 return None;
1100 };
1101 let result = self.variables.alloc(Variable::with_default_context(
1102 db,
1103 function.signature(db).unwrap().return_type,
1104 info.location,
1105 ));
1106 statements.push(Statement::Call(StatementCall {
1107 function,
1108 inputs: vec![info.inputs[0]],
1109 with_coupon: false,
1110 outputs: vec![result],
1111 location: info.location,
1112 is_specialization_base_call: false,
1113 }));
1114 return Some(BlockEnd::Match {
1115 info: MatchInfo::Enum(MatchEnumInfo {
1116 concrete_enum_id: *concrete_enum_id,
1117 input: VarUsage { var_id: result, location: info.location },
1118 arms: core::mem::take(&mut info.arms),
1119 location: info.location,
1120 }),
1121 });
1122 }
1123 }
1124 if let Some(lhs) = lhs
1125 && lhs.is_zero()
1126 && (self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id))
1127 {
1128 let arm = &info.arms[0];
1129 self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[1]).into());
1130 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1131 }
1132 None
1133 } else if let Some(reversed) = self.downcast_fns.get(&id) {
1134 let range = |ty: TypeId<'_>| {
1135 Some(if let Some(range) = self.type_value_ranges.get(&ty) {
1136 range.clone()
1137 } else {
1138 let (min, max) = corelib::try_extract_bounded_int_type(db, ty)?;
1139 TypeRange::new(min.to_int(db)?.clone(), max.to_int(db)?.clone())
1140 })
1141 };
1142 let (success_arm, failure_arm) = if *reversed { (1, 0) } else { (0, 1) };
1143 let input_var = info.inputs[0].var_id;
1144 let in_ty = self.variables[input_var].ty;
1145 let success_output = info.arms[success_arm].var_ids[0];
1146 let out_ty = self.variables[success_output].ty;
1147 let out_range = range(out_ty)?;
1148 let Some(value) = self.as_int(input_var) else {
1149 let in_range = range(in_ty)?;
1150 return if in_range.min < out_range.min || in_range.max > out_range.max {
1151 None
1152 } else {
1153 let generic_args = [in_ty, out_ty].map(GenericArgumentId::Type).to_vec();
1154 let function = db.core_info().upcast_fn.concretize(db, generic_args);
1155 statements.push(Statement::Call(StatementCall {
1156 function: function.lowered(db),
1157 inputs: vec![info.inputs[0]],
1158 with_coupon: false,
1159 outputs: vec![success_output],
1160 location: info.location,
1161 is_specialization_base_call: false,
1162 }));
1163 return Some(BlockEnd::Goto(
1164 info.arms[success_arm].block_id,
1165 Default::default(),
1166 ));
1167 };
1168 };
1169 let value = if in_ty == self.felt252 {
1170 felt252_for_downcast(value, &out_range.min)
1171 } else {
1172 value.clone()
1173 };
1174 Some(if let NormalizedResult::InRange(value) = out_range.normalized(value) {
1175 let value = ConstValue::Int(value, out_ty).intern(db);
1176 self.var_info.insert(success_output, VarInfo::Const(value).into());
1177 statements.push(Statement::Const(StatementConst::new_flat(value, success_output)));
1178 BlockEnd::Goto(info.arms[success_arm].block_id, Default::default())
1179 } else {
1180 BlockEnd::Goto(info.arms[failure_arm].block_id, Default::default())
1181 })
1182 } else if id == self.bounded_int_constrain {
1183 let input_var = info.inputs[0].var_id;
1184 let value = self.as_int(input_var)?;
1185 let generic_arg = generic_args[1];
1186 let constrain_value = extract_matches!(generic_arg, GenericArgumentId::Constant)
1187 .to_int(db)
1188 .expect("Expected ConstValue::Int for size");
1189 let arm_idx = if value < constrain_value { 0 } else { 1 };
1190 let output = info.arms[arm_idx].var_ids[0];
1191 statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1192 Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1193 } else if id == self.bounded_int_trim_min {
1194 let input_var = info.inputs[0].var_id;
1195 let ConstValue::Int(value, ty) = self.as_const(input_var)?.long(self.db) else {
1196 return None;
1197 };
1198 let is_trimmed = if let Some(range) = self.type_value_ranges.get(ty) {
1199 range.min == *value
1200 } else {
1201 corelib::try_extract_bounded_int_type(db, *ty)?.0.to_int(db)? == value
1202 };
1203 let arm_idx = if is_trimmed {
1204 0
1205 } else {
1206 let output = info.arms[1].var_ids[0];
1207 statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1208 1
1209 };
1210 Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1211 } else if id == self.bounded_int_trim_max {
1212 let input_var = info.inputs[0].var_id;
1213 let ConstValue::Int(value, ty) = self.as_const(input_var)?.long(self.db) else {
1214 return None;
1215 };
1216 let is_trimmed = if let Some(range) = self.type_value_ranges.get(ty) {
1217 range.max == *value
1218 } else {
1219 corelib::try_extract_bounded_int_type(db, *ty)?.1.to_int(db)? == value
1220 };
1221 let arm_idx = if is_trimmed {
1222 0
1223 } else {
1224 let output = info.arms[1].var_ids[0];
1225 statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1226 1
1227 };
1228 Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1229 } else if id == self.array_get {
1230 let index = self.as_int(info.inputs[1].var_id)?.to_usize()?;
1231 if let Some(arr_info) = self.var_info.get(&info.inputs[0].var_id)
1232 && let VarInfo::Snapshot(arr_info) = arr_info.as_ref()
1233 && let VarInfo::Array(infos) = arr_info.as_ref()
1234 {
1235 match infos.get(index) {
1236 Some(Some(output_var_info)) => {
1237 let arm = &info.arms[0];
1238 let output_var_info = output_var_info.clone();
1239 self.var_info.insert(
1240 arm.var_ids[0],
1241 VarInfo::Box(VarInfo::Snapshot(output_var_info.clone()).into()).into(),
1242 );
1243 if let VarInfo::Const(value) = output_var_info.as_ref() {
1244 let value_ty = value.ty(db).ok()?;
1245 let value_box_ty = corelib::core_box_ty(db, value_ty);
1246 let location = info.location;
1247 let boxed_var =
1248 Variable::with_default_context(db, value_box_ty, location);
1249 let boxed = self.variables.alloc(boxed_var.clone());
1250 let unused_boxed = self.variables.alloc(boxed_var);
1251 let snapped = self.variables.alloc(Variable::with_default_context(
1252 db,
1253 TypeLongId::Snapshot(value_box_ty).intern(db),
1254 location,
1255 ));
1256 statements.extend([
1257 Statement::Const(StatementConst::new_boxed(*value, boxed)),
1258 Statement::Snapshot(StatementSnapshot {
1259 input: VarUsage { var_id: boxed, location },
1260 outputs: [unused_boxed, snapped],
1261 }),
1262 Statement::Call(StatementCall {
1263 function: self
1264 .box_forward_snapshot
1265 .concretize(db, vec![GenericArgumentId::Type(value_ty)])
1266 .lowered(db),
1267 inputs: vec![VarUsage { var_id: snapped, location }],
1268 with_coupon: false,
1269 outputs: vec![arm.var_ids[0]],
1270 location: info.location,
1271 is_specialization_base_call: false,
1272 }),
1273 ]);
1274 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1275 }
1276 }
1277 None => {
1278 return Some(BlockEnd::Goto(info.arms[1].block_id, Default::default()));
1279 }
1280 Some(None) => {}
1281 }
1282 }
1283 if index.is_zero()
1284 && let [success, failure] = info.arms.as_mut_slice()
1285 {
1286 let arr = info.inputs[0].var_id;
1287 let unused_arr_output0 = self.variables.alloc(self.variables[arr].clone());
1288 let unused_arr_output1 = self.variables.alloc(self.variables[arr].clone());
1289 info.inputs.truncate(1);
1290 info.function = GenericFunctionId::Extern(self.array_snapshot_pop_front)
1291 .concretize(db, generic_args)
1292 .lowered(db);
1293 success.var_ids.insert(0, unused_arr_output0);
1294 failure.var_ids.insert(0, unused_arr_output1);
1295 }
1296 None
1297 } else if id == self.array_pop_front {
1298 let VarInfo::Array(var_infos) = self.var_info.get(&info.inputs[0].var_id)?.as_ref()
1299 else {
1300 return None;
1301 };
1302 if let Some(first) = var_infos.first() {
1303 if let Some(first) = first.as_ref().cloned() {
1304 let arm = &info.arms[0];
1305 self.var_info
1306 .insert(arm.var_ids[0], VarInfo::Array(var_infos[1..].to_vec()).into());
1307 self.var_info.insert(arm.var_ids[1], VarInfo::Box(first).into());
1308 }
1309 None
1310 } else {
1311 let arm = &info.arms[1];
1312 self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]).into());
1313 Some(BlockEnd::Goto(
1314 arm.block_id,
1315 VarRemapping {
1316 remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1317 },
1318 ))
1319 }
1320 } else if id == self.array_snapshot_pop_back || id == self.array_snapshot_pop_front {
1321 let var_info = self.var_info.get(&info.inputs[0].var_id)?;
1322 let desnapped = try_extract_matches!(var_info.as_ref(), VarInfo::Snapshot)?;
1323 let element_var_infos = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?;
1324 if element_var_infos.is_empty() {
1326 let arm = &info.arms[1];
1327 self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]).into());
1328 Some(BlockEnd::Goto(
1329 arm.block_id,
1330 VarRemapping {
1331 remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1332 },
1333 ))
1334 } else {
1335 None
1336 }
1337 } else {
1338 None
1339 }
1340 }
1341
1342 fn as_const(&self, var_id: VariableId) -> Option<ConstValueId<'db>> {
1344 try_extract_matches!(self.var_info.get(&var_id)?.as_ref(), VarInfo::Const).copied()
1345 }
1346
1347 fn as_int(&self, var_id: VariableId) -> Option<&BigInt> {
1349 match self.as_const(var_id)?.long(self.db) {
1350 ConstValue::Int(value, _) => Some(value),
1351 ConstValue::NonZero(const_value) => {
1352 if let ConstValue::Int(value, _) = const_value.long(self.db) {
1353 Some(value)
1354 } else {
1355 None
1356 }
1357 }
1358 _ => None,
1359 }
1360 }
1361
1362 fn maybe_replace_inputs(&self, inputs: &mut [VarUsage<'db>]) {
1364 for input in inputs {
1365 self.maybe_replace_input(input);
1366 }
1367 }
1368
1369 fn maybe_replace_input(&self, input: &mut VarUsage<'db>) {
1371 if let Some(info) = self.var_info.get(&input.var_id)
1372 && let VarInfo::Var(new_var) = info.as_ref()
1373 {
1374 *input = *new_var;
1375 }
1376 }
1377
1378 fn try_get_specialization_arg(
1384 &self,
1385 var_info: Rc<VarInfo<'db>>,
1386 ty: TypeId<'db>,
1387 unknown_vars: &mut Vec<VarUsage<'db>>,
1388 coerce: Option<&SpecializationArg<'db>>,
1389 ) -> Option<SpecializationArg<'db>> {
1390 require(self.db.type_size_info(ty).ok()? != TypeSizeInformation::ZeroSized)?;
1392 require(!matches!(coerce, Some(SpecializationArg::NotSpecialized)))?;
1394
1395 match var_info.as_ref() {
1396 VarInfo::Const(value) => {
1397 let res = const_to_specialization_arg(self.db, *value, false);
1398 let Some(coerce) = coerce else {
1399 return Some(res);
1400 };
1401 if *coerce == res { Some(res) } else { None }
1402 }
1403 VarInfo::Box(info) => {
1404 let res = try_extract_matches!(info.as_ref(), VarInfo::Const)
1405 .map(|value| SpecializationArg::Const { value: *value, boxed: true });
1406 let Some(coerce) = coerce else {
1407 return res;
1408 };
1409 if Some(coerce.clone()) == res { res } else { None }
1410 }
1411 VarInfo::Snapshot(info) => {
1412 let desnap_ty = *extract_matches!(ty.long(self.db), TypeLongId::Snapshot);
1413 let mut local_unknown_vars: Vec<VarUsage<'db>> = Vec::new();
1415 let inner = self.try_get_specialization_arg(
1416 info.clone(),
1417 desnap_ty,
1418 &mut local_unknown_vars,
1419 coerce.map(|coerce| {
1420 extract_matches!(coerce, SpecializationArg::Snapshot).as_ref()
1421 }),
1422 )?;
1423 unknown_vars.extend(local_unknown_vars);
1424 Some(SpecializationArg::Snapshot(Box::new(inner)))
1425 }
1426 VarInfo::Array(infos) => {
1427 let TypeLongId::Concrete(concrete_ty) = ty.long(self.db) else {
1428 unreachable!("Expected a concrete type");
1429 };
1430 let [GenericArgumentId::Type(inner_ty)] = &concrete_ty.generic_args(self.db)[..]
1431 else {
1432 unreachable!("Expected a single type generic argument");
1433 };
1434 let coerces = match coerce {
1435 Some(coerce) => {
1436 let SpecializationArg::Array(ty, specialization_args) = coerce else {
1437 unreachable!("Expected an array specialization argument");
1438 };
1439 assert_eq!(ty, inner_ty);
1440 if specialization_args.len() != infos.len() {
1441 return None;
1442 }
1443
1444 specialization_args.iter().map(Some).collect()
1445 }
1446 None => vec![None; infos.len()],
1447 };
1448 let mut vars = vec![];
1450 let mut args = vec![];
1451 for (info, coerce) in zip_eq(infos, coerces) {
1452 let info = info.as_ref()?.clone();
1453 let arg =
1454 self.try_get_specialization_arg(info, *inner_ty, &mut vars, coerce)?;
1455 args.push(arg);
1456 }
1457 if !args.is_empty()
1458 && args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
1459 {
1460 return None;
1461 }
1462 unknown_vars.extend(vars);
1463 Some(SpecializationArg::Array(*inner_ty, args))
1464 }
1465 VarInfo::Struct(infos) => {
1466 let element_types: Vec<TypeId<'db>> = match ty.long(self.db) {
1468 TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) => {
1469 let members = self.db.concrete_struct_members(*concrete_struct).unwrap();
1470 members.values().map(|member| member.ty).collect()
1471 }
1472 TypeLongId::Tuple(element_types) => element_types.clone(),
1473 TypeLongId::FixedSizeArray { type_id, .. } => vec![*type_id; infos.len()],
1474 _ => return None,
1476 };
1477
1478 let coerces = match coerce {
1479 Some(SpecializationArg::Struct(specialization_args)) => {
1480 assert_eq!(specialization_args.len(), infos.len());
1481 specialization_args.iter().map(Some).collect()
1482 }
1483 Some(_) => unreachable!("Expected a struct specialization argument"),
1484 None => vec![None; infos.len()],
1485 };
1486
1487 let mut struct_args = Vec::new();
1488 let mut vars = vec![];
1490 for ((elem_ty, opt_var_info), coerce) in
1491 zip_eq(zip_eq(element_types, infos), coerces)
1492 {
1493 let var_info = opt_var_info.as_ref()?.clone();
1494 let arg =
1495 self.try_get_specialization_arg(var_info, elem_ty, &mut vars, coerce)?;
1496 struct_args.push(arg);
1497 }
1498 if !struct_args.is_empty()
1499 && struct_args
1500 .iter()
1501 .all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
1502 {
1503 return None;
1504 }
1505 unknown_vars.extend(vars);
1506 Some(SpecializationArg::Struct(struct_args))
1507 }
1508 VarInfo::Enum { variant, payload } => {
1509 let coerce = match coerce {
1510 Some(coerce) => {
1511 let SpecializationArg::Enum { variant: coercion_variant, payload } = coerce
1512 else {
1513 unreachable!("Expected an enum specialization argument");
1514 };
1515 if coercion_variant != variant {
1516 return None;
1517 }
1518 Some(payload.as_ref())
1519 }
1520 None => None,
1521 };
1522 let mut local_unknown_vars = vec![];
1523 let payload_arg = self.try_get_specialization_arg(
1524 payload.clone(),
1525 variant.ty,
1526 &mut local_unknown_vars,
1527 coerce,
1528 )?;
1529
1530 unknown_vars.extend(local_unknown_vars);
1531 Some(SpecializationArg::Enum { variant: *variant, payload: Box::new(payload_arg) })
1532 }
1533 VarInfo::Var(var_usage) => {
1534 unknown_vars.push(*var_usage);
1535 Some(SpecializationArg::NotSpecialized)
1536 }
1537 }
1538 }
1539
1540 pub fn should_skip_const_folding(&self, db: &'db dyn Database) -> bool {
1542 if db.optimizations().skip_const_folding() {
1543 return true;
1544 }
1545
1546 if self.caller_function.base_semantic_function(db).generic_function(db)
1549 == GenericFunctionWithBodyId::Free(self.libfunc_info.panic_with_const_felt252)
1550 {
1551 return true;
1552 }
1553 false
1554 }
1555}
1556
1557fn var_info_if_copy<'db>(
1559 variables: &VariableArena<'db>,
1560 input: VarUsage<'db>,
1561) -> Option<Rc<VarInfo<'db>>> {
1562 variables[input.var_id].info.copyable.is_ok().then(|| VarInfo::Var(input).into())
1563}
1564
1565#[salsa::tracked(returns(ref))]
1567fn priv_const_folding_info<'db>(
1568 db: &'db dyn Database,
1569) -> crate::optimizations::const_folding::ConstFoldingLibfuncInfo<'db> {
1570 ConstFoldingLibfuncInfo::new(db)
1571}
1572
1573#[derive(Debug, PartialEq, Eq, salsa::Update)]
1575pub struct ConstFoldingLibfuncInfo<'db> {
1576 felt_sub: ExternFunctionId<'db>,
1578 felt_add: ExternFunctionId<'db>,
1580 felt_mul: ExternFunctionId<'db>,
1582 felt_div: ExternFunctionId<'db>,
1584 box_forward_snapshot: GenericFunctionId<'db>,
1586 eq_fns: OrderedHashSet<ExternFunctionId<'db>>,
1588 uadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1590 usub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1592 diff_fns: OrderedHashSet<ExternFunctionId<'db>>,
1594 iadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1596 isub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1598 wide_mul_fns: OrderedHashSet<ExternFunctionId<'db>>,
1600 div_rem_fns: OrderedHashSet<ExternFunctionId<'db>>,
1602 bounded_int_add: ExternFunctionId<'db>,
1604 bounded_int_sub: ExternFunctionId<'db>,
1606 bounded_int_constrain: ExternFunctionId<'db>,
1608 bounded_int_trim_min: ExternFunctionId<'db>,
1610 bounded_int_trim_max: ExternFunctionId<'db>,
1612 array_get: ExternFunctionId<'db>,
1614 array_snapshot_pop_front: ExternFunctionId<'db>,
1616 array_snapshot_pop_back: ExternFunctionId<'db>,
1618 array_len: ExternFunctionId<'db>,
1620 array_new: ExternFunctionId<'db>,
1622 array_append: ExternFunctionId<'db>,
1624 array_pop_front: ExternFunctionId<'db>,
1626 storage_base_address_from_felt252: ExternFunctionId<'db>,
1628 storage_base_address_const: GenericFunctionId<'db>,
1630 panic_with_felt252: FunctionId<'db>,
1632 pub panic_with_const_felt252: FreeFunctionId<'db>,
1634 panic_with_byte_array: FunctionId<'db>,
1636 type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>>,
1638 const_calculation_info: Arc<ConstCalcInfo<'db>>,
1640}
1641impl<'db> ConstFoldingLibfuncInfo<'db> {
1642 fn new(db: &'db dyn Database) -> Self {
1643 let core = ModuleHelper::core(db);
1644 let box_module = core.submodule("box");
1645 let integer_module = core.submodule("integer");
1646 let internal_module = core.submodule("internal");
1647 let bounded_int_module = internal_module.submodule("bounded_int");
1648 let num_module = internal_module.submodule("num");
1649 let array_module = core.submodule("array");
1650 let starknet_module = core.submodule("starknet");
1651 let storage_access_module = starknet_module.submodule("storage_access");
1652 let utypes = ["u8", "u16", "u32", "u64", "u128"];
1653 let itypes = ["i8", "i16", "i32", "i64", "i128"];
1654 let eq_fns = OrderedHashSet::<_>::from_iter(
1655 chain!(utypes, itypes).map(|ty| integer_module.extern_function_id(&format!("{ty}_eq"))),
1656 );
1657 let uadd_fns = OrderedHashSet::<_>::from_iter(
1658 utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_add"))),
1659 );
1660 let usub_fns = OrderedHashSet::<_>::from_iter(
1661 utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_sub"))),
1662 );
1663 let diff_fns = OrderedHashSet::<_>::from_iter(
1664 itypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_diff"))),
1665 );
1666 let iadd_fns =
1667 OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1668 integer_module.extern_function_id(&format!("{ty}_overflowing_add_impl"))
1669 }));
1670 let isub_fns =
1671 OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1672 integer_module.extern_function_id(&format!("{ty}_overflowing_sub_impl"))
1673 }));
1674 let wide_mul_fns = OrderedHashSet::<_>::from_iter(chain!(
1675 [bounded_int_module.extern_function_id("bounded_int_mul")],
1676 ["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
1677 .map(|ty| integer_module.extern_function_id(&format!("{ty}_wide_mul"))),
1678 ));
1679 let div_rem_fns = OrderedHashSet::<_>::from_iter(chain!(
1680 [bounded_int_module.extern_function_id("bounded_int_div_rem")],
1681 utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_safe_divmod"))),
1682 ));
1683 let type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>> = OrderedHashMap::from_iter(
1684 [
1685 ("u8", false, true),
1686 ("u16", false, true),
1687 ("u32", false, true),
1688 ("u64", false, true),
1689 ("u128", false, true),
1690 ("u256", false, false),
1691 ("i8", true, true),
1692 ("i16", true, true),
1693 ("i32", true, true),
1694 ("i64", true, true),
1695 ("i128", true, true),
1696 ]
1697 .map(|(ty_name, as_bounded_int, inc_dec): (&'static str, bool, bool)| {
1698 let ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, ty_name), vec![]);
1699 let is_zero = if as_bounded_int {
1700 bounded_int_module
1701 .function_id("bounded_int_is_zero", vec![GenericArgumentId::Type(ty)])
1702 } else {
1703 integer_module.function_id(
1704 SmolStrId::from(db, format!("{ty_name}_is_zero")).long(db).as_str(),
1705 vec![],
1706 )
1707 }
1708 .lowered(db);
1709 let (inc, dec) = if inc_dec {
1710 (
1711 Some(
1712 num_module
1713 .function_id(
1714 SmolStrId::from(db, format!("{ty_name}_inc")).long(db).as_str(),
1715 vec![],
1716 )
1717 .lowered(db),
1718 ),
1719 Some(
1720 num_module
1721 .function_id(
1722 SmolStrId::from(db, format!("{ty_name}_dec")).long(db).as_str(),
1723 vec![],
1724 )
1725 .lowered(db),
1726 ),
1727 )
1728 } else {
1729 (None, None)
1730 };
1731 let info = TypeInfo { is_zero, inc, dec };
1732 (ty, info)
1733 }),
1734 );
1735 Self {
1736 felt_sub: core.extern_function_id("felt252_sub"),
1737 felt_add: core.extern_function_id("felt252_add"),
1738 felt_mul: core.extern_function_id("felt252_mul"),
1739 felt_div: core.extern_function_id("felt252_div"),
1740 box_forward_snapshot: box_module.generic_function_id("box_forward_snapshot"),
1741 eq_fns,
1742 uadd_fns,
1743 usub_fns,
1744 diff_fns,
1745 iadd_fns,
1746 isub_fns,
1747 wide_mul_fns,
1748 div_rem_fns,
1749 bounded_int_add: bounded_int_module.extern_function_id("bounded_int_add"),
1750 bounded_int_sub: bounded_int_module.extern_function_id("bounded_int_sub"),
1751 bounded_int_constrain: bounded_int_module.extern_function_id("bounded_int_constrain"),
1752 bounded_int_trim_min: bounded_int_module.extern_function_id("bounded_int_trim_min"),
1753 bounded_int_trim_max: bounded_int_module.extern_function_id("bounded_int_trim_max"),
1754 array_get: array_module.extern_function_id("array_get"),
1755 array_snapshot_pop_front: array_module.extern_function_id("array_snapshot_pop_front"),
1756 array_snapshot_pop_back: array_module.extern_function_id("array_snapshot_pop_back"),
1757 array_len: array_module.extern_function_id("array_len"),
1758 array_new: array_module.extern_function_id("array_new"),
1759 array_append: array_module.extern_function_id("array_append"),
1760 array_pop_front: array_module.extern_function_id("array_pop_front"),
1761 storage_base_address_from_felt252: storage_access_module
1762 .extern_function_id("storage_base_address_from_felt252"),
1763 storage_base_address_const: storage_access_module
1764 .generic_function_id("storage_base_address_const"),
1765 panic_with_felt252: core.function_id("panic_with_felt252", vec![]).lowered(db),
1766 panic_with_const_felt252: core.free_function_id("panic_with_const_felt252"),
1767 panic_with_byte_array: core
1768 .submodule("panics")
1769 .function_id("panic_with_byte_array", vec![])
1770 .lowered(db),
1771 type_info,
1772 const_calculation_info: db.const_calc_info(),
1773 }
1774 }
1775}
1776
1777impl<'db> std::ops::Deref for ConstFoldingContext<'db, '_> {
1778 type Target = ConstFoldingLibfuncInfo<'db>;
1779 fn deref(&self) -> &ConstFoldingLibfuncInfo<'db> {
1780 self.libfunc_info
1781 }
1782}
1783
1784impl<'a> std::ops::Deref for ConstFoldingLibfuncInfo<'a> {
1785 type Target = ConstCalcInfo<'a>;
1786 fn deref(&self) -> &ConstCalcInfo<'a> {
1787 &self.const_calculation_info
1788 }
1789}
1790
1791#[derive(Debug, PartialEq, Eq, salsa::Update)]
1793struct TypeInfo<'db> {
1794 is_zero: FunctionId<'db>,
1796 inc: Option<FunctionId<'db>>,
1798 dec: Option<FunctionId<'db>>,
1800}
1801
1802trait TypeRangeNormalizer {
1803 fn normalized(&self, value: BigInt) -> NormalizedResult;
1806}
1807impl TypeRangeNormalizer for TypeRange {
1808 fn normalized(&self, value: BigInt) -> NormalizedResult {
1809 if value < self.min {
1810 NormalizedResult::Under(value - &self.min + &self.max + 1)
1811 } else if value > self.max {
1812 NormalizedResult::Over(value + &self.min - &self.max - 1)
1813 } else {
1814 NormalizedResult::InRange(value)
1815 }
1816 }
1817}
1818
1819enum NormalizedResult {
1821 InRange(BigInt),
1823 Over(BigInt),
1825 Under(BigInt),
1827}