cairo_lang_plugins/plugins/
generate_trait.rs1use std::iter::zip;
2
3use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
4use cairo_lang_defs::plugin::{
5 MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
6};
7use cairo_lang_filesystem::ids::SmolStrId;
8use cairo_lang_syntax::attribute::structured::{AttributeArgVariant, AttributeStructurize};
9use cairo_lang_syntax::node::helpers::{BodyItems, GenericParamEx, QueryAttrs};
10use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode, ast};
11use itertools::Itertools;
12use salsa::Database;
13
14#[derive(Debug, Default)]
15#[non_exhaustive]
16pub struct GenerateTraitPlugin;
17
18const GENERATE_TRAIT_ATTR: &str = "generate_trait";
19
20impl MacroPlugin for GenerateTraitPlugin {
21 fn generate_code<'db>(
22 &self,
23 db: &'db dyn Database,
24 item_ast: ast::ModuleItem<'db>,
25 _metadata: &MacroPluginMetadata<'_>,
26 ) -> PluginResult<'db> {
27 match item_ast {
28 ast::ModuleItem::Impl(impl_ast) => generate_trait_for_impl(db, impl_ast),
29 module_item => {
30 let mut diagnostics = vec![];
31
32 if let Some(attr) = module_item.find_attr(db, GENERATE_TRAIT_ATTR) {
33 diagnostics.push(PluginDiagnostic::warning(
34 attr.stable_ptr(db),
35 "`generate_trait` may only be applied to `impl`s".to_string(),
36 ));
37 }
38
39 PluginResult { diagnostics, ..PluginResult::default() }
40 }
41 }
42 }
43
44 fn declared_attributes<'db>(&self, db: &'db dyn Database) -> Vec<SmolStrId<'db>> {
45 vec![SmolStrId::from(db, GENERATE_TRAIT_ATTR)]
46 }
47}
48
49fn generate_trait_for_impl<'db>(
50 db: &'db dyn Database,
51 impl_ast: ast::ItemImpl<'db>,
52) -> PluginResult<'db> {
53 let Some(attr) = impl_ast.attributes(db).find_attr(db, GENERATE_TRAIT_ATTR) else {
54 return PluginResult::default();
55 };
56 let trait_ast = impl_ast.trait_path(db);
57 let Some([trait_ast_segment]) = trait_ast.segments(db).elements(db).collect_array() else {
58 return PluginResult {
59 code: None,
60 diagnostics: vec![PluginDiagnostic::error(
61 trait_ast.stable_ptr(db),
62 "Generated trait must have a single element path.".to_string(),
63 )],
64 remove_original_item: false,
65 };
66 };
67
68 let mut diagnostics = vec![];
69 let mut builder = PatchBuilder::new(db, &impl_ast);
70 let leading_trivia = impl_ast
71 .attributes(db)
72 .elements(db)
73 .next()
74 .unwrap()
75 .hash(db)
76 .leading_trivia(db)
77 .as_syntax_node()
78 .get_text(db);
79 let extra_ident = leading_trivia.split('\n').next_back().unwrap_or_default();
80 for attr_arg in attr.structurize(db).args {
81 match attr_arg.variant {
82 AttributeArgVariant::Unnamed(ast::Expr::FunctionCall(attr_arg))
83 if attr_arg.path(db).as_syntax_node().get_text_without_trivia(db).long(db)
84 == "trait_attrs" =>
85 {
86 for arg in attr_arg.arguments(db).arguments(db).elements(db) {
87 builder.add_modified(RewriteNode::interpolate_patched(
88 &format!("{extra_ident}#[$attr$]\n"),
89 &[("attr".to_string(), RewriteNode::from_ast_trimmed(&arg))].into(),
90 ));
91 }
92 }
93 _ => {
94 diagnostics.push(PluginDiagnostic::error(
95 attr_arg.arg.stable_ptr(db),
96 "Expected an argument with the name `trait_attrs`.".to_string(),
97 ));
98 }
99 }
100 }
101 builder.add_str(extra_ident);
102 builder.add_node(impl_ast.visibility(db).as_syntax_node());
103 builder.add_str("trait ");
104 let impl_generic_params = impl_ast.generic_params(db);
105 let generic_params_match = match trait_ast_segment {
106 ast::PathSegment::WithGenericArgs(segment) => {
107 let generic_args = segment.generic_args(db).generic_args(db);
108 builder.add_node(segment.ident(db).as_syntax_node());
109 if let ast::OptionWrappedGenericParamList::WrappedGenericParamList(
110 impl_generic_params,
111 ) = impl_generic_params.clone()
112 {
113 let trait_generic_args = generic_args.elements(db);
115 let longer_lived = impl_generic_params.generic_params(db);
116 let impl_generic_params = longer_lived.elements(db);
117 zip(trait_generic_args, impl_generic_params).all(
118 |(trait_generic_arg, impl_generic_param)| {
119 let ast::GenericArg::Unnamed(trait_generic_arg) = trait_generic_arg else {
120 return false;
121 };
122 let ast::GenericArgValue::Expr(trait_generic_arg) =
123 trait_generic_arg.value(db)
124 else {
125 return false;
126 };
127 let ast::Expr::Path(trait_generic_arg) = trait_generic_arg.expr(db) else {
128 return false;
129 };
130 let Some([ast::PathSegment::Simple(trait_generic_arg)]) =
131 trait_generic_arg.segments(db).elements(db).collect_array()
132 else {
133 return false;
134 };
135 let trait_generic_arg_name = trait_generic_arg.ident(db);
136 let Some(impl_generic_param_name) = impl_generic_param.name(db) else {
137 return false;
138 };
139 trait_generic_arg_name.text(db) == impl_generic_param_name.text(db)
140 },
141 )
142 } else {
143 false
144 }
145 }
146 ast::PathSegment::Simple(segment) => {
147 builder.add_node(segment.ident(db).as_syntax_node());
148 matches!(impl_generic_params, ast::OptionWrappedGenericParamList::Empty(_))
149 }
150 ast::PathSegment::Missing(_) => {
151 return PluginResult {
152 code: None,
153 diagnostics: vec![PluginDiagnostic::error(
154 trait_ast.stable_ptr(db),
155 "Generated trait can not have a missing path segment.".to_string(),
156 )],
157 remove_original_item: false,
158 };
159 }
160 };
161 if !generic_params_match {
162 diagnostics.push(PluginDiagnostic::error(
163 trait_ast.stable_ptr(db),
164 "Generated trait must have generic args matching the impl's generic params."
165 .to_string(),
166 ));
167 }
168 match impl_ast.body(db) {
169 ast::MaybeImplBody::None(semicolon) => {
170 builder.add_modified(RewriteNode::from_ast_trimmed(&impl_generic_params));
171 builder.add_node(semicolon.as_syntax_node());
172 }
173 ast::MaybeImplBody::Some(body) => {
174 builder.add_node(impl_generic_params.as_syntax_node());
175 builder.add_node(body.lbrace(db).as_syntax_node());
176 for item in body.iter_items(db) {
177 match item {
178 ast::ImplItem::Function(function_item) => {
179 let decl = function_item.declaration(db);
180 let signature = decl.signature(db);
181 builder.add_node(function_item.attributes(db).as_syntax_node());
182 builder.add_node(decl.optional_const(db).as_syntax_node());
183 builder.add_node(decl.function_kw(db).as_syntax_node());
184 builder.add_node(decl.name(db).as_syntax_node());
185 builder.add_node(decl.generic_params(db).as_syntax_node());
186 builder.add_node(signature.lparen(db).as_syntax_node());
187 for node in signature.parameters(db).node.get_children(db).iter() {
188 if let Some(param) = ast::Param::cast(db, *node) {
189 for modifier in param.modifiers(db).elements(db) {
190 if !matches!(modifier, ast::Modifier::Mut(_)) {
192 builder.add_node(modifier.as_syntax_node());
193 }
194 }
195 builder.add_node(param.name(db).as_syntax_node());
196 builder.add_node(param.type_clause(db).as_syntax_node());
197 } else {
198 builder.add_node(*node);
199 }
200 }
201 let rparen = signature.rparen(db);
202 let ret_ty = signature.ret_ty(db);
203 let implicits_clause = signature.implicits_clause(db);
204 let optional_no_panic = signature.optional_no_panic(db);
205 let last_node = if matches!(
206 optional_no_panic,
207 ast::OptionTerminalNoPanic::TerminalNoPanic(_)
208 ) {
209 builder.add_node(rparen.as_syntax_node());
210 builder.add_node(ret_ty.as_syntax_node());
211 builder.add_node(implicits_clause.as_syntax_node());
212 optional_no_panic.as_syntax_node()
213 } else if matches!(
214 implicits_clause,
215 ast::OptionImplicitsClause::ImplicitsClause(_)
216 ) {
217 builder.add_node(rparen.as_syntax_node());
218 builder.add_node(ret_ty.as_syntax_node());
219 implicits_clause.as_syntax_node()
220 } else if matches!(ret_ty, ast::OptionReturnTypeClause::ReturnTypeClause(_))
221 {
222 builder.add_node(rparen.as_syntax_node());
223 ret_ty.as_syntax_node()
224 } else {
225 rparen.as_syntax_node()
226 };
227 builder.add_modified(RewriteNode::Trimmed {
228 node: last_node,
229 trim_left: false,
230 trim_right: true,
231 });
232 builder.add_str(";\n");
233 }
234 ast::ImplItem::Type(type_item) => {
235 builder.add_node(type_item.attributes(db).as_syntax_node());
236 builder.add_node(type_item.type_kw(db).as_syntax_node());
237 builder.add_modified(RewriteNode::Trimmed {
238 node: type_item.name(db).as_syntax_node(),
239 trim_left: false,
240 trim_right: true,
241 });
242 builder.add_str(";\n");
243 }
244 ast::ImplItem::Constant(const_item) => {
245 builder.add_node(const_item.attributes(db).as_syntax_node());
246 builder.add_node(const_item.const_kw(db).as_syntax_node());
247 builder.add_node(const_item.name(db).as_syntax_node());
248 builder.add_modified(RewriteNode::Trimmed {
249 node: const_item.type_clause(db).as_syntax_node(),
250 trim_left: false,
251 trim_right: true,
252 });
253 builder.add_str(";\n");
254 }
255 _ => diagnostics.push(PluginDiagnostic::error(
256 item.stable_ptr(db),
257 "Only functions, types, and constants are supported in #[generate_trait]."
258 .to_string(),
259 )),
260 }
261 }
262 builder.add_node(body.rbrace(db).as_syntax_node());
263 }
264 }
265 let (content, code_mappings) = builder.build();
266 PluginResult {
267 code: Some(PluginGeneratedFile {
268 name: "generate_trait".into(),
269 content,
270 code_mappings,
271 aux_data: None,
272 diagnostics_note: Default::default(),
273 is_unhygienic: false,
274 }),
275 diagnostics,
276 remove_original_item: false,
277 }
278}