cairo_lang_plugins/plugins/
panicable.rs

1use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
2use cairo_lang_defs::plugin::{
3    MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
4};
5use cairo_lang_filesystem::ids::SmolStrId;
6use cairo_lang_syntax::attribute::structured::{
7    Attribute, AttributeArg, AttributeArgVariant, AttributeStructurize,
8};
9use cairo_lang_syntax::node::helpers::{GetIdentifier, QueryAttrs};
10use cairo_lang_syntax::node::{TypedSyntaxNode, ast};
11use cairo_lang_utils::try_extract_matches;
12use indoc::formatdoc;
13use itertools::Itertools;
14use salsa::Database;
15
16#[derive(Debug, Default)]
17#[non_exhaustive]
18pub struct PanicablePlugin;
19
20const PANIC_WITH_ATTR: &str = "panic_with";
21
22impl MacroPlugin for PanicablePlugin {
23    fn generate_code<'db>(
24        &self,
25        db: &'db dyn Database,
26        item_ast: ast::ModuleItem<'db>,
27        _metadata: &MacroPluginMetadata<'_>,
28    ) -> PluginResult<'db> {
29        let (declaration, attributes, visibility) = match item_ast {
30            ast::ModuleItem::ExternFunction(extern_func_ast) => (
31                extern_func_ast.declaration(db),
32                extern_func_ast.attributes(db),
33                extern_func_ast.visibility(db),
34            ),
35            ast::ModuleItem::FreeFunction(free_func_ast) => (
36                free_func_ast.declaration(db),
37                free_func_ast.attributes(db),
38                free_func_ast.visibility(db),
39            ),
40            _ => return PluginResult::default(),
41        };
42
43        generate_panicable_code(db, declaration, attributes, visibility)
44    }
45
46    fn declared_attributes<'db>(&self, db: &'db dyn Database) -> Vec<SmolStrId<'db>> {
47        vec![SmolStrId::from(db, PANIC_WITH_ATTR)]
48    }
49}
50
51/// Generate code defining a panicable variant of a function marked with `#[panic_with]` attribute.
52fn generate_panicable_code<'db>(
53    db: &'db dyn Database,
54    declaration: ast::FunctionDeclaration<'db>,
55    attributes: ast::AttributeList<'db>,
56    visibility: ast::Visibility<'db>,
57) -> PluginResult<'db> {
58    let mut attrs = attributes.query_attr(db, PANIC_WITH_ATTR);
59    let Some(attr) = attrs.next() else {
60        // No `#[panic_with]` attribute found.
61        return PluginResult::default();
62    };
63    let mut diagnostics = vec![];
64    if let Some(extra_attr) = attrs.next() {
65        diagnostics.push(PluginDiagnostic::error(
66            extra_attr.stable_ptr(db),
67            "`#[panic_with]` cannot be applied multiple times to the same item.".into(),
68        ));
69        return PluginResult { code: None, diagnostics, remove_original_item: false };
70    }
71
72    let signature = declaration.signature(db);
73    let Some((inner_ty, success_variant, failure_variant)) =
74        extract_success_ty_and_variants(db, &signature)
75    else {
76        diagnostics.push(PluginDiagnostic::error(
77            signature.ret_ty(db).stable_ptr(db),
78            "Currently only wrapping functions returning an Option<T> or Result<T, E>".into(),
79        ));
80        return PluginResult { code: None, diagnostics, remove_original_item: false };
81    };
82
83    let mut builder = PatchBuilder::new(db, &attr);
84    let attr = attr.structurize(db);
85
86    let Some((err_value, panicable_name)) = parse_arguments(db, &attr) else {
87        diagnostics.push(PluginDiagnostic::error(
88            attr.stable_ptr,
89            "Failed to extract panic data attribute".into(),
90        ));
91        return PluginResult { code: None, diagnostics, remove_original_item: false };
92    };
93    builder.add_node(visibility.as_syntax_node());
94    builder.add_node(declaration.function_kw(db).as_syntax_node());
95    builder.add_modified(RewriteNode::from_ast_trimmed(&panicable_name));
96    builder.add_node(declaration.generic_params(db).as_syntax_node());
97    builder.add_node(signature.lparen(db).as_syntax_node());
98    builder.add_node(signature.parameters(db).as_syntax_node());
99    builder.add_node(signature.rparen(db).as_syntax_node());
100    let args = signature
101        .parameters(db)
102        .elements(db)
103        .map(|param| {
104            let ref_kw = if let Some([ast::Modifier::Ref(_)]) =
105                param.modifiers(db).elements(db).collect_array()
106            {
107                "ref "
108            } else {
109                ""
110            };
111            format!("{}{}", ref_kw, param.name(db).as_syntax_node().get_text(db))
112        })
113        .join(", ");
114    builder.add_modified(RewriteNode::interpolate_patched(
115        &formatdoc!(
116            r#"
117                -> $inner_ty$ {{
118                    match $function_name$({args}) {{
119                        {success_variant}(v) => v,
120                        {failure_variant}(_) => core::panic_with_const_felt252::<$err_value$>(),
121                    }}
122                }}
123            "#
124        ),
125        &[
126            ("inner_ty".to_string(), RewriteNode::from_ast_trimmed(&inner_ty)),
127            ("function_name".to_string(), RewriteNode::from_ast_trimmed(&declaration.name(db))),
128            ("err_value".to_string(), RewriteNode::from_ast_trimmed(&err_value)),
129        ]
130        .into(),
131    ));
132
133    let (content, code_mappings) = builder.build();
134    PluginResult {
135        code: Some(PluginGeneratedFile {
136            name: "panicable".into(),
137            content,
138            code_mappings,
139            aux_data: None,
140            diagnostics_note: Default::default(),
141            is_unhygienic: false,
142        }),
143        diagnostics,
144        remove_original_item: false,
145    }
146}
147
148/// Given a function signature, if it returns `Option::<T>` or `Result::<T, E>`, returns T and the
149/// variant match strings. Otherwise, returns None.
150fn extract_success_ty_and_variants<'a>(
151    db: &'a dyn Database,
152    signature: &ast::FunctionSignature<'a>,
153) -> Option<(ast::GenericArg<'a>, String, String)> {
154    let ret_ty_expr =
155        try_extract_matches!(signature.ret_ty(db), ast::OptionReturnTypeClause::ReturnTypeClause)?
156            .ty(db);
157    let ret_ty_path = try_extract_matches!(ret_ty_expr, ast::Expr::Path)?;
158
159    // Currently only wrapping functions returning an Option<T>.
160    let Some([ast::PathSegment::WithGenericArgs(segment)]) =
161        ret_ty_path.segments(db).elements(db).collect_array()
162    else {
163        return None;
164    };
165    let ty = segment.identifier(db).long(db);
166    if ty == "Option" {
167        let [inner] = segment.generic_args(db).generic_args(db).elements(db).collect_array()?;
168        Some((inner.clone(), "Option::Some".to_owned(), "Option::None".to_owned()))
169    } else if ty == "Result" {
170        let [inner, _err] =
171            segment.generic_args(db).generic_args(db).elements(db).collect_array()?;
172        Some((inner.clone(), "Result::Ok".to_owned(), "Result::Err".to_owned()))
173    } else {
174        None
175    }
176}
177
178/// Parse `#[panic_with(...)]` attribute arguments and return a tuple with error value and
179/// panicable function name.
180fn parse_arguments<'a>(
181    db: &'a dyn Database,
182    attr: &Attribute<'a>,
183) -> Option<(ast::TerminalShortString<'a>, ast::TerminalIdentifier<'a>)> {
184    let [
185        AttributeArg {
186            variant: AttributeArgVariant::Unnamed(ast::Expr::ShortString(err_value)),
187            ..
188        },
189        AttributeArg { variant: AttributeArgVariant::Unnamed(ast::Expr::Path(name)), .. },
190    ] = &attr.args[..]
191    else {
192        return None;
193    };
194
195    let Some([ast::PathSegment::Simple(segment)]) = name.segments(db).elements(db).collect_array()
196    else {
197        return None;
198    };
199
200    Some((err_value.clone(), segment.ident(db)))
201}