cairo_lang_executable_plugin/
lib.rs

1#[cfg(test)]
2mod test;
3
4use cairo_lang_defs::ids::ModuleId;
5use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
6use cairo_lang_defs::plugin::{
7    MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
8};
9use cairo_lang_filesystem::ids::SmolStrId;
10use cairo_lang_semantic::corelib::CorelibSemantic;
11use cairo_lang_semantic::items::free_function::FreeFunctionSemantic;
12use cairo_lang_semantic::plugin::{AnalyzerPlugin, PluginSuite};
13use cairo_lang_semantic::{GenericArgumentId, Mutability, corelib};
14use cairo_lang_syntax::attribute::consts::IMPLICIT_PRECEDENCE_ATTR;
15use cairo_lang_syntax::node::helpers::{OptionWrappedGenericParamListHelper, QueryAttrs};
16use cairo_lang_syntax::node::{TypedStablePtr, TypedSyntaxNode, ast};
17use indoc::formatdoc;
18use itertools::Itertools;
19use salsa::Database;
20
21pub const EXECUTABLE_ATTR: &str = "executable";
22pub const EXECUTABLE_RAW_ATTR: &str = "executable_raw";
23pub const EXECUTABLE_PREFIX: &str = "__executable_wrapper__";
24
25/// Returns a plugin suite with the `ExecutablePlugin` and `RawExecutableAnalyzer`.
26pub fn executable_plugin_suite() -> PluginSuite {
27    std::mem::take(
28        PluginSuite::default()
29            .add_plugin::<ExecutablePlugin>()
30            .add_analyzer_plugin::<RawExecutableAnalyzer>(),
31    )
32}
33
34const IMPLICIT_PRECEDENCE: &[&str] = &[
35    "core::pedersen::Pedersen",
36    "core::RangeCheck",
37    "core::integer::Bitwise",
38    "core::ec::EcOp",
39    "core::poseidon::Poseidon",
40    "core::circuit::RangeCheck96",
41    "core::circuit::AddMod",
42    "core::circuit::MulMod",
43];
44
45/// Plugin to add diagnostics on bad `#[executable_raw]` annotations.
46#[derive(Default, Debug)]
47struct RawExecutableAnalyzer;
48
49impl AnalyzerPlugin for RawExecutableAnalyzer {
50    fn diagnostics<'db>(
51        &self,
52        db: &'db dyn Database,
53        module_id: ModuleId<'db>,
54    ) -> Vec<PluginDiagnostic<'db>> {
55        let mut diagnostics = vec![];
56        let Ok(free_functions) = module_id.module_data(db).map(|data| data.free_functions(db))
57        else {
58            return diagnostics;
59        };
60        for (id, item) in free_functions.iter() {
61            if !item.has_attr(db, EXECUTABLE_RAW_ATTR) {
62                continue;
63            }
64            let Ok(signature) = db.free_function_signature(*id) else {
65                continue;
66            };
67            if signature.return_type != corelib::unit_ty(db) {
68                diagnostics.push(PluginDiagnostic::error(
69                    signature.stable_ptr.lookup(db).ret_ty(db).stable_ptr(db),
70                    "Invalid return type for `#[executable_raw]` function, expected `()`."
71                        .to_string(),
72                ));
73            }
74            let [input, output] = &signature.params[..] else {
75                diagnostics.push(PluginDiagnostic::error(
76                    signature.stable_ptr.lookup(db).parameters(db).stable_ptr(db),
77                    "Invalid number of params for `#[executable_raw]` function, expected 2."
78                        .to_string(),
79                ));
80                continue;
81            };
82            if input.ty
83                != corelib::get_core_ty_by_name(
84                    db,
85                    SmolStrId::from(db, "Span"),
86                    vec![GenericArgumentId::Type(db.core_info().felt252)],
87                )
88            {
89                diagnostics.push(PluginDiagnostic::error(
90                    input.stable_ptr.untyped(),
91                    "Invalid first param type for `#[executable_raw]` function, expected \
92                     `Span<felt252>`."
93                        .to_string(),
94                ));
95            }
96            if input.mutability == Mutability::Reference {
97                diagnostics.push(PluginDiagnostic::error(
98                    input.stable_ptr.untyped(),
99                    "Invalid first param mutability for `#[executable_raw]` function, got \
100                     unexpected `ref`."
101                        .to_string(),
102                ));
103            }
104            if output.ty != corelib::core_array_felt252_ty(db) {
105                diagnostics.push(PluginDiagnostic::error(
106                    output.stable_ptr.untyped(),
107                    "Invalid second param type for `#[executable_raw]` function, expected \
108                     `Array<felt252>`."
109                        .to_string(),
110                ));
111            }
112            if output.mutability != Mutability::Reference {
113                diagnostics.push(PluginDiagnostic::error(
114                    output.stable_ptr.untyped(),
115                    "Invalid second param mutability for `#[executable_raw]` function, expected \
116                     `ref`."
117                        .to_string(),
118                ));
119            }
120        }
121        diagnostics
122    }
123}
124
125#[derive(Debug, Default)]
126#[non_exhaustive]
127pub struct ExecutablePlugin;
128
129impl MacroPlugin for ExecutablePlugin {
130    fn generate_code<'db>(
131        &self,
132        db: &'db dyn Database,
133        item_ast: ast::ModuleItem<'db>,
134        _metadata: &MacroPluginMetadata<'_>,
135    ) -> PluginResult<'db> {
136        let ast::ModuleItem::FreeFunction(item) = item_ast else {
137            return PluginResult::default();
138        };
139        if !item.has_attr(db, EXECUTABLE_ATTR) {
140            return PluginResult::default();
141        }
142        let mut diagnostics = vec![];
143        let mut builder = PatchBuilder::new(db, &item);
144        let declaration = item.declaration(db);
145        let generics = declaration.generic_params(db);
146        if !generics.is_empty(db) {
147            diagnostics.push(PluginDiagnostic::error(
148                generics.stable_ptr(db),
149                "Executable functions cannot have generic params.".to_string(),
150            ));
151        }
152        let name = declaration.name(db);
153        let implicits_precedence =
154            RewriteNode::Text(format!("#[{IMPLICIT_PRECEDENCE_ATTR}({})]", {
155                IMPLICIT_PRECEDENCE.iter().join(", ")
156            }));
157        builder.add_modified(RewriteNode::interpolate_patched(&formatdoc! {"
158
159                $implicit_precedence$
160                #[{EXECUTABLE_RAW_ATTR}]
161                fn {EXECUTABLE_PREFIX}$function_name$(mut input: Span<felt252>, ref output: Array<felt252>) {{\n
162            "},
163            &[
164                ("implicit_precedence".into(), implicits_precedence,),
165                ("function_name".into(), RewriteNode::from_ast(&name))
166            ].into()
167        ));
168        let params = declaration.signature(db).parameters(db).elements_vec(db);
169        for (param_idx, param) in params.iter().enumerate() {
170            for modifier in param.modifiers(db).elements(db) {
171                if let ast::Modifier::Ref(terminal_ref) = modifier {
172                    diagnostics.push(PluginDiagnostic::error(
173                        terminal_ref.stable_ptr(db),
174                        "Parameters of an `#[executable]` function can't be `ref`.".into(),
175                    ));
176                }
177            }
178            builder.add_modified(
179                RewriteNode::Text(format!(
180                    "    let __param{EXECUTABLE_PREFIX}{param_idx} = Serde::deserialize(ref \
181                     input).expect('Failed to deserialize param #{param_idx}');\n"
182                ))
183                .mapped(db, param),
184            );
185        }
186        if !diagnostics.is_empty() {
187            return PluginResult { code: None, diagnostics, remove_original_item: false };
188        }
189        builder.add_str(
190            "    assert(core::array::SpanTrait::is_empty(input), 'Input too long for params.');\n",
191        );
192        builder.add_modified(RewriteNode::interpolate_patched(
193            "    let __result = @$function_name$(\n",
194            &[("function_name".into(), RewriteNode::from_ast(&name))].into(),
195        ));
196        for (param_idx, param) in params.iter().enumerate() {
197            builder.add_modified(
198                RewriteNode::Text(format!("        __param{EXECUTABLE_PREFIX}{param_idx},\n"))
199                    .mapped(db, param),
200            );
201        }
202        builder.add_str("    );\n");
203        let mut serialize_node = RewriteNode::text("    Serde::serialize(__result, ref output);\n");
204        if let ast::OptionReturnTypeClause::ReturnTypeClause(clause) =
205            declaration.signature(db).ret_ty(db)
206        {
207            serialize_node = serialize_node.mapped(db, &clause);
208        }
209        builder.add_modified(serialize_node);
210        builder.add_str("}\n");
211        let (content, code_mappings) = builder.build();
212        PluginResult {
213            code: Some(PluginGeneratedFile {
214                name: "executable".into(),
215                content,
216                code_mappings,
217                aux_data: None,
218                diagnostics_note: Default::default(),
219                is_unhygienic: false,
220            }),
221            diagnostics,
222            remove_original_item: false,
223        }
224    }
225
226    fn declared_attributes<'db>(&self, db: &'db dyn Database) -> Vec<SmolStrId<'db>> {
227        vec![(SmolStrId::from(db, EXECUTABLE_ATTR)), (SmolStrId::from(db, EXECUTABLE_RAW_ATTR))]
228    }
229
230    fn executable_attributes<'db>(&self, db: &'db dyn Database) -> Vec<SmolStrId<'db>> {
231        vec![cairo_lang_filesystem::ids::SmolStrId::from(db, EXECUTABLE_RAW_ATTR)]
232    }
233}