cairo_lint/lints/
unwrap_syscall.rs1use crate::{
2 context::{CairoLintKind, Lint},
3 fixer::InternalFix,
4 queries::{get_all_function_bodies, get_all_function_calls},
5};
6use cairo_lang_defs::ids::{NamedLanguageElementId, TopLevelLanguageElementId};
7use cairo_lang_defs::{ids::ModuleItemId, plugin::PluginDiagnostic};
8use cairo_lang_diagnostics::Severity;
9use cairo_lang_semantic::items::functions::{GenericFunctionId, ImplGenericFunctionId};
10use cairo_lang_semantic::items::imp::ImplHead;
11use cairo_lang_semantic::{
12 Arenas, ExprFunctionCall, ExprFunctionCallArg, GenericArgumentId, TypeId, TypeLongId,
13};
14use cairo_lang_syntax::node::{SyntaxNode, TypedStablePtr, TypedSyntaxNode, ast};
15use salsa::Database;
16
17pub struct UnwrapSyscall;
18
19const UNWRAP_SYSCALL_TRAIT_PATH: &str = "starknet::SyscallResultTrait";
20
21impl Lint for UnwrapSyscall {
52 fn allowed_name(&self) -> &'static str {
53 "unwrap_syscall"
54 }
55
56 fn diagnostic_message(&self) -> &'static str {
57 "consider using `unwrap_syscall` instead of `unwrap`"
58 }
59
60 fn kind(&self) -> CairoLintKind {
61 CairoLintKind::UnwrapSyscall
62 }
63
64 fn has_fixer(&self) -> bool {
65 true
66 }
67
68 fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
69 fix_unwrap_syscall(db, node)
70 }
71
72 fn fix_message(&self) -> Option<&'static str> {
73 Some("Replace with `unwrap_syscall()` for syscall results")
74 }
75}
76
77#[tracing::instrument(skip_all, level = "trace")]
78pub fn check_unwrap_syscall<'db>(
79 db: &'db dyn Database,
80 item: &ModuleItemId<'db>,
81 diagnostics: &mut Vec<PluginDiagnostic<'db>>,
82) {
83 let function_bodies = get_all_function_bodies(db, item);
84 for function_body in function_bodies {
85 let function_call_exprs = get_all_function_calls(function_body);
86 let arenas = &function_body.arenas;
87 for function_call_expr in function_call_exprs {
88 check_single_unwrap_syscall(db, &function_call_expr, arenas, diagnostics);
89 }
90 }
91}
92
93fn check_single_unwrap_syscall<'db>(
94 db: &'db dyn Database,
95 expr: &ExprFunctionCall<'db>,
96 arenas: &Arenas<'db>,
97 diagnostics: &mut Vec<PluginDiagnostic<'db>>,
98) {
99 if is_result_trait_impl_unwrap_call(db, expr)
100 && let Some(ExprFunctionCallArg::Value(expr_id)) = expr.args.first()
101 && let receiver_expr = &arenas.exprs[*expr_id]
102 && is_syscall_result_type(db, receiver_expr.ty())
103 {
104 diagnostics.push(PluginDiagnostic {
105 stable_ptr: receiver_expr
106 .stable_ptr()
107 .lookup(db)
108 .as_syntax_node()
109 .parent(db)
110 .unwrap()
111 .stable_ptr(db),
112 message: UnwrapSyscall.diagnostic_message().to_string(),
113 severity: Severity::Warning,
114 inner_span: None,
115 error_code: None,
116 });
117 }
118}
119
120fn is_result_trait_impl_unwrap_call(db: &dyn Database, expr: &ExprFunctionCall) -> bool {
122 if let GenericFunctionId::Impl(ImplGenericFunctionId { impl_id, function }) =
123 expr.function.get_concrete(db).generic_function
124 && function.name(db).long(db).as_str() == "unwrap"
125 && let Some(ImplHead::Concrete(impl_def_id)) = impl_id.head(db)
126 && impl_def_id.full_path(db) == "core::result::ResultTraitImpl"
127 {
128 true
129 } else {
130 false
131 }
132}
133
134fn is_syscall_result_type(db: &dyn Database, ty: TypeId) -> bool {
136 is_specific_concrete_generic_type(db, ty, "core::result::Result", |[_, arg_e]| {
137 if let GenericArgumentId::Type(arg_e) = arg_e
138 && is_specific_concrete_generic_type(db, arg_e, "core::array::Array", |[arg]| {
139 if let GenericArgumentId::Type(arg) = arg
140 && is_specific_concrete_type(db, arg, "core::felt252")
141 {
142 true
143 } else {
144 false
145 }
146 })
147 {
148 true
149 } else {
150 false
151 }
152 })
153}
154
155fn is_specific_concrete_type(db: &dyn Database, ty: TypeId, full_path: &str) -> bool {
156 if let TypeLongId::Concrete(concrete_type_long_id) = ty.long(db)
157 && concrete_type_long_id.generic_type(db).full_path(db) == full_path
158 {
159 true
160 } else {
161 false
162 }
163}
164
165fn is_specific_concrete_generic_type<'db, const N: usize>(
166 db: &'db dyn Database,
167 ty: TypeId<'db>,
168 full_path: &str,
169 generic_args_matcher: impl FnOnce([GenericArgumentId<'db>; N]) -> bool,
170) -> bool {
171 if let TypeLongId::Concrete(concrete_type_long_id) = ty.long(db)
172 && concrete_type_long_id.generic_type(db).full_path(db) == full_path
173 && let Ok(generic_args) =
174 <[GenericArgumentId; N]>::try_from(concrete_type_long_id.generic_args(db))
175 && generic_args_matcher(generic_args)
176 {
177 true
178 } else {
179 false
180 }
181}
182
183#[tracing::instrument(skip_all, level = "trace")]
184fn fix_unwrap_syscall<'db>(
185 db: &'db dyn Database,
186 node: SyntaxNode<'db>,
187) -> Option<InternalFix<'db>> {
188 let ast_expr_binary = ast::ExprBinary::cast(db, node).unwrap_or_else(|| {
189 panic!(
190 "Expected a binary expression for unwrap called on SyscallResult. Actual node text: {:?}",
191 node.get_text(db)
192 )
193 });
194
195 let fixed = format!(
196 "{}{}unwrap_syscall()",
197 ast_expr_binary.lhs(db).as_syntax_node().get_text(db),
198 ast_expr_binary.op(db).as_syntax_node().get_text(db)
199 );
200 Some(InternalFix {
201 node,
202 suggestion: fixed,
203 description: UnwrapSyscall.fix_message().unwrap().to_string(),
204 import_addition_paths: Some(vec![UNWRAP_SYSCALL_TRAIT_PATH.to_string()]),
205 })
206}