borsh_derive/internals/schema/structs/
mod.rs1use proc_macro2::TokenStream as TokenStream2;
2use quote::{quote, ToTokens};
3use syn::{ExprPath, Fields, Ident, ItemStruct, Path, Type};
4
5use crate::internals::{attributes::field, generics, schema};
6
7fn field_declaration_output(
11 field_name: Option<&Ident>,
12 field_type: &Type,
13 cratename: &Path,
14 declaration_override: Option<ExprPath>,
15) -> TokenStream2 {
16 let default_path: ExprPath =
17 syn::parse2(quote! { <#field_type as #cratename::BorshSchema>::declaration }).unwrap();
18
19 let path = declaration_override.unwrap_or(default_path);
20
21 if let Some(field_name) = field_name {
22 let field_name = field_name.to_token_stream().to_string();
23 quote! {
24 (#field_name.to_string(), #path())
25 }
26 } else {
27 quote! {
28 #path()
29 }
30 }
31}
32
33fn field_definitions_output(
36 field_type: &Type,
37 cratename: &Path,
38 definitions_override: Option<ExprPath>,
39) -> TokenStream2 {
40 let default_path: ExprPath = syn::parse2(
41 quote! { <#field_type as #cratename::BorshSchema>::add_definitions_recursively },
42 )
43 .unwrap();
44 let path = definitions_override.unwrap_or(default_path);
45
46 quote! {
47 #path(definitions);
48 }
49}
50
51pub fn process(input: &ItemStruct, cratename: Path) -> syn::Result<TokenStream2> {
52 let name = &input.ident;
53 let struct_name = name.to_token_stream().to_string();
54 let generics = generics::without_defaults(&input.generics);
55 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
56 let mut where_clause = generics::default_where(where_clause);
57 let mut generics_output = schema::GenericsOutput::new(&generics);
58 let (struct_fields, add_definitions_recursively) =
59 process_fields(&cratename, &input.fields, &mut generics_output)?;
60
61 let add_definitions_recursively = quote! {
62 fn add_definitions_recursively(definitions: &mut #cratename::__private::maybestd::collections::BTreeMap<#cratename::schema::Declaration, #cratename::schema::Definition>) {
63 #struct_fields
64 let definition = #cratename::schema::Definition::Struct { fields };
65
66 let no_recursion_flag = definitions.get(&<Self as #cratename::BorshSchema>::declaration()).is_none();
67 #cratename::schema::add_definition(<Self as #cratename::BorshSchema>::declaration(), definition, definitions);
68 if no_recursion_flag {
69 #add_definitions_recursively
70 }
71 }
72 };
73
74 let (predicates, declaration) = generics_output.result(&struct_name, &cratename);
75 where_clause.predicates.extend(predicates);
76 Ok(quote! {
77 impl #impl_generics #cratename::BorshSchema for #name #ty_generics #where_clause {
78 fn declaration() -> #cratename::schema::Declaration {
79 #declaration
80 }
81 #add_definitions_recursively
82 }
83 })
84}
85
86fn process_fields(
87 cratename: &Path,
88 fields: &Fields,
89 generics: &mut schema::GenericsOutput,
90) -> syn::Result<(TokenStream2, TokenStream2)> {
91 let mut struct_fields = TokenStream2::new();
92 let mut add_definitions_recursively = TokenStream2::new();
93
94 let mut fields_vec = vec![];
96 schema::visit_struct_fields(fields, &mut generics.params_visitor)?;
97 match fields {
98 Fields::Named(fields) => {
99 for field in &fields.named {
100 process_field(
101 field,
102 cratename,
103 &mut fields_vec,
104 &mut add_definitions_recursively,
105 )?;
106 }
107 if !fields_vec.is_empty() {
108 struct_fields = quote! {
109 let fields = #cratename::schema::Fields::NamedFields(#cratename::__private::maybestd::vec![#(#fields_vec),*]);
110 };
111 }
112 }
113 Fields::Unnamed(fields) => {
114 for field in &fields.unnamed {
115 process_field(
116 field,
117 cratename,
118 &mut fields_vec,
119 &mut add_definitions_recursively,
120 )?;
121 }
122 if !fields_vec.is_empty() {
123 struct_fields = quote! {
124 let fields = #cratename::schema::Fields::UnnamedFields(#cratename::__private::maybestd::vec![#(#fields_vec),*]);
125 };
126 }
127 }
128 Fields::Unit => {}
129 }
130
131 if fields_vec.is_empty() {
132 struct_fields = quote! {
133 let fields = #cratename::schema::Fields::Empty;
134 };
135 }
136 Ok((struct_fields, add_definitions_recursively))
137}
138fn process_field(
139 field: &syn::Field,
140 cratename: &Path,
141 fields_vec: &mut Vec<TokenStream2>,
142 add_definitions_recursively: &mut TokenStream2,
143) -> syn::Result<()> {
144 let parsed = field::Attributes::parse(&field.attrs)?;
145 if !parsed.skip {
146 let field_name = field.ident.as_ref();
147 let field_type = &field.ty;
148 fields_vec.push(field_declaration_output(
149 field_name,
150 field_type,
151 cratename,
152 parsed.schema_declaration(),
153 ));
154 add_definitions_recursively.extend(field_definitions_output(
155 field_type,
156 cratename,
157 parsed.schema_definitions(),
158 ));
159 }
160 Ok(())
161}
162
163#[cfg(test)]
164mod tests {
165 use crate::internals::test_helpers::{
166 default_cratename, local_insta_assert_debug_snapshot, local_insta_assert_snapshot,
167 pretty_print_syn_str,
168 };
169
170 use super::*;
171
172 #[test]
173 fn unit_struct() {
174 let item_struct: ItemStruct = syn::parse2(quote! {
175 struct A;
176 })
177 .unwrap();
178
179 let actual = process(&item_struct, default_cratename()).unwrap();
180 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
181 }
182
183 #[test]
184 fn wrapper_struct() {
185 let item_struct: ItemStruct = syn::parse2(quote! {
186 struct A<T>(T);
187 })
188 .unwrap();
189
190 let actual = process(&item_struct, default_cratename()).unwrap();
191 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
192 }
193
194 #[test]
195 fn tuple_struct() {
196 let item_struct: ItemStruct = syn::parse2(quote! {
197 struct A(u64, String);
198 })
199 .unwrap();
200
201 let actual = process(&item_struct, default_cratename()).unwrap();
202 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
203 }
204
205 #[test]
206 fn tuple_struct_params() {
207 let item_struct: ItemStruct = syn::parse2(quote! {
208 struct A<K, V>(K, V);
209 })
210 .unwrap();
211
212 let actual = process(&item_struct, default_cratename()).unwrap();
213 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
214 }
215
216 #[test]
217 fn simple_struct() {
218 let item_struct: ItemStruct = syn::parse2(quote! {
219 struct A {
220 x: u64,
221 y: String,
222 }
223 })
224 .unwrap();
225
226 let actual = process(&item_struct, default_cratename()).unwrap();
227 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
228 }
229
230 #[test]
231 fn simple_struct_with_custom_crate() {
232 let item_struct: ItemStruct = syn::parse2(quote! {
233 struct A {
234 x: u64,
235 y: String,
236 }
237 })
238 .unwrap();
239
240 let crate_: Path = syn::parse2(quote! { reexporter::borsh }).unwrap();
241 let actual = process(&item_struct, crate_).unwrap();
242 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
243 }
244
245 #[test]
246 fn simple_generics() {
247 let item_struct: ItemStruct = syn::parse2(quote! {
248 struct A<K, V> {
249 x: HashMap<K, V>,
250 y: String,
251 }
252 })
253 .unwrap();
254
255 let actual = process(&item_struct, default_cratename()).unwrap();
256 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
257 }
258
259 #[test]
260 fn trailing_comma_generics() {
261 let item_struct: ItemStruct = syn::parse2(quote! {
262 struct A<K, V>
263 where
264 K: Display + Debug,
265 {
266 x: HashMap<K, V>,
267 y: String,
268 }
269 })
270 .unwrap();
271
272 let actual = process(&item_struct, default_cratename()).unwrap();
273 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
274 }
275
276 #[test]
277 fn tuple_struct_whole_skip() {
278 let item_struct: ItemStruct = syn::parse2(quote! {
279 struct A(#[borsh(skip)] String);
280 })
281 .unwrap();
282
283 let actual = process(&item_struct, default_cratename()).unwrap();
284 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
285 }
286
287 #[test]
288 fn tuple_struct_partial_skip() {
289 let item_struct: ItemStruct = syn::parse2(quote! {
290 struct A(#[borsh(skip)] u64, String);
291 })
292 .unwrap();
293
294 let actual = process(&item_struct, default_cratename()).unwrap();
295 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
296 }
297
298 #[test]
299 fn generic_tuple_struct_borsh_skip1() {
300 let item_struct: ItemStruct = syn::parse2(quote! {
301 struct G<K, V, U> (
302 #[borsh(skip)]
303 HashMap<K, V>,
304 U,
305 );
306 })
307 .unwrap();
308
309 let actual = process(&item_struct, default_cratename()).unwrap();
310
311 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
312 }
313
314 #[test]
315 fn generic_tuple_struct_borsh_skip2() {
316 let item_struct: ItemStruct = syn::parse2(quote! {
317 struct G<K, V, U> (
318 HashMap<K, V>,
319 #[borsh(skip)]
320 U,
321 );
322 })
323 .unwrap();
324
325 let actual = process(&item_struct, default_cratename()).unwrap();
326
327 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
328 }
329
330 #[test]
331 fn generic_tuple_struct_borsh_skip3() {
332 let item_struct: ItemStruct = syn::parse2(quote! {
333 struct G<U, K, V> (
334 #[borsh(skip)]
335 HashMap<K, V>,
336 U,
337 K,
338 );
339 })
340 .unwrap();
341
342 let actual = process(&item_struct, default_cratename()).unwrap();
343
344 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
345 }
346
347 #[test]
348 fn generic_tuple_struct_borsh_skip4() {
349 let item_struct: ItemStruct = syn::parse2(quote! {
350 struct ASalad<C>(Tomatoes, #[borsh(skip)] C, Oil);
351 })
352 .unwrap();
353
354 let actual = process(&item_struct, default_cratename()).unwrap();
355
356 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
357 }
358
359 #[test]
360 fn generic_named_fields_struct_borsh_skip() {
361 let item_struct: ItemStruct = syn::parse2(quote! {
362 struct G<K, V, U> {
363 #[borsh(skip)]
364 x: HashMap<K, V>,
365 y: U,
366 }
367 })
368 .unwrap();
369
370 let actual = process(&item_struct, default_cratename()).unwrap();
371
372 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
373 }
374
375 #[test]
376 fn recursive_struct() {
377 let item_struct: ItemStruct = syn::parse2(quote! {
378 struct CRecC {
379 a: String,
380 b: HashMap<String, CRecC>,
381 }
382 })
383 .unwrap();
384
385 let actual = process(&item_struct, default_cratename()).unwrap();
386
387 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
388 }
389
390 #[test]
391 fn generic_associated_type() {
392 let item_struct: ItemStruct = syn::parse2(quote! {
393 struct Parametrized<V, T: Debug>
394 where
395 T: TraitName,
396 {
397 field: T::Associated,
398 another: V,
399 }
400 })
401 .unwrap();
402
403 let actual = process(&item_struct, default_cratename()).unwrap();
404
405 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
406 }
407
408 #[test]
409 fn generic_associated_type_param_override() {
410 let item_struct: ItemStruct = syn::parse2(quote! {
411 struct Parametrized<V, T>
412 where
413 T: TraitName,
414 {
415 #[borsh(schema(params =
416 "T => <T as TraitName>::Associated"
417 ))]
418 field: <T as TraitName>::Associated,
419 another: V,
420 }
421 })
422 .unwrap();
423
424 let actual = process(&item_struct, default_cratename()).unwrap();
425
426 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
427 }
428
429 #[test]
430 fn generic_associated_type_param_override2() {
431 let item_struct: ItemStruct = syn::parse2(quote! {
432 struct Parametrized<V, T>
433 where
434 T: TraitName,
435 {
436 #[borsh(schema(params =
437 "T => T, T => <T as TraitName>::Associated"
438 ))]
439 field: (<T as TraitName>::Associated, T),
440 another: V,
441 }
442 })
443 .unwrap();
444
445 let actual = process(&item_struct, default_cratename()).unwrap();
446
447 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
448 }
449
450 #[test]
451 fn generic_associated_type_param_override_conflict() {
452 let item_struct: ItemStruct = syn::parse2(quote! {
453 struct Parametrized<V, T>
454 where
455 T: TraitName,
456 {
457 #[borsh(skip,schema(params =
458 "T => <T as TraitName>::Associated"
459 ))]
460 field: <T as TraitName>::Associated,
461 another: V,
462 }
463 })
464 .unwrap();
465
466 let actual = process(&item_struct, default_cratename());
467
468 local_insta_assert_debug_snapshot!(actual.unwrap_err());
469 }
470
471 #[test]
472 fn check_with_funcs_skip_conflict() {
473 let item_struct: ItemStruct = syn::parse2(quote! {
474 struct A<K, V> {
475 #[borsh(skip,schema(with_funcs(
476 declaration = "third_party_impl::declaration::<K, V>",
477 definitions = "third_party_impl::add_definitions_recursively::<K, V>"
478 )))]
479 x: ThirdParty<K, V>,
480 y: u64,
481 }
482 })
483 .unwrap();
484
485 let actual = process(&item_struct, default_cratename());
486
487 local_insta_assert_debug_snapshot!(actual.unwrap_err());
488 }
489
490 #[test]
491 fn with_funcs_attr() {
492 let item_struct: ItemStruct = syn::parse2(quote! {
493 struct A<K, V> {
494 #[borsh(schema(with_funcs(
495 declaration = "third_party_impl::declaration::<K, V>",
496 definitions = "third_party_impl::add_definitions_recursively::<K, V>"
497 )))]
498 x: ThirdParty<K, V>,
499 y: u64,
500 }
501 })
502 .unwrap();
503
504 let actual = process(&item_struct, default_cratename()).unwrap();
505
506 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
507 }
508
509 #[test]
510 fn schema_param_override3() {
511 let item_struct: ItemStruct = syn::parse2(quote! {
512 struct A<K: EntityRef, V> {
513 #[borsh(
514 schema(
515 params = "V => V"
516 )
517 )]
518 x: PrimaryMap<K, V>,
519 y: String,
520 }
521 })
522 .unwrap();
523
524 let actual = process(&item_struct, default_cratename()).unwrap();
525
526 local_insta_assert_snapshot!(pretty_print_syn_str(&actual).unwrap());
527 }
528}