cairo_lang_lowering/panic/
mod.rs

1use std::collections::VecDeque;
2
3use assert_matches::assert_matches;
4use cairo_lang_diagnostics::Maybe;
5use cairo_lang_filesystem::db::FilesGroup;
6use cairo_lang_filesystem::flag::{Flag, flag_unsafe_panic};
7use cairo_lang_filesystem::ids::{FlagId, FlagLongId, SmolStrId};
8use cairo_lang_semantic::corelib::{
9    CorelibSemantic, core_submodule, get_core_enum_concrete_variant, get_function_id, get_panic_ty,
10    never_ty,
11};
12use cairo_lang_semantic::helper::ModuleHelper;
13use cairo_lang_semantic::items::constant::ConstValue;
14use cairo_lang_semantic::{self as semantic, GenericArgumentId};
15use cairo_lang_utils::Intern;
16use itertools::{Itertools, chain, zip_eq};
17use salsa::Database;
18use semantic::{ConcreteVariant, MatchArmSelector, TypeId};
19
20use crate::blocks::BlocksBuilder;
21use crate::db::{ConcreteSCCRepresentative, LoweringGroup};
22use crate::ids::{
23    ConcreteFunctionWithBodyId, FunctionId, FunctionLongId, LocationId, SemanticFunctionIdEx,
24    Signature,
25};
26use crate::lower::context::{VarRequest, VariableAllocator};
27use crate::{
28    Block, BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, MatchArm, MatchEnumInfo,
29    MatchExternInfo, MatchInfo, Statement, StatementCall, StatementEnumConstruct,
30    StatementStructConstruct, StatementStructDestructure, VarRemapping, VarUsage, VariableId,
31};
32
33// TODO(spapini): Remove tuple in the Ok() variant of the panic, by supporting multiple values in
34// the Sierra type.
35
36/// Lowering phase that converts `BlockEnd::Panic` into `BlockEnd::Return`, and wraps necessary
37/// types with `PanicResult<>`.
38pub fn lower_panics<'db>(
39    db: &'db dyn Database,
40    function_id: ConcreteFunctionWithBodyId<'db>,
41    lowered: &mut Lowered<'db>,
42) -> Maybe<()> {
43    // Skip this phase for non panicable functions.
44    if !db.function_with_body_may_panic(function_id)? {
45        return Ok(());
46    }
47
48    let opt_trace_fn = if matches!(
49        db.get_flag(FlagId::new(db, FlagLongId("panic_backtrace".into()))),
50        Some(flag) if matches!(*flag, Flag::PanicBacktrace(true)),
51    ) {
52        Some(
53            ModuleHelper::core(db)
54                .submodule("internal")
55                .function_id(
56                    "trace",
57                    vec![GenericArgumentId::Constant(
58                        ConstValue::Int(
59                            0x70616e6963u64.into(), // 'panic' as numeric.
60                            db.core_info().felt252,
61                        )
62                        .intern(db),
63                    )],
64                )
65                .lowered(db),
66        )
67    } else {
68        None
69    };
70
71    if flag_unsafe_panic(db) {
72        lower_unsafe_panic(db, lowered, opt_trace_fn);
73        return Ok(());
74    }
75
76    let signature = function_id.signature(db)?;
77    // TODO(orizi): Validate all signature types are fully concrete at this point.
78    let panic_info = PanicSignatureInfo::new(db, &signature);
79    let variables = VariableAllocator::new(
80        db,
81        function_id.base_semantic_function(db).function_with_body_id(db),
82        std::mem::take(&mut lowered.variables),
83    )
84    .unwrap();
85    let mut ctx = PanicLoweringContext {
86        variables,
87        block_queue: VecDeque::from(lowered.blocks.get().clone()),
88        flat_blocks: BlocksBuilder::new(),
89        panic_info,
90    };
91
92    if let Some(trace_fn) = opt_trace_fn {
93        for block in ctx.block_queue.iter_mut() {
94            if let BlockEnd::Panic(end) = &block.end {
95                block.statements.push(Statement::Call(StatementCall {
96                    function: trace_fn,
97                    inputs: vec![],
98                    with_coupon: false,
99                    outputs: vec![],
100                    location: end.location,
101                    is_specialization_base_call: false,
102                }));
103            }
104        }
105    }
106
107    // Iterate block queue (old and new blocks).
108    while let Some(block) = ctx.block_queue.pop_front() {
109        ctx = handle_block(ctx, block)?;
110    }
111
112    lowered.variables = ctx.variables.variables;
113    lowered.blocks = ctx.flat_blocks.build().unwrap();
114
115    Ok(())
116}
117
118/// Lowering phase that converts BlockEnd::Panic into BlockEnd::Match { function: unsafe_panic }.
119/// 'opt_trace_fn' is an optional function to call before the panic.
120fn lower_unsafe_panic<'db>(
121    db: &'db dyn Database,
122    lowered: &mut Lowered<'db>,
123    opt_trace_fn: Option<FunctionId<'db>>,
124) {
125    let panics = core_submodule(db, SmolStrId::from(db, "panics"));
126    let panic_func_id = FunctionLongId::Semantic(get_function_id(
127        db,
128        panics,
129        SmolStrId::from(db, "unsafe_panic"),
130        vec![],
131    ))
132    .intern(db);
133
134    for block in lowered.blocks.iter_mut() {
135        let BlockEnd::Panic(err_data) = &mut block.end else {
136            continue;
137        };
138
139        // Clean up undroppable panic related variables to pass add_destructs.
140        let Some(Statement::StructConstruct(tuple_construct)) = block.statements.pop() else {
141            panic!("Expected a tuple construct before the panic.");
142        };
143        // Assert `err_data` is produced by the statement that was removed above.
144        assert_eq!(tuple_construct.output, err_data.var_id);
145
146        let panic_construct_statement = block.statements.pop();
147        // Assert that the output of `panic_construct_statement` is the first input of
148        // 'tuple_construct'.
149        assert_matches!(panic_construct_statement, Some(Statement::StructConstruct(panic_construct)) if panic_construct.output == tuple_construct.inputs[0].var_id);
150
151        if let Some(trace_fn) = opt_trace_fn {
152            block.statements.push(Statement::Call(StatementCall {
153                function: trace_fn,
154                inputs: vec![],
155                with_coupon: false,
156                outputs: vec![],
157                location: err_data.location,
158                is_specialization_base_call: false,
159            }));
160        }
161
162        block.end = BlockEnd::Match {
163            info: MatchInfo::Extern(MatchExternInfo {
164                arms: vec![],
165                location: err_data.location,
166                function: panic_func_id,
167                inputs: vec![],
168            }),
169        }
170    }
171}
172
173/// Handles the lowering of panics in a single block.
174fn handle_block<'db>(
175    mut ctx: PanicLoweringContext<'db>,
176    mut block: Block<'db>,
177) -> Maybe<PanicLoweringContext<'db>> {
178    let mut block_ctx = PanicBlockLoweringContext { ctx, statements: Vec::new() };
179    for (i, stmt) in block.statements.iter().cloned().enumerate() {
180        if let Some((cur_block_end, continuation_block)) = block_ctx.handle_statement(&stmt)? {
181            // This case means that the lowering should split the block here.
182
183            // Block ended with a match.
184            ctx = block_ctx.handle_end(cur_block_end);
185            if let Some(continuation_block) = continuation_block {
186                // The rest of the statements in this block have not been handled yet, and should be
187                // handled as a part of the continuation block - the second block in the "split".
188                let block_to_edit =
189                    &mut ctx.block_queue[continuation_block.0 - ctx.flat_blocks.len()];
190                block_to_edit.statements.extend(block.statements.drain(i + 1..));
191                block_to_edit.end = block.end;
192            }
193            return Ok(ctx);
194        }
195    }
196    ctx = block_ctx.handle_end(block.end);
197    Ok(ctx)
198}
199
200pub struct PanicSignatureInfo<'db> {
201    /// The types of all the variables returned on OK: Reference variables and the original result.
202    ok_ret_tys: Vec<TypeId<'db>>,
203    /// The type of the Ok() variant.
204    ok_ty: TypeId<'db>,
205    /// The Ok() variant.
206    ok_variant: ConcreteVariant<'db>,
207    /// The Err() variant.
208    pub err_variant: ConcreteVariant<'db>,
209    /// The PanicResult concrete type - the new return type of the function.
210    pub actual_return_ty: TypeId<'db>,
211    /// Does the function always panic.
212    /// Note that if it does - the function returned type is always `(Panic, Array<felt252>)`.
213    pub always_panic: bool,
214}
215impl<'db> PanicSignatureInfo<'db> {
216    pub fn new(db: &'db dyn Database, signature: &Signature<'db>) -> Self {
217        let extra_rets = signature.extra_rets.iter().map(|param| param.ty());
218        let original_return_ty = signature.return_type;
219
220        let ok_ret_tys = chain!(extra_rets, [original_return_ty]).collect_vec();
221        let ok_ty = semantic::TypeLongId::Tuple(ok_ret_tys.clone()).intern(db);
222        let ok_variant = get_core_enum_concrete_variant(
223            db,
224            SmolStrId::from(db, "PanicResult"),
225            vec![GenericArgumentId::Type(ok_ty)],
226            SmolStrId::from(db, "Ok"),
227        );
228        let err_variant = get_core_enum_concrete_variant(
229            db,
230            SmolStrId::from(db, "PanicResult"),
231            vec![GenericArgumentId::Type(ok_ty)],
232            SmolStrId::from(db, "Err"),
233        );
234        let always_panic = original_return_ty == never_ty(db);
235        let panic_ty = if always_panic { err_variant.ty } else { get_panic_ty(db, ok_ty) };
236
237        Self {
238            ok_ret_tys,
239            ok_ty,
240            ok_variant,
241            err_variant,
242            actual_return_ty: panic_ty,
243            always_panic,
244        }
245    }
246}
247
248struct PanicLoweringContext<'db> {
249    variables: VariableAllocator<'db>,
250    block_queue: VecDeque<Block<'db>>,
251    flat_blocks: BlocksBuilder<'db>,
252    panic_info: PanicSignatureInfo<'db>,
253}
254impl<'db> PanicLoweringContext<'db> {
255    pub fn db(&self) -> &'db dyn Database {
256        self.variables.db
257    }
258
259    fn enqueue_block(&mut self, block: Block<'db>) -> BlockId {
260        self.block_queue.push_back(block);
261        BlockId(self.flat_blocks.len() + self.block_queue.len())
262    }
263}
264
265struct PanicBlockLoweringContext<'db> {
266    ctx: PanicLoweringContext<'db>,
267    statements: Vec<Statement<'db>>,
268}
269impl<'db> PanicBlockLoweringContext<'db> {
270    pub fn db(&self) -> &'db dyn Database {
271        self.ctx.db()
272    }
273
274    fn new_var(&mut self, ty: TypeId<'db>, location: LocationId<'db>) -> VariableId {
275        self.ctx.variables.new_var(VarRequest { ty, location })
276    }
277
278    /// Handles a statement. If needed, returns the continuation block and the block end for this
279    /// block.
280    /// The continuation block happens when a panic match is added, and the block needs to be split.
281    /// The continuation block is the second block in the "split". This function already partially
282    /// creates this second block, and returns it.
283    /// In case there is no panic match - but just a panic, there is no continuation block.
284    fn handle_statement(
285        &mut self,
286        stmt: &Statement<'db>,
287    ) -> Maybe<Option<(BlockEnd<'db>, Option<BlockId>)>> {
288        if let Statement::Call(call) = &stmt
289            && let Some(with_body) = call.function.body(self.db())?
290            && self.db().function_with_body_may_panic(with_body)?
291        {
292            return Ok(Some(self.handle_call_panic(call)?));
293        }
294        self.statements.push(stmt.clone());
295        Ok(None)
296    }
297
298    /// Handles a call statement to a panicking function.
299    /// Returns the continuation block ID for the caller to complete it, and the block end to set
300    /// for the current block.
301    fn handle_call_panic(
302        &mut self,
303        call: &StatementCall<'db>,
304    ) -> Maybe<(BlockEnd<'db>, Option<BlockId>)> {
305        // Extract return variable.
306        let mut original_outputs = call.outputs.clone();
307        let location = call.location.with_auto_generation_note(self.db(), "Panic handling");
308
309        // Get callee info.
310        let callee_signature = call.function.signature(self.db())?;
311        let callee_info = PanicSignatureInfo::new(self.db(), &callee_signature);
312        if callee_info.always_panic {
313            // The panic value, which is actually of type (Panics, Array<felt252>).
314            let panic_result_var = self.new_var(callee_info.actual_return_ty, location);
315            // Emit the new statement.
316            self.statements.push(Statement::Call(StatementCall {
317                function: call.function,
318                inputs: call.inputs.clone(),
319                with_coupon: call.with_coupon,
320                outputs: vec![panic_result_var],
321                location,
322                is_specialization_base_call: call.is_specialization_base_call,
323            }));
324            return Ok((BlockEnd::Panic(VarUsage { var_id: panic_result_var, location }), None));
325        }
326
327        // Allocate 2 new variables.
328        // panic_result_var - for the new return variable, with is actually of type PanicResult<ty>.
329        let panic_result_var = self.new_var(callee_info.actual_return_ty, location);
330        let n_callee_implicits = original_outputs.len() - callee_info.ok_ret_tys.len();
331        let mut call_outputs = original_outputs.drain(..n_callee_implicits).collect_vec();
332        call_outputs.push(panic_result_var);
333        // inner_ok_value - for the Ok() match arm input.
334        let inner_ok_value = self.new_var(callee_info.ok_ty, location);
335        // inner_ok_values - for the destructure.
336        let inner_ok_values = callee_info
337            .ok_ret_tys
338            .iter()
339            .copied()
340            .map(|ty| self.new_var(ty, location))
341            .collect_vec();
342
343        // Emit the new statement.
344        self.statements.push(Statement::Call(StatementCall {
345            function: call.function,
346            inputs: call.inputs.clone(),
347            with_coupon: call.with_coupon,
348            outputs: call_outputs,
349            location,
350            is_specialization_base_call: call.is_specialization_base_call,
351        }));
352
353        // Start constructing a match on the result.
354        let block_continuation =
355            self.ctx.enqueue_block(Block { statements: vec![], end: BlockEnd::NotSet });
356
357        // Prepare Ok() match arm block. This block will be the continuation block.
358        // This block is only partially created. It is returned at this function to let the caller
359        // complete it.
360        let block_ok = self.ctx.enqueue_block(Block {
361            statements: vec![Statement::StructDestructure(StatementStructDestructure {
362                input: VarUsage { var_id: inner_ok_value, location },
363                outputs: inner_ok_values.clone(),
364            })],
365            end: BlockEnd::Goto(
366                block_continuation,
367                VarRemapping {
368                    remapping: zip_eq(
369                        original_outputs,
370                        inner_ok_values.into_iter().map(|var_id| VarUsage { var_id, location }),
371                    )
372                    .collect(),
373                },
374            ),
375        });
376
377        // Prepare Err() match arm block.
378        let err_var = self.new_var(self.ctx.panic_info.err_variant.ty, location);
379        let block_err = self.ctx.enqueue_block(Block {
380            statements: vec![],
381            end: BlockEnd::Panic(VarUsage { var_id: err_var, location }),
382        });
383
384        let cur_block_end = BlockEnd::Match {
385            info: MatchInfo::Enum(MatchEnumInfo {
386                concrete_enum_id: callee_info.ok_variant.concrete_enum_id,
387                input: VarUsage { var_id: panic_result_var, location },
388                arms: vec![
389                    MatchArm {
390                        arm_selector: MatchArmSelector::VariantId(callee_info.ok_variant),
391                        block_id: block_ok,
392                        var_ids: vec![inner_ok_value],
393                    },
394                    MatchArm {
395                        arm_selector: MatchArmSelector::VariantId(callee_info.err_variant),
396                        block_id: block_err,
397                        var_ids: vec![err_var],
398                    },
399                ],
400                location,
401            }),
402        };
403
404        Ok((cur_block_end, Some(block_continuation)))
405    }
406
407    fn handle_end(mut self, end: BlockEnd<'db>) -> PanicLoweringContext<'db> {
408        let end = match end {
409            BlockEnd::Goto(target, remapping) => BlockEnd::Goto(target, remapping),
410            BlockEnd::Panic(err_data) => {
411                // Wrap with PanicResult::Err.
412                let ty = self.ctx.panic_info.actual_return_ty;
413                let location = err_data.location;
414                let output = if self.ctx.panic_info.always_panic {
415                    err_data.var_id
416                } else {
417                    let output = self.new_var(ty, location);
418                    self.statements.push(Statement::EnumConstruct(StatementEnumConstruct {
419                        variant: self.ctx.panic_info.err_variant,
420                        input: err_data,
421                        output,
422                    }));
423                    output
424                };
425                BlockEnd::Return(vec![VarUsage { var_id: output, location }], location)
426            }
427            BlockEnd::Return(returns, location) => {
428                // Tuple construction.
429                let tupled_res = self.new_var(self.ctx.panic_info.ok_ty, location);
430                self.statements.push(Statement::StructConstruct(StatementStructConstruct {
431                    inputs: returns,
432                    output: tupled_res,
433                }));
434
435                // Wrap with PanicResult::Ok.
436                let ty = self.ctx.panic_info.actual_return_ty;
437                let output = self.new_var(ty, location);
438                self.statements.push(Statement::EnumConstruct(StatementEnumConstruct {
439                    variant: self.ctx.panic_info.ok_variant,
440                    input: VarUsage { var_id: tupled_res, location },
441                    output,
442                }));
443                BlockEnd::Return(vec![VarUsage { var_id: output, location }], location)
444            }
445            BlockEnd::NotSet => unreachable!(),
446            BlockEnd::Match { info } => BlockEnd::Match { info },
447        };
448        self.ctx.flat_blocks.alloc(Block { statements: self.statements, end });
449        self.ctx
450    }
451}
452
453// ============= Query implementations =============
454
455/// Query implementation of [crate::db::LoweringGroup::function_may_panic].
456#[salsa::tracked]
457pub fn function_may_panic<'db>(db: &'db dyn Database, function: FunctionId<'db>) -> Maybe<bool> {
458    if let Some(body) = function.body(db)? {
459        return db.function_with_body_may_panic(body);
460    }
461    Ok(function.signature(db)?.panicable)
462}
463
464/// A trait to add helper methods in [LoweringGroup].
465pub trait MayPanicTrait<'db>: Database {
466    /// Returns whether a [ConcreteFunctionWithBodyId] may panic.
467    fn function_with_body_may_panic(
468        &'db self,
469        function: ConcreteFunctionWithBodyId<'db>,
470    ) -> Maybe<bool> {
471        let db: &'db dyn Database = self.as_dyn_database();
472        let scc_representative = db.lowered_scc_representative(
473            function,
474            DependencyType::Call,
475            LoweringStage::Monomorphized,
476        );
477        scc_may_panic(db, scc_representative)
478    }
479}
480impl<'db, T: Database + ?Sized> MayPanicTrait<'db> for T {}
481
482/// Returns whether any function in the strongly connected component may panic.
483fn scc_may_panic<'db>(db: &'db dyn Database, scc: ConcreteSCCRepresentative<'db>) -> Maybe<bool> {
484    scc_may_panic_tracked(db, scc.0)
485}
486
487/// Tracked implementation of [scc_may_panic].
488#[salsa::tracked]
489fn scc_may_panic_tracked<'db>(
490    db: &'db dyn Database,
491    rep: ConcreteFunctionWithBodyId<'db>,
492) -> Maybe<bool> {
493    // Find the SCC representative.
494    let scc_functions = db.lowered_scc(rep, DependencyType::Call, LoweringStage::Monomorphized);
495    for function in scc_functions {
496        if db.needs_withdraw_gas(function)? {
497            return Ok(true);
498        }
499        if db.has_direct_panic(function)? {
500            return Ok(true);
501        }
502        // For each direct callee, find if it may panic.
503        let direct_callees = db.lowered_direct_callees(
504            function,
505            DependencyType::Call,
506            LoweringStage::Monomorphized,
507        )?;
508        for direct_callee in direct_callees {
509            if let Some(callee_body) = direct_callee.body(db)? {
510                let callee_scc = db.lowered_scc_representative(
511                    callee_body,
512                    DependencyType::Call,
513                    LoweringStage::Monomorphized,
514                );
515                if callee_scc.0 != rep && scc_may_panic(db, callee_scc)? {
516                    return Ok(true);
517                }
518            } else if direct_callee.signature(db)?.panicable {
519                return Ok(true);
520            }
521        }
522    }
523    Ok(false)
524}
525
526/// Query implementation of [crate::db::LoweringGroup::has_direct_panic].
527#[salsa::tracked]
528pub fn has_direct_panic<'db>(
529    db: &'db dyn Database,
530    function_id: ConcreteFunctionWithBodyId<'db>,
531) -> Maybe<bool> {
532    let lowered_function = db.lowered_body(function_id, LoweringStage::Monomorphized)?;
533    Ok(itertools::any(lowered_function.blocks.iter(), |(_, block)| {
534        matches!(&block.end, BlockEnd::Panic(..))
535    }))
536}