borsh_schema_derive_internal/
enum_schema.rs

1use proc_macro2::{Span, TokenStream as TokenStream2};
2use quote::{quote, ToTokens};
3use syn::{
4    parse_quote, AttrStyle, Attribute, Field, Fields, FieldsUnnamed, Ident, ItemEnum, ItemStruct,
5    Visibility,
6};
7
8use crate::helpers::{declaration, quote_where_clause};
9
10pub fn process_enum(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream2> {
11    let name = &input.ident;
12    let name_str = name.to_token_stream().to_string();
13    let generics = &input.generics;
14    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
15    // Generate function that returns the name of the type.
16    let (declaration, where_clause_additions) =
17        declaration(&name_str, &input.generics, cratename.clone());
18
19    // Generate function that returns the schema for variants.
20    // Definitions of the variants.
21    let mut variants_defs = vec![];
22    // Definitions of the anonymous structs used in variants.
23    let mut anonymous_defs = TokenStream2::new();
24    // Recursive calls to `add_definitions_recursively`.
25    let mut add_recursive_defs = TokenStream2::new();
26    for variant in &input.variants {
27        let variant_name_str = variant.ident.to_token_stream().to_string();
28        let full_variant_name_str = format!("{}{}", name_str, variant_name_str);
29        let full_variant_ident = Ident::new(full_variant_name_str.as_str(), Span::call_site());
30        let mut anonymous_struct = ItemStruct {
31            attrs: vec![],
32            vis: Visibility::Inherited,
33            struct_token: Default::default(),
34            ident: full_variant_ident.clone(),
35            generics: (*generics).clone(),
36            fields: variant.fields.clone(),
37            semi_token: Some(Default::default()),
38        };
39        let generic_params = generics
40            .type_params()
41            .fold(TokenStream2::new(), |acc, generic| {
42                let ident = &generic.ident;
43                quote! {
44                    #acc
45                    #ident ,
46                }
47            });
48        if !generic_params.is_empty() {
49            let attr = Attribute {
50                pound_token: Default::default(),
51                style: AttrStyle::Outer,
52                bracket_token: Default::default(),
53                path: parse_quote! {borsh_skip},
54                tokens: Default::default(),
55            };
56            // Whether we should convert the struct from unit struct to regular struct.
57            let mut unit_to_regular = false;
58            match &mut anonymous_struct.fields {
59                Fields::Named(named) => {
60                    named.named.push(Field {
61                        attrs: vec![attr.clone()],
62                        vis: Visibility::Inherited,
63                        ident: Some(Ident::new("borsh_schema_phantom_data", Span::call_site())),
64                        colon_token: None,
65                        ty: parse_quote! {::core::marker::PhantomData<(#generic_params)>},
66                    });
67                }
68                Fields::Unnamed(unnamed) => {
69                    unnamed.unnamed.push(Field {
70                        attrs: vec![attr.clone()],
71                        vis: Visibility::Inherited,
72                        ident: None,
73                        colon_token: None,
74                        ty: parse_quote! {::core::marker::PhantomData<(#generic_params)>},
75                    });
76                }
77                Fields::Unit => {
78                    unit_to_regular = true;
79                }
80            }
81            if unit_to_regular {
82                let mut fields = FieldsUnnamed {
83                    paren_token: Default::default(),
84                    unnamed: Default::default(),
85                };
86                fields.unnamed.push(Field {
87                    attrs: vec![attr],
88                    vis: Visibility::Inherited,
89                    ident: None,
90                    colon_token: None,
91                    ty: parse_quote! {::core::marker::PhantomData<(#generic_params)>},
92                });
93                anonymous_struct.fields = Fields::Unnamed(fields);
94            }
95        }
96        anonymous_defs.extend(quote! {
97            #[derive(#cratename::BorshSchema)]
98            #anonymous_struct
99        });
100        add_recursive_defs.extend(quote! {
101            <#full_variant_ident #ty_generics>::add_definitions_recursively(definitions);
102        });
103        variants_defs.push(quote! {
104            (#variant_name_str.to_string(), <#full_variant_ident #ty_generics>::declaration())
105        });
106    }
107
108    let type_definitions = quote! {
109        fn add_definitions_recursively(definitions: &mut #cratename::maybestd::collections::HashMap<#cratename::schema::Declaration, #cratename::schema::Definition>) {
110            #anonymous_defs
111            #add_recursive_defs
112            let variants = #cratename::maybestd::vec![#(#variants_defs),*];
113            let definition = #cratename::schema::Definition::Enum{variants};
114            Self::add_definition(Self::declaration(), definition, definitions);
115        }
116    };
117    let where_clause = quote_where_clause(where_clause, where_clause_additions);
118    Ok(quote! {
119        impl #impl_generics #cratename::BorshSchema for #name #ty_generics #where_clause {
120            fn declaration() -> #cratename::schema::Declaration {
121                #declaration
122            }
123            #type_definitions
124        }
125    })
126}
127
128// Rustfmt removes comas.
129#[rustfmt::skip]
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    fn assert_eq(expected: TokenStream2, actual: TokenStream2) {
135        pretty_assertions::assert_eq!(expected.to_string(), actual.to_string())
136    }
137
138    #[test]
139    fn simple_enum() {
140        let item_enum: ItemEnum = syn::parse2(quote!{
141            enum A {
142                Bacon,
143                Eggs
144            }
145        }).unwrap();
146
147        let actual = process_enum(&item_enum, Ident::new("borsh", Span::call_site())).unwrap();
148        let expected = quote!{
149            impl borsh::BorshSchema for A {
150                fn declaration() -> borsh::schema::Declaration {
151                    "A".to_string()
152                }
153                fn add_definitions_recursively(
154                    definitions: &mut borsh::maybestd::collections::HashMap<
155                        borsh::schema::Declaration,
156                        borsh::schema::Definition
157                    >
158                ) {
159                    #[derive(borsh :: BorshSchema)]
160                    struct ABacon;
161                    #[derive(borsh :: BorshSchema)]
162                    struct AEggs;
163                    <ABacon>::add_definitions_recursively(definitions);
164                    <AEggs>::add_definitions_recursively(definitions);
165                    let variants = borsh::maybestd::vec![
166                        ("Bacon".to_string(), <ABacon>::declaration()),
167                        ("Eggs".to_string(), <AEggs>::declaration())
168                    ];
169                    let definition = borsh::schema::Definition::Enum { variants };
170                    Self::add_definition(Self::declaration(), definition, definitions);
171                }
172            }
173        };
174        assert_eq(expected, actual);
175    }
176
177    #[test]
178    fn single_field_enum() {
179        let item_enum: ItemEnum = syn::parse2(quote! {
180            enum A {
181                Bacon,
182            }
183        }).unwrap();
184
185        let actual = process_enum(&item_enum, Ident::new("borsh", Span::call_site())).unwrap();
186        let expected = quote!{
187            impl borsh::BorshSchema for A {
188                fn declaration() -> borsh::schema::Declaration {
189                    "A".to_string()
190                }
191                fn add_definitions_recursively(
192                    definitions: &mut borsh::maybestd::collections::HashMap<
193                        borsh::schema::Declaration,
194                        borsh::schema::Definition
195                    >
196                ) {
197                    #[derive(borsh :: BorshSchema)]
198                    struct ABacon;
199                    <ABacon>::add_definitions_recursively(definitions);
200                    let variants = borsh::maybestd::vec![("Bacon".to_string(), <ABacon>::declaration())];
201                    let definition = borsh::schema::Definition::Enum { variants };
202                    Self::add_definition(Self::declaration(), definition, definitions);
203                }
204            }
205        };
206        assert_eq(expected, actual);
207    }
208
209    #[test]
210    fn complex_enum() {
211        let item_enum: ItemEnum = syn::parse2(quote! {
212            enum A {
213                Bacon,
214                Eggs,
215                Salad(Tomatoes, Cucumber, Oil),
216                Sausage{wrapper: Wrapper, filling: Filling},
217            }
218        }).unwrap();
219
220        let actual = process_enum(&item_enum, Ident::new("borsh", Span::call_site())).unwrap();
221        let expected = quote!{
222            impl borsh::BorshSchema for A {
223                fn declaration() -> borsh::schema::Declaration {
224                    "A".to_string()
225                }
226                fn add_definitions_recursively(
227                    definitions: &mut borsh::maybestd::collections::HashMap<
228                        borsh::schema::Declaration,
229                        borsh::schema::Definition
230                    >
231                ) {
232                    #[derive(borsh :: BorshSchema)]
233                    struct ABacon;
234                    #[derive(borsh :: BorshSchema)]
235                    struct AEggs;
236                    #[derive(borsh :: BorshSchema)]
237                    struct ASalad(Tomatoes, Cucumber, Oil);
238                    #[derive(borsh :: BorshSchema)]
239                    struct ASausage {
240                        wrapper: Wrapper,
241                        filling: Filling
242                    }
243                    <ABacon>::add_definitions_recursively(definitions);
244                    <AEggs>::add_definitions_recursively(definitions);
245                    <ASalad>::add_definitions_recursively(definitions);
246                    <ASausage>::add_definitions_recursively(definitions);
247                    let variants = borsh::maybestd::vec![
248                        ("Bacon".to_string(), <ABacon>::declaration()),
249                        ("Eggs".to_string(), <AEggs>::declaration()),
250                        ("Salad".to_string(), <ASalad>::declaration()),
251                        ("Sausage".to_string(), <ASausage>::declaration())
252                    ];
253                    let definition = borsh::schema::Definition::Enum { variants };
254                    Self::add_definition(Self::declaration(), definition, definitions);
255                }
256            }
257        };
258        assert_eq(expected, actual);
259    }
260
261    #[test]
262    fn complex_enum_generics() {
263        let item_enum: ItemEnum = syn::parse2(quote! {
264            enum A<C, W> {
265                Bacon,
266                Eggs,
267                Salad(Tomatoes, C, Oil),
268                Sausage{wrapper: W, filling: Filling},
269            }
270        }).unwrap();
271
272        let actual = process_enum(&item_enum, Ident::new("borsh", Span::call_site())).unwrap();
273        let expected = quote!{
274            impl<C, W> borsh::BorshSchema for A<C, W>
275            where
276                C: borsh::BorshSchema,
277                W: borsh::BorshSchema
278            {
279                fn declaration() -> borsh::schema::Declaration {
280                    let params = borsh::maybestd::vec![<C>::declaration(), <W>::declaration()];
281                    format!(r#"{}<{}>"#, "A", params.join(", "))
282                }
283                fn add_definitions_recursively(
284                    definitions: &mut borsh::maybestd::collections::HashMap<
285                        borsh::schema::Declaration,
286                        borsh::schema::Definition
287                    >
288                ) {
289                    #[derive(borsh :: BorshSchema)]
290                    struct ABacon<C, W>(#[borsh_skip] ::core::marker::PhantomData<(C, W, )>);
291                    #[derive(borsh :: BorshSchema)]
292                    struct AEggs<C, W>(#[borsh_skip] ::core::marker::PhantomData<(C, W, )>);
293                    #[derive(borsh :: BorshSchema)]
294                    struct ASalad<C, W>(
295                        Tomatoes,
296                        C,
297                        Oil,
298                        #[borsh_skip] ::core::marker::PhantomData<(C, W, )>
299                    );
300                    #[derive(borsh :: BorshSchema)]
301                    struct ASausage<C, W> {
302                        wrapper: W,
303                        filling: Filling,
304                        #[borsh_skip]
305                        borsh_schema_phantom_data: ::core::marker::PhantomData<(C, W, )>
306                    }
307                    <ABacon<C, W> >::add_definitions_recursively(definitions);
308                    <AEggs<C, W> >::add_definitions_recursively(definitions);
309                    <ASalad<C, W> >::add_definitions_recursively(definitions);
310                    <ASausage<C, W> >::add_definitions_recursively(definitions);
311                    let variants = borsh::maybestd::vec![
312                        ("Bacon".to_string(), <ABacon<C, W> >::declaration()),
313                        ("Eggs".to_string(), <AEggs<C, W> >::declaration()),
314                        ("Salad".to_string(), <ASalad<C, W> >::declaration()),
315                        ("Sausage".to_string(), <ASausage<C, W> >::declaration())
316                    ];
317                    let definition = borsh::schema::Definition::Enum { variants };
318                    Self::add_definition(Self::declaration(), definition, definitions);
319                }
320            }
321        };
322        assert_eq(expected, actual);
323    }
324
325    #[test]
326    fn trailing_comma_generics() {
327        let item_struct: ItemEnum = syn::parse2(quote!{
328            enum Side<A, B>
329            where
330                A: Display + Debug,
331                B: Display + Debug,
332            {
333                Left(A),
334                Right(B),
335            }
336        })
337        .unwrap();
338
339        let actual = process_enum(
340            &item_struct,
341            Ident::new("borsh", proc_macro2::Span::call_site()),
342        )
343        .unwrap();
344        let expected = quote!{
345            impl<A, B> borsh::BorshSchema for Side<A, B>
346            where
347                A: Display + Debug,
348                B: Display + Debug,
349                A: borsh::BorshSchema,
350                B: borsh::BorshSchema
351            {
352                fn declaration() -> borsh::schema::Declaration {
353                    let params = borsh::maybestd::vec![<A>::declaration(), <B>::declaration()];
354                    format!(r#"{}<{}>"#, "Side", params.join(", "))
355                }
356                fn add_definitions_recursively(
357                    definitions: &mut borsh::maybestd::collections::HashMap<
358                        borsh::schema::Declaration,
359                        borsh::schema::Definition
360                    >
361                ) {
362                    #[derive(borsh :: BorshSchema)]
363                    struct SideLeft<A, B>
364                    (
365                        A, 
366                        #[borsh_skip] ::core::marker::PhantomData<(A, B, )>
367                    )
368                    where 
369                        A: Display + Debug, 
370                        B: Display + Debug,
371                    ;
372                    #[derive(borsh :: BorshSchema)]
373                    struct SideRight<A, B>
374                    (
375                        B, 
376                        #[borsh_skip] ::core::marker::PhantomData<(A, B, )>
377                    )
378                    where 
379                        A: Display + Debug, 
380                        B: Display + Debug,
381                    ;
382                    <SideLeft<A, B> >::add_definitions_recursively(definitions);
383                    <SideRight<A, B> >::add_definitions_recursively(definitions);
384                    let variants = borsh::maybestd::vec![
385                        ("Left".to_string(), <SideLeft<A, B> >::declaration()),
386                        ("Right".to_string(), <SideRight<A, B> >::declaration())
387                    ];
388                    let definition = borsh::schema::Definition::Enum { variants };
389                    Self::add_definition(Self::declaration(), definition, definitions);
390                }
391            }
392        };
393        assert_eq(expected, actual);
394    }
395}