1#[cfg(test)]
2#[path = "const_folding_test.rs"]
3mod test;
4
5use std::sync::Arc;
6
7use cairo_lang_defs::ids::{ExternFunctionId, FreeFunctionId};
8use cairo_lang_filesystem::flag::flag_unsafe_panic;
9use cairo_lang_filesystem::ids::SmolStrId;
10use cairo_lang_semantic::corelib::CorelibSemantic;
11use cairo_lang_semantic::helper::ModuleHelper;
12use cairo_lang_semantic::items::constant::{
13 ConstCalcInfo, ConstValue, ConstValueId, ConstantSemantic, TypeRange, canonical_felt252,
14 felt252_for_downcast,
15};
16use cairo_lang_semantic::items::functions::{GenericFunctionId, GenericFunctionWithBodyId};
17use cairo_lang_semantic::items::structure::StructSemantic;
18use cairo_lang_semantic::types::{TypeSizeInformation, TypesSemantic};
19use cairo_lang_semantic::{
20 ConcreteTypeId, ConcreteVariant, GenericArgumentId, MatchArmSelector, TypeId, TypeLongId,
21 corelib,
22};
23use cairo_lang_utils::byte_array::BYTE_ARRAY_MAGIC;
24use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
25use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
26use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
27use cairo_lang_utils::{Intern, extract_matches, try_extract_matches};
28use itertools::{chain, zip_eq};
29use num_bigint::BigInt;
30use num_integer::Integer;
31use num_traits::cast::ToPrimitive;
32use num_traits::{Num, One, Zero};
33use salsa::Database;
34use starknet_types_core::felt::Felt as Felt252;
35
36use crate::db::LoweringGroup;
37use crate::ids::{
38 ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionId, SemanticFunctionIdEx,
39 SpecializedFunction,
40};
41use crate::specialization::SpecializationArg;
42use crate::utils::InliningStrategy;
43use crate::{
44 Block, BlockEnd, BlockId, Lowered, MatchArm, MatchEnumInfo, MatchExternInfo, MatchInfo,
45 Statement, StatementCall, StatementConst, StatementDesnap, StatementEnumConstruct,
46 StatementSnapshot, StatementStructConstruct, StatementStructDestructure, VarRemapping,
47 VarUsage, Variable, VariableArena, VariableId,
48};
49
50#[derive(Debug, Clone)]
53enum VarInfo<'db> {
54 Const(ConstValueId<'db>),
56 Var(VarUsage<'db>),
58 Snapshot(Box<VarInfo<'db>>),
60 Struct(Vec<Option<VarInfo<'db>>>),
63 Enum { variant: ConcreteVariant<'db>, payload: Box<VarInfo<'db>> },
65 Box(Box<VarInfo<'db>>),
67 Array(Vec<Option<VarInfo<'db>>>),
70}
71impl<'db> VarInfo<'db> {
72 fn peel_snapshots(&self) -> (usize, &VarInfo<'db>) {
74 let mut n_snapshots = 0;
75 let mut info = self;
76 while let VarInfo::Snapshot(inner) = info {
77 info = inner.as_ref();
78 n_snapshots += 1;
79 }
80 (n_snapshots, info)
81 }
82 fn wrap_with_snapshots(mut self, n_snapshots: usize) -> VarInfo<'db> {
84 for _ in 0..n_snapshots {
85 self = VarInfo::Snapshot(Box::new(self));
86 }
87 self
88 }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq)]
92enum Reachability {
93 FromSingleGoto(BlockId),
96 Any,
99}
100
101pub fn const_folding<'db>(
104 db: &'db dyn Database,
105 function_id: ConcreteFunctionWithBodyId<'db>,
106 lowered: &mut Lowered<'db>,
107) {
108 if lowered.blocks.is_empty() {
109 return;
110 }
111
112 let mut ctx = ConstFoldingContext::new(db, function_id, &mut lowered.variables);
115
116 if ctx.should_skip_const_folding(db) {
117 return;
118 }
119
120 for block_id in (0..lowered.blocks.len()).map(BlockId) {
121 if !ctx.visit_block_start(block_id, |block_id| &lowered.blocks[block_id]) {
122 continue;
123 }
124
125 let block = &mut lowered.blocks[block_id];
126 for stmt in block.statements.iter_mut() {
127 ctx.visit_statement(stmt);
128 }
129 ctx.visit_block_end(block_id, block);
130 }
131}
132
133pub struct ConstFoldingContext<'db, 'mt> {
134 db: &'db dyn Database,
136 pub variables: &'mt mut VariableArena<'db>,
138 var_info: UnorderedHashMap<VariableId, VarInfo<'db>>,
140 libfunc_info: &'db ConstFoldingLibfuncInfo<'db>,
142 caller_base: ConcreteFunctionWithBodyId<'db>,
145 reachability: UnorderedHashMap<BlockId, Reachability>,
149 additional_stmts: Vec<Statement<'db>>,
151}
152
153impl<'db, 'mt> ConstFoldingContext<'db, 'mt> {
154 pub fn new(
155 db: &'db dyn Database,
156 function_id: ConcreteFunctionWithBodyId<'db>,
157 variables: &'mt mut VariableArena<'db>,
158 ) -> Self {
159 let caller_base = match function_id.long(db) {
160 ConcreteFunctionWithBodyLongId::Specialized(specialized_func) => specialized_func.base,
161 _ => function_id,
162 };
163
164 Self {
165 db,
166 var_info: UnorderedHashMap::default(),
167 variables,
168 libfunc_info: priv_const_folding_info(db),
169 caller_base,
170 reachability: UnorderedHashMap::from_iter([(BlockId::root(), Reachability::Any)]),
171 additional_stmts: vec![],
172 }
173 }
174
175 pub fn visit_block_start<'r, 'get>(
178 &'r mut self,
179 block_id: BlockId,
180 get_block: impl FnOnce(BlockId) -> &'get Block<'db>,
181 ) -> bool
182 where
183 'db: 'get,
184 {
185 let Some(reachability) = self.reachability.remove(&block_id) else {
186 return false;
187 };
188 match reachability {
189 Reachability::Any => {}
190 Reachability::FromSingleGoto(from_block) => match &get_block(from_block).end {
191 BlockEnd::Goto(_, remapping) => {
192 for (dst, src) in remapping.iter() {
193 if let Some(v) = self.as_const(src.var_id) {
194 self.var_info.insert(*dst, VarInfo::Const(v));
195 }
196 }
197 }
198 _ => unreachable!("Expected a goto end"),
199 },
200 }
201 true
202 }
203
204 pub fn visit_statement(&mut self, stmt: &mut Statement<'db>) {
214 self.maybe_replace_inputs(stmt.inputs_mut());
215 match stmt {
216 Statement::Const(StatementConst { value, output, boxed }) if *boxed => {
217 self.var_info.insert(*output, VarInfo::Box(VarInfo::Const(*value).into()));
218 }
219 Statement::Const(StatementConst { value, output, .. }) => match value.long(self.db) {
220 ConstValue::Int(..)
221 | ConstValue::Struct(..)
222 | ConstValue::Enum(..)
223 | ConstValue::NonZero(..) => {
224 self.var_info.insert(*output, VarInfo::Const(*value));
225 }
226 ConstValue::Generic(_)
227 | ConstValue::ImplConstant(_)
228 | ConstValue::Var(..)
229 | ConstValue::Missing(_) => {}
230 },
231 Statement::Snapshot(stmt) => {
232 if let Some(info) = self.var_info.get(&stmt.input.var_id).cloned() {
233 self.var_info.insert(stmt.original(), info.clone());
234 self.var_info.insert(stmt.snapshot(), VarInfo::Snapshot(info.into()));
235 }
236 }
237 Statement::Desnap(StatementDesnap { input, output }) => {
238 if let Some(VarInfo::Snapshot(info)) = self.var_info.get(&input.var_id) {
239 self.var_info.insert(*output, info.as_ref().clone());
240 }
241 }
242 Statement::Call(call_stmt) => {
243 if let Some(updated_stmt) = self.handle_statement_call(call_stmt) {
244 *stmt = updated_stmt;
245 } else if let Some(updated_stmt) = self.try_specialize_call(call_stmt) {
246 *stmt = updated_stmt;
247 }
248 }
249 Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
250 let mut const_args = vec![];
251 let mut all_args = vec![];
252 let mut contains_info = false;
253 for input in inputs.iter() {
254 let Some(info) = self.var_info.get(&input.var_id) else {
255 all_args.push(var_info_if_copy(self.variables, *input));
256 continue;
257 };
258 contains_info = true;
259 if let VarInfo::Const(value) = info {
260 const_args.push(*value);
261 }
262 all_args.push(Some(info.clone()));
263 }
264 if const_args.len() == inputs.len() {
265 let value =
266 ConstValue::Struct(const_args, self.variables[*output].ty).intern(self.db);
267 self.var_info.insert(*output, VarInfo::Const(value));
268 } else if contains_info {
269 self.var_info.insert(*output, VarInfo::Struct(all_args));
270 }
271 }
272 Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
273 if let Some(info) = self.var_info.get(&input.var_id) {
274 let (n_snapshots, info) = info.peel_snapshots();
275 match info {
276 VarInfo::Const(const_value) => {
277 if let ConstValue::Struct(member_values, _) = const_value.long(self.db)
278 {
279 for (output, value) in zip_eq(outputs, member_values) {
280 self.var_info.insert(
281 *output,
282 VarInfo::Const(*value).wrap_with_snapshots(n_snapshots),
283 );
284 }
285 }
286 }
287 VarInfo::Struct(members) => {
288 for (output, member) in zip_eq(outputs, members.clone()) {
289 if let Some(member) = member {
290 self.var_info
291 .insert(*output, member.wrap_with_snapshots(n_snapshots));
292 }
293 }
294 }
295 _ => {}
296 }
297 }
298 }
299 Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
300 let value = if let Some(info) = self.var_info.get(&input.var_id) {
301 if let VarInfo::Const(val) = info {
302 VarInfo::Const(ConstValue::Enum(*variant, *val).intern(self.db))
303 } else {
304 VarInfo::Enum { variant: *variant, payload: info.clone().into() }
305 }
306 } else {
307 VarInfo::Enum { variant: *variant, payload: VarInfo::Var(*input).into() }
308 };
309 self.var_info.insert(*output, value);
310 }
311 }
312 }
313
314 pub fn visit_block_end(&mut self, block_id: BlockId, block: &mut Block<'db>) {
321 let statements = &mut block.statements;
322 statements.splice(0..0, self.additional_stmts.drain(..));
323
324 match &mut block.end {
325 BlockEnd::Goto(_, remappings) => {
326 for (_, v) in remappings.iter_mut() {
327 self.maybe_replace_input(v);
328 }
329 }
330 BlockEnd::Match { info } => {
331 self.maybe_replace_inputs(info.inputs_mut());
332 match info {
333 MatchInfo::Enum(info) => {
334 if let Some(updated_end) = self.handle_enum_block_end(info, statements) {
335 block.end = updated_end;
336 }
337 }
338 MatchInfo::Extern(info) => {
339 if let Some(updated_end) = self.handle_extern_block_end(info, statements) {
340 block.end = updated_end;
341 }
342 }
343 MatchInfo::Value(info) => {
344 if let Some(value) =
345 self.as_int(info.input.var_id).and_then(|x| x.to_usize())
346 && let Some(arm) = info.arms.iter().find(|arm| {
347 matches!(
348 &arm.arm_selector,
349 MatchArmSelector::Value(v) if v.value == value
350 )
351 })
352 {
353 statements.push(Statement::StructConstruct(StatementStructConstruct {
355 inputs: vec![],
356 output: arm.var_ids[0],
357 }));
358 block.end = BlockEnd::Goto(arm.block_id, Default::default());
359 }
360 }
361 }
362 }
363 BlockEnd::Return(inputs, _) => self.maybe_replace_inputs(inputs),
364 BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
365 }
366 match &block.end {
367 BlockEnd::Goto(dst_block_id, _) => {
368 match self.reachability.entry(*dst_block_id) {
369 std::collections::hash_map::Entry::Occupied(mut e) => {
370 e.insert(Reachability::Any)
371 }
372 std::collections::hash_map::Entry::Vacant(e) => {
373 *e.insert(Reachability::FromSingleGoto(block_id))
374 }
375 };
376 }
377 BlockEnd::Match { info } => {
378 for arm in info.arms() {
379 assert!(self.reachability.insert(arm.block_id, Reachability::Any).is_none());
380 }
381 }
382 BlockEnd::NotSet | BlockEnd::Return(..) | BlockEnd::Panic(..) => {}
383 }
384 }
385
386 fn handle_statement_call(&mut self, stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
394 let db = self.db;
395 if stmt.function == self.panic_with_felt252 {
396 let val = self.as_const(stmt.inputs[0].var_id)?;
397 stmt.inputs.clear();
398 stmt.function = GenericFunctionId::Free(self.panic_with_const_felt252)
399 .concretize(db, vec![GenericArgumentId::Constant(val)])
400 .lowered(db);
401 return None;
402 } else if stmt.function == self.panic_with_byte_array && !flag_unsafe_panic(db) {
403 let snap = self.var_info.get(&stmt.inputs[0].var_id)?;
404 let bytearray = try_extract_matches!(snap, VarInfo::Snapshot)?;
405 let [
406 Some(VarInfo::Array(data)),
407 Some(VarInfo::Const(pending_word)),
408 Some(VarInfo::Const(pending_len)),
409 ] = &try_extract_matches!(bytearray.as_ref(), VarInfo::Struct)?[..]
410 else {
411 return None;
412 };
413 let mut panic_data =
414 vec![BigInt::from_str_radix(BYTE_ARRAY_MAGIC, 16).unwrap(), data.len().into()];
415 for word in data {
416 let Some(VarInfo::Const(word)) = word else {
417 return None;
418 };
419 panic_data.push(word.long(db).to_int()?.clone());
420 }
421 panic_data.extend([
422 pending_word.long(db).to_int()?.clone(),
423 pending_len.long(db).to_int()?.clone(),
424 ]);
425 let felt252_ty = self.felt252;
426 let location = stmt.location;
427 let new_var = |ty| Variable::with_default_context(db, ty, location);
428 let as_usage = |var_id| VarUsage { var_id, location };
429 let array_fn = |extern_id| {
430 let args = vec![GenericArgumentId::Type(felt252_ty)];
431 GenericFunctionId::Extern(extern_id).concretize(db, args).lowered(db)
432 };
433 let call_stmt = |function, inputs, outputs| {
434 let with_coupon = false;
435 Statement::Call(StatementCall { function, inputs, with_coupon, outputs, location })
436 };
437 let arr_var = new_var(corelib::core_array_felt252_ty(db));
438 let mut arr = self.variables.alloc(arr_var.clone());
439 self.additional_stmts.push(call_stmt(array_fn(self.array_new), vec![], vec![arr]));
440 let felt252_var = new_var(felt252_ty);
441 let arr_append_fn = array_fn(self.array_append);
442 for word in panic_data {
443 let to_append = self.variables.alloc(felt252_var.clone());
444 let new_arr = self.variables.alloc(arr_var.clone());
445 self.additional_stmts.push(Statement::Const(StatementConst::new_flat(
446 ConstValue::Int(word, felt252_ty).intern(db),
447 to_append,
448 )));
449 self.additional_stmts.push(call_stmt(
450 arr_append_fn,
451 vec![as_usage(arr), as_usage(to_append)],
452 vec![new_arr],
453 ));
454 arr = new_arr;
455 }
456 let panic_ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, "Panic"), vec![]);
457 let panic_var = self.variables.alloc(new_var(panic_ty));
458 self.additional_stmts.push(Statement::StructConstruct(StatementStructConstruct {
459 inputs: vec![],
460 output: panic_var,
461 }));
462 return Some(Statement::StructConstruct(StatementStructConstruct {
463 inputs: vec![as_usage(panic_var), as_usage(arr)],
464 output: stmt.outputs[0],
465 }));
466 }
467 let (id, _generic_args) = stmt.function.get_extern(db)?;
468 if id == self.felt_sub {
469 if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
470 && rhs.is_zero()
471 {
472 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
473 None
474 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
475 && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
476 {
477 let value = canonical_felt252(&(lhs - rhs));
478 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
479 } else {
480 None
481 }
482 } else if id == self.felt_add {
483 if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
484 && lhs.is_zero()
485 {
486 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]));
487 None
488 } else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
489 && rhs.is_zero()
490 {
491 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
492 None
493 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
494 && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
495 {
496 let value = canonical_felt252(&(lhs + rhs));
497 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
498 } else {
499 None
500 }
501 } else if id == self.felt_mul {
502 let lhs = self.as_int(stmt.inputs[0].var_id);
503 let rhs = self.as_int(stmt.inputs[1].var_id);
504 if lhs.map(Zero::is_zero).unwrap_or_default()
505 || rhs.map(Zero::is_zero).unwrap_or_default()
506 {
507 Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
508 } else if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
509 && rhs.is_one()
510 {
511 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
512 None
513 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
514 && lhs.is_one()
515 {
516 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[1]));
517 None
518 } else if let Some(lhs) = lhs
519 && let Some(rhs) = rhs
520 {
521 let value = canonical_felt252(&(lhs * rhs));
522 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
523 } else {
524 None
525 }
526 } else if id == self.felt_div {
527 if let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
529 && rhs.is_one()
531 {
532 self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
533 None
534 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
535 && lhs.is_zero()
537 {
538 Some(self.propagate_zero_and_get_statement(stmt.outputs[0]))
539 } else if let Some(lhs) = self.as_int(stmt.inputs[0].var_id)
540 && let Some(rhs) = self.as_int(stmt.inputs[1].var_id)
541 && let Ok(rhs_nonzero) = Felt252::from(rhs).try_into()
542 {
543 let lhs_felt = Felt252::from(lhs);
547 let value = lhs_felt.field_div(&rhs_nonzero).to_bigint();
548 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
549 } else {
550 None
551 }
552 } else if self.wide_mul_fns.contains(&id) {
553 let lhs = self.as_int(stmt.inputs[0].var_id);
554 let rhs = self.as_int(stmt.inputs[1].var_id);
555 let output = stmt.outputs[0];
556 if lhs.map(Zero::is_zero).unwrap_or_default()
557 || rhs.map(Zero::is_zero).unwrap_or_default()
558 {
559 return Some(self.propagate_zero_and_get_statement(output));
560 }
561 let lhs = lhs?;
562 Some(self.propagate_const_and_get_statement(lhs * rhs?, stmt.outputs[0]))
563 } else if id == self.bounded_int_add || id == self.bounded_int_sub {
564 let lhs = self.as_int(stmt.inputs[0].var_id)?;
565 let rhs = self.as_int(stmt.inputs[1].var_id)?;
566 let value = if id == self.bounded_int_add { lhs + rhs } else { lhs - rhs };
567 Some(self.propagate_const_and_get_statement(value, stmt.outputs[0]))
568 } else if self.div_rem_fns.contains(&id) {
569 let lhs = self.as_int(stmt.inputs[0].var_id);
570 if lhs.map(Zero::is_zero).unwrap_or_default() {
571 let additional_stmt = self.propagate_zero_and_get_statement(stmt.outputs[1]);
572 self.additional_stmts.push(additional_stmt);
573 return Some(self.propagate_zero_and_get_statement(stmt.outputs[0]));
574 }
575 let rhs = self.as_int(stmt.inputs[1].var_id)?;
576 let (q, r) = lhs?.div_rem(rhs);
577 let q_output = stmt.outputs[0];
578 let q_value = ConstValue::Int(q, self.variables[q_output].ty).intern(db);
579 self.var_info.insert(q_output, VarInfo::Const(q_value));
580 let r_output = stmt.outputs[1];
581 let r_value = ConstValue::Int(r, self.variables[r_output].ty).intern(db);
582 self.var_info.insert(r_output, VarInfo::Const(r_value));
583 self.additional_stmts
584 .push(Statement::Const(StatementConst::new_flat(r_value, r_output)));
585 Some(Statement::Const(StatementConst::new_flat(q_value, q_output)))
586 } else if id == self.storage_base_address_from_felt252 {
587 let input_var = stmt.inputs[0].var_id;
588 if let Some(const_value) = self.as_const(input_var)
589 && let ConstValue::Int(val, ty) = const_value.long(db)
590 {
591 stmt.inputs.clear();
592 let arg = GenericArgumentId::Constant(ConstValue::Int(val.clone(), *ty).intern(db));
593 stmt.function =
594 self.storage_base_address_const.concretize(db, vec![arg]).lowered(db);
595 }
596 None
597 } else if id == self.into_box {
598 let input = stmt.inputs[0];
599 let var_info = self.var_info.get(&input.var_id);
600 let const_value = match var_info {
601 Some(VarInfo::Const(val)) => Some(*val),
602 Some(VarInfo::Snapshot(info)) => {
603 try_extract_matches!(info.as_ref(), VarInfo::Const).copied()
604 }
605 _ => None,
606 };
607 let var_info = var_info.cloned().or_else(|| var_info_if_copy(self.variables, input))?;
608 self.var_info.insert(stmt.outputs[0], VarInfo::Box(var_info.into()));
609 Some(Statement::Const(StatementConst::new_boxed(const_value?, stmt.outputs[0])))
610 } else if id == self.unbox {
611 if let VarInfo::Box(inner) = self.var_info.get(&stmt.inputs[0].var_id)? {
612 let inner = inner.as_ref().clone();
613 if let VarInfo::Const(inner) =
614 self.var_info.entry(stmt.outputs[0]).insert_entry(inner).get()
615 {
616 return Some(Statement::Const(StatementConst::new_flat(
617 *inner,
618 stmt.outputs[0],
619 )));
620 }
621 }
622 None
623 } else if self.upcast_fns.contains(&id) {
624 let int_value = self.as_int(stmt.inputs[0].var_id)?;
625 let output = stmt.outputs[0];
626 let value = ConstValue::Int(int_value.clone(), self.variables[output].ty).intern(db);
627 self.var_info.insert(output, VarInfo::Const(value));
628 Some(Statement::Const(StatementConst::new_flat(value, output)))
629 } else if id == self.array_new {
630 self.var_info.insert(stmt.outputs[0], VarInfo::Array(vec![]));
631 None
632 } else if id == self.array_append {
633 let mut var_infos =
634 if let VarInfo::Array(var_infos) = self.var_info.get(&stmt.inputs[0].var_id)? {
635 var_infos.clone()
636 } else {
637 return None;
638 };
639 let appended = stmt.inputs[1];
640 var_infos.push(match self.var_info.get(&appended.var_id) {
641 Some(var_info) => Some(var_info.clone()),
642 None => var_info_if_copy(self.variables, appended),
643 });
644 self.var_info.insert(stmt.outputs[0], VarInfo::Array(var_infos));
645 None
646 } else if id == self.array_len {
647 let info = self.var_info.get(&stmt.inputs[0].var_id)?;
648 let desnapped = try_extract_matches!(info, VarInfo::Snapshot)?;
649 let length = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?.len();
650 Some(self.propagate_const_and_get_statement(length.into(), stmt.outputs[0]))
651 } else {
652 None
653 }
654 }
655
656 fn try_specialize_call(&self, call_stmt: &mut StatementCall<'db>) -> Option<Statement<'db>> {
663 if call_stmt.with_coupon {
664 return None;
665 }
666 if matches!(self.db.optimizations().inlining_strategy(), InliningStrategy::Avoid) {
668 return None;
669 }
670
671 let Ok(Some(mut base)) = call_stmt.function.body(self.db) else {
672 return None;
673 };
674
675 if self.db.priv_never_inline(base).ok()? {
676 return None;
677 }
678
679 if base == self.caller_base {
682 return None;
683 }
684 if call_stmt.inputs.iter().all(|arg| self.var_info.get(&arg.var_id).is_none()) {
685 return None;
687 }
688 let mut specialization_args = vec![];
689 let mut new_args = vec![];
690 for arg in &call_stmt.inputs {
691 if let Some(var_info) = self.var_info.get(&arg.var_id)
692 && self.variables[arg.var_id].info.droppable.is_ok()
693 && let Some(specialization_arg) = self.try_get_specialization_arg(
694 var_info.clone(),
695 self.variables[arg.var_id].ty,
696 &mut new_args,
697 )
698 {
699 specialization_args.push(specialization_arg);
700 } else {
701 specialization_args.push(SpecializationArg::NotSpecialized);
702 new_args.push(*arg);
703 continue;
704 };
705 }
706
707 if specialization_args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized)) {
708 return None;
710 }
711 if let ConcreteFunctionWithBodyLongId::Specialized(specialized_function) =
712 base.long(self.db)
713 {
714 base = specialized_function.base;
717 let mut new_args_iter = specialization_args.into_iter();
718 let mut old_args = specialized_function.args.to_vec();
719 let mut stack = vec![];
720 for arg in old_args.iter_mut().rev() {
721 stack.push(arg);
722 }
723 while let Some(arg) = stack.pop() {
724 match arg {
725 SpecializationArg::Const { .. } => {}
726 SpecializationArg::Snapshot(inner) => {
727 stack.push(inner.as_mut());
728 }
729 SpecializationArg::Enum { payload, .. } => {
730 stack.push(payload.as_mut());
731 }
732 SpecializationArg::Array(_, values) => {
733 for value in values.iter_mut().rev() {
734 stack.push(value);
735 }
736 }
737 SpecializationArg::Struct(specialization_args) => {
738 for arg in specialization_args.iter_mut().rev() {
739 stack.push(arg);
740 }
741 }
742 SpecializationArg::NotSpecialized => {
743 *arg = new_args_iter.next().unwrap_or(SpecializationArg::NotSpecialized);
744 }
745 }
746 }
747 specialization_args = old_args;
748 }
749 let specialized = SpecializedFunction { base, args: specialization_args.into() };
750 let specialized_func_id =
751 ConcreteFunctionWithBodyLongId::Specialized(specialized).intern(self.db);
752
753 if self.db.priv_should_specialize(specialized_func_id) == Ok(false) {
754 return None;
755 }
756
757 Some(Statement::Call(StatementCall {
758 function: specialized_func_id.function_id(self.db).unwrap(),
759 inputs: new_args,
760 with_coupon: call_stmt.with_coupon,
761 outputs: std::mem::take(&mut call_stmt.outputs),
762 location: call_stmt.location,
763 }))
764 }
765
766 fn propagate_const_and_get_statement(
768 &mut self,
769 value: BigInt,
770 output: VariableId,
771 ) -> Statement<'db> {
772 let ty = self.variables[output].ty;
773 let value = ConstValueId::from_int(self.db, ty, &value);
774 self.var_info.insert(output, VarInfo::Const(value));
775 Statement::Const(StatementConst::new_flat(value, output))
776 }
777
778 fn propagate_zero_and_get_statement(&mut self, output: VariableId) -> Statement<'db> {
780 self.propagate_const_and_get_statement(BigInt::zero(), output)
781 }
782
783 fn try_generate_const_statement(
785 &self,
786 value: ConstValueId<'db>,
787 output: VariableId,
788 ) -> Option<Statement<'db>> {
789 if self.db.type_size_info(self.variables[output].ty) == Ok(TypeSizeInformation::Other) {
790 Some(Statement::Const(StatementConst::new_flat(value, output)))
791 } else if matches!(value.long(self.db), ConstValue::Struct(members, _) if members.is_empty())
792 {
793 Some(Statement::StructConstruct(StatementStructConstruct { inputs: vec![], output }))
795 } else {
796 None
797 }
798 }
799
800 fn handle_enum_block_end(
805 &mut self,
806 info: &mut MatchEnumInfo<'db>,
807 statements: &mut Vec<Statement<'db>>,
808 ) -> Option<BlockEnd<'db>> {
809 let input = info.input.var_id;
810 let (n_snapshots, var_info) = self.var_info.get(&input)?.peel_snapshots();
811 let location = info.location;
812 let as_usage = |var_id| VarUsage { var_id, location };
813 let db = self.db;
814 let snapshot_stmt = |vars: &mut VariableArena<'_>, pre_snap, post_snap| {
815 let ignored = vars.alloc(vars[pre_snap].clone());
816 Statement::Snapshot(StatementSnapshot::new(as_usage(pre_snap), ignored, post_snap))
817 };
818 if let VarInfo::Const(const_value) = var_info
820 && let ConstValue::Enum(variant, value) = const_value.long(db)
821 {
822 let arm = &info.arms[variant.idx];
823 let output = arm.var_ids[0];
824 self.var_info.insert(output, VarInfo::Const(*value).wrap_with_snapshots(n_snapshots));
826 if self.variables[input].info.droppable.is_ok()
827 && self.variables[output].info.copyable.is_ok()
828 && let Ok(mut ty) = value.ty(db)
829 && let Some(mut stmt) = self.try_generate_const_statement(*value, output)
830 {
831 for _ in 0..n_snapshots {
833 let non_snap_var = Variable::with_default_context(db, ty, location);
834 ty = TypeLongId::Snapshot(ty).intern(db);
835 let pre_snap = self.variables.alloc(non_snap_var);
836 stmt.outputs_mut()[0] = pre_snap;
837 let take_snap = snapshot_stmt(self.variables, pre_snap, output);
838 statements.push(core::mem::replace(&mut stmt, take_snap));
839 }
840 statements.push(stmt);
841 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
842 }
843 } else if let VarInfo::Enum { variant, payload } = var_info {
844 let arm = &info.arms[variant.idx];
845 let variant_ty = variant.ty;
846 let output = arm.var_ids[0];
847 let payload = payload.as_ref().clone();
848 let unwrapped =
849 self.variables[input].info.droppable.is_ok().then_some(()).and_then(|_| {
850 let (extra_snapshots, inner) = payload.peel_snapshots();
851 match inner {
852 VarInfo::Var(var) if self.variables[var.var_id].info.copyable.is_ok() => {
853 Some((var.var_id, extra_snapshots))
854 }
855 VarInfo::Const(value) => {
856 let const_var = self
857 .variables
858 .alloc(Variable::with_default_context(db, variant_ty, location));
859 statements.push(self.try_generate_const_statement(*value, const_var)?);
860 Some((const_var, extra_snapshots))
861 }
862 _ => None,
863 }
864 });
865 self.var_info.insert(output, payload.wrap_with_snapshots(n_snapshots));
867 if let Some((mut unwrapped, extra_snapshots)) = unwrapped {
868 let total_snapshots = n_snapshots + extra_snapshots;
869 if total_snapshots != 0 {
870 for _ in 1..total_snapshots {
872 let ty = TypeLongId::Snapshot(self.variables[unwrapped].ty).intern(db);
873 let non_snap_var = Variable::with_default_context(self.db, ty, location);
874 let snapped = self.variables.alloc(non_snap_var);
875 statements.push(snapshot_stmt(self.variables, unwrapped, snapped));
876 unwrapped = snapped;
877 }
878 statements.push(snapshot_stmt(self.variables, unwrapped, output));
879 };
880 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
881 }
882 }
883 None
884 }
885
886 fn handle_extern_block_end(
891 &mut self,
892 info: &mut MatchExternInfo<'db>,
893 statements: &mut Vec<Statement<'db>>,
894 ) -> Option<BlockEnd<'db>> {
895 let db = self.db;
896 let (id, generic_args) = info.function.get_extern(db)?;
897 if self.nz_fns.contains(&id) {
898 let val = self.as_const(info.inputs[0].var_id)?;
899 let is_zero = match val.long(db) {
900 ConstValue::Int(v, _) => v.is_zero(),
901 ConstValue::Struct(s, _) => s.iter().all(|v| {
902 v.long(db).to_int().expect("Expected ConstValue::Int for size").is_zero()
903 }),
904 _ => unreachable!(),
905 };
906 Some(if is_zero {
907 BlockEnd::Goto(info.arms[0].block_id, Default::default())
908 } else {
909 let arm = &info.arms[1];
910 let nz_var = arm.var_ids[0];
911 let nz_val = ConstValue::NonZero(val).intern(db);
912 self.var_info.insert(nz_var, VarInfo::Const(nz_val));
913 statements.push(Statement::Const(StatementConst::new_flat(nz_val, nz_var)));
914 BlockEnd::Goto(arm.block_id, Default::default())
915 })
916 } else if self.eq_fns.contains(&id) {
917 let lhs = self.as_int(info.inputs[0].var_id);
918 let rhs = self.as_int(info.inputs[1].var_id);
919 if (lhs.map(Zero::is_zero).unwrap_or_default() && rhs.is_none())
920 || (rhs.map(Zero::is_zero).unwrap_or_default() && lhs.is_none())
921 {
922 let nz_input = info.inputs[if lhs.is_some() { 1 } else { 0 }];
923 let var = &self.variables[nz_input.var_id].clone();
924 let function = self.type_info.get(&var.ty)?.is_zero;
925 let unused_nz_var = Variable::with_default_context(
926 db,
927 corelib::core_nonzero_ty(db, var.ty),
928 var.location,
929 );
930 let unused_nz_var = self.variables.alloc(unused_nz_var);
931 return Some(BlockEnd::Match {
932 info: MatchInfo::Extern(MatchExternInfo {
933 function,
934 inputs: vec![nz_input],
935 arms: vec![
936 MatchArm {
937 arm_selector: MatchArmSelector::VariantId(
938 corelib::jump_nz_zero_variant(db, var.ty),
939 ),
940 block_id: info.arms[1].block_id,
941 var_ids: vec![],
942 },
943 MatchArm {
944 arm_selector: MatchArmSelector::VariantId(
945 corelib::jump_nz_nonzero_variant(db, var.ty),
946 ),
947 block_id: info.arms[0].block_id,
948 var_ids: vec![unused_nz_var],
949 },
950 ],
951 location: info.location,
952 }),
953 });
954 }
955 Some(BlockEnd::Goto(
956 info.arms[if lhs? == rhs? { 1 } else { 0 }].block_id,
957 Default::default(),
958 ))
959 } else if self.uadd_fns.contains(&id)
960 || self.usub_fns.contains(&id)
961 || self.diff_fns.contains(&id)
962 || self.iadd_fns.contains(&id)
963 || self.isub_fns.contains(&id)
964 {
965 let rhs = self.as_int(info.inputs[1].var_id);
966 let lhs = self.as_int(info.inputs[0].var_id);
967 if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
968 let ty = self.variables[info.arms[0].var_ids[0]].ty;
969 let range = self.type_value_ranges.get(&ty)?;
970 let value = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
971 lhs + rhs
972 } else {
973 lhs - rhs
974 };
975 let (arm_index, value) = match range.normalized(value) {
976 NormalizedResult::InRange(value) => (0, value),
977 NormalizedResult::Under(value) => (1, value),
978 NormalizedResult::Over(value) => (
979 if self.iadd_fns.contains(&id) || self.isub_fns.contains(&id) {
980 2
981 } else {
982 1
983 },
984 value,
985 ),
986 };
987 let arm = &info.arms[arm_index];
988 let actual_output = arm.var_ids[0];
989 let value = ConstValue::Int(value, ty).intern(db);
990 self.var_info.insert(actual_output, VarInfo::Const(value));
991 statements.push(Statement::Const(StatementConst::new_flat(value, actual_output)));
992 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
993 }
994 if let Some(rhs) = rhs {
995 if rhs.is_zero() && !self.diff_fns.contains(&id) {
996 let arm = &info.arms[0];
997 self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[0]));
998 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
999 }
1000 if rhs.is_one() && !self.diff_fns.contains(&id) {
1001 let ty = self.variables[info.arms[0].var_ids[0]].ty;
1002 let ty_info = self.type_info.get(&ty)?;
1003 let function = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
1004 ty_info.inc?
1005 } else {
1006 ty_info.dec?
1007 };
1008 let enum_ty = function.signature(db).ok()?.return_type;
1009 let TypeLongId::Concrete(ConcreteTypeId::Enum(concrete_enum_id)) =
1010 enum_ty.long(db)
1011 else {
1012 return None;
1013 };
1014 let result = self.variables.alloc(Variable::with_default_context(
1015 db,
1016 function.signature(db).unwrap().return_type,
1017 info.location,
1018 ));
1019 statements.push(Statement::Call(StatementCall {
1020 function,
1021 inputs: vec![info.inputs[0]],
1022 with_coupon: false,
1023 outputs: vec![result],
1024 location: info.location,
1025 }));
1026 return Some(BlockEnd::Match {
1027 info: MatchInfo::Enum(MatchEnumInfo {
1028 concrete_enum_id: *concrete_enum_id,
1029 input: VarUsage { var_id: result, location: info.location },
1030 arms: core::mem::take(&mut info.arms),
1031 location: info.location,
1032 }),
1033 });
1034 }
1035 }
1036 if let Some(lhs) = lhs
1037 && lhs.is_zero()
1038 && (self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id))
1039 {
1040 let arm = &info.arms[0];
1041 self.var_info.insert(arm.var_ids[0], VarInfo::Var(info.inputs[1]));
1042 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1043 }
1044 None
1045 } else if let Some(reversed) = self.downcast_fns.get(&id) {
1046 let range = |ty: TypeId<'_>| {
1047 Some(if let Some(range) = self.type_value_ranges.get(&ty) {
1048 range.clone()
1049 } else {
1050 let (min, max) = corelib::try_extract_bounded_int_type_ranges(db, ty)?;
1051 TypeRange { min, max }
1052 })
1053 };
1054 let (success_arm, failure_arm) = if *reversed { (1, 0) } else { (0, 1) };
1055 let input_var = info.inputs[0].var_id;
1056 let in_ty = self.variables[input_var].ty;
1057 let success_output = info.arms[success_arm].var_ids[0];
1058 let out_ty = self.variables[success_output].ty;
1059 let out_range = range(out_ty)?;
1060 let Some(value) = self.as_int(input_var) else {
1061 let in_range = range(in_ty)?;
1062 return if in_range.min < out_range.min || in_range.max > out_range.max {
1063 None
1064 } else {
1065 let generic_args = [in_ty, out_ty].map(GenericArgumentId::Type).to_vec();
1066 let function = db.core_info().upcast_fn.concretize(db, generic_args);
1067 statements.push(Statement::Call(StatementCall {
1068 function: function.lowered(db),
1069 inputs: vec![info.inputs[0]],
1070 with_coupon: false,
1071 outputs: vec![success_output],
1072 location: info.location,
1073 }));
1074 return Some(BlockEnd::Goto(
1075 info.arms[success_arm].block_id,
1076 Default::default(),
1077 ));
1078 };
1079 };
1080 let value = if in_ty == self.felt252 {
1081 felt252_for_downcast(value, &out_range.min)
1082 } else {
1083 value.clone()
1084 };
1085 Some(if let NormalizedResult::InRange(value) = out_range.normalized(value) {
1086 let value = ConstValue::Int(value, out_ty).intern(db);
1087 self.var_info.insert(success_output, VarInfo::Const(value));
1088 statements.push(Statement::Const(StatementConst::new_flat(value, success_output)));
1089 BlockEnd::Goto(info.arms[success_arm].block_id, Default::default())
1090 } else {
1091 BlockEnd::Goto(info.arms[failure_arm].block_id, Default::default())
1092 })
1093 } else if id == self.bounded_int_constrain {
1094 let input_var = info.inputs[0].var_id;
1095 let value = self.as_int(input_var)?;
1096 let generic_arg = generic_args[1];
1097 let constrain_value = extract_matches!(generic_arg, GenericArgumentId::Constant)
1098 .long(db)
1099 .to_int()
1100 .expect("Expected ConstValue::Int for size");
1101 let arm_idx = if value < constrain_value { 0 } else { 1 };
1102 let output = info.arms[arm_idx].var_ids[0];
1103 statements.push(self.propagate_const_and_get_statement(value.clone(), output));
1104 Some(BlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()))
1105 } else if id == self.array_get {
1106 let index = self.as_int(info.inputs[1].var_id)?.to_usize()?;
1107 if let Some(VarInfo::Snapshot(arr_info)) = self.var_info.get(&info.inputs[0].var_id)
1108 && let VarInfo::Array(infos) = arr_info.as_ref()
1109 {
1110 match infos.get(index) {
1111 Some(Some(output_var_info)) => {
1112 let arm = &info.arms[0];
1113 let output_var_info = output_var_info.clone();
1114 let box_info =
1115 VarInfo::Box(VarInfo::Snapshot(output_var_info.clone().into()).into());
1116 self.var_info.insert(arm.var_ids[0], box_info);
1117 if let VarInfo::Const(value) = output_var_info {
1118 let value_ty = value.ty(db).ok()?;
1119 let value_box_ty = corelib::core_box_ty(db, value_ty);
1120 let location = info.location;
1121 let boxed_var =
1122 Variable::with_default_context(db, value_box_ty, location);
1123 let boxed = self.variables.alloc(boxed_var.clone());
1124 let unused_boxed = self.variables.alloc(boxed_var);
1125 let snapped = self.variables.alloc(Variable::with_default_context(
1126 db,
1127 TypeLongId::Snapshot(value_box_ty).intern(db),
1128 location,
1129 ));
1130 statements.extend([
1131 Statement::Const(StatementConst::new_boxed(value, boxed)),
1132 Statement::Snapshot(StatementSnapshot {
1133 input: VarUsage { var_id: boxed, location },
1134 outputs: [unused_boxed, snapped],
1135 }),
1136 Statement::Call(StatementCall {
1137 function: self
1138 .box_forward_snapshot
1139 .concretize(db, vec![GenericArgumentId::Type(value_ty)])
1140 .lowered(db),
1141 inputs: vec![VarUsage { var_id: snapped, location }],
1142 with_coupon: false,
1143 outputs: vec![arm.var_ids[0]],
1144 location: info.location,
1145 }),
1146 ]);
1147 return Some(BlockEnd::Goto(arm.block_id, Default::default()));
1148 }
1149 }
1150 None => {
1151 return Some(BlockEnd::Goto(info.arms[1].block_id, Default::default()));
1152 }
1153 Some(None) => {}
1154 }
1155 }
1156 if index.is_zero()
1157 && let [success, failure] = info.arms.as_mut_slice()
1158 {
1159 let arr = info.inputs[0].var_id;
1160 let unused_arr_output0 = self.variables.alloc(self.variables[arr].clone());
1161 let unused_arr_output1 = self.variables.alloc(self.variables[arr].clone());
1162 info.inputs.truncate(1);
1163 info.function = GenericFunctionId::Extern(self.array_snapshot_pop_front)
1164 .concretize(db, generic_args)
1165 .lowered(db);
1166 success.var_ids.insert(0, unused_arr_output0);
1167 failure.var_ids.insert(0, unused_arr_output1);
1168 }
1169 None
1170 } else if id == self.array_pop_front {
1171 let VarInfo::Array(var_infos) = self.var_info.get(&info.inputs[0].var_id)? else {
1172 return None;
1173 };
1174 if let Some(first) = var_infos.first() {
1175 if let Some(first) = first.as_ref().cloned() {
1176 let arm = &info.arms[0];
1177 self.var_info.insert(arm.var_ids[0], VarInfo::Array(var_infos[1..].to_vec()));
1178 self.var_info.insert(arm.var_ids[1], VarInfo::Box(first.into()));
1179 }
1180 None
1181 } else {
1182 let arm = &info.arms[1];
1183 self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]));
1184 Some(BlockEnd::Goto(
1185 arm.block_id,
1186 VarRemapping {
1187 remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1188 },
1189 ))
1190 }
1191 } else if id == self.array_snapshot_pop_back || id == self.array_snapshot_pop_front {
1192 let var_info = self.var_info.get(&info.inputs[0].var_id)?;
1193 let desnapped = try_extract_matches!(var_info, VarInfo::Snapshot)?;
1194 let element_var_infos = try_extract_matches!(desnapped.as_ref(), VarInfo::Array)?;
1195 if element_var_infos.is_empty() {
1197 let arm = &info.arms[1];
1198 self.var_info.insert(arm.var_ids[0], VarInfo::Array(vec![]));
1199 Some(BlockEnd::Goto(
1200 arm.block_id,
1201 VarRemapping {
1202 remapping: FromIterator::from_iter([(arm.var_ids[0], info.inputs[0])]),
1203 },
1204 ))
1205 } else {
1206 None
1207 }
1208 } else {
1209 None
1210 }
1211 }
1212
1213 fn as_const(&self, var_id: VariableId) -> Option<ConstValueId<'db>> {
1215 try_extract_matches!(self.var_info.get(&var_id)?, VarInfo::Const).copied()
1216 }
1217
1218 fn as_int(&self, var_id: VariableId) -> Option<&BigInt> {
1220 match self.as_const(var_id)?.long(self.db) {
1221 ConstValue::Int(value, _) => Some(value),
1222 ConstValue::NonZero(const_value) => {
1223 if let ConstValue::Int(value, _) = const_value.long(self.db) {
1224 Some(value)
1225 } else {
1226 None
1227 }
1228 }
1229 _ => None,
1230 }
1231 }
1232
1233 fn maybe_replace_inputs(&self, inputs: &mut [VarUsage<'db>]) {
1235 for input in inputs {
1236 self.maybe_replace_input(input);
1237 }
1238 }
1239
1240 fn maybe_replace_input(&self, input: &mut VarUsage<'db>) {
1242 if let Some(VarInfo::Var(new_var)) = self.var_info.get(&input.var_id) {
1243 *input = *new_var;
1244 }
1245 }
1246
1247 fn try_get_specialization_arg(
1250 &self,
1251 var_info: VarInfo<'db>,
1252 ty: TypeId<'db>,
1253 unknown_vars: &mut Vec<VarUsage<'db>>,
1254 ) -> Option<SpecializationArg<'db>> {
1255 if self.db.type_size_info(ty).ok()? == TypeSizeInformation::ZeroSized {
1256 return None;
1258 }
1259
1260 match var_info {
1261 VarInfo::Const(value) => Some(SpecializationArg::Const { value, boxed: false }),
1262 VarInfo::Box(info) => try_extract_matches!(info.as_ref(), VarInfo::Const)
1263 .map(|value| SpecializationArg::Const { value: *value, boxed: true }),
1264 VarInfo::Snapshot(info) => {
1265 let desnap_ty = *extract_matches!(ty.long(self.db), TypeLongId::Snapshot);
1266 let mut local_unknown_vars: Vec<VarUsage<'db>> = Vec::new();
1268 let inner = self.try_get_specialization_arg(
1269 info.as_ref().clone(),
1270 desnap_ty,
1271 &mut local_unknown_vars,
1272 )?;
1273 unknown_vars.extend(local_unknown_vars);
1274 Some(SpecializationArg::Snapshot(Box::new(inner)))
1275 }
1276 VarInfo::Array(infos) => {
1277 let TypeLongId::Concrete(concrete_ty) = ty.long(self.db) else {
1278 unreachable!("Expected a concrete type");
1279 };
1280 let [GenericArgumentId::Type(inner_ty)] = &concrete_ty.generic_args(self.db)[..]
1281 else {
1282 unreachable!("Expected a single type generic argument");
1283 };
1284 let mut vars = vec![];
1286 let mut args = vec![];
1287 for info in infos {
1288 let info = info?;
1289 let arg = self.try_get_specialization_arg(info, *inner_ty, &mut vars)?;
1290 args.push(arg);
1291 }
1292 if !args.is_empty()
1293 && args.iter().all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
1294 {
1295 return None;
1296 }
1297 unknown_vars.extend(vars);
1298 Some(SpecializationArg::Array(*inner_ty, args))
1299 }
1300 VarInfo::Struct(infos) => {
1301 let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct)) =
1302 ty.long(self.db)
1303 else {
1304 return None;
1306 };
1307
1308 let members = self.db.concrete_struct_members(*concrete_struct).unwrap();
1309 let mut struct_args = Vec::new();
1310 let mut vars = vec![];
1312 for (member, opt_var_info) in zip_eq(members.values(), infos) {
1313 let var_info = opt_var_info?;
1314 let arg = self.try_get_specialization_arg(var_info, member.ty, &mut vars)?;
1315 struct_args.push(arg);
1316 }
1317 if !struct_args.is_empty()
1318 && struct_args
1319 .iter()
1320 .all(|arg| matches!(arg, SpecializationArg::NotSpecialized))
1321 {
1322 return None;
1323 }
1324 unknown_vars.extend(vars);
1325 Some(SpecializationArg::Struct(struct_args))
1326 }
1327 VarInfo::Enum { variant, payload } => {
1328 let mut local_unknown_vars = vec![];
1329 let payload_arg = self.try_get_specialization_arg(
1330 payload.as_ref().clone(),
1331 variant.ty,
1332 &mut local_unknown_vars,
1333 )?;
1334
1335 unknown_vars.extend(local_unknown_vars);
1336 Some(SpecializationArg::Enum { variant, payload: Box::new(payload_arg) })
1337 }
1338 VarInfo::Var(var_usage) => {
1339 unknown_vars.push(var_usage);
1340 Some(SpecializationArg::NotSpecialized)
1341 }
1342 }
1343 }
1344
1345 pub fn should_skip_const_folding(&self, db: &'db dyn Database) -> bool {
1347 if db.optimizations().skip_const_folding() {
1348 return true;
1349 }
1350
1351 if self.caller_base.base_semantic_function(db).generic_function(db)
1354 == GenericFunctionWithBodyId::Free(self.libfunc_info.panic_with_const_felt252)
1355 {
1356 return true;
1357 }
1358 false
1359 }
1360}
1361
1362fn var_info_if_copy<'db>(
1364 variables: &VariableArena<'db>,
1365 input: VarUsage<'db>,
1366) -> Option<VarInfo<'db>> {
1367 variables[input.var_id].info.copyable.is_ok().then_some(VarInfo::Var(input))
1368}
1369
1370#[salsa::tracked(returns(ref))]
1372fn priv_const_folding_info<'db>(
1373 db: &'db dyn Database,
1374) -> crate::optimizations::const_folding::ConstFoldingLibfuncInfo<'db> {
1375 ConstFoldingLibfuncInfo::new(db)
1376}
1377
1378#[derive(Debug, PartialEq, Eq, salsa::Update)]
1380pub struct ConstFoldingLibfuncInfo<'db> {
1381 felt_sub: ExternFunctionId<'db>,
1383 felt_add: ExternFunctionId<'db>,
1385 felt_mul: ExternFunctionId<'db>,
1387 felt_div: ExternFunctionId<'db>,
1389 into_box: ExternFunctionId<'db>,
1391 unbox: ExternFunctionId<'db>,
1393 box_forward_snapshot: GenericFunctionId<'db>,
1395 eq_fns: OrderedHashSet<ExternFunctionId<'db>>,
1397 uadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1399 usub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1401 diff_fns: OrderedHashSet<ExternFunctionId<'db>>,
1403 iadd_fns: OrderedHashSet<ExternFunctionId<'db>>,
1405 isub_fns: OrderedHashSet<ExternFunctionId<'db>>,
1407 wide_mul_fns: OrderedHashSet<ExternFunctionId<'db>>,
1409 div_rem_fns: OrderedHashSet<ExternFunctionId<'db>>,
1411 bounded_int_add: ExternFunctionId<'db>,
1413 bounded_int_sub: ExternFunctionId<'db>,
1415 bounded_int_constrain: ExternFunctionId<'db>,
1417 array_get: ExternFunctionId<'db>,
1419 array_snapshot_pop_front: ExternFunctionId<'db>,
1421 array_snapshot_pop_back: ExternFunctionId<'db>,
1423 array_len: ExternFunctionId<'db>,
1425 array_new: ExternFunctionId<'db>,
1427 array_append: ExternFunctionId<'db>,
1429 array_pop_front: ExternFunctionId<'db>,
1431 storage_base_address_from_felt252: ExternFunctionId<'db>,
1433 storage_base_address_const: GenericFunctionId<'db>,
1435 panic_with_felt252: FunctionId<'db>,
1437 pub panic_with_const_felt252: FreeFunctionId<'db>,
1439 panic_with_byte_array: FunctionId<'db>,
1441 type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>>,
1443 const_calculation_info: Arc<ConstCalcInfo<'db>>,
1445}
1446impl<'db> ConstFoldingLibfuncInfo<'db> {
1447 fn new(db: &'db dyn Database) -> Self {
1448 let core = ModuleHelper::core(db);
1449 let box_module = core.submodule("box");
1450 let integer_module = core.submodule("integer");
1451 let internal_module = core.submodule("internal");
1452 let bounded_int_module = internal_module.submodule("bounded_int");
1453 let num_module = internal_module.submodule("num");
1454 let array_module = core.submodule("array");
1455 let starknet_module = core.submodule("starknet");
1456 let storage_access_module = starknet_module.submodule("storage_access");
1457 let utypes = ["u8", "u16", "u32", "u64", "u128"];
1458 let itypes = ["i8", "i16", "i32", "i64", "i128"];
1459 let eq_fns = OrderedHashSet::<_>::from_iter(
1460 chain!(utypes, itypes).map(|ty| integer_module.extern_function_id(&format!("{ty}_eq"))),
1461 );
1462 let uadd_fns = OrderedHashSet::<_>::from_iter(
1463 utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_add"))),
1464 );
1465 let usub_fns = OrderedHashSet::<_>::from_iter(
1466 utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_overflowing_sub"))),
1467 );
1468 let diff_fns = OrderedHashSet::<_>::from_iter(
1469 itypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_diff"))),
1470 );
1471 let iadd_fns =
1472 OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1473 integer_module.extern_function_id(&format!("{ty}_overflowing_add_impl"))
1474 }));
1475 let isub_fns =
1476 OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
1477 integer_module.extern_function_id(&format!("{ty}_overflowing_sub_impl"))
1478 }));
1479 let wide_mul_fns = OrderedHashSet::<_>::from_iter(chain!(
1480 [bounded_int_module.extern_function_id("bounded_int_mul")],
1481 ["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
1482 .map(|ty| integer_module.extern_function_id(&format!("{ty}_wide_mul"))),
1483 ));
1484 let div_rem_fns = OrderedHashSet::<_>::from_iter(chain!(
1485 [bounded_int_module.extern_function_id("bounded_int_div_rem")],
1486 utypes.map(|ty| integer_module.extern_function_id(&format!("{ty}_safe_divmod"))),
1487 ));
1488 let type_info: OrderedHashMap<TypeId<'db>, TypeInfo<'db>> = OrderedHashMap::from_iter(
1489 [
1490 ("u8", false, true),
1491 ("u16", false, true),
1492 ("u32", false, true),
1493 ("u64", false, true),
1494 ("u128", false, true),
1495 ("u256", false, false),
1496 ("i8", true, true),
1497 ("i16", true, true),
1498 ("i32", true, true),
1499 ("i64", true, true),
1500 ("i128", true, true),
1501 ]
1502 .map(|(ty_name, as_bounded_int, inc_dec): (&'static str, bool, bool)| {
1503 let ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, ty_name), vec![]);
1504 let is_zero = if as_bounded_int {
1505 bounded_int_module
1506 .function_id("bounded_int_is_zero", vec![GenericArgumentId::Type(ty)])
1507 } else {
1508 integer_module.function_id(
1509 SmolStrId::from(db, format!("{ty_name}_is_zero")).long(db).as_str(),
1510 vec![],
1511 )
1512 }
1513 .lowered(db);
1514 let (inc, dec) = if inc_dec {
1515 (
1516 Some(
1517 num_module
1518 .function_id(
1519 SmolStrId::from(db, format!("{ty_name}_inc")).long(db).as_str(),
1520 vec![],
1521 )
1522 .lowered(db),
1523 ),
1524 Some(
1525 num_module
1526 .function_id(
1527 SmolStrId::from(db, format!("{ty_name}_dec")).long(db).as_str(),
1528 vec![],
1529 )
1530 .lowered(db),
1531 ),
1532 )
1533 } else {
1534 (None, None)
1535 };
1536 let info = TypeInfo { is_zero, inc, dec };
1537 (ty, info)
1538 }),
1539 );
1540 Self {
1541 felt_sub: core.extern_function_id("felt252_sub"),
1542 felt_add: core.extern_function_id("felt252_add"),
1543 felt_mul: core.extern_function_id("felt252_mul"),
1544 felt_div: core.extern_function_id("felt252_div"),
1545 into_box: box_module.extern_function_id("into_box"),
1546 unbox: box_module.extern_function_id("unbox"),
1547 box_forward_snapshot: box_module.generic_function_id("box_forward_snapshot"),
1548 eq_fns,
1549 uadd_fns,
1550 usub_fns,
1551 diff_fns,
1552 iadd_fns,
1553 isub_fns,
1554 wide_mul_fns,
1555 div_rem_fns,
1556 bounded_int_add: bounded_int_module.extern_function_id("bounded_int_add"),
1557 bounded_int_sub: bounded_int_module.extern_function_id("bounded_int_sub"),
1558 bounded_int_constrain: bounded_int_module.extern_function_id("bounded_int_constrain"),
1559 array_get: array_module.extern_function_id("array_get"),
1560 array_snapshot_pop_front: array_module.extern_function_id("array_snapshot_pop_front"),
1561 array_snapshot_pop_back: array_module.extern_function_id("array_snapshot_pop_back"),
1562 array_len: array_module.extern_function_id("array_len"),
1563 array_new: array_module.extern_function_id("array_new"),
1564 array_append: array_module.extern_function_id("array_append"),
1565 array_pop_front: array_module.extern_function_id("array_pop_front"),
1566 storage_base_address_from_felt252: storage_access_module
1567 .extern_function_id("storage_base_address_from_felt252"),
1568 storage_base_address_const: storage_access_module
1569 .generic_function_id("storage_base_address_const"),
1570 panic_with_felt252: core.function_id("panic_with_felt252", vec![]).lowered(db),
1571 panic_with_const_felt252: core.free_function_id("panic_with_const_felt252"),
1572 panic_with_byte_array: core
1573 .submodule("panics")
1574 .function_id("panic_with_byte_array", vec![])
1575 .lowered(db),
1576 type_info,
1577 const_calculation_info: db.const_calc_info(),
1578 }
1579 }
1580}
1581
1582impl<'db> std::ops::Deref for ConstFoldingContext<'db, '_> {
1583 type Target = ConstFoldingLibfuncInfo<'db>;
1584 fn deref(&self) -> &ConstFoldingLibfuncInfo<'db> {
1585 self.libfunc_info
1586 }
1587}
1588
1589impl<'a> std::ops::Deref for ConstFoldingLibfuncInfo<'a> {
1590 type Target = ConstCalcInfo<'a>;
1591 fn deref(&self) -> &ConstCalcInfo<'a> {
1592 &self.const_calculation_info
1593 }
1594}
1595
1596#[derive(Debug, PartialEq, Eq, salsa::Update)]
1598struct TypeInfo<'db> {
1599 is_zero: FunctionId<'db>,
1601 inc: Option<FunctionId<'db>>,
1603 dec: Option<FunctionId<'db>>,
1605}
1606
1607trait TypeRangeNormalizer {
1608 fn normalized(&self, value: BigInt) -> NormalizedResult;
1611}
1612impl TypeRangeNormalizer for TypeRange {
1613 fn normalized(&self, value: BigInt) -> NormalizedResult {
1614 if value < self.min {
1615 NormalizedResult::Under(value - &self.min + &self.max + 1)
1616 } else if value > self.max {
1617 NormalizedResult::Over(value + &self.min - &self.max - 1)
1618 } else {
1619 NormalizedResult::InRange(value)
1620 }
1621 }
1622}
1623
1624enum NormalizedResult {
1626 InRange(BigInt),
1628 Over(BigInt),
1630 Under(BigInt),
1632}