Skip to main content

cairo_lint/lints/
unwrap_syscall.rs

1use crate::{
2    context::{CairoLintKind, Lint},
3    fixer::InternalFix,
4    queries::{get_all_function_bodies, get_all_function_calls},
5};
6use cairo_lang_defs::ids::{NamedLanguageElementId, TopLevelLanguageElementId};
7use cairo_lang_defs::{ids::ModuleItemId, plugin::PluginDiagnostic};
8use cairo_lang_diagnostics::Severity;
9use cairo_lang_semantic::items::functions::{GenericFunctionId, ImplGenericFunctionId};
10use cairo_lang_semantic::items::imp::ImplHead;
11use cairo_lang_semantic::{
12    Arenas, ExprFunctionCall, ExprFunctionCallArg, GenericArgumentId, TypeId, TypeLongId,
13};
14use cairo_lang_syntax::node::{SyntaxNode, TypedStablePtr, TypedSyntaxNode, ast};
15use salsa::Database;
16
17pub struct UnwrapSyscall;
18
19const UNWRAP_SYSCALL_TRAIT_PATH: &str = "starknet::SyscallResultTrait";
20
21/// ## What it does
22///
23/// Detects if the function uses `unwrap` on a `SyscallResult` object.
24///
25/// ## Example
26///
27/// ```cairo
28/// use starknet::storage_access::{storage_address_from_base, storage_base_address_from_felt252};
29/// use starknet::syscalls::storage_read_syscall;
30///
31/// fn main() {
32///     let storage_address = storage_base_address_from_felt252(3534535754756246375475423547453);
33///     let result = storage_read_syscall(0, storage_address_from_base(storage_address));
34///     result.unwrap();
35/// }
36/// ```
37///
38/// Can be changed to:
39///
40/// ```cairo
41/// use starknet::SyscallResultTrait;
42/// use starknet::storage_access::{storage_address_from_base, storage_base_address_from_felt252};
43/// use starknet::syscalls::storage_read_syscall;
44///
45/// fn main() {
46///     let storage_address = storage_base_address_from_felt252(3534535754756246375475423547453);
47///     let result = storage_read_syscall(0, storage_address_from_base(storage_address));
48///     result.unwrap_syscall();
49/// }
50/// ```
51impl Lint for UnwrapSyscall {
52    fn allowed_name(&self) -> &'static str {
53        "unwrap_syscall"
54    }
55
56    fn diagnostic_message(&self) -> &'static str {
57        "consider using `unwrap_syscall` instead of `unwrap`"
58    }
59
60    fn kind(&self) -> CairoLintKind {
61        CairoLintKind::UnwrapSyscall
62    }
63
64    fn has_fixer(&self) -> bool {
65        true
66    }
67
68    fn fix<'db>(&self, db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<InternalFix<'db>> {
69        fix_unwrap_syscall(db, node)
70    }
71
72    fn fix_message(&self) -> Option<&'static str> {
73        Some("Replace with `unwrap_syscall()` for syscall results")
74    }
75}
76
77#[tracing::instrument(skip_all, level = "trace")]
78pub fn check_unwrap_syscall<'db>(
79    db: &'db dyn Database,
80    item: &ModuleItemId<'db>,
81    diagnostics: &mut Vec<PluginDiagnostic<'db>>,
82) {
83    let function_bodies = get_all_function_bodies(db, item);
84    for function_body in function_bodies {
85        let function_call_exprs = get_all_function_calls(function_body);
86        let arenas = &function_body.arenas;
87        for function_call_expr in function_call_exprs {
88            check_single_unwrap_syscall(db, &function_call_expr, arenas, diagnostics);
89        }
90    }
91}
92
93fn check_single_unwrap_syscall<'db>(
94    db: &'db dyn Database,
95    expr: &ExprFunctionCall<'db>,
96    arenas: &Arenas<'db>,
97    diagnostics: &mut Vec<PluginDiagnostic<'db>>,
98) {
99    if is_result_trait_impl_unwrap_call(db, expr)
100        && let Some(ExprFunctionCallArg::Value(expr_id)) = expr.args.first()
101        && let receiver_expr = &arenas.exprs[*expr_id]
102        && is_syscall_result_type(db, receiver_expr.ty())
103    {
104        diagnostics.push(PluginDiagnostic {
105            stable_ptr: receiver_expr
106                .stable_ptr()
107                .lookup(db)
108                .as_syntax_node()
109                .parent(db)
110                .unwrap()
111                .stable_ptr(db),
112            message: UnwrapSyscall.diagnostic_message().to_string(),
113            severity: Severity::Warning,
114            inner_span: None,
115            error_code: None,
116        });
117    }
118}
119
120/// Check if this function call expression calls `core::result::ResultTraitImpl<_>::unwrap`.
121fn is_result_trait_impl_unwrap_call(db: &dyn Database, expr: &ExprFunctionCall) -> bool {
122    if let GenericFunctionId::Impl(ImplGenericFunctionId { impl_id, function }) =
123        expr.function.get_concrete(db).generic_function
124        && function.name(db).long(db).as_str() == "unwrap"
125        && let Some(ImplHead::Concrete(impl_def_id)) = impl_id.head(db)
126        && impl_def_id.full_path(db) == "core::result::ResultTraitImpl"
127    {
128        true
129    } else {
130        false
131    }
132}
133
134// Checks if the type is `Result<T, Array<felt252>>`.
135fn is_syscall_result_type(db: &dyn Database, ty: TypeId) -> bool {
136    is_specific_concrete_generic_type(db, ty, "core::result::Result", |[_, arg_e]| {
137        if let GenericArgumentId::Type(arg_e) = arg_e
138            && is_specific_concrete_generic_type(db, arg_e, "core::array::Array", |[arg]| {
139                if let GenericArgumentId::Type(arg) = arg
140                    && is_specific_concrete_type(db, arg, "core::felt252")
141                {
142                    true
143                } else {
144                    false
145                }
146            })
147        {
148            true
149        } else {
150            false
151        }
152    })
153}
154
155fn is_specific_concrete_type(db: &dyn Database, ty: TypeId, full_path: &str) -> bool {
156    if let TypeLongId::Concrete(concrete_type_long_id) = ty.long(db)
157        && concrete_type_long_id.generic_type(db).full_path(db) == full_path
158    {
159        true
160    } else {
161        false
162    }
163}
164
165fn is_specific_concrete_generic_type<'db, const N: usize>(
166    db: &'db dyn Database,
167    ty: TypeId<'db>,
168    full_path: &str,
169    generic_args_matcher: impl FnOnce([GenericArgumentId<'db>; N]) -> bool,
170) -> bool {
171    if let TypeLongId::Concrete(concrete_type_long_id) = ty.long(db)
172        && concrete_type_long_id.generic_type(db).full_path(db) == full_path
173        && let Ok(generic_args) =
174            <[GenericArgumentId; N]>::try_from(concrete_type_long_id.generic_args(db))
175        && generic_args_matcher(generic_args)
176    {
177        true
178    } else {
179        false
180    }
181}
182
183#[tracing::instrument(skip_all, level = "trace")]
184fn fix_unwrap_syscall<'db>(
185    db: &'db dyn Database,
186    node: SyntaxNode<'db>,
187) -> Option<InternalFix<'db>> {
188    let ast_expr_binary = ast::ExprBinary::cast(db, node).unwrap_or_else(|| {
189        panic!(
190          "Expected a binary expression for unwrap called on SyscallResult. Actual node text: {:?}",
191          node.get_text(db)
192        )
193    });
194
195    let fixed = format!(
196        "{}{}unwrap_syscall()",
197        ast_expr_binary.lhs(db).as_syntax_node().get_text(db),
198        ast_expr_binary.op(db).as_syntax_node().get_text(db)
199    );
200    Some(InternalFix {
201        node,
202        suggestion: fixed,
203        description: UnwrapSyscall.fix_message().unwrap().to_string(),
204        import_addition_paths: Some(vec![UNWRAP_SYSCALL_TRAIT_PATH.to_string()]),
205    })
206}