borsh_schema_derive_internal/
struct_schema.rs

1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens};
3use syn::{Fields, Ident, ItemStruct};
4
5use crate::helpers::{contains_skip, declaration, quote_where_clause};
6
7pub fn process_struct(input: &ItemStruct, cratename: Ident) -> syn::Result<TokenStream2> {
8    let name = &input.ident;
9    let name_str = name.to_token_stream().to_string();
10    let generics = &input.generics;
11    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
12    // Generate function that returns the name of the type.
13    let (declaration, mut where_clause_additions) =
14        declaration(&name_str, &input.generics, cratename.clone());
15
16    // Generate function that returns the schema of required types.
17    let mut fields_vec = vec![];
18    let mut struct_fields = TokenStream2::new();
19    let mut add_definitions_recursively_rec = TokenStream2::new();
20    match &input.fields {
21        Fields::Named(fields) => {
22            for field in &fields.named {
23                if contains_skip(&field.attrs) {
24                    continue;
25                }
26                let field_name = field.ident.as_ref().unwrap().to_token_stream().to_string();
27                let field_type = &field.ty;
28                fields_vec.push(quote! {
29                    (#field_name.to_string(), <#field_type>::declaration())
30                });
31                add_definitions_recursively_rec.extend(quote! {
32                    <#field_type>::add_definitions_recursively(definitions);
33                });
34                where_clause_additions.push(quote! {
35                    #field_type: #cratename::BorshSchema
36                });
37            }
38            if !fields_vec.is_empty() {
39                struct_fields = quote! {
40                    let fields = #cratename::schema::Fields::NamedFields(#cratename::maybestd::vec![#(#fields_vec),*]);
41                };
42            }
43        }
44        Fields::Unnamed(fields) => {
45            for field in &fields.unnamed {
46                if contains_skip(&field.attrs) {
47                    continue;
48                }
49                let field_type = &field.ty;
50                fields_vec.push(quote! {
51                    <#field_type>::declaration()
52                });
53                add_definitions_recursively_rec.extend(quote! {
54                    <#field_type>::add_definitions_recursively(definitions);
55                });
56                where_clause_additions.push(quote! {
57                    #field_type: #cratename::BorshSchema
58                });
59            }
60            if !fields_vec.is_empty() {
61                struct_fields = quote! {
62                    let fields = #cratename::schema::Fields::UnnamedFields(#cratename::maybestd::vec![#(#fields_vec),*]);
63                };
64            }
65        }
66        Fields::Unit => {}
67    }
68
69    if fields_vec.is_empty() {
70        struct_fields = quote! {
71            let fields = #cratename::schema::Fields::Empty;
72        };
73    }
74
75    let add_definitions_recursively = quote! {
76        fn add_definitions_recursively(definitions: &mut #cratename::maybestd::collections::HashMap<#cratename::schema::Declaration, #cratename::schema::Definition>) {
77            #struct_fields
78            let definition = #cratename::schema::Definition::Struct { fields };
79            Self::add_definition(Self::declaration(), definition, definitions);
80            #add_definitions_recursively_rec
81        }
82    };
83    let where_clause = quote_where_clause(where_clause, where_clause_additions);
84    Ok(quote! {
85        impl #impl_generics #cratename::BorshSchema for #name #ty_generics #where_clause {
86            fn declaration() -> #cratename::schema::Declaration {
87                #declaration
88            }
89            #add_definitions_recursively
90        }
91    })
92}
93
94// Rustfmt removes comas.
95#[rustfmt::skip::macros(quote)]
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    fn assert_eq(expected: TokenStream2, actual: TokenStream2) {
101        pretty_assertions::assert_eq!(expected.to_string(), actual.to_string())
102    }
103
104    #[test]
105    fn unit_struct() {
106        let item_struct: ItemStruct = syn::parse2(quote!{
107            struct A;
108        })
109        .unwrap();
110
111        let actual = process_struct(
112            &item_struct,
113            Ident::new("borsh", proc_macro2::Span::call_site()),
114        )
115        .unwrap();
116        let expected = quote!{
117            impl borsh::BorshSchema for A
118            {
119                fn declaration() -> borsh::schema::Declaration {
120                    "A".to_string()
121                }
122                fn add_definitions_recursively(definitions: &mut borsh::maybestd::collections::HashMap<borsh::schema::Declaration, borsh::schema::Definition>) {
123                    let fields = borsh::schema::Fields::Empty;
124                    let definition = borsh::schema::Definition::Struct { fields };
125                    Self::add_definition(Self::declaration(), definition, definitions);
126                }
127            }
128        };
129        assert_eq(expected, actual);
130    }
131
132    #[test]
133    fn wrapper_struct() {
134        let item_struct: ItemStruct = syn::parse2(quote!{
135            struct A<T>(T);
136        })
137        .unwrap();
138
139        let actual = process_struct(
140            &item_struct,
141            Ident::new("borsh", proc_macro2::Span::call_site()),
142        )
143        .unwrap();
144        let expected = quote!{
145            impl<T> borsh::BorshSchema for A<T>
146            where
147                T: borsh::BorshSchema,
148                T: borsh::BorshSchema
149            {
150                fn declaration() -> borsh::schema::Declaration {
151                    let params = borsh::maybestd::vec![<T>::declaration()];
152                    format!(r#"{}<{}>"#, "A", params.join(", "))
153                }
154                fn add_definitions_recursively(
155                    definitions: &mut borsh::maybestd::collections::HashMap<
156                        borsh::schema::Declaration,
157                        borsh::schema::Definition
158                    >
159                ) {
160                    let fields = borsh::schema::Fields::UnnamedFields(borsh::maybestd::vec![<T>::declaration()]);
161                    let definition = borsh::schema::Definition::Struct { fields };
162                    Self::add_definition(Self::declaration(), definition, definitions);
163                    <T>::add_definitions_recursively(definitions);
164                }
165            }
166        };
167        assert_eq(expected, actual);
168    }
169
170    #[test]
171    fn tuple_struct() {
172        let item_struct: ItemStruct = syn::parse2(quote!{
173            struct A(u64, String);
174        })
175        .unwrap();
176
177        let actual = process_struct(
178            &item_struct,
179            Ident::new("borsh", proc_macro2::Span::call_site()),
180        )
181        .unwrap();
182        let expected = quote!{
183            impl borsh::BorshSchema for A
184            where
185                u64: borsh::BorshSchema,
186                String: borsh::BorshSchema
187            {
188                fn declaration() -> borsh::schema::Declaration {
189                    "A".to_string()
190                }
191                fn add_definitions_recursively(
192                    definitions: &mut borsh::maybestd::collections::HashMap<
193                        borsh::schema::Declaration,
194                        borsh::schema::Definition
195                    >
196                ) {
197                    let fields = borsh::schema::Fields::UnnamedFields(borsh::maybestd::vec![
198                        <u64>::declaration(),
199                        <String>::declaration()
200                    ]);
201                    let definition = borsh::schema::Definition::Struct { fields };
202                    Self::add_definition(Self::declaration(), definition, definitions);
203                    <u64>::add_definitions_recursively(definitions);
204                    <String>::add_definitions_recursively(definitions);
205                }
206            }
207        };
208        assert_eq(expected, actual);
209    }
210
211    #[test]
212    fn tuple_struct_params() {
213        let item_struct: ItemStruct = syn::parse2(quote!{
214            struct A<K, V>(K, V);
215        })
216        .unwrap();
217
218        let actual = process_struct(
219            &item_struct,
220            Ident::new("borsh", proc_macro2::Span::call_site()),
221        )
222        .unwrap();
223        let expected = quote!{
224            impl<K, V> borsh::BorshSchema for A<K, V>
225            where
226                K: borsh::BorshSchema,
227                V: borsh::BorshSchema,
228                K: borsh::BorshSchema,
229                V: borsh::BorshSchema
230            {
231                fn declaration() -> borsh::schema::Declaration {
232                    let params = borsh::maybestd::vec![<K>::declaration(), <V>::declaration()];
233                    format!(r#"{}<{}>"#, "A", params.join(", "))
234                }
235                fn add_definitions_recursively(
236                    definitions: &mut borsh::maybestd::collections::HashMap<
237                        borsh::schema::Declaration,
238                        borsh::schema::Definition
239                    >
240                ) {
241                    let fields =
242                        borsh::schema::Fields::UnnamedFields(borsh::maybestd::vec![<K>::declaration(), <V>::declaration()]);
243                    let definition = borsh::schema::Definition::Struct { fields };
244                    Self::add_definition(Self::declaration(), definition, definitions);
245                    <K>::add_definitions_recursively(definitions);
246                    <V>::add_definitions_recursively(definitions);
247                }
248            }
249        };
250        assert_eq(expected, actual);
251    }
252
253    #[test]
254    fn simple_struct() {
255        let item_struct: ItemStruct = syn::parse2(quote!{
256            struct A {
257                x: u64,
258                y: String,
259            }
260        })
261        .unwrap();
262
263        let actual = process_struct(
264            &item_struct,
265            Ident::new("borsh", proc_macro2::Span::call_site()),
266        )
267        .unwrap();
268        let expected = quote!{
269            impl borsh::BorshSchema for A
270            where
271                u64: borsh::BorshSchema,
272                String: borsh::BorshSchema
273            {
274                fn declaration() -> borsh::schema::Declaration {
275                    "A".to_string()
276                }
277                fn add_definitions_recursively(
278                    definitions: &mut borsh::maybestd::collections::HashMap<
279                        borsh::schema::Declaration,
280                        borsh::schema::Definition
281                    >
282                ) {
283                    let fields = borsh::schema::Fields::NamedFields(borsh::maybestd::vec![
284                        ("x".to_string(), <u64>::declaration()),
285                        ("y".to_string(), <String>::declaration())
286                    ]);
287                    let definition = borsh::schema::Definition::Struct { fields };
288                    Self::add_definition(Self::declaration(), definition, definitions);
289                    <u64>::add_definitions_recursively(definitions);
290                    <String>::add_definitions_recursively(definitions);
291                }
292            }
293        };
294        assert_eq(expected, actual);
295    }
296
297    #[test]
298    fn simple_generics() {
299        let item_struct: ItemStruct = syn::parse2(quote!{
300            struct A<K, V> {
301                x: HashMap<K, V>,
302                y: String,
303            }
304        })
305        .unwrap();
306
307        let actual = process_struct(
308            &item_struct,
309            Ident::new("borsh", proc_macro2::Span::call_site()),
310        )
311        .unwrap();
312        let expected = quote!{
313            impl<K, V> borsh::BorshSchema for A<K, V>
314            where
315                K: borsh::BorshSchema,
316                V: borsh::BorshSchema,
317                HashMap<K, V>: borsh::BorshSchema,
318                String: borsh::BorshSchema
319            {
320                fn declaration() -> borsh::schema::Declaration {
321                    let params = borsh::maybestd::vec![<K>::declaration(), <V>::declaration()];
322                    format!(r#"{}<{}>"#, "A", params.join(", "))
323                }
324                fn add_definitions_recursively(
325                    definitions: &mut borsh::maybestd::collections::HashMap<
326                        borsh::schema::Declaration,
327                        borsh::schema::Definition
328                    >
329                ) {
330                    let fields = borsh::schema::Fields::NamedFields(borsh::maybestd::vec![
331                        ("x".to_string(), <HashMap<K, V> >::declaration()),
332                        ("y".to_string(), <String>::declaration())
333                    ]);
334                    let definition = borsh::schema::Definition::Struct { fields };
335                    Self::add_definition(Self::declaration(), definition, definitions);
336                    <HashMap<K, V> >::add_definitions_recursively(definitions);
337                    <String>::add_definitions_recursively(definitions);
338                }
339            }
340        };
341        assert_eq(expected, actual);
342    }
343
344    #[test]
345    fn trailing_comma_generics() {
346        let item_struct: ItemStruct = syn::parse2(quote!{
347            struct A<K, V>
348            where
349                K: Display + Debug,
350            {
351                x: HashMap<K, V>,
352                y: String,
353            }
354        })
355        .unwrap();
356
357        let actual = process_struct(
358            &item_struct,
359            Ident::new("borsh", proc_macro2::Span::call_site()),
360        )
361        .unwrap();
362        let expected = quote!{
363            impl<K, V> borsh::BorshSchema for A<K, V>
364            where
365                K: Display + Debug,
366                K: borsh::BorshSchema,
367                V: borsh::BorshSchema,
368                HashMap<K, V>: borsh::BorshSchema,
369                String: borsh::BorshSchema
370            {
371                fn declaration() -> borsh::schema::Declaration {
372                    let params = borsh::maybestd::vec![<K>::declaration(), <V>::declaration()];
373                    format!(r#"{}<{}>"#, "A", params.join(", "))
374                }
375                fn add_definitions_recursively(
376                    definitions: &mut borsh::maybestd::collections::HashMap<
377                        borsh::schema::Declaration,
378                        borsh::schema::Definition
379                    >
380                ) {
381                    let fields = borsh::schema::Fields::NamedFields(borsh::maybestd::vec![
382                        ("x".to_string(), <HashMap<K, V> >::declaration()),
383                        ("y".to_string(), <String>::declaration())
384                    ]);
385                    let definition = borsh::schema::Definition::Struct { fields };
386                    Self::add_definition(Self::declaration(), definition, definitions);
387                    <HashMap<K, V> >::add_definitions_recursively(definitions);
388                    <String>::add_definitions_recursively(definitions);
389                }
390            }
391        };
392        assert_eq(expected, actual);
393    }
394
395    #[test]
396    fn tuple_struct_whole_skip() {
397        let item_struct: ItemStruct = syn::parse2(quote!{
398            struct A(#[borsh_skip] String);
399        })
400        .unwrap();
401
402        let actual = process_struct(
403            &item_struct,
404            Ident::new("borsh", proc_macro2::Span::call_site()),
405        )
406        .unwrap();
407        let expected = quote!{
408            impl borsh::BorshSchema for A {
409                fn declaration() -> borsh::schema::Declaration {
410                    "A".to_string()
411                }
412                fn add_definitions_recursively(
413                    definitions: &mut borsh::maybestd::collections::HashMap<
414                        borsh::schema::Declaration,
415                        borsh::schema::Definition
416                    >
417                ) {
418                    let fields = borsh::schema::Fields::Empty;
419                    let definition = borsh::schema::Definition::Struct { fields };
420                    Self::add_definition(Self::declaration(), definition, definitions);
421                }
422            }
423        };
424        assert_eq(expected, actual);
425    }
426
427    #[test]
428    fn tuple_struct_partial_skip() {
429        let item_struct: ItemStruct = syn::parse2(quote!{
430            struct A(#[borsh_skip] u64, String);
431        })
432        .unwrap();
433
434        let actual = process_struct(
435            &item_struct,
436            Ident::new("borsh", proc_macro2::Span::call_site()),
437        )
438        .unwrap();
439        let expected = quote!{
440            impl borsh::BorshSchema for A
441            where
442                String: borsh::BorshSchema
443            {
444                fn declaration() -> borsh::schema::Declaration {
445                    "A".to_string()
446                }
447                fn add_definitions_recursively(
448                    definitions: &mut borsh::maybestd::collections::HashMap<
449                        borsh::schema::Declaration,
450                        borsh::schema::Definition
451                    >
452                ) {
453                    let fields = borsh::schema::Fields::UnnamedFields(borsh::maybestd::vec![<String>::declaration()]);
454                    let definition = borsh::schema::Definition::Struct { fields };
455                    Self::add_definition(Self::declaration(), definition, definitions);
456                    <String>::add_definitions_recursively(definitions);
457                }
458            }
459        };
460        assert_eq(expected, actual);
461    }
462}