cairo_lang_plugins/plugins/
generate_trait.rs

1use std::iter::zip;
2
3use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
4use cairo_lang_defs::plugin::{
5    MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
6};
7use cairo_lang_filesystem::ids::SmolStrId;
8use cairo_lang_syntax::attribute::structured::{AttributeArgVariant, AttributeStructurize};
9use cairo_lang_syntax::node::helpers::{BodyItems, GenericParamEx, QueryAttrs};
10use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode, ast};
11use itertools::Itertools;
12use salsa::Database;
13
14#[derive(Debug, Default)]
15#[non_exhaustive]
16pub struct GenerateTraitPlugin;
17
18const GENERATE_TRAIT_ATTR: &str = "generate_trait";
19
20impl MacroPlugin for GenerateTraitPlugin {
21    fn generate_code<'db>(
22        &self,
23        db: &'db dyn Database,
24        item_ast: ast::ModuleItem<'db>,
25        _metadata: &MacroPluginMetadata<'_>,
26    ) -> PluginResult<'db> {
27        match item_ast {
28            ast::ModuleItem::Impl(impl_ast) => generate_trait_for_impl(db, impl_ast),
29            module_item => {
30                let mut diagnostics = vec![];
31
32                if let Some(attr) = module_item.find_attr(db, GENERATE_TRAIT_ATTR) {
33                    diagnostics.push(PluginDiagnostic::warning(
34                        attr.stable_ptr(db),
35                        "`generate_trait` may only be applied to `impl`s".to_string(),
36                    ));
37                }
38
39                PluginResult { diagnostics, ..PluginResult::default() }
40            }
41        }
42    }
43
44    fn declared_attributes<'db>(&self, db: &'db dyn Database) -> Vec<SmolStrId<'db>> {
45        vec![SmolStrId::from(db, GENERATE_TRAIT_ATTR)]
46    }
47}
48
49fn generate_trait_for_impl<'db>(
50    db: &'db dyn Database,
51    impl_ast: ast::ItemImpl<'db>,
52) -> PluginResult<'db> {
53    let Some(attr) = impl_ast.attributes(db).find_attr(db, GENERATE_TRAIT_ATTR) else {
54        return PluginResult::default();
55    };
56    let trait_ast = impl_ast.trait_path(db);
57    let Some([trait_ast_segment]) = trait_ast.segments(db).elements(db).collect_array() else {
58        return PluginResult {
59            code: None,
60            diagnostics: vec![PluginDiagnostic::error(
61                trait_ast.stable_ptr(db),
62                "Generated trait must have a single element path.".to_string(),
63            )],
64            remove_original_item: false,
65        };
66    };
67
68    let mut diagnostics = vec![];
69    let mut builder = PatchBuilder::new(db, &impl_ast);
70    let leading_trivia = impl_ast
71        .attributes(db)
72        .elements(db)
73        .next()
74        .unwrap()
75        .hash(db)
76        .leading_trivia(db)
77        .as_syntax_node()
78        .get_text(db);
79    let extra_ident = leading_trivia.split('\n').next_back().unwrap_or_default();
80    for attr_arg in attr.structurize(db).args {
81        match attr_arg.variant {
82            AttributeArgVariant::Unnamed(ast::Expr::FunctionCall(attr_arg))
83                if attr_arg.path(db).as_syntax_node().get_text_without_trivia(db).long(db)
84                    == "trait_attrs" =>
85            {
86                for arg in attr_arg.arguments(db).arguments(db).elements(db) {
87                    builder.add_modified(RewriteNode::interpolate_patched(
88                        &format!("{extra_ident}#[$attr$]\n"),
89                        &[("attr".to_string(), RewriteNode::from_ast_trimmed(&arg))].into(),
90                    ));
91                }
92            }
93            _ => {
94                diagnostics.push(PluginDiagnostic::error(
95                    attr_arg.arg.stable_ptr(db),
96                    "Expected an argument with the name `trait_attrs`.".to_string(),
97                ));
98            }
99        }
100    }
101    builder.add_str(extra_ident);
102    builder.add_node(impl_ast.visibility(db).as_syntax_node());
103    builder.add_str("trait ");
104    let impl_generic_params = impl_ast.generic_params(db);
105    let generic_params_match = match trait_ast_segment {
106        ast::PathSegment::WithGenericArgs(segment) => {
107            let generic_args = segment.generic_args(db).generic_args(db);
108            builder.add_node(segment.ident(db).as_syntax_node());
109            if let ast::OptionWrappedGenericParamList::WrappedGenericParamList(
110                impl_generic_params,
111            ) = impl_generic_params.clone()
112            {
113                // TODO(orizi): Support generic args that do not directly match the generic params.
114                let trait_generic_args = generic_args.elements(db);
115                let longer_lived = impl_generic_params.generic_params(db);
116                let impl_generic_params = longer_lived.elements(db);
117                zip(trait_generic_args, impl_generic_params).all(
118                    |(trait_generic_arg, impl_generic_param)| {
119                        let ast::GenericArg::Unnamed(trait_generic_arg) = trait_generic_arg else {
120                            return false;
121                        };
122                        let ast::GenericArgValue::Expr(trait_generic_arg) =
123                            trait_generic_arg.value(db)
124                        else {
125                            return false;
126                        };
127                        let ast::Expr::Path(trait_generic_arg) = trait_generic_arg.expr(db) else {
128                            return false;
129                        };
130                        let Some([ast::PathSegment::Simple(trait_generic_arg)]) =
131                            trait_generic_arg.segments(db).elements(db).collect_array()
132                        else {
133                            return false;
134                        };
135                        let trait_generic_arg_name = trait_generic_arg.ident(db);
136                        let Some(impl_generic_param_name) = impl_generic_param.name(db) else {
137                            return false;
138                        };
139                        trait_generic_arg_name.text(db) == impl_generic_param_name.text(db)
140                    },
141                )
142            } else {
143                false
144            }
145        }
146        ast::PathSegment::Simple(segment) => {
147            builder.add_node(segment.ident(db).as_syntax_node());
148            matches!(impl_generic_params, ast::OptionWrappedGenericParamList::Empty(_))
149        }
150        ast::PathSegment::Missing(_) => {
151            return PluginResult {
152                code: None,
153                diagnostics: vec![PluginDiagnostic::error(
154                    trait_ast.stable_ptr(db),
155                    "Generated trait can not have a missing path segment.".to_string(),
156                )],
157                remove_original_item: false,
158            };
159        }
160    };
161    if !generic_params_match {
162        diagnostics.push(PluginDiagnostic::error(
163            trait_ast.stable_ptr(db),
164            "Generated trait must have generic args matching the impl's generic params."
165                .to_string(),
166        ));
167    }
168    match impl_ast.body(db) {
169        ast::MaybeImplBody::None(semicolon) => {
170            builder.add_modified(RewriteNode::from_ast_trimmed(&impl_generic_params));
171            builder.add_node(semicolon.as_syntax_node());
172        }
173        ast::MaybeImplBody::Some(body) => {
174            builder.add_node(impl_generic_params.as_syntax_node());
175            builder.add_node(body.lbrace(db).as_syntax_node());
176            for item in body.iter_items(db) {
177                match item {
178                    ast::ImplItem::Function(function_item) => {
179                        let decl = function_item.declaration(db);
180                        let signature = decl.signature(db);
181                        builder.add_node(function_item.attributes(db).as_syntax_node());
182                        builder.add_node(decl.optional_const(db).as_syntax_node());
183                        builder.add_node(decl.function_kw(db).as_syntax_node());
184                        builder.add_node(decl.name(db).as_syntax_node());
185                        builder.add_node(decl.generic_params(db).as_syntax_node());
186                        builder.add_node(signature.lparen(db).as_syntax_node());
187                        for node in signature.parameters(db).node.get_children(db).iter() {
188                            if let Some(param) = ast::Param::cast(db, *node) {
189                                for modifier in param.modifiers(db).elements(db) {
190                                    // `mut` modifiers are only relevant for impls, not traits.
191                                    if !matches!(modifier, ast::Modifier::Mut(_)) {
192                                        builder.add_node(modifier.as_syntax_node());
193                                    }
194                                }
195                                builder.add_node(param.name(db).as_syntax_node());
196                                builder.add_node(param.type_clause(db).as_syntax_node());
197                            } else {
198                                builder.add_node(*node);
199                            }
200                        }
201                        let rparen = signature.rparen(db);
202                        let ret_ty = signature.ret_ty(db);
203                        let implicits_clause = signature.implicits_clause(db);
204                        let optional_no_panic = signature.optional_no_panic(db);
205                        let last_node = if matches!(
206                            optional_no_panic,
207                            ast::OptionTerminalNoPanic::TerminalNoPanic(_)
208                        ) {
209                            builder.add_node(rparen.as_syntax_node());
210                            builder.add_node(ret_ty.as_syntax_node());
211                            builder.add_node(implicits_clause.as_syntax_node());
212                            optional_no_panic.as_syntax_node()
213                        } else if matches!(
214                            implicits_clause,
215                            ast::OptionImplicitsClause::ImplicitsClause(_)
216                        ) {
217                            builder.add_node(rparen.as_syntax_node());
218                            builder.add_node(ret_ty.as_syntax_node());
219                            implicits_clause.as_syntax_node()
220                        } else if matches!(ret_ty, ast::OptionReturnTypeClause::ReturnTypeClause(_))
221                        {
222                            builder.add_node(rparen.as_syntax_node());
223                            ret_ty.as_syntax_node()
224                        } else {
225                            rparen.as_syntax_node()
226                        };
227                        builder.add_modified(RewriteNode::Trimmed {
228                            node: last_node,
229                            trim_left: false,
230                            trim_right: true,
231                        });
232                        builder.add_str(";\n");
233                    }
234                    ast::ImplItem::Type(type_item) => {
235                        builder.add_node(type_item.attributes(db).as_syntax_node());
236                        builder.add_node(type_item.type_kw(db).as_syntax_node());
237                        builder.add_modified(RewriteNode::Trimmed {
238                            node: type_item.name(db).as_syntax_node(),
239                            trim_left: false,
240                            trim_right: true,
241                        });
242                        builder.add_str(";\n");
243                    }
244                    ast::ImplItem::Constant(const_item) => {
245                        builder.add_node(const_item.attributes(db).as_syntax_node());
246                        builder.add_node(const_item.const_kw(db).as_syntax_node());
247                        builder.add_node(const_item.name(db).as_syntax_node());
248                        builder.add_modified(RewriteNode::Trimmed {
249                            node: const_item.type_clause(db).as_syntax_node(),
250                            trim_left: false,
251                            trim_right: true,
252                        });
253                        builder.add_str(";\n");
254                    }
255                    _ => diagnostics.push(PluginDiagnostic::error(
256                        item.stable_ptr(db),
257                        "Only functions, types, and constants are supported in #[generate_trait]."
258                            .to_string(),
259                    )),
260                }
261            }
262            builder.add_node(body.rbrace(db).as_syntax_node());
263        }
264    }
265    let (content, code_mappings) = builder.build();
266    PluginResult {
267        code: Some(PluginGeneratedFile {
268            name: "generate_trait".into(),
269            content,
270            code_mappings,
271            aux_data: None,
272            diagnostics_note: Default::default(),
273            is_unhygienic: false,
274        }),
275        diagnostics,
276        remove_original_item: false,
277    }
278}