1use cairo_lang_defs::ids::LanguageElementId;
8use cairo_lang_semantic as semantic;
9use cairo_lang_semantic::ConcreteFunction;
10use cairo_lang_semantic::corelib::{core_array_felt252_ty, core_module, get_ty_by_name, unit_ty};
11use cairo_lang_semantic::items::functions::{GenericFunctionId, ImplGenericFunctionId};
12use cairo_lang_semantic::items::imp::ImplId;
13use cairo_lang_utils::{Intern, LookupIntern};
14use itertools::{Itertools, chain, zip_eq};
15use semantic::{TypeId, TypeLongId};
16
17use crate::borrow_check::Demand;
18use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
19use crate::borrow_check::demand::{AuxCombine, DemandReporter};
20use crate::db::LoweringGroup;
21use crate::ids::{
22 ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, GeneratedFunction,
23 SemanticFunctionIdEx,
24};
25use crate::lower::context::{VarRequest, VariableAllocator};
26use crate::{
27 BlockEnd, BlockId, Lowered, MatchInfo, Statement, StatementCall, StatementStructConstruct,
28 StatementStructDestructure, VarRemapping, VarUsage, VariableId,
29};
30
31pub type DestructAdderDemand = Demand<VariableId, (), PanicState>;
32
33#[derive(PartialEq, Eq, PartialOrd, Ord)]
35enum AddDestructFlowType {
36 Plain,
38 PanicVar,
40 PanicPostMatch,
42}
43
44pub struct DestructAdder<'a> {
46 db: &'a dyn LoweringGroup,
47 lowered: &'a Lowered,
48 destructions: Vec<DestructionEntry>,
49 panic_ty: TypeId,
50 never_fn_actual_return_ty: TypeId,
52 is_panic_destruct_fn: bool,
53}
54
55enum DestructionEntry {
57 Plain(PlainDestructionEntry),
59 Panic(PanicDeconstructionEntry),
61}
62
63struct PlainDestructionEntry {
64 position: StatementLocation,
65 var_id: VariableId,
66 impl_id: ImplId,
67}
68struct PanicDeconstructionEntry {
69 panic_location: PanicLocation,
70 var_id: VariableId,
71 impl_id: ImplId,
72}
73
74impl DestructAdder<'_> {
75 fn set_post_stmt_destruct(
77 &mut self,
78 introductions: &[VariableId],
79 info: &mut DestructAdderDemand,
80 block_id: BlockId,
81 statement_index: usize,
82 ) {
83 if let [panic_var] = introductions[..] {
84 let var = &self.lowered.variables[panic_var];
85 if [self.panic_ty, self.never_fn_actual_return_ty].contains(&var.ty) {
86 info.aux = PanicState::EndsWithPanic(vec![PanicLocation::PanicVar {
87 statement_location: (block_id, statement_index),
88 }]);
89 }
90 }
91 }
92
93 fn set_post_match_state(
96 &mut self,
97 introduced_vars: &[VariableId],
98 info: &mut DestructAdderDemand,
99 match_block_id: BlockId,
100 target_block_id: BlockId,
101 arm_idx: usize,
102 ) {
103 if arm_idx != 1 {
104 return;
106 }
107 if let [err_var] = introduced_vars[..] {
108 let var = &self.lowered.variables[err_var];
109
110 let long_ty = var.ty.lookup_intern(self.db);
111 let TypeLongId::Tuple(tys) = long_ty else {
112 return;
113 };
114 if tys.first() == Some(&self.panic_ty) {
115 info.aux = PanicState::EndsWithPanic(vec![PanicLocation::PanicMatch {
116 match_block_id,
117 target_block_id,
118 }]);
119 }
120 }
121 }
122}
123
124impl DemandReporter<VariableId, PanicState> for DestructAdder<'_> {
125 type IntroducePosition = StatementLocation;
126 type UsePosition = ();
127
128 fn drop_aux(
129 &mut self,
130 position: StatementLocation,
131 var_id: VariableId,
132 panic_state: PanicState,
133 ) {
134 let var = &self.lowered.variables[var_id];
135 if var.info.droppable.is_ok() {
139 return;
140 };
141 if let Ok(impl_id) = var.info.destruct_impl.clone() {
143 self.destructions.push(DestructionEntry::Plain(PlainDestructionEntry {
144 position,
145 var_id,
146 impl_id,
147 }));
148 return;
149 }
150 if let Ok(impl_id) = var.info.panic_destruct_impl.clone() {
152 if let PanicState::EndsWithPanic(panic_locations) = panic_state {
153 for panic_location in panic_locations {
154 self.destructions.push(DestructionEntry::Panic(PanicDeconstructionEntry {
155 panic_location,
156 var_id,
157 impl_id,
158 }));
159 }
160 return;
161 }
162 }
163
164 panic!("Borrow checker should have caught this.")
165 }
166}
167
168#[derive(Clone, Default)]
171pub enum PanicState {
172 EndsWithPanic(Vec<PanicLocation>),
176 #[default]
177 Otherwise,
178}
179impl AuxCombine for PanicState {
181 fn merge<'a, I: Iterator<Item = &'a Self>>(iter: I) -> Self
182 where
183 Self: 'a,
184 {
185 let mut panic_locations = vec![];
186 for state in iter {
187 if let Self::EndsWithPanic(locations) = state {
188 panic_locations.extend_from_slice(locations);
189 } else {
190 return Self::Otherwise;
191 }
192 }
193
194 Self::EndsWithPanic(panic_locations)
195 }
196}
197
198#[derive(Clone)]
200pub enum PanicLocation {
201 PanicVar { statement_location: StatementLocation },
203 PanicMatch { match_block_id: BlockId, target_block_id: BlockId },
205}
206
207impl Analyzer<'_> for DestructAdder<'_> {
208 type Info = DestructAdderDemand;
209
210 fn visit_stmt(
211 &mut self,
212 info: &mut Self::Info,
213 (block_id, statement_index): StatementLocation,
214 stmt: &Statement,
215 ) {
216 self.set_post_stmt_destruct(stmt.outputs(), info, block_id, statement_index);
217 info.variables_introduced(self, stmt.outputs(), (block_id, statement_index + 1));
219 info.variables_used(self, stmt.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())));
220 }
221
222 fn visit_goto(
223 &mut self,
224 info: &mut Self::Info,
225 _statement_location: StatementLocation,
226 _target_block_id: BlockId,
227 remapping: &VarRemapping,
228 ) {
229 info.apply_remapping(self, remapping.iter().map(|(dst, src)| (dst, (&src.var_id, ()))));
230 }
231
232 fn merge_match(
233 &mut self,
234 (block_id, _statement_index): StatementLocation,
235 match_info: &MatchInfo,
236 infos: impl Iterator<Item = Self::Info>,
237 ) -> Self::Info {
238 let arm_demands = zip_eq(match_info.arms(), infos)
239 .enumerate()
240 .map(|(arm_idx, (arm, mut demand))| {
241 let use_position = (arm.block_id, 0);
242 self.set_post_match_state(
243 &arm.var_ids,
244 &mut demand,
245 block_id,
246 arm.block_id,
247 arm_idx,
248 );
249 demand.variables_introduced(self, &arm.var_ids, use_position);
250 (demand, use_position)
251 })
252 .collect_vec();
253 let mut demand = DestructAdderDemand::merge_demands(&arm_demands, self);
254 demand.variables_used(
255 self,
256 match_info.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
257 );
258 demand
259 }
260
261 fn info_from_return(
262 &mut self,
263 statement_location: StatementLocation,
264 vars: &[VarUsage],
265 ) -> Self::Info {
266 let mut info = DestructAdderDemand::default();
267 if self.is_panic_destruct_fn {
269 info.aux =
270 PanicState::EndsWithPanic(vec![PanicLocation::PanicVar { statement_location }]);
271 }
272
273 info.variables_used(self, vars.iter().map(|VarUsage { var_id, .. }| (var_id, ())));
274 info
275 }
276}
277
278fn panic_ty(db: &dyn LoweringGroup) -> semantic::TypeId {
279 get_ty_by_name(db, core_module(db), "Panic".into(), vec![])
280}
281
282pub fn add_destructs(
288 db: &dyn LoweringGroup,
289 function_id: ConcreteFunctionWithBodyId,
290 lowered: &mut Lowered,
291) {
292 if lowered.blocks.is_empty() {
293 return;
294 }
295
296 let Ok(is_panic_destruct_fn) = function_id.is_panic_destruct_fn(db) else {
297 return;
298 };
299
300 let panic_ty = panic_ty(db);
301 let felt_arr_ty = core_array_felt252_ty(db);
302 let never_fn_actual_return_ty = TypeLongId::Tuple(vec![panic_ty, felt_arr_ty]).intern(db);
303 let checker = DestructAdder {
304 db,
305 lowered,
306 destructions: vec![],
307 panic_ty,
308 never_fn_actual_return_ty,
309 is_panic_destruct_fn,
310 };
311 let mut analysis = BackAnalysis::new(lowered, checker);
312 let mut root_demand = analysis.get_root_info();
313 root_demand.variables_introduced(
314 &mut analysis.analyzer,
315 &lowered.parameters,
316 (BlockId::root(), 0),
317 );
318 assert!(root_demand.finalize(), "Undefined variable should not happen at this stage");
319 let DestructAdder { destructions, .. } = analysis.analyzer;
320
321 let mut variables = VariableAllocator::new(
322 db,
323 function_id.base_semantic_function(db).function_with_body_id(db),
324 std::mem::take(&mut lowered.variables),
325 )
326 .unwrap();
327
328 let info = db.core_info();
329 let plain_trait_function = info.destruct_fn;
330 let panic_trait_function = info.panic_destruct_fn;
331
332 let stable_ptr =
334 function_id.base_semantic_function(db).function_with_body_id(db).untyped_stable_ptr(db);
335
336 let location = variables.get_location(stable_ptr);
337
338 let as_tuple = |entry: &DestructionEntry| match entry {
345 DestructionEntry::Plain(plain_destruct) => {
346 (plain_destruct.position.0.0, plain_destruct.position.1, AddDestructFlowType::Plain, 0)
347 }
348 DestructionEntry::Panic(panic_destruct) => match panic_destruct.panic_location {
349 PanicLocation::PanicMatch { target_block_id, match_block_id } => {
350 (target_block_id.0, 0, AddDestructFlowType::PanicPostMatch, match_block_id.0)
351 }
352 PanicLocation::PanicVar { statement_location } => {
353 (statement_location.0.0, statement_location.1, AddDestructFlowType::PanicVar, 0)
354 }
355 },
356 };
357
358 for ((block_id, statement_idx, destruct_type, match_block_id), destructions) in
359 &destructions.into_iter().sorted_by_key(as_tuple).rev().chunk_by(as_tuple)
360 {
361 let mut stmts = vec![];
362
363 let first_panic_var = variables.new_var(VarRequest { ty: panic_ty, location });
364 let mut last_panic_var = first_panic_var;
365
366 for destruction in destructions {
367 let output_var = variables.new_var(VarRequest { ty: unit_ty(db), location });
368
369 match destruction {
370 DestructionEntry::Plain(plain_destruct) => {
371 let semantic_function = semantic::FunctionLongId {
372 function: ConcreteFunction {
373 generic_function: GenericFunctionId::Impl(ImplGenericFunctionId {
374 impl_id: plain_destruct.impl_id,
375 function: plain_trait_function,
376 }),
377 generic_args: vec![],
378 },
379 }
380 .intern(db);
381
382 stmts.push(StatementCall {
383 function: semantic_function.lowered(db),
384 inputs: vec![VarUsage { var_id: plain_destruct.var_id, location }],
385 with_coupon: false,
386 outputs: vec![output_var],
387 location: variables.variables[plain_destruct.var_id].location,
388 })
389 }
390
391 DestructionEntry::Panic(panic_destruct) => {
392 let semantic_function = semantic::FunctionLongId {
393 function: ConcreteFunction {
394 generic_function: GenericFunctionId::Impl(ImplGenericFunctionId {
395 impl_id: panic_destruct.impl_id,
396 function: panic_trait_function,
397 }),
398 generic_args: vec![],
399 },
400 }
401 .intern(db);
402
403 let new_panic_var = variables.new_var(VarRequest { ty: panic_ty, location });
404
405 stmts.push(StatementCall {
406 function: semantic_function.lowered(db),
407 inputs: vec![
408 VarUsage { var_id: panic_destruct.var_id, location },
409 VarUsage { var_id: last_panic_var, location },
410 ],
411 with_coupon: false,
412 outputs: vec![new_panic_var, output_var],
413 location,
414 });
415 last_panic_var = new_panic_var;
416 }
417 }
418 }
419
420 match destruct_type {
421 AddDestructFlowType::Plain => {
422 let block = &mut lowered.blocks[BlockId(block_id)];
423 block
424 .statements
425 .splice(statement_idx..statement_idx, stmts.into_iter().map(Statement::Call));
426 }
427 AddDestructFlowType::PanicPostMatch => {
428 let block = &mut lowered.blocks[BlockId(match_block_id)];
429 let BlockEnd::Match { info: MatchInfo::Enum(info) } = &mut block.end else {
430 unreachable!();
431 };
432
433 let arm = &mut info.arms[1];
434 let tuple_var = &mut arm.var_ids[0];
435 let tuple_ty = variables.variables[*tuple_var].ty;
436 let new_tuple_var = variables.new_var(VarRequest { ty: tuple_ty, location });
437 let orig_tuple_var = *tuple_var;
438 *tuple_var = new_tuple_var;
439 let long_ty = tuple_ty.lookup_intern(db);
440 let TypeLongId::Tuple(tys) = long_ty else { unreachable!() };
441
442 let vars = tys
443 .iter()
444 .copied()
445 .map(|ty| variables.new_var(VarRequest { ty, location }))
446 .collect::<Vec<_>>();
447
448 *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() = vars[0];
449
450 let target_block_id = arm.block_id;
451
452 let block = &mut lowered.blocks[target_block_id];
453
454 block.statements.splice(
455 0..0,
456 chain!(
457 [Statement::StructDestructure(StatementStructDestructure {
458 input: VarUsage { var_id: new_tuple_var, location },
459 outputs: chain!([first_panic_var], vars.iter().skip(1).cloned())
460 .collect(),
461 })],
462 stmts.into_iter().map(Statement::Call),
463 [Statement::StructConstruct(StatementStructConstruct {
464 inputs: vars
465 .into_iter()
466 .map(|var_id| VarUsage { var_id, location })
467 .collect(),
468 output: orig_tuple_var,
469 })]
470 ),
471 );
472 }
473 AddDestructFlowType::PanicVar => {
474 let block = &mut lowered.blocks[BlockId(block_id)];
475
476 let idx = match block.statements.get_mut(statement_idx) {
477 Some(stmt) => {
478 match stmt {
479 Statement::StructConstruct(stmt) => {
480 let panic_var = &mut stmt.output;
481 *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() = *panic_var;
482 *panic_var = first_panic_var;
483 }
484 Statement::Call(stmt) => {
485 let tuple_var = &mut stmt.outputs[0];
486 let new_tuple_var = variables.new_var(VarRequest {
487 ty: never_fn_actual_return_ty,
488 location,
489 });
490 let orig_tuple_var = *tuple_var;
491 *tuple_var = new_tuple_var;
492 let new_panic_var =
493 variables.new_var(VarRequest { ty: panic_ty, location });
494 let new_arr_var =
495 variables.new_var(VarRequest { ty: felt_arr_ty, location });
496 *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() =
497 new_panic_var;
498 let idx = statement_idx + 1;
499 block.statements.splice(
500 idx..idx,
501 chain!(
502 [Statement::StructDestructure(
503 StatementStructDestructure {
504 input: VarUsage { var_id: new_tuple_var, location },
505 outputs: vec![first_panic_var, new_arr_var],
506 }
507 )],
508 stmts.into_iter().map(Statement::Call),
509 [Statement::StructConstruct(StatementStructConstruct {
510 inputs: [new_panic_var, new_arr_var]
511 .into_iter()
512 .map(|var_id| VarUsage { var_id, location })
513 .collect(),
514 output: orig_tuple_var,
515 })]
516 ),
517 );
518 stmts = vec![];
519 }
520 _ => unreachable!("Expected a struct construct or a call statement."),
521 }
522 statement_idx + 1
523 }
524 None => {
525 assert_eq!(statement_idx, block.statements.len());
526 let panic_var = match &mut block.end {
527 BlockEnd::Return(vars, _) => &mut vars[0].var_id,
528 _ => unreachable!("Expected a return statement."),
529 };
530
531 stmts.first_mut().unwrap().inputs.get_mut(1).unwrap().var_id = *panic_var;
532 *panic_var = last_panic_var;
533 statement_idx
534 }
535 };
536
537 block.statements.splice(idx..idx, stmts.into_iter().map(Statement::Call));
538 }
539 };
540 }
541
542 lowered.variables = variables.variables;
543
544 match function_id.lookup_intern(db) {
545 ConcreteFunctionWithBodyLongId::Specialized(_) => return,
547 ConcreteFunctionWithBodyLongId::Semantic(id)
548 | ConcreteFunctionWithBodyLongId::Generated(GeneratedFunction { parent: id, .. }) => {
549 if id.substitution(db).map(|s| s.is_empty()).unwrap_or_default() {
551 return;
552 }
553 }
554 }
555
556 for (_, var) in lowered.variables.iter_mut() {
557 if var.info.copyable.is_err() {
559 var.info.copyable = db.copyable(var.ty);
560 }
561 if var.info.droppable.is_err() {
562 var.info.droppable = db.droppable(var.ty);
563 }
564 }
565}