1#[cfg(test)]
2#[path = "const_folding_test.rs"]
3mod test;
4
5use std::sync::Arc;
6
7use cairo_lang_defs::ids::{ExternFunctionId, FreeFunctionId};
8use cairo_lang_filesystem::flag::flag_unsafe_panic;
9use cairo_lang_filesystem::ids::SmolStrId;
10use cairo_lang_semantic::corelib::CorelibSemantic;
11use cairo_lang_semantic::helper::ModuleHelper;
12use cairo_lang_semantic::items::constant::{
13 ConstCalcInfo, ConstValue, ConstValueId, ConstantSemantic, TypeRange, canonical_felt252,
14 felt252_for_downcast,
15};
16use cairo_lang_semantic::items::functions::{GenericFunctionId, GenericFunctionWithBodyId};
17use cairo_lang_semantic::items::structure::StructSemantic;
18use cairo_lang_semantic::types::{TypeSizeInformation, TypesSemantic};
19use cairo_lang_semantic::{
20 ConcreteTypeId, ConcreteVariant, GenericArgumentId, MatchArmSelector, TypeId, TypeLongId,
21 corelib,
22};
23use cairo_lang_utils::byte_array::BYTE_ARRAY_MAGIC;
24use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
25use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
26use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
27use cairo_lang_utils::{Intern, extract_matches, require, try_extract_matches};
28use itertools::{chain, zip_eq};
29use num_bigint::BigInt;
30use num_integer::Integer;
31use num_traits::cast::ToPrimitive;
32use num_traits::{Num, One, Zero};
33use salsa::Database;
34use starknet_types_core::felt::Felt as Felt252;
35
36use crate::db::LoweringGroup;
37use crate::ids::{
38 ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionId, SemanticFunctionIdEx,
39 SpecializedFunction,
40};
41use crate::specialization::SpecializationArg;
42use crate::utils::InliningStrategy;
43use crate::{
44 Block, BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, MatchArm, MatchEnumInfo,
45 MatchExternInfo, MatchInfo, Statement, StatementCall, StatementConst, StatementDesnap,
46 StatementEnumConstruct, StatementSnapshot, StatementStructConstruct,
47 StatementStructDestructure, VarRemapping, VarUsage, Variable, VariableArena, VariableId,
48};
49
50fn const_to_specialization_arg<'db>(
54 db: &'db dyn Database,
55 value: ConstValueId<'db>,
56 boxed: bool,
57) -> SpecializationArg<'db> {
58 match value.long(db) {
59 ConstValue::Struct(members, ty) => {
60 if matches!(
63 ty.long(db),
64 TypeLongId::Concrete(ConcreteTypeId::Struct(_))
65 | TypeLongId::Tuple(_)
66 | TypeLongId::FixedSizeArray { .. }
67 ) {
68 let args = members
69 .iter()
70 .map(|member| const_to_specialization_arg(db, *member, false))
71 .collect();
72 SpecializationArg::Struct(args)
73 } else {
74 SpecializationArg::Const { value, boxed }
75 }
76 }
77 ConstValue::Enum(variant, payload) => SpecializationArg::Enum {
78 variant: *variant,
79 payload: Box::new(const_to_specialization_arg(db, *payload, false)),
80 },
81 _ => SpecializationArg::Const { value, boxed },
82 }
83}
84
85#[derive(Debug, Clone)]
88enum VarInfo<'db> {
89 Const(ConstValueId<'db>),
91 Var(VarUsage<'db>),
93 Snapshot(Box<VarInfo<'db>>),
95 Struct(Vec<Option<VarInfo<'db>>>),
98 Enum { variant: ConcreteVariant<'db>, payload: Box<VarInfo<'db>> },
100 Box(Box<VarInfo<'db>>),
102 Array(Vec<Option<VarInfo<'db>>>),
105}
106impl<'db> VarInfo<'db> {
107 fn peel_snapshots(&self) -> (usize, &VarInfo<'db>) {
109 let mut n_snapshots = 0;
110 let mut info = self;
111 while let VarInfo::Snapshot(inner) = info {
112 info = inner.as_ref();
113 n_snapshots += 1;
114 }
115 (n_snapshots, info)
116 }
117 fn wrap_with_snapshots(mut self, n_snapshots: usize) -> VarInfo<'db> {
119 for _ in 0..n_snapshots {
120 self = VarInfo::Snapshot(Box::new(self));
121 }
122 self
123 }
124}
125
126#[derive(Debug, Clone, Copy, PartialEq)]
127enum Reachability {
128 FromSingleGoto(BlockId),
131 Any,
134}
135
136pub fn const_folding<'db>(
139 db: &'db dyn Database,
140 function_id: ConcreteFunctionWithBodyId<'db>,
141 lowered: &mut Lowered<'db>,
142) {
143 if lowered.blocks.is_empty() {
144 return;
145 }
146
147 let mut ctx = ConstFoldingContext::new(db, function_id, &mut lowered.variables);
150
151 if ctx.should_skip_const_folding(db) {
152 return;
153 }
154
155 for block_id in (0..lowered.blocks.len()).map(BlockId) {
156 if !ctx.visit_block_start(block_id, |block_id| &lowered.blocks[block_id]) {
157 continue;
158 }
159
160 let block = &mut lowered.blocks[block_id];
161 for stmt in block.statements.iter_mut() {
162 ctx.visit_statement(stmt);
163 }
164 ctx.visit_block_end(block_id, block);
165 }
166}
167
168pub struct ConstFoldingContext<'db, 'mt> {
169 db: &'db dyn Database,
171 pub variables: &'mt mut VariableArena<'db>,
173 var_info: UnorderedHashMap<VariableId, VarInfo<'db>>,
175 libfunc_info: &'db ConstFoldingLibfuncInfo<'db>,
177 caller_function: ConcreteFunctionWithBodyId<'db>,
180 reachability: UnorderedHashMap<BlockId, Reachability>,
184 additional_stmts: Vec<Statement<'db>>,
186}
187
188impl<'db, 'mt> ConstFoldingContext<'db, 'mt> {
189 pub fn new(
190 db: &'db dyn Database,
191 function_id: ConcreteFunctionWithBodyId<'db>,
192 variables: &'mt mut VariableArena<'db>,
193 ) -> Self {
194 Self {
195 db,
196 var_info: UnorderedHashMap::default(),
197 variables,
198 libfunc_info: priv_const_folding_info(db),
199 caller_function: function_id,
200 reachability: UnorderedHashMap::from_iter([(BlockId::root(), Reachability::Any)]),
201 additional_stmts: vec![],
202 }
203 }
204
205 pub fn visit_block_start<'r, 'get>(
208 &'r mut self,
209 block_id: BlockId,
210 get_block: impl FnOnce(BlockId) -> &'get Block<'db>,
211 ) -> bool
212 where
213 'db: 'get,
214 {
215 let Some(reachability) = self.reachability.remove(&block_id) else {
216 return false;
217 };
218 match reachability {
219 Reachability::Any => {}
220 Reachability::FromSingleGoto(from_block) => match &get_block(from_block).end {
221 BlockEnd::Goto(_, remapping) => {
222 for (dst, src) in remapping.iter() {
223 if let Some(v) = self.as_const(src.var_id) {
224 self.var_info.insert(*dst, VarInfo::Const(v));
225 }
226 }
227 }
228 _ => unreachable!("Expected a goto end"),
229 },
230 }
231 true
232 }
233
234 pub fn visit_statement(&mut self, stmt: &mut Statement<'db>) {
244 self.maybe_replace_inputs(stmt.inputs_mut());
245 match stmt {
246 Statement::Const(StatementConst { value, output, boxed }) if *boxed => {
247 self.var_info.insert(*output, VarInfo::Box(VarInfo::Const(*value).into()));
248 }
249 Statement::Const(StatementConst { value, output, .. }) => match value.long(self.db) {
250 ConstValue::Int(..)
251 | ConstValue::Struct(..)
252 | ConstValue::Enum(..)
253 | ConstValue::NonZero(..) => {
254 self.var_info.insert(*output, VarInfo::Const(*value));
255 }
256 ConstValue::Generic(_)
257 | ConstValue::ImplConstant(_)
258 | ConstValue::Var(..)
259 | ConstValue::Missing(_) => {}
260 },
261 Statement::Snapshot(stmt) => {
262 if let Some(info) = self.var_info.get(&stmt.input.var_id).cloned() {
263 self.var_info.insert(stmt.original(), info.clone());
264 self.var_info.insert(stmt.snapshot(), VarInfo::Snapshot(info.into()));
265 }
266 }
267 Statement::Desnap(StatementDesnap { input, output }) => {
268 if let Some(VarInfo::Snapshot(info)) = self.var_info.get(&input.var_id) {
269 self.var_info.insert(*output, info.as_ref().clone());
270 }
271 }
272 Statement::Call(call_stmt) => {
273 if let Some(updated_stmt) = self.handle_statement_call(call_stmt) {
274 *stmt = updated_stmt;
275 } else if let Some(updated_stmt) = self.try_specialize_call(call_stmt) {
276 *stmt = updated_stmt;
277 }
278 }
279 Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
280 let mut const_args = vec![];
281 let mut all_args = vec![];
282 let mut contains_info = false;
283 for input in inputs.iter() {
284 let Some(info) = self.var_info.get(&input.var_id) else {
285 all_args.push(var_info_if_copy(self.variables, *input));
286 continue;
287 };
288 contains_info = true;
289 if let VarInfo::Const(value) = info {
290 const_args.push(*value);
291 }
292 all_args.push(Some(info.clone()));
293 }
294 if const_args.len() == inputs.len() {
295 let value =
296 ConstValue::Struct(const_args, self.variables[*output].ty).intern(self.db);
297 self.var_info.insert(*output, VarInfo::Const(value));
298 } else if contains_info {
299 self.var_info.insert(*output, VarInfo::Struct(all_args));
300 }
301 }
302 Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
303 if let Some(info) = self.var_info.get(&input.var_id) {
304 let (n_snapshots, info) = info.peel_snapshots();
305 match info {
306 VarInfo::Const(const_value) => {
307 if let ConstValue::Struct(member_values, _) = const_value.long(self.db)
308 {
309 for (output, value) in zip_eq(outputs, member_values) {
310 self.var_info.insert(
311 *output,
312 VarInfo::Const(*value).wrap_with_snapshots(n_snapshots),
313 );
314 }
315 }
316 }
317 VarInfo::Struct(members) => {
318 for (output, member) in zip_eq(outputs, members.clone()) {
319 if let Some(member) = member {
320 self.var_info
321 .insert(*output, member.wrap_with_snapshots(n_snapshots));
322 }
323 }
324 }
325 _ => {}
326 }
327 }
328 }
329 Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
330 let value = if let Some(info) = self.var_info.get(&input.var_id) {
331 if let VarInfo::Const(val) = info {
332 VarInfo::Const(ConstValue::Enum(*variant, *val).intern(self.db))
333 } else {
334 VarInfo::Enum { variant: *variant, payload: info.clone().into() }
335 }
336 } else {
337 VarInfo::Enum { variant: *variant, payload: VarInfo::Var(*input).into() }
338 };
339 self.var_info.insert(*output, value);
340 }
341 }
342 }
343
344 pub fn visit_block_end(&mut self, block_id: BlockId, block: &mut Block<'db>) {
351 let statements = &mut block.statements;
352 statements.splice(0..0, self.additional_stmts.drain(..));
353
354 match &mut block.end {
355 BlockEnd::Goto(_, remappings) => {
356 for (_, v) in remappings.iter_mut() {
357 self.maybe_replace_input(v);
358 }
359 }
360 BlockEnd::Match { info } => {
361 self.maybe_replace_inputs(info.inputs_mut());
362 match info {
363 MatchInfo::Enum(info) => {
364 if let Some(updated_end) = self.handle_enum_block_end(info, statements) {
365 block.end = updated_end;
366 }
367 }
368 MatchInfo::Extern(info) => {
369 if let Some(updated_end) = self.handle_extern_block_end(info, statements) {
370 block.end = updated_end;
371 }
372 }
373 MatchInfo::Value(info) => {
374 if let Some(value) =
375 self.as_int(info.input.var_id).and_then(|x| x.to_usize())
376 && let Some(arm) = info.arms.iter().find(|arm| {
377 matches!(
378 &arm.arm_selector,
379 MatchArmSelector::Value(v) if v.value == value
380 )
381 })
382 {
383 statements.push(Statement::StructConstruct(StatementStructConstruct {
385 inputs: vec![],
386 output: arm.var_ids[0],
387 }));
388 block.end = BlockEnd::Goto(arm.block_id, Default::default());
389 }
390 }
391 }
392 }
393 BlockEnd::Return(inputs, _) => self.maybe_replace_inputs(inputs),
394 BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
395 }
396 match &block.end {
397 BlockEnd::Goto(dst_block_id, _) => {
398 match self.reachability.entry(*dst_block_id) {
399 std::collections::hash_map::Entry::Occupied(mut e) => {
400 e.insert(Reachability::Any)
401 }
402 std::collections::hash_map::Entry::Vacant(e) => {
403 *e.insert(Reachability::FromSingleGoto(block_id))
404 }
405 };
406 }
407 BlockEnd::Match { info } => {
408 for arm in info.arms() {
409 assert!(self.reachability.insert(arm.block_id, Reachability::Any).is_none());
410 }
411 }
412 BlockEnd::NotSet | BlockEnd::Return(..) | BlockEnd::Panic(..) => {}
413 }
414 }
415
416 fn handle_statement_call(&mut self, stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
424 let db = self.db;
425 if stmt.function == self.panic_with_felt252 {
426 let val = self.as_const(stmt.inputs[0].var_id)?;
427 stmt.inputs.clear();
428 stmt.function = GenericFunctionId::Free(self.panic_with_const_felt252)
429 .concretize(db, vec![GenericArgumentId::Constant(val)])
430 .lowered(db);
431 return None;
432 } else if stmt.function == self.panic_with_byte_array && !flag_unsafe_panic(db) {
433 let snap = self.var_info.get(&stmt.inputs[0].var_id)?;
434 let bytearray = try_extract_matches!(snap, VarInfo::Snapshot)?;
435 let [
436 Some(VarInfo::Array(data)),
437 Some(VarInfo::Const(pending_word)),
438 Some(VarInfo::Const(pending_len)),
439 ] = &try_extract_matches!(bytearray.as_ref(), VarInfo::Struct)?[..]
440 else {
441 return None;
442 };
443 let mut panic_data =
444 vec![BigInt::from_str_radix(BYTE_ARRAY_MAGIC, 16).unwrap(), data.len().into()];
445 for word in data {
446 let Some(VarInfo::Const(word)) = word else {
447 return None;
448 };
449 panic_data.push(word.long(db).to_int()?.clone());
450 }
451 panic_data.extend([
452 pending_word.long(db).to_int()?.clone(),
453 pending_len.long(db).to_int()?.clone(),
454 ]);
455 let felt252_ty = self.felt252;
456 let location = stmt.location;
457 let new_var = |ty| Variable::with_default_context(db, ty, location);
458 let as_usage = |var_id| VarUsage { var_id, location };
459 let array_fn = |extern_id| {
460 let args = vec![GenericArgumentId::Type(felt252_ty)];
461 GenericFunctionId::Extern(extern_id).concretize(db, args).lowered(db)
462 };
463 let call_stmt = |function, inputs, outputs| {
464 let with_coupon = false;
465 Statement::Call(StatementCall {
466 function,
467 inputs,
468 with_coupon,
469 outputs,
470 location,
471 is_specialization_base_call: false,
472 })
473 };
474 let arr_var = new_var(corelib::core_array_felt252_ty(db));
475 let mut arr = self.variables.alloc(arr_var.clone());
476 self.additional_stmts.push(call_stmt(array_fn(self.array_new), vec![], vec![arr]));
477 let felt252_var = new_var(felt252_ty);
478 let arr_append_fn = array_fn(self.array_append);
479 for word in panic_data {
480 let to_append = self.variables.alloc(felt252_var.clone());
481 let new_arr = self.variables.alloc(arr_var.clone());
482 self.additional_stmts.push(Statement::Const(StatementConst::new_flat(
483 ConstValue::Int(word, felt252_ty).intern(db),
484 to_append,
485 )));
486 self.additional_stmts.push(call_stmt(
487 arr_append_fn,
488 vec![as_usage(arr), as_usage(to_append)],
489 vec![new_arr],
490 ));
491 arr = new_arr;
492 }
493 let panic_ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, "Panic"), vec![]);
494 let panic_var = self.variables.alloc(new_var(panic_ty));
495 self.additional_stmts.push(Statement::StructConstruct(StatementStructConstruct {
496 inputs: vec![],
497 output: panic_var,
498 }));
499 return Some(Statement::StructConstruct(StatementStructConstruct {
500 inputs: vec![as_usage(panic_var), as_usage(arr)],
501 output: stmt.outputs[0],
502 }));
503 }
504 let (id, _generic_args) = stmt.function.get_extern(db)?;
505 if id == self.felt_sub {
506 if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
507 && rhs.is_zero()
508 {
509 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
510 None
511 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
512 && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
513 {
514 let value = canonical_felt252(&(lhs - rhs));
515 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
516 } else {
517 None
518 }
519 } else if id == self.felt_add {
520 if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
521 && lhs.is_zero()
522 {
523 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]));
524 None
525 } else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
526 && rhs.is_zero()
527 {
528 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
529 None
530 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
531 && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
532 {
533 let value = canonical_felt252(&(lhs + rhs));
534 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
535 } else {
536 None
537 }
538 } else if id == self.felt_mul {
539 let lhs = self.as_int(stmt.inputs[0].var_id);
540 let rhs = self.as_int(stmt.inputs[1].var_id);
541 if lhs.map(Zero::is_zero).unwrap_or_default()
542 || rhs.map(Zero::is_zero).unwrap_or_default()
543 {
544 Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
545 } else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
546 && rhs.is_one()
547 {
548 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
549 None
550 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
551 && lhs.is_one()
552 {
553 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]));
554 None
555 } else if let Some(lhs) = lhs
556 && let Some(rhs) = rhs
557 {
558 let value = canonical_felt252(&(lhs * rhs));
559 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
560 } else {
561 None
562 }
563 } else if id == self.felt_div {
564 if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
566 && rhs.is_one()
568 {
569 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
570 None
571 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
572 && lhs.is_zero()
574 {
575 Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
576 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
577 && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
578 && let Ok(rhs_nonzero) = Felt252::from(rhs).try_into()
579 {
580 let lhs_felt = Felt252::from(lhs);
584 let value = lhs_felt.field_div(&rhs_nonzero).to_bigint();
585 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
586 } else {
587 None
588 }
589 } else if self.wide_mul_fns.contains(&id) {
590 let lhs = self.as_int(stmt.inputs[0].var_id);
591 let rhs = self.as_int(stmt.inputs[1].var_id);
592 let output = stmt.outputs[0];
593 if lhs.map(Zero::is_zero).unwrap_or_default()
594 || rhs.map(Zero::is_zero).unwrap_or_default()
595 {
596 return Some(self.propagate_zero_and_get_statement(output));
597 }
598 let lhs = lhs?;
599 Some(self.propagate_const_and_get_statement(lhs * rhs?, stmt.outputs[0]))
600 } else if id == self.bounded_int_add || id == self.bounded_int_sub {
601 let lhs = self.as_int(stmt.inputs[0].var_id)?;
602 let rhs = self.as_int(stmt.inputs[1].var_id)?;
603 let value = if id == self.bounded_int_add { lhs + rhs } else { lhs - rhs };
604 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
605 } else if self.div_rem_fns.contains(&id) {
606 let lhs = self.as_int(stmt.inputs[0].var_id);
607 if lhs.map(Zero::is_zero).unwrap_or_default() {
608 let additional_stmt = self.propagate_zero_and_get_statement(stmt.outputs[1]);
609 self.additional_stmts.push(additional_stmt);
610 return Some(self.propagate_zero_and_get_statement(stmt.outputs[0]));
611 }
612 let rhs = self.as_int(stmt.inputs[1].var_id)?;
613 let (q, r) = lhs?.div_rem(rhs);
614 let q_output = stmt.outputs[0];
615 let q_value = ConstValue::Int(q, self.variables[q_output].ty).intern(db);
616 self.var_info.insert(q_output, VarInfo::Const(q_value));
617 let r_output = stmt.outputs[1];
618 let r_value = ConstValue::Int(r, self.variables[r_output].ty).intern(db);
619 self.var_info.insert(r_output, VarInfo::Const(r_value));
620 self.additional_stmts
621 .push(Statement::Const(StatementConst::new_flat(r_value, r_output)));
622 Some(Statement::Const(StatementConst::new_flat(q_value, q_output)))
623 } else if id == self.storage_base_address_from_felt252 {
624 let input_var = stmt.inputs[0].var_id;
625 if let Some(const_value) = self.as_const(input_var)
626 && let ConstValue::Int(val, ty) = const_value.long(db)
627 {
628 stmt.inputs.clear();
629 let arg = GenericArgumentId::Constant(ConstValue::Int(val.clone(), *ty).intern(db));
630 stmt.function =
631 self.storage_base_address_const.concretize(db, vec![arg]).lowered(db);
632 }
633 None
634 } else if id == self.into_box {
635 let input = stmt.inputs[0];
636 let var_info = self.var_info.get(&input.var_id);
637 let const_value = match var_info {
638 Some(VarInfo::Const(val)) => Some(*val),
639 Some(VarInfo::Snapshot(info)) => {
640 try_extract_matches!(info.as_ref(), VarInfo::Const).copied()
641 }
642 _ => None,
643 };
644 let var_info = var_info.cloned().or_else(|| var_info_if_copy(self.variables, input))?;
645 self.var_info.insert(stmt.outputs[0], VarInfo::Box(var_info.into()));
646 Some(Statement::Const(StatementConst::new_boxed(const_value?, stmt.outputs[0])))
647 } else if id == self.unbox {
648 if let VarInfo::Box(inner) = self.var_info.get(&stmt.inputs[0].var_id)? {
649 let inner = inner.as_ref().clone();
650 if let VarInfo::Const(inner) =
651 self.var_info.entry(stmt.outputs[0]).insert_entry(inner).get()
652 {
653 return Some(Statement::Const(StatementConst::new_flat(
654 *inner,
655 stmt.outputs[0],
656 )));
657 }
658 }
659 None
660 } else if self.upcast_fns.contains(&id) {
661 let int_value = self.as_int(stmt.inputs[0].var_id)?;
662 let output = stmt.outputs[0];
663 let value = ConstValue::Int(int_value.clone(), self.variables[output].ty).intern(db);
664 self.var_info.insert(output, VarInfo::Const(value));
665 Some(Statement::Const(StatementConst::new_flat(value, output)))
666 } else if id == self.array_new {
667 self.var_info.insert(stmt.outputs[0], VarInfo::Array(vec![]));
668 None
669 } else if id == self.array_append {
670 let mut var_infos =
671 if let VarInfo::Array(var_infos) = self.var_info.get(&stmt.inputs[0].var_id)? {
672 var_infos.clone()
673 } else {
674 return None;
675 };
676 let appended = stmt.inputs[1];
677 var_infos.push(match self.var_info.get(&appended.var_id) {
678 Some(var_info) => Some(var_info.clone()),
679 None => var_info_if_copy(self.variables, appended),
680 });
681 self.var_info.insert(stmt.outputs[0], VarInfo::Array(var_infos));
682 None
683 } else if id == self.array_len {
684 let info = self.var_info.get(&stmt.inputs[0].var_id)?;
685 let desnapped = try_extract_matches!(info, VarInfo::Snapshot)?;
686 let length = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?.len();
687 Some(self.propagate_const_and_get_statement(length.into(), stmt.outputs[0]))
688 } else {
689 None
690 }
691 }
692
693 fn try_specialize_call(&self, call_stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
700 if call_stmt.with_coupon {
701 return None;
702 }
703 if matches!(self.db.optimizations().inlining_strategy(), InliningStrategy::Avoid) {
705 return None;
706 }
707
708 let Ok(Some(mut called_function)) = call_stmt.function.body(self.db) else {
709 return None;
710 };
711
712 let extract_base = |function: ConcreteFunctionWithBodyId<'db>| match function.long(self.db)
713 {
714 ConcreteFunctionWithBodyLongId::Specialized(specialized) => specialized.base,
715 _ => function,
716 };
717 let called_base = extract_base(called_function);
718 let caller_base = extract_base(self.caller_function);
719
720 if self.db.priv_never_inline(called_base).ok()? {
721 return None;
722 }
723
724 if call_stmt.is_specialization_base_call {
726 return None;
727 }
728
729 if called_base == caller_base && called_function != called_base {
731 return None;
732 }
733
734 let scc =
737 self.db.lowered_scc(called_base, DependencyType::Call, LoweringStage::Monomorphized);
738 if scc.len() > 1 && scc.contains(&caller_base) {
739 return None;
740 }
741
742 if call_stmt.inputs.iter().all(|arg| self.var_info.get(&arg.var_id).is_none()) {
743 return None;
745 }
746
747 let self_specializition = if let ConcreteFunctionWithBodyLongId::Specialized(specialized) =
749 self.caller_function.long(self.db)
750 && caller_base == called_base
751 {
752 specialized.args.iter().map(Some).collect()
753 } else {
754 vec![None; call_stmt.inputs.len()]
755 };
756
757 let mut specialization_args = vec![];
758 let mut new_args = vec![];
759 for (arg, coerce) in zip_eq(&call_stmt.inputs, &self_specializition) {
760 if let Some(var_info) = self.var_info.get(&arg.var_id)
761 && self.variables[arg.var_id].info.droppable.is_ok()
762 && let Some(specialization_arg) = self.try_get_specialization_arg(
763 var_info.clone(),
764 self.variables[arg.var_id].ty,
765 &mut new_args,
766 *coerce,
767 )
768 {
769 specialization_args.push(specialization_arg);
770 } else {
771 specialization_args.push(SpecializationArg::NotSpecialized);
772 new_args.push(*arg);
773 continue;
774 };
775 }
776
777 if specialization_args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized)) {
778 return None;
780 }
781 if let ConcreteFunctionWithBodyLongId::Specialized(specialized_function) =
782 called_function.long(self.db)
783 {
784 called_function = specialized_function.base;
787 let mut new_args_iter = specialization_args.into_iter();
788 let mut old_args = specialized_function.args.to_vec();
789 let mut stack = vec![];
790 for arg in old_args.iter_mut().rev() {
791 stack.push(arg);
792 }
793 while let Some(arg) = stack.pop() {
794 match arg {
795 SpecializationArg::Const { .. } => {}
796 SpecializationArg::Snapshot(inner) => {
797 stack.push(inner.as_mut());
798 }
799 SpecializationArg::Enum { payload, .. } => {
800 stack.push(payload.as_mut());
801 }
802 SpecializationArg::Array(_, values) | SpecializationArg::Struct(values) => {
803 for value in values.iter_mut().rev() {
804 stack.push(value);
805 }
806 }
807 SpecializationArg::NotSpecialized => {
808 *arg = new_args_iter.next().unwrap_or(SpecializationArg::NotSpecialized);
809 }
810 }
811 }
812 specialization_args = old_args;
813 }
814 let specialized =
815 SpecializedFunction { base: called_function, args: specialization_args.into() };
816 let specialized_func_id =
817 ConcreteFunctionWithBodyLongId::Specialized(specialized).intern(self.db);
818
819 if caller_base != called_base
820 && self.db.priv_should_specialize(specialized_func_id) == Ok(false)
821 {
822 return None;
823 }
824
825 Some(Statement::Call(StatementCall {
826 function: specialized_func_id.function_id(self.db).unwrap(),
827 inputs: new_args,
828 with_coupon: call_stmt.with_coupon,
829 outputs: std::mem::take(&mut call_stmt.outputs),
830 location: call_stmt.location,
831 is_specialization_base_call: false,
832 }))
833 }
834
835 fn propagate_const_and_get_statement(
837 &mut self,
838 value: BigInt,
839 output: VariableId,
840 ) -> Statement<'db> {
841 let ty = self.variables[output].ty;
842 let value = ConstValueId::from_int(self.db, ty, &value);
843 self.var_info.insert(output, VarInfo::Const(value));
844 Statement::Const(StatementConst::new_flat(value, output))
845 }
846
847 fn propagate_zero_and_get_statement(&mut self, output: VariableId) -> Statement<'db> {
849 self.propagate_const_and_get_statement(BigInt::zero(), output)
850 }
851
852 fn try_generate_const_statement(
854 &self,
855 value: ConstValueId<'db>,
856 output: VariableId,
857 ) -> Option<Statement<'db>> {
858 if self.db.type_size_info(self.variables[output].ty) == Ok(TypeSizeInformation::Other) {
859 Some(Statement::Const(StatementConst::new_flat(value, output)))
860 } else if matches!(value.long(self.db), ConstValue::Struct(members, _) if members.is_empty())
861 {
862 Some(Statement::StructConstruct(StatementStructConstruct { inputs: vec![], output }))
864 } else {
865 None
866 }
867 }
868
869 fn handle_enum_block_end(
874 &mut self,
875 info: &mut MatchEnumInfo<'db>,
876 statements: &mut Vec<Statement<'db>>,
877 ) -> Option<BlockEnd<'db>> {
878 let input = info.input.var_id;
879 let (n_snapshots, var_info) = self.var_info.get(&input)?.peel_snapshots();
880 let location = info.location;
881 let as_usage = |var_id| VarUsage { var_id, location };
882 let db = self.db;
883 let snapshot_stmt = |vars: &mut VariableArena<'_>, pre_snap, post_snap| {
884 let ignored = vars.alloc(vars[pre_snap].clone());
885 Statement::Snapshot(StatementSnapshot::new(as_usage(pre_snap), ignored, post_snap))
886 };
887 if let VarInfo::Const(const_value) = var_info
889 && let ConstValue::Enum(variant, value) = const_value.long(db)
890 {
891 let arm = &info.arms[variant.idx];
892 let output = arm.var_ids[0];
893 self.var_info.insert(output, VarInfo::Const(*value).wrap_with_snapshots(n_snapshots));
895 if self.variables[input].info.droppable.is_ok()
896 && self.variables[output].info.copyable.is_ok()
897 && let Ok(mut ty) = value.ty(db)
898 && let Some(mut stmt) = self.try_generate_const_statement(*value, output)
899 {
900 for _ in 0..n_snapshots {
902 let non_snap_var = Variable::with_default_context(db, ty, location);
903 ty = TypeLongId::Snapshot(ty).intern(db);
904 let pre_snap = self.variables.alloc(non_snap_var);
905 stmt.outputs_mut()[0] = pre_snap;
906 let take_snap = snapshot_stmt(self.variables, pre_snap, output);
907 statements.push(core::mem::replace(&mut stmt, take_snap));
908 }
909 statements.push(stmt);
910 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
911 }
912 } else if let VarInfo::Enum { variant, payload } = var_info {
913 let arm = &info.arms[variant.idx];
914 let variant_ty = variant.ty;
915 let output = arm.var_ids[0];
916 let payload = payload.as_ref().clone();
917 let unwrapped =
918 self.variables[input].info.droppable.is_ok().then_some(()).and_then(|_| {
919 let (extra_snapshots, inner) = payload.peel_snapshots();
920 match inner {
921 VarInfo::Var(var) if self.variables[var.var_id].info.copyable.is_ok() => {
922 Some((var.var_id, extra_snapshots))
923 }
924 VarInfo::Const(value) => {
925 let const_var = self
926 .variables
927 .alloc(Variable::with_default_context(db, variant_ty, location));
928 statements.push(self.try_generate_const_statement(*value, const_var)?);
929 Some((const_var, extra_snapshots))
930 }
931 _ => None,
932 }
933 });
934 self.var_info.insert(output, payload.wrap_with_snapshots(n_snapshots));
936 if let Some((mut unwrapped, extra_snapshots)) = unwrapped {
937 let total_snapshots = n_snapshots + extra_snapshots;
938 if total_snapshots != 0 {
939 for _ in 1..total_snapshots {
941 let ty = TypeLongId::Snapshot(self.variables[unwrapped].ty).intern(db);
942 let non_snap_var = Variable::with_default_context(self.db, ty, location);
943 let snapped = self.variables.alloc(non_snap_var);
944 statements.push(snapshot_stmt(self.variables, unwrapped, snapped));
945 unwrapped = snapped;
946 }
947 statements.push(snapshot_stmt(self.variables, unwrapped, output));
948 };
949 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
950 }
951 }
952 None
953 }
954
955 fn handle_extern_block_end(
960 &mut self,
961 info: &mut MatchExternInfo<'db>,
962 statements: &mut Vec<Statement<'db>>,
963 ) -> Option<BlockEnd<'db>> {
964 let db = self.db;
965 let (id, generic_args) = info.function.get_extern(db)?;
966 if self.nz_fns.contains(&id) {
967 let val = self.as_const(info.inputs[0].var_id)?;
968 let is_zero = match val.long(db) {
969 ConstValue::Int(v, _) => v.is_zero(),
970 ConstValue::Struct(s, _) => s.iter().all(|v| {
971 v.long(db).to_int().expect("Expected ConstValue::Int for size").is_zero()
972 }),
973 _ => unreachable!(),
974 };
975 Some(if is_zero {
976 BlockEnd::Goto(info.arms[0].block_id, Default::default())
977 } else {
978 let arm = &info.arms[1];
979 let nz_var = arm.var_ids[0];
980 let nz_val = ConstValue::NonZero(val).intern(db);
981 self.var_info.insert(nz_var, VarInfo::Const(nz_val));
982 statements.push(Statement::Const(StatementConst::new_flat(nz_val, nz_var)));
983 BlockEnd::Goto(arm.block_id, Default::default())
984 })
985 } else if self.eq_fns.contains(&id) {
986 let lhs = self.as_int(info.inputs[0].var_id);
987 let rhs = self.as_int(info.inputs[1].var_id);
988 if (lhs.map(Zero::is_zero).unwrap_or_default() && rhs.is_none())
989 || (rhs.map(Zero::is_zero).unwrap_or_default() && lhs.is_none())
990 {
991 let nz_input = info.inputs[if lhs.is_some() { 1 } else { 0 }];
992 let var = &self.variables[nz_input.var_id].clone();
993 let function = self.type_info.get(&var.ty)?.is_zero;
994 let unused_nz_var = Variable::with_default_context(
995 db,
996 corelib::core_nonzero_ty(db, var.ty),
997 var.location,
998 );
999 let unused_nz_var = self.variables.alloc(unused_nz_var);
1000 return Some(BlockEnd::Match {
1001 info: MatchInfo::Extern(MatchExternInfo {
1002 function,
1003 inputs: vec![nz_input],
1004 arms: vec![
1005 MatchArm {
1006 arm_selector: MatchArmSelector::VariantId(
1007 corelib::jump_nz_zero_variant(db, var.ty),
1008 ),
1009 block_id: info.arms[1].block_id,
1010 var_ids: vec![],
1011 },
1012 MatchArm {
1013 arm_selector: MatchArmSelector::VariantId(
1014 corelib::jump_nz_nonzero_variant(db, var.ty),
1015 ),
1016 block_id: info.arms[0].block_id,
1017 var_ids: vec![unused_nz_var],
1018 },
1019 ],
1020 location: info.location,
1021 }),
1022 });
1023 }
1024 Some(BlockEnd::Goto(
1025 info.arms[if lhs? == rhs? { 1 } else { 0 }].block_id,
1026 Default::default(),
1027 ))
1028 } else if self.uadd_fns.contains(&id)
1029 || self.usub_fns.contains(&id)
1030 || self.diff_fns.contains(&id)
1031 || self.iadd_fns.contains(&id)
1032 || self.isub_fns.contains(&id)
1033 {
1034 let rhs = self.as_int(info.inputs[1].var_id);
1035 let lhs = self.as_int(info.inputs[0].var_id);
1036 if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
1037 let ty = self.variables[info.arms[0].var_ids[0]].ty;
1038 let range = self.type_value_ranges.get(&ty)?;
1039 let value = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
1040 lhs + rhs
1041 } else {
1042 lhs - rhs
1043 };
1044 let (arm_index, value) = match range.normalized(value) {
1045 NormalizedResult::InRange(value) => (0, value),
1046 NormalizedResult::Under(value) => (1, value),
1047 NormalizedResult::Over(value) => (
1048 if self.iadd_fns.contains(&id) || self.isub_fns.contains(&id) {
1049 2
1050 } else {
1051 1
1052 },
1053 value,
1054 ),
1055 };
1056 let arm = &info.arms[arm_index];
1057 let actual_output = arm.var_ids[0];
1058 let value = ConstValue::Int(value, ty).intern(db);
1059 self.var_info.insert(actual_output, VarInfo::Const(value));
1060 statements.push(Statement::Const(StatementConst::new_flat(value, actual_output)));
1061 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1062 }
1063 if let Some(rhs) = rhs {
1064 if rhs.is_zero() && !self.diff_fns.contains(&id) {
1065 let arm = &info.arms[0];
1066 self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[0]));
1067 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1068 }
1069 if rhs.is_one() && !self.diff_fns.contains(&id) {
1070 let ty = self.variables[info.arms[0].var_ids[0]].ty;
1071 let ty_info = self.type_info.get(&ty)?;
1072 let function = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
1073 ty_info.inc?
1074 } else {
1075 ty_info.dec?
1076 };
1077 let enum_ty = function.signature(db).ok()?.return_type;
1078 let TypeLongId::Concrete(ConcreteTypeId::Enum(concrete_enum_id)) =
1079 enum_ty.long(db)
1080 else {
1081 return None;
1082 };
1083 let result = self.variables.alloc(Variable::with_default_context(
1084 db,
1085 function.signature(db).unwrap().return_type,
1086 info.location,
1087 ));
1088 statements.push(Statement::Call(StatementCall {
1089 function,
1090 inputs: vec![info.inputs[0]],
1091 with_coupon: false,
1092 outputs: vec![result],
1093 location: info.location,
1094 is_specialization_base_call: false,
1095 }));
1096 return Some(BlockEnd::Match {
1097 info: MatchInfo::Enum(MatchEnumInfo {
1098 concrete_enum_id: *concrete_enum_id,
1099 input: VarUsage { var_id: result, location: info.location },
1100 arms: core::mem::take(&mut info.arms),
1101 location: info.location,
1102 }),
1103 });
1104 }
1105 }
1106 if let Some(lhs) = lhs
1107 && lhs.is_zero()
1108 && (self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id))
1109 {
1110 let arm = &info.arms[0];
1111 self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[1]));
1112 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1113 }
1114 None
1115 } else if let Some(reversed) = self.downcast_fns.get(&id) {
1116 let range = |ty: TypeId<'_>| {
1117 Some(if let Some(range) = self.type_value_ranges.get(&ty) {
1118 range.clone()
1119 } else {
1120 let (min, max) = corelib::try_extract_bounded_int_type_ranges(db, ty)?;
1121 TypeRange { min, max }
1122 })
1123 };
1124 let (success_arm, failure_arm) = if *reversed { (1, 0) } else { (0, 1) };
1125 let input_var = info.inputs[0].var_id;
1126 let in_ty = self.variables[input_var].ty;
1127 let success_output = info.arms[success_arm].var_ids[0];
1128 let out_ty = self.variables[success_output].ty;
1129 let out_range = range(out_ty)?;
1130 let Some(value) = self.as_int(input_var) else {
1131 let in_range = range(in_ty)?;
1132 return if in_range.min < out_range.min || in_range.max > out_range.max {
1133 None
1134 } else {
1135 let generic_args = [in_ty, out_ty].map(GenericArgumentId::Type).to_vec();
1136 let function = db.core_info().upcast_fn.concretize(db, generic_args);
1137 statements.push(Statement::Call(StatementCall {
1138 function: function.lowered(db),
1139 inputs: vec![info.inputs[0]],
1140 with_coupon: false,
1141 outputs: vec![success_output],
1142 location: info.location,
1143 is_specialization_base_call: false,
1144 }));
1145 return Some(BlockEnd::Goto(
1146 info.arms[success_arm].block_id,
1147 Default::default(),
1148 ));
1149 };
1150 };
1151 let value = if in_ty == self.felt252 {
1152 felt252_for_downcast(value, &out_range.min)
1153 } else {
1154 value.clone()
1155 };
1156 Some(if let NormalizedResult::InRange(value) = out_range.normalized(value) {
1157 let value = ConstValue::Int(value, out_ty).intern(db);
1158 self.var_info.insert(success_output, VarInfo::Const(value));
1159 statements.push(Statement::Const(StatementConst::new_flat(value, success_output)));
1160 BlockEnd::Goto(info.arms[success_arm].block_id, Default::default())
1161 } else {
1162 BlockEnd::Goto(info.arms[failure_arm].block_id, Default::default())
1163 })
1164 } else if id == self.bounded_int_constrain {
1165 let input_var = info.inputs[0].var_id;
1166 let value = self.as_int(input_var)?;
1167 let generic_arg = generic_args[1];
1168 let constrain_value = extract_matches!(generic_arg, GenericArgumentId::Constant)
1169 .long(db)
1170 .to_int()
1171 .expect("Expected ConstValue::Int for size");
1172 let arm_idx = if value < constrain_value { 0 } else { 1 };
1173 let output = info.arms[arm_idx].var_ids[0];
1174 statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1175 Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1176 } else if id == self.bounded_int_trim_min {
1177 let input_var = info.inputs[0].var_id;
1178 let ConstValue::Int(value, ty) = self.as_const(input_var)?.long(self.db) else {
1179 return None;
1180 };
1181 let is_trimmed = if let Some(range) = self.type_value_ranges.get(ty) {
1182 range.min == *value
1183 } else {
1184 corelib::try_extract_bounded_int_type_ranges(db, *ty)?.0 == *value
1185 };
1186 let arm_idx = if is_trimmed {
1187 0
1188 } else {
1189 let output = info.arms[1].var_ids[0];
1190 statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1191 1
1192 };
1193 Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1194 } else if id == self.bounded_int_trim_max {
1195 let input_var = info.inputs[0].var_id;
1196 let ConstValue::Int(value, ty) = self.as_const(input_var)?.long(self.db) else {
1197 return None;
1198 };
1199 let is_trimmed = if let Some(range) = self.type_value_ranges.get(ty) {
1200 range.max == *value
1201 } else {
1202 corelib::try_extract_bounded_int_type_ranges(db, *ty)?.1 == *value
1203 };
1204 let arm_idx = if is_trimmed {
1205 0
1206 } else {
1207 let output = info.arms[1].var_ids[0];
1208 statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1209 1
1210 };
1211 Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1212 } else if id == self.array_get {
1213 let index = self.as_int(info.inputs[1].var_id)?.to_usize()?;
1214 if let Some(VarInfo::Snapshot(arr_info)) = self.var_info.get(&info.inputs[0].var_id)
1215 && let VarInfo::Array(infos) = arr_info.as_ref()
1216 {
1217 match infos.get(index) {
1218 Some(Some(output_var_info)) => {
1219 let arm = &info.arms[0];
1220 let output_var_info = output_var_info.clone();
1221 let box_info =
1222 VarInfo::Box(VarInfo::Snapshot(output_var_info.clone().into()).into());
1223 self.var_info.insert(arm.var_ids[0], box_info);
1224 if let VarInfo::Const(value) = output_var_info {
1225 let value_ty = value.ty(db).ok()?;
1226 let value_box_ty = corelib::core_box_ty(db, value_ty);
1227 let location = info.location;
1228 let boxed_var =
1229 Variable::with_default_context(db, value_box_ty, location);
1230 let boxed = self.variables.alloc(boxed_var.clone());
1231 let unused_boxed = self.variables.alloc(boxed_var);
1232 let snapped = self.variables.alloc(Variable::with_default_context(
1233 db,
1234 TypeLongId::Snapshot(value_box_ty).intern(db),
1235 location,
1236 ));
1237 statements.extend([
1238 Statement::Const(StatementConst::new_boxed(value, boxed)),
1239 Statement::Snapshot(StatementSnapshot {
1240 input: VarUsage { var_id: boxed, location },
1241 outputs: [unused_boxed, snapped],
1242 }),
1243 Statement::Call(StatementCall {
1244 function: self
1245 .box_forward_snapshot
1246 .concretize(db, vec![GenericArgumentId::Type(value_ty)])
1247 .lowered(db),
1248 inputs: vec![VarUsage { var_id: snapped, location }],
1249 with_coupon: false,
1250 outputs: vec![arm.var_ids[0]],
1251 location: info.location,
1252 is_specialization_base_call: false,
1253 }),
1254 ]);
1255 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1256 }
1257 }
1258 None => {
1259 return Some(BlockEnd::Goto(info.arms[1].block_id, Default::default()));
1260 }
1261 Some(None) => {}
1262 }
1263 }
1264 if index.is_zero()
1265 && let [success, failure] = info.arms.as_mut_slice()
1266 {
1267 let arr = info.inputs[0].var_id;
1268 let unused_arr_output0 = self.variables.alloc(self.variables[arr].clone());
1269 let unused_arr_output1 = self.variables.alloc(self.variables[arr].clone());
1270 info.inputs.truncate(1);
1271 info.function = GenericFunctionId::Extern(self.array_snapshot_pop_front)
1272 .concretize(db, generic_args)
1273 .lowered(db);
1274 success.var_ids.insert(0, unused_arr_output0);
1275 failure.var_ids.insert(0, unused_arr_output1);
1276 }
1277 None
1278 } else if id == self.array_pop_front {
1279 let VarInfo::Array(var_infos) = self.var_info.get(&info.inputs[0].var_id)? else {
1280 return None;
1281 };
1282 if let Some(first) = var_infos.first() {
1283 if let Some(first) = first.as_ref().cloned() {
1284 let arm = &info.arms[0];
1285 self.var_info.insert(arm.var_ids[0], VarInfo::Array(var_infos[1..].to_vec()));
1286 self.var_info.insert(arm.var_ids[1], VarInfo::Box(first.into()));
1287 }
1288 None
1289 } else {
1290 let arm = &info.arms[1];
1291 self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]));
1292 Some(BlockEnd::Goto(
1293 arm.block_id,
1294 VarRemapping {
1295 remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1296 },
1297 ))
1298 }
1299 } else if id == self.array_snapshot_pop_back || id == self.array_snapshot_pop_front {
1300 let var_info = self.var_info.get(&info.inputs[0].var_id)?;
1301 let desnapped = try_extract_matches!(var_info, VarInfo::Snapshot)?;
1302 let element_var_infos = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?;
1303 if element_var_infos.is_empty() {
1305 let arm = &info.arms[1];
1306 self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]));
1307 Some(BlockEnd::Goto(
1308 arm.block_id,
1309 VarRemapping {
1310 remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1311 },
1312 ))
1313 } else {
1314 None
1315 }
1316 } else {
1317 None
1318 }
1319 }
1320
1321 fn as_const(&self, var_id: VariableId) -> Option<ConstValueId<'db>> {
1323 try_extract_matches!(self.var_info.get(&var_id)?, VarInfo::Const).copied()
1324 }
1325
1326 fn as_int(&self, var_id: VariableId) -> Option<&BigInt> {
1328 match self.as_const(var_id)?.long(self.db) {
1329 ConstValue::Int(value, _) => Some(value),
1330 ConstValue::NonZero(const_value) => {
1331 if let ConstValue::Int(value, _) = const_value.long(self.db) {
1332 Some(value)
1333 } else {
1334 None
1335 }
1336 }
1337 _ => None,
1338 }
1339 }
1340
1341 fn maybe_replace_inputs(&self, inputs: &mut [VarUsage<'db>]) {
1343 for input in inputs {
1344 self.maybe_replace_input(input);
1345 }
1346 }
1347
1348 fn maybe_replace_input(&self, input: &mut VarUsage<'db>) {
1350 if let Some(VarInfo::Var(new_var)) = self.var_info.get(&input.var_id) {
1351 *input = *new_var;
1352 }
1353 }
1354
1355 fn try_get_specialization_arg(
1361 &self,
1362 var_info: VarInfo<'db>,
1363 ty: TypeId<'db>,
1364 unknown_vars: &mut Vec<VarUsage<'db>>,
1365 coerce: Option<&SpecializationArg<'db>>,
1366 ) -> Option<SpecializationArg<'db>> {
1367 require(self.db.type_size_info(ty).ok()? != TypeSizeInformation::ZeroSized)?;
1369 require(!matches!(coerce, Some(SpecializationArg::NotSpecialized)))?;
1371
1372 match var_info {
1373 VarInfo::Const(value) => {
1374 let res = const_to_specialization_arg(self.db, value, false);
1375 let Some(coerce) = coerce else {
1376 return Some(res);
1377 };
1378 if *coerce == res { Some(res) } else { None }
1379 }
1380 VarInfo::Box(info) => {
1381 let res = try_extract_matches!(info.as_ref(), VarInfo::Const)
1382 .map(|value| SpecializationArg::Const { value: *value, boxed: true });
1383 let Some(coerce) = coerce else {
1384 return res;
1385 };
1386 if Some(coerce.clone()) == res { res } else { None }
1387 }
1388 VarInfo::Snapshot(info) => {
1389 let desnap_ty = *extract_matches!(ty.long(self.db), TypeLongId::Snapshot);
1390 let mut local_unknown_vars: Vec<VarUsage<'db>> = Vec::new();
1392 let inner = self.try_get_specialization_arg(
1393 info.as_ref().clone(),
1394 desnap_ty,
1395 &mut local_unknown_vars,
1396 coerce.map(|coerce| {
1397 extract_matches!(coerce, SpecializationArg::Snapshot).as_ref()
1398 }),
1399 )?;
1400 unknown_vars.extend(local_unknown_vars);
1401 Some(SpecializationArg::Snapshot(Box::new(inner)))
1402 }
1403 VarInfo::Array(infos) => {
1404 let TypeLongId::Concrete(concrete_ty) = ty.long(self.db) else {
1405 unreachable!("Expected a concrete type");
1406 };
1407 let [GenericArgumentId::Type(inner_ty)] = &concrete_ty.generic_args(self.db)[..]
1408 else {
1409 unreachable!("Expected a single type generic argument");
1410 };
1411 let coerces = match coerce {
1412 Some(coerce) => {
1413 let SpecializationArg::Array(ty, specialization_args) = coerce else {
1414 unreachable!("Expected an array specialization argument");
1415 };
1416 assert_eq!(ty, inner_ty);
1417 if specialization_args.len() != infos.len() {
1418 return None;
1419 }
1420
1421 specialization_args.iter().map(Some).collect()
1422 }
1423 None => vec![None; infos.len()],
1424 };
1425 let mut vars = vec![];
1427 let mut args = vec![];
1428 for (info, coerce) in zip_eq(infos, coerces) {
1429 let info = info?;
1430 let arg =
1431 self.try_get_specialization_arg(info, *inner_ty, &mut vars, coerce)?;
1432 args.push(arg);
1433 }
1434 if !args.is_empty()
1435 && args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
1436 {
1437 return None;
1438 }
1439 unknown_vars.extend(vars);
1440 Some(SpecializationArg::Array(*inner_ty, args))
1441 }
1442 VarInfo::Struct(infos) => {
1443 let element_types: Vec<TypeId<'db>> = match ty.long(self.db) {
1445 TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) => {
1446 let members = self.db.concrete_struct_members(*concrete_struct).unwrap();
1447 members.values().map(|member| member.ty).collect()
1448 }
1449 TypeLongId::Tuple(element_types) => element_types.clone(),
1450 TypeLongId::FixedSizeArray { type_id, .. } => vec![*type_id; infos.len()],
1451 _ => return None,
1453 };
1454
1455 let coerces = match coerce {
1456 Some(SpecializationArg::Struct(specialization_args)) => {
1457 assert_eq!(specialization_args.len(), infos.len());
1458 specialization_args.iter().map(Some).collect()
1459 }
1460 Some(_) => unreachable!("Expected a struct specialization argument"),
1461 None => vec![None; infos.len()],
1462 };
1463
1464 let mut struct_args = Vec::new();
1465 let mut vars = vec![];
1467 for ((elem_ty, opt_var_info), coerce) in
1468 zip_eq(zip_eq(element_types, infos), coerces)
1469 {
1470 let var_info = opt_var_info?;
1471 let arg =
1472 self.try_get_specialization_arg(var_info, elem_ty, &mut vars, coerce)?;
1473 struct_args.push(arg);
1474 }
1475 if !struct_args.is_empty()
1476 && struct_args
1477 .iter()
1478 .all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
1479 {
1480 return None;
1481 }
1482 unknown_vars.extend(vars);
1483 Some(SpecializationArg::Struct(struct_args))
1484 }
1485 VarInfo::Enum { variant, payload } => {
1486 let coerce = match coerce {
1487 Some(coerce) => {
1488 let SpecializationArg::Enum { variant: coercion_variant, payload } = coerce
1489 else {
1490 unreachable!("Expected an enum specialization argument");
1491 };
1492 if *coercion_variant != variant {
1493 return None;
1494 }
1495 Some(payload.as_ref())
1496 }
1497 None => None,
1498 };
1499 let mut local_unknown_vars = vec![];
1500 let payload_arg = self.try_get_specialization_arg(
1501 payload.as_ref().clone(),
1502 variant.ty,
1503 &mut local_unknown_vars,
1504 coerce,
1505 )?;
1506
1507 unknown_vars.extend(local_unknown_vars);
1508 Some(SpecializationArg::Enum { variant, payload: Box::new(payload_arg) })
1509 }
1510 VarInfo::Var(var_usage) => {
1511 unknown_vars.push(var_usage);
1512 Some(SpecializationArg::NotSpecialized)
1513 }
1514 }
1515 }
1516
1517 pub fn should_skip_const_folding(&self, db: &'db dyn Database) -> bool {
1519 if db.optimizations().skip_const_folding() {
1520 return true;
1521 }
1522
1523 if self.caller_function.base_semantic_function(db).generic_function(db)
1526 == GenericFunctionWithBodyId::Free(self.libfunc_info.panic_with_const_felt252)
1527 {
1528 return true;
1529 }
1530 false
1531 }
1532}
1533
1534fn var_info_if_copy<'db>(
1536 variables: &VariableArena<'db>,
1537 input: VarUsage<'db>,
1538) -> Option<VarInfo<'db>> {
1539 variables[input.var_id].info.copyable.is_ok().then_some(VarInfo::Var(input))
1540}
1541
1542#[salsa::tracked(returns(ref))]
1544fn priv_const_folding_info<'db>(
1545 db: &'db dyn Database,
1546) -> crate::optimizations::const_folding::ConstFoldingLibfuncInfo<'db> {
1547 ConstFoldingLibfuncInfo::new(db)
1548}
1549
1550#[derive(Debug, PartialEq, Eq, salsa::Update)]
1552pub struct ConstFoldingLibfuncInfo<'db> {
1553 felt_sub: ExternFunctionId<'db>,
1555 felt_add: ExternFunctionId<'db>,
1557 felt_mul: ExternFunctionId<'db>,
1559 felt_div: ExternFunctionId<'db>,
1561 into_box: ExternFunctionId<'db>,
1563 unbox: ExternFunctionId<'db>,
1565 box_forward_snapshot: GenericFunctionId<'db>,
1567 eq_fns: OrderedHashSet<ExternFunctionId<'db>>,
1569 uadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1571 usub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1573 diff_fns: OrderedHashSet<ExternFunctionId<'db>>,
1575 iadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1577 isub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1579 wide_mul_fns: OrderedHashSet<ExternFunctionId<'db>>,
1581 div_rem_fns: OrderedHashSet<ExternFunctionId<'db>>,
1583 bounded_int_add: ExternFunctionId<'db>,
1585 bounded_int_sub: ExternFunctionId<'db>,
1587 bounded_int_constrain: ExternFunctionId<'db>,
1589 bounded_int_trim_min: ExternFunctionId<'db>,
1591 bounded_int_trim_max: ExternFunctionId<'db>,
1593 array_get: ExternFunctionId<'db>,
1595 array_snapshot_pop_front: ExternFunctionId<'db>,
1597 array_snapshot_pop_back: ExternFunctionId<'db>,
1599 array_len: ExternFunctionId<'db>,
1601 array_new: ExternFunctionId<'db>,
1603 array_append: ExternFunctionId<'db>,
1605 array_pop_front: ExternFunctionId<'db>,
1607 storage_base_address_from_felt252: ExternFunctionId<'db>,
1609 storage_base_address_const: GenericFunctionId<'db>,
1611 panic_with_felt252: FunctionId<'db>,
1613 pub panic_with_const_felt252: FreeFunctionId<'db>,
1615 panic_with_byte_array: FunctionId<'db>,
1617 type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>>,
1619 const_calculation_info: Arc<ConstCalcInfo<'db>>,
1621}
1622impl<'db> ConstFoldingLibfuncInfo<'db> {
1623 fn new(db: &'db dyn Database) -> Self {
1624 let core = ModuleHelper::core(db);
1625 let box_module = core.submodule("box");
1626 let integer_module = core.submodule("integer");
1627 let internal_module = core.submodule("internal");
1628 let bounded_int_module = internal_module.submodule("bounded_int");
1629 let num_module = internal_module.submodule("num");
1630 let array_module = core.submodule("array");
1631 let starknet_module = core.submodule("starknet");
1632 let storage_access_module = starknet_module.submodule("storage_access");
1633 let utypes = ["u8", "u16", "u32", "u64", "u128"];
1634 let itypes = ["i8", "i16", "i32", "i64", "i128"];
1635 let eq_fns = OrderedHashSet::<_>::from_iter(
1636 chain!(utypes, itypes).map(|ty| integer_module.extern_function_id(&format!("{ty}_eq"))),
1637 );
1638 let uadd_fns = OrderedHashSet::<_>::from_iter(
1639 utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_add"))),
1640 );
1641 let usub_fns = OrderedHashSet::<_>::from_iter(
1642 utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_sub"))),
1643 );
1644 let diff_fns = OrderedHashSet::<_>::from_iter(
1645 itypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_diff"))),
1646 );
1647 let iadd_fns =
1648 OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1649 integer_module.extern_function_id(&format!("{ty}_overflowing_add_impl"))
1650 }));
1651 let isub_fns =
1652 OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1653 integer_module.extern_function_id(&format!("{ty}_overflowing_sub_impl"))
1654 }));
1655 let wide_mul_fns = OrderedHashSet::<_>::from_iter(chain!(
1656 [bounded_int_module.extern_function_id("bounded_int_mul")],
1657 ["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
1658 .map(|ty| integer_module.extern_function_id(&format!("{ty}_wide_mul"))),
1659 ));
1660 let div_rem_fns = OrderedHashSet::<_>::from_iter(chain!(
1661 [bounded_int_module.extern_function_id("bounded_int_div_rem")],
1662 utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_safe_divmod"))),
1663 ));
1664 let type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>> = OrderedHashMap::from_iter(
1665 [
1666 ("u8", false, true),
1667 ("u16", false, true),
1668 ("u32", false, true),
1669 ("u64", false, true),
1670 ("u128", false, true),
1671 ("u256", false, false),
1672 ("i8", true, true),
1673 ("i16", true, true),
1674 ("i32", true, true),
1675 ("i64", true, true),
1676 ("i128", true, true),
1677 ]
1678 .map(|(ty_name, as_bounded_int, inc_dec): (&'static str, bool, bool)| {
1679 let ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, ty_name), vec![]);
1680 let is_zero = if as_bounded_int {
1681 bounded_int_module
1682 .function_id("bounded_int_is_zero", vec![GenericArgumentId::Type(ty)])
1683 } else {
1684 integer_module.function_id(
1685 SmolStrId::from(db, format!("{ty_name}_is_zero")).long(db).as_str(),
1686 vec![],
1687 )
1688 }
1689 .lowered(db);
1690 let (inc, dec) = if inc_dec {
1691 (
1692 Some(
1693 num_module
1694 .function_id(
1695 SmolStrId::from(db, format!("{ty_name}_inc")).long(db).as_str(),
1696 vec![],
1697 )
1698 .lowered(db),
1699 ),
1700 Some(
1701 num_module
1702 .function_id(
1703 SmolStrId::from(db, format!("{ty_name}_dec")).long(db).as_str(),
1704 vec![],
1705 )
1706 .lowered(db),
1707 ),
1708 )
1709 } else {
1710 (None, None)
1711 };
1712 let info = TypeInfo { is_zero, inc, dec };
1713 (ty, info)
1714 }),
1715 );
1716 Self {
1717 felt_sub: core.extern_function_id("felt252_sub"),
1718 felt_add: core.extern_function_id("felt252_add"),
1719 felt_mul: core.extern_function_id("felt252_mul"),
1720 felt_div: core.extern_function_id("felt252_div"),
1721 into_box: box_module.extern_function_id("into_box"),
1722 unbox: box_module.extern_function_id("unbox"),
1723 box_forward_snapshot: box_module.generic_function_id("box_forward_snapshot"),
1724 eq_fns,
1725 uadd_fns,
1726 usub_fns,
1727 diff_fns,
1728 iadd_fns,
1729 isub_fns,
1730 wide_mul_fns,
1731 div_rem_fns,
1732 bounded_int_add: bounded_int_module.extern_function_id("bounded_int_add"),
1733 bounded_int_sub: bounded_int_module.extern_function_id("bounded_int_sub"),
1734 bounded_int_constrain: bounded_int_module.extern_function_id("bounded_int_constrain"),
1735 bounded_int_trim_min: bounded_int_module.extern_function_id("bounded_int_trim_min"),
1736 bounded_int_trim_max: bounded_int_module.extern_function_id("bounded_int_trim_max"),
1737 array_get: array_module.extern_function_id("array_get"),
1738 array_snapshot_pop_front: array_module.extern_function_id("array_snapshot_pop_front"),
1739 array_snapshot_pop_back: array_module.extern_function_id("array_snapshot_pop_back"),
1740 array_len: array_module.extern_function_id("array_len"),
1741 array_new: array_module.extern_function_id("array_new"),
1742 array_append: array_module.extern_function_id("array_append"),
1743 array_pop_front: array_module.extern_function_id("array_pop_front"),
1744 storage_base_address_from_felt252: storage_access_module
1745 .extern_function_id("storage_base_address_from_felt252"),
1746 storage_base_address_const: storage_access_module
1747 .generic_function_id("storage_base_address_const"),
1748 panic_with_felt252: core.function_id("panic_with_felt252", vec![]).lowered(db),
1749 panic_with_const_felt252: core.free_function_id("panic_with_const_felt252"),
1750 panic_with_byte_array: core
1751 .submodule("panics")
1752 .function_id("panic_with_byte_array", vec![])
1753 .lowered(db),
1754 type_info,
1755 const_calculation_info: db.const_calc_info(),
1756 }
1757 }
1758}
1759
1760impl<'db> std::ops::Deref for ConstFoldingContext<'db, '_> {
1761 type Target = ConstFoldingLibfuncInfo<'db>;
1762 fn deref(&self) -> &ConstFoldingLibfuncInfo<'db> {
1763 self.libfunc_info
1764 }
1765}
1766
1767impl<'a> std::ops::Deref for ConstFoldingLibfuncInfo<'a> {
1768 type Target = ConstCalcInfo<'a>;
1769 fn deref(&self) -> &ConstCalcInfo<'a> {
1770 &self.const_calculation_info
1771 }
1772}
1773
1774#[derive(Debug, PartialEq, Eq, salsa::Update)]
1776struct TypeInfo<'db> {
1777 is_zero: FunctionId<'db>,
1779 inc: Option<FunctionId<'db>>,
1781 dec: Option<FunctionId<'db>>,
1783}
1784
1785trait TypeRangeNormalizer {
1786 fn normalized(&self, value: BigInt) -> NormalizedResult;
1789}
1790impl TypeRangeNormalizer for TypeRange {
1791 fn normalized(&self, value: BigInt) -> NormalizedResult {
1792 if value < self.min {
1793 NormalizedResult::Under(value - &self.min + &self.max + 1)
1794 } else if value > self.max {
1795 NormalizedResult::Over(value + &self.min - &self.max - 1)
1796 } else {
1797 NormalizedResult::InRange(value)
1798 }
1799 }
1800}
1801
1802enum NormalizedResult {
1804 InRange(BigInt),
1806 Over(BigInt),
1808 Under(BigInt),
1810}