cairo_lang_plugins/plugins/
panicable.rs1use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
2use cairo_lang_defs::plugin::{
3 MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
4};
5use cairo_lang_filesystem::ids::SmolStrId;
6use cairo_lang_syntax::attribute::structured::{
7 Attribute, AttributeArg, AttributeArgVariant, AttributeStructurize,
8};
9use cairo_lang_syntax::node::helpers::{GetIdentifier, QueryAttrs};
10use cairo_lang_syntax::node::{TypedSyntaxNode, ast};
11use cairo_lang_utils::try_extract_matches;
12use indoc::formatdoc;
13use itertools::Itertools;
14use salsa::Database;
15
16#[derive(Debug, Default)]
17#[non_exhaustive]
18pub struct PanicablePlugin;
19
20const PANIC_WITH_ATTR: &str = "panic_with";
21
22impl MacroPlugin for PanicablePlugin {
23 fn generate_code<'db>(
24 &self,
25 db: &'db dyn Database,
26 item_ast: ast::ModuleItem<'db>,
27 _metadata: &MacroPluginMetadata<'_>,
28 ) -> PluginResult<'db> {
29 let (declaration, attributes, visibility) = match item_ast {
30 ast::ModuleItem::ExternFunction(extern_func_ast) => (
31 extern_func_ast.declaration(db),
32 extern_func_ast.attributes(db),
33 extern_func_ast.visibility(db),
34 ),
35 ast::ModuleItem::FreeFunction(free_func_ast) => (
36 free_func_ast.declaration(db),
37 free_func_ast.attributes(db),
38 free_func_ast.visibility(db),
39 ),
40 _ => return PluginResult::default(),
41 };
42
43 generate_panicable_code(db, declaration, attributes, visibility)
44 }
45
46 fn declared_attributes<'db>(&self, db: &'db dyn Database) -> Vec<SmolStrId<'db>> {
47 vec![SmolStrId::from(db, PANIC_WITH_ATTR)]
48 }
49}
50
51fn generate_panicable_code<'db>(
53 db: &'db dyn Database,
54 declaration: ast::FunctionDeclaration<'db>,
55 attributes: ast::AttributeList<'db>,
56 visibility: ast::Visibility<'db>,
57) -> PluginResult<'db> {
58 let mut attrs = attributes.query_attr(db, PANIC_WITH_ATTR);
59 let Some(attr) = attrs.next() else {
60 return PluginResult::default();
62 };
63 let mut diagnostics = vec![];
64 if let Some(extra_attr) = attrs.next() {
65 diagnostics.push(PluginDiagnostic::error(
66 extra_attr.stable_ptr(db),
67 "`#[panic_with]` cannot be applied multiple times to the same item.".into(),
68 ));
69 return PluginResult { code: None, diagnostics, remove_original_item: false };
70 }
71
72 let signature = declaration.signature(db);
73 let Some((inner_ty, success_variant, failure_variant)) =
74 extract_success_ty_and_variants(db, &signature)
75 else {
76 diagnostics.push(PluginDiagnostic::error(
77 signature.ret_ty(db).stable_ptr(db),
78 "Currently only wrapping functions returning an Option<T> or Result<T, E>".into(),
79 ));
80 return PluginResult { code: None, diagnostics, remove_original_item: false };
81 };
82
83 let mut builder = PatchBuilder::new(db, &attr);
84 let attr = attr.structurize(db);
85
86 let Some((err_value, panicable_name)) = parse_arguments(db, &attr) else {
87 diagnostics.push(PluginDiagnostic::error(
88 attr.stable_ptr,
89 "Failed to extract panic data attribute".into(),
90 ));
91 return PluginResult { code: None, diagnostics, remove_original_item: false };
92 };
93 builder.add_node(visibility.as_syntax_node());
94 builder.add_node(declaration.function_kw(db).as_syntax_node());
95 builder.add_modified(RewriteNode::from_ast_trimmed(&panicable_name));
96 builder.add_node(declaration.generic_params(db).as_syntax_node());
97 builder.add_node(signature.lparen(db).as_syntax_node());
98 builder.add_node(signature.parameters(db).as_syntax_node());
99 builder.add_node(signature.rparen(db).as_syntax_node());
100 let args = signature
101 .parameters(db)
102 .elements(db)
103 .map(|param| {
104 let ref_kw =
105 if let Ok(ast::Modifier::Ref(_)) = param.modifiers(db).elements(db).exactly_one() {
106 "ref "
107 } else {
108 ""
109 };
110 format!("{}{}", ref_kw, param.name(db).as_syntax_node().get_text(db))
111 })
112 .join(", ");
113 builder.add_modified(RewriteNode::interpolate_patched(
114 &formatdoc!(
115 r#"
116 -> $inner_ty$ {{
117 match $function_name$({args}) {{
118 {success_variant}(v) => v,
119 {failure_variant}(_) => core::panic_with_const_felt252::<$err_value$>(),
120 }}
121 }}
122 "#
123 ),
124 &[
125 ("inner_ty".to_string(), RewriteNode::from_ast_trimmed(&inner_ty)),
126 ("function_name".to_string(), RewriteNode::from_ast_trimmed(&declaration.name(db))),
127 ("err_value".to_string(), RewriteNode::from_ast_trimmed(&err_value)),
128 ]
129 .into(),
130 ));
131
132 let (content, code_mappings) = builder.build();
133 PluginResult {
134 code: Some(PluginGeneratedFile {
135 name: "panicable".into(),
136 content,
137 code_mappings,
138 aux_data: None,
139 diagnostics_note: Default::default(),
140 is_unhygienic: false,
141 }),
142 diagnostics,
143 remove_original_item: false,
144 }
145}
146
147fn extract_success_ty_and_variants<'a>(
150 db: &'a dyn Database,
151 signature: &ast::FunctionSignature<'a>,
152) -> Option<(ast::GenericArg<'a>, String, String)> {
153 let ret_ty_expr =
154 try_extract_matches!(signature.ret_ty(db), ast::OptionReturnTypeClause::ReturnTypeClause)?
155 .ty(db);
156 let ret_ty_path = try_extract_matches!(ret_ty_expr, ast::Expr::Path)?;
157
158 let Ok(ast::PathSegment::WithGenericArgs(segment)) =
160 ret_ty_path.segments(db).elements(db).exactly_one()
161 else {
162 return None;
163 };
164 let ty = segment.identifier(db).long(db);
165 if ty == "Option" {
166 let inner = segment.generic_args(db).generic_args(db).elements(db).exactly_one().ok()?;
167 Some((inner.clone(), "Option::Some".to_owned(), "Option::None".to_owned()))
168 } else if ty == "Result" {
169 let [inner, _err] =
170 segment.generic_args(db).generic_args(db).elements(db).collect_array()?;
171 Some((inner.clone(), "Result::Ok".to_owned(), "Result::Err".to_owned()))
172 } else {
173 None
174 }
175}
176
177fn parse_arguments<'a>(
180 db: &'a dyn Database,
181 attr: &Attribute<'a>,
182) -> Option<(ast::TerminalShortString<'a>, ast::TerminalIdentifier<'a>)> {
183 let [
184 AttributeArg {
185 variant: AttributeArgVariant::Unnamed(ast::Expr::ShortString(err_value)),
186 ..
187 },
188 AttributeArg { variant: AttributeArgVariant::Unnamed(ast::Expr::Path(name)), .. },
189 ] = &attr.args[..]
190 else {
191 return None;
192 };
193
194 let Ok(ast::PathSegment::Simple(segment)) = name.segments(db).elements(db).exactly_one() else {
195 return None;
196 };
197
198 Some((err_value.clone(), segment.ident(db)))
199}