Skip to main content

cairo_lang_plugins/plugins/
generate_trait.rs

1use std::iter::zip;
2
3use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
4use cairo_lang_defs::plugin::{
5    MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
6};
7use cairo_lang_syntax::attribute::structured::{AttributeArgVariant, AttributeStructurize};
8use cairo_lang_syntax::node::db::SyntaxGroup;
9use cairo_lang_syntax::node::helpers::{BodyItems, GenericParamEx, QueryAttrs};
10use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode, ast};
11use itertools::Itertools;
12
13#[derive(Debug, Default)]
14#[non_exhaustive]
15pub struct GenerateTraitPlugin;
16
17const GENERATE_TRAIT_ATTR: &str = "generate_trait";
18
19impl MacroPlugin for GenerateTraitPlugin {
20    fn generate_code(
21        &self,
22        db: &dyn SyntaxGroup,
23        item_ast: ast::ModuleItem,
24        _metadata: &MacroPluginMetadata<'_>,
25    ) -> PluginResult {
26        match item_ast {
27            ast::ModuleItem::Impl(impl_ast) => generate_trait_for_impl(db, impl_ast),
28            module_item => {
29                let mut diagnostics = vec![];
30
31                if let Some(attr) = module_item.find_attr(db, GENERATE_TRAIT_ATTR) {
32                    diagnostics.push(PluginDiagnostic::warning(
33                        attr.stable_ptr(db),
34                        "`generate_trait` may only be applied to `impl`s".to_string(),
35                    ));
36                }
37
38                PluginResult { diagnostics, ..PluginResult::default() }
39            }
40        }
41    }
42
43    fn declared_attributes(&self) -> Vec<String> {
44        vec![GENERATE_TRAIT_ATTR.to_string()]
45    }
46}
47
48fn generate_trait_for_impl(db: &dyn SyntaxGroup, impl_ast: ast::ItemImpl) -> PluginResult {
49    let Some(attr) = impl_ast.attributes(db).find_attr(db, GENERATE_TRAIT_ATTR) else {
50        return PluginResult::default();
51    };
52    let trait_ast = impl_ast.trait_path(db);
53    let Some([trait_ast_segment]) = trait_ast.segments(db).elements(db).collect_array() else {
54        return PluginResult {
55            code: None,
56            diagnostics: vec![PluginDiagnostic::error(
57                trait_ast.stable_ptr(db),
58                "Generated trait must have a single element path.".to_string(),
59            )],
60            remove_original_item: false,
61        };
62    };
63
64    let mut diagnostics = vec![];
65    let mut builder = PatchBuilder::new(db, &impl_ast);
66    let leading_trivia = impl_ast
67        .attributes(db)
68        .elements(db)
69        .next()
70        .unwrap()
71        .hash(db)
72        .leading_trivia(db)
73        .as_syntax_node()
74        .get_text(db);
75    let extra_ident = leading_trivia.split('\n').next_back().unwrap_or_default();
76    for attr_arg in attr.structurize(db).args {
77        match attr_arg.variant {
78            AttributeArgVariant::Unnamed(ast::Expr::FunctionCall(attr_arg))
79                if attr_arg.path(db).as_syntax_node().get_text_without_trivia(db)
80                    == "trait_attrs" =>
81            {
82                for arg in attr_arg.arguments(db).arguments(db).elements(db) {
83                    builder.add_modified(RewriteNode::interpolate_patched(
84                        &format!("{extra_ident}#[$attr$]\n"),
85                        &[("attr".to_string(), RewriteNode::from_ast_trimmed(&arg))].into(),
86                    ));
87                }
88            }
89            _ => {
90                diagnostics.push(PluginDiagnostic::error(
91                    attr_arg.arg.stable_ptr(db),
92                    "Expected an argument with the name `trait_attrs`.".to_string(),
93                ));
94            }
95        }
96    }
97    builder.add_str(extra_ident);
98    builder.add_node(impl_ast.visibility(db).as_syntax_node());
99    builder.add_str("trait ");
100    let impl_generic_params = impl_ast.generic_params(db);
101    let generic_params_match = match trait_ast_segment {
102        ast::PathSegment::WithGenericArgs(segment) => {
103            builder.add_node(segment.ident(db).as_syntax_node());
104            if let ast::OptionWrappedGenericParamList::WrappedGenericParamList(
105                impl_generic_params,
106            ) = impl_generic_params.clone()
107            {
108                // TODO(orizi): Support generic args that do not directly match the generic params.
109                let trait_generic_args = segment.generic_args(db).generic_args(db).elements(db);
110                let impl_generic_params = impl_generic_params.generic_params(db).elements(db);
111                zip(trait_generic_args, impl_generic_params).all(
112                    |(trait_generic_arg, impl_generic_param)| {
113                        let ast::GenericArg::Unnamed(trait_generic_arg) = trait_generic_arg else {
114                            return false;
115                        };
116                        let ast::GenericArgValue::Expr(trait_generic_arg) =
117                            trait_generic_arg.value(db)
118                        else {
119                            return false;
120                        };
121                        let ast::Expr::Path(trait_generic_arg) = trait_generic_arg.expr(db) else {
122                            return false;
123                        };
124                        let Some([ast::PathSegment::Simple(trait_generic_arg)]) =
125                            trait_generic_arg.segments(db).elements(db).collect_array()
126                        else {
127                            return false;
128                        };
129                        let trait_generic_arg_name = trait_generic_arg.ident(db);
130                        let Some(impl_generic_param_name) = impl_generic_param.name(db) else {
131                            return false;
132                        };
133                        trait_generic_arg_name.text(db) == impl_generic_param_name.text(db)
134                    },
135                )
136            } else {
137                false
138            }
139        }
140        ast::PathSegment::Simple(segment) => {
141            builder.add_node(segment.ident(db).as_syntax_node());
142            matches!(impl_generic_params, ast::OptionWrappedGenericParamList::Empty(_))
143        }
144        ast::PathSegment::Missing(_) => {
145            return PluginResult {
146                code: None,
147                diagnostics: vec![PluginDiagnostic::error(
148                    trait_ast.stable_ptr(db),
149                    "Generated trait can not have a missing path segment.".to_string(),
150                )],
151                remove_original_item: false,
152            };
153        }
154    };
155    if !generic_params_match {
156        diagnostics.push(PluginDiagnostic::error(
157            trait_ast.stable_ptr(db),
158            "Generated trait must have generic args matching the impl's generic params."
159                .to_string(),
160        ));
161    }
162    match impl_ast.body(db) {
163        ast::MaybeImplBody::None(semicolon) => {
164            builder.add_modified(RewriteNode::from_ast_trimmed(&impl_generic_params));
165            builder.add_node(semicolon.as_syntax_node());
166        }
167        ast::MaybeImplBody::Some(body) => {
168            builder.add_node(impl_generic_params.as_syntax_node());
169            builder.add_node(body.lbrace(db).as_syntax_node());
170            for item in body.iter_items(db) {
171                match item {
172                    ast::ImplItem::Function(function_item) => {
173                        let decl = function_item.declaration(db);
174                        let signature = decl.signature(db);
175                        builder.add_node(function_item.attributes(db).as_syntax_node());
176                        builder.add_node(decl.optional_const(db).as_syntax_node());
177                        builder.add_node(decl.function_kw(db).as_syntax_node());
178                        builder.add_node(decl.name(db).as_syntax_node());
179                        builder.add_node(decl.generic_params(db).as_syntax_node());
180                        builder.add_node(signature.lparen(db).as_syntax_node());
181                        for node in signature.parameters(db).node.get_children(db).iter() {
182                            if let Some(param) = ast::Param::cast(db, *node) {
183                                for modifier in param.modifiers(db).elements(db) {
184                                    // `mut` modifiers are only relevant for impls, not traits.
185                                    if !matches!(modifier, ast::Modifier::Mut(_)) {
186                                        builder.add_node(modifier.as_syntax_node());
187                                    }
188                                }
189                                builder.add_node(param.name(db).as_syntax_node());
190                                builder.add_node(param.type_clause(db).as_syntax_node());
191                            } else {
192                                builder.add_node(*node);
193                            }
194                        }
195                        let rparen = signature.rparen(db);
196                        let ret_ty = signature.ret_ty(db);
197                        let implicits_clause = signature.implicits_clause(db);
198                        let optional_no_panic = signature.optional_no_panic(db);
199                        let last_node = if matches!(
200                            optional_no_panic,
201                            ast::OptionTerminalNoPanic::TerminalNoPanic(_)
202                        ) {
203                            builder.add_node(rparen.as_syntax_node());
204                            builder.add_node(ret_ty.as_syntax_node());
205                            builder.add_node(implicits_clause.as_syntax_node());
206                            optional_no_panic.as_syntax_node()
207                        } else if matches!(
208                            implicits_clause,
209                            ast::OptionImplicitsClause::ImplicitsClause(_)
210                        ) {
211                            builder.add_node(rparen.as_syntax_node());
212                            builder.add_node(ret_ty.as_syntax_node());
213                            implicits_clause.as_syntax_node()
214                        } else if matches!(ret_ty, ast::OptionReturnTypeClause::ReturnTypeClause(_))
215                        {
216                            builder.add_node(rparen.as_syntax_node());
217                            ret_ty.as_syntax_node()
218                        } else {
219                            rparen.as_syntax_node()
220                        };
221                        builder.add_modified(RewriteNode::Trimmed {
222                            node: last_node,
223                            trim_left: false,
224                            trim_right: true,
225                        });
226                        builder.add_str(";\n");
227                    }
228                    ast::ImplItem::Type(type_item) => {
229                        builder.add_node(type_item.attributes(db).as_syntax_node());
230                        builder.add_node(type_item.type_kw(db).as_syntax_node());
231                        builder.add_modified(RewriteNode::Trimmed {
232                            node: type_item.name(db).as_syntax_node(),
233                            trim_left: false,
234                            trim_right: true,
235                        });
236                        builder.add_str(";\n");
237                    }
238                    ast::ImplItem::Constant(const_item) => {
239                        builder.add_node(const_item.attributes(db).as_syntax_node());
240                        builder.add_node(const_item.const_kw(db).as_syntax_node());
241                        builder.add_node(const_item.name(db).as_syntax_node());
242                        builder.add_modified(RewriteNode::Trimmed {
243                            node: const_item.type_clause(db).as_syntax_node(),
244                            trim_left: false,
245                            trim_right: true,
246                        });
247                        builder.add_str(";\n");
248                    }
249                    _ => diagnostics.push(PluginDiagnostic::error(
250                        item.stable_ptr(db),
251                        "Only functions, types, and constants are supported in #[generate_trait]."
252                            .to_string(),
253                    )),
254                }
255            }
256            builder.add_node(body.rbrace(db).as_syntax_node());
257        }
258    }
259    let (content, code_mappings) = builder.build();
260    PluginResult {
261        code: Some(PluginGeneratedFile {
262            name: "generate_trait".into(),
263            content,
264            code_mappings,
265            aux_data: None,
266            diagnostics_note: Default::default(),
267            is_unhygienic: false,
268        }),
269        diagnostics,
270        remove_original_item: false,
271    }
272}