go_away_derive_internals/type_metadata_derive/
mod.rs

1use proc_macro2::{Literal, TokenStream};
2use quote::quote;
3use serde_derive_internals::{
4    ast::{Container, Data, Field, Style},
5    attr::TagType,
6    Ctxt,
7};
8
9mod type_id;
10
11use type_id::TypeIdCall;
12
13pub fn type_metadata_derive(ast: &syn::DeriveInput) -> Result<TokenStream, syn::Error> {
14    use quote::TokenStreamExt;
15
16    let ctx = Ctxt::new();
17
18    let container =
19        Container::from_ast(&ctx, ast, serde_derive_internals::Derive::Deserialize).unwrap();
20
21    match ctx.check() {
22        Ok(_) => {}
23        Err(errors) => {
24            let mut rv = TokenStream::new();
25            for err in errors {
26                rv.extend(err.to_compile_error());
27            }
28
29            return Ok(rv);
30        }
31    }
32    let type_id = TypeIdCall::for_struct(&container.ident, container.generics);
33
34    let ident = &container.ident;
35    let name_literal = Literal::string(&ident.to_string());
36    let mut inner = TokenStream::new();
37    match container.data {
38        Data::Enum(variants) if variants.iter().all(|v| matches!(v.style, Style::Unit)) => {
39            inner.append_all(quote! {
40                let mut rv = types::Enum {
41                    name: #name_literal.into(),
42                    variants: vec![]
43                };
44            });
45            for variant in variants {
46                if variant.attrs.skip_deserializing() && variant.attrs.skip_serializing() {
47                    continue;
48                }
49                let variant_name = Literal::string(&variant.ident.to_string());
50                let serialized_name = Literal::string(&variant.attrs.name().serialize_name());
51                inner.append_all(quote! {
52                    rv.variants.push(types::EnumVariant {
53                        name: #variant_name.into(),
54                        serialized_name: #serialized_name.into(),
55                    });
56                })
57            }
58
59            inner.append_all(quote! {
60                FieldType::Named(registry.register_enum(#type_id, rv))
61            });
62        }
63        Data::Enum(variants) if variants.iter().all(|v| matches!(v.style, Style::Newtype)) => {
64            let repr = tag_to_representation(container.attrs.tag());
65            inner.append_all(quote! {
66                let mut rv = types::Union {
67                    name: #name_literal.into(),
68                    representation: #repr,
69                    variants: vec![]
70                };
71            });
72            for variant in variants {
73                if variant.attrs.skip_deserializing() && variant.attrs.skip_serializing() {
74                    continue;
75                }
76                let variant_name = Literal::string(&variant.ident.to_string());
77                let serialized_name = Literal::string(&variant.attrs.name().serialize_name());
78                let metadata_call = metadata_call(variant.fields.first().unwrap().ty);
79                inner.append_all(quote! {
80                    rv.variants.push(
81                        types::UnionVariant {
82                            name: Some(#variant_name.to_string()),
83                            ty: #metadata_call,
84                            serialized_name: #serialized_name.to_string()
85                        }
86                    );
87                })
88            }
89            inner.append_all(quote! {
90                FieldType::Named(registry.register_union(#type_id, rv))
91            })
92        }
93        Data::Enum(variants) => {
94            let repr = tag_to_representation(container.attrs.tag());
95            inner.append_all(quote! {
96                let mut rv = types::Union {
97                    name: #name_literal.into(),
98                    representation: #repr,
99                    variants: vec![]
100                };
101            });
102            for variant in variants {
103                if variant.attrs.skip_deserializing() && variant.attrs.skip_serializing() {
104                    continue;
105                }
106                let variant_name = Literal::string(&variant.ident.to_string());
107                let serialized_name = Literal::string(&variant.attrs.name().serialize_name());
108                let type_id =
109                    TypeIdCall::for_variant(&container.ident, &variant.ident, container.generics);
110                let inner_type_block =
111                    struct_block(&variant.ident.to_string(), &variant.fields, type_id);
112                inner.append_all(quote! {
113                    rv.variants.push(
114                        types::UnionVariant {
115                            name: Some(#variant_name.to_string()),
116                            ty: FieldType::Named({#inner_type_block}),
117                            serialized_name: #serialized_name.to_string()
118                        }
119                    );
120                })
121            }
122            inner.append_all(quote! {
123                FieldType::Named(registry.register_union(#type_id, rv))
124            })
125        }
126        Data::Struct(Style::Newtype, fields) => {
127            let metadata_call = metadata_call(fields.first().unwrap().ty);
128            inner.append_all(quote! {
129                let nt = types::NewType {
130                    name: #name_literal.to_string(),
131                    inner: #metadata_call,
132                };
133                FieldType::Named(registry.register_newtype(#type_id, nt))
134            });
135        }
136        Data::Struct(_, fields) => {
137            let struct_block_contents = struct_block(&ident.to_string(), &fields, type_id);
138            inner.append_all(quote! {
139                let type_ref = {
140                    #struct_block_contents
141                };
142                FieldType::Named(type_ref)
143            });
144        }
145    }
146
147    let (impl_generics, ty_generics, where_clause) = container.generics.split_for_impl();
148    Ok(quote! {
149        #[automatically_derived]
150        impl #impl_generics ::go_away::TypeMetadata for #ident #ty_generics #where_clause {
151            fn metadata(registry: &mut ::go_away::TypeRegistry) -> ::go_away::types::FieldType {
152                use ::go_away::types::{self, FieldType};
153                #inner
154            }
155        }
156    })
157}
158
159fn struct_block(name: &str, fields: &[Field], type_id: TypeIdCall<'_>) -> TokenStream {
160    use quote::TokenStreamExt;
161
162    let mut rv = TokenStream::new();
163
164    let name_literal = Literal::string(name);
165    rv.append_all(quote! {
166        let mut st = types::Struct {
167            name: #name_literal.into(),
168            fields: vec![]
169        };
170    });
171    for field in fields {
172        if field.attrs.skip_deserializing() && field.attrs.skip_serializing() {
173            continue;
174        }
175        let field_name = name_of_member(&field.member);
176        let serialized_name = Literal::string(&field.attrs.name().serialize_name());
177        let ty_def = metadata_call(field.ty);
178        rv.append_all(quote! {
179            st.fields.push(
180                types::Field {
181                    name: #field_name.into(),
182                    serialized_name: #serialized_name.into(),
183                    ty: #ty_def
184                }
185            );
186        });
187    }
188    rv.append_all(quote! {
189        registry.register_struct(#type_id, st)
190    });
191
192    rv
193}
194
195fn tag_to_representation(tag: &TagType) -> proc_macro2::TokenStream {
196    match tag {
197        TagType::Adjacent { tag, content } => {
198            let tag = Literal::string(tag);
199            let content = Literal::string(content);
200            quote! {
201                types::UnionRepresentation::AdjacentlyTagged {
202                    tag: #tag.into(),
203                    content: #content.into()
204                }
205            }
206        }
207        TagType::External => {
208            quote! { types::UnionRepresentation::ExternallyTagged }
209        }
210        TagType::Internal { tag } => {
211            let tag = Literal::string(tag);
212            quote! {
213                types::UnionRepresentation::InternallyTagged {
214                    tag: #tag.into()
215                }
216            }
217        }
218        TagType::None => {
219            quote! { types::UnionRepresentation::Untagged }
220        }
221    }
222}
223
224fn metadata_call(ty: &syn::Type) -> proc_macro2::TokenStream {
225    match ty {
226        syn::Type::Reference(r) => metadata_call(r.elem.as_ref()),
227        syn::Type::Paren(r) => metadata_call(r.elem.as_ref()),
228        other => quote! { <#other as ::go_away::TypeMetadata>::metadata(registry) },
229    }
230}
231
232fn name_of_member(member: &syn::Member) -> proc_macro2::Literal {
233    use syn::Index;
234    match member {
235        syn::Member::Named(ident) => Literal::string(&ident.to_string()),
236        syn::Member::Unnamed(Index { index, .. }) => Literal::string(&format!("_{}", index)),
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use insta::assert_snapshot;
243    use quote::quote;
244    use xshell::Shell;
245
246    use super::*;
247
248    #[test]
249    fn test_struct() {
250        assert_snapshot!(test_conversion(quote! {
251            struct MyData {
252                field_one: String,
253                field_two: String
254            }
255        }))
256    }
257
258    #[test]
259    fn test_newtype_struct() {
260        assert_snapshot!(test_conversion(quote! {
261            struct MyData(String);
262        }))
263    }
264
265    #[test]
266    fn test_struct_with_single_field() {
267        assert_snapshot!(test_conversion(quote! {
268            struct MyData {
269                data: String
270            }
271        }))
272    }
273
274    fn test_conversion(ts: proc_macro2::TokenStream) -> String {
275        format_code(
276            &type_metadata_derive(&syn::parse2(ts).unwrap())
277                .unwrap()
278                .to_string(),
279        )
280    }
281
282    fn format_code(text: &str) -> String {
283        let sh = Shell::new().unwrap();
284        xshell::cmd!(sh, "rustfmt").stdin(text).read().unwrap()
285    }
286}