cairo_lang_lowering/
destructs.rs

1//! This module implements the destructor call addition.
2//!
3//! It is assumed to run after the panic phase.
4//! This is similar to the borrow checking algorithm, except we handle "undroppable drops" by adding
5//! destructor calls.
6
7use cairo_lang_defs::ids::LanguageElementId;
8use cairo_lang_filesystem::ids::SmolStrId;
9use cairo_lang_semantic as semantic;
10use cairo_lang_semantic::ConcreteFunction;
11use cairo_lang_semantic::corelib::{
12    CorelibSemantic, core_array_felt252_ty, core_module, get_ty_by_name, unit_ty,
13};
14use cairo_lang_semantic::items::functions::{GenericFunctionId, ImplGenericFunctionId};
15use cairo_lang_semantic::items::imp::ImplId;
16use cairo_lang_semantic::types::TypesSemantic;
17use cairo_lang_utils::Intern;
18use itertools::{Itertools, chain, zip_eq};
19use salsa::Database;
20use semantic::{TypeId, TypeLongId};
21
22use crate::borrow_check::Demand;
23use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
24use crate::borrow_check::demand::{AuxCombine, DemandReporter};
25use crate::ids::{
26    ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, GeneratedFunction,
27    SemanticFunctionIdEx,
28};
29use crate::lower::context::{VarRequest, VariableAllocator};
30use crate::{
31    BlockEnd, BlockId, Lowered, MatchInfo, Statement, StatementCall, StatementStructConstruct,
32    StatementStructDestructure, VarRemapping, VarUsage, VariableId,
33};
34
35pub type DestructAdderDemand = Demand<VariableId, (), PanicState>;
36
37/// The add destruct flow type, used for grouping of destruct calls.
38#[derive(PartialEq, Eq, PartialOrd, Ord)]
39enum AddDestructFlowType {
40    /// Plain destruct
41    Plain,
42    /// Panic destruct following the creation of a panic variable (or return of a panic variable)
43    PanicVar,
44    /// Panic destruct following a match of PanicResult.
45    PanicPostMatch,
46}
47
48/// Context for the destructor call addition phase,
49pub struct DestructAdder<'db, 'a> {
50    db: &'db dyn Database,
51    lowered: &'a Lowered<'db>,
52    destructions: Vec<DestructionEntry<'db>>,
53    panic_ty: TypeId<'db>,
54    /// The actual return type of a never function after adding panics.
55    never_fn_actual_return_ty: TypeId<'db>,
56    is_panic_destruct_fn: bool,
57}
58
59/// A destructor call that needs to be added.
60enum DestructionEntry<'db> {
61    /// A normal destructor call.
62    Plain(PlainDestructionEntry<'db>),
63    /// A panic destructor call.
64    Panic(PanicDeconstructionEntry<'db>),
65}
66
67struct PlainDestructionEntry<'db> {
68    position: StatementLocation,
69    var_id: VariableId,
70    impl_id: ImplId<'db>,
71}
72struct PanicDeconstructionEntry<'db> {
73    panic_location: PanicLocation,
74    var_id: VariableId,
75    impl_id: ImplId<'db>,
76}
77
78impl<'db> DestructAdder<'db, '_> {
79    /// Checks if the statement introduces a panic variable and sets the panic state accordingly.
80    fn set_post_stmt_destruct(
81        &mut self,
82        introductions: &[VariableId],
83        info: &mut DestructAdderDemand,
84        block_id: BlockId,
85        statement_index: usize,
86    ) {
87        if let [panic_var] = introductions[..] {
88            let var = &self.lowered.variables[panic_var];
89            if [self.panic_ty, self.never_fn_actual_return_ty].contains(&var.ty) {
90                info.aux = PanicState::EndsWithPanic(vec![PanicLocation::PanicVar {
91                    statement_location: (block_id, statement_index),
92                }]);
93            }
94        }
95    }
96
97    /// Check if the match arm introduces a `PanicResult::Err` variable and sets the panic state
98    /// accordingly.
99    fn set_post_match_state(
100        &mut self,
101        introduced_vars: &[VariableId],
102        info: &mut DestructAdderDemand,
103        match_block_id: BlockId,
104        target_block_id: BlockId,
105        arm_idx: usize,
106    ) {
107        if arm_idx != 1 {
108            // The post match panic should be on the second arm of a match on a PanicResult.
109            return;
110        }
111        if let [err_var] = introduced_vars[..] {
112            let var = &self.lowered.variables[err_var];
113
114            let long_ty = var.ty.long(self.db);
115            let TypeLongId::Tuple(tys) = long_ty else {
116                return;
117            };
118            if tys.first() == Some(&self.panic_ty) {
119                info.aux = PanicState::EndsWithPanic(vec![PanicLocation::PanicMatch {
120                    match_block_id,
121                    target_block_id,
122                }]);
123            }
124        }
125    }
126}
127
128impl<'db> DemandReporter<VariableId, PanicState> for DestructAdder<'db, '_> {
129    type IntroducePosition = StatementLocation;
130    type UsePosition = ();
131
132    fn drop_aux(
133        &mut self,
134        position: StatementLocation,
135        var_id: VariableId,
136        panic_state: PanicState,
137    ) {
138        let var = &self.lowered.variables[var_id];
139        // Note that droppable here means droppable before monomorphization.
140        // I.e. it is possible that T was substituted with a unit type, but T was not droppable
141        // and therefore the unit type var is not droppable here.
142        if var.info.droppable.is_ok() {
143            return;
144        };
145        // If a non droppable variable gets out of scope, add a destruct call for it.
146        if let Ok(impl_id) = var.info.destruct_impl.clone() {
147            self.destructions.push(DestructionEntry::Plain(PlainDestructionEntry {
148                position,
149                var_id,
150                impl_id,
151            }));
152            return;
153        }
154        // If a non destructible variable gets out of scope, add a panic_destruct call for it.
155        if let Ok(impl_id) = var.info.panic_destruct_impl.clone()
156            && let PanicState::EndsWithPanic(panic_locations) = panic_state
157        {
158            for panic_location in panic_locations {
159                self.destructions.push(DestructionEntry::Panic(PanicDeconstructionEntry {
160                    panic_location,
161                    var_id,
162                    impl_id,
163                }));
164            }
165            return;
166        }
167
168        panic!("Borrow checker should have caught this.")
169    }
170}
171
172/// A state saved for each position in the back analysis.
173/// Used to determine if a Panic object is guaranteed to exist or be created, and where.
174#[derive(Clone, Default)]
175pub enum PanicState {
176    /// The flow will end with a panic. The locations are all the possible places a Panic object
177    /// can be created from this flow.
178    /// The flow is guaranteed to end up in one of these places.
179    EndsWithPanic(Vec<PanicLocation>),
180    #[default]
181    Otherwise,
182}
183/// How to combine two panic states in a flow divergence.
184impl AuxCombine for PanicState {
185    fn merge<'a, I: Iterator<Item = &'a Self>>(iter: I) -> Self
186    where
187        Self: 'a,
188    {
189        let mut panic_locations = vec![];
190        for state in iter {
191            if let Self::EndsWithPanic(locations) = state {
192                panic_locations.extend_from_slice(locations);
193            } else {
194                return Self::Otherwise;
195            }
196        }
197
198        Self::EndsWithPanic(panic_locations)
199    }
200}
201
202/// Location where a `Panic` is first available.
203#[derive(Clone)]
204pub enum PanicLocation {
205    /// The `Panic` value is at a variable created by a StructConstruct at `statement_location`.
206    PanicVar { statement_location: StatementLocation },
207    /// The `Panic` is inside a PanicResult::Err that was create by a match at `match_block_id`.
208    PanicMatch { match_block_id: BlockId, target_block_id: BlockId },
209}
210
211impl<'db> Analyzer<'db, '_> for DestructAdder<'db, '_> {
212    type Info = DestructAdderDemand;
213
214    fn visit_stmt(
215        &mut self,
216        info: &mut Self::Info,
217        (block_id, statement_index): StatementLocation,
218        stmt: &Statement<'db>,
219    ) {
220        self.set_post_stmt_destruct(stmt.outputs(), info, block_id, statement_index);
221        // Since we need to insert destructor call right after the statement.
222        info.variables_introduced(self, stmt.outputs(), (block_id, statement_index + 1));
223        info.variables_used(self, stmt.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())));
224    }
225
226    fn visit_goto(
227        &mut self,
228        info: &mut Self::Info,
229        _statement_location: StatementLocation,
230        _target_block_id: BlockId,
231        remapping: &VarRemapping<'db>,
232    ) {
233        info.apply_remapping(self, remapping.iter().map(|(dst, src)| (dst, (&src.var_id, ()))));
234    }
235
236    fn merge_match(
237        &mut self,
238        (block_id, _statement_index): StatementLocation,
239        match_info: &MatchInfo<'db>,
240        infos: impl Iterator<Item = Self::Info>,
241    ) -> Self::Info {
242        let arm_demands = zip_eq(match_info.arms(), infos)
243            .enumerate()
244            .map(|(arm_idx, (arm, mut demand))| {
245                let use_position = (arm.block_id, 0);
246                self.set_post_match_state(
247                    &arm.var_ids,
248                    &mut demand,
249                    block_id,
250                    arm.block_id,
251                    arm_idx,
252                );
253                demand.variables_introduced(self, &arm.var_ids, use_position);
254                (demand, use_position)
255            })
256            .collect_vec();
257        let mut demand = DestructAdderDemand::merge_demands(&arm_demands, self);
258        demand.variables_used(
259            self,
260            match_info.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
261        );
262        demand
263    }
264
265    fn info_from_return(
266        &mut self,
267        statement_location: StatementLocation,
268        vars: &[VarUsage<'db>],
269    ) -> Self::Info {
270        let mut info = DestructAdderDemand::default();
271        // Allow panic destructors to be called inside panic destruct functions.
272        if self.is_panic_destruct_fn {
273            info.aux =
274                PanicState::EndsWithPanic(vec![PanicLocation::PanicVar { statement_location }]);
275        }
276
277        info.variables_used(self, vars.iter().map(|VarUsage { var_id, .. }| (var_id, ())));
278        info
279    }
280}
281
282fn panic_ty<'db>(db: &'db dyn Database) -> semantic::TypeId<'db> {
283    get_ty_by_name(db, core_module(db), SmolStrId::from(db, "Panic"), vec![])
284}
285
286/// Inserts destructor calls into the lowered function.
287///
288/// Additionally overrides the inferred impls for the `Copyable` and `Droppable` traits according to
289/// the concrete type. This is performed here instead of in `concretize_lowered` to support custom
290/// destructors for droppable types.
291pub fn add_destructs<'db>(
292    db: &'db dyn Database,
293    function_id: ConcreteFunctionWithBodyId<'db>,
294    lowered: &mut Lowered<'db>,
295) {
296    if lowered.blocks.is_empty() {
297        return;
298    }
299
300    let Ok(is_panic_destruct_fn) = function_id.is_panic_destruct_fn(db) else {
301        return;
302    };
303
304    let panic_ty = panic_ty(db);
305    let felt_arr_ty = core_array_felt252_ty(db);
306    let never_fn_actual_return_ty = TypeLongId::Tuple(vec![panic_ty, felt_arr_ty]).intern(db);
307    let checker = DestructAdder {
308        db,
309        lowered,
310        destructions: vec![],
311        panic_ty,
312        never_fn_actual_return_ty,
313        is_panic_destruct_fn,
314    };
315    let mut analysis = BackAnalysis::new(lowered, checker);
316    let mut root_demand = analysis.get_root_info();
317    root_demand.variables_introduced(
318        &mut analysis.analyzer,
319        &lowered.parameters,
320        (BlockId::root(), 0),
321    );
322    assert!(root_demand.finalize(), "Undefined variable should not happen at this stage");
323    let DestructAdder { destructions, .. } = analysis.analyzer;
324
325    let mut variables = VariableAllocator::new(
326        db,
327        function_id.base_semantic_function(db).function_with_body_id(db),
328        std::mem::take(&mut lowered.variables),
329    )
330    .unwrap();
331
332    let info = db.core_info();
333    let plain_trait_function = info.destruct_fn;
334    let panic_trait_function = info.panic_destruct_fn;
335
336    // Add destructions.
337    let stable_ptr =
338        function_id.base_semantic_function(db).function_with_body_id(db).untyped_stable_ptr(db);
339
340    let location = variables.get_location(stable_ptr);
341
342    // We need to add the destructions in reverse order, so that they won't interfere with each
343    // other.
344    // For panic destruction, we need to group them by type and create chains of destruct calls
345    // where each one consumes a panic variable and creates a new one.
346    // To facilitate this, we convert each entry to a tuple we the relevant information for
347    // ordering and grouping.
348    let as_tuple = |entry: &DestructionEntry<'_>| match entry {
349        DestructionEntry::Plain(plain_destruct) => {
350            (plain_destruct.position.0.0, plain_destruct.position.1, AddDestructFlowType::Plain, 0)
351        }
352        DestructionEntry::Panic(panic_destruct) => match panic_destruct.panic_location {
353            PanicLocation::PanicMatch { target_block_id, match_block_id } => {
354                (target_block_id.0, 0, AddDestructFlowType::PanicPostMatch, match_block_id.0)
355            }
356            PanicLocation::PanicVar { statement_location } => {
357                (statement_location.0.0, statement_location.1, AddDestructFlowType::PanicVar, 0)
358            }
359        },
360    };
361
362    for ((block_id, statement_idx, destruct_type, match_block_id), destructions) in
363        &destructions.into_iter().sorted_by_key(as_tuple).rev().chunk_by(as_tuple)
364    {
365        let mut stmts = vec![];
366
367        let first_panic_var = variables.new_var(VarRequest { ty: panic_ty, location });
368        let mut last_panic_var = first_panic_var;
369
370        for destruction in destructions {
371            let output_var = variables.new_var(VarRequest { ty: unit_ty(db), location });
372
373            match destruction {
374                DestructionEntry::Plain(plain_destruct) => {
375                    let semantic_function = semantic::FunctionLongId {
376                        function: ConcreteFunction {
377                            generic_function: GenericFunctionId::Impl(ImplGenericFunctionId {
378                                impl_id: plain_destruct.impl_id,
379                                function: plain_trait_function,
380                            }),
381                            generic_args: vec![],
382                        },
383                    }
384                    .intern(db);
385
386                    stmts.push(StatementCall {
387                        function: semantic_function.lowered(db),
388                        inputs: vec![VarUsage { var_id: plain_destruct.var_id, location }],
389                        with_coupon: false,
390                        outputs: vec![output_var],
391                        location: variables.variables[plain_destruct.var_id].location,
392                        is_specialization_base_call: false,
393                    })
394                }
395
396                DestructionEntry::Panic(panic_destruct) => {
397                    let semantic_function = semantic::FunctionLongId {
398                        function: ConcreteFunction {
399                            generic_function: GenericFunctionId::Impl(ImplGenericFunctionId {
400                                impl_id: panic_destruct.impl_id,
401                                function: panic_trait_function,
402                            }),
403                            generic_args: vec![],
404                        },
405                    }
406                    .intern(db);
407
408                    let new_panic_var = variables.new_var(VarRequest { ty: panic_ty, location });
409
410                    stmts.push(StatementCall {
411                        function: semantic_function.lowered(db),
412                        inputs: vec![
413                            VarUsage { var_id: panic_destruct.var_id, location },
414                            VarUsage { var_id: last_panic_var, location },
415                        ],
416                        with_coupon: false,
417                        outputs: vec![new_panic_var, output_var],
418                        location,
419                        is_specialization_base_call: false,
420                    });
421                    last_panic_var = new_panic_var;
422                }
423            }
424        }
425
426        match destruct_type {
427            AddDestructFlowType::Plain => {
428                let block = &mut lowered.blocks[BlockId(block_id)];
429                block
430                    .statements
431                    .splice(statement_idx..statement_idx, stmts.into_iter().map(Statement::Call));
432            }
433            AddDestructFlowType::PanicPostMatch => {
434                let block = &mut lowered.blocks[BlockId(match_block_id)];
435                let BlockEnd::Match { info: MatchInfo::Enum(info) } = &mut block.end else {
436                    unreachable!();
437                };
438
439                let arm = &mut info.arms[1];
440                let tuple_var = &mut arm.var_ids[0];
441                let tuple_ty = variables.variables[*tuple_var].ty;
442                let new_tuple_var = variables.new_var(VarRequest { ty: tuple_ty, location });
443                let orig_tuple_var = *tuple_var;
444                *tuple_var = new_tuple_var;
445                let long_ty = tuple_ty.long(db);
446                let TypeLongId::Tuple(tys) = long_ty else { unreachable!() };
447
448                let vars = tys
449                    .iter()
450                    .copied()
451                    .map(|ty| variables.new_var(VarRequest { ty, location }))
452                    .collect::<Vec<_>>();
453
454                *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() = vars[0];
455
456                let target_block_id = arm.block_id;
457
458                let block = &mut lowered.blocks[target_block_id];
459
460                block.statements.splice(
461                    0..0,
462                    chain!(
463                        [Statement::StructDestructure(StatementStructDestructure {
464                            input: VarUsage { var_id: new_tuple_var, location },
465                            outputs: chain!([first_panic_var], vars.iter().skip(1).cloned())
466                                .collect(),
467                        })],
468                        stmts.into_iter().map(Statement::Call),
469                        [Statement::StructConstruct(StatementStructConstruct {
470                            inputs: vars
471                                .into_iter()
472                                .map(|var_id| VarUsage { var_id, location })
473                                .collect(),
474                            output: orig_tuple_var,
475                        })]
476                    ),
477                );
478            }
479            AddDestructFlowType::PanicVar => {
480                let block = &mut lowered.blocks[BlockId(block_id)];
481
482                let idx = match block.statements.get_mut(statement_idx) {
483                    Some(stmt) => {
484                        match stmt {
485                            Statement::StructConstruct(stmt) => {
486                                let panic_var = &mut stmt.output;
487                                *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() = *panic_var;
488                                *panic_var = first_panic_var;
489                            }
490                            Statement::Call(stmt) => {
491                                let tuple_var = &mut stmt.outputs[0];
492                                let new_tuple_var = variables.new_var(VarRequest {
493                                    ty: never_fn_actual_return_ty,
494                                    location,
495                                });
496                                let orig_tuple_var = *tuple_var;
497                                *tuple_var = new_tuple_var;
498                                let new_panic_var =
499                                    variables.new_var(VarRequest { ty: panic_ty, location });
500                                let new_arr_var =
501                                    variables.new_var(VarRequest { ty: felt_arr_ty, location });
502                                *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() =
503                                    new_panic_var;
504                                let idx = statement_idx + 1;
505                                block.statements.splice(
506                                    idx..idx,
507                                    chain!(
508                                        [Statement::StructDestructure(
509                                            StatementStructDestructure {
510                                                input: VarUsage { var_id: new_tuple_var, location },
511                                                outputs: vec![first_panic_var, new_arr_var],
512                                            }
513                                        )],
514                                        stmts.into_iter().map(Statement::Call),
515                                        [Statement::StructConstruct(StatementStructConstruct {
516                                            inputs: [new_panic_var, new_arr_var]
517                                                .into_iter()
518                                                .map(|var_id| VarUsage { var_id, location })
519                                                .collect(),
520                                            output: orig_tuple_var,
521                                        })]
522                                    ),
523                                );
524                                stmts = vec![];
525                            }
526                            _ => unreachable!("Expected a struct construct or a call statement."),
527                        }
528                        statement_idx + 1
529                    }
530                    None => {
531                        assert_eq!(statement_idx, block.statements.len());
532                        let panic_var = match &mut block.end {
533                            BlockEnd::Return(vars, _) => &mut vars[0].var_id,
534                            _ => unreachable!("Expected a return statement."),
535                        };
536
537                        stmts.first_mut().unwrap().inputs.get_mut(1).unwrap().var_id = *panic_var;
538                        *panic_var = last_panic_var;
539                        statement_idx
540                    }
541                };
542
543                block.statements.splice(idx..idx, stmts.into_iter().map(Statement::Call));
544            }
545        };
546    }
547
548    lowered.variables = variables.variables;
549
550    match function_id.long(db) {
551        // If specialized, destructors are already correct.
552        ConcreteFunctionWithBodyLongId::Specialized(_) => return,
553        ConcreteFunctionWithBodyLongId::Semantic(id)
554        | ConcreteFunctionWithBodyLongId::Generated(GeneratedFunction { parent: id, .. }) => {
555            // If there is no substitution, destructors are already correct.
556            if id.substitution(db).map(|s| s.is_empty()).unwrap_or_default() {
557                return;
558            }
559        }
560    }
561
562    for (_, var) in lowered.variables.iter_mut() {
563        // After adding destructors, we can infer the concrete `Copyable` and `Droppable` impls.
564        if var.info.copyable.is_err() {
565            var.info.copyable = db.copyable(var.ty);
566        }
567        if var.info.droppable.is_err() {
568            var.info.droppable = db.droppable(var.ty);
569        }
570    }
571}