borsh_derive/internals/schema/structs/
mod.rs

1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens};
3use syn::{ExprPath, Fields, Ident, ItemStruct, Path, Type};
4
5use crate::internals::{attributes::field, generics, schema};
6
7/// function which computes derive output [proc_macro2::TokenStream]
8/// of code, which computes declaration of a single field, which is later added to
9/// the struct's definition as a whole  
10fn field_declaration_output(
11    field_name: Option<&Ident>,
12    field_type: &Type,
13    cratename: &Path,
14    declaration_override: Option<ExprPath>,
15) -> TokenStream2 {
16    let default_path: ExprPath =
17        syn::parse2(quote! { <#field_type as #cratename::BorshSchema>::declaration }).unwrap();
18
19    let path = declaration_override.unwrap_or(default_path);
20
21    if let Some(field_name) = field_name {
22        let field_name = field_name.to_token_stream().to_string();
23        quote! {
24            (#field_name.to_string(), #path())
25        }
26    } else {
27        quote! {
28            #path()
29        }
30    }
31}
32
33/// function which computes derive output [proc_macro2::TokenStream]
34/// of code, which adds definitions of a field to the output `definitions: &mut BTreeMap`
35fn field_definitions_output(
36    field_type: &Type,
37    cratename: &Path,
38    definitions_override: Option<ExprPath>,
39) -> TokenStream2 {
40    let default_path: ExprPath = syn::parse2(
41        quote! { <#field_type as #cratename::BorshSchema>::add_definitions_recursively },
42    )
43    .unwrap();
44    let path = definitions_override.unwrap_or(default_path);
45
46    quote! {
47        #path(definitions);
48    }
49}
50
51pub fn process(input: &ItemStruct, cratename: Path) -> syn::Result<TokenStream2> {
52    let name = &input.ident;
53    let struct_name = name.to_token_stream().to_string();
54    let generics = generics::without_defaults(&input.generics);
55    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
56    let mut where_clause = generics::default_where(where_clause);
57    let mut generics_output = schema::GenericsOutput::new(&generics);
58    let (struct_fields, add_definitions_recursively) =
59        process_fields(&cratename, &input.fields, &mut generics_output)?;
60
61    let add_definitions_recursively = quote! {
62        fn add_definitions_recursively(definitions: &mut #cratename::__private::maybestd::collections::BTreeMap<#cratename::schema::Declaration, #cratename::schema::Definition>) {
63            #struct_fields
64            let definition = #cratename::schema::Definition::Struct { fields };
65
66            let no_recursion_flag = definitions.get(&<Self as #cratename::BorshSchema>::declaration()).is_none();
67            #cratename::schema::add_definition(<Self as #cratename::BorshSchema>::declaration(), definition, definitions);
68            if no_recursion_flag {
69                #add_definitions_recursively
70            }
71        }
72    };
73
74    let (predicates, declaration) = generics_output.result(&struct_name, &cratename);
75    where_clause.predicates.extend(predicates);
76    Ok(quote! {
77        impl #impl_generics #cratename::BorshSchema for #name #ty_generics #where_clause {
78            fn declaration() -> #cratename::schema::Declaration {
79                #declaration
80            }
81            #add_definitions_recursively
82        }
83    })
84}
85
86fn process_fields(
87    cratename: &Path,
88    fields: &Fields,
89    generics: &mut schema::GenericsOutput,
90) -> syn::Result<(TokenStream2, TokenStream2)> {
91    let mut struct_fields = TokenStream2::new();
92    let mut add_definitions_recursively = TokenStream2::new();
93
94    // Generate function that returns the schema of required types.
95    let mut fields_vec = vec![];
96    schema::visit_struct_fields(fields, &mut generics.params_visitor)?;
97    match fields {
98        Fields::Named(fields) => {
99            for field in &fields.named {
100                process_field(
101                    field,
102                    cratename,
103                    &mut fields_vec,
104                    &mut add_definitions_recursively,
105                )?;
106            }
107            if !fields_vec.is_empty() {
108                struct_fields = quote! {
109                    let fields = #cratename::schema::Fields::NamedFields(#cratename::__private::maybestd::vec![#(#fields_vec),*]);
110                };
111            }
112        }
113        Fields::Unnamed(fields) => {
114            for field in &fields.unnamed {
115                process_field(
116                    field,
117                    cratename,
118                    &mut fields_vec,
119                    &mut add_definitions_recursively,
120                )?;
121            }
122            if !fields_vec.is_empty() {
123                struct_fields = quote! {
124                    let fields = #cratename::schema::Fields::UnnamedFields(#cratename::__private::maybestd::vec![#(#fields_vec),*]);
125                };
126            }
127        }
128        Fields::Unit => {}
129    }
130
131    if fields_vec.is_empty() {
132        struct_fields = quote! {
133            let fields = #cratename::schema::Fields::Empty;
134        };
135    }
136    Ok((struct_fields, add_definitions_recursively))
137}
138fn process_field(
139    field: &syn::Field,
140    cratename: &Path,
141    fields_vec: &mut Vec<TokenStream2>,
142    add_definitions_recursively: &mut TokenStream2,
143) -> syn::Result<()> {
144    let parsed = field::Attributes::parse(&field.attrs)?;
145    if !parsed.skip {
146        let field_name = field.ident.as_ref();
147        let field_type = &field.ty;
148        fields_vec.push(field_declaration_output(
149            field_name,
150            field_type,
151            cratename,
152            parsed.schema_declaration(),
153        ));
154        add_definitions_recursively.extend(field_definitions_output(
155            field_type,
156            cratename,
157            parsed.schema_definitions(),
158        ));
159    }
160    Ok(())
161}
162
163#[cfg(test)]
164mod tests {
165    use crate::internals::test_helpers::{
166        default_cratename, local_insta_assert_debug_snapshot, local_insta_assert_snapshot,
167        pretty_print_syn_str,
168    };
169
170    use super::*;
171
172    #[test]
173    fn unit_struct() {
174        let item_struct: ItemStruct = syn::parse2(quote! {
175            struct A;
176        })
177        .unwrap();
178
179        let actual = process(&item_struct, default_cratename()).unwrap();
180        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
181    }
182
183    #[test]
184    fn wrapper_struct() {
185        let item_struct: ItemStruct = syn::parse2(quote! {
186            struct A<T>(T);
187        })
188        .unwrap();
189
190        let actual = process(&item_struct, default_cratename()).unwrap();
191        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
192    }
193
194    #[test]
195    fn tuple_struct() {
196        let item_struct: ItemStruct = syn::parse2(quote! {
197            struct A(u64, String);
198        })
199        .unwrap();
200
201        let actual = process(&item_struct, default_cratename()).unwrap();
202        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
203    }
204
205    #[test]
206    fn tuple_struct_params() {
207        let item_struct: ItemStruct = syn::parse2(quote! {
208            struct A<K, V>(K, V);
209        })
210        .unwrap();
211
212        let actual = process(&item_struct, default_cratename()).unwrap();
213        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
214    }
215
216    #[test]
217    fn simple_struct() {
218        let item_struct: ItemStruct = syn::parse2(quote! {
219            struct A {
220                x: u64,
221                y: String,
222            }
223        })
224        .unwrap();
225
226        let actual = process(&item_struct, default_cratename()).unwrap();
227        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
228    }
229
230    #[test]
231    fn simple_struct_with_custom_crate() {
232        let item_struct: ItemStruct = syn::parse2(quote! {
233            struct A {
234                x: u64,
235                y: String,
236            }
237        })
238        .unwrap();
239
240        let crate_: Path = syn::parse2(quote! { reexporter::borsh }).unwrap();
241        let actual = process(&item_struct, crate_).unwrap();
242        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
243    }
244
245    #[test]
246    fn simple_generics() {
247        let item_struct: ItemStruct = syn::parse2(quote! {
248            struct A<K, V> {
249                x: HashMap<K, V>,
250                y: String,
251            }
252        })
253        .unwrap();
254
255        let actual = process(&item_struct, default_cratename()).unwrap();
256        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
257    }
258
259    #[test]
260    fn trailing_comma_generics() {
261        let item_struct: ItemStruct = syn::parse2(quote! {
262            struct A<K, V>
263            where
264                K: Display + Debug,
265            {
266                x: HashMap<K, V>,
267                y: String,
268            }
269        })
270        .unwrap();
271
272        let actual = process(&item_struct, default_cratename()).unwrap();
273        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
274    }
275
276    #[test]
277    fn tuple_struct_whole_skip() {
278        let item_struct: ItemStruct = syn::parse2(quote! {
279            struct A(#[borsh(skip)] String);
280        })
281        .unwrap();
282
283        let actual = process(&item_struct, default_cratename()).unwrap();
284        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
285    }
286
287    #[test]
288    fn tuple_struct_partial_skip() {
289        let item_struct: ItemStruct = syn::parse2(quote! {
290            struct A(#[borsh(skip)] u64, String);
291        })
292        .unwrap();
293
294        let actual = process(&item_struct, default_cratename()).unwrap();
295        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
296    }
297
298    #[test]
299    fn generic_tuple_struct_borsh_skip1() {
300        let item_struct: ItemStruct = syn::parse2(quote! {
301            struct G<K, V, U> (
302                #[borsh(skip)]
303                HashMap<K, V>,
304                U,
305            );
306        })
307        .unwrap();
308
309        let actual = process(&item_struct, default_cratename()).unwrap();
310
311        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
312    }
313
314    #[test]
315    fn generic_tuple_struct_borsh_skip2() {
316        let item_struct: ItemStruct = syn::parse2(quote! {
317            struct G<K, V, U> (
318                HashMap<K, V>,
319                #[borsh(skip)]
320                U,
321            );
322        })
323        .unwrap();
324
325        let actual = process(&item_struct, default_cratename()).unwrap();
326
327        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
328    }
329
330    #[test]
331    fn generic_tuple_struct_borsh_skip3() {
332        let item_struct: ItemStruct = syn::parse2(quote! {
333            struct G<U, K, V> (
334                #[borsh(skip)]
335                HashMap<K, V>,
336                U,
337                K,
338            );
339        })
340        .unwrap();
341
342        let actual = process(&item_struct, default_cratename()).unwrap();
343
344        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
345    }
346
347    #[test]
348    fn generic_tuple_struct_borsh_skip4() {
349        let item_struct: ItemStruct = syn::parse2(quote! {
350            struct ASalad<C>(Tomatoes, #[borsh(skip)] C, Oil);
351        })
352        .unwrap();
353
354        let actual = process(&item_struct, default_cratename()).unwrap();
355
356        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
357    }
358
359    #[test]
360    fn generic_named_fields_struct_borsh_skip() {
361        let item_struct: ItemStruct = syn::parse2(quote! {
362            struct G<K, V, U> {
363                #[borsh(skip)]
364                x: HashMap<K, V>,
365                y: U,
366            }
367        })
368        .unwrap();
369
370        let actual = process(&item_struct, default_cratename()).unwrap();
371
372        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
373    }
374
375    #[test]
376    fn recursive_struct() {
377        let item_struct: ItemStruct = syn::parse2(quote! {
378            struct CRecC {
379                a: String,
380                b: HashMap<String, CRecC>,
381            }
382        })
383        .unwrap();
384
385        let actual = process(&item_struct, default_cratename()).unwrap();
386
387        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
388    }
389
390    #[test]
391    fn generic_associated_type() {
392        let item_struct: ItemStruct = syn::parse2(quote! {
393            struct Parametrized<V, T: Debug>
394            where
395                T: TraitName,
396            {
397                field: T::Associated,
398                another: V,
399            }
400        })
401        .unwrap();
402
403        let actual = process(&item_struct, default_cratename()).unwrap();
404
405        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
406    }
407
408    #[test]
409    fn generic_associated_type_param_override() {
410        let item_struct: ItemStruct = syn::parse2(quote! {
411            struct Parametrized<V, T>
412            where
413                T: TraitName,
414            {
415                #[borsh(schema(params =
416                    "T => <T as TraitName>::Associated"
417               ))]
418                field: <T as TraitName>::Associated,
419                another: V,
420            }
421        })
422        .unwrap();
423
424        let actual = process(&item_struct, default_cratename()).unwrap();
425
426        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
427    }
428
429    #[test]
430    fn generic_associated_type_param_override2() {
431        let item_struct: ItemStruct = syn::parse2(quote! {
432            struct Parametrized<V, T>
433            where
434                T: TraitName,
435            {
436                #[borsh(schema(params =
437                    "T => T, T => <T as TraitName>::Associated"
438               ))]
439                field: (<T as TraitName>::Associated, T),
440                another: V,
441            }
442        })
443        .unwrap();
444
445        let actual = process(&item_struct, default_cratename()).unwrap();
446
447        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
448    }
449
450    #[test]
451    fn generic_associated_type_param_override_conflict() {
452        let item_struct: ItemStruct = syn::parse2(quote! {
453            struct Parametrized<V, T>
454            where
455                T: TraitName,
456            {
457                #[borsh(skip,schema(params =
458                    "T => <T as TraitName>::Associated"
459               ))]
460                field: <T as TraitName>::Associated,
461                another: V,
462            }
463        })
464        .unwrap();
465
466        let actual = process(&item_struct, default_cratename());
467
468        local_insta_assert_debug_snapshot!(actual.unwrap_err());
469    }
470
471    #[test]
472    fn check_with_funcs_skip_conflict() {
473        let item_struct: ItemStruct = syn::parse2(quote! {
474            struct A<K, V> {
475                #[borsh(skip,schema(with_funcs(
476                    declaration = "third_party_impl::declaration::<K, V>",
477                    definitions = "third_party_impl::add_definitions_recursively::<K, V>"
478                )))]
479                x: ThirdParty<K, V>,
480                y: u64,
481            }
482        })
483        .unwrap();
484
485        let actual = process(&item_struct, default_cratename());
486
487        local_insta_assert_debug_snapshot!(actual.unwrap_err());
488    }
489
490    #[test]
491    fn with_funcs_attr() {
492        let item_struct: ItemStruct = syn::parse2(quote! {
493            struct A<K, V> {
494                #[borsh(schema(with_funcs(
495                    declaration = "third_party_impl::declaration::<K, V>",
496                    definitions = "third_party_impl::add_definitions_recursively::<K, V>"
497                )))]
498                x: ThirdParty<K, V>,
499                y: u64,
500            }
501        })
502        .unwrap();
503
504        let actual = process(&item_struct, default_cratename()).unwrap();
505
506        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
507    }
508
509    #[test]
510    fn schema_param_override3() {
511        let item_struct: ItemStruct = syn::parse2(quote! {
512            struct A<K: EntityRef, V> {
513                #[borsh(
514                    schema(
515                        params = "V => V"
516                    )
517                )]
518                x: PrimaryMap<K, V>,
519                y: String,
520            }
521        })
522        .unwrap();
523
524        let actual = process(&item_struct, default_cratename()).unwrap();
525
526        local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
527    }
528}