1#[cfg(test)]
2#[path = "const_folding_test.rs"]
3mod test;
4
5use std::sync::Arc;
6
7use cairo_lang_defs::ids::{ExternFunctionId, FreeFunctionId};
8use cairo_lang_filesystem::flag::flag_unsafe_panic;
9use cairo_lang_filesystem::ids::SmolStrId;
10use cairo_lang_semantic::corelib::CorelibSemantic;
11use cairo_lang_semantic::helper::ModuleHelper;
12use cairo_lang_semantic::items::constant::{
13 ConstCalcInfo, ConstValue, ConstValueId, ConstantSemantic, TypeRange, canonical_felt252,
14 felt252_for_downcast,
15};
16use cairo_lang_semantic::items::functions::{GenericFunctionId, GenericFunctionWithBodyId};
17use cairo_lang_semantic::items::structure::StructSemantic;
18use cairo_lang_semantic::types::{TypeSizeInformation, TypesSemantic};
19use cairo_lang_semantic::{
20 ConcreteTypeId, ConcreteVariant, GenericArgumentId, MatchArmSelector, TypeId, TypeLongId,
21 corelib,
22};
23use cairo_lang_utils::byte_array::BYTE_ARRAY_MAGIC;
24use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
25use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
26use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
27use cairo_lang_utils::{Intern, extract_matches, try_extract_matches};
28use itertools::{chain, zip_eq};
29use num_bigint::BigInt;
30use num_integer::Integer;
31use num_traits::cast::ToPrimitive;
32use num_traits::{Num, One, Zero};
33use salsa::Database;
34use starknet_types_core::felt::Felt as Felt252;
35
36use crate::db::LoweringGroup;
37use crate::ids::{
38 ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionId, SemanticFunctionIdEx,
39 SpecializedFunction,
40};
41use crate::specialization::SpecializationArg;
42use crate::utils::InliningStrategy;
43use crate::{
44 Block, BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, MatchArm, MatchEnumInfo,
45 MatchExternInfo, MatchInfo, Statement, StatementCall, StatementConst, StatementDesnap,
46 StatementEnumConstruct, StatementSnapshot, StatementStructConstruct,
47 StatementStructDestructure, VarRemapping, VarUsage, Variable, VariableArena, VariableId,
48};
49
50fn const_to_specialization_arg<'db>(
53 db: &'db dyn Database,
54 value: ConstValueId<'db>,
55 boxed: bool,
56) -> SpecializationArg<'db> {
57 match value.long(db) {
58 ConstValue::Struct(members, ty) => {
59 if matches!(ty.long(db), TypeLongId::Concrete(ConcreteTypeId::Struct(_))) {
62 let args = members
63 .iter()
64 .map(|member| const_to_specialization_arg(db, *member, false))
65 .collect();
66 SpecializationArg::Struct(args)
67 } else {
68 SpecializationArg::Const { value, boxed }
69 }
70 }
71 ConstValue::Enum(variant, payload) => SpecializationArg::Enum {
72 variant: *variant,
73 payload: Box::new(const_to_specialization_arg(db, *payload, false)),
74 },
75 _ => SpecializationArg::Const { value, boxed },
76 }
77}
78
79#[derive(Debug, Clone)]
82enum VarInfo<'db> {
83 Const(ConstValueId<'db>),
85 Var(VarUsage<'db>),
87 Snapshot(Box<VarInfo<'db>>),
89 Struct(Vec<Option<VarInfo<'db>>>),
92 Enum { variant: ConcreteVariant<'db>, payload: Box<VarInfo<'db>> },
94 Box(Box<VarInfo<'db>>),
96 Array(Vec<Option<VarInfo<'db>>>),
99}
100impl<'db> VarInfo<'db> {
101 fn peel_snapshots(&self) -> (usize, &VarInfo<'db>) {
103 let mut n_snapshots = 0;
104 let mut info = self;
105 while let VarInfo::Snapshot(inner) = info {
106 info = inner.as_ref();
107 n_snapshots += 1;
108 }
109 (n_snapshots, info)
110 }
111 fn wrap_with_snapshots(mut self, n_snapshots: usize) -> VarInfo<'db> {
113 for _ in 0..n_snapshots {
114 self = VarInfo::Snapshot(Box::new(self));
115 }
116 self
117 }
118}
119
120#[derive(Debug, Clone, Copy, PartialEq)]
121enum Reachability {
122 FromSingleGoto(BlockId),
125 Any,
128}
129
130pub fn const_folding<'db>(
133 db: &'db dyn Database,
134 function_id: ConcreteFunctionWithBodyId<'db>,
135 lowered: &mut Lowered<'db>,
136) {
137 if lowered.blocks.is_empty() {
138 return;
139 }
140
141 let mut ctx = ConstFoldingContext::new(db, function_id, &mut lowered.variables);
144
145 if ctx.should_skip_const_folding(db) {
146 return;
147 }
148
149 for block_id in (0..lowered.blocks.len()).map(BlockId) {
150 if !ctx.visit_block_start(block_id, |block_id| &lowered.blocks[block_id]) {
151 continue;
152 }
153
154 let block = &mut lowered.blocks[block_id];
155 for stmt in block.statements.iter_mut() {
156 ctx.visit_statement(stmt);
157 }
158 ctx.visit_block_end(block_id, block);
159 }
160}
161
162pub struct ConstFoldingContext<'db, 'mt> {
163 db: &'db dyn Database,
165 pub variables: &'mt mut VariableArena<'db>,
167 var_info: UnorderedHashMap<VariableId, VarInfo<'db>>,
169 libfunc_info: &'db ConstFoldingLibfuncInfo<'db>,
171 caller_function: ConcreteFunctionWithBodyId<'db>,
174 reachability: UnorderedHashMap<BlockId, Reachability>,
178 additional_stmts: Vec<Statement<'db>>,
180}
181
182impl<'db, 'mt> ConstFoldingContext<'db, 'mt> {
183 pub fn new(
184 db: &'db dyn Database,
185 function_id: ConcreteFunctionWithBodyId<'db>,
186 variables: &'mt mut VariableArena<'db>,
187 ) -> Self {
188 Self {
189 db,
190 var_info: UnorderedHashMap::default(),
191 variables,
192 libfunc_info: priv_const_folding_info(db),
193 caller_function: function_id,
194 reachability: UnorderedHashMap::from_iter([(BlockId::root(), Reachability::Any)]),
195 additional_stmts: vec![],
196 }
197 }
198
199 pub fn visit_block_start<'r, 'get>(
202 &'r mut self,
203 block_id: BlockId,
204 get_block: impl FnOnce(BlockId) -> &'get Block<'db>,
205 ) -> bool
206 where
207 'db: 'get,
208 {
209 let Some(reachability) = self.reachability.remove(&block_id) else {
210 return false;
211 };
212 match reachability {
213 Reachability::Any => {}
214 Reachability::FromSingleGoto(from_block) => match &get_block(from_block).end {
215 BlockEnd::Goto(_, remapping) => {
216 for (dst, src) in remapping.iter() {
217 if let Some(v) = self.as_const(src.var_id) {
218 self.var_info.insert(*dst, VarInfo::Const(v));
219 }
220 }
221 }
222 _ => unreachable!("Expected a goto end"),
223 },
224 }
225 true
226 }
227
228 pub fn visit_statement(&mut self, stmt: &mut Statement<'db>) {
238 self.maybe_replace_inputs(stmt.inputs_mut());
239 match stmt {
240 Statement::Const(StatementConst { value, output, boxed }) if *boxed => {
241 self.var_info.insert(*output, VarInfo::Box(VarInfo::Const(*value).into()));
242 }
243 Statement::Const(StatementConst { value, output, .. }) => match value.long(self.db) {
244 ConstValue::Int(..)
245 | ConstValue::Struct(..)
246 | ConstValue::Enum(..)
247 | ConstValue::NonZero(..) => {
248 self.var_info.insert(*output, VarInfo::Const(*value));
249 }
250 ConstValue::Generic(_)
251 | ConstValue::ImplConstant(_)
252 | ConstValue::Var(..)
253 | ConstValue::Missing(_) => {}
254 },
255 Statement::Snapshot(stmt) => {
256 if let Some(info) = self.var_info.get(&stmt.input.var_id).cloned() {
257 self.var_info.insert(stmt.original(), info.clone());
258 self.var_info.insert(stmt.snapshot(), VarInfo::Snapshot(info.into()));
259 }
260 }
261 Statement::Desnap(StatementDesnap { input, output }) => {
262 if let Some(VarInfo::Snapshot(info)) = self.var_info.get(&input.var_id) {
263 self.var_info.insert(*output, info.as_ref().clone());
264 }
265 }
266 Statement::Call(call_stmt) => {
267 if let Some(updated_stmt) = self.handle_statement_call(call_stmt) {
268 *stmt = updated_stmt;
269 } else if let Some(updated_stmt) = self.try_specialize_call(call_stmt) {
270 *stmt = updated_stmt;
271 }
272 }
273 Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
274 let mut const_args = vec![];
275 let mut all_args = vec![];
276 let mut contains_info = false;
277 for input in inputs.iter() {
278 let Some(info) = self.var_info.get(&input.var_id) else {
279 all_args.push(var_info_if_copy(self.variables, *input));
280 continue;
281 };
282 contains_info = true;
283 if let VarInfo::Const(value) = info {
284 const_args.push(*value);
285 }
286 all_args.push(Some(info.clone()));
287 }
288 if const_args.len() == inputs.len() {
289 let value =
290 ConstValue::Struct(const_args, self.variables[*output].ty).intern(self.db);
291 self.var_info.insert(*output, VarInfo::Const(value));
292 } else if contains_info {
293 self.var_info.insert(*output, VarInfo::Struct(all_args));
294 }
295 }
296 Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
297 if let Some(info) = self.var_info.get(&input.var_id) {
298 let (n_snapshots, info) = info.peel_snapshots();
299 match info {
300 VarInfo::Const(const_value) => {
301 if let ConstValue::Struct(member_values, _) = const_value.long(self.db)
302 {
303 for (output, value) in zip_eq(outputs, member_values) {
304 self.var_info.insert(
305 *output,
306 VarInfo::Const(*value).wrap_with_snapshots(n_snapshots),
307 );
308 }
309 }
310 }
311 VarInfo::Struct(members) => {
312 for (output, member) in zip_eq(outputs, members.clone()) {
313 if let Some(member) = member {
314 self.var_info
315 .insert(*output, member.wrap_with_snapshots(n_snapshots));
316 }
317 }
318 }
319 _ => {}
320 }
321 }
322 }
323 Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
324 let value = if let Some(info) = self.var_info.get(&input.var_id) {
325 if let VarInfo::Const(val) = info {
326 VarInfo::Const(ConstValue::Enum(*variant, *val).intern(self.db))
327 } else {
328 VarInfo::Enum { variant: *variant, payload: info.clone().into() }
329 }
330 } else {
331 VarInfo::Enum { variant: *variant, payload: VarInfo::Var(*input).into() }
332 };
333 self.var_info.insert(*output, value);
334 }
335 }
336 }
337
338 pub fn visit_block_end(&mut self, block_id: BlockId, block: &mut Block<'db>) {
345 let statements = &mut block.statements;
346 statements.splice(0..0, self.additional_stmts.drain(..));
347
348 match &mut block.end {
349 BlockEnd::Goto(_, remappings) => {
350 for (_, v) in remappings.iter_mut() {
351 self.maybe_replace_input(v);
352 }
353 }
354 BlockEnd::Match { info } => {
355 self.maybe_replace_inputs(info.inputs_mut());
356 match info {
357 MatchInfo::Enum(info) => {
358 if let Some(updated_end) = self.handle_enum_block_end(info, statements) {
359 block.end = updated_end;
360 }
361 }
362 MatchInfo::Extern(info) => {
363 if let Some(updated_end) = self.handle_extern_block_end(info, statements) {
364 block.end = updated_end;
365 }
366 }
367 MatchInfo::Value(info) => {
368 if let Some(value) =
369 self.as_int(info.input.var_id).and_then(|x| x.to_usize())
370 && let Some(arm) = info.arms.iter().find(|arm| {
371 matches!(
372 &arm.arm_selector,
373 MatchArmSelector::Value(v) if v.value == value
374 )
375 })
376 {
377 statements.push(Statement::StructConstruct(StatementStructConstruct {
379 inputs: vec![],
380 output: arm.var_ids[0],
381 }));
382 block.end = BlockEnd::Goto(arm.block_id, Default::default());
383 }
384 }
385 }
386 }
387 BlockEnd::Return(inputs, _) => self.maybe_replace_inputs(inputs),
388 BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
389 }
390 match &block.end {
391 BlockEnd::Goto(dst_block_id, _) => {
392 match self.reachability.entry(*dst_block_id) {
393 std::collections::hash_map::Entry::Occupied(mut e) => {
394 e.insert(Reachability::Any)
395 }
396 std::collections::hash_map::Entry::Vacant(e) => {
397 *e.insert(Reachability::FromSingleGoto(block_id))
398 }
399 };
400 }
401 BlockEnd::Match { info } => {
402 for arm in info.arms() {
403 assert!(self.reachability.insert(arm.block_id, Reachability::Any).is_none());
404 }
405 }
406 BlockEnd::NotSet | BlockEnd::Return(..) | BlockEnd::Panic(..) => {}
407 }
408 }
409
410 fn handle_statement_call(&mut self, stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
418 let db = self.db;
419 if stmt.function == self.panic_with_felt252 {
420 let val = self.as_const(stmt.inputs[0].var_id)?;
421 stmt.inputs.clear();
422 stmt.function = GenericFunctionId::Free(self.panic_with_const_felt252)
423 .concretize(db, vec![GenericArgumentId::Constant(val)])
424 .lowered(db);
425 return None;
426 } else if stmt.function == self.panic_with_byte_array && !flag_unsafe_panic(db) {
427 let snap = self.var_info.get(&stmt.inputs[0].var_id)?;
428 let bytearray = try_extract_matches!(snap, VarInfo::Snapshot)?;
429 let [
430 Some(VarInfo::Array(data)),
431 Some(VarInfo::Const(pending_word)),
432 Some(VarInfo::Const(pending_len)),
433 ] = &try_extract_matches!(bytearray.as_ref(), VarInfo::Struct)?[..]
434 else {
435 return None;
436 };
437 let mut panic_data =
438 vec![BigInt::from_str_radix(BYTE_ARRAY_MAGIC, 16).unwrap(), data.len().into()];
439 for word in data {
440 let Some(VarInfo::Const(word)) = word else {
441 return None;
442 };
443 panic_data.push(word.long(db).to_int()?.clone());
444 }
445 panic_data.extend([
446 pending_word.long(db).to_int()?.clone(),
447 pending_len.long(db).to_int()?.clone(),
448 ]);
449 let felt252_ty = self.felt252;
450 let location = stmt.location;
451 let new_var = |ty| Variable::with_default_context(db, ty, location);
452 let as_usage = |var_id| VarUsage { var_id, location };
453 let array_fn = |extern_id| {
454 let args = vec![GenericArgumentId::Type(felt252_ty)];
455 GenericFunctionId::Extern(extern_id).concretize(db, args).lowered(db)
456 };
457 let call_stmt = |function, inputs, outputs| {
458 let with_coupon = false;
459 Statement::Call(StatementCall {
460 function,
461 inputs,
462 with_coupon,
463 outputs,
464 location,
465 is_specialization_base_call: false,
466 })
467 };
468 let arr_var = new_var(corelib::core_array_felt252_ty(db));
469 let mut arr = self.variables.alloc(arr_var.clone());
470 self.additional_stmts.push(call_stmt(array_fn(self.array_new), vec![], vec![arr]));
471 let felt252_var = new_var(felt252_ty);
472 let arr_append_fn = array_fn(self.array_append);
473 for word in panic_data {
474 let to_append = self.variables.alloc(felt252_var.clone());
475 let new_arr = self.variables.alloc(arr_var.clone());
476 self.additional_stmts.push(Statement::Const(StatementConst::new_flat(
477 ConstValue::Int(word, felt252_ty).intern(db),
478 to_append,
479 )));
480 self.additional_stmts.push(call_stmt(
481 arr_append_fn,
482 vec![as_usage(arr), as_usage(to_append)],
483 vec![new_arr],
484 ));
485 arr = new_arr;
486 }
487 let panic_ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, "Panic"), vec![]);
488 let panic_var = self.variables.alloc(new_var(panic_ty));
489 self.additional_stmts.push(Statement::StructConstruct(StatementStructConstruct {
490 inputs: vec![],
491 output: panic_var,
492 }));
493 return Some(Statement::StructConstruct(StatementStructConstruct {
494 inputs: vec![as_usage(panic_var), as_usage(arr)],
495 output: stmt.outputs[0],
496 }));
497 }
498 let (id, _generic_args) = stmt.function.get_extern(db)?;
499 if id == self.felt_sub {
500 if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
501 && rhs.is_zero()
502 {
503 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
504 None
505 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
506 && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
507 {
508 let value = canonical_felt252(&(lhs - rhs));
509 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
510 } else {
511 None
512 }
513 } else if id == self.felt_add {
514 if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
515 && lhs.is_zero()
516 {
517 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]));
518 None
519 } else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
520 && rhs.is_zero()
521 {
522 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
523 None
524 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
525 && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
526 {
527 let value = canonical_felt252(&(lhs + rhs));
528 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
529 } else {
530 None
531 }
532 } else if id == self.felt_mul {
533 let lhs = self.as_int(stmt.inputs[0].var_id);
534 let rhs = self.as_int(stmt.inputs[1].var_id);
535 if lhs.map(Zero::is_zero).unwrap_or_default()
536 || rhs.map(Zero::is_zero).unwrap_or_default()
537 {
538 Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
539 } else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
540 && rhs.is_one()
541 {
542 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
543 None
544 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
545 && lhs.is_one()
546 {
547 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]));
548 None
549 } else if let Some(lhs) = lhs
550 && let Some(rhs) = rhs
551 {
552 let value = canonical_felt252(&(lhs * rhs));
553 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
554 } else {
555 None
556 }
557 } else if id == self.felt_div {
558 if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
560 && rhs.is_one()
562 {
563 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
564 None
565 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
566 && lhs.is_zero()
568 {
569 Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
570 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
571 && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
572 && let Ok(rhs_nonzero) = Felt252::from(rhs).try_into()
573 {
574 let lhs_felt = Felt252::from(lhs);
578 let value = lhs_felt.field_div(&rhs_nonzero).to_bigint();
579 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
580 } else {
581 None
582 }
583 } else if self.wide_mul_fns.contains(&id) {
584 let lhs = self.as_int(stmt.inputs[0].var_id);
585 let rhs = self.as_int(stmt.inputs[1].var_id);
586 let output = stmt.outputs[0];
587 if lhs.map(Zero::is_zero).unwrap_or_default()
588 || rhs.map(Zero::is_zero).unwrap_or_default()
589 {
590 return Some(self.propagate_zero_and_get_statement(output));
591 }
592 let lhs = lhs?;
593 Some(self.propagate_const_and_get_statement(lhs * rhs?, stmt.outputs[0]))
594 } else if id == self.bounded_int_add || id == self.bounded_int_sub {
595 let lhs = self.as_int(stmt.inputs[0].var_id)?;
596 let rhs = self.as_int(stmt.inputs[1].var_id)?;
597 let value = if id == self.bounded_int_add { lhs + rhs } else { lhs - rhs };
598 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
599 } else if self.div_rem_fns.contains(&id) {
600 let lhs = self.as_int(stmt.inputs[0].var_id);
601 if lhs.map(Zero::is_zero).unwrap_or_default() {
602 let additional_stmt = self.propagate_zero_and_get_statement(stmt.outputs[1]);
603 self.additional_stmts.push(additional_stmt);
604 return Some(self.propagate_zero_and_get_statement(stmt.outputs[0]));
605 }
606 let rhs = self.as_int(stmt.inputs[1].var_id)?;
607 let (q, r) = lhs?.div_rem(rhs);
608 let q_output = stmt.outputs[0];
609 let q_value = ConstValue::Int(q, self.variables[q_output].ty).intern(db);
610 self.var_info.insert(q_output, VarInfo::Const(q_value));
611 let r_output = stmt.outputs[1];
612 let r_value = ConstValue::Int(r, self.variables[r_output].ty).intern(db);
613 self.var_info.insert(r_output, VarInfo::Const(r_value));
614 self.additional_stmts
615 .push(Statement::Const(StatementConst::new_flat(r_value, r_output)));
616 Some(Statement::Const(StatementConst::new_flat(q_value, q_output)))
617 } else if id == self.storage_base_address_from_felt252 {
618 let input_var = stmt.inputs[0].var_id;
619 if let Some(const_value) = self.as_const(input_var)
620 && let ConstValue::Int(val, ty) = const_value.long(db)
621 {
622 stmt.inputs.clear();
623 let arg = GenericArgumentId::Constant(ConstValue::Int(val.clone(), *ty).intern(db));
624 stmt.function =
625 self.storage_base_address_const.concretize(db, vec![arg]).lowered(db);
626 }
627 None
628 } else if id == self.into_box {
629 let input = stmt.inputs[0];
630 let var_info = self.var_info.get(&input.var_id);
631 let const_value = match var_info {
632 Some(VarInfo::Const(val)) => Some(*val),
633 Some(VarInfo::Snapshot(info)) => {
634 try_extract_matches!(info.as_ref(), VarInfo::Const).copied()
635 }
636 _ => None,
637 };
638 let var_info = var_info.cloned().or_else(|| var_info_if_copy(self.variables, input))?;
639 self.var_info.insert(stmt.outputs[0], VarInfo::Box(var_info.into()));
640 Some(Statement::Const(StatementConst::new_boxed(const_value?, stmt.outputs[0])))
641 } else if id == self.unbox {
642 if let VarInfo::Box(inner) = self.var_info.get(&stmt.inputs[0].var_id)? {
643 let inner = inner.as_ref().clone();
644 if let VarInfo::Const(inner) =
645 self.var_info.entry(stmt.outputs[0]).insert_entry(inner).get()
646 {
647 return Some(Statement::Const(StatementConst::new_flat(
648 *inner,
649 stmt.outputs[0],
650 )));
651 }
652 }
653 None
654 } else if self.upcast_fns.contains(&id) {
655 let int_value = self.as_int(stmt.inputs[0].var_id)?;
656 let output = stmt.outputs[0];
657 let value = ConstValue::Int(int_value.clone(), self.variables[output].ty).intern(db);
658 self.var_info.insert(output, VarInfo::Const(value));
659 Some(Statement::Const(StatementConst::new_flat(value, output)))
660 } else if id == self.array_new {
661 self.var_info.insert(stmt.outputs[0], VarInfo::Array(vec![]));
662 None
663 } else if id == self.array_append {
664 let mut var_infos =
665 if let VarInfo::Array(var_infos) = self.var_info.get(&stmt.inputs[0].var_id)? {
666 var_infos.clone()
667 } else {
668 return None;
669 };
670 let appended = stmt.inputs[1];
671 var_infos.push(match self.var_info.get(&appended.var_id) {
672 Some(var_info) => Some(var_info.clone()),
673 None => var_info_if_copy(self.variables, appended),
674 });
675 self.var_info.insert(stmt.outputs[0], VarInfo::Array(var_infos));
676 None
677 } else if id == self.array_len {
678 let info = self.var_info.get(&stmt.inputs[0].var_id)?;
679 let desnapped = try_extract_matches!(info, VarInfo::Snapshot)?;
680 let length = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?.len();
681 Some(self.propagate_const_and_get_statement(length.into(), stmt.outputs[0]))
682 } else {
683 None
684 }
685 }
686
687 fn try_specialize_call(&self, call_stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
694 if call_stmt.with_coupon {
695 return None;
696 }
697 if matches!(self.db.optimizations().inlining_strategy(), InliningStrategy::Avoid) {
699 return None;
700 }
701
702 let Ok(Some(mut called_function)) = call_stmt.function.body(self.db) else {
703 return None;
704 };
705
706 let extract_base = |function: ConcreteFunctionWithBodyId<'db>| match function.long(self.db)
707 {
708 ConcreteFunctionWithBodyLongId::Specialized(specialized) => specialized.base,
709 _ => function,
710 };
711 let called_base = extract_base(called_function);
712 let caller_base = extract_base(self.caller_function);
713
714 if self.db.priv_never_inline(called_base).ok()? {
715 return None;
716 }
717
718 if call_stmt.is_specialization_base_call {
720 return None;
721 }
722
723 if called_base == caller_base && called_function != called_base {
725 return None;
726 }
727
728 let scc =
731 self.db.lowered_scc(called_base, DependencyType::Call, LoweringStage::Monomorphized);
732 if scc.len() > 1 && scc.contains(&caller_base) {
733 return None;
734 }
735
736 if call_stmt.inputs.iter().all(|arg| self.var_info.get(&arg.var_id).is_none()) {
737 return None;
739 }
740
741 let self_specializition = if let ConcreteFunctionWithBodyLongId::Specialized(specialized) =
743 self.caller_function.long(self.db)
744 && caller_base == called_base
745 {
746 specialized.args.iter().map(Some).collect()
747 } else {
748 vec![None; call_stmt.inputs.len()]
749 };
750
751 let mut specialization_args = vec![];
752 let mut new_args = vec![];
753 for (arg, coerce) in zip_eq(&call_stmt.inputs, &self_specializition) {
754 if let Some(var_info) = self.var_info.get(&arg.var_id)
755 && self.variables[arg.var_id].info.droppable.is_ok()
756 && let Some(specialization_arg) = self.try_get_specialization_arg(
757 var_info.clone(),
758 self.variables[arg.var_id].ty,
759 &mut new_args,
760 *coerce,
761 )
762 {
763 specialization_args.push(specialization_arg);
764 } else {
765 specialization_args.push(SpecializationArg::NotSpecialized);
766 new_args.push(*arg);
767 continue;
768 };
769 }
770
771 if specialization_args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized)) {
772 return None;
774 }
775 if let ConcreteFunctionWithBodyLongId::Specialized(specialized_function) =
776 called_function.long(self.db)
777 {
778 called_function = specialized_function.base;
781 let mut new_args_iter = specialization_args.into_iter();
782 let mut old_args = specialized_function.args.to_vec();
783 let mut stack = vec![];
784 for arg in old_args.iter_mut().rev() {
785 stack.push(arg);
786 }
787 while let Some(arg) = stack.pop() {
788 match arg {
789 SpecializationArg::Const { .. } => {}
790 SpecializationArg::Snapshot(inner) => {
791 stack.push(inner.as_mut());
792 }
793 SpecializationArg::Enum { payload, .. } => {
794 stack.push(payload.as_mut());
795 }
796 SpecializationArg::Array(_, values) => {
797 for value in values.iter_mut().rev() {
798 stack.push(value);
799 }
800 }
801 SpecializationArg::Struct(specialization_args) => {
802 for arg in specialization_args.iter_mut().rev() {
803 stack.push(arg);
804 }
805 }
806 SpecializationArg::NotSpecialized => {
807 *arg = new_args_iter.next().unwrap_or(SpecializationArg::NotSpecialized);
808 }
809 }
810 }
811 specialization_args = old_args;
812 }
813 let specialized =
814 SpecializedFunction { base: called_function, args: specialization_args.into() };
815 let specialized_func_id =
816 ConcreteFunctionWithBodyLongId::Specialized(specialized).intern(self.db);
817
818 if caller_base != called_base
819 && self.db.priv_should_specialize(specialized_func_id) == Ok(false)
820 {
821 return None;
822 }
823
824 Some(Statement::Call(StatementCall {
825 function: specialized_func_id.function_id(self.db).unwrap(),
826 inputs: new_args,
827 with_coupon: call_stmt.with_coupon,
828 outputs: std::mem::take(&mut call_stmt.outputs),
829 location: call_stmt.location,
830 is_specialization_base_call: false,
831 }))
832 }
833
834 fn propagate_const_and_get_statement(
836 &mut self,
837 value: BigInt,
838 output: VariableId,
839 ) -> Statement<'db> {
840 let ty = self.variables[output].ty;
841 let value = ConstValueId::from_int(self.db, ty, &value);
842 self.var_info.insert(output, VarInfo::Const(value));
843 Statement::Const(StatementConst::new_flat(value, output))
844 }
845
846 fn propagate_zero_and_get_statement(&mut self, output: VariableId) -> Statement<'db> {
848 self.propagate_const_and_get_statement(BigInt::zero(), output)
849 }
850
851 fn try_generate_const_statement(
853 &self,
854 value: ConstValueId<'db>,
855 output: VariableId,
856 ) -> Option<Statement<'db>> {
857 if self.db.type_size_info(self.variables[output].ty) == Ok(TypeSizeInformation::Other) {
858 Some(Statement::Const(StatementConst::new_flat(value, output)))
859 } else if matches!(value.long(self.db), ConstValue::Struct(members, _) if members.is_empty())
860 {
861 Some(Statement::StructConstruct(StatementStructConstruct { inputs: vec![], output }))
863 } else {
864 None
865 }
866 }
867
868 fn handle_enum_block_end(
873 &mut self,
874 info: &mut MatchEnumInfo<'db>,
875 statements: &mut Vec<Statement<'db>>,
876 ) -> Option<BlockEnd<'db>> {
877 let input = info.input.var_id;
878 let (n_snapshots, var_info) = self.var_info.get(&input)?.peel_snapshots();
879 let location = info.location;
880 let as_usage = |var_id| VarUsage { var_id, location };
881 let db = self.db;
882 let snapshot_stmt = |vars: &mut VariableArena<'_>, pre_snap, post_snap| {
883 let ignored = vars.alloc(vars[pre_snap].clone());
884 Statement::Snapshot(StatementSnapshot::new(as_usage(pre_snap), ignored, post_snap))
885 };
886 if let VarInfo::Const(const_value) = var_info
888 && let ConstValue::Enum(variant, value) = const_value.long(db)
889 {
890 let arm = &info.arms[variant.idx];
891 let output = arm.var_ids[0];
892 self.var_info.insert(output, VarInfo::Const(*value).wrap_with_snapshots(n_snapshots));
894 if self.variables[input].info.droppable.is_ok()
895 && self.variables[output].info.copyable.is_ok()
896 && let Ok(mut ty) = value.ty(db)
897 && let Some(mut stmt) = self.try_generate_const_statement(*value, output)
898 {
899 for _ in 0..n_snapshots {
901 let non_snap_var = Variable::with_default_context(db, ty, location);
902 ty = TypeLongId::Snapshot(ty).intern(db);
903 let pre_snap = self.variables.alloc(non_snap_var);
904 stmt.outputs_mut()[0] = pre_snap;
905 let take_snap = snapshot_stmt(self.variables, pre_snap, output);
906 statements.push(core::mem::replace(&mut stmt, take_snap));
907 }
908 statements.push(stmt);
909 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
910 }
911 } else if let VarInfo::Enum { variant, payload } = var_info {
912 let arm = &info.arms[variant.idx];
913 let variant_ty = variant.ty;
914 let output = arm.var_ids[0];
915 let payload = payload.as_ref().clone();
916 let unwrapped =
917 self.variables[input].info.droppable.is_ok().then_some(()).and_then(|_| {
918 let (extra_snapshots, inner) = payload.peel_snapshots();
919 match inner {
920 VarInfo::Var(var) if self.variables[var.var_id].info.copyable.is_ok() => {
921 Some((var.var_id, extra_snapshots))
922 }
923 VarInfo::Const(value) => {
924 let const_var = self
925 .variables
926 .alloc(Variable::with_default_context(db, variant_ty, location));
927 statements.push(self.try_generate_const_statement(*value, const_var)?);
928 Some((const_var, extra_snapshots))
929 }
930 _ => None,
931 }
932 });
933 self.var_info.insert(output, payload.wrap_with_snapshots(n_snapshots));
935 if let Some((mut unwrapped, extra_snapshots)) = unwrapped {
936 let total_snapshots = n_snapshots + extra_snapshots;
937 if total_snapshots != 0 {
938 for _ in 1..total_snapshots {
940 let ty = TypeLongId::Snapshot(self.variables[unwrapped].ty).intern(db);
941 let non_snap_var = Variable::with_default_context(self.db, ty, location);
942 let snapped = self.variables.alloc(non_snap_var);
943 statements.push(snapshot_stmt(self.variables, unwrapped, snapped));
944 unwrapped = snapped;
945 }
946 statements.push(snapshot_stmt(self.variables, unwrapped, output));
947 };
948 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
949 }
950 }
951 None
952 }
953
954 fn handle_extern_block_end(
959 &mut self,
960 info: &mut MatchExternInfo<'db>,
961 statements: &mut Vec<Statement<'db>>,
962 ) -> Option<BlockEnd<'db>> {
963 let db = self.db;
964 let (id, generic_args) = info.function.get_extern(db)?;
965 if self.nz_fns.contains(&id) {
966 let val = self.as_const(info.inputs[0].var_id)?;
967 let is_zero = match val.long(db) {
968 ConstValue::Int(v, _) => v.is_zero(),
969 ConstValue::Struct(s, _) => s.iter().all(|v| {
970 v.long(db).to_int().expect("Expected ConstValue::Int for size").is_zero()
971 }),
972 _ => unreachable!(),
973 };
974 Some(if is_zero {
975 BlockEnd::Goto(info.arms[0].block_id, Default::default())
976 } else {
977 let arm = &info.arms[1];
978 let nz_var = arm.var_ids[0];
979 let nz_val = ConstValue::NonZero(val).intern(db);
980 self.var_info.insert(nz_var, VarInfo::Const(nz_val));
981 statements.push(Statement::Const(StatementConst::new_flat(nz_val, nz_var)));
982 BlockEnd::Goto(arm.block_id, Default::default())
983 })
984 } else if self.eq_fns.contains(&id) {
985 let lhs = self.as_int(info.inputs[0].var_id);
986 let rhs = self.as_int(info.inputs[1].var_id);
987 if (lhs.map(Zero::is_zero).unwrap_or_default() && rhs.is_none())
988 || (rhs.map(Zero::is_zero).unwrap_or_default() && lhs.is_none())
989 {
990 let nz_input = info.inputs[if lhs.is_some() { 1 } else { 0 }];
991 let var = &self.variables[nz_input.var_id].clone();
992 let function = self.type_info.get(&var.ty)?.is_zero;
993 let unused_nz_var = Variable::with_default_context(
994 db,
995 corelib::core_nonzero_ty(db, var.ty),
996 var.location,
997 );
998 let unused_nz_var = self.variables.alloc(unused_nz_var);
999 return Some(BlockEnd::Match {
1000 info: MatchInfo::Extern(MatchExternInfo {
1001 function,
1002 inputs: vec![nz_input],
1003 arms: vec![
1004 MatchArm {
1005 arm_selector: MatchArmSelector::VariantId(
1006 corelib::jump_nz_zero_variant(db, var.ty),
1007 ),
1008 block_id: info.arms[1].block_id,
1009 var_ids: vec![],
1010 },
1011 MatchArm {
1012 arm_selector: MatchArmSelector::VariantId(
1013 corelib::jump_nz_nonzero_variant(db, var.ty),
1014 ),
1015 block_id: info.arms[0].block_id,
1016 var_ids: vec![unused_nz_var],
1017 },
1018 ],
1019 location: info.location,
1020 }),
1021 });
1022 }
1023 Some(BlockEnd::Goto(
1024 info.arms[if lhs? == rhs? { 1 } else { 0 }].block_id,
1025 Default::default(),
1026 ))
1027 } else if self.uadd_fns.contains(&id)
1028 || self.usub_fns.contains(&id)
1029 || self.diff_fns.contains(&id)
1030 || self.iadd_fns.contains(&id)
1031 || self.isub_fns.contains(&id)
1032 {
1033 let rhs = self.as_int(info.inputs[1].var_id);
1034 let lhs = self.as_int(info.inputs[0].var_id);
1035 if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
1036 let ty = self.variables[info.arms[0].var_ids[0]].ty;
1037 let range = self.type_value_ranges.get(&ty)?;
1038 let value = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
1039 lhs + rhs
1040 } else {
1041 lhs - rhs
1042 };
1043 let (arm_index, value) = match range.normalized(value) {
1044 NormalizedResult::InRange(value) => (0, value),
1045 NormalizedResult::Under(value) => (1, value),
1046 NormalizedResult::Over(value) => (
1047 if self.iadd_fns.contains(&id) || self.isub_fns.contains(&id) {
1048 2
1049 } else {
1050 1
1051 },
1052 value,
1053 ),
1054 };
1055 let arm = &info.arms[arm_index];
1056 let actual_output = arm.var_ids[0];
1057 let value = ConstValue::Int(value, ty).intern(db);
1058 self.var_info.insert(actual_output, VarInfo::Const(value));
1059 statements.push(Statement::Const(StatementConst::new_flat(value, actual_output)));
1060 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1061 }
1062 if let Some(rhs) = rhs {
1063 if rhs.is_zero() && !self.diff_fns.contains(&id) {
1064 let arm = &info.arms[0];
1065 self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[0]));
1066 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1067 }
1068 if rhs.is_one() && !self.diff_fns.contains(&id) {
1069 let ty = self.variables[info.arms[0].var_ids[0]].ty;
1070 let ty_info = self.type_info.get(&ty)?;
1071 let function = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
1072 ty_info.inc?
1073 } else {
1074 ty_info.dec?
1075 };
1076 let enum_ty = function.signature(db).ok()?.return_type;
1077 let TypeLongId::Concrete(ConcreteTypeId::Enum(concrete_enum_id)) =
1078 enum_ty.long(db)
1079 else {
1080 return None;
1081 };
1082 let result = self.variables.alloc(Variable::with_default_context(
1083 db,
1084 function.signature(db).unwrap().return_type,
1085 info.location,
1086 ));
1087 statements.push(Statement::Call(StatementCall {
1088 function,
1089 inputs: vec![info.inputs[0]],
1090 with_coupon: false,
1091 outputs: vec![result],
1092 location: info.location,
1093 is_specialization_base_call: false,
1094 }));
1095 return Some(BlockEnd::Match {
1096 info: MatchInfo::Enum(MatchEnumInfo {
1097 concrete_enum_id: *concrete_enum_id,
1098 input: VarUsage { var_id: result, location: info.location },
1099 arms: core::mem::take(&mut info.arms),
1100 location: info.location,
1101 }),
1102 });
1103 }
1104 }
1105 if let Some(lhs) = lhs
1106 && lhs.is_zero()
1107 && (self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id))
1108 {
1109 let arm = &info.arms[0];
1110 self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[1]));
1111 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1112 }
1113 None
1114 } else if let Some(reversed) = self.downcast_fns.get(&id) {
1115 let range = |ty: TypeId<'_>| {
1116 Some(if let Some(range) = self.type_value_ranges.get(&ty) {
1117 range.clone()
1118 } else {
1119 let (min, max) = corelib::try_extract_bounded_int_type_ranges(db, ty)?;
1120 TypeRange { min, max }
1121 })
1122 };
1123 let (success_arm, failure_arm) = if *reversed { (1, 0) } else { (0, 1) };
1124 let input_var = info.inputs[0].var_id;
1125 let in_ty = self.variables[input_var].ty;
1126 let success_output = info.arms[success_arm].var_ids[0];
1127 let out_ty = self.variables[success_output].ty;
1128 let out_range = range(out_ty)?;
1129 let Some(value) = self.as_int(input_var) else {
1130 let in_range = range(in_ty)?;
1131 return if in_range.min < out_range.min || in_range.max > out_range.max {
1132 None
1133 } else {
1134 let generic_args = [in_ty, out_ty].map(GenericArgumentId::Type).to_vec();
1135 let function = db.core_info().upcast_fn.concretize(db, generic_args);
1136 statements.push(Statement::Call(StatementCall {
1137 function: function.lowered(db),
1138 inputs: vec![info.inputs[0]],
1139 with_coupon: false,
1140 outputs: vec![success_output],
1141 location: info.location,
1142 is_specialization_base_call: false,
1143 }));
1144 return Some(BlockEnd::Goto(
1145 info.arms[success_arm].block_id,
1146 Default::default(),
1147 ));
1148 };
1149 };
1150 let value = if in_ty == self.felt252 {
1151 felt252_for_downcast(value, &out_range.min)
1152 } else {
1153 value.clone()
1154 };
1155 Some(if let NormalizedResult::InRange(value) = out_range.normalized(value) {
1156 let value = ConstValue::Int(value, out_ty).intern(db);
1157 self.var_info.insert(success_output, VarInfo::Const(value));
1158 statements.push(Statement::Const(StatementConst::new_flat(value, success_output)));
1159 BlockEnd::Goto(info.arms[success_arm].block_id, Default::default())
1160 } else {
1161 BlockEnd::Goto(info.arms[failure_arm].block_id, Default::default())
1162 })
1163 } else if id == self.bounded_int_constrain {
1164 let input_var = info.inputs[0].var_id;
1165 let value = self.as_int(input_var)?;
1166 let generic_arg = generic_args[1];
1167 let constrain_value = extract_matches!(generic_arg, GenericArgumentId::Constant)
1168 .long(db)
1169 .to_int()
1170 .expect("Expected ConstValue::Int for size");
1171 let arm_idx = if value < constrain_value { 0 } else { 1 };
1172 let output = info.arms[arm_idx].var_ids[0];
1173 statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1174 Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1175 } else if id == self.bounded_int_trim_min {
1176 let input_var = info.inputs[0].var_id;
1177 let ConstValue::Int(value, ty) = self.as_const(input_var)?.long(self.db) else {
1178 return None;
1179 };
1180 let is_trimmed = if let Some(range) = self.type_value_ranges.get(ty) {
1181 range.min == *value
1182 } else {
1183 corelib::try_extract_bounded_int_type_ranges(db, *ty)?.0 == *value
1184 };
1185 let arm_idx = if is_trimmed {
1186 0
1187 } else {
1188 let output = info.arms[1].var_ids[0];
1189 statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1190 1
1191 };
1192 Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1193 } else if id == self.bounded_int_trim_max {
1194 let input_var = info.inputs[0].var_id;
1195 let ConstValue::Int(value, ty) = self.as_const(input_var)?.long(self.db) else {
1196 return None;
1197 };
1198 let is_trimmed = if let Some(range) = self.type_value_ranges.get(ty) {
1199 range.max == *value
1200 } else {
1201 corelib::try_extract_bounded_int_type_ranges(db, *ty)?.1 == *value
1202 };
1203 let arm_idx = if is_trimmed {
1204 0
1205 } else {
1206 let output = info.arms[1].var_ids[0];
1207 statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1208 1
1209 };
1210 Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1211 } else if id == self.array_get {
1212 let index = self.as_int(info.inputs[1].var_id)?.to_usize()?;
1213 if let Some(VarInfo::Snapshot(arr_info)) = self.var_info.get(&info.inputs[0].var_id)
1214 && let VarInfo::Array(infos) = arr_info.as_ref()
1215 {
1216 match infos.get(index) {
1217 Some(Some(output_var_info)) => {
1218 let arm = &info.arms[0];
1219 let output_var_info = output_var_info.clone();
1220 let box_info =
1221 VarInfo::Box(VarInfo::Snapshot(output_var_info.clone().into()).into());
1222 self.var_info.insert(arm.var_ids[0], box_info);
1223 if let VarInfo::Const(value) = output_var_info {
1224 let value_ty = value.ty(db).ok()?;
1225 let value_box_ty = corelib::core_box_ty(db, value_ty);
1226 let location = info.location;
1227 let boxed_var =
1228 Variable::with_default_context(db, value_box_ty, location);
1229 let boxed = self.variables.alloc(boxed_var.clone());
1230 let unused_boxed = self.variables.alloc(boxed_var);
1231 let snapped = self.variables.alloc(Variable::with_default_context(
1232 db,
1233 TypeLongId::Snapshot(value_box_ty).intern(db),
1234 location,
1235 ));
1236 statements.extend([
1237 Statement::Const(StatementConst::new_boxed(value, boxed)),
1238 Statement::Snapshot(StatementSnapshot {
1239 input: VarUsage { var_id: boxed, location },
1240 outputs: [unused_boxed, snapped],
1241 }),
1242 Statement::Call(StatementCall {
1243 function: self
1244 .box_forward_snapshot
1245 .concretize(db, vec![GenericArgumentId::Type(value_ty)])
1246 .lowered(db),
1247 inputs: vec![VarUsage { var_id: snapped, location }],
1248 with_coupon: false,
1249 outputs: vec![arm.var_ids[0]],
1250 location: info.location,
1251 is_specialization_base_call: false,
1252 }),
1253 ]);
1254 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1255 }
1256 }
1257 None => {
1258 return Some(BlockEnd::Goto(info.arms[1].block_id, Default::default()));
1259 }
1260 Some(None) => {}
1261 }
1262 }
1263 if index.is_zero()
1264 && let [success, failure] = info.arms.as_mut_slice()
1265 {
1266 let arr = info.inputs[0].var_id;
1267 let unused_arr_output0 = self.variables.alloc(self.variables[arr].clone());
1268 let unused_arr_output1 = self.variables.alloc(self.variables[arr].clone());
1269 info.inputs.truncate(1);
1270 info.function = GenericFunctionId::Extern(self.array_snapshot_pop_front)
1271 .concretize(db, generic_args)
1272 .lowered(db);
1273 success.var_ids.insert(0, unused_arr_output0);
1274 failure.var_ids.insert(0, unused_arr_output1);
1275 }
1276 None
1277 } else if id == self.array_pop_front {
1278 let VarInfo::Array(var_infos) = self.var_info.get(&info.inputs[0].var_id)? else {
1279 return None;
1280 };
1281 if let Some(first) = var_infos.first() {
1282 if let Some(first) = first.as_ref().cloned() {
1283 let arm = &info.arms[0];
1284 self.var_info.insert(arm.var_ids[0], VarInfo::Array(var_infos[1..].to_vec()));
1285 self.var_info.insert(arm.var_ids[1], VarInfo::Box(first.into()));
1286 }
1287 None
1288 } else {
1289 let arm = &info.arms[1];
1290 self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]));
1291 Some(BlockEnd::Goto(
1292 arm.block_id,
1293 VarRemapping {
1294 remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1295 },
1296 ))
1297 }
1298 } else if id == self.array_snapshot_pop_back || id == self.array_snapshot_pop_front {
1299 let var_info = self.var_info.get(&info.inputs[0].var_id)?;
1300 let desnapped = try_extract_matches!(var_info, VarInfo::Snapshot)?;
1301 let element_var_infos = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?;
1302 if element_var_infos.is_empty() {
1304 let arm = &info.arms[1];
1305 self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]));
1306 Some(BlockEnd::Goto(
1307 arm.block_id,
1308 VarRemapping {
1309 remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1310 },
1311 ))
1312 } else {
1313 None
1314 }
1315 } else {
1316 None
1317 }
1318 }
1319
1320 fn as_const(&self, var_id: VariableId) -> Option<ConstValueId<'db>> {
1322 try_extract_matches!(self.var_info.get(&var_id)?, VarInfo::Const).copied()
1323 }
1324
1325 fn as_int(&self, var_id: VariableId) -> Option<&BigInt> {
1327 match self.as_const(var_id)?.long(self.db) {
1328 ConstValue::Int(value, _) => Some(value),
1329 ConstValue::NonZero(const_value) => {
1330 if let ConstValue::Int(value, _) = const_value.long(self.db) {
1331 Some(value)
1332 } else {
1333 None
1334 }
1335 }
1336 _ => None,
1337 }
1338 }
1339
1340 fn maybe_replace_inputs(&self, inputs: &mut [VarUsage<'db>]) {
1342 for input in inputs {
1343 self.maybe_replace_input(input);
1344 }
1345 }
1346
1347 fn maybe_replace_input(&self, input: &mut VarUsage<'db>) {
1349 if let Some(VarInfo::Var(new_var)) = self.var_info.get(&input.var_id) {
1350 *input = *new_var;
1351 }
1352 }
1353
1354 fn try_get_specialization_arg(
1360 &self,
1361 var_info: VarInfo<'db>,
1362 ty: TypeId<'db>,
1363 unknown_vars: &mut Vec<VarUsage<'db>>,
1364 coerce: Option<&SpecializationArg<'db>>,
1365 ) -> Option<SpecializationArg<'db>> {
1366 if self.db.type_size_info(ty).ok()? == TypeSizeInformation::ZeroSized {
1367 return None;
1369 }
1370
1371 match var_info {
1372 VarInfo::Const(value) => {
1373 let res = const_to_specialization_arg(self.db, value, false);
1374 let Some(coerce) = coerce else {
1375 return Some(res);
1376 };
1377 if *coerce == res { Some(res) } else { None }
1378 }
1379 VarInfo::Box(info) => {
1380 let res = try_extract_matches!(info.as_ref(), VarInfo::Const)
1381 .map(|value| SpecializationArg::Const { value: *value, boxed: true });
1382 let Some(coerce) = coerce else {
1383 return res;
1384 };
1385 if Some(coerce.clone()) == res { res } else { None }
1386 }
1387 VarInfo::Snapshot(info) => {
1388 let desnap_ty = *extract_matches!(ty.long(self.db), TypeLongId::Snapshot);
1389 let mut local_unknown_vars: Vec<VarUsage<'db>> = Vec::new();
1391 let inner = self.try_get_specialization_arg(
1392 info.as_ref().clone(),
1393 desnap_ty,
1394 &mut local_unknown_vars,
1395 coerce.map(|coerce| {
1396 extract_matches!(coerce, SpecializationArg::Snapshot).as_ref()
1397 }),
1398 )?;
1399 unknown_vars.extend(local_unknown_vars);
1400 Some(SpecializationArg::Snapshot(Box::new(inner)))
1401 }
1402 VarInfo::Array(infos) => {
1403 let TypeLongId::Concrete(concrete_ty) = ty.long(self.db) else {
1404 unreachable!("Expected a concrete type");
1405 };
1406 let [GenericArgumentId::Type(inner_ty)] = &concrete_ty.generic_args(self.db)[..]
1407 else {
1408 unreachable!("Expected a single type generic argument");
1409 };
1410 let coerces = match coerce {
1411 Some(coerce) => {
1412 let SpecializationArg::Array(ty, specialization_args) = coerce else {
1413 unreachable!("Expected an array specialization argument");
1414 };
1415 assert_eq!(ty, inner_ty);
1416 if specialization_args.len() != infos.len() {
1417 return None;
1418 }
1419
1420 specialization_args.iter().map(Some).collect()
1421 }
1422 None => vec![None; infos.len()],
1423 };
1424 let mut vars = vec![];
1426 let mut args = vec![];
1427 for (info, coerce) in zip_eq(infos, coerces) {
1428 let info = info?;
1429 let arg =
1430 self.try_get_specialization_arg(info, *inner_ty, &mut vars, coerce)?;
1431 args.push(arg);
1432 }
1433 if !args.is_empty()
1434 && args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
1435 {
1436 return None;
1437 }
1438 unknown_vars.extend(vars);
1439 Some(SpecializationArg::Array(*inner_ty, args))
1440 }
1441 VarInfo::Struct(infos) => {
1442 let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) =
1443 ty.long(self.db)
1444 else {
1445 return None;
1447 };
1448
1449 let members = self.db.concrete_struct_members(*concrete_struct).unwrap();
1450 let coerces = match coerce {
1451 Some(coerce) => {
1452 let SpecializationArg::Struct(specialization_args) = coerce else {
1453 unreachable!("Expected a struct specialization argument");
1454 };
1455 assert_eq!(specialization_args.len(), infos.len());
1456
1457 specialization_args.iter().map(Some).collect()
1458 }
1459 None => vec![None; infos.len()],
1460 };
1461 let mut struct_args = Vec::new();
1462 let mut vars = vec![];
1464 for ((member, opt_var_info), coerce) in
1465 zip_eq(zip_eq(members.values(), infos), coerces)
1466 {
1467 let var_info = opt_var_info?;
1468 let arg =
1469 self.try_get_specialization_arg(var_info, member.ty, &mut vars, coerce)?;
1470 struct_args.push(arg);
1471 }
1472 if !struct_args.is_empty()
1473 && struct_args
1474 .iter()
1475 .all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
1476 {
1477 return None;
1478 }
1479 unknown_vars.extend(vars);
1480 Some(SpecializationArg::Struct(struct_args))
1481 }
1482 VarInfo::Enum { variant, payload } => {
1483 let coerce = match coerce {
1484 Some(coerce) => {
1485 let SpecializationArg::Enum { variant: coercion_variant, payload } = coerce
1486 else {
1487 unreachable!("Expected an enum specialization argument");
1488 };
1489 if *coercion_variant != variant {
1490 return None;
1491 }
1492 Some(payload.as_ref())
1493 }
1494 None => None,
1495 };
1496 let mut local_unknown_vars = vec![];
1497 let payload_arg = self.try_get_specialization_arg(
1498 payload.as_ref().clone(),
1499 variant.ty,
1500 &mut local_unknown_vars,
1501 coerce,
1502 )?;
1503
1504 unknown_vars.extend(local_unknown_vars);
1505 Some(SpecializationArg::Enum { variant, payload: Box::new(payload_arg) })
1506 }
1507 VarInfo::Var(var_usage) => {
1508 unknown_vars.push(var_usage);
1509 Some(SpecializationArg::NotSpecialized)
1510 }
1511 }
1512 }
1513
1514 pub fn should_skip_const_folding(&self, db: &'db dyn Database) -> bool {
1516 if db.optimizations().skip_const_folding() {
1517 return true;
1518 }
1519
1520 if self.caller_function.base_semantic_function(db).generic_function(db)
1523 == GenericFunctionWithBodyId::Free(self.libfunc_info.panic_with_const_felt252)
1524 {
1525 return true;
1526 }
1527 false
1528 }
1529}
1530
1531fn var_info_if_copy<'db>(
1533 variables: &VariableArena<'db>,
1534 input: VarUsage<'db>,
1535) -> Option<VarInfo<'db>> {
1536 variables[input.var_id].info.copyable.is_ok().then_some(VarInfo::Var(input))
1537}
1538
1539#[salsa::tracked(returns(ref))]
1541fn priv_const_folding_info<'db>(
1542 db: &'db dyn Database,
1543) -> crate::optimizations::const_folding::ConstFoldingLibfuncInfo<'db> {
1544 ConstFoldingLibfuncInfo::new(db)
1545}
1546
1547#[derive(Debug, PartialEq, Eq, salsa::Update)]
1549pub struct ConstFoldingLibfuncInfo<'db> {
1550 felt_sub: ExternFunctionId<'db>,
1552 felt_add: ExternFunctionId<'db>,
1554 felt_mul: ExternFunctionId<'db>,
1556 felt_div: ExternFunctionId<'db>,
1558 into_box: ExternFunctionId<'db>,
1560 unbox: ExternFunctionId<'db>,
1562 box_forward_snapshot: GenericFunctionId<'db>,
1564 eq_fns: OrderedHashSet<ExternFunctionId<'db>>,
1566 uadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1568 usub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1570 diff_fns: OrderedHashSet<ExternFunctionId<'db>>,
1572 iadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1574 isub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1576 wide_mul_fns: OrderedHashSet<ExternFunctionId<'db>>,
1578 div_rem_fns: OrderedHashSet<ExternFunctionId<'db>>,
1580 bounded_int_add: ExternFunctionId<'db>,
1582 bounded_int_sub: ExternFunctionId<'db>,
1584 bounded_int_constrain: ExternFunctionId<'db>,
1586 bounded_int_trim_min: ExternFunctionId<'db>,
1588 bounded_int_trim_max: ExternFunctionId<'db>,
1590 array_get: ExternFunctionId<'db>,
1592 array_snapshot_pop_front: ExternFunctionId<'db>,
1594 array_snapshot_pop_back: ExternFunctionId<'db>,
1596 array_len: ExternFunctionId<'db>,
1598 array_new: ExternFunctionId<'db>,
1600 array_append: ExternFunctionId<'db>,
1602 array_pop_front: ExternFunctionId<'db>,
1604 storage_base_address_from_felt252: ExternFunctionId<'db>,
1606 storage_base_address_const: GenericFunctionId<'db>,
1608 panic_with_felt252: FunctionId<'db>,
1610 pub panic_with_const_felt252: FreeFunctionId<'db>,
1612 panic_with_byte_array: FunctionId<'db>,
1614 type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>>,
1616 const_calculation_info: Arc<ConstCalcInfo<'db>>,
1618}
1619impl<'db> ConstFoldingLibfuncInfo<'db> {
1620 fn new(db: &'db dyn Database) -> Self {
1621 let core = ModuleHelper::core(db);
1622 let box_module = core.submodule("box");
1623 let integer_module = core.submodule("integer");
1624 let internal_module = core.submodule("internal");
1625 let bounded_int_module = internal_module.submodule("bounded_int");
1626 let num_module = internal_module.submodule("num");
1627 let array_module = core.submodule("array");
1628 let starknet_module = core.submodule("starknet");
1629 let storage_access_module = starknet_module.submodule("storage_access");
1630 let utypes = ["u8", "u16", "u32", "u64", "u128"];
1631 let itypes = ["i8", "i16", "i32", "i64", "i128"];
1632 let eq_fns = OrderedHashSet::<_>::from_iter(
1633 chain!(utypes, itypes).map(|ty| integer_module.extern_function_id(&format!("{ty}_eq"))),
1634 );
1635 let uadd_fns = OrderedHashSet::<_>::from_iter(
1636 utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_add"))),
1637 );
1638 let usub_fns = OrderedHashSet::<_>::from_iter(
1639 utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_sub"))),
1640 );
1641 let diff_fns = OrderedHashSet::<_>::from_iter(
1642 itypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_diff"))),
1643 );
1644 let iadd_fns =
1645 OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1646 integer_module.extern_function_id(&format!("{ty}_overflowing_add_impl"))
1647 }));
1648 let isub_fns =
1649 OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1650 integer_module.extern_function_id(&format!("{ty}_overflowing_sub_impl"))
1651 }));
1652 let wide_mul_fns = OrderedHashSet::<_>::from_iter(chain!(
1653 [bounded_int_module.extern_function_id("bounded_int_mul")],
1654 ["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
1655 .map(|ty| integer_module.extern_function_id(&format!("{ty}_wide_mul"))),
1656 ));
1657 let div_rem_fns = OrderedHashSet::<_>::from_iter(chain!(
1658 [bounded_int_module.extern_function_id("bounded_int_div_rem")],
1659 utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_safe_divmod"))),
1660 ));
1661 let type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>> = OrderedHashMap::from_iter(
1662 [
1663 ("u8", false, true),
1664 ("u16", false, true),
1665 ("u32", false, true),
1666 ("u64", false, true),
1667 ("u128", false, true),
1668 ("u256", false, false),
1669 ("i8", true, true),
1670 ("i16", true, true),
1671 ("i32", true, true),
1672 ("i64", true, true),
1673 ("i128", true, true),
1674 ]
1675 .map(|(ty_name, as_bounded_int, inc_dec): (&'static str, bool, bool)| {
1676 let ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, ty_name), vec![]);
1677 let is_zero = if as_bounded_int {
1678 bounded_int_module
1679 .function_id("bounded_int_is_zero", vec![GenericArgumentId::Type(ty)])
1680 } else {
1681 integer_module.function_id(
1682 SmolStrId::from(db, format!("{ty_name}_is_zero")).long(db).as_str(),
1683 vec![],
1684 )
1685 }
1686 .lowered(db);
1687 let (inc, dec) = if inc_dec {
1688 (
1689 Some(
1690 num_module
1691 .function_id(
1692 SmolStrId::from(db, format!("{ty_name}_inc")).long(db).as_str(),
1693 vec![],
1694 )
1695 .lowered(db),
1696 ),
1697 Some(
1698 num_module
1699 .function_id(
1700 SmolStrId::from(db, format!("{ty_name}_dec")).long(db).as_str(),
1701 vec![],
1702 )
1703 .lowered(db),
1704 ),
1705 )
1706 } else {
1707 (None, None)
1708 };
1709 let info = TypeInfo { is_zero, inc, dec };
1710 (ty, info)
1711 }),
1712 );
1713 Self {
1714 felt_sub: core.extern_function_id("felt252_sub"),
1715 felt_add: core.extern_function_id("felt252_add"),
1716 felt_mul: core.extern_function_id("felt252_mul"),
1717 felt_div: core.extern_function_id("felt252_div"),
1718 into_box: box_module.extern_function_id("into_box"),
1719 unbox: box_module.extern_function_id("unbox"),
1720 box_forward_snapshot: box_module.generic_function_id("box_forward_snapshot"),
1721 eq_fns,
1722 uadd_fns,
1723 usub_fns,
1724 diff_fns,
1725 iadd_fns,
1726 isub_fns,
1727 wide_mul_fns,
1728 div_rem_fns,
1729 bounded_int_add: bounded_int_module.extern_function_id("bounded_int_add"),
1730 bounded_int_sub: bounded_int_module.extern_function_id("bounded_int_sub"),
1731 bounded_int_constrain: bounded_int_module.extern_function_id("bounded_int_constrain"),
1732 bounded_int_trim_min: bounded_int_module.extern_function_id("bounded_int_trim_min"),
1733 bounded_int_trim_max: bounded_int_module.extern_function_id("bounded_int_trim_max"),
1734 array_get: array_module.extern_function_id("array_get"),
1735 array_snapshot_pop_front: array_module.extern_function_id("array_snapshot_pop_front"),
1736 array_snapshot_pop_back: array_module.extern_function_id("array_snapshot_pop_back"),
1737 array_len: array_module.extern_function_id("array_len"),
1738 array_new: array_module.extern_function_id("array_new"),
1739 array_append: array_module.extern_function_id("array_append"),
1740 array_pop_front: array_module.extern_function_id("array_pop_front"),
1741 storage_base_address_from_felt252: storage_access_module
1742 .extern_function_id("storage_base_address_from_felt252"),
1743 storage_base_address_const: storage_access_module
1744 .generic_function_id("storage_base_address_const"),
1745 panic_with_felt252: core.function_id("panic_with_felt252", vec![]).lowered(db),
1746 panic_with_const_felt252: core.free_function_id("panic_with_const_felt252"),
1747 panic_with_byte_array: core
1748 .submodule("panics")
1749 .function_id("panic_with_byte_array", vec![])
1750 .lowered(db),
1751 type_info,
1752 const_calculation_info: db.const_calc_info(),
1753 }
1754 }
1755}
1756
1757impl<'db> std::ops::Deref for ConstFoldingContext<'db, '_> {
1758 type Target = ConstFoldingLibfuncInfo<'db>;
1759 fn deref(&self) -> &ConstFoldingLibfuncInfo<'db> {
1760 self.libfunc_info
1761 }
1762}
1763
1764impl<'a> std::ops::Deref for ConstFoldingLibfuncInfo<'a> {
1765 type Target = ConstCalcInfo<'a>;
1766 fn deref(&self) -> &ConstCalcInfo<'a> {
1767 &self.const_calculation_info
1768 }
1769}
1770
1771#[derive(Debug, PartialEq, Eq, salsa::Update)]
1773struct TypeInfo<'db> {
1774 is_zero: FunctionId<'db>,
1776 inc: Option<FunctionId<'db>>,
1778 dec: Option<FunctionId<'db>>,
1780}
1781
1782trait TypeRangeNormalizer {
1783 fn normalized(&self, value: BigInt) -> NormalizedResult;
1786}
1787impl TypeRangeNormalizer for TypeRange {
1788 fn normalized(&self, value: BigInt) -> NormalizedResult {
1789 if value < self.min {
1790 NormalizedResult::Under(value - &self.min + &self.max + 1)
1791 } else if value > self.max {
1792 NormalizedResult::Over(value + &self.min - &self.max - 1)
1793 } else {
1794 NormalizedResult::InRange(value)
1795 }
1796 }
1797}
1798
1799enum NormalizedResult {
1801 InRange(BigInt),
1803 Over(BigInt),
1805 Under(BigInt),
1807}