Skip to main content

cairo_lang_lowering/panic/
mod.rs

1use std::collections::VecDeque;
2
3use assert_matches::assert_matches;
4use cairo_lang_diagnostics::Maybe;
5use cairo_lang_filesystem::flag::FlagsGroup;
6use cairo_lang_filesystem::ids::SmolStrId;
7use cairo_lang_semantic::corelib::{
8    CorelibSemantic, core_submodule, get_core_enum_concrete_variant, get_function_id, get_panic_ty,
9    never_ty,
10};
11use cairo_lang_semantic::helper::ModuleHelper;
12use cairo_lang_semantic::items::constant::ConstValue;
13use cairo_lang_semantic::{self as semantic, GenericArgumentId};
14use cairo_lang_utils::Intern;
15use itertools::{Itertools, chain, zip_eq};
16use salsa::Database;
17use semantic::{ConcreteVariant, MatchArmSelector, TypeId};
18
19use crate::blocks::BlocksBuilder;
20use crate::db::{ConcreteSCCRepresentative, LoweringGroup};
21use crate::ids::{
22    ConcreteFunctionWithBodyId, FunctionId, FunctionLongId, LocationId, SemanticFunctionIdEx,
23    Signature,
24};
25use crate::lower::context::{VarRequest, VariableAllocator};
26use crate::{
27    Block, BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, MatchArm, MatchEnumInfo,
28    MatchExternInfo, MatchInfo, Statement, StatementCall, StatementEnumConstruct,
29    StatementStructConstruct, StatementStructDestructure, VarRemapping, VarUsage, VariableId,
30};
31
32// TODO(spapini): Remove tuple in the Ok() variant of the panic, by supporting multiple values in
33// the Sierra type.
34
35/// Lowering phase that converts `BlockEnd::Panic` into `BlockEnd::Return`, and wraps necessary
36/// types with `PanicResult<>`.
37pub fn lower_panics<'db>(
38    db: &'db dyn Database,
39    function_id: ConcreteFunctionWithBodyId<'db>,
40    lowered: &mut Lowered<'db>,
41) -> Maybe<()> {
42    // Skip this phase for non panicable functions.
43    if !db.function_with_body_may_panic(function_id)? {
44        return Ok(());
45    }
46
47    let opt_trace_fn = if db.flag_panic_backtrace() {
48        Some(
49            ModuleHelper::core(db)
50                .submodule("internal")
51                .function_id(
52                    "trace",
53                    vec![GenericArgumentId::Constant(
54                        ConstValue::Int(
55                            0x70616e6963u64.into(), // 'panic' as numeric.
56                            db.core_info().felt252,
57                        )
58                        .intern(db),
59                    )],
60                )
61                .lowered(db),
62        )
63    } else {
64        None
65    };
66
67    if db.flag_unsafe_panic() {
68        lower_unsafe_panic(db, lowered, opt_trace_fn);
69        return Ok(());
70    }
71
72    let signature = function_id.signature(db)?;
73    // TODO(orizi): Validate all signature types are fully concrete at this point.
74    let panic_info = PanicSignatureInfo::new(db, &signature);
75    let variables = VariableAllocator::new(
76        db,
77        function_id.base_semantic_function(db).function_with_body_id(db),
78        std::mem::take(&mut lowered.variables),
79    )
80    .unwrap();
81    let mut ctx = PanicLoweringContext {
82        variables,
83        block_queue: VecDeque::from_iter(lowered.blocks.get().iter().cloned()),
84        flat_blocks: BlocksBuilder::new(),
85        panic_info,
86    };
87
88    if let Some(trace_fn) = opt_trace_fn {
89        for block in ctx.block_queue.iter_mut() {
90            if let BlockEnd::Panic(end) = &block.end {
91                block.statements.push(Statement::Call(StatementCall {
92                    function: trace_fn,
93                    inputs: vec![],
94                    with_coupon: false,
95                    outputs: vec![],
96                    location: end.location,
97                    is_specialization_base_call: false,
98                }));
99            }
100        }
101    }
102
103    // Iterate block queue (old and new blocks).
104    while let Some(block) = ctx.block_queue.pop_front() {
105        ctx = handle_block(ctx, block)?;
106    }
107
108    lowered.variables = ctx.variables.variables;
109    lowered.blocks = ctx.flat_blocks.build().unwrap();
110
111    Ok(())
112}
113
114/// Lowering phase that converts BlockEnd::Panic into BlockEnd::Match { function: unsafe_panic }.
115/// 'opt_trace_fn' is an optional function to call before the panic.
116fn lower_unsafe_panic<'db>(
117    db: &'db dyn Database,
118    lowered: &mut Lowered<'db>,
119    opt_trace_fn: Option<FunctionId<'db>>,
120) {
121    let panics = core_submodule(db, SmolStrId::from(db, "panics"));
122    let panic_func_id = FunctionLongId::Semantic(get_function_id(
123        db,
124        panics,
125        SmolStrId::from(db, "unsafe_panic"),
126        vec![],
127    ))
128    .intern(db);
129
130    for block in lowered.blocks.iter_mut() {
131        let BlockEnd::Panic(err_data) = &mut block.end else {
132            continue;
133        };
134
135        // Clean up undroppable panic related variables to pass add_destructs.
136        let Some(Statement::StructConstruct(tuple_construct)) = block.statements.pop() else {
137            panic!("Expected a tuple construct before the panic.");
138        };
139        // Assert `err_data` is produced by the statement that was removed above.
140        assert_eq!(tuple_construct.output, err_data.var_id);
141
142        let panic_construct_statement = block.statements.pop();
143        // Assert that the output of `panic_construct_statement` is the first input of
144        // 'tuple_construct'.
145        assert_matches!(panic_construct_statement, Some(Statement::StructConstruct(panic_construct)) if panic_construct.output == tuple_construct.inputs[0].var_id);
146
147        if let Some(trace_fn) = opt_trace_fn {
148            block.statements.push(Statement::Call(StatementCall {
149                function: trace_fn,
150                inputs: vec![],
151                with_coupon: false,
152                outputs: vec![],
153                location: err_data.location,
154                is_specialization_base_call: false,
155            }));
156        }
157
158        block.end = BlockEnd::Match {
159            info: MatchInfo::Extern(MatchExternInfo {
160                arms: vec![],
161                location: err_data.location,
162                function: panic_func_id,
163                inputs: vec![],
164            }),
165        }
166    }
167}
168
169/// Handles the lowering of panics in a single block.
170fn handle_block<'db>(
171    mut ctx: PanicLoweringContext<'db>,
172    mut block: Block<'db>,
173) -> Maybe<PanicLoweringContext<'db>> {
174    let mut block_ctx = PanicBlockLoweringContext { ctx, statements: Vec::new() };
175    for (i, stmt) in block.statements.iter().cloned().enumerate() {
176        if let Some((cur_block_end, continuation_block)) = block_ctx.handle_statement(&stmt)? {
177            // This case means that the lowering should split the block here.
178
179            // Block ended with a match.
180            ctx = block_ctx.handle_end(cur_block_end);
181            if let Some(continuation_block) = continuation_block {
182                // The rest of the statements in this block have not been handled yet, and should be
183                // handled as a part of the continuation block - the second block in the "split".
184                let block_to_edit =
185                    &mut ctx.block_queue[continuation_block.0 - ctx.flat_blocks.len()];
186                block_to_edit.statements.extend(block.statements.drain(i + 1..));
187                block_to_edit.end = block.end;
188            }
189            return Ok(ctx);
190        }
191    }
192    ctx = block_ctx.handle_end(block.end);
193    Ok(ctx)
194}
195
196pub struct PanicSignatureInfo<'db> {
197    /// The types of all the variables returned on OK: Reference variables and the original result.
198    ok_ret_tys: Vec<TypeId<'db>>,
199    /// The type of the Ok() variant.
200    ok_ty: TypeId<'db>,
201    /// The Ok() variant.
202    ok_variant: ConcreteVariant<'db>,
203    /// The Err() variant.
204    pub err_variant: ConcreteVariant<'db>,
205    /// The PanicResult concrete type - the new return type of the function.
206    pub actual_return_ty: TypeId<'db>,
207    /// Does the function always panic.
208    /// Note that if it does - the function returned type is always `(Panic, Array<felt252>)`.
209    pub always_panic: bool,
210}
211impl<'db> PanicSignatureInfo<'db> {
212    pub fn new(db: &'db dyn Database, signature: &Signature<'db>) -> Self {
213        let extra_rets = signature.extra_rets.iter().map(|param| param.ty());
214        let original_return_ty = signature.return_type;
215
216        let ok_ret_tys = chain!(extra_rets, [original_return_ty]).collect_vec();
217        let ok_ty = semantic::TypeLongId::Tuple(ok_ret_tys.clone()).intern(db);
218        let ok_variant = get_core_enum_concrete_variant(
219            db,
220            SmolStrId::from(db, "PanicResult"),
221            vec![GenericArgumentId::Type(ok_ty)],
222            SmolStrId::from(db, "Ok"),
223        );
224        let err_variant = get_core_enum_concrete_variant(
225            db,
226            SmolStrId::from(db, "PanicResult"),
227            vec![GenericArgumentId::Type(ok_ty)],
228            SmolStrId::from(db, "Err"),
229        );
230        let always_panic = original_return_ty == never_ty(db);
231        let panic_ty = if always_panic { err_variant.ty } else { get_panic_ty(db, ok_ty) };
232
233        Self {
234            ok_ret_tys,
235            ok_ty,
236            ok_variant,
237            err_variant,
238            actual_return_ty: panic_ty,
239            always_panic,
240        }
241    }
242}
243
244struct PanicLoweringContext<'db> {
245    variables: VariableAllocator<'db>,
246    block_queue: VecDeque<Block<'db>>,
247    flat_blocks: BlocksBuilder<'db>,
248    panic_info: PanicSignatureInfo<'db>,
249}
250impl<'db> PanicLoweringContext<'db> {
251    pub fn db(&self) -> &'db dyn Database {
252        self.variables.db
253    }
254
255    fn enqueue_block(&mut self, block: Block<'db>) -> BlockId {
256        self.block_queue.push_back(block);
257        BlockId(self.flat_blocks.len() + self.block_queue.len())
258    }
259}
260
261struct PanicBlockLoweringContext<'db> {
262    ctx: PanicLoweringContext<'db>,
263    statements: Vec<Statement<'db>>,
264}
265impl<'db> PanicBlockLoweringContext<'db> {
266    pub fn db(&self) -> &'db dyn Database {
267        self.ctx.db()
268    }
269
270    fn new_var(&mut self, ty: TypeId<'db>, location: LocationId<'db>) -> VariableId {
271        self.ctx.variables.new_var(VarRequest { ty, location })
272    }
273
274    /// Handles a statement. If needed, returns the continuation block and the block end for this
275    /// block.
276    /// The continuation block happens when a panic match is added, and the block needs to be split.
277    /// The continuation block is the second block in the "split". This function already partially
278    /// creates this second block, and returns it.
279    /// In case there is no panic match - but just a panic, there is no continuation block.
280    fn handle_statement(
281        &mut self,
282        stmt: &Statement<'db>,
283    ) -> Maybe<Option<(BlockEnd<'db>, Option<BlockId>)>> {
284        if let Statement::Call(call) = &stmt
285            && let Some(with_body) = call.function.body(self.db())?
286            && self.db().function_with_body_may_panic(with_body)?
287        {
288            return Ok(Some(self.handle_call_panic(call)?));
289        }
290        self.statements.push(stmt.clone());
291        Ok(None)
292    }
293
294    /// Handles a call statement to a panicking function.
295    /// Returns the continuation block ID for the caller to complete it, and the block end to set
296    /// for the current block.
297    fn handle_call_panic(
298        &mut self,
299        call: &StatementCall<'db>,
300    ) -> Maybe<(BlockEnd<'db>, Option<BlockId>)> {
301        // Extract return variable.
302        let mut original_outputs = call.outputs.clone();
303        let location = call.location.with_auto_generation_note(self.db(), "Panic handling");
304
305        // Get callee info.
306        let callee_signature = call.function.signature(self.db())?;
307        let callee_info = PanicSignatureInfo::new(self.db(), &callee_signature);
308        if callee_info.always_panic {
309            // The panic value, which is actually of type (Panics, Array<felt252>).
310            let panic_result_var = self.new_var(callee_info.actual_return_ty, location);
311            // Emit the new statement.
312            self.statements.push(Statement::Call(StatementCall {
313                function: call.function,
314                inputs: call.inputs.clone(),
315                with_coupon: call.with_coupon,
316                outputs: vec![panic_result_var],
317                location,
318                is_specialization_base_call: call.is_specialization_base_call,
319            }));
320            return Ok((BlockEnd::Panic(VarUsage { var_id: panic_result_var, location }), None));
321        }
322
323        // Allocate 2 new variables.
324        // panic_result_var - for the new return variable, with is actually of type PanicResult<ty>.
325        let panic_result_var = self.new_var(callee_info.actual_return_ty, location);
326        let n_callee_implicits = original_outputs.len() - callee_info.ok_ret_tys.len();
327        let mut call_outputs = original_outputs.drain(..n_callee_implicits).collect_vec();
328        call_outputs.push(panic_result_var);
329        // inner_ok_value - for the Ok() match arm input.
330        let inner_ok_value = self.new_var(callee_info.ok_ty, location);
331        // inner_ok_values - for the destructure.
332        let inner_ok_values = callee_info
333            .ok_ret_tys
334            .iter()
335            .copied()
336            .map(|ty| self.new_var(ty, location))
337            .collect_vec();
338
339        // Emit the new statement.
340        self.statements.push(Statement::Call(StatementCall {
341            function: call.function,
342            inputs: call.inputs.clone(),
343            with_coupon: call.with_coupon,
344            outputs: call_outputs,
345            location,
346            is_specialization_base_call: call.is_specialization_base_call,
347        }));
348
349        // Start constructing a match on the result.
350        let block_continuation =
351            self.ctx.enqueue_block(Block { statements: vec![], end: BlockEnd::NotSet });
352
353        // Prepare Ok() match arm block. This block will be the continuation block.
354        // This block is only partially created. It is returned at this function to let the caller
355        // complete it.
356        let block_ok = self.ctx.enqueue_block(Block {
357            statements: vec![Statement::StructDestructure(StatementStructDestructure {
358                input: VarUsage { var_id: inner_ok_value, location },
359                outputs: inner_ok_values.clone(),
360            })],
361            end: BlockEnd::Goto(
362                block_continuation,
363                VarRemapping {
364                    remapping: zip_eq(
365                        original_outputs,
366                        inner_ok_values.into_iter().map(|var_id| VarUsage { var_id, location }),
367                    )
368                    .collect(),
369                },
370            ),
371        });
372
373        // Prepare Err() match arm block.
374        let err_var = self.new_var(self.ctx.panic_info.err_variant.ty, location);
375        let block_err = self.ctx.enqueue_block(Block {
376            statements: vec![],
377            end: BlockEnd::Panic(VarUsage { var_id: err_var, location }),
378        });
379
380        let cur_block_end = BlockEnd::Match {
381            info: MatchInfo::Enum(MatchEnumInfo {
382                concrete_enum_id: callee_info.ok_variant.concrete_enum_id,
383                input: VarUsage { var_id: panic_result_var, location },
384                arms: vec![
385                    MatchArm {
386                        arm_selector: MatchArmSelector::VariantId(callee_info.ok_variant),
387                        block_id: block_ok,
388                        var_ids: vec![inner_ok_value],
389                    },
390                    MatchArm {
391                        arm_selector: MatchArmSelector::VariantId(callee_info.err_variant),
392                        block_id: block_err,
393                        var_ids: vec![err_var],
394                    },
395                ],
396                location,
397            }),
398        };
399
400        Ok((cur_block_end, Some(block_continuation)))
401    }
402
403    fn handle_end(mut self, end: BlockEnd<'db>) -> PanicLoweringContext<'db> {
404        let end = match end {
405            BlockEnd::Goto(target, remapping) => BlockEnd::Goto(target, remapping),
406            BlockEnd::Panic(err_data) => {
407                // Wrap with PanicResult::Err.
408                let ty = self.ctx.panic_info.actual_return_ty;
409                let location = err_data.location;
410                let output = if self.ctx.panic_info.always_panic {
411                    err_data.var_id
412                } else {
413                    let output = self.new_var(ty, location);
414                    self.statements.push(Statement::EnumConstruct(StatementEnumConstruct {
415                        variant: self.ctx.panic_info.err_variant,
416                        input: err_data,
417                        output,
418                    }));
419                    output
420                };
421                BlockEnd::Return(vec![VarUsage { var_id: output, location }], location)
422            }
423            BlockEnd::Return(returns, location) => {
424                // Tuple construction.
425                let tupled_res = self.new_var(self.ctx.panic_info.ok_ty, location);
426                self.statements.push(Statement::StructConstruct(StatementStructConstruct {
427                    inputs: returns,
428                    output: tupled_res,
429                }));
430
431                // Wrap with PanicResult::Ok.
432                let ty = self.ctx.panic_info.actual_return_ty;
433                let output = self.new_var(ty, location);
434                self.statements.push(Statement::EnumConstruct(StatementEnumConstruct {
435                    variant: self.ctx.panic_info.ok_variant,
436                    input: VarUsage { var_id: tupled_res, location },
437                    output,
438                }));
439                BlockEnd::Return(vec![VarUsage { var_id: output, location }], location)
440            }
441            BlockEnd::NotSet => unreachable!(),
442            BlockEnd::Match { info } => BlockEnd::Match { info },
443        };
444        self.ctx.flat_blocks.alloc(Block { statements: self.statements, end });
445        self.ctx
446    }
447}
448
449// ============= Query implementations =============
450
451/// Query implementation of [crate::db::LoweringGroup::function_may_panic].
452#[salsa::tracked]
453pub fn function_may_panic<'db>(db: &'db dyn Database, function: FunctionId<'db>) -> Maybe<bool> {
454    if let Some(body) = function.body(db)? {
455        return db.function_with_body_may_panic(body);
456    }
457    Ok(function.signature(db)?.panicable)
458}
459
460/// A trait to add helper methods in [LoweringGroup].
461pub trait MayPanicTrait<'db>: Database {
462    /// Returns whether a [ConcreteFunctionWithBodyId] may panic.
463    fn function_with_body_may_panic(
464        &'db self,
465        function: ConcreteFunctionWithBodyId<'db>,
466    ) -> Maybe<bool> {
467        let db: &'db dyn Database = self.as_dyn_database();
468        let scc_representative = db.lowered_scc_representative(
469            function,
470            DependencyType::Call,
471            LoweringStage::Monomorphized,
472        );
473        scc_may_panic(db, scc_representative)
474    }
475}
476impl<'db, T: Database + ?Sized> MayPanicTrait<'db> for T {}
477
478/// Returns whether any function in the strongly connected component may panic.
479fn scc_may_panic<'db>(db: &'db dyn Database, scc: ConcreteSCCRepresentative<'db>) -> Maybe<bool> {
480    scc_may_panic_tracked(db, scc.0)
481}
482
483/// Tracked implementation of [scc_may_panic].
484#[salsa::tracked]
485fn scc_may_panic_tracked<'db>(
486    db: &'db dyn Database,
487    rep: ConcreteFunctionWithBodyId<'db>,
488) -> Maybe<bool> {
489    // Find the SCC representative.
490    let scc_functions = db.lowered_scc(rep, DependencyType::Call, LoweringStage::Monomorphized);
491    for function in scc_functions {
492        if db.needs_withdraw_gas(function)? {
493            return Ok(true);
494        }
495        if db.has_direct_panic(function)? {
496            return Ok(true);
497        }
498        // For each direct callee, find if it may panic.
499        let direct_callees = db.lowered_direct_callees(
500            function,
501            DependencyType::Call,
502            LoweringStage::Monomorphized,
503        )?;
504        for direct_callee in direct_callees {
505            if let Some(callee_body) = direct_callee.body(db)? {
506                let callee_scc = db.lowered_scc_representative(
507                    callee_body,
508                    DependencyType::Call,
509                    LoweringStage::Monomorphized,
510                );
511                if callee_scc.0 != rep && scc_may_panic(db, callee_scc)? {
512                    return Ok(true);
513                }
514            } else if direct_callee.signature(db)?.panicable {
515                return Ok(true);
516            }
517        }
518    }
519    Ok(false)
520}
521
522/// Query implementation of [crate::db::LoweringGroup::has_direct_panic].
523#[salsa::tracked]
524pub fn has_direct_panic<'db>(
525    db: &'db dyn Database,
526    function_id: ConcreteFunctionWithBodyId<'db>,
527) -> Maybe<bool> {
528    let lowered_function = db.lowered_body(function_id, LoweringStage::Monomorphized)?;
529    Ok(itertools::any(lowered_function.blocks.iter(), |(_, block)| {
530        matches!(&block.end, BlockEnd::Panic(..))
531    }))
532}