1use cairo_lang_defs::ids::LanguageElementId;
8use cairo_lang_filesystem::ids::SmolStrId;
9use cairo_lang_semantic as semantic;
10use cairo_lang_semantic::ConcreteFunction;
11use cairo_lang_semantic::corelib::{
12 CorelibSemantic, core_array_felt252_ty, core_module, get_ty_by_name, unit_ty,
13};
14use cairo_lang_semantic::items::functions::{GenericFunctionId, ImplGenericFunctionId};
15use cairo_lang_semantic::items::imp::ImplId;
16use cairo_lang_semantic::types::TypesSemantic;
17use cairo_lang_utils::Intern;
18use itertools::{Itertools, chain, zip_eq};
19use salsa::Database;
20use semantic::{TypeId, TypeLongId};
21
22use crate::borrow_check::Demand;
23use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
24use crate::borrow_check::demand::{AuxCombine, DemandReporter};
25use crate::ids::{
26 ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, GeneratedFunction,
27 SemanticFunctionIdEx,
28};
29use crate::lower::context::{VarRequest, VariableAllocator};
30use crate::{
31 BlockEnd, BlockId, Lowered, MatchInfo, Statement, StatementCall, StatementStructConstruct,
32 StatementStructDestructure, VarRemapping, VarUsage, VariableId,
33};
34
35pub type DestructAdderDemand = Demand<VariableId, (), PanicState>;
36
37#[derive(PartialEq, Eq, PartialOrd, Ord)]
39enum AddDestructFlowType {
40 Plain,
42 PanicVar,
44 PanicPostMatch,
46}
47
48pub struct DestructAdder<'db, 'a> {
50 db: &'db dyn Database,
51 lowered: &'a Lowered<'db>,
52 destructions: Vec<DestructionEntry<'db>>,
53 panic_ty: TypeId<'db>,
54 never_fn_actual_return_ty: TypeId<'db>,
56 is_panic_destruct_fn: bool,
57}
58
59enum DestructionEntry<'db> {
61 Plain(PlainDestructionEntry<'db>),
63 Panic(PanicDeconstructionEntry<'db>),
65}
66
67struct PlainDestructionEntry<'db> {
68 position: StatementLocation,
69 var_id: VariableId,
70 impl_id: ImplId<'db>,
71}
72struct PanicDeconstructionEntry<'db> {
73 panic_location: PanicLocation,
74 var_id: VariableId,
75 impl_id: ImplId<'db>,
76}
77
78impl<'db> DestructAdder<'db, '_> {
79 fn set_post_stmt_destruct(
81 &mut self,
82 introductions: &[VariableId],
83 info: &mut DestructAdderDemand,
84 block_id: BlockId,
85 statement_index: usize,
86 ) {
87 if let [panic_var] = introductions[..] {
88 let var = &self.lowered.variables[panic_var];
89 if [self.panic_ty, self.never_fn_actual_return_ty].contains(&var.ty) {
90 info.aux = PanicState::EndsWithPanic(vec![PanicLocation::PanicVar {
91 statement_location: (block_id, statement_index),
92 }]);
93 }
94 }
95 }
96
97 fn set_post_match_state(
100 &mut self,
101 introduced_vars: &[VariableId],
102 info: &mut DestructAdderDemand,
103 match_block_id: BlockId,
104 target_block_id: BlockId,
105 arm_idx: usize,
106 ) {
107 if arm_idx != 1 {
108 return;
110 }
111 if let [err_var] = introduced_vars[..] {
112 let var = &self.lowered.variables[err_var];
113
114 let long_ty = var.ty.long(self.db);
115 let TypeLongId::Tuple(tys) = long_ty else {
116 return;
117 };
118 if tys.first() == Some(&self.panic_ty) {
119 info.aux = PanicState::EndsWithPanic(vec![PanicLocation::PanicMatch {
120 match_block_id,
121 target_block_id,
122 }]);
123 }
124 }
125 }
126}
127
128impl<'db> DemandReporter<VariableId, PanicState> for DestructAdder<'db, '_> {
129 type IntroducePosition = StatementLocation;
130 type UsePosition = ();
131
132 fn drop_aux(
133 &mut self,
134 position: StatementLocation,
135 var_id: VariableId,
136 panic_state: PanicState,
137 ) {
138 let var = &self.lowered.variables[var_id];
139 if var.info.droppable.is_ok() {
143 return;
144 };
145 if let Ok(impl_id) = var.info.destruct_impl.clone() {
147 self.destructions.push(DestructionEntry::Plain(PlainDestructionEntry {
148 position,
149 var_id,
150 impl_id,
151 }));
152 return;
153 }
154 if let Ok(impl_id) = var.info.panic_destruct_impl.clone()
156 && let PanicState::EndsWithPanic(panic_locations) = panic_state
157 {
158 for panic_location in panic_locations {
159 self.destructions.push(DestructionEntry::Panic(PanicDeconstructionEntry {
160 panic_location,
161 var_id,
162 impl_id,
163 }));
164 }
165 return;
166 }
167
168 panic!("Borrow checker should have caught this.")
169 }
170}
171
172#[derive(Clone, Default)]
175pub enum PanicState {
176 EndsWithPanic(Vec<PanicLocation>),
180 #[default]
181 Otherwise,
182}
183impl AuxCombine for PanicState {
185 fn merge<'a, I: Iterator<Item = &'a Self>>(iter: I) -> Self
186 where
187 Self: 'a,
188 {
189 let mut panic_locations = vec![];
190 for state in iter {
191 if let Self::EndsWithPanic(locations) = state {
192 panic_locations.extend_from_slice(locations);
193 } else {
194 return Self::Otherwise;
195 }
196 }
197
198 Self::EndsWithPanic(panic_locations)
199 }
200}
201
202#[derive(Clone)]
204pub enum PanicLocation {
205 PanicVar { statement_location: StatementLocation },
207 PanicMatch { match_block_id: BlockId, target_block_id: BlockId },
209}
210
211impl<'db> Analyzer<'db, '_> for DestructAdder<'db, '_> {
212 type Info = DestructAdderDemand;
213
214 fn visit_stmt(
215 &mut self,
216 info: &mut Self::Info,
217 (block_id, statement_index): StatementLocation,
218 stmt: &Statement<'db>,
219 ) {
220 self.set_post_stmt_destruct(stmt.outputs(), info, block_id, statement_index);
221 info.variables_introduced(self, stmt.outputs(), (block_id, statement_index + 1));
223 info.variables_used(self, stmt.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())));
224 }
225
226 fn visit_goto(
227 &mut self,
228 info: &mut Self::Info,
229 _statement_location: StatementLocation,
230 _target_block_id: BlockId,
231 remapping: &VarRemapping<'db>,
232 ) {
233 info.apply_remapping(self, remapping.iter().map(|(dst, src)| (dst, (&src.var_id, ()))));
234 }
235
236 fn merge_match(
237 &mut self,
238 (block_id, _statement_index): StatementLocation,
239 match_info: &MatchInfo<'db>,
240 infos: impl Iterator<Item = Self::Info>,
241 ) -> Self::Info {
242 let arm_demands = zip_eq(match_info.arms(), infos)
243 .enumerate()
244 .map(|(arm_idx, (arm, mut demand))| {
245 let use_position = (arm.block_id, 0);
246 self.set_post_match_state(
247 &arm.var_ids,
248 &mut demand,
249 block_id,
250 arm.block_id,
251 arm_idx,
252 );
253 demand.variables_introduced(self, &arm.var_ids, use_position);
254 (demand, use_position)
255 })
256 .collect_vec();
257 let mut demand = DestructAdderDemand::merge_demands(&arm_demands, self);
258 demand.variables_used(
259 self,
260 match_info.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
261 );
262 demand
263 }
264
265 fn info_from_return(
266 &mut self,
267 statement_location: StatementLocation,
268 vars: &[VarUsage<'db>],
269 ) -> Self::Info {
270 let mut info = DestructAdderDemand::default();
271 if self.is_panic_destruct_fn {
273 info.aux =
274 PanicState::EndsWithPanic(vec![PanicLocation::PanicVar { statement_location }]);
275 }
276
277 info.variables_used(self, vars.iter().map(|VarUsage { var_id, .. }| (var_id, ())));
278 info
279 }
280}
281
282fn panic_ty<'db>(db: &'db dyn Database) -> semantic::TypeId<'db> {
283 get_ty_by_name(db, core_module(db), SmolStrId::from(db, "Panic"), vec![])
284}
285
286pub fn add_destructs<'db>(
292 db: &'db dyn Database,
293 function_id: ConcreteFunctionWithBodyId<'db>,
294 lowered: &mut Lowered<'db>,
295) {
296 if lowered.blocks.is_empty() {
297 return;
298 }
299
300 let Ok(is_panic_destruct_fn) = function_id.is_panic_destruct_fn(db) else {
301 return;
302 };
303
304 let panic_ty = panic_ty(db);
305 let felt_arr_ty = core_array_felt252_ty(db);
306 let never_fn_actual_return_ty = TypeLongId::Tuple(vec![panic_ty, felt_arr_ty]).intern(db);
307 let checker = DestructAdder {
308 db,
309 lowered,
310 destructions: vec![],
311 panic_ty,
312 never_fn_actual_return_ty,
313 is_panic_destruct_fn,
314 };
315 let mut analysis = BackAnalysis::new(lowered, checker);
316 let mut root_demand = analysis.get_root_info();
317 root_demand.variables_introduced(
318 &mut analysis.analyzer,
319 &lowered.parameters,
320 (BlockId::root(), 0),
321 );
322 assert!(root_demand.finalize(), "Undefined variable should not happen at this stage");
323 let DestructAdder { destructions, .. } = analysis.analyzer;
324
325 let mut variables = VariableAllocator::new(
326 db,
327 function_id.base_semantic_function(db).function_with_body_id(db),
328 std::mem::take(&mut lowered.variables),
329 )
330 .unwrap();
331
332 let info = db.core_info();
333 let plain_trait_function = info.destruct_fn;
334 let panic_trait_function = info.panic_destruct_fn;
335
336 let stable_ptr =
338 function_id.base_semantic_function(db).function_with_body_id(db).untyped_stable_ptr(db);
339
340 let location = variables.get_location(stable_ptr);
341
342 let as_tuple = |entry: &DestructionEntry<'_>| match entry {
349 DestructionEntry::Plain(plain_destruct) => {
350 (plain_destruct.position.0.0, plain_destruct.position.1, AddDestructFlowType::Plain, 0)
351 }
352 DestructionEntry::Panic(panic_destruct) => match panic_destruct.panic_location {
353 PanicLocation::PanicMatch { target_block_id, match_block_id } => {
354 (target_block_id.0, 0, AddDestructFlowType::PanicPostMatch, match_block_id.0)
355 }
356 PanicLocation::PanicVar { statement_location } => {
357 (statement_location.0.0, statement_location.1, AddDestructFlowType::PanicVar, 0)
358 }
359 },
360 };
361
362 for ((block_id, statement_idx, destruct_type, match_block_id), destructions) in
363 &destructions.into_iter().sorted_by_key(as_tuple).rev().chunk_by(as_tuple)
364 {
365 let mut stmts = vec![];
366
367 let first_panic_var = variables.new_var(VarRequest { ty: panic_ty, location });
368 let mut last_panic_var = first_panic_var;
369
370 for destruction in destructions {
371 let output_var = variables.new_var(VarRequest { ty: unit_ty(db), location });
372
373 match destruction {
374 DestructionEntry::Plain(plain_destruct) => {
375 let semantic_function = semantic::FunctionLongId {
376 function: ConcreteFunction {
377 generic_function: GenericFunctionId::Impl(ImplGenericFunctionId {
378 impl_id: plain_destruct.impl_id,
379 function: plain_trait_function,
380 }),
381 generic_args: vec![],
382 },
383 }
384 .intern(db);
385
386 stmts.push(StatementCall {
387 function: semantic_function.lowered(db),
388 inputs: vec![VarUsage { var_id: plain_destruct.var_id, location }],
389 with_coupon: false,
390 outputs: vec![output_var],
391 location: variables.variables[plain_destruct.var_id].location,
392 is_specialization_base_call: false,
393 })
394 }
395
396 DestructionEntry::Panic(panic_destruct) => {
397 let semantic_function = semantic::FunctionLongId {
398 function: ConcreteFunction {
399 generic_function: GenericFunctionId::Impl(ImplGenericFunctionId {
400 impl_id: panic_destruct.impl_id,
401 function: panic_trait_function,
402 }),
403 generic_args: vec![],
404 },
405 }
406 .intern(db);
407
408 let new_panic_var = variables.new_var(VarRequest { ty: panic_ty, location });
409
410 stmts.push(StatementCall {
411 function: semantic_function.lowered(db),
412 inputs: vec![
413 VarUsage { var_id: panic_destruct.var_id, location },
414 VarUsage { var_id: last_panic_var, location },
415 ],
416 with_coupon: false,
417 outputs: vec![new_panic_var, output_var],
418 location,
419 is_specialization_base_call: false,
420 });
421 last_panic_var = new_panic_var;
422 }
423 }
424 }
425
426 match destruct_type {
427 AddDestructFlowType::Plain => {
428 let block = &mut lowered.blocks[BlockId(block_id)];
429 block
430 .statements
431 .splice(statement_idx..statement_idx, stmts.into_iter().map(Statement::Call));
432 }
433 AddDestructFlowType::PanicPostMatch => {
434 let block = &mut lowered.blocks[BlockId(match_block_id)];
435 let BlockEnd::Match { info: MatchInfo::Enum(info) } = &mut block.end else {
436 unreachable!();
437 };
438
439 let arm = &mut info.arms[1];
440 let tuple_var = &mut arm.var_ids[0];
441 let tuple_ty = variables.variables[*tuple_var].ty;
442 let new_tuple_var = variables.new_var(VarRequest { ty: tuple_ty, location });
443 let orig_tuple_var = *tuple_var;
444 *tuple_var = new_tuple_var;
445 let long_ty = tuple_ty.long(db);
446 let TypeLongId::Tuple(tys) = long_ty else { unreachable!() };
447
448 let vars = tys
449 .iter()
450 .copied()
451 .map(|ty| variables.new_var(VarRequest { ty, location }))
452 .collect::<Vec<_>>();
453
454 *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() = vars[0];
455
456 let target_block_id = arm.block_id;
457
458 let block = &mut lowered.blocks[target_block_id];
459
460 block.statements.splice(
461 0..0,
462 chain!(
463 [Statement::StructDestructure(StatementStructDestructure {
464 input: VarUsage { var_id: new_tuple_var, location },
465 outputs: chain!([first_panic_var], vars.iter().skip(1).cloned())
466 .collect(),
467 })],
468 stmts.into_iter().map(Statement::Call),
469 [Statement::StructConstruct(StatementStructConstruct {
470 inputs: vars
471 .into_iter()
472 .map(|var_id| VarUsage { var_id, location })
473 .collect(),
474 output: orig_tuple_var,
475 })]
476 ),
477 );
478 }
479 AddDestructFlowType::PanicVar => {
480 let block = &mut lowered.blocks[BlockId(block_id)];
481
482 let idx = match block.statements.get_mut(statement_idx) {
483 Some(stmt) => {
484 match stmt {
485 Statement::StructConstruct(stmt) => {
486 let panic_var = &mut stmt.output;
487 *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() = *panic_var;
488 *panic_var = first_panic_var;
489 }
490 Statement::Call(stmt) => {
491 let tuple_var = &mut stmt.outputs[0];
492 let new_tuple_var = variables.new_var(VarRequest {
493 ty: never_fn_actual_return_ty,
494 location,
495 });
496 let orig_tuple_var = *tuple_var;
497 *tuple_var = new_tuple_var;
498 let new_panic_var =
499 variables.new_var(VarRequest { ty: panic_ty, location });
500 let new_arr_var =
501 variables.new_var(VarRequest { ty: felt_arr_ty, location });
502 *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() =
503 new_panic_var;
504 let idx = statement_idx + 1;
505 block.statements.splice(
506 idx..idx,
507 chain!(
508 [Statement::StructDestructure(
509 StatementStructDestructure {
510 input: VarUsage { var_id: new_tuple_var, location },
511 outputs: vec![first_panic_var, new_arr_var],
512 }
513 )],
514 stmts.into_iter().map(Statement::Call),
515 [Statement::StructConstruct(StatementStructConstruct {
516 inputs: [new_panic_var, new_arr_var]
517 .into_iter()
518 .map(|var_id| VarUsage { var_id, location })
519 .collect(),
520 output: orig_tuple_var,
521 })]
522 ),
523 );
524 stmts = vec![];
525 }
526 _ => unreachable!("Expected a struct construct or a call statement."),
527 }
528 statement_idx + 1
529 }
530 None => {
531 assert_eq!(statement_idx, block.statements.len());
532 let panic_var = match &mut block.end {
533 BlockEnd::Return(vars, _) => &mut vars[0].var_id,
534 _ => unreachable!("Expected a return statement."),
535 };
536
537 stmts.first_mut().unwrap().inputs.get_mut(1).unwrap().var_id = *panic_var;
538 *panic_var = last_panic_var;
539 statement_idx
540 }
541 };
542
543 block.statements.splice(idx..idx, stmts.into_iter().map(Statement::Call));
544 }
545 };
546 }
547
548 lowered.variables = variables.variables;
549
550 match function_id.long(db) {
551 ConcreteFunctionWithBodyLongId::Specialized(_) => return,
553 ConcreteFunctionWithBodyLongId::Semantic(id)
554 | ConcreteFunctionWithBodyLongId::Generated(GeneratedFunction { parent: id, .. }) => {
555 if id.substitution(db).map(|s| s.is_empty()).unwrap_or_default() {
557 return;
558 }
559 }
560 }
561
562 for (_, var) in lowered.variables.iter_mut() {
563 if var.info.copyable.is_err() {
565 var.info.copyable = db.copyable(var.ty);
566 }
567 if var.info.droppable.is_err() {
568 var.info.droppable = db.droppable(var.ty);
569 }
570 }
571}