enum_conversion_derive/
parse_enum.rs

1use std::collections::HashMap;
2use std::fmt::Write as _;
3
4use quote::ToTokens;
5use syn::__private::Span;
6use syn::{Data, GenericParam, Lifetime, LifetimeDef, Token};
7
8use super::*;
9use crate::parse_attributes::{parse_attrs, VariantInfo};
10
11/// This functions determines the name of the enum with generic
12/// params attached.
13///
14/// # Example
15/// ```
16/// use std::fmt::Debug;
17/// enum Enum<'a, T: 'a + Debug, const X: usize> {
18///     F1(&'a T),
19///     F2([T; X])
20/// }
21/// ```
22/// This function should return `(Enum<'a, T, X>, vec!['a])`
23pub fn fetch_name_with_generic_params(ast: &DeriveInput) -> (String, Vec<String>) {
24    let mut param_string = String::new();
25    let mut lifetimes = vec![];
26    for param in ast.generics.params.iter() {
27        let next = match param {
28            syn::GenericParam::Type(type_) => type_.ident.to_token_stream(),
29            syn::GenericParam::Lifetime(life_def) => {
30                let lifetime = life_def.lifetime.to_token_stream();
31                lifetimes.push(lifetime.to_string());
32                lifetime
33            }
34            syn::GenericParam::Const(constant) => constant.ident.to_token_stream(),
35        };
36        _ = write!(param_string, "{},", next);
37    }
38    param_string.pop();
39    if !param_string.is_empty() {
40        (format!("{}<{}>", ast.ident, param_string), lifetimes)
41    } else {
42        (ast.ident.to_string(), lifetimes)
43    }
44}
45
46/// The generic arguments and lifetimes that must
47/// be added to trait implementations.
48pub struct ImplGenerics {
49    /// The generic params inherited from the decorated
50    /// type.
51    pub impl_generics: String,
52    /// For returning references, an extra lifetime with
53    /// appropriate bounds must be used in addition to
54    /// the generics from the type.
55    pub impl_generics_ref: String,
56    /// The where clause with trait bounds from the decorated
57    /// type.
58    pub where_clause: String,
59}
60
61/// This fetches the generics for impl blocks on the traits
62/// and the where clause.
63///
64/// # Example:
65/// ```
66/// use std::fmt::Debug;
67/// pub enum Enum<'a, T: Debug, U>
68///where
69///     U: Into<T>
70/// {
71///     F1(&'a T),
72///     F2(U)
73/// }
74/// ```
75/// returns
76/// `
77/// ImplGenerics {
78///     impl_generics: "<T: Debug, U>",
79///     impl_generics_ref: "<'a, 'enum_conv: 'a, T: Debug, U>",
80///     where_clause: "where U: Into<T>",
81/// }
82/// `
83///
84///
85/// For traits the return references, the lifetime of the reference must be bound
86/// by lifetimes in the definition of the enum.
87pub fn fetch_impl_generics(ast: &DeriveInput, lifetime: &str, bounds: &[String]) -> ImplGenerics {
88    let mut generics = ast.generics.clone();
89    let mut generics_ref = generics.clone();
90    generics_ref
91        .params
92        .push(GenericParam::Lifetime(bound_lifetime(lifetime, bounds)));
93
94    let where_clause = generics
95        .where_clause
96        .take()
97        .map(|w| w.to_token_stream().to_string());
98    ImplGenerics {
99        impl_generics: generics.to_token_stream().to_string(),
100        impl_generics_ref: generics_ref.to_token_stream().to_string(),
101        where_clause: where_clause.unwrap_or_default(),
102    }
103}
104
105/// Given a lifetime and a list of other lifetimes, creates
106/// the bound that states the input lifetime cannot outlive
107/// the lifetimes in the list.
108pub fn bound_lifetime(lifetime: &str, bounds: &[String]) -> syn::LifetimeDef {
109    let mut lifetime_def = LifetimeDef::new(Lifetime::new(lifetime, Span::call_site()));
110    lifetime_def.colon_token = if bounds.is_empty() {
111        Some(Token![:](Span::call_site()))
112    } else {
113        None
114    };
115    lifetime_def.bounds = bounds
116        .iter()
117        .map(|lifetime| Lifetime::new(lifetime, Span::call_site()))
118        .collect();
119    lifetime_def
120}
121
122/// Fetches the name of each variant in the enum and
123/// maps it to a string representation of its type.
124///
125/// Also performs validation for unsupported enum types.
126/// These include:
127///  * Enums with multiple variants of the same type.
128///  * Enums with variants with multiple or named fields.
129///  * Enums with unit variants.
130///
131/// Will panic if the input type is not an enum.
132pub(crate) fn fetch_fields_from_enum(ast: &mut DeriveInput) -> HashMap<String, VariantInfo> {
133    let derive_globally = parse_attrs(&mut ast.attrs);
134    if let Data::Enum(data) = &mut ast.data {
135        let mut num_fields: usize = 0;
136        let mut types = data
137            .variants
138            .iter_mut()
139            .map(|var| match &var.fields {
140                syn::Fields::Unnamed(field_) => {
141                    if field_.unnamed.len() != 1 {
142                        panic!(
143                            "Can only derive for enums whose types do \
144                             not contain multiple fields."
145                        );
146                    }
147                    let var_ty = field_
148                        .unnamed
149                        .iter()
150                        .next()
151                        .unwrap()
152                        .ty
153                        .to_token_stream()
154                        .to_string();
155                    let var_name = var.ident.to_token_stream().to_string();
156                    let var_info = VariantInfo {
157                        ty: var_ty,
158                        try_from: parse_attrs(&mut var.attrs) || derive_globally,
159                    };
160                    num_fields += 1;
161                    (var_info, var_name)
162                }
163                syn::Fields::Named(_) => {
164                    panic!("Can only derive for enums whose types do not have named fields.")
165                }
166                syn::Fields::Unit => {
167                    panic!("Can only derive for enums who don't contain unit types as variants.")
168                }
169            })
170            .collect::<HashMap<VariantInfo, String>>();
171        let types: HashMap<String, VariantInfo> = types.drain().map(|(k, v)| (v, k)).collect();
172        if types.keys().len() != num_fields {
173            panic!("Cannot derive for enums with more than one field with the same type.")
174        }
175        types
176    } else {
177        panic!("Can only derive for enums.")
178    }
179}
180
181/// Creates a marker enum for each field in the enum
182/// under a new module.
183///
184/// Used to identify types in the enum and disambiguate
185/// generic parameters.
186pub(crate) fn create_marker_enums(name: &str, types: &HashMap<String, VariantInfo>) -> String {
187    let mut piece = format!(
188        "#[allow(non_snake_case)]\n mod enum___conversion___{}",
189        name
190    );
191    piece.push_str("{ ");
192    for field in types.keys() {
193        _ = write!(piece, "pub(crate) enum {}{{}}", field);
194    }
195    piece.push('}');
196    piece
197}
198
199/// Get the fully qualified name of the marker struct
200/// associated with an enum variant.
201pub fn get_marker(name: &str, field: &str) -> String {
202    format!("enum___conversion___{}::{}", name, field)
203}
204
205#[cfg(test)]
206mod test_parsers {
207
208    use super::*;
209
210    const ENUM: &str = r#"
211            enum Enum<'a, 'b, T, U: Debug>
212            where T: Into<U>, U: 'a
213            {
214                #[help]
215                Array([u8; 20]),
216                BareFn(fn(&'a usize) -> bool),
217                Macro(typey!()),
218                Path(<Vec<&'a mut T> as IntoIterator>::Item),
219                Ptr(*const u8),
220                Tuple((&'b i64, bool)),
221                Slice([u8]),
222                Trait(Box<&dyn Into<U>>),
223            }
224        "#;
225
226    /// Test that we support all possible types in an enum,
227    /// and that we get the names of the field correctly.
228    /// We also check that attribute macros are supported.
229    #[test]
230    fn test_parse_fields_and_types() {
231        let mut ast: DeriveInput = syn::parse_str(ENUM).expect("Test failed.");
232        let fields = fetch_fields_from_enum(&mut ast);
233        let expected: HashMap<String, VariantInfo> = HashMap::from([
234            ("Array".to_string(), "[u8 ; 20]".into()),
235            ("BareFn".to_string(), "fn (& 'a usize) -> bool".into()),
236            ("Macro".to_string(), "typey ! ()".into()),
237            (
238                "Path".to_string(),
239                "< Vec < & 'a mut T > as IntoIterator > :: Item".into(),
240            ),
241            ("Ptr".to_string(), "* const u8".into()),
242            ("Slice".to_string(), "[u8]".into()),
243            ("Trait".to_string(), "Box < & dyn Into < U > >".into()),
244            ("Tuple".to_string(), "(& 'b i64 , bool)".into()),
245        ]);
246        assert_eq!(expected, fields);
247    }
248
249    #[test]
250    fn test_global_try_from_config() {
251        let mut ast: DeriveInput = syn::parse_str(
252            r#"
253            #[DeriveTryFrom]
254            enum Enum {
255                F1(i64),
256                F2(bool),
257            }
258        "#,
259        )
260        .expect("Test failed");
261        let fields = fetch_fields_from_enum(&mut ast);
262        let expected: HashMap<String, VariantInfo> = HashMap::from([
263            (
264                "F1".to_string(),
265                VariantInfo {
266                    ty: "i64".to_string(),
267                    try_from: true,
268                },
269            ),
270            (
271                "F2".to_string(),
272                VariantInfo {
273                    ty: "bool".to_string(),
274                    try_from: true,
275                },
276            ),
277        ]);
278        assert_eq!(fields, expected);
279    }
280
281    #[test]
282    fn test_try_from_local_config() {
283        let mut ast: DeriveInput = syn::parse_str(
284            r#"
285            enum Enum {
286                F1(i64),
287                #[DeriveTryFrom]
288                F2(bool),
289            }
290        "#,
291        )
292        .expect("Test failed");
293        let fields = fetch_fields_from_enum(&mut ast);
294        let expected: HashMap<String, VariantInfo> = HashMap::from([
295            ("F1".to_string(), "i64".into()),
296            (
297                "F2".to_string(),
298                VariantInfo {
299                    ty: "bool".to_string(),
300                    try_from: true,
301                },
302            ),
303        ]);
304        assert_eq!(fields, expected);
305    }
306
307    #[test]
308    fn test_generics_and_bounds() {
309        let ast: DeriveInput = syn::parse_str(ENUM).expect("Test failed.");
310        let (_, lifetimes) = fetch_name_with_generic_params(&ast);
311        let ImplGenerics {
312            impl_generics,
313            impl_generics_ref,
314            where_clause,
315        } = fetch_impl_generics(&ast, ENUM_CONV_LIFETIME, &lifetimes);
316        assert_eq!(impl_generics, "< 'a , 'b , T , U : Debug >");
317        assert_eq!(
318            impl_generics_ref,
319            "< 'a , 'b , 'enum_conv : 'a + 'b , T , U : Debug , >"
320        );
321        assert_eq!(where_clause, "where T : Into < U > , U : 'a");
322    }
323
324    #[test]
325    fn test_get_name_with_generics() {
326        let ast: DeriveInput = syn::parse_str(ENUM).expect("Test failed.");
327        let (name, lifetimes) = fetch_name_with_generic_params(&ast);
328        assert_eq!(name, "Enum<'a,'b,T,U>");
329        assert_eq!(lifetimes, vec![String::from("'a"), String::from("'b")]);
330    }
331
332    #[test]
333    #[should_panic(expected = "Can only derive for enums.")]
334    fn test_panic_on_struct() {
335        let mut ast = syn::parse_str("pub struct Struct;").expect("Test failed");
336        _ = fetch_fields_from_enum(&mut ast);
337    }
338
339    #[test]
340    #[should_panic(expected = "Can only derive for enums whose types do not have named fields.")]
341    fn test_panic_on_field_with_named_types() {
342        let mut ast = syn::parse_str(
343            r#"
344            enum Enum {
345                F{a: i64},
346            }
347        "#,
348        )
349        .expect("Test failed");
350        _ = fetch_fields_from_enum(&mut ast);
351    }
352
353    #[test]
354    #[should_panic(
355        expected = "Cannot derive for enums with more than one field with the same type."
356    )]
357    fn test_multiple_fields_same_type() {
358        let mut ast = syn::parse_str(
359            r#"
360        enum Enum {
361            F1(u64),
362            F2(u64),
363        }
364        "#,
365        )
366        .expect("Test failed");
367        _ = fetch_fields_from_enum(&mut ast);
368    }
369
370    #[test]
371    #[should_panic(
372        expected = "Can only derive for enums whose types do not contain multiple fields."
373    )]
374    fn test_multiple_types_in_field() {
375        let mut ast = syn::parse_str(
376            r#"
377            enum Enum {
378                Field(i64, bool),
379            }
380        "#,
381        )
382        .expect("Test failed");
383        _ = fetch_fields_from_enum(&mut ast);
384    }
385
386    #[test]
387    #[should_panic(
388        expected = "Can only derive for enums who don't contain unit types as variants."
389    )]
390    fn test_unit_type() {
391        let mut ast = syn::parse_str(
392            r#"
393            enum Enum {
394                Some(bool),
395                None,
396            }
397        "#,
398        )
399        .expect("Test failed");
400        _ = fetch_fields_from_enum(&mut ast);
401    }
402
403    /// If an enum has no fields, this derive macro will be a no-op
404    #[test]
405    fn test_harmless() {
406        let mut ast = syn::parse_str(r#"enum Enum{ }"#).expect("Test failed");
407        let fields = fetch_fields_from_enum(&mut ast);
408        assert!(fields.is_empty())
409    }
410
411    #[test]
412    fn test_create_marker_structs() {
413        let mut ast = syn::parse_str(
414            r#"
415            enum Enum {
416                F1(u64)
417            }
418        "#,
419        )
420        .expect("Test failed.");
421        let fields = fetch_fields_from_enum(&mut ast);
422        let output = create_marker_enums(&ast.ident.to_string(), &fields);
423        assert_eq!(
424            output,
425            "#[allow(non_snake_case)]\n mod enum___conversion___Enum{ pub(crate) enum F1{}}"
426        );
427    }
428}