cairo_lang_lowering/optimizations/
early_unsafe_panic.rs1#[cfg(test)]
2#[path = "early_unsafe_panic_test.rs"]
3mod test;
4
5use std::collections::HashSet;
6
7use cairo_lang_defs::ids::ExternFunctionId;
8use cairo_lang_filesystem::flag::FlagsGroup;
9use cairo_lang_semantic::helper::ModuleHelper;
10use itertools::zip_eq;
11use salsa::Database;
12
13use crate::analysis::{Analyzer, BackAnalysis, StatementLocation};
14use crate::ids::{LocationId, SemanticFunctionIdEx};
15use crate::{
16 BlockEnd, BlockId, Lowered, MatchExternInfo, MatchInfo, Statement, StatementCall, VarUsage,
17};
18
19pub fn early_unsafe_panic<'db>(db: &'db dyn Database, lowered: &mut Lowered<'db>) {
25 if !db.flag_unsafe_panic() || lowered.blocks.is_empty() {
26 return;
27 }
28
29 let core = ModuleHelper::core(db);
30 let libfuncs_with_sideffect = HashSet::from_iter([
31 core.submodule("debug").extern_function_id("print"),
32 core.submodule("internal").extern_function_id("trace"),
33 ]);
34
35 let ctx = UnsafePanicContext { db, fixes: vec![], libfuncs_with_sideffect };
36 let mut analysis = BackAnalysis::new(lowered, ctx);
37 let fixes = if let ReachableSideEffects::Unreachable(location) = analysis.get_root_info() {
38 vec![((BlockId::root(), 0), location)]
39 } else {
40 analysis.analyzer.fixes
41 };
42
43 let panic_func_id = core.submodule("panics").function_id("unsafe_panic", vec![]).lowered(db);
44 for ((block_id, statement_idx), location) in fixes {
45 let block = &mut lowered.blocks[block_id];
46 block.statements.truncate(statement_idx);
47
48 block.end = BlockEnd::Match {
49 info: MatchInfo::Extern(MatchExternInfo {
50 arms: vec![],
51 location,
52 function: panic_func_id,
53 inputs: vec![],
54 }),
55 }
56 }
57}
58
59pub struct UnsafePanicContext<'db> {
60 db: &'db dyn Database,
61
62 fixes: Vec<(StatementLocation, LocationId<'db>)>,
64
65 libfuncs_with_sideffect: HashSet<ExternFunctionId<'db>>,
67}
68
69impl<'db> UnsafePanicContext<'db> {
70 pub fn has_side_effects(&self, stmt: &Statement<'db>) -> bool {
72 if let Statement::Call(StatementCall { function, .. }) = stmt {
73 let Some((extern_fn, _gargs)) = function.get_extern(self.db) else {
74 return false;
75 };
76
77 if self.libfuncs_with_sideffect.contains(&extern_fn) {
78 return true;
79 }
80 }
81
82 false
83 }
84}
85
86#[derive(Clone, Default, PartialEq, Debug)]
88pub enum ReachableSideEffects<'db> {
89 #[default]
91 Reachable,
92 Unreachable(LocationId<'db>),
95}
96
97impl<'db> Analyzer<'db, '_> for UnsafePanicContext<'db> {
98 type Info = ReachableSideEffects<'db>;
99
100 fn visit_stmt(
101 &mut self,
102 info: &mut Self::Info,
103 statement_location: StatementLocation,
104 stmt: &Statement<'db>,
105 ) {
106 if self.has_side_effects(stmt)
107 && let ReachableSideEffects::Unreachable(locations) = *info
108 {
109 self.fixes.push((statement_location, locations));
110 *info = ReachableSideEffects::Reachable
111 }
112 }
113
114 fn merge_match(
115 &mut self,
116 statement_location: StatementLocation,
117 match_info: &MatchInfo<'db>,
118 infos: impl Iterator<Item = Self::Info>,
119 ) -> Self::Info {
120 let mut res = ReachableSideEffects::Unreachable(*match_info.location());
121 for (arm, info) in zip_eq(match_info.arms(), infos) {
122 match info {
123 ReachableSideEffects::Reachable => {
124 res = ReachableSideEffects::Reachable;
125 }
126 ReachableSideEffects::Unreachable(l) => self.fixes.push(((arm.block_id, 0), l)),
127 }
128 }
129
130 if let ReachableSideEffects::Unreachable(location) = res {
131 self.fixes.push((statement_location, location));
132 }
133
134 res
135 }
136
137 fn info_from_return(&mut self, _: StatementLocation, _vars: &[VarUsage<'db>]) -> Self::Info {
138 ReachableSideEffects::Reachable
139 }
140}