1#[cfg(test)]
2#[path = "const_folding_test.rs"]
3mod test;
4
5use std::rc::Rc;
6use std::sync::Arc;
7
8use cairo_lang_defs::ids::{ExternFunctionId, FreeFunctionId};
9use cairo_lang_filesystem::flag::FlagsGroup;
10use cairo_lang_filesystem::ids::SmolStrId;
11use cairo_lang_semantic::corelib::CorelibSemantic;
12use cairo_lang_semantic::helper::ModuleHelper;
13use cairo_lang_semantic::items::constant::{
14 ConstCalcInfo, ConstValue, ConstValueId, ConstantSemantic, TypeRange, canonical_felt252,
15 felt252_for_downcast,
16};
17use cairo_lang_semantic::items::functions::{GenericFunctionId, GenericFunctionWithBodyId};
18use cairo_lang_semantic::items::structure::StructSemantic;
19use cairo_lang_semantic::types::{TypeSizeInformation, TypesSemantic};
20use cairo_lang_semantic::{
21 ConcreteTypeId, ConcreteVariant, GenericArgumentId, MatchArmSelector, TypeId, TypeLongId,
22 corelib,
23};
24use cairo_lang_utils::byte_array::BYTE_ARRAY_MAGIC;
25use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
26use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
27use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
28use cairo_lang_utils::{Intern, extract_matches, require, try_extract_matches};
29use itertools::{Itertools, chain, zip_eq};
30use num_bigint::BigInt;
31use num_integer::Integer;
32use num_traits::cast::ToPrimitive;
33use num_traits::{Num, One, Zero};
34use salsa::Database;
35use starknet_types_core::felt::Felt as Felt252;
36
37use crate::db::LoweringGroup;
38use crate::ids::{
39 ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionId, SemanticFunctionIdEx,
40 SpecializedFunction,
41};
42use crate::specialization::SpecializationArg;
43use crate::utils::InliningStrategy;
44use crate::{
45 Block, BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, MatchArm, MatchEnumInfo,
46 MatchExternInfo, MatchInfo, Statement, StatementCall, StatementConst, StatementDesnap,
47 StatementEnumConstruct, StatementIntoBox, StatementSnapshot, StatementStructConstruct,
48 StatementStructDestructure, StatementUnbox, VarRemapping, VarUsage, Variable, VariableArena,
49 VariableId,
50};
51
52fn const_to_specialization_arg<'db>(
56 db: &'db dyn Database,
57 value: ConstValueId<'db>,
58 boxed: bool,
59) -> SpecializationArg<'db> {
60 match value.long(db) {
61 ConstValue::Struct(members, ty) => {
62 if matches!(
65 ty.long(db),
66 TypeLongId::Concrete(ConcreteTypeId::Struct(_))
67 | TypeLongId::Tuple(_)
68 | TypeLongId::FixedSizeArray { .. }
69 ) {
70 let args = members
71 .iter()
72 .map(|member| const_to_specialization_arg(db, *member, false))
73 .collect();
74 SpecializationArg::Struct(args)
75 } else {
76 SpecializationArg::Const { value, boxed }
77 }
78 }
79 ConstValue::Enum(variant, payload) => SpecializationArg::Enum {
80 variant: *variant,
81 payload: Box::new(const_to_specialization_arg(db, *payload, false)),
82 },
83 _ => SpecializationArg::Const { value, boxed },
84 }
85}
86
87#[derive(Debug, Clone)]
90enum VarInfo<'db> {
91 Const(ConstValueId<'db>),
93 Var(VarUsage<'db>),
95 Snapshot(Rc<VarInfo<'db>>),
97 Struct(Vec<Option<Rc<VarInfo<'db>>>>),
100 Enum { variant: ConcreteVariant<'db>, payload: Rc<VarInfo<'db>> },
102 Box(Rc<VarInfo<'db>>),
104 Array(Vec<Option<Rc<VarInfo<'db>>>>),
107}
108impl<'db> VarInfo<'db> {
109 fn peel_snapshots(mut self: Rc<Self>) -> (usize, Rc<VarInfo<'db>>) {
111 let mut n_snapshots = 0;
112 while let VarInfo::Snapshot(inner) = self.as_ref() {
113 self = inner.clone();
114 n_snapshots += 1;
115 }
116 (n_snapshots, self)
117 }
118 fn wrap_with_snapshots(mut self: Rc<Self>, n_snapshots: usize) -> Rc<VarInfo<'db>> {
120 for _ in 0..n_snapshots {
121 self = VarInfo::Snapshot(self).into();
122 }
123 self
124 }
125}
126
127#[derive(Debug, Clone, Copy, PartialEq)]
128enum Reachability {
129 FromSingleGoto(BlockId),
132 Any,
135}
136
137pub fn const_folding<'db>(
140 db: &'db dyn Database,
141 function_id: ConcreteFunctionWithBodyId<'db>,
142 lowered: &mut Lowered<'db>,
143) {
144 if lowered.blocks.is_empty() {
145 return;
146 }
147
148 let mut ctx = ConstFoldingContext::new(db, function_id, &mut lowered.variables);
151
152 if ctx.should_skip_const_folding(db) {
153 return;
154 }
155
156 for block_id in (0..lowered.blocks.len()).map(BlockId) {
157 if !ctx.visit_block_start(block_id, |block_id| &lowered.blocks[block_id]) {
158 continue;
159 }
160
161 let block = &mut lowered.blocks[block_id];
162 for stmt in block.statements.iter_mut() {
163 ctx.visit_statement(stmt);
164 }
165 ctx.visit_block_end(block_id, block);
166 }
167}
168
169pub struct ConstFoldingContext<'db, 'mt> {
170 db: &'db dyn Database,
172 pub variables: &'mt mut VariableArena<'db>,
174 var_info: UnorderedHashMap<VariableId, Rc<VarInfo<'db>>>,
176 libfunc_info: &'db ConstFoldingLibfuncInfo<'db>,
178 caller_function: ConcreteFunctionWithBodyId<'db>,
181 reachability: UnorderedHashMap<BlockId, Reachability>,
185 additional_stmts: Vec<Statement<'db>>,
187}
188
189impl<'db, 'mt> ConstFoldingContext<'db, 'mt> {
190 pub fn new(
191 db: &'db dyn Database,
192 function_id: ConcreteFunctionWithBodyId<'db>,
193 variables: &'mt mut VariableArena<'db>,
194 ) -> Self {
195 Self {
196 db,
197 var_info: UnorderedHashMap::default(),
198 variables,
199 libfunc_info: priv_const_folding_info(db),
200 caller_function: function_id,
201 reachability: UnorderedHashMap::from_iter([(BlockId::root(), Reachability::Any)]),
202 additional_stmts: vec![],
203 }
204 }
205
206 pub fn visit_block_start<'r, 'get>(
209 &'r mut self,
210 block_id: BlockId,
211 get_block: impl FnOnce(BlockId) -> &'get Block<'db>,
212 ) -> bool
213 where
214 'db: 'get,
215 {
216 let Some(reachability) = self.reachability.remove(&block_id) else {
217 return false;
218 };
219 match reachability {
220 Reachability::Any => {}
221 Reachability::FromSingleGoto(from_block) => match &get_block(from_block).end {
222 BlockEnd::Goto(_, remapping) => {
223 for (dst, src) in remapping.iter() {
224 if let Some(v) = self.as_const(src.var_id) {
225 self.var_info.insert(*dst, VarInfo::Const(v).into());
226 }
227 }
228 }
229 _ => unreachable!("Expected a goto end"),
230 },
231 }
232 true
233 }
234
235 pub fn visit_statement(&mut self, stmt: &mut Statement<'db>) {
245 self.maybe_replace_inputs(stmt.inputs_mut());
246 match stmt {
247 Statement::Const(StatementConst { value, output, boxed }) if *boxed => {
248 self.var_info.insert(*output, VarInfo::Box(VarInfo::Const(*value).into()).into());
249 }
250 Statement::Const(StatementConst { value, output, .. }) => match value.long(self.db) {
251 ConstValue::Int(..)
252 | ConstValue::Struct(..)
253 | ConstValue::Enum(..)
254 | ConstValue::NonZero(..) => {
255 self.var_info.insert(*output, VarInfo::Const(*value).into());
256 }
257 ConstValue::Generic(_)
258 | ConstValue::ImplConstant(_)
259 | ConstValue::Var(..)
260 | ConstValue::Missing(_) => {}
261 },
262 Statement::Snapshot(stmt) => {
263 if let Some(info) = self.var_info.get(&stmt.input.var_id) {
264 let info = info.clone();
265 self.var_info.insert(stmt.original(), info.clone());
266 self.var_info.insert(stmt.snapshot(), VarInfo::Snapshot(info).into());
267 }
268 }
269 Statement::Desnap(StatementDesnap { input, output }) => {
270 if let Some(info) = self.var_info.get(&input.var_id)
271 && let VarInfo::Snapshot(info) = info.as_ref()
272 {
273 self.var_info.insert(*output, info.clone());
274 }
275 }
276 Statement::Call(call_stmt) => {
277 if let Some(updated_stmt) = self.handle_statement_call(call_stmt) {
278 *stmt = updated_stmt;
279 } else if let Some(updated_stmt) = self.try_specialize_call(call_stmt) {
280 *stmt = updated_stmt;
281 }
282 }
283 Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
284 let mut const_args = vec![];
285 let mut all_args = vec![];
286 let mut contains_info = false;
287 for input in inputs.iter() {
288 let Some(info) = self.var_info.get(&input.var_id) else {
289 all_args.push(var_info_if_copy(self.variables, *input));
290 continue;
291 };
292 contains_info = true;
293 if let VarInfo::Const(value) = info.as_ref() {
294 const_args.push(*value);
295 }
296 all_args.push(Some(info.clone()));
297 }
298 if const_args.len() == inputs.len() {
299 let value =
300 ConstValue::Struct(const_args, self.variables[*output].ty).intern(self.db);
301 self.var_info.insert(*output, VarInfo::Const(value).into());
302 } else if contains_info {
303 self.var_info.insert(*output, VarInfo::Struct(all_args).into());
304 }
305 }
306 Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
307 if let Some(info) = self.var_info.get(&input.var_id) {
308 let (n_snapshots, info) = info.clone().peel_snapshots();
309 match info.as_ref() {
310 VarInfo::Const(const_value) => {
311 if let ConstValue::Struct(member_values, _) = const_value.long(self.db)
312 {
313 for (output, value) in zip_eq(outputs, member_values) {
314 self.var_info.insert(
315 *output,
316 Rc::new(VarInfo::Const(*value))
317 .wrap_with_snapshots(n_snapshots),
318 );
319 }
320 }
321 }
322 VarInfo::Struct(members) => {
323 for (output, member) in zip_eq(outputs, members.clone()) {
324 if let Some(member) = member {
325 self.var_info
326 .insert(*output, member.wrap_with_snapshots(n_snapshots));
327 }
328 }
329 }
330 _ => {}
331 }
332 }
333 }
334 Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
335 let value = if let Some(info) = self.var_info.get(&input.var_id) {
336 if let VarInfo::Const(val) = info.as_ref() {
337 VarInfo::Const(ConstValue::Enum(*variant, *val).intern(self.db))
338 } else {
339 VarInfo::Enum { variant: *variant, payload: info.clone() }
340 }
341 } else {
342 VarInfo::Enum { variant: *variant, payload: VarInfo::Var(*input).into() }
343 };
344 self.var_info.insert(*output, value.into());
345 }
346 Statement::IntoBox(StatementIntoBox { input, output }) => {
347 let var_info = self.var_info.get(&input.var_id);
348 let const_value = var_info.and_then(|var_info| match var_info.as_ref() {
349 VarInfo::Const(val) => Some(*val),
350 VarInfo::Snapshot(info) => {
351 try_extract_matches!(info.as_ref(), VarInfo::Const).copied()
352 }
353 _ => None,
354 });
355 let var_info =
356 var_info.cloned().or_else(|| var_info_if_copy(self.variables, *input));
357 if let Some(var_info) = var_info {
358 self.var_info.insert(*output, VarInfo::Box(var_info).into());
359 }
360
361 if let Some(const_value) = const_value {
362 *stmt = Statement::Const(StatementConst::new_boxed(const_value, *output));
363 }
364 }
365 Statement::Unbox(StatementUnbox { input, output }) => {
366 if let Some(inner) = self.var_info.get(&input.var_id)
367 && let VarInfo::Box(inner) = inner.as_ref()
368 {
369 let inner = inner.clone();
370 if let VarInfo::Const(inner) =
371 self.var_info.entry(*output).insert_entry(inner).get().as_ref()
372 {
373 *stmt = Statement::Const(StatementConst::new_flat(*inner, *output));
374 }
375 }
376 }
377 }
378 }
379
380 pub fn visit_block_end(&mut self, block_id: BlockId, block: &mut Block<'db>) {
387 let statements = &mut block.statements;
388 statements.splice(0..0, self.additional_stmts.drain(..));
389
390 match &mut block.end {
391 BlockEnd::Goto(_, remappings) => {
392 for (_, v) in remappings.iter_mut() {
393 self.maybe_replace_input(v);
394 }
395 }
396 BlockEnd::Match { info } => {
397 self.maybe_replace_inputs(info.inputs_mut());
398 match info {
399 MatchInfo::Enum(info) => {
400 if let Some(updated_end) = self.handle_enum_block_end(info, statements) {
401 block.end = updated_end;
402 }
403 }
404 MatchInfo::Extern(info) => {
405 if let Some(updated_end) = self.handle_extern_block_end(info, statements) {
406 block.end = updated_end;
407 }
408 }
409 MatchInfo::Value(info) => {
410 if let Some(value) =
411 self.as_int(info.input.var_id).and_then(|x| x.to_usize())
412 && let Some(arm) = info.arms.iter().find(|arm| {
413 matches!(
414 &arm.arm_selector,
415 MatchArmSelector::Value(v) if v.value == value
416 )
417 })
418 {
419 statements.push(Statement::StructConstruct(StatementStructConstruct {
421 inputs: vec![],
422 output: arm.var_ids[0],
423 }));
424 block.end = BlockEnd::Goto(arm.block_id, Default::default());
425 }
426 }
427 }
428 }
429 BlockEnd::Return(inputs, _) => self.maybe_replace_inputs(inputs),
430 BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
431 }
432 match &block.end {
433 BlockEnd::Goto(dst_block_id, _) => {
434 match self.reachability.entry(*dst_block_id) {
435 std::collections::hash_map::Entry::Occupied(mut e) => {
436 e.insert(Reachability::Any)
437 }
438 std::collections::hash_map::Entry::Vacant(e) => {
439 *e.insert(Reachability::FromSingleGoto(block_id))
440 }
441 };
442 }
443 BlockEnd::Match { info } => {
444 for arm in info.arms() {
445 assert!(self.reachability.insert(arm.block_id, Reachability::Any).is_none());
446 }
447 }
448 BlockEnd::NotSet | BlockEnd::Return(..) | BlockEnd::Panic(..) => {}
449 }
450 }
451
452 fn handle_statement_call(&mut self, stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
460 let db = self.db;
461 if stmt.function == self.panic_with_felt252 {
462 let val = self.as_const(stmt.inputs[0].var_id)?;
463 stmt.inputs.clear();
464 stmt.function = GenericFunctionId::Free(self.panic_with_const_felt252)
465 .concretize(db, vec![GenericArgumentId::Constant(val)])
466 .lowered(db);
467 return None;
468 } else if stmt.function == self.panic_with_byte_array && !db.flag_unsafe_panic() {
469 let snap = self.var_info.get(&stmt.inputs[0].var_id)?;
470 let bytearray = try_extract_matches!(snap.as_ref(), VarInfo::Snapshot)?;
471 let [Some(data), Some(pending_word), Some(pending_len)] =
472 &try_extract_matches!(bytearray.as_ref(), VarInfo::Struct)?[..]
473 else {
474 return None;
475 };
476 let data = try_extract_matches!(data.as_ref(), VarInfo::Array)?;
477 let pending_word = try_extract_matches!(pending_word.as_ref(), VarInfo::Const)?;
478 let pending_len = try_extract_matches!(pending_len.as_ref(), VarInfo::Const)?;
479 let mut panic_data =
480 vec![BigInt::from_str_radix(BYTE_ARRAY_MAGIC, 16).unwrap(), data.len().into()];
481 for word in data {
482 let VarInfo::Const(word) = word.as_ref()?.as_ref() else {
483 return None;
484 };
485 panic_data.push(word.long(db).to_int()?.clone());
486 }
487 panic_data.extend([
488 pending_word.long(db).to_int()?.clone(),
489 pending_len.long(db).to_int()?.clone(),
490 ]);
491 let felt252_ty = self.felt252;
492 let location = stmt.location;
493 let new_var = |ty| Variable::with_default_context(db, ty, location);
494 let as_usage = |var_id| VarUsage { var_id, location };
495 let array_fn = |extern_id| {
496 let args = vec![GenericArgumentId::Type(felt252_ty)];
497 GenericFunctionId::Extern(extern_id).concretize(db, args).lowered(db)
498 };
499 let call_stmt = |function, inputs, outputs| {
500 let with_coupon = false;
501 Statement::Call(StatementCall {
502 function,
503 inputs,
504 with_coupon,
505 outputs,
506 location,
507 is_specialization_base_call: false,
508 })
509 };
510 let arr_var = new_var(corelib::core_array_felt252_ty(db));
511 let mut arr = self.variables.alloc(arr_var.clone());
512 self.additional_stmts.push(call_stmt(array_fn(self.array_new), vec![], vec![arr]));
513 let felt252_var = new_var(felt252_ty);
514 let arr_append_fn = array_fn(self.array_append);
515 for word in panic_data {
516 let to_append = self.variables.alloc(felt252_var.clone());
517 let new_arr = self.variables.alloc(arr_var.clone());
518 self.additional_stmts.push(Statement::Const(StatementConst::new_flat(
519 ConstValue::Int(word, felt252_ty).intern(db),
520 to_append,
521 )));
522 self.additional_stmts.push(call_stmt(
523 arr_append_fn,
524 vec![as_usage(arr), as_usage(to_append)],
525 vec![new_arr],
526 ));
527 arr = new_arr;
528 }
529 let panic_ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, "Panic"), vec![]);
530 let panic_var = self.variables.alloc(new_var(panic_ty));
531 self.additional_stmts.push(Statement::StructConstruct(StatementStructConstruct {
532 inputs: vec![],
533 output: panic_var,
534 }));
535 return Some(Statement::StructConstruct(StatementStructConstruct {
536 inputs: vec![as_usage(panic_var), as_usage(arr)],
537 output: stmt.outputs[0],
538 }));
539 }
540 let (id, _generic_args) = stmt.function.get_extern(db)?;
541 if id == self.felt_sub {
542 if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
543 && rhs.is_zero()
544 {
545 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]).into());
546 None
547 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
548 && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
549 {
550 let value = canonical_felt252(&(lhs - rhs));
551 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
552 } else {
553 None
554 }
555 } else if id == self.felt_add {
556 if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
557 && lhs.is_zero()
558 {
559 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]).into());
560 None
561 } else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
562 && rhs.is_zero()
563 {
564 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]).into());
565 None
566 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
567 && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
568 {
569 let value = canonical_felt252(&(lhs + rhs));
570 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
571 } else {
572 None
573 }
574 } else if id == self.felt_mul {
575 let lhs = self.as_int(stmt.inputs[0].var_id);
576 let rhs = self.as_int(stmt.inputs[1].var_id);
577 if lhs.map(Zero::is_zero).unwrap_or_default()
578 || rhs.map(Zero::is_zero).unwrap_or_default()
579 {
580 Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
581 } else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
582 && rhs.is_one()
583 {
584 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]).into());
585 None
586 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
587 && lhs.is_one()
588 {
589 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]).into());
590 None
591 } else if let Some(lhs) = lhs
592 && let Some(rhs) = rhs
593 {
594 let value = canonical_felt252(&(lhs * rhs));
595 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
596 } else {
597 None
598 }
599 } else if id == self.felt_div {
600 if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
602 && rhs.is_one()
604 {
605 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]).into());
606 None
607 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
608 && lhs.is_zero()
610 {
611 Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
612 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
613 && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
614 && let Ok(rhs_nonzero) = Felt252::from(rhs).try_into()
615 {
616 let lhs_felt = Felt252::from(lhs);
620 let value = lhs_felt.field_div(&rhs_nonzero).to_bigint();
621 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
622 } else {
623 None
624 }
625 } else if self.wide_mul_fns.contains(&id) {
626 let lhs = self.as_int(stmt.inputs[0].var_id);
627 let rhs = self.as_int(stmt.inputs[1].var_id);
628 let output = stmt.outputs[0];
629 if lhs.map(Zero::is_zero).unwrap_or_default()
630 || rhs.map(Zero::is_zero).unwrap_or_default()
631 {
632 return Some(self.propagate_zero_and_get_statement(output));
633 }
634 let lhs = lhs?;
635 Some(self.propagate_const_and_get_statement(lhs * rhs?, stmt.outputs[0]))
636 } else if id == self.bounded_int_add || id == self.bounded_int_sub {
637 let lhs = self.as_int(stmt.inputs[0].var_id)?;
638 let rhs = self.as_int(stmt.inputs[1].var_id)?;
639 let value = if id == self.bounded_int_add { lhs + rhs } else { lhs - rhs };
640 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
641 } else if self.div_rem_fns.contains(&id) {
642 let lhs = self.as_int(stmt.inputs[0].var_id);
643 if lhs.map(Zero::is_zero).unwrap_or_default() {
644 let additional_stmt = self.propagate_zero_and_get_statement(stmt.outputs[1]);
645 self.additional_stmts.push(additional_stmt);
646 return Some(self.propagate_zero_and_get_statement(stmt.outputs[0]));
647 }
648 let rhs = self.as_int(stmt.inputs[1].var_id)?;
649 let (q, r) = lhs?.div_rem(rhs);
650 let q_output = stmt.outputs[0];
651 let q_value = ConstValue::Int(q, self.variables[q_output].ty).intern(db);
652 self.var_info.insert(q_output, VarInfo::Const(q_value).into());
653 let r_output = stmt.outputs[1];
654 let r_value = ConstValue::Int(r, self.variables[r_output].ty).intern(db);
655 self.var_info.insert(r_output, VarInfo::Const(r_value).into());
656 self.additional_stmts
657 .push(Statement::Const(StatementConst::new_flat(r_value, r_output)));
658 Some(Statement::Const(StatementConst::new_flat(q_value, q_output)))
659 } else if id == self.storage_base_address_from_felt252 {
660 let input_var = stmt.inputs[0].var_id;
661 if let Some(const_value) = self.as_const(input_var)
662 && let ConstValue::Int(val, ty) = const_value.long(db)
663 {
664 stmt.inputs.clear();
665 let arg = GenericArgumentId::Constant(ConstValue::Int(val.clone(), *ty).intern(db));
666 stmt.function =
667 self.storage_base_address_const.concretize(db, vec![arg]).lowered(db);
668 }
669 None
670 } else if self.upcast_fns.contains(&id) {
671 let int_value = self.as_int(stmt.inputs[0].var_id)?;
672 let output = stmt.outputs[0];
673 let value = ConstValue::Int(int_value.clone(), self.variables[output].ty).intern(db);
674 self.var_info.insert(output, VarInfo::Const(value).into());
675 Some(Statement::Const(StatementConst::new_flat(value, output)))
676 } else if id == self.array_new {
677 self.var_info.insert(stmt.outputs[0], VarInfo::Array(vec![]).into());
678 None
679 } else if id == self.array_append {
680 let mut var_infos = if let VarInfo::Array(var_infos) =
681 self.var_info.get(&stmt.inputs[0].var_id)?.as_ref()
682 {
683 var_infos.clone()
684 } else {
685 return None;
686 };
687 let appended = stmt.inputs[1];
688 var_infos.push(match self.var_info.get(&appended.var_id) {
689 Some(var_info) => Some(var_info.clone()),
690 None => var_info_if_copy(self.variables, appended),
691 });
692 self.var_info.insert(stmt.outputs[0], VarInfo::Array(var_infos).into());
693 None
694 } else if id == self.array_len {
695 let info = self.var_info.get(&stmt.inputs[0].var_id)?;
696 let desnapped = try_extract_matches!(info.as_ref(), VarInfo::Snapshot)?;
697 let length = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?.len();
698 Some(self.propagate_const_and_get_statement(length.into(), stmt.outputs[0]))
699 } else {
700 None
701 }
702 }
703
704 fn try_specialize_call(&self, call_stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
711 if call_stmt.with_coupon {
712 return None;
713 }
714 if matches!(self.db.optimizations().inlining_strategy(), InliningStrategy::Avoid) {
716 return None;
717 }
718
719 let Ok(Some(mut called_function)) = call_stmt.function.body(self.db) else {
720 return None;
721 };
722
723 let extract_base = |function: ConcreteFunctionWithBodyId<'db>| match function.long(self.db)
724 {
725 ConcreteFunctionWithBodyLongId::Specialized(specialized) => {
726 specialized.long(self.db).base
727 }
728 _ => function,
729 };
730 let called_base = extract_base(called_function);
731 let caller_base = extract_base(self.caller_function);
732
733 if self.db.priv_never_inline(called_base).ok()? {
734 return None;
735 }
736
737 if call_stmt.is_specialization_base_call {
739 return None;
740 }
741
742 if called_base == caller_base && called_function != called_base {
744 return None;
745 }
746
747 let scc =
750 self.db.lowered_scc(called_base, DependencyType::Call, LoweringStage::Monomorphized);
751 if scc.len() > 1 && scc.contains(&caller_base) {
752 return None;
753 }
754
755 if call_stmt.inputs.iter().all(|arg| self.var_info.get(&arg.var_id).is_none()) {
756 return None;
758 }
759
760 let self_specializition = if let ConcreteFunctionWithBodyLongId::Specialized(specialized) =
762 self.caller_function.long(self.db)
763 && caller_base == called_base
764 {
765 specialized.long(self.db).args.iter().map(Some).collect()
766 } else {
767 vec![None; call_stmt.inputs.len()]
768 };
769
770 let mut specialization_args = vec![];
771 let mut new_args = vec![];
772 for (arg, coerce) in zip_eq(&call_stmt.inputs, &self_specializition) {
773 if let Some(var_info) = self.var_info.get(&arg.var_id)
774 && self.variables[arg.var_id].info.droppable.is_ok()
775 && let Some(specialization_arg) = self.try_get_specialization_arg(
776 var_info.clone(),
777 self.variables[arg.var_id].ty,
778 &mut new_args,
779 *coerce,
780 )
781 {
782 specialization_args.push(specialization_arg);
783 } else {
784 specialization_args.push(SpecializationArg::NotSpecialized);
785 new_args.push(*arg);
786 continue;
787 };
788 }
789
790 if specialization_args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized)) {
791 return None;
793 }
794 if let ConcreteFunctionWithBodyLongId::Specialized(specialized_function) =
795 called_function.long(self.db)
796 {
797 let specialized_function = specialized_function.long(self.db);
798 called_function = specialized_function.base;
801 let mut new_args_iter = specialization_args.into_iter();
802 let mut old_args = specialized_function.args.clone();
803 let mut stack = vec![];
804 for arg in old_args.iter_mut().rev() {
805 stack.push(arg);
806 }
807 while let Some(arg) = stack.pop() {
808 match arg {
809 SpecializationArg::Const { .. } => {}
810 SpecializationArg::Snapshot(inner) => {
811 stack.push(inner.as_mut());
812 }
813 SpecializationArg::Enum { payload, .. } => {
814 stack.push(payload.as_mut());
815 }
816 SpecializationArg::Array(_, values) | SpecializationArg::Struct(values) => {
817 for value in values.iter_mut().rev() {
818 stack.push(value);
819 }
820 }
821 SpecializationArg::NotSpecialized => {
822 *arg = new_args_iter.next().unwrap_or(SpecializationArg::NotSpecialized);
823 }
824 }
825 }
826 specialization_args = old_args;
827 }
828 let specialized = SpecializedFunction { base: called_function, args: specialization_args }
829 .intern(self.db);
830 let specialized_func_id =
831 ConcreteFunctionWithBodyLongId::Specialized(specialized).intern(self.db);
832
833 if caller_base != called_base
834 && self.db.priv_should_specialize(specialized_func_id) == Ok(false)
835 {
836 return None;
837 }
838
839 Some(Statement::Call(StatementCall {
840 function: specialized_func_id.function_id(self.db).unwrap(),
841 inputs: new_args,
842 with_coupon: call_stmt.with_coupon,
843 outputs: std::mem::take(&mut call_stmt.outputs),
844 location: call_stmt.location,
845 is_specialization_base_call: false,
846 }))
847 }
848
849 fn propagate_const_and_get_statement(
851 &mut self,
852 value: BigInt,
853 output: VariableId,
854 ) -> Statement<'db> {
855 let ty = self.variables[output].ty;
856 let value = ConstValueId::from_int(self.db, ty, &value);
857 self.var_info.insert(output, VarInfo::Const(value).into());
858 Statement::Const(StatementConst::new_flat(value, output))
859 }
860
861 fn propagate_zero_and_get_statement(&mut self, output: VariableId) -> Statement<'db> {
863 self.propagate_const_and_get_statement(BigInt::zero(), output)
864 }
865
866 fn try_generate_const_statement(
868 &self,
869 value: ConstValueId<'db>,
870 output: VariableId,
871 ) -> Option<Statement<'db>> {
872 if self.db.type_size_info(self.variables[output].ty) == Ok(TypeSizeInformation::Other) {
873 Some(Statement::Const(StatementConst::new_flat(value, output)))
874 } else if matches!(value.long(self.db), ConstValue::Struct(members, _) if members.is_empty())
875 {
876 Some(Statement::StructConstruct(StatementStructConstruct { inputs: vec![], output }))
878 } else {
879 None
880 }
881 }
882
883 fn handle_enum_block_end(
888 &mut self,
889 info: &mut MatchEnumInfo<'db>,
890 statements: &mut Vec<Statement<'db>>,
891 ) -> Option<BlockEnd<'db>> {
892 let input = info.input.var_id;
893 let (n_snapshots, var_info) = self.var_info.get(&input)?.clone().peel_snapshots();
894 let location = info.location;
895 let as_usage = |var_id| VarUsage { var_id, location };
896 let db = self.db;
897 let snapshot_stmt = |vars: &mut VariableArena<'_>, pre_snap, post_snap| {
898 let ignored = vars.alloc(vars[pre_snap].clone());
899 Statement::Snapshot(StatementSnapshot::new(as_usage(pre_snap), ignored, post_snap))
900 };
901 if let VarInfo::Const(const_value) = var_info.as_ref()
903 && let ConstValue::Enum(variant, value) = const_value.long(db)
904 {
905 let arm = &info.arms[variant.idx];
906 let output = arm.var_ids[0];
907 self.var_info
909 .insert(output, Rc::new(VarInfo::Const(*value)).wrap_with_snapshots(n_snapshots));
910 if self.variables[input].info.droppable.is_ok()
911 && self.variables[output].info.copyable.is_ok()
912 && let Ok(mut ty) = value.ty(db)
913 && let Some(mut stmt) = self.try_generate_const_statement(*value, output)
914 {
915 let snapshot_vars = (0..n_snapshots)
916 .map(|_| {
917 let old_ty = ty;
918 ty = TypeLongId::Snapshot(ty).intern(db);
919 self.variables.alloc(Variable::with_default_context(db, old_ty, location))
920 })
921 .chain([output])
922 .collect::<Vec<_>>();
923 stmt.outputs_mut()[0] = snapshot_vars[0];
925 statements.push(stmt);
926 statements.extend(snapshot_vars.into_iter().tuple_windows().map(
928 |(pre_snap, post_snap)| snapshot_stmt(self.variables, pre_snap, post_snap),
929 ));
930
931 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
932 }
933 } else if let VarInfo::Enum { variant, payload } = var_info.as_ref() {
934 let arm = &info.arms[variant.idx];
935 let variant_ty = variant.ty;
936 let output = arm.var_ids[0];
937 let payload = payload.clone();
938 let unwrapped =
939 self.variables[input].info.droppable.is_ok().then_some(()).and_then(|_| {
940 let (extra_snapshots, inner) = payload.clone().peel_snapshots();
941 match inner.as_ref() {
942 VarInfo::Var(var) if self.variables[var.var_id].info.copyable.is_ok() => {
943 Some((var.var_id, extra_snapshots))
944 }
945 VarInfo::Const(value) => {
946 let const_var = self
947 .variables
948 .alloc(Variable::with_default_context(db, variant_ty, location));
949 statements.push(self.try_generate_const_statement(*value, const_var)?);
950 Some((const_var, extra_snapshots))
951 }
952 _ => None,
953 }
954 });
955 self.var_info.insert(output, payload.wrap_with_snapshots(n_snapshots));
957 if let Some((mut unwrapped, extra_snapshots)) = unwrapped {
958 let total_snapshots = n_snapshots + extra_snapshots;
959 if total_snapshots != 0 {
960 for _ in 1..total_snapshots {
962 let ty = TypeLongId::Snapshot(self.variables[unwrapped].ty).intern(db);
963 let non_snap_var = Variable::with_default_context(self.db, ty, location);
964 let snapped = self.variables.alloc(non_snap_var);
965 statements.push(snapshot_stmt(self.variables, unwrapped, snapped));
966 unwrapped = snapped;
967 }
968 statements.push(snapshot_stmt(self.variables, unwrapped, output));
969 };
970 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
971 }
972 }
973 None
974 }
975
976 fn handle_extern_block_end(
981 &mut self,
982 info: &mut MatchExternInfo<'db>,
983 statements: &mut Vec<Statement<'db>>,
984 ) -> Option<BlockEnd<'db>> {
985 let db = self.db;
986 let (id, generic_args) = info.function.get_extern(db)?;
987 if self.nz_fns.contains(&id) {
988 let val = self.as_const(info.inputs[0].var_id)?;
989 let is_zero = match val.long(db) {
990 ConstValue::Int(v, _) => v.is_zero(),
991 ConstValue::Struct(s, _) => s.iter().all(|v| {
992 v.long(db).to_int().expect("Expected ConstValue::Int for size").is_zero()
993 }),
994 _ => unreachable!(),
995 };
996 Some(if is_zero {
997 BlockEnd::Goto(info.arms[0].block_id, Default::default())
998 } else {
999 let arm = &info.arms[1];
1000 let nz_var = arm.var_ids[0];
1001 let nz_val = ConstValue::NonZero(val).intern(db);
1002 self.var_info.insert(nz_var, VarInfo::Const(nz_val).into());
1003 statements.push(Statement::Const(StatementConst::new_flat(nz_val, nz_var)));
1004 BlockEnd::Goto(arm.block_id, Default::default())
1005 })
1006 } else if self.eq_fns.contains(&id) {
1007 let lhs = self.as_int(info.inputs[0].var_id);
1008 let rhs = self.as_int(info.inputs[1].var_id);
1009 if (lhs.map(Zero::is_zero).unwrap_or_default() && rhs.is_none())
1010 || (rhs.map(Zero::is_zero).unwrap_or_default() && lhs.is_none())
1011 {
1012 let nz_input = info.inputs[if lhs.is_some() { 1 } else { 0 }];
1013 let var = &self.variables[nz_input.var_id].clone();
1014 let function = self.type_info.get(&var.ty)?.is_zero;
1015 let unused_nz_var = Variable::with_default_context(
1016 db,
1017 corelib::core_nonzero_ty(db, var.ty),
1018 var.location,
1019 );
1020 let unused_nz_var = self.variables.alloc(unused_nz_var);
1021 return Some(BlockEnd::Match {
1022 info: MatchInfo::Extern(MatchExternInfo {
1023 function,
1024 inputs: vec![nz_input],
1025 arms: vec![
1026 MatchArm {
1027 arm_selector: MatchArmSelector::VariantId(
1028 corelib::jump_nz_zero_variant(db, var.ty),
1029 ),
1030 block_id: info.arms[1].block_id,
1031 var_ids: vec![],
1032 },
1033 MatchArm {
1034 arm_selector: MatchArmSelector::VariantId(
1035 corelib::jump_nz_nonzero_variant(db, var.ty),
1036 ),
1037 block_id: info.arms[0].block_id,
1038 var_ids: vec![unused_nz_var],
1039 },
1040 ],
1041 location: info.location,
1042 }),
1043 });
1044 }
1045 Some(BlockEnd::Goto(
1046 info.arms[if lhs? == rhs? { 1 } else { 0 }].block_id,
1047 Default::default(),
1048 ))
1049 } else if self.uadd_fns.contains(&id)
1050 || self.usub_fns.contains(&id)
1051 || self.diff_fns.contains(&id)
1052 || self.iadd_fns.contains(&id)
1053 || self.isub_fns.contains(&id)
1054 {
1055 let rhs = self.as_int(info.inputs[1].var_id);
1056 let lhs = self.as_int(info.inputs[0].var_id);
1057 if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
1058 let ty = self.variables[info.arms[0].var_ids[0]].ty;
1059 let range = self.type_value_ranges.get(&ty)?;
1060 let value = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
1061 lhs + rhs
1062 } else {
1063 lhs - rhs
1064 };
1065 let (arm_index, value) = match range.normalized(value) {
1066 NormalizedResult::InRange(value) => (0, value),
1067 NormalizedResult::Under(value) => (1, value),
1068 NormalizedResult::Over(value) => (
1069 if self.iadd_fns.contains(&id) || self.isub_fns.contains(&id) {
1070 2
1071 } else {
1072 1
1073 },
1074 value,
1075 ),
1076 };
1077 let arm = &info.arms[arm_index];
1078 let actual_output = arm.var_ids[0];
1079 let value = ConstValue::Int(value, ty).intern(db);
1080 self.var_info.insert(actual_output, VarInfo::Const(value).into());
1081 statements.push(Statement::Const(StatementConst::new_flat(value, actual_output)));
1082 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1083 }
1084 if let Some(rhs) = rhs {
1085 if rhs.is_zero() && !self.diff_fns.contains(&id) {
1086 let arm = &info.arms[0];
1087 self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[0]).into());
1088 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1089 }
1090 if rhs.is_one() && !self.diff_fns.contains(&id) {
1091 let ty = self.variables[info.arms[0].var_ids[0]].ty;
1092 let ty_info = self.type_info.get(&ty)?;
1093 let function = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
1094 ty_info.inc?
1095 } else {
1096 ty_info.dec?
1097 };
1098 let enum_ty = function.signature(db).ok()?.return_type;
1099 let TypeLongId::Concrete(ConcreteTypeId::Enum(concrete_enum_id)) =
1100 enum_ty.long(db)
1101 else {
1102 return None;
1103 };
1104 let result = self.variables.alloc(Variable::with_default_context(
1105 db,
1106 function.signature(db).unwrap().return_type,
1107 info.location,
1108 ));
1109 statements.push(Statement::Call(StatementCall {
1110 function,
1111 inputs: vec![info.inputs[0]],
1112 with_coupon: false,
1113 outputs: vec![result],
1114 location: info.location,
1115 is_specialization_base_call: false,
1116 }));
1117 return Some(BlockEnd::Match {
1118 info: MatchInfo::Enum(MatchEnumInfo {
1119 concrete_enum_id: *concrete_enum_id,
1120 input: VarUsage { var_id: result, location: info.location },
1121 arms: core::mem::take(&mut info.arms),
1122 location: info.location,
1123 }),
1124 });
1125 }
1126 }
1127 if let Some(lhs) = lhs
1128 && lhs.is_zero()
1129 && (self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id))
1130 {
1131 let arm = &info.arms[0];
1132 self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[1]).into());
1133 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1134 }
1135 None
1136 } else if let Some(reversed) = self.downcast_fns.get(&id) {
1137 let range = |ty: TypeId<'_>| {
1138 Some(if let Some(range) = self.type_value_ranges.get(&ty) {
1139 range.clone()
1140 } else {
1141 let (min, max) = corelib::try_extract_bounded_int_type_ranges(db, ty)?;
1142 TypeRange { min, max }
1143 })
1144 };
1145 let (success_arm, failure_arm) = if *reversed { (1, 0) } else { (0, 1) };
1146 let input_var = info.inputs[0].var_id;
1147 let in_ty = self.variables[input_var].ty;
1148 let success_output = info.arms[success_arm].var_ids[0];
1149 let out_ty = self.variables[success_output].ty;
1150 let out_range = range(out_ty)?;
1151 let Some(value) = self.as_int(input_var) else {
1152 let in_range = range(in_ty)?;
1153 return if in_range.min < out_range.min || in_range.max > out_range.max {
1154 None
1155 } else {
1156 let generic_args = [in_ty, out_ty].map(GenericArgumentId::Type).to_vec();
1157 let function = db.core_info().upcast_fn.concretize(db, generic_args);
1158 statements.push(Statement::Call(StatementCall {
1159 function: function.lowered(db),
1160 inputs: vec![info.inputs[0]],
1161 with_coupon: false,
1162 outputs: vec![success_output],
1163 location: info.location,
1164 is_specialization_base_call: false,
1165 }));
1166 return Some(BlockEnd::Goto(
1167 info.arms[success_arm].block_id,
1168 Default::default(),
1169 ));
1170 };
1171 };
1172 let value = if in_ty == self.felt252 {
1173 felt252_for_downcast(value, &out_range.min)
1174 } else {
1175 value.clone()
1176 };
1177 Some(if let NormalizedResult::InRange(value) = out_range.normalized(value) {
1178 let value = ConstValue::Int(value, out_ty).intern(db);
1179 self.var_info.insert(success_output, VarInfo::Const(value).into());
1180 statements.push(Statement::Const(StatementConst::new_flat(value, success_output)));
1181 BlockEnd::Goto(info.arms[success_arm].block_id, Default::default())
1182 } else {
1183 BlockEnd::Goto(info.arms[failure_arm].block_id, Default::default())
1184 })
1185 } else if id == self.bounded_int_constrain {
1186 let input_var = info.inputs[0].var_id;
1187 let value = self.as_int(input_var)?;
1188 let generic_arg = generic_args[1];
1189 let constrain_value = extract_matches!(generic_arg, GenericArgumentId::Constant)
1190 .long(db)
1191 .to_int()
1192 .expect("Expected ConstValue::Int for size");
1193 let arm_idx = if value < constrain_value { 0 } else { 1 };
1194 let output = info.arms[arm_idx].var_ids[0];
1195 statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1196 Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1197 } else if id == self.bounded_int_trim_min {
1198 let input_var = info.inputs[0].var_id;
1199 let ConstValue::Int(value, ty) = self.as_const(input_var)?.long(self.db) else {
1200 return None;
1201 };
1202 let is_trimmed = if let Some(range) = self.type_value_ranges.get(ty) {
1203 range.min == *value
1204 } else {
1205 corelib::try_extract_bounded_int_type_ranges(db, *ty)?.0 == *value
1206 };
1207 let arm_idx = if is_trimmed {
1208 0
1209 } else {
1210 let output = info.arms[1].var_ids[0];
1211 statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1212 1
1213 };
1214 Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1215 } else if id == self.bounded_int_trim_max {
1216 let input_var = info.inputs[0].var_id;
1217 let ConstValue::Int(value, ty) = self.as_const(input_var)?.long(self.db) else {
1218 return None;
1219 };
1220 let is_trimmed = if let Some(range) = self.type_value_ranges.get(ty) {
1221 range.max == *value
1222 } else {
1223 corelib::try_extract_bounded_int_type_ranges(db, *ty)?.1 == *value
1224 };
1225 let arm_idx = if is_trimmed {
1226 0
1227 } else {
1228 let output = info.arms[1].var_ids[0];
1229 statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1230 1
1231 };
1232 Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1233 } else if id == self.array_get {
1234 let index = self.as_int(info.inputs[1].var_id)?.to_usize()?;
1235 if let Some(arr_info) = self.var_info.get(&info.inputs[0].var_id)
1236 && let VarInfo::Snapshot(arr_info) = arr_info.as_ref()
1237 && let VarInfo::Array(infos) = arr_info.as_ref()
1238 {
1239 match infos.get(index) {
1240 Some(Some(output_var_info)) => {
1241 let arm = &info.arms[0];
1242 let output_var_info = output_var_info.clone();
1243 self.var_info.insert(
1244 arm.var_ids[0],
1245 VarInfo::Box(VarInfo::Snapshot(output_var_info.clone()).into()).into(),
1246 );
1247 if let VarInfo::Const(value) = output_var_info.as_ref() {
1248 let value_ty = value.ty(db).ok()?;
1249 let value_box_ty = corelib::core_box_ty(db, value_ty);
1250 let location = info.location;
1251 let boxed_var =
1252 Variable::with_default_context(db, value_box_ty, location);
1253 let boxed = self.variables.alloc(boxed_var.clone());
1254 let unused_boxed = self.variables.alloc(boxed_var);
1255 let snapped = self.variables.alloc(Variable::with_default_context(
1256 db,
1257 TypeLongId::Snapshot(value_box_ty).intern(db),
1258 location,
1259 ));
1260 statements.extend([
1261 Statement::Const(StatementConst::new_boxed(*value, boxed)),
1262 Statement::Snapshot(StatementSnapshot {
1263 input: VarUsage { var_id: boxed, location },
1264 outputs: [unused_boxed, snapped],
1265 }),
1266 Statement::Call(StatementCall {
1267 function: self
1268 .box_forward_snapshot
1269 .concretize(db, vec![GenericArgumentId::Type(value_ty)])
1270 .lowered(db),
1271 inputs: vec![VarUsage { var_id: snapped, location }],
1272 with_coupon: false,
1273 outputs: vec![arm.var_ids[0]],
1274 location: info.location,
1275 is_specialization_base_call: false,
1276 }),
1277 ]);
1278 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1279 }
1280 }
1281 None => {
1282 return Some(BlockEnd::Goto(info.arms[1].block_id, Default::default()));
1283 }
1284 Some(None) => {}
1285 }
1286 }
1287 if index.is_zero()
1288 && let [success, failure] = info.arms.as_mut_slice()
1289 {
1290 let arr = info.inputs[0].var_id;
1291 let unused_arr_output0 = self.variables.alloc(self.variables[arr].clone());
1292 let unused_arr_output1 = self.variables.alloc(self.variables[arr].clone());
1293 info.inputs.truncate(1);
1294 info.function = GenericFunctionId::Extern(self.array_snapshot_pop_front)
1295 .concretize(db, generic_args)
1296 .lowered(db);
1297 success.var_ids.insert(0, unused_arr_output0);
1298 failure.var_ids.insert(0, unused_arr_output1);
1299 }
1300 None
1301 } else if id == self.array_pop_front {
1302 let VarInfo::Array(var_infos) = self.var_info.get(&info.inputs[0].var_id)?.as_ref()
1303 else {
1304 return None;
1305 };
1306 if let Some(first) = var_infos.first() {
1307 if let Some(first) = first.as_ref().cloned() {
1308 let arm = &info.arms[0];
1309 self.var_info
1310 .insert(arm.var_ids[0], VarInfo::Array(var_infos[1..].to_vec()).into());
1311 self.var_info.insert(arm.var_ids[1], VarInfo::Box(first).into());
1312 }
1313 None
1314 } else {
1315 let arm = &info.arms[1];
1316 self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]).into());
1317 Some(BlockEnd::Goto(
1318 arm.block_id,
1319 VarRemapping {
1320 remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1321 },
1322 ))
1323 }
1324 } else if id == self.array_snapshot_pop_back || id == self.array_snapshot_pop_front {
1325 let var_info = self.var_info.get(&info.inputs[0].var_id)?;
1326 let desnapped = try_extract_matches!(var_info.as_ref(), VarInfo::Snapshot)?;
1327 let element_var_infos = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?;
1328 if element_var_infos.is_empty() {
1330 let arm = &info.arms[1];
1331 self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]).into());
1332 Some(BlockEnd::Goto(
1333 arm.block_id,
1334 VarRemapping {
1335 remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1336 },
1337 ))
1338 } else {
1339 None
1340 }
1341 } else {
1342 None
1343 }
1344 }
1345
1346 fn as_const(&self, var_id: VariableId) -> Option<ConstValueId<'db>> {
1348 try_extract_matches!(self.var_info.get(&var_id)?.as_ref(), VarInfo::Const).copied()
1349 }
1350
1351 fn as_int(&self, var_id: VariableId) -> Option<&BigInt> {
1353 match self.as_const(var_id)?.long(self.db) {
1354 ConstValue::Int(value, _) => Some(value),
1355 ConstValue::NonZero(const_value) => {
1356 if let ConstValue::Int(value, _) = const_value.long(self.db) {
1357 Some(value)
1358 } else {
1359 None
1360 }
1361 }
1362 _ => None,
1363 }
1364 }
1365
1366 fn maybe_replace_inputs(&self, inputs: &mut [VarUsage<'db>]) {
1368 for input in inputs {
1369 self.maybe_replace_input(input);
1370 }
1371 }
1372
1373 fn maybe_replace_input(&self, input: &mut VarUsage<'db>) {
1375 if let Some(info) = self.var_info.get(&input.var_id)
1376 && let VarInfo::Var(new_var) = info.as_ref()
1377 {
1378 *input = *new_var;
1379 }
1380 }
1381
1382 fn try_get_specialization_arg(
1388 &self,
1389 var_info: Rc<VarInfo<'db>>,
1390 ty: TypeId<'db>,
1391 unknown_vars: &mut Vec<VarUsage<'db>>,
1392 coerce: Option<&SpecializationArg<'db>>,
1393 ) -> Option<SpecializationArg<'db>> {
1394 require(self.db.type_size_info(ty).ok()? != TypeSizeInformation::ZeroSized)?;
1396 require(!matches!(coerce, Some(SpecializationArg::NotSpecialized)))?;
1398
1399 match var_info.as_ref() {
1400 VarInfo::Const(value) => {
1401 let res = const_to_specialization_arg(self.db, *value, false);
1402 let Some(coerce) = coerce else {
1403 return Some(res);
1404 };
1405 if *coerce == res { Some(res) } else { None }
1406 }
1407 VarInfo::Box(info) => {
1408 let res = try_extract_matches!(info.as_ref(), VarInfo::Const)
1409 .map(|value| SpecializationArg::Const { value: *value, boxed: true });
1410 let Some(coerce) = coerce else {
1411 return res;
1412 };
1413 if Some(coerce.clone()) == res { res } else { None }
1414 }
1415 VarInfo::Snapshot(info) => {
1416 let desnap_ty = *extract_matches!(ty.long(self.db), TypeLongId::Snapshot);
1417 let mut local_unknown_vars: Vec<VarUsage<'db>> = Vec::new();
1419 let inner = self.try_get_specialization_arg(
1420 info.clone(),
1421 desnap_ty,
1422 &mut local_unknown_vars,
1423 coerce.map(|coerce| {
1424 extract_matches!(coerce, SpecializationArg::Snapshot).as_ref()
1425 }),
1426 )?;
1427 unknown_vars.extend(local_unknown_vars);
1428 Some(SpecializationArg::Snapshot(Box::new(inner)))
1429 }
1430 VarInfo::Array(infos) => {
1431 let TypeLongId::Concrete(concrete_ty) = ty.long(self.db) else {
1432 unreachable!("Expected a concrete type");
1433 };
1434 let [GenericArgumentId::Type(inner_ty)] = &concrete_ty.generic_args(self.db)[..]
1435 else {
1436 unreachable!("Expected a single type generic argument");
1437 };
1438 let coerces = match coerce {
1439 Some(coerce) => {
1440 let SpecializationArg::Array(ty, specialization_args) = coerce else {
1441 unreachable!("Expected an array specialization argument");
1442 };
1443 assert_eq!(ty, inner_ty);
1444 if specialization_args.len() != infos.len() {
1445 return None;
1446 }
1447
1448 specialization_args.iter().map(Some).collect()
1449 }
1450 None => vec![None; infos.len()],
1451 };
1452 let mut vars = vec![];
1454 let mut args = vec![];
1455 for (info, coerce) in zip_eq(infos, coerces) {
1456 let info = info.as_ref()?.clone();
1457 let arg =
1458 self.try_get_specialization_arg(info, *inner_ty, &mut vars, coerce)?;
1459 args.push(arg);
1460 }
1461 if !args.is_empty()
1462 && args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
1463 {
1464 return None;
1465 }
1466 unknown_vars.extend(vars);
1467 Some(SpecializationArg::Array(*inner_ty, args))
1468 }
1469 VarInfo::Struct(infos) => {
1470 let element_types: Vec<TypeId<'db>> = match ty.long(self.db) {
1472 TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) => {
1473 let members = self.db.concrete_struct_members(*concrete_struct).unwrap();
1474 members.values().map(|member| member.ty).collect()
1475 }
1476 TypeLongId::Tuple(element_types) => element_types.clone(),
1477 TypeLongId::FixedSizeArray { type_id, .. } => vec![*type_id; infos.len()],
1478 _ => return None,
1480 };
1481
1482 let coerces = match coerce {
1483 Some(SpecializationArg::Struct(specialization_args)) => {
1484 assert_eq!(specialization_args.len(), infos.len());
1485 specialization_args.iter().map(Some).collect()
1486 }
1487 Some(_) => unreachable!("Expected a struct specialization argument"),
1488 None => vec![None; infos.len()],
1489 };
1490
1491 let mut struct_args = Vec::new();
1492 let mut vars = vec![];
1494 for ((elem_ty, opt_var_info), coerce) in
1495 zip_eq(zip_eq(element_types, infos), coerces)
1496 {
1497 let var_info = opt_var_info.as_ref()?.clone();
1498 let arg =
1499 self.try_get_specialization_arg(var_info, elem_ty, &mut vars, coerce)?;
1500 struct_args.push(arg);
1501 }
1502 if !struct_args.is_empty()
1503 && struct_args
1504 .iter()
1505 .all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
1506 {
1507 return None;
1508 }
1509 unknown_vars.extend(vars);
1510 Some(SpecializationArg::Struct(struct_args))
1511 }
1512 VarInfo::Enum { variant, payload } => {
1513 let coerce = match coerce {
1514 Some(coerce) => {
1515 let SpecializationArg::Enum { variant: coercion_variant, payload } = coerce
1516 else {
1517 unreachable!("Expected an enum specialization argument");
1518 };
1519 if coercion_variant != variant {
1520 return None;
1521 }
1522 Some(payload.as_ref())
1523 }
1524 None => None,
1525 };
1526 let mut local_unknown_vars = vec![];
1527 let payload_arg = self.try_get_specialization_arg(
1528 payload.clone(),
1529 variant.ty,
1530 &mut local_unknown_vars,
1531 coerce,
1532 )?;
1533
1534 unknown_vars.extend(local_unknown_vars);
1535 Some(SpecializationArg::Enum { variant: *variant, payload: Box::new(payload_arg) })
1536 }
1537 VarInfo::Var(var_usage) => {
1538 unknown_vars.push(*var_usage);
1539 Some(SpecializationArg::NotSpecialized)
1540 }
1541 }
1542 }
1543
1544 pub fn should_skip_const_folding(&self, db: &'db dyn Database) -> bool {
1546 if db.optimizations().skip_const_folding() {
1547 return true;
1548 }
1549
1550 if self.caller_function.base_semantic_function(db).generic_function(db)
1553 == GenericFunctionWithBodyId::Free(self.libfunc_info.panic_with_const_felt252)
1554 {
1555 return true;
1556 }
1557 false
1558 }
1559}
1560
1561fn var_info_if_copy<'db>(
1563 variables: &VariableArena<'db>,
1564 input: VarUsage<'db>,
1565) -> Option<Rc<VarInfo<'db>>> {
1566 variables[input.var_id].info.copyable.is_ok().then(|| VarInfo::Var(input).into())
1567}
1568
1569#[salsa::tracked(returns(ref))]
1571fn priv_const_folding_info<'db>(
1572 db: &'db dyn Database,
1573) -> crate::optimizations::const_folding::ConstFoldingLibfuncInfo<'db> {
1574 ConstFoldingLibfuncInfo::new(db)
1575}
1576
1577#[derive(Debug, PartialEq, Eq, salsa::Update)]
1579pub struct ConstFoldingLibfuncInfo<'db> {
1580 felt_sub: ExternFunctionId<'db>,
1582 felt_add: ExternFunctionId<'db>,
1584 felt_mul: ExternFunctionId<'db>,
1586 felt_div: ExternFunctionId<'db>,
1588 box_forward_snapshot: GenericFunctionId<'db>,
1590 eq_fns: OrderedHashSet<ExternFunctionId<'db>>,
1592 uadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1594 usub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1596 diff_fns: OrderedHashSet<ExternFunctionId<'db>>,
1598 iadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1600 isub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1602 wide_mul_fns: OrderedHashSet<ExternFunctionId<'db>>,
1604 div_rem_fns: OrderedHashSet<ExternFunctionId<'db>>,
1606 bounded_int_add: ExternFunctionId<'db>,
1608 bounded_int_sub: ExternFunctionId<'db>,
1610 bounded_int_constrain: ExternFunctionId<'db>,
1612 bounded_int_trim_min: ExternFunctionId<'db>,
1614 bounded_int_trim_max: ExternFunctionId<'db>,
1616 array_get: ExternFunctionId<'db>,
1618 array_snapshot_pop_front: ExternFunctionId<'db>,
1620 array_snapshot_pop_back: ExternFunctionId<'db>,
1622 array_len: ExternFunctionId<'db>,
1624 array_new: ExternFunctionId<'db>,
1626 array_append: ExternFunctionId<'db>,
1628 array_pop_front: ExternFunctionId<'db>,
1630 storage_base_address_from_felt252: ExternFunctionId<'db>,
1632 storage_base_address_const: GenericFunctionId<'db>,
1634 panic_with_felt252: FunctionId<'db>,
1636 pub panic_with_const_felt252: FreeFunctionId<'db>,
1638 panic_with_byte_array: FunctionId<'db>,
1640 type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>>,
1642 const_calculation_info: Arc<ConstCalcInfo<'db>>,
1644}
1645impl<'db> ConstFoldingLibfuncInfo<'db> {
1646 fn new(db: &'db dyn Database) -> Self {
1647 let core = ModuleHelper::core(db);
1648 let box_module = core.submodule("box");
1649 let integer_module = core.submodule("integer");
1650 let internal_module = core.submodule("internal");
1651 let bounded_int_module = internal_module.submodule("bounded_int");
1652 let num_module = internal_module.submodule("num");
1653 let array_module = core.submodule("array");
1654 let starknet_module = core.submodule("starknet");
1655 let storage_access_module = starknet_module.submodule("storage_access");
1656 let utypes = ["u8", "u16", "u32", "u64", "u128"];
1657 let itypes = ["i8", "i16", "i32", "i64", "i128"];
1658 let eq_fns = OrderedHashSet::<_>::from_iter(
1659 chain!(utypes, itypes).map(|ty| integer_module.extern_function_id(&format!("{ty}_eq"))),
1660 );
1661 let uadd_fns = OrderedHashSet::<_>::from_iter(
1662 utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_add"))),
1663 );
1664 let usub_fns = OrderedHashSet::<_>::from_iter(
1665 utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_sub"))),
1666 );
1667 let diff_fns = OrderedHashSet::<_>::from_iter(
1668 itypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_diff"))),
1669 );
1670 let iadd_fns =
1671 OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1672 integer_module.extern_function_id(&format!("{ty}_overflowing_add_impl"))
1673 }));
1674 let isub_fns =
1675 OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1676 integer_module.extern_function_id(&format!("{ty}_overflowing_sub_impl"))
1677 }));
1678 let wide_mul_fns = OrderedHashSet::<_>::from_iter(chain!(
1679 [bounded_int_module.extern_function_id("bounded_int_mul")],
1680 ["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
1681 .map(|ty| integer_module.extern_function_id(&format!("{ty}_wide_mul"))),
1682 ));
1683 let div_rem_fns = OrderedHashSet::<_>::from_iter(chain!(
1684 [bounded_int_module.extern_function_id("bounded_int_div_rem")],
1685 utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_safe_divmod"))),
1686 ));
1687 let type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>> = OrderedHashMap::from_iter(
1688 [
1689 ("u8", false, true),
1690 ("u16", false, true),
1691 ("u32", false, true),
1692 ("u64", false, true),
1693 ("u128", false, true),
1694 ("u256", false, false),
1695 ("i8", true, true),
1696 ("i16", true, true),
1697 ("i32", true, true),
1698 ("i64", true, true),
1699 ("i128", true, true),
1700 ]
1701 .map(|(ty_name, as_bounded_int, inc_dec): (&'static str, bool, bool)| {
1702 let ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, ty_name), vec![]);
1703 let is_zero = if as_bounded_int {
1704 bounded_int_module
1705 .function_id("bounded_int_is_zero", vec![GenericArgumentId::Type(ty)])
1706 } else {
1707 integer_module.function_id(
1708 SmolStrId::from(db, format!("{ty_name}_is_zero")).long(db).as_str(),
1709 vec![],
1710 )
1711 }
1712 .lowered(db);
1713 let (inc, dec) = if inc_dec {
1714 (
1715 Some(
1716 num_module
1717 .function_id(
1718 SmolStrId::from(db, format!("{ty_name}_inc")).long(db).as_str(),
1719 vec![],
1720 )
1721 .lowered(db),
1722 ),
1723 Some(
1724 num_module
1725 .function_id(
1726 SmolStrId::from(db, format!("{ty_name}_dec")).long(db).as_str(),
1727 vec![],
1728 )
1729 .lowered(db),
1730 ),
1731 )
1732 } else {
1733 (None, None)
1734 };
1735 let info = TypeInfo { is_zero, inc, dec };
1736 (ty, info)
1737 }),
1738 );
1739 Self {
1740 felt_sub: core.extern_function_id("felt252_sub"),
1741 felt_add: core.extern_function_id("felt252_add"),
1742 felt_mul: core.extern_function_id("felt252_mul"),
1743 felt_div: core.extern_function_id("felt252_div"),
1744 box_forward_snapshot: box_module.generic_function_id("box_forward_snapshot"),
1745 eq_fns,
1746 uadd_fns,
1747 usub_fns,
1748 diff_fns,
1749 iadd_fns,
1750 isub_fns,
1751 wide_mul_fns,
1752 div_rem_fns,
1753 bounded_int_add: bounded_int_module.extern_function_id("bounded_int_add"),
1754 bounded_int_sub: bounded_int_module.extern_function_id("bounded_int_sub"),
1755 bounded_int_constrain: bounded_int_module.extern_function_id("bounded_int_constrain"),
1756 bounded_int_trim_min: bounded_int_module.extern_function_id("bounded_int_trim_min"),
1757 bounded_int_trim_max: bounded_int_module.extern_function_id("bounded_int_trim_max"),
1758 array_get: array_module.extern_function_id("array_get"),
1759 array_snapshot_pop_front: array_module.extern_function_id("array_snapshot_pop_front"),
1760 array_snapshot_pop_back: array_module.extern_function_id("array_snapshot_pop_back"),
1761 array_len: array_module.extern_function_id("array_len"),
1762 array_new: array_module.extern_function_id("array_new"),
1763 array_append: array_module.extern_function_id("array_append"),
1764 array_pop_front: array_module.extern_function_id("array_pop_front"),
1765 storage_base_address_from_felt252: storage_access_module
1766 .extern_function_id("storage_base_address_from_felt252"),
1767 storage_base_address_const: storage_access_module
1768 .generic_function_id("storage_base_address_const"),
1769 panic_with_felt252: core.function_id("panic_with_felt252", vec![]).lowered(db),
1770 panic_with_const_felt252: core.free_function_id("panic_with_const_felt252"),
1771 panic_with_byte_array: core
1772 .submodule("panics")
1773 .function_id("panic_with_byte_array", vec![])
1774 .lowered(db),
1775 type_info,
1776 const_calculation_info: db.const_calc_info(),
1777 }
1778 }
1779}
1780
1781impl<'db> std::ops::Deref for ConstFoldingContext<'db, '_> {
1782 type Target = ConstFoldingLibfuncInfo<'db>;
1783 fn deref(&self) -> &ConstFoldingLibfuncInfo<'db> {
1784 self.libfunc_info
1785 }
1786}
1787
1788impl<'a> std::ops::Deref for ConstFoldingLibfuncInfo<'a> {
1789 type Target = ConstCalcInfo<'a>;
1790 fn deref(&self) -> &ConstCalcInfo<'a> {
1791 &self.const_calculation_info
1792 }
1793}
1794
1795#[derive(Debug, PartialEq, Eq, salsa::Update)]
1797struct TypeInfo<'db> {
1798 is_zero: FunctionId<'db>,
1800 inc: Option<FunctionId<'db>>,
1802 dec: Option<FunctionId<'db>>,
1804}
1805
1806trait TypeRangeNormalizer {
1807 fn normalized(&self, value: BigInt) -> NormalizedResult;
1810}
1811impl TypeRangeNormalizer for TypeRange {
1812 fn normalized(&self, value: BigInt) -> NormalizedResult {
1813 if value < self.min {
1814 NormalizedResult::Under(value - &self.min + &self.max + 1)
1815 } else if value > self.max {
1816 NormalizedResult::Over(value + &self.min - &self.max - 1)
1817 } else {
1818 NormalizedResult::InRange(value)
1819 }
1820 }
1821}
1822
1823enum NormalizedResult {
1825 InRange(BigInt),
1827 Over(BigInt),
1829 Under(BigInt),
1831}