cairo_lang_plugins/plugins/
generate_trait.rs1use std::iter::zip;
2
3use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
4use cairo_lang_defs::plugin::{
5 MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
6};
7use cairo_lang_syntax::attribute::structured::{AttributeArgVariant, AttributeStructurize};
8use cairo_lang_syntax::node::db::SyntaxGroup;
9use cairo_lang_syntax::node::helpers::{BodyItems, GenericParamEx, QueryAttrs};
10use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode, ast};
11use itertools::Itertools;
12
13#[derive(Debug, Default)]
14#[non_exhaustive]
15pub struct GenerateTraitPlugin;
16
17const GENERATE_TRAIT_ATTR: &str = "generate_trait";
18
19impl MacroPlugin for GenerateTraitPlugin {
20 fn generate_code(
21 &self,
22 db: &dyn SyntaxGroup,
23 item_ast: ast::ModuleItem,
24 _metadata: &MacroPluginMetadata<'_>,
25 ) -> PluginResult {
26 match item_ast {
27 ast::ModuleItem::Impl(impl_ast) => generate_trait_for_impl(db, impl_ast),
28 module_item => {
29 let mut diagnostics = vec![];
30
31 if let Some(attr) = module_item.find_attr(db, GENERATE_TRAIT_ATTR) {
32 diagnostics.push(PluginDiagnostic::warning(
33 attr.stable_ptr(db),
34 "`generate_trait` may only be applied to `impl`s".to_string(),
35 ));
36 }
37
38 PluginResult { diagnostics, ..PluginResult::default() }
39 }
40 }
41 }
42
43 fn declared_attributes(&self) -> Vec<String> {
44 vec![GENERATE_TRAIT_ATTR.to_string()]
45 }
46}
47
48fn generate_trait_for_impl(db: &dyn SyntaxGroup, impl_ast: ast::ItemImpl) -> PluginResult {
49 let Some(attr) = impl_ast.attributes(db).find_attr(db, GENERATE_TRAIT_ATTR) else {
50 return PluginResult::default();
51 };
52 let trait_ast = impl_ast.trait_path(db);
53 let Some([trait_ast_segment]) = trait_ast.segments(db).elements(db).collect_array() else {
54 return PluginResult {
55 code: None,
56 diagnostics: vec![PluginDiagnostic::error(
57 trait_ast.stable_ptr(db),
58 "Generated trait must have a single element path.".to_string(),
59 )],
60 remove_original_item: false,
61 };
62 };
63
64 let mut diagnostics = vec![];
65 let mut builder = PatchBuilder::new(db, &impl_ast);
66 let leading_trivia = impl_ast
67 .attributes(db)
68 .elements(db)
69 .next()
70 .unwrap()
71 .hash(db)
72 .leading_trivia(db)
73 .as_syntax_node()
74 .get_text(db);
75 let extra_ident = leading_trivia.split('\n').next_back().unwrap_or_default();
76 for attr_arg in attr.structurize(db).args {
77 match attr_arg.variant {
78 AttributeArgVariant::Unnamed(ast::Expr::FunctionCall(attr_arg))
79 if attr_arg.path(db).as_syntax_node().get_text_without_trivia(db)
80 == "trait_attrs" =>
81 {
82 for arg in attr_arg.arguments(db).arguments(db).elements(db) {
83 builder.add_modified(RewriteNode::interpolate_patched(
84 &format!("{extra_ident}#[$attr$]\n"),
85 &[("attr".to_string(), RewriteNode::from_ast_trimmed(&arg))].into(),
86 ));
87 }
88 }
89 _ => {
90 diagnostics.push(PluginDiagnostic::error(
91 attr_arg.arg.stable_ptr(db),
92 "Expected an argument with the name `trait_attrs`.".to_string(),
93 ));
94 }
95 }
96 }
97 builder.add_str(extra_ident);
98 builder.add_node(impl_ast.visibility(db).as_syntax_node());
99 builder.add_str("trait ");
100 let impl_generic_params = impl_ast.generic_params(db);
101 let generic_params_match = match trait_ast_segment {
102 ast::PathSegment::WithGenericArgs(segment) => {
103 builder.add_node(segment.ident(db).as_syntax_node());
104 if let ast::OptionWrappedGenericParamList::WrappedGenericParamList(
105 impl_generic_params,
106 ) = impl_generic_params.clone()
107 {
108 let trait_generic_args = segment.generic_args(db).generic_args(db).elements(db);
110 let impl_generic_params = impl_generic_params.generic_params(db).elements(db);
111 zip(trait_generic_args, impl_generic_params).all(
112 |(trait_generic_arg, impl_generic_param)| {
113 let ast::GenericArg::Unnamed(trait_generic_arg) = trait_generic_arg else {
114 return false;
115 };
116 let ast::GenericArgValue::Expr(trait_generic_arg) =
117 trait_generic_arg.value(db)
118 else {
119 return false;
120 };
121 let ast::Expr::Path(trait_generic_arg) = trait_generic_arg.expr(db) else {
122 return false;
123 };
124 let Some([ast::PathSegment::Simple(trait_generic_arg)]) =
125 trait_generic_arg.segments(db).elements(db).collect_array()
126 else {
127 return false;
128 };
129 let trait_generic_arg_name = trait_generic_arg.ident(db);
130 let Some(impl_generic_param_name) = impl_generic_param.name(db) else {
131 return false;
132 };
133 trait_generic_arg_name.text(db) == impl_generic_param_name.text(db)
134 },
135 )
136 } else {
137 false
138 }
139 }
140 ast::PathSegment::Simple(segment) => {
141 builder.add_node(segment.ident(db).as_syntax_node());
142 matches!(impl_generic_params, ast::OptionWrappedGenericParamList::Empty(_))
143 }
144 ast::PathSegment::Missing(_) => {
145 return PluginResult {
146 code: None,
147 diagnostics: vec![PluginDiagnostic::error(
148 trait_ast.stable_ptr(db),
149 "Generated trait can not have a missing path segment.".to_string(),
150 )],
151 remove_original_item: false,
152 };
153 }
154 };
155 if !generic_params_match {
156 diagnostics.push(PluginDiagnostic::error(
157 trait_ast.stable_ptr(db),
158 "Generated trait must have generic args matching the impl's generic params."
159 .to_string(),
160 ));
161 }
162 match impl_ast.body(db) {
163 ast::MaybeImplBody::None(semicolon) => {
164 builder.add_modified(RewriteNode::from_ast_trimmed(&impl_generic_params));
165 builder.add_node(semicolon.as_syntax_node());
166 }
167 ast::MaybeImplBody::Some(body) => {
168 builder.add_node(impl_generic_params.as_syntax_node());
169 builder.add_node(body.lbrace(db).as_syntax_node());
170 for item in body.iter_items(db) {
171 match item {
172 ast::ImplItem::Function(function_item) => {
173 let decl = function_item.declaration(db);
174 let signature = decl.signature(db);
175 builder.add_node(function_item.attributes(db).as_syntax_node());
176 builder.add_node(decl.optional_const(db).as_syntax_node());
177 builder.add_node(decl.function_kw(db).as_syntax_node());
178 builder.add_node(decl.name(db).as_syntax_node());
179 builder.add_node(decl.generic_params(db).as_syntax_node());
180 builder.add_node(signature.lparen(db).as_syntax_node());
181 for node in signature.parameters(db).node.get_children(db).iter() {
182 if let Some(param) = ast::Param::cast(db, *node) {
183 for modifier in param.modifiers(db).elements(db) {
184 if !matches!(modifier, ast::Modifier::Mut(_)) {
186 builder.add_node(modifier.as_syntax_node());
187 }
188 }
189 builder.add_node(param.name(db).as_syntax_node());
190 builder.add_node(param.type_clause(db).as_syntax_node());
191 } else {
192 builder.add_node(*node);
193 }
194 }
195 let rparen = signature.rparen(db);
196 let ret_ty = signature.ret_ty(db);
197 let implicits_clause = signature.implicits_clause(db);
198 let optional_no_panic = signature.optional_no_panic(db);
199 let last_node = if matches!(
200 optional_no_panic,
201 ast::OptionTerminalNoPanic::TerminalNoPanic(_)
202 ) {
203 builder.add_node(rparen.as_syntax_node());
204 builder.add_node(ret_ty.as_syntax_node());
205 builder.add_node(implicits_clause.as_syntax_node());
206 optional_no_panic.as_syntax_node()
207 } else if matches!(
208 implicits_clause,
209 ast::OptionImplicitsClause::ImplicitsClause(_)
210 ) {
211 builder.add_node(rparen.as_syntax_node());
212 builder.add_node(ret_ty.as_syntax_node());
213 implicits_clause.as_syntax_node()
214 } else if matches!(ret_ty, ast::OptionReturnTypeClause::ReturnTypeClause(_))
215 {
216 builder.add_node(rparen.as_syntax_node());
217 ret_ty.as_syntax_node()
218 } else {
219 rparen.as_syntax_node()
220 };
221 builder.add_modified(RewriteNode::Trimmed {
222 node: last_node,
223 trim_left: false,
224 trim_right: true,
225 });
226 builder.add_str(";\n");
227 }
228 ast::ImplItem::Type(type_item) => {
229 builder.add_node(type_item.attributes(db).as_syntax_node());
230 builder.add_node(type_item.type_kw(db).as_syntax_node());
231 builder.add_modified(RewriteNode::Trimmed {
232 node: type_item.name(db).as_syntax_node(),
233 trim_left: false,
234 trim_right: true,
235 });
236 builder.add_str(";\n");
237 }
238 ast::ImplItem::Constant(const_item) => {
239 builder.add_node(const_item.attributes(db).as_syntax_node());
240 builder.add_node(const_item.const_kw(db).as_syntax_node());
241 builder.add_node(const_item.name(db).as_syntax_node());
242 builder.add_modified(RewriteNode::Trimmed {
243 node: const_item.type_clause(db).as_syntax_node(),
244 trim_left: false,
245 trim_right: true,
246 });
247 builder.add_str(";\n");
248 }
249 _ => diagnostics.push(PluginDiagnostic::error(
250 item.stable_ptr(db),
251 "Only functions, types, and constants are supported in #[generate_trait]."
252 .to_string(),
253 )),
254 }
255 }
256 builder.add_node(body.rbrace(db).as_syntax_node());
257 }
258 }
259 let (content, code_mappings) = builder.build();
260 PluginResult {
261 code: Some(PluginGeneratedFile {
262 name: "generate_trait".into(),
263 content,
264 code_mappings,
265 aux_data: None,
266 diagnostics_note: Default::default(),
267 is_unhygienic: false,
268 }),
269 diagnostics,
270 remove_original_item: false,
271 }
272}