cairo_lang_executable/
plugin.rs

1use cairo_lang_defs::ids::ModuleId;
2use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
3use cairo_lang_defs::plugin::{
4    MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
5};
6use cairo_lang_semantic::db::SemanticGroup;
7use cairo_lang_semantic::plugin::{AnalyzerPlugin, PluginSuite};
8use cairo_lang_semantic::{GenericArgumentId, Mutability, corelib};
9use cairo_lang_syntax::attribute::consts::IMPLICIT_PRECEDENCE_ATTR;
10use cairo_lang_syntax::node::db::SyntaxGroup;
11use cairo_lang_syntax::node::helpers::{OptionWrappedGenericParamListHelper, QueryAttrs};
12use cairo_lang_syntax::node::{TypedStablePtr, ast};
13use indoc::formatdoc;
14use itertools::Itertools;
15
16pub const EXECUTABLE_ATTR: &str = "executable";
17pub const EXECUTABLE_RAW_ATTR: &str = "executable_raw";
18pub const EXECUTABLE_PREFIX: &str = "__executable_wrapper__";
19
20/// Returns a plugin suite with the `ExecutablePlugin` and `RawExecutableAnalyzer`.
21pub fn executable_plugin_suite() -> PluginSuite {
22    std::mem::take(
23        PluginSuite::default()
24            .add_plugin::<ExecutablePlugin>()
25            .add_analyzer_plugin::<RawExecutableAnalyzer>(),
26    )
27}
28
29const IMPLICIT_PRECEDENCE: &[&str] = &[
30    "core::pedersen::Pedersen",
31    "core::RangeCheck",
32    "core::integer::Bitwise",
33    "core::ec::EcOp",
34    "core::poseidon::Poseidon",
35    "core::circuit::RangeCheck96",
36    "core::circuit::AddMod",
37    "core::circuit::MulMod",
38];
39
40#[derive(Debug, Default)]
41#[non_exhaustive]
42struct ExecutablePlugin;
43
44impl MacroPlugin for ExecutablePlugin {
45    fn generate_code(
46        &self,
47        db: &dyn SyntaxGroup,
48        item_ast: ast::ModuleItem,
49        _metadata: &MacroPluginMetadata<'_>,
50    ) -> PluginResult {
51        let ast::ModuleItem::FreeFunction(item) = item_ast else {
52            return PluginResult::default();
53        };
54        if !item.has_attr(db, EXECUTABLE_ATTR) {
55            return PluginResult::default();
56        }
57        let mut diagnostics = vec![];
58        let mut builder = PatchBuilder::new(db, &item);
59        let declaration = item.declaration(db);
60        let generics = declaration.generic_params(db);
61        if !generics.is_empty(db) {
62            diagnostics.push(PluginDiagnostic::error(
63                &generics,
64                "Executable functions cannot have generic params.".to_string(),
65            ));
66            return PluginResult { code: None, diagnostics, remove_original_item: false };
67        }
68        let name = declaration.name(db);
69        let implicits_precedence =
70            RewriteNode::Text(format!("#[{IMPLICIT_PRECEDENCE_ATTR}({})]", {
71                IMPLICIT_PRECEDENCE.iter().join(", ")
72            }));
73        builder.add_modified(RewriteNode::interpolate_patched(&formatdoc! {"
74
75                $implicit_precedence$
76                #[{EXECUTABLE_RAW_ATTR}]
77                fn {EXECUTABLE_PREFIX}$function_name$(mut input: Span<felt252>, ref output: Array<felt252>) {{\n
78            "},
79            &[
80                ("implicit_precedence".into(), implicits_precedence,),
81                ("function_name".into(), RewriteNode::from_ast(&name))
82            ].into()
83        ));
84        let params = declaration.signature(db).parameters(db).elements(db);
85        for (param_idx, param) in params.iter().enumerate() {
86            builder.add_modified(
87                RewriteNode::Text(format!(
88                    "    let __param{EXECUTABLE_PREFIX}{param_idx} = Serde::deserialize(ref \
89                     input).expect('Failed to deserialize param #{param_idx}');\n"
90                ))
91                .mapped(db, param),
92            );
93        }
94        builder.add_str(
95            "    assert(core::array::SpanTrait::is_empty(input), 'Input too long for params.');\n",
96        );
97        builder.add_modified(RewriteNode::interpolate_patched(
98            "    let __result = @$function_name$(\n",
99            &[("function_name".into(), RewriteNode::from_ast(&name))].into(),
100        ));
101        for (param_idx, param) in params.iter().enumerate() {
102            builder.add_modified(
103                RewriteNode::Text(format!("        __param{EXECUTABLE_PREFIX}{param_idx},\n"))
104                    .mapped(db, param),
105            );
106        }
107        builder.add_str("    );\n");
108        let mut serialize_node = RewriteNode::text("    Serde::serialize(__result, ref output);\n");
109        if let ast::OptionReturnTypeClause::ReturnTypeClause(clause) =
110            declaration.signature(db).ret_ty(db)
111        {
112            serialize_node = serialize_node.mapped(db, &clause);
113        }
114        builder.add_modified(serialize_node);
115        builder.add_str("}\n");
116        let (content, code_mappings) = builder.build();
117        PluginResult {
118            code: Some(PluginGeneratedFile {
119                name: "executable".into(),
120                content,
121                code_mappings,
122                aux_data: None,
123                diagnostics_note: Default::default(),
124            }),
125            diagnostics,
126            remove_original_item: false,
127        }
128    }
129
130    fn declared_attributes(&self) -> Vec<String> {
131        vec![EXECUTABLE_ATTR.to_string(), EXECUTABLE_RAW_ATTR.to_string()]
132    }
133
134    fn executable_attributes(&self) -> Vec<String> {
135        vec![EXECUTABLE_RAW_ATTR.to_string()]
136    }
137}
138
139/// Plugin to add diagnostics on bad `#[executable_raw]` annotations.
140#[derive(Default, Debug)]
141struct RawExecutableAnalyzer;
142
143impl AnalyzerPlugin for RawExecutableAnalyzer {
144    fn diagnostics(&self, db: &dyn SemanticGroup, module_id: ModuleId) -> Vec<PluginDiagnostic> {
145        let syntax_db = db.upcast();
146        let mut diagnostics = vec![];
147        let Ok(free_functions) = db.module_free_functions(module_id) else {
148            return diagnostics;
149        };
150        for (id, item) in free_functions.iter() {
151            if !item.has_attr(syntax_db, EXECUTABLE_RAW_ATTR) {
152                continue;
153            }
154            let Ok(signature) = db.free_function_signature(*id) else {
155                continue;
156            };
157            if signature.return_type != corelib::unit_ty(db) {
158                diagnostics.push(PluginDiagnostic::error(
159                    &signature.stable_ptr.lookup(syntax_db).ret_ty(syntax_db),
160                    "Invalid return type for `#[executable_raw]` function, expected `()`."
161                        .to_string(),
162                ));
163            }
164            let [input, output] = &signature.params[..] else {
165                diagnostics.push(PluginDiagnostic::error(
166                    &signature.stable_ptr.lookup(syntax_db).parameters(syntax_db),
167                    "Invalid number of params for `#[executable_raw]` function, expected 2."
168                        .to_string(),
169                ));
170                continue;
171            };
172            if input.ty
173                != corelib::get_core_ty_by_name(db, "Span".into(), vec![GenericArgumentId::Type(
174                    corelib::core_felt252_ty(db),
175                )])
176            {
177                diagnostics.push(PluginDiagnostic::error(
178                    input.stable_ptr.untyped(),
179                    "Invalid first param type for `#[executable_raw]` function, expected \
180                     `Span<felt252>`."
181                        .to_string(),
182                ));
183            }
184            if input.mutability == Mutability::Reference {
185                diagnostics.push(PluginDiagnostic::error(
186                    input.stable_ptr.untyped(),
187                    "Invalid first param mutability for `#[executable_raw]` function, got \
188                     unexpected `ref`."
189                        .to_string(),
190                ));
191            }
192            if output.ty != corelib::core_array_felt252_ty(db) {
193                diagnostics.push(PluginDiagnostic::error(
194                    output.stable_ptr.untyped(),
195                    "Invalid second param type for `#[executable_raw]` function, expected \
196                     `Array<felt252>`."
197                        .to_string(),
198                ));
199            }
200            if output.mutability != Mutability::Reference {
201                diagnostics.push(PluginDiagnostic::error(
202                    output.stable_ptr.untyped(),
203                    "Invalid second param mutability for `#[executable_raw]` function, expected \
204                     `ref`."
205                        .to_string(),
206                ));
207            }
208        }
209        diagnostics
210    }
211}