macro_field_utils/
variants.rs

1use darling::ast::Fields;
2use proc_macro2::TokenStream;
3use quote::{quote, ToTokens};
4
5use crate::{FieldInfo, FieldsCollector, FieldsHelper};
6
7/// Utility macro to implement [VariantInfo]
8#[macro_export]
9macro_rules! variant_info {
10    ($v:path, $f:path) => {
11        impl $crate::VariantInfo<$f> for $v {
12            fn ident(&self) -> &syn::Ident {
13                &self.ident
14            }
15
16            fn discriminant(&self) -> &Option<syn::Expr> {
17                &self.discriminant
18            }
19
20            fn fields(&self) -> &darling::ast::Fields<$f> {
21                &self.fields
22            }
23        }
24    };
25}
26
27/// Trait to retrieve variants' properties
28pub trait VariantInfo<F: FieldInfo> {
29    /// Retrieves the identifier of the passed-in variant
30    fn ident(&self) -> &syn::Ident;
31    /// Retrieves the variant discriminant (if any). For a variant such as `Example = 2`, the `2`
32    fn discriminant(&self) -> &Option<syn::Expr>;
33    /// Retrieves the fields associated with the variant
34    fn fields(&self) -> &Fields<F>;
35    /// Retrieves the discriminant expr, if any (`= 3`)
36    fn discriminant_expr(&self) -> TokenStream {
37        self.discriminant()
38            .as_ref()
39            .map(|d| {
40                let expr = d.to_token_stream();
41                quote!(= #expr)
42            })
43            .unwrap_or_default()
44    }
45}
46
47/// Utility struct to work with enum's variants
48pub struct VariantsHelper<'v, V: VariantInfo<F>, F: FieldInfo> {
49    variants: Vec<&'v V>,
50    variant_filter: Option<Box<dyn Fn(&V) -> bool + 'v>>,
51    variant_attributes: Option<Box<dyn Fn(&V) -> Option<TokenStream> + 'v>>,
52    include_extra_variants: Vec<(TokenStream, Option<TokenStream>)>,
53    ignore_all_extra_variants: Option<TokenStream>,
54    include_wrapper: bool,
55    left_collector: Option<Box<dyn Fn(&V, FieldsHelper<'_, F>) -> TokenStream + 'v>>,
56    right_collector: Option<Box<dyn Fn(&V, FieldsHelper<'_, F>) -> TokenStream + 'v>>,
57}
58
59impl<'v, V: VariantInfo<F>, F: FieldInfo> VariantsHelper<'v, V, F> {
60    /// Builds a new [VariantsHelper]
61    pub fn new(variants: &'v [V]) -> Self {
62        Self {
63            variants: variants.iter().collect(),
64            variant_filter: None,
65            variant_attributes: None,
66            include_extra_variants: Vec::new(),
67            ignore_all_extra_variants: None,
68            include_wrapper: true,
69            left_collector: None,
70            right_collector: None,
71        }
72    }
73
74    /// Remove all variants `v` for which `predicate(&v)` returns `false`.
75    /// This method operates in place, visiting each element exactly once in the
76    /// original order, and preserves the order of the retained elements.
77    pub fn filtering_variants<P>(mut self, predicate: P) -> Self
78    where
79        P: Fn(&V) -> bool + 'v,
80    {
81        self.variant_filter = Some(Box::new(predicate));
82        self
83    }
84
85    /// Adds an arbitrary number of attributes to each variant.
86    pub fn with_variant_attributes<P>(mut self, predicate: P) -> Self
87    where
88        P: Fn(&V) -> Option<TokenStream> + 'v,
89    {
90        self.variant_attributes = Some(Box::new(predicate));
91        self
92    }
93
94    /// Include extra variants by including the given left and right sides (`#left` if right is [None] or `#left =>
95    /// #right`) at the end.
96    pub fn include_extra_variants(
97        mut self,
98        include_extra_variants: impl IntoIterator<Item = (impl ToTokens, Option<impl ToTokens>)>,
99    ) -> Self {
100        let mut include_extra_variants = include_extra_variants
101            .into_iter()
102            .map(|(l, r)| (l.to_token_stream(), r.map(|t| t.to_token_stream())))
103            .collect::<Vec<_>>();
104        self.include_extra_variants.append(&mut include_extra_variants);
105        self
106    }
107
108    /// Wether to ignore all extra variants with the given right side `_ => #right`, if [Some].
109    ///
110    /// It should be used only when collecting a match.
111    pub fn ignore_all_extra_variants(mut self, right_side: Option<TokenStream>) -> Self {
112        self.ignore_all_extra_variants = right_side;
113        self
114    }
115
116    /// Wether to include the wrapper (curly braces), defaults to `true`.
117    pub fn include_wrapper(mut self, include_wrapper: bool) -> Self {
118        self.include_wrapper = include_wrapper;
119        self
120    }
121
122    /// Specifies the left collector.
123    ///
124    /// Defaults to `VariantsCollector::variant_definition`
125    pub fn left_collector<C>(mut self, left_collector: C) -> Self
126    where
127        C: Fn(&V, FieldsHelper<'_, F>) -> TokenStream + 'v,
128    {
129        self.left_collector = Some(Box::new(left_collector));
130        self
131    }
132
133    /// Specifies the right collector.
134    ///
135    /// Defaults to `VariantsCollector::empty`
136    pub fn right_collector<C>(mut self, right_collector: C) -> Self
137    where
138        C: Fn(&V, FieldsHelper<'_, F>) -> TokenStream + 'v,
139    {
140        self.right_collector = Some(Box::new(right_collector));
141        self
142    }
143
144    /// Collects the fields.
145    ///
146    /// # Examples
147    /// If using the default collectors:
148    /// ``` ignore
149    /// {
150    ///     Variant1,
151    ///     Variant2 = 5,
152    ///     Variant3 {
153    ///         field_1: String,
154    ///         field_2: i32,
155    ///     },
156    /// }
157    /// ```
158    ///
159    /// If using `VariantsCollector::variant_fields_collector(quote!(SelfTy))` left collector and
160    /// `VariantsCollector::variant_fields_collector(quote!(OtherTy)` right
161    /// collector:
162    /// ``` ignore
163    /// {
164    ///     SelfTy::Variant1 => OtherTy::Variant1,
165    ///     SelfTy::Variant2 {
166    ///         field_1: field_1,
167    ///         field_2: field_2,
168    ///     } => OtherTy::Variant2 {
169    ///         field_1: field_1,
170    ///         field_2: field_2,
171    ///     },
172    /// }
173    /// ```
174    pub fn collect(self) -> TokenStream {
175        let left_collector = self
176            .left_collector
177            .unwrap_or_else(|| Box::new(VariantsCollector::empty));
178
179        let mut variants = self
180            .variants
181            .into_iter()
182            .filter(|&v| {
183                if let Some(variant_filter_fn) = &self.variant_filter {
184                    variant_filter_fn(v)
185                } else {
186                    true
187                }
188            })
189            .map(|v| {
190                let attrs = if let Some(attrs_fn) = &self.variant_attributes {
191                    attrs_fn(v)
192                } else {
193                    None
194                }
195                .unwrap_or_default();
196
197                let left = left_collector(v, FieldsHelper::new(v.fields()));
198                let right = if let Some(right_collector) = &self.right_collector {
199                    let right = right_collector(v, FieldsHelper::new(v.fields()));
200                    quote!(=> #right)
201                } else {
202                    TokenStream::default()
203                };
204
205                quote!(
206                    #attrs
207                    #left #right
208                )
209            })
210            .collect::<Vec<_>>();
211
212        for (left, right) in self.include_extra_variants {
213            let right = right.map(|r| quote!(=> #r)).unwrap_or_default();
214            variants.push(quote!(#left #right));
215        }
216
217        if let Some(right) = self.ignore_all_extra_variants {
218            variants.push(quote!(_ => #right));
219        }
220
221        if self.include_wrapper {
222            quote!(
223                {
224                    #( #variants ),*
225                }
226            )
227        } else {
228            quote!( #( #variants ),* )
229        }
230    }
231}
232
233/// Utility struct with common collectors for [VariantsHelper]
234pub struct VariantsCollector;
235impl VariantsCollector {
236    /// Empty collector
237    pub fn empty<V, F>(_v: &V, _f: FieldsHelper<'_, F>) -> TokenStream
238    where
239        V: VariantInfo<F>,
240        F: FieldInfo,
241    {
242        TokenStream::default()
243    }
244
245    /// Variant definition collector
246    ///
247    /// ## Examples
248    /// ``` ignore
249    /// {
250    ///     Variant1,
251    ///     Variant2 = 5,
252    ///     Variant3 {
253    ///         field_1: String,
254    ///         field_2: i32,
255    ///     },
256    /// }
257    /// ```
258    pub fn variant_definition<V, F>(v: &V, fields: FieldsHelper<'_, F>) -> TokenStream
259    where
260        V: VariantInfo<F>,
261        F: FieldInfo,
262    {
263        let ident = v.ident();
264        let dis = v.discriminant_expr();
265        let fields_expr = fields.collect();
266        quote!(
267            #ident #dis #fields_expr
268        )
269    }
270
271    /// Collects the fields of a variant
272    ///
273    /// ## Examples
274    /// An unit variant:
275    /// ``` ignore
276    /// Ty::Variant1
277    /// ```
278    /// A variant with unnamed fields:
279    /// ``` ignore
280    /// Ty::Variant2 (v_0, v_1)
281    /// ```
282    /// A variant with named fields:
283    /// ``` ignore
284    /// Ty::Variant3 {
285    ///     field_1: field_1,
286    ///     field_2: field_2,
287    /// }
288    /// ```
289    pub fn variant_fields_collector<V, F>(ty: impl ToTokens) -> impl Fn(&V, FieldsHelper<'_, F>) -> TokenStream
290    where
291        V: VariantInfo<F>,
292        F: FieldInfo,
293    {
294        move |v, fields| {
295            let ident = v.ident();
296            let right = fields.right_collector(FieldsCollector::ident).collect();
297            quote!( #ty::#ident #right )
298        }
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    #![allow(clippy::manual_unwrap_or_default)] // darling macro
305
306    use darling::{FromField, FromVariant};
307    use quote::quote;
308    use syn::Result;
309
310    use super::*;
311    use crate::field_info;
312
313    #[derive(FromField, Clone)]
314    #[darling(attributes(tst))]
315    struct FieldReceiver {
316        /// The identifier of the passed-in field, or [None] for tuple fields
317        ident: Option<syn::Ident>,
318        /// The visibility of the passed-in field
319        vis: syn::Visibility,
320        /// The type of the passed-in field
321        ty: syn::Type,
322
323        /// Whether to skip
324        #[darling(default)]
325        pub skip: bool,
326    }
327    field_info!(FieldReceiver);
328
329    #[derive(FromVariant, Clone)]
330    #[darling(attributes(tst))]
331    struct VariantReceiver {
332        /// The identifier of the passed-in variant
333        ident: syn::Ident,
334        /// For a variant such as `Example = 2`, the `2`
335        discriminant: Option<syn::Expr>,
336        /// The fields associated with the variant
337        fields: Fields<FieldReceiver>,
338
339        /// Whether to skip
340        #[darling(default)]
341        skip: bool,
342    }
343    variant_info!(VariantReceiver, FieldReceiver);
344
345    #[test]
346    fn test_variant_helper() -> Result<()> {
347        let input: syn::DeriveInput = syn::parse2(quote! {
348            pub enum MyEnum {
349                Variant1,
350                #[tst(skip)]
351                Variant2,
352                Variant3 (String, i64),
353                Variant4 {
354                    field_1: String,
355                    #[tst(skip)]
356                    field_2: i32,
357                    field_3: bool,
358                }
359            }
360        })?;
361        let variants = darling::ast::Data::<VariantReceiver, FieldReceiver>::try_from(&input.data)?
362            .take_enum()
363            .unwrap();
364
365        let collected = VariantsHelper::new(&variants)
366            .filtering_variants(|v| !v.skip)
367            .left_collector(|v, fields| {
368                let ident = &v.ident;
369                let dis = v.discriminant_expr();
370                let fields_expr = fields.filtering(|_ix, f| !f.skip).collect();
371                quote!(
372                    #ident #dis #fields_expr
373                )
374            })
375            .collect();
376        #[rustfmt::skip]
377        let expected = quote!({
378            Variant1,
379            Variant3 (String, i64),
380            Variant4 {
381                field_1: String,
382                field_3: bool
383            }
384        });
385
386        assert_eq!(collected.to_string(), expected.to_string());
387
388        let collected = VariantsHelper::new(&variants)
389            .with_variant_attributes(|v| if v.skip { Some(quote!(#[skipped])) } else { None })
390            .left_collector(|v, fields| {
391                let ident = &v.ident;
392                let dis = v.discriminant_expr();
393                let fields_expr = fields.filtering(|_ix, f| !f.skip).include_all_default(true).collect();
394                quote!(
395                    #ident #dis #fields_expr
396                )
397            })
398            .collect();
399        #[rustfmt::skip]
400        let expected = quote!({
401            Variant1,
402            #[skipped]
403            Variant2,
404            Variant3 (String, i64, .. Default::default()),
405            Variant4 {
406                field_1: String,
407                field_3: bool,
408                .. Default::default()
409            }
410        });
411
412        assert_eq!(collected.to_string(), expected.to_string());
413
414        let collected = VariantsHelper::new(&variants)
415            .filtering_variants(|v| !v.skip)
416            .left_collector(|v, fields| {
417                let ident = &v.ident;
418                let dis = v.discriminant_expr();
419                let fields_expr = fields
420                    .with_attributes(|_ix, f| if f.skip { Some(quote!(#[skipped])) } else { None })
421                    .ignore_all_extra(true)
422                    .collect();
423                quote!(
424                    #ident #dis #fields_expr
425                )
426            })
427            .collect();
428        #[rustfmt::skip]
429        let expected = quote!({
430            Variant1,
431            Variant3 (String, i64, ..),
432            Variant4 {
433                field_1: String,
434                #[skipped]
435                field_2: i32,
436                field_3: bool,
437                ..
438            }
439        });
440
441        assert_eq!(collected.to_string(), expected.to_string());
442
443        let collected = VariantsHelper::new(&variants)
444            .filtering_variants(|v| !v.skip)
445            .left_collector(|v, fields| {
446                let ident = &v.ident;
447                let fields_expr = fields
448                    .filtering(|_ix, f| !f.skip)
449                    .right_collector(FieldsCollector::ident)
450                    .collect();
451                quote!( MyEnum1::#ident #fields_expr )
452            })
453            .right_collector(|v, fields| {
454                let ident = &v.ident;
455                let fields_expr = fields
456                    .filtering(|_ix, f| !f.skip)
457                    .right_collector(FieldsCollector::ident)
458                    .collect();
459                quote!( MyEnum2::#ident #fields_expr )
460            })
461            .collect();
462        #[rustfmt::skip]
463        let expected = quote!({
464            MyEnum1::Variant1 => MyEnum2::Variant1,
465            MyEnum1::Variant3 (v_0, v_1) => MyEnum2::Variant3 (v_0, v_1),
466            MyEnum1::Variant4 {
467                field_1: field_1,
468                field_3: field_3
469            } => MyEnum2::Variant4 {
470                field_1: field_1,
471                field_3: field_3
472            }
473        });
474
475        assert_eq!(collected.to_string(), expected.to_string());
476
477        let collected = VariantsHelper::new(&variants)
478            .left_collector(|v, fields| {
479                let ident = &v.ident;
480                let fields_expr = fields
481                    .ignore_all_extra(true)
482                    .right_collector(FieldsCollector::ident)
483                    .collect();
484                quote!( MyEnum1::#ident #fields_expr )
485            })
486            .right_collector(|v, fields| {
487                let ident = &v.ident;
488                let fields_expr = fields
489                    .include_all_default(true)
490                    .right_collector(FieldsCollector::ident)
491                    .collect();
492                quote!( MyEnum2::#ident #fields_expr )
493            })
494            .collect();
495        #[rustfmt::skip]
496        let expected = quote!({
497            MyEnum1::Variant1 => MyEnum2::Variant1,
498            MyEnum1::Variant2 => MyEnum2::Variant2,
499            MyEnum1::Variant3 (v_0, v_1, ..) => MyEnum2::Variant3 (v_0, v_1, ..Default::default()),
500            MyEnum1::Variant4 {
501                field_1: field_1,
502                field_2: field_2,
503                field_3: field_3,
504                ..
505            } => MyEnum2::Variant4 {
506                field_1: field_1,
507                field_2: field_2,
508                field_3: field_3,
509                ..Default::default()
510            }
511        });
512
513        assert_eq!(collected.to_string(), expected.to_string());
514
515        Ok(())
516    }
517}