1#![warn(clippy::all)]
5#![warn(clippy::pedantic)]
6
7use convert_case::{Case, Casing};
8use itertools::Itertools;
9use proc_macro2::{Span, TokenStream};
10use quote::{quote, ToTokens};
11use std::{
12 collections::{hash_map::Entry, HashMap},
13 iter::IntoIterator,
14};
15use syn::token::Mut;
16use syn::{
17 parse_macro_input, parse_str, spanned::Spanned, Attribute, Data, DataEnum, DataStruct,
18 DeriveInput, Error, Field, Fields, Ident, Lit, LitStr, Member, Meta, MetaList, NestedMeta,
19 Path, Result, Variant,
20};
21
22#[proc_macro_derive(Visitor, attributes(visitor))]
23pub fn derive_visitor(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24 expand_with(input, |stream| impl_visitor(stream, false))
25}
26
27#[proc_macro_derive(VisitorMut, attributes(visitor))]
28pub fn derive_visitor_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
29 expand_with(input, |stream| impl_visitor(stream, true))
30}
31
32#[proc_macro_derive(Drive, attributes(drive))]
33pub fn derive_drive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
34 expand_with(input, |stream| impl_drive(stream, false))
35}
36
37#[proc_macro_derive(DriveMut, attributes(drive))]
38pub fn derive_drive_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
39 expand_with(input, |stream| impl_drive(stream, true))
40}
41
42fn expand_with(
43 input: proc_macro::TokenStream,
44 handler: impl Fn(DeriveInput) -> Result<TokenStream>,
45) -> proc_macro::TokenStream {
46 let input = parse_macro_input!(input as DeriveInput);
47 handler(input)
48 .unwrap_or_else(|error| error.to_compile_error())
49 .into()
50}
51
52fn extract_meta(attrs: Vec<Attribute>, attr_name: &str) -> Result<Option<Meta>> {
53 let macro_attrs = attrs
54 .into_iter()
55 .filter(|attr| attr.path.is_ident(attr_name))
56 .collect::<Vec<Attribute>>();
57
58 if let Some(second) = macro_attrs.get(2) {
59 return Err(Error::new_spanned(second, "duplicate attribute"));
60 }
61
62 macro_attrs.first().map(Attribute::parse_meta).transpose()
63}
64
65#[derive(Default)]
66struct Params(HashMap<Path, Meta>);
67
68impl Params {
69 fn from_attrs(attrs: Vec<Attribute>, attr_name: &str) -> Result<Self> {
70 Ok(extract_meta(attrs, attr_name)?
71 .map(|meta| {
72 if let Meta::List(meta_list) = meta {
73 Self::from_meta_list(meta_list)
74 } else {
75 Err(Error::new_spanned(meta, "invalid attribute"))
76 }
77 })
78 .transpose()?
79 .unwrap_or_default())
80 }
81
82 fn from_meta_list(meta_list: MetaList) -> Result<Self> {
83 let mut params = HashMap::new();
84 for meta in meta_list.nested {
85 if let NestedMeta::Meta(meta) = meta {
86 let path = meta.path();
87 let entry = params.entry(path.clone());
88 if matches!(entry, Entry::Occupied(_)) {
89 return Err(Error::new_spanned(path, "duplicate parameter"));
90 }
91 entry.or_insert(meta);
92 } else {
93 return Err(Error::new_spanned(meta, "invalid attribute"));
94 }
95 }
96 Ok(Self(params))
97 }
98
99 fn validate(&self, allowed_params: &[&str]) -> Result<()> {
100 for path in self.0.keys() {
101 if !allowed_params
102 .iter()
103 .any(|allowed_param| path.is_ident(allowed_param))
104 {
105 return Err(Error::new_spanned(
106 path,
107 format!(
108 "unknown parameter, supported: {}",
109 Itertools::intersperse(allowed_params.iter().copied(), ", ")
110 .collect::<String>()
111 ),
112 ));
113 }
114 }
115 Ok(())
116 }
117
118 fn param(&mut self, name: &str) -> Result<Option<Param>> {
119 self.0
120 .remove(&Ident::new(name, Span::call_site()).into())
121 .map(Param::from_meta)
122 .transpose()
123 }
124}
125
126impl Iterator for Params {
127 type Item = Result<Param>;
128 fn next(&mut self) -> Option<Self::Item> {
129 self.0
130 .keys()
131 .next()
132 .cloned()
133 .map(|path| Param::from_meta(self.0.remove(&path).unwrap()))
134 }
135}
136
137enum Param {
138 Unit(Path, Span),
139 StringLiteral(Path, Span, LitStr),
140 NestedParams(Path, Span, Params),
141}
142
143impl Param {
144 fn from_meta(meta: Meta) -> Result<Self> {
145 let path = meta.path().clone();
146 let span = meta.span();
147 match meta {
148 Meta::Path(_) => Ok(Param::Unit(path, span)),
149 Meta::List(meta_list) => Ok(Param::NestedParams(
150 path,
151 span,
152 Params::from_meta_list(meta_list)?,
153 )),
154 Meta::NameValue(name_value) => {
155 if let Lit::Str(lit_str) = name_value.lit {
156 Ok(Param::StringLiteral(path, span, lit_str))
157 } else {
158 Err(Error::new_spanned(name_value, "invalid parameter"))
159 }
160 }
161 }
162 }
163 fn path(&self) -> &Path {
164 match self {
165 Self::Unit(path, _)
166 | Self::StringLiteral(path, _, _)
167 | Self::NestedParams(path, _, _) => path,
168 }
169 }
170
171 fn span(&self) -> Span {
172 match self {
173 Self::Unit(_, span)
174 | Self::StringLiteral(_, span, _)
175 | Self::NestedParams(_, span, _) => *span,
176 }
177 }
178
179 fn unit(self) -> Result<()> {
180 if let Self::Unit(_, _) = self {
181 Ok(())
182 } else {
183 Err(Error::new(self.span(), "invalid parameter"))
184 }
185 }
186
187 fn string_literal(self) -> Result<LitStr> {
188 if let Self::StringLiteral(_, _, lit_str) = self {
189 Ok(lit_str)
190 } else {
191 Err(Error::new(self.span(), "invalid parameter"))
192 }
193 }
194}
195
196struct VisitorItemParams {
197 enter: Option<Ident>,
198 exit: Option<Ident>,
199}
200
201fn visitor_method_name_from_path(struct_path: &Path, event: &str) -> Ident {
202 let last_segment = struct_path.segments.last().unwrap();
203 Ident::new(
204 &format!(
205 "{}_{}",
206 event,
207 last_segment.ident.to_string().to_case(Case::Snake)
208 ),
209 Span::call_site(),
210 )
211}
212
213fn visitor_method_name_from_param(param: Param, path: &Path, event: &str) -> Result<Ident> {
214 match param {
215 Param::StringLiteral(_, _, lit_str) => lit_str.parse(),
216 Param::Unit(_, _) => Ok(visitor_method_name_from_path(path, event)),
217 Param::NestedParams(_, span, _) => Err(Error::new(span, "invalid parameter")),
218 }
219}
220
221fn impl_visitor(input: DeriveInput, mutable: bool) -> Result<TokenStream> {
222 let params = Params::from_attrs(input.attrs, "visitor")?
223 .map_ok(|param| {
224 let path = param.path().clone();
225
226 let item_params = match param {
227 Param::Unit(_, _) => VisitorItemParams {
228 enter: Some(visitor_method_name_from_path(&path, "enter")),
229 exit: Some(visitor_method_name_from_path(&path, "exit")),
230 },
231 Param::NestedParams(_, _, mut nested) => {
232 nested.validate(&["enter", "exit"])?;
233 VisitorItemParams {
234 enter: nested
235 .param("enter")?
236 .map(|param| visitor_method_name_from_param(param, &path, "enter"))
237 .transpose()?,
238 exit: nested
239 .param("exit")?
240 .map(|param| visitor_method_name_from_param(param, &path, "exit"))
241 .transpose()?,
242 }
243 }
244 Param::StringLiteral(_, _, lit) => {
245 return Err(Error::new_spanned(lit, "invalid attribute"));
246 }
247 };
248 Ok((path, item_params))
249 })
250 .flatten()
251 .collect::<Result<HashMap<Path, VisitorItemParams>>>()?;
252
253 match input.data {
254 Data::Enum(enum_) => {
255 for variant in enum_.variants {
256 if let Some(attr) = variant.attrs.first() {
257 return Err(Error::new_spanned(
258 attr,
259 "#[visitor] attribute can only be applied to enum or struct",
260 ));
261 }
262 for field in variant.fields {
263 if let Some(attr) = field.attrs.first() {
264 return Err(Error::new_spanned(
265 attr,
266 "#[visitor] attribute can only be applied to enum or struct",
267 ));
268 }
269 }
270 }
271 }
272 Data::Struct(struct_) => {
273 for field in struct_.fields {
274 if let Some(attr) = field.attrs.first() {
275 return Err(Error::new_spanned(
276 attr,
277 "#[visitor] attribute can only be applied to enum or struct",
278 ));
279 }
280 }
281 }
282 Data::Union(union_) => {
283 return Err(Error::new_spanned(
284 union_.union_token,
285 "unions are not supported",
286 ));
287 }
288 }
289
290 let name = input.ident;
291 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
292 let routes = params
293 .into_iter()
294 .map(|(path, item_params)| visitor_route(&path, item_params, mutable));
295 let impl_trait = Ident::new(
296 if mutable { "VisitorMut" } else { "Visitor" },
297 Span::call_site(),
298 );
299 let mut_modifier = if mutable {
300 Some(Mut(Span::call_site()))
301 } else {
302 None
303 };
304 Ok(quote! {
305 impl #impl_generics ::derive_visitor::#impl_trait for #name #ty_generics #where_clause {
306 fn visit(&mut self, item: & #mut_modifier dyn ::std::any::Any, event: ::derive_visitor::Event) {
307 #(
308 #routes
309 )*
310 }
311 }
312 })
313}
314
315fn visitor_route(path: &Path, item_params: VisitorItemParams, mutable: bool) -> TokenStream {
316 let enter = item_params.enter.map(|method_name| {
317 quote! {
318 ::derive_visitor::Event::Enter => {
319 self.#method_name(item);
320 }
321 }
322 });
323 let exit = item_params.exit.map(|method_name| {
324 quote! {
325 ::derive_visitor::Event::Exit => {
326 self.#method_name(item);
327 }
328 }
329 });
330
331 let method = Ident::new(
332 if mutable {
333 "downcast_mut"
334 } else {
335 "downcast_ref"
336 },
337 Span::call_site(),
338 );
339
340 quote! {
341 if let Some(item) = <dyn ::std::any::Any>::#method::<#path>(item) {
342 match event {
343 #enter
344 #exit
345 _ => {}
346 }
347 }
348 }
349}
350
351fn impl_drive(input: DeriveInput, mutable: bool) -> Result<TokenStream> {
352 let mut params = Params::from_attrs(input.attrs, "drive")?;
353 params.validate(&["skip"])?;
354
355 let skip_visit_self = params
356 .param("skip")?
357 .map(Param::unit)
358 .transpose()?
359 .is_some();
360
361 let name = input.ident;
362 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
363
364 let visitor = Ident::new(
365 if mutable { "VisitorMut" } else { "Visitor" },
366 Span::call_site(),
367 );
368
369 let enter_self = if skip_visit_self {
370 None
371 } else {
372 Some(quote! {
373 ::derive_visitor::#visitor::visit(visitor, self, ::derive_visitor::Event::Enter);
374 })
375 };
376
377 let exit_self = if skip_visit_self {
378 None
379 } else {
380 Some(quote! {
381 ::derive_visitor::#visitor::visit(visitor, self, ::derive_visitor::Event::Exit);
382 })
383 };
384
385 let drive_fields = match input.data {
386 Data::Struct(struct_) => drive_struct(struct_, mutable),
387 Data::Enum(enum_) => drive_enum(enum_, mutable),
388 Data::Union(union_) => {
389 return Err(Error::new_spanned(
390 union_.union_token,
391 "unions are not supported",
392 ));
393 }
394 }?;
395
396 let impl_trait = Ident::new(
397 if mutable { "DriveMut" } else { "Drive" },
398 Span::call_site(),
399 );
400 let method = Ident::new(
401 if mutable { "drive_mut" } else { "drive" },
402 Span::call_site(),
403 );
404 let mut_modifier = if mutable {
405 Some(Mut(Span::call_site()))
406 } else {
407 None
408 };
409
410 Ok(quote! {
411 impl #impl_generics ::derive_visitor::#impl_trait for #name #ty_generics #where_clause {
412 fn #method<V: ::derive_visitor::#visitor>(& #mut_modifier self, visitor: &mut V) {
413 #enter_self
414 #drive_fields
415 #exit_self
416 }
417 }
418 })
419}
420
421fn drive_struct(struct_: DataStruct, mutable: bool) -> Result<TokenStream> {
422 struct_
423 .fields
424 .into_iter()
425 .enumerate()
426 .map(|(index, field)| {
427 let member = field.ident.as_ref().map_or_else(
428 || Member::Unnamed(index.into()),
429 |ident| Member::Named(ident.clone()),
430 );
431 let mut_modifier = if mutable {
432 Some(Mut(Span::call_site()))
433 } else {
434 None
435 };
436 drive_field("e! { & #mut_modifier self.#member }, field, mutable)
437 })
438 .collect()
439}
440
441fn drive_enum(enum_: DataEnum, mutable: bool) -> Result<TokenStream> {
442 let variants = enum_
443 .variants
444 .into_iter()
445 .map(|x| drive_variant(x, mutable))
446 .collect::<Result<TokenStream>>()?;
447 Ok(quote! {
448 match self {
449 #variants
450 _ => {}
451 }
452 })
453}
454
455fn drive_variant(variant: Variant, mutable: bool) -> Result<TokenStream> {
456 let mut params = Params::from_attrs(variant.attrs, "drive")?;
457 params.validate(&["skip"])?;
458 if params.param("skip")?.map(Param::unit).is_some() {
459 return Ok(TokenStream::new());
460 }
461 let name = variant.ident;
462 let destructuring = destructure_fields(variant.fields.clone())?;
463 let fields = variant
464 .fields
465 .into_iter()
466 .enumerate()
467 .map(|(index, field)| {
468 drive_field(
469 &field
470 .ident
471 .clone()
472 .unwrap_or_else(|| Ident::new(&format!("i{}", index), Span::call_site()))
473 .to_token_stream(),
474 field,
475 mutable,
476 )
477 })
478 .collect::<Result<TokenStream>>()?;
479 Ok(quote! {
480 Self::#name#destructuring => {
481 #fields
482 }
483 })
484}
485
486fn destructure_fields(fields: Fields) -> Result<TokenStream> {
487 Ok(match fields {
488 Fields::Named(fields) => {
489 let field_list = fields
490 .named
491 .into_iter()
492 .map(|field| {
493 let mut params = Params::from_attrs(field.attrs, "drive")?;
494 let field_name = field.ident.unwrap();
495 Ok(if params.param("skip")?.map(Param::unit).is_some() {
496 quote! { #field_name: _ }
497 } else {
498 field_name.into_token_stream()
499 })
500 })
501 .collect::<Result<Vec<TokenStream>>>()?;
502 quote! {
503 { #( #field_list ),* }
504 }
505 }
506 Fields::Unnamed(fields) => {
507 let field_list = fields
508 .unnamed
509 .into_iter()
510 .enumerate()
511 .map(|(index, field)| {
512 let mut params = Params::from_attrs(field.attrs, "drive")?;
513 Ok(if params.param("skip")?.map(Param::unit).is_some() {
514 quote! { _ }
515 } else {
516 Ident::new(&format!("i{}", index), Span::call_site()).into_token_stream()
517 })
518 })
519 .collect::<Result<Vec<TokenStream>>>()?;
520 quote! {
521 ( #( #field_list ),* )
522 }
523 }
524 Fields::Unit => TokenStream::new(),
525 })
526}
527
528fn drive_field(value_expr: &TokenStream, field: Field, mutable: bool) -> Result<TokenStream> {
529 let mut params = Params::from_attrs(field.attrs, "drive")?;
530 params.validate(&["skip", "with"])?;
531
532 if params.param("skip")?.map(Param::unit).is_some() {
533 return Ok(TokenStream::new());
534 }
535
536 let drive_fn = params.param("with")?.map_or_else(
537 || {
538 parse_str(if mutable {
539 "::derive_visitor::DriveMut::drive_mut"
540 } else {
541 "::derive_visitor::Drive::drive"
542 })
543 },
544 |param| param.string_literal()?.parse::<Path>(),
545 )?;
546
547 Ok(quote! {
548 #drive_fn(#value_expr, visitor);
549 })
550}