Skip to main content

cairo_lang_executable_plugin/
lib.rs

1#[cfg(test)]
2mod test;
3
4use cairo_lang_defs::ids::ModuleId;
5use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
6use cairo_lang_defs::plugin::{
7    MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
8};
9use cairo_lang_filesystem::ids::SmolStrId;
10use cairo_lang_semantic::corelib::CorelibSemantic;
11use cairo_lang_semantic::items::free_function::FreeFunctionSemantic;
12use cairo_lang_semantic::plugin::{AnalyzerPlugin, PluginSuite};
13use cairo_lang_semantic::{GenericArgumentId, Mutability, corelib};
14use cairo_lang_syntax::attribute::consts::IMPLICIT_PRECEDENCE_ATTR;
15use cairo_lang_syntax::node::helpers::{OptionWrappedGenericParamListHelper, QueryAttrs};
16use cairo_lang_syntax::node::{TypedStablePtr, TypedSyntaxNode, ast};
17use indoc::formatdoc;
18use itertools::Itertools;
19use salsa::Database;
20
21pub const EXECUTABLE_ATTR: &str = "executable";
22pub const EXECUTABLE_RAW_ATTR: &str = "executable_raw";
23pub const EXECUTABLE_PREFIX: &str = "__executable_wrapper__";
24
25/// Returns a plugin suite with the `ExecutablePlugin` and `RawExecutableAnalyzer`.
26pub fn executable_plugin_suite() -> PluginSuite {
27    let mut suite = PluginSuite::default();
28    suite.add_plugin::<ExecutablePlugin>().add_analyzer_plugin::<RawExecutableAnalyzer>();
29    suite
30}
31
32const IMPLICIT_PRECEDENCE: &[&str] = &[
33    "core::pedersen::Pedersen",
34    "core::RangeCheck",
35    "core::integer::Bitwise",
36    "core::ec::EcOp",
37    "core::poseidon::Poseidon",
38    "core::circuit::RangeCheck96",
39    "core::circuit::AddMod",
40    "core::circuit::MulMod",
41];
42
43/// Plugin to add diagnostics on bad `#[executable_raw]` annotations.
44#[derive(Default, Debug)]
45struct RawExecutableAnalyzer;
46
47impl AnalyzerPlugin for RawExecutableAnalyzer {
48    fn diagnostics<'db>(
49        &self,
50        db: &'db dyn Database,
51        module_id: ModuleId<'db>,
52    ) -> Vec<PluginDiagnostic<'db>> {
53        let mut diagnostics = vec![];
54        let Ok(free_functions) = module_id.module_data(db).map(|data| data.free_functions(db))
55        else {
56            return diagnostics;
57        };
58        for (id, item) in free_functions.iter() {
59            if !item.has_attr(db, EXECUTABLE_RAW_ATTR) {
60                continue;
61            }
62            let Ok(signature) = db.free_function_signature(*id) else {
63                continue;
64            };
65            if signature.return_type != corelib::unit_ty(db) {
66                diagnostics.push(PluginDiagnostic::error(
67                    signature.stable_ptr.lookup(db).ret_ty(db).stable_ptr(db),
68                    "Invalid return type for `#[executable_raw]` function, expected `()`."
69                        .to_string(),
70                ));
71            }
72            let [input, output] = &signature.params[..] else {
73                diagnostics.push(PluginDiagnostic::error(
74                    signature.stable_ptr.lookup(db).parameters(db).stable_ptr(db),
75                    "Invalid number of params for `#[executable_raw]` function, expected 2."
76                        .to_string(),
77                ));
78                continue;
79            };
80            if input.ty
81                != corelib::get_core_ty_by_name(
82                    db,
83                    SmolStrId::from(db, "Span"),
84                    vec![GenericArgumentId::Type(db.core_info().felt252)],
85                )
86            {
87                diagnostics.push(PluginDiagnostic::error(
88                    input.stable_ptr.untyped(),
89                    "Invalid first param type for `#[executable_raw]` function, expected \
90                     `Span<felt252>`."
91                        .to_string(),
92                ));
93            }
94            if input.mutability == Mutability::Reference {
95                diagnostics.push(PluginDiagnostic::error(
96                    input.stable_ptr.untyped(),
97                    "Invalid first param mutability for `#[executable_raw]` function, got \
98                     unexpected `ref`."
99                        .to_string(),
100                ));
101            }
102            if output.ty != corelib::core_array_felt252_ty(db) {
103                diagnostics.push(PluginDiagnostic::error(
104                    output.stable_ptr.untyped(),
105                    "Invalid second param type for `#[executable_raw]` function, expected \
106                     `Array<felt252>`."
107                        .to_string(),
108                ));
109            }
110            if output.mutability != Mutability::Reference {
111                diagnostics.push(PluginDiagnostic::error(
112                    output.stable_ptr.untyped(),
113                    "Invalid second param mutability for `#[executable_raw]` function, expected \
114                     `ref`."
115                        .to_string(),
116                ));
117            }
118        }
119        diagnostics
120    }
121}
122
123#[derive(Debug, Default)]
124#[non_exhaustive]
125pub struct ExecutablePlugin;
126
127impl MacroPlugin for ExecutablePlugin {
128    fn generate_code<'db>(
129        &self,
130        db: &'db dyn Database,
131        item_ast: ast::ModuleItem<'db>,
132        _metadata: &MacroPluginMetadata<'_>,
133    ) -> PluginResult<'db> {
134        let ast::ModuleItem::FreeFunction(item) = item_ast else {
135            return PluginResult::default();
136        };
137        if !item.has_attr(db, EXECUTABLE_ATTR) {
138            return PluginResult::default();
139        }
140        let mut diagnostics = vec![];
141        let mut builder = PatchBuilder::new(db, &item);
142        let declaration = item.declaration(db);
143        let generics = declaration.generic_params(db);
144        if !generics.is_empty(db) {
145            diagnostics.push(PluginDiagnostic::error(
146                generics.stable_ptr(db),
147                "Executable functions cannot have generic params.".to_string(),
148            ));
149        }
150        let name = declaration.name(db);
151        let implicits_precedence =
152            RewriteNode::Text(format!("#[{IMPLICIT_PRECEDENCE_ATTR}({})]", {
153                IMPLICIT_PRECEDENCE.iter().join(", ")
154            }));
155        builder.add_modified(RewriteNode::interpolate_patched(&formatdoc! {"
156
157                $implicit_precedence$
158                #[{EXECUTABLE_RAW_ATTR}]
159                fn {EXECUTABLE_PREFIX}$function_name$(mut input: Span<felt252>, ref output: Array<felt252>) {{\n
160            "},
161            &[
162                ("implicit_precedence".into(), implicits_precedence,),
163                ("function_name".into(), RewriteNode::from_ast(&name))
164            ].into()
165        ));
166        let params = declaration.signature(db).parameters(db).elements_vec(db);
167        for (param_idx, param) in params.iter().enumerate() {
168            for modifier in param.modifiers(db).elements(db) {
169                if let ast::Modifier::Ref(terminal_ref) = modifier {
170                    diagnostics.push(PluginDiagnostic::error(
171                        terminal_ref.stable_ptr(db),
172                        "Parameters of an `#[executable]` function can't be `ref`.".into(),
173                    ));
174                }
175            }
176            builder.add_modified(
177                RewriteNode::Text(format!(
178                    "    let __param{EXECUTABLE_PREFIX}{param_idx} = Serde::deserialize(ref \
179                     input).expect('Failed to deserialize param #{param_idx}');\n"
180                ))
181                .mapped(db, param),
182            );
183        }
184        if !diagnostics.is_empty() {
185            return PluginResult { code: None, diagnostics, remove_original_item: false };
186        }
187        builder.add_str(
188            "    assert(core::array::SpanTrait::is_empty(input), 'Input too long for params.');\n",
189        );
190        builder.add_modified(RewriteNode::interpolate_patched(
191            "    let __result = @$function_name$(\n",
192            &[("function_name".into(), RewriteNode::from_ast(&name))].into(),
193        ));
194        for (param_idx, param) in params.iter().enumerate() {
195            builder.add_modified(
196                RewriteNode::Text(format!("        __param{EXECUTABLE_PREFIX}{param_idx},\n"))
197                    .mapped(db, param),
198            );
199        }
200        builder.add_str("    );\n");
201        let mut serialize_node = RewriteNode::text("    Serde::serialize(__result, ref output);\n");
202        if let ast::OptionReturnTypeClause::ReturnTypeClause(clause) =
203            declaration.signature(db).ret_ty(db)
204        {
205            serialize_node = serialize_node.mapped(db, &clause);
206        }
207        builder.add_modified(serialize_node);
208        builder.add_str("}\n");
209        let (content, code_mappings) = builder.build();
210        PluginResult {
211            code: Some(PluginGeneratedFile {
212                name: "executable".into(),
213                content,
214                code_mappings,
215                aux_data: None,
216                diagnostics_note: Default::default(),
217                is_unhygienic: false,
218            }),
219            diagnostics,
220            remove_original_item: false,
221        }
222    }
223
224    fn declared_attributes<'db>(&self, db: &'db dyn Database) -> Vec<SmolStrId<'db>> {
225        vec![(SmolStrId::from(db, EXECUTABLE_ATTR)), (SmolStrId::from(db, EXECUTABLE_RAW_ATTR))]
226    }
227
228    fn executable_attributes<'db>(&self, db: &'db dyn Database) -> Vec<SmolStrId<'db>> {
229        vec![cairo_lang_filesystem::ids::SmolStrId::from(db, EXECUTABLE_RAW_ATTR)]
230    }
231}