obi_derive_internal/
struct_schema.rs

1use crate::helpers::declaration;
2use proc_macro2::TokenStream;
3use quote::{quote, ToTokens};
4use syn::{Fields, ItemStruct};
5
6pub fn process_struct(input: &ItemStruct) -> syn::Result<TokenStream> {
7    let name = &input.ident;
8    let name_str = name.to_token_stream().to_string();
9    let generics = &input.generics;
10    let (impl_generics, ty_generics, _) = generics.split_for_impl();
11    // Generate function that returns the name of the type.
12    let (declaration, mut where_clause) = declaration(&name_str, &input.generics);
13
14    // Generate function that returns the schema of required types.
15    let mut fields_vec = vec![];
16    let mut struct_fields = TokenStream::new();
17    let mut add_definitions_recursively_rec = TokenStream::new();
18    match &input.fields {
19        Fields::Named(fields) => {
20            for field in &fields.named {
21                let field_name = field.ident.as_ref().unwrap().to_token_stream().to_string();
22                let field_type = &field.ty;
23                fields_vec.push(quote! {
24                    (#field_name.to_string(), <#field_type>::declaration())
25                });
26                add_definitions_recursively_rec.extend(quote! {
27                    <#field_type>::add_definitions_recursively(definitions);
28                });
29                where_clause.push(quote! {
30                    #field_type: obi::OBISchema
31                });
32            }
33            if !fields_vec.is_empty() {
34                struct_fields = quote! {
35                    let fields = vec![#(#fields_vec),*];
36                };
37            }
38        }
39        // Unsupported on unnamed struct (tuple)
40        Fields::Unnamed(_fields) => {}
41        Fields::Unit => {}
42    }
43
44    if fields_vec.is_empty() {
45        struct_fields = quote! {
46            let fields = vec![];
47        };
48    }
49
50    let add_definitions_recursively = quote! {
51        fn add_definitions_recursively(definitions: &mut ::std::collections::HashMap<obi::schema::Declaration, obi::schema::Definition>) {
52            #struct_fields
53            let definition = obi::schema::Definition::Struct { fields };
54            Self::add_definition(Self::declaration(), definition, definitions);
55            #add_definitions_recursively_rec
56        }
57    };
58    let where_clause = if !where_clause.is_empty() {
59        quote! { where #(#where_clause),*}
60    } else {
61        TokenStream::new()
62    };
63    Ok(quote! {
64        impl #impl_generics obi::OBISchema for #name #ty_generics #where_clause {
65            fn declaration() -> obi::schema::Declaration {
66                #declaration
67            }
68            #add_definitions_recursively
69        }
70    })
71}
72
73// Rustfmt removes comas.
74#[rustfmt::skip::macros(quote)]
75#[cfg(test)]
76mod tests {
77    use super::*;
78
79    fn assert_eq(expected: TokenStream, actual: TokenStream) {
80        assert_eq!(expected.to_string(), actual.to_string())
81    }
82
83    #[test]
84    fn unit_struct() {
85        let item_struct: ItemStruct = syn::parse2(quote!{
86            struct A;
87        })
88        .unwrap();
89
90        let actual = process_struct(&item_struct).unwrap();
91        let expected = quote!{
92            impl obi::OBISchema for A
93            {
94                fn declaration() -> obi::schema::Declaration {
95                    "A".to_string()
96                }
97                fn add_definitions_recursively(definitions: &mut ::std::collections::HashMap<obi::schema::Declaration, obi::schema::Definition>) {
98                    let fields = vec![];
99                    let definition = obi::schema::Definition::Struct { fields };
100                    Self::add_definition(Self::declaration(), definition, definitions);
101                }
102            }
103        };
104        assert_eq(expected, actual);
105    }
106
107    #[test]
108    fn simple_struct() {
109        let item_struct: ItemStruct = syn::parse2(quote!{
110            struct A {
111                x: u64,
112                y: String,
113            }
114        })
115        .unwrap();
116
117        let actual = process_struct(&item_struct).unwrap();
118        let expected = quote!{
119            impl obi::OBISchema for A
120            where
121                u64: obi::OBISchema,
122                String: obi::OBISchema
123            {
124                fn declaration() -> obi::schema::Declaration {
125                    "A".to_string()
126                }
127                fn add_definitions_recursively(
128                    definitions: &mut ::std::collections::HashMap<
129                        obi::schema::Declaration,
130                        obi::schema::Definition
131                    >
132                ) {
133                    let fields = vec![
134                        ("x".to_string(), <u64>::declaration()),
135                        ("y".to_string(), <String>::declaration())
136                    ];
137                    let definition = obi::schema::Definition::Struct { fields };
138                    Self::add_definition(Self::declaration(), definition, definitions);
139                    <u64>::add_definitions_recursively(definitions);
140                    <String>::add_definitions_recursively(definitions);
141                }
142            }
143        };
144        assert_eq(expected, actual);
145    }
146
147    #[test]
148    fn simple_generics() {
149        let item_struct: ItemStruct = syn::parse2(quote!{
150            struct A<K, V> {
151                x: HashMap<K, V>,
152                y: String,
153            }
154        })
155        .unwrap();
156
157        let actual = process_struct(&item_struct).unwrap();
158        let expected = quote!{
159            impl<K, V> obi::OBISchema for A<K, V>
160            where
161                K: obi::OBISchema,
162                V: obi::OBISchema,
163                HashMap<K, V>: obi::OBISchema,
164                String: obi::OBISchema
165            {
166                fn declaration() -> obi::schema::Declaration {
167                    let params = vec![<K>::declaration(), <V>::declaration()];
168                    format!(r#"{}<{}>"#, "A", params.join(", "))
169                }
170                fn add_definitions_recursively(
171                    definitions: &mut ::std::collections::HashMap<
172                        obi::schema::Declaration,
173                        obi::schema::Definition
174                    >
175                ) {
176                    let fields = vec![
177                        ("x".to_string(), <HashMap<K, V> >::declaration()),
178                        ("y".to_string(), <String>::declaration())
179                    ];
180                    let definition = obi::schema::Definition::Struct { fields };
181                    Self::add_definition(Self::declaration(), definition, definitions);
182                    <HashMap<K, V> >::add_definitions_recursively(definitions);
183                    <String>::add_definitions_recursively(definitions);
184                }
185            }
186        };
187        assert_eq(expected, actual);
188    }
189
190    #[test]
191    fn tuple_struct_whole_skip() {
192        let item_struct: ItemStruct = syn::parse2(quote!{
193            struct A(#[obi_skip] String);
194        })
195        .unwrap();
196
197        let actual = process_struct(&item_struct).unwrap();
198        let expected = quote!{
199            impl obi::OBISchema for A {
200                fn declaration() -> obi::schema::Declaration {
201                    "A".to_string()
202                }
203                fn add_definitions_recursively(
204                    definitions: &mut ::std::collections::HashMap<
205                        obi::schema::Declaration,
206                        obi::schema::Definition
207                    >
208                ) {
209                    let fields = vec![];
210                    let definition = obi::schema::Definition::Struct { fields };
211                    Self::add_definition(Self::declaration(), definition, definitions);
212                }
213            }
214        };
215        assert_eq(expected, actual);
216    }
217}