1use cairo_lang_defs::ids::ModuleId;
2use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
3use cairo_lang_defs::plugin::{
4 MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
5};
6use cairo_lang_semantic::db::SemanticGroup;
7use cairo_lang_semantic::plugin::{AnalyzerPlugin, PluginSuite};
8use cairo_lang_semantic::{GenericArgumentId, Mutability, corelib};
9use cairo_lang_syntax::attribute::consts::IMPLICIT_PRECEDENCE_ATTR;
10use cairo_lang_syntax::node::db::SyntaxGroup;
11use cairo_lang_syntax::node::helpers::{OptionWrappedGenericParamListHelper, QueryAttrs};
12use cairo_lang_syntax::node::{TypedStablePtr, ast};
13use indoc::formatdoc;
14use itertools::Itertools;
15
16pub const EXECUTABLE_ATTR: &str = "executable";
17pub const EXECUTABLE_RAW_ATTR: &str = "executable_raw";
18pub const EXECUTABLE_PREFIX: &str = "__executable_wrapper__";
19
20pub fn executable_plugin_suite() -> PluginSuite {
22 std::mem::take(
23 PluginSuite::default()
24 .add_plugin::<ExecutablePlugin>()
25 .add_analyzer_plugin::<RawExecutableAnalyzer>(),
26 )
27}
28
29const IMPLICIT_PRECEDENCE: &[&str] = &[
30 "core::pedersen::Pedersen",
31 "core::RangeCheck",
32 "core::integer::Bitwise",
33 "core::ec::EcOp",
34 "core::poseidon::Poseidon",
35 "core::circuit::RangeCheck96",
36 "core::circuit::AddMod",
37 "core::circuit::MulMod",
38];
39
40#[derive(Debug, Default)]
41#[non_exhaustive]
42struct ExecutablePlugin;
43
44impl MacroPlugin for ExecutablePlugin {
45 fn generate_code(
46 &self,
47 db: &dyn SyntaxGroup,
48 item_ast: ast::ModuleItem,
49 _metadata: &MacroPluginMetadata<'_>,
50 ) -> PluginResult {
51 let ast::ModuleItem::FreeFunction(item) = item_ast else {
52 return PluginResult::default();
53 };
54 if !item.has_attr(db, EXECUTABLE_ATTR) {
55 return PluginResult::default();
56 }
57 let mut diagnostics = vec![];
58 let mut builder = PatchBuilder::new(db, &item);
59 let declaration = item.declaration(db);
60 let generics = declaration.generic_params(db);
61 if !generics.is_empty(db) {
62 diagnostics.push(PluginDiagnostic::error(
63 &generics,
64 "Executable functions cannot have generic params.".to_string(),
65 ));
66 }
67 let name = declaration.name(db);
68 let implicits_precedence =
69 RewriteNode::Text(format!("#[{IMPLICIT_PRECEDENCE_ATTR}({})]", {
70 IMPLICIT_PRECEDENCE.iter().join(", ")
71 }));
72 builder.add_modified(RewriteNode::interpolate_patched(&formatdoc! {"
73
74 $implicit_precedence$
75 #[{EXECUTABLE_RAW_ATTR}]
76 fn {EXECUTABLE_PREFIX}$function_name$(mut input: Span<felt252>, ref output: Array<felt252>) {{\n
77 "},
78 &[
79 ("implicit_precedence".into(), implicits_precedence,),
80 ("function_name".into(), RewriteNode::from_ast(&name))
81 ].into()
82 ));
83 let params = declaration.signature(db).parameters(db).elements(db);
84 for (param_idx, param) in params.iter().enumerate() {
85 for modifier in param.modifiers(db).elements(db) {
86 if let ast::Modifier::Ref(terminal_ref) = modifier {
87 diagnostics.push(PluginDiagnostic::error(
88 &terminal_ref,
89 "Parameters of an `#[executable]` function can't be `ref`.".into(),
90 ));
91 }
92 }
93 builder.add_modified(
94 RewriteNode::Text(format!(
95 " let __param{EXECUTABLE_PREFIX}{param_idx} = Serde::deserialize(ref \
96 input).expect('Failed to deserialize param #{param_idx}');\n"
97 ))
98 .mapped(db, param),
99 );
100 }
101 if !diagnostics.is_empty() {
102 return PluginResult { code: None, diagnostics, remove_original_item: false };
103 }
104 builder.add_str(
105 " assert(core::array::SpanTrait::is_empty(input), 'Input too long for params.');\n",
106 );
107 builder.add_modified(RewriteNode::interpolate_patched(
108 " let __result = @$function_name$(\n",
109 &[("function_name".into(), RewriteNode::from_ast(&name))].into(),
110 ));
111 for (param_idx, param) in params.iter().enumerate() {
112 builder.add_modified(
113 RewriteNode::Text(format!(" __param{EXECUTABLE_PREFIX}{param_idx},\n"))
114 .mapped(db, param),
115 );
116 }
117 builder.add_str(" );\n");
118 let mut serialize_node = RewriteNode::text(" Serde::serialize(__result, ref output);\n");
119 if let ast::OptionReturnTypeClause::ReturnTypeClause(clause) =
120 declaration.signature(db).ret_ty(db)
121 {
122 serialize_node = serialize_node.mapped(db, &clause);
123 }
124 builder.add_modified(serialize_node);
125 builder.add_str("}\n");
126 let (content, code_mappings) = builder.build();
127 PluginResult {
128 code: Some(PluginGeneratedFile {
129 name: "executable".into(),
130 content,
131 code_mappings,
132 aux_data: None,
133 diagnostics_note: Default::default(),
134 }),
135 diagnostics,
136 remove_original_item: false,
137 }
138 }
139
140 fn declared_attributes(&self) -> Vec<String> {
141 vec![EXECUTABLE_ATTR.to_string(), EXECUTABLE_RAW_ATTR.to_string()]
142 }
143
144 fn executable_attributes(&self) -> Vec<String> {
145 vec![EXECUTABLE_RAW_ATTR.to_string()]
146 }
147}
148
149#[derive(Default, Debug)]
151struct RawExecutableAnalyzer;
152
153impl AnalyzerPlugin for RawExecutableAnalyzer {
154 fn diagnostics(&self, db: &dyn SemanticGroup, module_id: ModuleId) -> Vec<PluginDiagnostic> {
155 let syntax_db = db.upcast();
156 let mut diagnostics = vec![];
157 let Ok(free_functions) = db.module_free_functions(module_id) else {
158 return diagnostics;
159 };
160 for (id, item) in free_functions.iter() {
161 if !item.has_attr(syntax_db, EXECUTABLE_RAW_ATTR) {
162 continue;
163 }
164 let Ok(signature) = db.free_function_signature(*id) else {
165 continue;
166 };
167 if signature.return_type != corelib::unit_ty(db) {
168 diagnostics.push(PluginDiagnostic::error(
169 &signature.stable_ptr.lookup(syntax_db).ret_ty(syntax_db),
170 "Invalid return type for `#[executable_raw]` function, expected `()`."
171 .to_string(),
172 ));
173 }
174 let [input, output] = &signature.params[..] else {
175 diagnostics.push(PluginDiagnostic::error(
176 &signature.stable_ptr.lookup(syntax_db).parameters(syntax_db),
177 "Invalid number of params for `#[executable_raw]` function, expected 2."
178 .to_string(),
179 ));
180 continue;
181 };
182 if input.ty
183 != corelib::get_core_ty_by_name(
184 db,
185 "Span".into(),
186 vec![GenericArgumentId::Type(db.core_info().felt252)],
187 )
188 {
189 diagnostics.push(PluginDiagnostic::error(
190 input.stable_ptr.untyped(),
191 "Invalid first param type for `#[executable_raw]` function, expected \
192 `Span<felt252>`."
193 .to_string(),
194 ));
195 }
196 if input.mutability == Mutability::Reference {
197 diagnostics.push(PluginDiagnostic::error(
198 input.stable_ptr.untyped(),
199 "Invalid first param mutability for `#[executable_raw]` function, got \
200 unexpected `ref`."
201 .to_string(),
202 ));
203 }
204 if output.ty != corelib::core_array_felt252_ty(db) {
205 diagnostics.push(PluginDiagnostic::error(
206 output.stable_ptr.untyped(),
207 "Invalid second param type for `#[executable_raw]` function, expected \
208 `Array<felt252>`."
209 .to_string(),
210 ));
211 }
212 if output.mutability != Mutability::Reference {
213 diagnostics.push(PluginDiagnostic::error(
214 output.stable_ptr.untyped(),
215 "Invalid second param mutability for `#[executable_raw]` function, expected \
216 `ref`."
217 .to_string(),
218 ));
219 }
220 }
221 diagnostics
222 }
223}