1use proc_macro2::{Span, TokenStream as TokenStream2};
2use quote::{quote, ToTokens};
3use syn::{
4 parse_quote, AttrStyle, Attribute, Field, Fields, FieldsUnnamed, Ident, ItemEnum, ItemStruct,
5 Visibility,
6};
7
8use crate::helpers::{declaration, quote_where_clause};
9
10pub fn process_enum(input: &ItemEnum, cratename: Ident) -> syn::Result<TokenStream2> {
11 let name = &input.ident;
12 let name_str = name.to_token_stream().to_string();
13 let generics = &input.generics;
14 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
15 let (declaration, where_clause_additions) =
17 declaration(&name_str, &input.generics, cratename.clone());
18
19 let mut variants_defs = vec![];
22 let mut anonymous_defs = TokenStream2::new();
24 let mut add_recursive_defs = TokenStream2::new();
26 for variant in &input.variants {
27 let variant_name_str = variant.ident.to_token_stream().to_string();
28 let full_variant_name_str = format!("{}{}", name_str, variant_name_str);
29 let full_variant_ident = Ident::new(full_variant_name_str.as_str(), Span::call_site());
30 let mut anonymous_struct = ItemStruct {
31 attrs: vec![],
32 vis: Visibility::Inherited,
33 struct_token: Default::default(),
34 ident: full_variant_ident.clone(),
35 generics: (*generics).clone(),
36 fields: variant.fields.clone(),
37 semi_token: Some(Default::default()),
38 };
39 let generic_params = generics
40 .type_params()
41 .fold(TokenStream2::new(), |acc, generic| {
42 let ident = &generic.ident;
43 quote! {
44 #acc
45 #ident ,
46 }
47 });
48 if !generic_params.is_empty() {
49 let attr = Attribute {
50 pound_token: Default::default(),
51 style: AttrStyle::Outer,
52 bracket_token: Default::default(),
53 path: parse_quote! {borsh_skip},
54 tokens: Default::default(),
55 };
56 let mut unit_to_regular = false;
58 match &mut anonymous_struct.fields {
59 Fields::Named(named) => {
60 named.named.push(Field {
61 attrs: vec![attr.clone()],
62 vis: Visibility::Inherited,
63 ident: Some(Ident::new("borsh_schema_phantom_data", Span::call_site())),
64 colon_token: None,
65 ty: parse_quote! {::core::marker::PhantomData<(#generic_params)>},
66 });
67 }
68 Fields::Unnamed(unnamed) => {
69 unnamed.unnamed.push(Field {
70 attrs: vec![attr.clone()],
71 vis: Visibility::Inherited,
72 ident: None,
73 colon_token: None,
74 ty: parse_quote! {::core::marker::PhantomData<(#generic_params)>},
75 });
76 }
77 Fields::Unit => {
78 unit_to_regular = true;
79 }
80 }
81 if unit_to_regular {
82 let mut fields = FieldsUnnamed {
83 paren_token: Default::default(),
84 unnamed: Default::default(),
85 };
86 fields.unnamed.push(Field {
87 attrs: vec![attr],
88 vis: Visibility::Inherited,
89 ident: None,
90 colon_token: None,
91 ty: parse_quote! {::core::marker::PhantomData<(#generic_params)>},
92 });
93 anonymous_struct.fields = Fields::Unnamed(fields);
94 }
95 }
96 anonymous_defs.extend(quote! {
97 #[derive(#cratename::BorshSchema)]
98 #anonymous_struct
99 });
100 add_recursive_defs.extend(quote! {
101 <#full_variant_ident #ty_generics>::add_definitions_recursively(definitions);
102 });
103 variants_defs.push(quote! {
104 (#variant_name_str.to_string(), <#full_variant_ident #ty_generics>::declaration())
105 });
106 }
107
108 let type_definitions = quote! {
109 fn add_definitions_recursively(definitions: &mut #cratename::maybestd::collections::HashMap<#cratename::schema::Declaration, #cratename::schema::Definition>) {
110 #anonymous_defs
111 #add_recursive_defs
112 let variants = #cratename::maybestd::vec![#(#variants_defs),*];
113 let definition = #cratename::schema::Definition::Enum{variants};
114 Self::add_definition(Self::declaration(), definition, definitions);
115 }
116 };
117 let where_clause = quote_where_clause(where_clause, where_clause_additions);
118 Ok(quote! {
119 impl #impl_generics #cratename::BorshSchema for #name #ty_generics #where_clause {
120 fn declaration() -> #cratename::schema::Declaration {
121 #declaration
122 }
123 #type_definitions
124 }
125 })
126}
127
128#[rustfmt::skip]
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 fn assert_eq(expected: TokenStream2, actual: TokenStream2) {
135 pretty_assertions::assert_eq!(expected.to_string(), actual.to_string())
136 }
137
138 #[test]
139 fn simple_enum() {
140 let item_enum: ItemEnum = syn::parse2(quote!{
141 enum A {
142 Bacon,
143 Eggs
144 }
145 }).unwrap();
146
147 let actual = process_enum(&item_enum, Ident::new("borsh", Span::call_site())).unwrap();
148 let expected = quote!{
149 impl borsh::BorshSchema for A {
150 fn declaration() -> borsh::schema::Declaration {
151 "A".to_string()
152 }
153 fn add_definitions_recursively(
154 definitions: &mut borsh::maybestd::collections::HashMap<
155 borsh::schema::Declaration,
156 borsh::schema::Definition
157 >
158 ) {
159 #[derive(borsh :: BorshSchema)]
160 struct ABacon;
161 #[derive(borsh :: BorshSchema)]
162 struct AEggs;
163 <ABacon>::add_definitions_recursively(definitions);
164 <AEggs>::add_definitions_recursively(definitions);
165 let variants = borsh::maybestd::vec![
166 ("Bacon".to_string(), <ABacon>::declaration()),
167 ("Eggs".to_string(), <AEggs>::declaration())
168 ];
169 let definition = borsh::schema::Definition::Enum { variants };
170 Self::add_definition(Self::declaration(), definition, definitions);
171 }
172 }
173 };
174 assert_eq(expected, actual);
175 }
176
177 #[test]
178 fn single_field_enum() {
179 let item_enum: ItemEnum = syn::parse2(quote! {
180 enum A {
181 Bacon,
182 }
183 }).unwrap();
184
185 let actual = process_enum(&item_enum, Ident::new("borsh", Span::call_site())).unwrap();
186 let expected = quote!{
187 impl borsh::BorshSchema for A {
188 fn declaration() -> borsh::schema::Declaration {
189 "A".to_string()
190 }
191 fn add_definitions_recursively(
192 definitions: &mut borsh::maybestd::collections::HashMap<
193 borsh::schema::Declaration,
194 borsh::schema::Definition
195 >
196 ) {
197 #[derive(borsh :: BorshSchema)]
198 struct ABacon;
199 <ABacon>::add_definitions_recursively(definitions);
200 let variants = borsh::maybestd::vec![("Bacon".to_string(), <ABacon>::declaration())];
201 let definition = borsh::schema::Definition::Enum { variants };
202 Self::add_definition(Self::declaration(), definition, definitions);
203 }
204 }
205 };
206 assert_eq(expected, actual);
207 }
208
209 #[test]
210 fn complex_enum() {
211 let item_enum: ItemEnum = syn::parse2(quote! {
212 enum A {
213 Bacon,
214 Eggs,
215 Salad(Tomatoes, Cucumber, Oil),
216 Sausage{wrapper: Wrapper, filling: Filling},
217 }
218 }).unwrap();
219
220 let actual = process_enum(&item_enum, Ident::new("borsh", Span::call_site())).unwrap();
221 let expected = quote!{
222 impl borsh::BorshSchema for A {
223 fn declaration() -> borsh::schema::Declaration {
224 "A".to_string()
225 }
226 fn add_definitions_recursively(
227 definitions: &mut borsh::maybestd::collections::HashMap<
228 borsh::schema::Declaration,
229 borsh::schema::Definition
230 >
231 ) {
232 #[derive(borsh :: BorshSchema)]
233 struct ABacon;
234 #[derive(borsh :: BorshSchema)]
235 struct AEggs;
236 #[derive(borsh :: BorshSchema)]
237 struct ASalad(Tomatoes, Cucumber, Oil);
238 #[derive(borsh :: BorshSchema)]
239 struct ASausage {
240 wrapper: Wrapper,
241 filling: Filling
242 }
243 <ABacon>::add_definitions_recursively(definitions);
244 <AEggs>::add_definitions_recursively(definitions);
245 <ASalad>::add_definitions_recursively(definitions);
246 <ASausage>::add_definitions_recursively(definitions);
247 let variants = borsh::maybestd::vec![
248 ("Bacon".to_string(), <ABacon>::declaration()),
249 ("Eggs".to_string(), <AEggs>::declaration()),
250 ("Salad".to_string(), <ASalad>::declaration()),
251 ("Sausage".to_string(), <ASausage>::declaration())
252 ];
253 let definition = borsh::schema::Definition::Enum { variants };
254 Self::add_definition(Self::declaration(), definition, definitions);
255 }
256 }
257 };
258 assert_eq(expected, actual);
259 }
260
261 #[test]
262 fn complex_enum_generics() {
263 let item_enum: ItemEnum = syn::parse2(quote! {
264 enum A<C, W> {
265 Bacon,
266 Eggs,
267 Salad(Tomatoes, C, Oil),
268 Sausage{wrapper: W, filling: Filling},
269 }
270 }).unwrap();
271
272 let actual = process_enum(&item_enum, Ident::new("borsh", Span::call_site())).unwrap();
273 let expected = quote!{
274 impl<C, W> borsh::BorshSchema for A<C, W>
275 where
276 C: borsh::BorshSchema,
277 W: borsh::BorshSchema
278 {
279 fn declaration() -> borsh::schema::Declaration {
280 let params = borsh::maybestd::vec![<C>::declaration(), <W>::declaration()];
281 format!(r#"{}<{}>"#, "A", params.join(", "))
282 }
283 fn add_definitions_recursively(
284 definitions: &mut borsh::maybestd::collections::HashMap<
285 borsh::schema::Declaration,
286 borsh::schema::Definition
287 >
288 ) {
289 #[derive(borsh :: BorshSchema)]
290 struct ABacon<C, W>(#[borsh_skip] ::core::marker::PhantomData<(C, W, )>);
291 #[derive(borsh :: BorshSchema)]
292 struct AEggs<C, W>(#[borsh_skip] ::core::marker::PhantomData<(C, W, )>);
293 #[derive(borsh :: BorshSchema)]
294 struct ASalad<C, W>(
295 Tomatoes,
296 C,
297 Oil,
298 #[borsh_skip] ::core::marker::PhantomData<(C, W, )>
299 );
300 #[derive(borsh :: BorshSchema)]
301 struct ASausage<C, W> {
302 wrapper: W,
303 filling: Filling,
304 #[borsh_skip]
305 borsh_schema_phantom_data: ::core::marker::PhantomData<(C, W, )>
306 }
307 <ABacon<C, W> >::add_definitions_recursively(definitions);
308 <AEggs<C, W> >::add_definitions_recursively(definitions);
309 <ASalad<C, W> >::add_definitions_recursively(definitions);
310 <ASausage<C, W> >::add_definitions_recursively(definitions);
311 let variants = borsh::maybestd::vec![
312 ("Bacon".to_string(), <ABacon<C, W> >::declaration()),
313 ("Eggs".to_string(), <AEggs<C, W> >::declaration()),
314 ("Salad".to_string(), <ASalad<C, W> >::declaration()),
315 ("Sausage".to_string(), <ASausage<C, W> >::declaration())
316 ];
317 let definition = borsh::schema::Definition::Enum { variants };
318 Self::add_definition(Self::declaration(), definition, definitions);
319 }
320 }
321 };
322 assert_eq(expected, actual);
323 }
324
325 #[test]
326 fn trailing_comma_generics() {
327 let item_struct: ItemEnum = syn::parse2(quote!{
328 enum Side<A, B>
329 where
330 A: Display + Debug,
331 B: Display + Debug,
332 {
333 Left(A),
334 Right(B),
335 }
336 })
337 .unwrap();
338
339 let actual = process_enum(
340 &item_struct,
341 Ident::new("borsh", proc_macro2::Span::call_site()),
342 )
343 .unwrap();
344 let expected = quote!{
345 impl<A, B> borsh::BorshSchema for Side<A, B>
346 where
347 A: Display + Debug,
348 B: Display + Debug,
349 A: borsh::BorshSchema,
350 B: borsh::BorshSchema
351 {
352 fn declaration() -> borsh::schema::Declaration {
353 let params = borsh::maybestd::vec![<A>::declaration(), <B>::declaration()];
354 format!(r#"{}<{}>"#, "Side", params.join(", "))
355 }
356 fn add_definitions_recursively(
357 definitions: &mut borsh::maybestd::collections::HashMap<
358 borsh::schema::Declaration,
359 borsh::schema::Definition
360 >
361 ) {
362 #[derive(borsh :: BorshSchema)]
363 struct SideLeft<A, B>
364 (
365 A,
366 #[borsh_skip] ::core::marker::PhantomData<(A, B, )>
367 )
368 where
369 A: Display + Debug,
370 B: Display + Debug,
371 ;
372 #[derive(borsh :: BorshSchema)]
373 struct SideRight<A, B>
374 (
375 B,
376 #[borsh_skip] ::core::marker::PhantomData<(A, B, )>
377 )
378 where
379 A: Display + Debug,
380 B: Display + Debug,
381 ;
382 <SideLeft<A, B> >::add_definitions_recursively(definitions);
383 <SideRight<A, B> >::add_definitions_recursively(definitions);
384 let variants = borsh::maybestd::vec![
385 ("Left".to_string(), <SideLeft<A, B> >::declaration()),
386 ("Right".to_string(), <SideRight<A, B> >::declaration())
387 ];
388 let definition = borsh::schema::Definition::Enum { variants };
389 Self::add_definition(Self::declaration(), definition, definitions);
390 }
391 }
392 };
393 assert_eq(expected, actual);
394 }
395}