1use cairo_lang_defs::ids::ModuleId;
2use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
3use cairo_lang_defs::plugin::{
4 MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
5};
6use cairo_lang_semantic::db::SemanticGroup;
7use cairo_lang_semantic::plugin::{AnalyzerPlugin, PluginSuite};
8use cairo_lang_semantic::{GenericArgumentId, Mutability, corelib};
9use cairo_lang_syntax::attribute::consts::IMPLICIT_PRECEDENCE_ATTR;
10use cairo_lang_syntax::node::db::SyntaxGroup;
11use cairo_lang_syntax::node::helpers::{OptionWrappedGenericParamListHelper, QueryAttrs};
12use cairo_lang_syntax::node::{TypedStablePtr, ast};
13use indoc::formatdoc;
14use itertools::Itertools;
15
16pub const EXECUTABLE_ATTR: &str = "executable";
17pub const EXECUTABLE_RAW_ATTR: &str = "executable_raw";
18pub const EXECUTABLE_PREFIX: &str = "__executable_wrapper__";
19
20pub fn executable_plugin_suite() -> PluginSuite {
22 std::mem::take(
23 PluginSuite::default()
24 .add_plugin::<ExecutablePlugin>()
25 .add_analyzer_plugin::<RawExecutableAnalyzer>(),
26 )
27}
28
29const IMPLICIT_PRECEDENCE: &[&str] = &[
30 "core::pedersen::Pedersen",
31 "core::RangeCheck",
32 "core::integer::Bitwise",
33 "core::ec::EcOp",
34 "core::poseidon::Poseidon",
35 "core::circuit::RangeCheck96",
36 "core::circuit::AddMod",
37 "core::circuit::MulMod",
38];
39
40#[derive(Debug, Default)]
41#[non_exhaustive]
42struct ExecutablePlugin;
43
44impl MacroPlugin for ExecutablePlugin {
45 fn generate_code(
46 &self,
47 db: &dyn SyntaxGroup,
48 item_ast: ast::ModuleItem,
49 _metadata: &MacroPluginMetadata<'_>,
50 ) -> PluginResult {
51 let ast::ModuleItem::FreeFunction(item) = item_ast else {
52 return PluginResult::default();
53 };
54 if !item.has_attr(db, EXECUTABLE_ATTR) {
55 return PluginResult::default();
56 }
57 let mut diagnostics = vec![];
58 let mut builder = PatchBuilder::new(db, &item);
59 let declaration = item.declaration(db);
60 let generics = declaration.generic_params(db);
61 if !generics.is_empty(db) {
62 diagnostics.push(PluginDiagnostic::error(
63 &generics,
64 "Executable functions cannot have generic params.".to_string(),
65 ));
66 return PluginResult { code: None, diagnostics, remove_original_item: false };
67 }
68 let name = declaration.name(db);
69 let implicits_precedence =
70 RewriteNode::Text(format!("#[{IMPLICIT_PRECEDENCE_ATTR}({})]", {
71 IMPLICIT_PRECEDENCE.iter().join(", ")
72 }));
73 builder.add_modified(RewriteNode::interpolate_patched(&formatdoc! {"
74
75 $implicit_precedence$
76 #[{EXECUTABLE_RAW_ATTR}]
77 fn {EXECUTABLE_PREFIX}$function_name$(mut input: Span<felt252>, ref output: Array<felt252>) {{\n
78 "},
79 &[
80 ("implicit_precedence".into(), implicits_precedence,),
81 ("function_name".into(), RewriteNode::from_ast(&name))
82 ].into()
83 ));
84 let params = declaration.signature(db).parameters(db).elements(db);
85 for (param_idx, param) in params.iter().enumerate() {
86 builder.add_modified(
87 RewriteNode::Text(format!(
88 " let __param{EXECUTABLE_PREFIX}{param_idx} = Serde::deserialize(ref \
89 input).expect('Failed to deserialize param #{param_idx}');\n"
90 ))
91 .mapped(db, param),
92 );
93 }
94 builder.add_str(
95 " assert(core::array::SpanTrait::is_empty(input), 'Input too long for params.');\n",
96 );
97 builder.add_modified(RewriteNode::interpolate_patched(
98 " let __result = @$function_name$(\n",
99 &[("function_name".into(), RewriteNode::from_ast(&name))].into(),
100 ));
101 for (param_idx, param) in params.iter().enumerate() {
102 builder.add_modified(
103 RewriteNode::Text(format!(" __param{EXECUTABLE_PREFIX}{param_idx},\n"))
104 .mapped(db, param),
105 );
106 }
107 builder.add_str(" );\n");
108 let mut serialize_node = RewriteNode::text(" Serde::serialize(__result, ref output);\n");
109 if let ast::OptionReturnTypeClause::ReturnTypeClause(clause) =
110 declaration.signature(db).ret_ty(db)
111 {
112 serialize_node = serialize_node.mapped(db, &clause);
113 }
114 builder.add_modified(serialize_node);
115 builder.add_str("}\n");
116 let (content, code_mappings) = builder.build();
117 PluginResult {
118 code: Some(PluginGeneratedFile {
119 name: "executable".into(),
120 content,
121 code_mappings,
122 aux_data: None,
123 diagnostics_note: Default::default(),
124 }),
125 diagnostics,
126 remove_original_item: false,
127 }
128 }
129
130 fn declared_attributes(&self) -> Vec<String> {
131 vec![EXECUTABLE_ATTR.to_string(), EXECUTABLE_RAW_ATTR.to_string()]
132 }
133
134 fn executable_attributes(&self) -> Vec<String> {
135 vec![EXECUTABLE_RAW_ATTR.to_string()]
136 }
137}
138
139#[derive(Default, Debug)]
141struct RawExecutableAnalyzer;
142
143impl AnalyzerPlugin for RawExecutableAnalyzer {
144 fn diagnostics(&self, db: &dyn SemanticGroup, module_id: ModuleId) -> Vec<PluginDiagnostic> {
145 let syntax_db = db.upcast();
146 let mut diagnostics = vec![];
147 let Ok(free_functions) = db.module_free_functions(module_id) else {
148 return diagnostics;
149 };
150 for (id, item) in free_functions.iter() {
151 if !item.has_attr(syntax_db, EXECUTABLE_RAW_ATTR) {
152 continue;
153 }
154 let Ok(signature) = db.free_function_signature(*id) else {
155 continue;
156 };
157 if signature.return_type != corelib::unit_ty(db) {
158 diagnostics.push(PluginDiagnostic::error(
159 &signature.stable_ptr.lookup(syntax_db).ret_ty(syntax_db),
160 "Invalid return type for `#[executable_raw]` function, expected `()`."
161 .to_string(),
162 ));
163 }
164 let [input, output] = &signature.params[..] else {
165 diagnostics.push(PluginDiagnostic::error(
166 &signature.stable_ptr.lookup(syntax_db).parameters(syntax_db),
167 "Invalid number of params for `#[executable_raw]` function, expected 2."
168 .to_string(),
169 ));
170 continue;
171 };
172 if input.ty
173 != corelib::get_core_ty_by_name(db, "Span".into(), vec![GenericArgumentId::Type(
174 corelib::core_felt252_ty(db),
175 )])
176 {
177 diagnostics.push(PluginDiagnostic::error(
178 input.stable_ptr.untyped(),
179 "Invalid first param type for `#[executable_raw]` function, expected \
180 `Span<felt252>`."
181 .to_string(),
182 ));
183 }
184 if input.mutability == Mutability::Reference {
185 diagnostics.push(PluginDiagnostic::error(
186 input.stable_ptr.untyped(),
187 "Invalid first param mutability for `#[executable_raw]` function, got \
188 unexpected `ref`."
189 .to_string(),
190 ));
191 }
192 if output.ty != corelib::core_array_felt252_ty(db) {
193 diagnostics.push(PluginDiagnostic::error(
194 output.stable_ptr.untyped(),
195 "Invalid second param type for `#[executable_raw]` function, expected \
196 `Array<felt252>`."
197 .to_string(),
198 ));
199 }
200 if output.mutability != Mutability::Reference {
201 diagnostics.push(PluginDiagnostic::error(
202 output.stable_ptr.untyped(),
203 "Invalid second param mutability for `#[executable_raw]` function, expected \
204 `ref`."
205 .to_string(),
206 ));
207 }
208 }
209 diagnostics
210 }
211}