1use cairo_lang_defs::ids::ModuleId;
2use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
3use cairo_lang_defs::plugin::{
4    MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
5};
6use cairo_lang_semantic::db::SemanticGroup;
7use cairo_lang_semantic::plugin::{AnalyzerPlugin, PluginSuite};
8use cairo_lang_semantic::{GenericArgumentId, Mutability, corelib};
9use cairo_lang_syntax::attribute::consts::IMPLICIT_PRECEDENCE_ATTR;
10use cairo_lang_syntax::node::db::SyntaxGroup;
11use cairo_lang_syntax::node::helpers::{OptionWrappedGenericParamListHelper, QueryAttrs};
12use cairo_lang_syntax::node::{TypedStablePtr, TypedSyntaxNode, ast};
13use indoc::formatdoc;
14use itertools::Itertools;
15
16pub const EXECUTABLE_ATTR: &str = "executable";
17pub const EXECUTABLE_RAW_ATTR: &str = "executable_raw";
18pub const EXECUTABLE_PREFIX: &str = "__executable_wrapper__";
19
20pub fn executable_plugin_suite() -> PluginSuite {
22    std::mem::take(
23        PluginSuite::default()
24            .add_plugin::<ExecutablePlugin>()
25            .add_analyzer_plugin::<RawExecutableAnalyzer>(),
26    )
27}
28
29const IMPLICIT_PRECEDENCE: &[&str] = &[
30    "core::pedersen::Pedersen",
31    "core::RangeCheck",
32    "core::integer::Bitwise",
33    "core::ec::EcOp",
34    "core::poseidon::Poseidon",
35    "core::circuit::RangeCheck96",
36    "core::circuit::AddMod",
37    "core::circuit::MulMod",
38];
39
40#[derive(Debug, Default)]
41#[non_exhaustive]
42struct ExecutablePlugin;
43
44impl MacroPlugin for ExecutablePlugin {
45    fn generate_code(
46        &self,
47        db: &dyn SyntaxGroup,
48        item_ast: ast::ModuleItem,
49        _metadata: &MacroPluginMetadata<'_>,
50    ) -> PluginResult {
51        let ast::ModuleItem::FreeFunction(item) = item_ast else {
52            return PluginResult::default();
53        };
54        if !item.has_attr(db, EXECUTABLE_ATTR) {
55            return PluginResult::default();
56        }
57        let mut diagnostics = vec![];
58        let mut builder = PatchBuilder::new(db, &item);
59        let declaration = item.declaration(db);
60        let generics = declaration.generic_params(db);
61        if !generics.is_empty(db) {
62            diagnostics.push(PluginDiagnostic::error(
63                generics.stable_ptr(db),
64                "Executable functions cannot have generic params.".to_string(),
65            ));
66        }
67        let name = declaration.name(db);
68        let implicits_precedence =
69            RewriteNode::Text(format!("#[{IMPLICIT_PRECEDENCE_ATTR}({})]", {
70                IMPLICIT_PRECEDENCE.iter().join(", ")
71            }));
72        builder.add_modified(RewriteNode::interpolate_patched(&formatdoc! {"
73
74                $implicit_precedence$
75                #[{EXECUTABLE_RAW_ATTR}]
76                fn {EXECUTABLE_PREFIX}$function_name$(mut input: Span<felt252>, ref output: Array<felt252>) {{\n
77            "},
78            &[
79                ("implicit_precedence".into(), implicits_precedence,),
80                ("function_name".into(), RewriteNode::from_ast(&name))
81            ].into()
82        ));
83        let params = declaration.signature(db).parameters(db).elements_vec(db);
84        for (param_idx, param) in params.iter().enumerate() {
85            for modifier in param.modifiers(db).elements(db) {
86                if let ast::Modifier::Ref(terminal_ref) = modifier {
87                    diagnostics.push(PluginDiagnostic::error(
88                        terminal_ref.stable_ptr(db),
89                        "Parameters of an `#[executable]` function can't be `ref`.".into(),
90                    ));
91                }
92            }
93            builder.add_modified(
94                RewriteNode::Text(format!(
95                    "    let __param{EXECUTABLE_PREFIX}{param_idx} = Serde::deserialize(ref \
96                     input).expect('Failed to deserialize param #{param_idx}');\n"
97                ))
98                .mapped(db, param),
99            );
100        }
101        if !diagnostics.is_empty() {
102            return PluginResult { code: None, diagnostics, remove_original_item: false };
103        }
104        builder.add_str(
105            "    assert(core::array::SpanTrait::is_empty(input), 'Input too long for params.');\n",
106        );
107        builder.add_modified(RewriteNode::interpolate_patched(
108            "    let __result = @$function_name$(\n",
109            &[("function_name".into(), RewriteNode::from_ast(&name))].into(),
110        ));
111        for (param_idx, param) in params.iter().enumerate() {
112            builder.add_modified(
113                RewriteNode::Text(format!("        __param{EXECUTABLE_PREFIX}{param_idx},\n"))
114                    .mapped(db, param),
115            );
116        }
117        builder.add_str("    );\n");
118        let mut serialize_node = RewriteNode::text("    Serde::serialize(__result, ref output);\n");
119        if let ast::OptionReturnTypeClause::ReturnTypeClause(clause) =
120            declaration.signature(db).ret_ty(db)
121        {
122            serialize_node = serialize_node.mapped(db, &clause);
123        }
124        builder.add_modified(serialize_node);
125        builder.add_str("}\n");
126        let (content, code_mappings) = builder.build();
127        PluginResult {
128            code: Some(PluginGeneratedFile {
129                name: "executable".into(),
130                content,
131                code_mappings,
132                aux_data: None,
133                diagnostics_note: Default::default(),
134                is_unhygienic: false,
135            }),
136            diagnostics,
137            remove_original_item: false,
138        }
139    }
140
141    fn declared_attributes(&self) -> Vec<String> {
142        vec![EXECUTABLE_ATTR.to_string(), EXECUTABLE_RAW_ATTR.to_string()]
143    }
144
145    fn executable_attributes(&self) -> Vec<String> {
146        vec![EXECUTABLE_RAW_ATTR.to_string()]
147    }
148}
149
150#[derive(Default, Debug)]
152struct RawExecutableAnalyzer;
153
154impl AnalyzerPlugin for RawExecutableAnalyzer {
155    fn diagnostics(&self, db: &dyn SemanticGroup, module_id: ModuleId) -> Vec<PluginDiagnostic> {
156        let mut diagnostics = vec![];
157        let Ok(free_functions) = db.module_free_functions(module_id) else {
158            return diagnostics;
159        };
160        for (id, item) in free_functions.iter() {
161            if !item.has_attr(db, EXECUTABLE_RAW_ATTR) {
162                continue;
163            }
164            let Ok(signature) = db.free_function_signature(*id) else {
165                continue;
166            };
167            if signature.return_type != corelib::unit_ty(db) {
168                diagnostics.push(PluginDiagnostic::error(
169                    signature.stable_ptr.lookup(db).ret_ty(db).stable_ptr(db),
170                    "Invalid return type for `#[executable_raw]` function, expected `()`."
171                        .to_string(),
172                ));
173            }
174            let [input, output] = &signature.params[..] else {
175                diagnostics.push(PluginDiagnostic::error(
176                    signature.stable_ptr.lookup(db).parameters(db).stable_ptr(db),
177                    "Invalid number of params for `#[executable_raw]` function, expected 2."
178                        .to_string(),
179                ));
180                continue;
181            };
182            if input.ty
183                != corelib::get_core_ty_by_name(
184                    db,
185                    "Span".into(),
186                    vec![GenericArgumentId::Type(db.core_info().felt252)],
187                )
188            {
189                diagnostics.push(PluginDiagnostic::error(
190                    input.stable_ptr.untyped(),
191                    "Invalid first param type for `#[executable_raw]` function, expected \
192                     `Span<felt252>`."
193                        .to_string(),
194                ));
195            }
196            if input.mutability == Mutability::Reference {
197                diagnostics.push(PluginDiagnostic::error(
198                    input.stable_ptr.untyped(),
199                    "Invalid first param mutability for `#[executable_raw]` function, got \
200                     unexpected `ref`."
201                        .to_string(),
202                ));
203            }
204            if output.ty != corelib::core_array_felt252_ty(db) {
205                diagnostics.push(PluginDiagnostic::error(
206                    output.stable_ptr.untyped(),
207                    "Invalid second param type for `#[executable_raw]` function, expected \
208                     `Array<felt252>`."
209                        .to_string(),
210                ));
211            }
212            if output.mutability != Mutability::Reference {
213                diagnostics.push(PluginDiagnostic::error(
214                    output.stable_ptr.untyped(),
215                    "Invalid second param mutability for `#[executable_raw]` function, expected \
216                     `ref`."
217                        .to_string(),
218                ));
219            }
220        }
221        diagnostics
222    }
223}