Skip to main content

cairo_lang_lowering/optimizations/
early_unsafe_panic.rs

1#[cfg(test)]
2#[path = "early_unsafe_panic_test.rs"]
3mod test;
4
5use std::collections::HashSet;
6
7use cairo_lang_defs::ids::ExternFunctionId;
8use cairo_lang_filesystem::flag::FlagsGroup;
9use cairo_lang_semantic::helper::ModuleHelper;
10use itertools::zip_eq;
11use salsa::Database;
12
13use crate::analysis::{Analyzer, BackAnalysis, StatementLocation};
14use crate::ids::{LocationId, SemanticFunctionIdEx};
15use crate::{
16    BlockEnd, BlockId, Lowered, MatchExternInfo, MatchInfo, Statement, StatementCall, VarUsage,
17};
18
19/// Adds an early unsafe_panic when we detect that `return` is unreachable from a certain point in
20/// the code. This step is needed to avoid issues with undroppable references in Sierra to CASM.
21///
22/// This step might replace a match on an empty enum with a call to unsafe_panic and we rely on the
23/// 'trim_unreachable' optimization to clean that up.
24pub fn early_unsafe_panic<'db>(db: &'db dyn Database, lowered: &mut Lowered<'db>) {
25    if !db.flag_unsafe_panic() || lowered.blocks.is_empty() {
26        return;
27    }
28
29    let core = ModuleHelper::core(db);
30    let libfuncs_with_sideffect = HashSet::from_iter([
31        core.submodule("debug").extern_function_id("print"),
32        core.submodule("internal").extern_function_id("trace"),
33    ]);
34
35    let ctx = UnsafePanicContext { db, fixes: vec![], libfuncs_with_sideffect };
36    let mut analysis = BackAnalysis::new(lowered, ctx);
37    let fixes = if let ReachableSideEffects::Unreachable(location) = analysis.get_root_info() {
38        vec![((BlockId::root(), 0), location)]
39    } else {
40        analysis.analyzer.fixes
41    };
42
43    let panic_func_id = core.submodule("panics").function_id("unsafe_panic", vec![]).lowered(db);
44    for ((block_id, statement_idx), location) in fixes {
45        let block = &mut lowered.blocks[block_id];
46        block.statements.truncate(statement_idx);
47
48        block.end = BlockEnd::Match {
49            info: MatchInfo::Extern(MatchExternInfo {
50                arms: vec![],
51                location,
52                function: panic_func_id,
53                inputs: vec![],
54            }),
55        }
56    }
57}
58
59pub struct UnsafePanicContext<'db> {
60    db: &'db dyn Database,
61
62    /// The list of blocks where we can insert unsafe_panic.
63    fixes: Vec<(StatementLocation, LocationId<'db>)>,
64
65    /// libfuncs with side effects that we need to ignore.
66    libfuncs_with_sideffect: HashSet<ExternFunctionId<'db>>,
67}
68
69impl<'db> UnsafePanicContext<'db> {
70    /// Returns true if the statement has side effects.
71    pub fn has_side_effects(&self, stmt: &Statement<'db>) -> bool {
72        if let Statement::Call(StatementCall { function, .. }) = stmt {
73            let Some((extern_fn, _gargs)) = function.get_extern(self.db) else {
74                return false;
75            };
76
77            if self.libfuncs_with_sideffect.contains(&extern_fn) {
78                return true;
79            }
80        }
81
82        false
83    }
84}
85
86/// Can this state lead to a return or a statement with side effect.
87#[derive(Clone, Default, PartialEq, Debug)]
88pub enum ReachableSideEffects<'db> {
89    /// Some return statement or statement with side effect is reachable.
90    #[default]
91    Reachable,
92    /// No return statement or statement with side effect is reachable.
93    /// holds the location of the closest match with no returning arms.
94    Unreachable(LocationId<'db>),
95}
96
97impl<'db> Analyzer<'db, '_> for UnsafePanicContext<'db> {
98    type Info = ReachableSideEffects<'db>;
99
100    fn visit_stmt(
101        &mut self,
102        info: &mut Self::Info,
103        statement_location: StatementLocation,
104        stmt: &Statement<'db>,
105    ) {
106        if self.has_side_effects(stmt)
107            && let ReachableSideEffects::Unreachable(locations) = *info
108        {
109            self.fixes.push((statement_location, locations));
110            *info = ReachableSideEffects::Reachable
111        }
112    }
113
114    fn merge_match(
115        &mut self,
116        statement_location: StatementLocation,
117        match_info: &MatchInfo<'db>,
118        infos: impl Iterator<Item = Self::Info>,
119    ) -> Self::Info {
120        let mut res = ReachableSideEffects::Unreachable(*match_info.location());
121        for (arm, info) in zip_eq(match_info.arms(), infos) {
122            match info {
123                ReachableSideEffects::Reachable => {
124                    res = ReachableSideEffects::Reachable;
125                }
126                ReachableSideEffects::Unreachable(l) => self.fixes.push(((arm.block_id, 0), l)),
127            }
128        }
129
130        if let ReachableSideEffects::Unreachable(location) = res {
131            self.fixes.push((statement_location, location));
132        }
133
134        res
135    }
136
137    fn info_from_return(&mut self, _: StatementLocation, _vars: &[VarUsage<'db>]) -> Self::Info {
138        ReachableSideEffects::Reachable
139    }
140}