npyz_derive/
lib.rs

1#![recursion_limit = "256"]
2
3/*!
4Derive `npyz`'s traits for structured arrays.
5
6Using this crate, it is enough to `#[derive(npyz::Serialize, npyz::Deserialize)]` on a struct to be able to
7serialize and deserialize it. All of the fields must implement [`Serialize`](../npyz/trait.Serialize.html)
8and [`Deserialize`](../npyz/trait.Deserialize.html) respectively.
9
10*/
11
12use proc_macro::{TokenStream as TokenStream1};
13use proc_macro2::{Span, TokenStream};
14use quote::quote;
15
16#[proc_macro_derive(Serialize)]
17pub fn npy_serialize(input: TokenStream1) -> TokenStream1 {
18    // Parse the string representation
19    let ast = syn::parse(input).unwrap();
20
21    // Build the impl
22    let expanded = impl_npy_serialize(&ast);
23
24    // Return the generated impl
25    expanded.into()
26}
27
28#[proc_macro_derive(Deserialize)]
29pub fn npy_deserialize(input: TokenStream1) -> TokenStream1 {
30    // Parse the string representation
31    let ast = syn::parse(input).unwrap();
32
33    // Build the impl
34    let expanded = impl_npy_deserialize(&ast);
35
36    // Return the generated impl
37    expanded.into()
38}
39
40#[proc_macro_derive(AutoSerialize)]
41pub fn npy_auto_serialize(input: TokenStream1) -> TokenStream1 {
42    // Parse the string representation
43    let ast = syn::parse(input).unwrap();
44
45    // Build the impl
46    let expanded = impl_npy_auto_serialize(&ast);
47
48    // Return the generated impl
49    expanded.into()
50}
51
52struct FieldData {
53    idents: Vec<syn::Ident>,
54    idents_str: Vec<String>,
55    types: Vec<TokenStream>,
56}
57
58impl FieldData {
59    fn extract(ast: &syn::DeriveInput) -> Self {
60        let fields = match ast.data {
61            syn::Data::Struct(ref data) => &data.fields,
62            _ => panic!("npyz derive macros can only be used with structs"),
63        };
64
65        let idents: Vec<syn::Ident> = fields.iter().map(|f| {
66            f.ident.clone().expect("Tuple structs not supported")
67        }).collect();
68        let idents_str = idents.iter().map(|t| unraw(t)).collect::<Vec<_>>();
69
70        let types: Vec<TokenStream> = fields.iter().map(|f| {
71            let ty = &f.ty;
72            quote!( #ty )
73        }).collect::<Vec<_>>();
74
75        FieldData { idents, idents_str, types }
76    }
77}
78
79fn impl_npy_serialize(ast: &syn::DeriveInput) -> TokenStream {
80    let name = &ast.ident;
81    let vis = &ast.vis;
82    let FieldData { ref idents, ref idents_str, ref types } = FieldData::extract(ast);
83
84    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
85    let field_dtypes_struct = gen_field_dtypes_struct(idents, idents_str);
86
87    let idents_1 = idents;
88
89    wrap_in_const("Serialize", &name, quote! {
90        use ::std::io;
91
92        #vis struct GeneratedWriter #ty_generics #where_clause {
93            writers: FieldWriters #ty_generics,
94        }
95
96        struct FieldWriters #ty_generics #where_clause {
97            #( #idents: <#types as _npyz::Serialize>::TypeWriter ,)*
98        }
99
100        #field_dtypes_struct
101
102        impl #impl_generics _npyz::TypeWrite for GeneratedWriter #ty_generics #where_clause {
103            type Value = #name #ty_generics;
104
105            #[allow(unused_mut)]
106            fn write_one<W: io::Write>(&self, mut w: W, value: &Self::Value) -> io::Result<()> {
107                #({ // braces for pre-NLL
108                    let method = <<#types as _npyz::Serialize>::TypeWriter as _npyz::TypeWrite>::write_one;
109                    method(&self.writers.#idents, &mut w, &value.#idents_1)?;
110                })*
111                p::Ok(())
112            }
113        }
114
115        impl #impl_generics _npyz::Serialize for #name #ty_generics #where_clause {
116            type TypeWriter = GeneratedWriter #ty_generics;
117
118            fn writer(dtype: &_npyz::DType) -> p::Result<GeneratedWriter, _npyz::DTypeError> {
119                let dtypes = FieldDTypes::extract(dtype)?;
120                let writers = FieldWriters {
121                    #( #idents: <#types as _npyz::Serialize>::writer(&dtypes.#idents_1)? ,)*
122                };
123
124                p::Ok(GeneratedWriter { writers })
125            }
126        }
127    })
128}
129
130fn impl_npy_deserialize(ast: &syn::DeriveInput) -> TokenStream {
131    let name = &ast.ident;
132    let vis = &ast.vis;
133    let FieldData { ref idents, ref idents_str, ref types } = FieldData::extract(ast);
134
135    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
136    let field_dtypes_struct = gen_field_dtypes_struct(idents, idents_str);
137
138    let idents_1 = idents;
139
140    wrap_in_const("Deserialize", &name, quote! {
141        use ::std::io;
142
143        #vis struct GeneratedReader #ty_generics #where_clause {
144            readers: FieldReaders #ty_generics,
145        }
146
147        struct FieldReaders #ty_generics #where_clause {
148            #( #idents: <#types as _npyz::Deserialize>::TypeReader ,)*
149        }
150
151        #field_dtypes_struct
152
153        impl #impl_generics _npyz::TypeRead for GeneratedReader #ty_generics #where_clause {
154            type Value = #name #ty_generics;
155
156            #[allow(unused_mut)]
157            fn read_one<R: io::Read>(&self, mut reader: R) -> io::Result<Self::Value> {
158                #(
159                    let func = <<#types as _npyz::Deserialize>::TypeReader as _npyz::TypeRead>::read_one;
160                    let #idents = func(&self.readers.#idents_1, &mut reader)?;
161                )*
162                io::Result::Ok(#name { #( #idents ),* })
163            }
164        }
165
166        impl #impl_generics _npyz::Deserialize for #name #ty_generics #where_clause {
167            type TypeReader = GeneratedReader #ty_generics;
168
169            fn reader(dtype: &_npyz::DType) -> p::Result<GeneratedReader, _npyz::DTypeError> {
170                let dtypes = FieldDTypes::extract(dtype)?;
171                let readers = FieldReaders {
172                    #( #idents: <#types as _npyz::Deserialize>::reader(&dtypes.#idents_1)? ,)*
173                };
174
175                p::Ok(GeneratedReader { readers })
176            }
177        }
178    })
179}
180
181fn impl_npy_auto_serialize(ast: &syn::DeriveInput) -> TokenStream {
182    let name = &ast.ident;
183    let FieldData { idents: _, ref idents_str, ref types } = FieldData::extract(ast);
184
185    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
186
187    wrap_in_const("AutoSerialize", &name, quote! {
188        impl #impl_generics _npyz::AutoSerialize for #name #ty_generics #where_clause {
189            fn default_dtype() -> _npyz::DType {
190                _npyz::DType::Record(::std::vec![#(
191                    _npyz::Field {
192                        name: p::ToString::to_string(#idents_str),
193                        dtype: <#types as _npyz::AutoSerialize>::default_dtype()
194                    }
195                ),*])
196            }
197        }
198    })
199}
200
201fn gen_field_dtypes_struct(
202    idents: &[syn::Ident],
203    idents_str: &[String],
204) -> TokenStream {
205    assert_eq!(idents.len(), idents_str.len());
206    quote!{
207        struct FieldDTypes {
208            #( #idents : _npyz::DType ,)*
209        }
210
211        impl FieldDTypes {
212            fn extract(dtype: &_npyz::DType) -> p::Result<Self, _npyz::DTypeError> {
213                let fields = match dtype {
214                    _npyz::DType::Record(fields) => fields,
215                    ty => return p::Err(_npyz::DTypeError::expected_record(ty)),
216                };
217
218                let correct_names: &[&str] = &[ #(#idents_str),* ];
219
220                if p::Iterator::ne(
221                    p::Iterator::map(fields.iter(), |f| &f.name[..]),
222                    p::Iterator::cloned(correct_names.iter()),
223                ) {
224                    let actual_names = p::Iterator::map(fields.iter(), |f| &f.name[..]);
225                    return p::Err(_npyz::DTypeError::wrong_fields(actual_names, correct_names));
226                }
227
228                #[allow(unused_mut)]
229                let mut fields = p::IntoIterator::into_iter(fields);
230                p::Result::Ok(FieldDTypes {
231                    #( #idents : {
232                        let field = p::Iterator::next(&mut fields).unwrap();
233                        p::Clone::clone(&field.dtype)
234                    },)*
235                })
236            }
237        }
238    }
239}
240
241// from the wonderful folks working on serde
242fn wrap_in_const(
243    trait_: &str,
244    ty: &syn::Ident,
245    code: TokenStream,
246) -> TokenStream {
247    let dummy_const = syn::Ident::new(
248        &format!("__IMPL_npy_{}_FOR_{}", trait_, unraw(ty)),
249        Span::call_site(),
250    );
251
252    quote! {
253        #[allow(non_upper_case_globals, unused_attributes, unused_qualifications)]
254        const #dummy_const: () = {
255            #[allow(unknown_lints)]
256            #[cfg_attr(feature = "cargo-clippy", allow(useless_attribute))]
257            #[allow(rust_2018_idioms)]
258            extern crate npyz as _npyz;
259
260            // if our generated code directly imports any traits, then the #[no_implicit_prelude]
261            // test won't catch accidental use of method syntax on trait methods (which can fail
262            // due to ambiguity with similarly-named methods on other traits).  So if we want to
263            // abbreviate paths, we need to do this instead:
264            use ::std::prelude::v1 as p;
265
266            #code
267        };
268    }
269}
270
271fn unraw(ident: &syn::Ident) -> String {
272    ident.to_string().trim_start_matches("r#").to_owned()
273}