1#![recursion_limit = "256"]
2
3use proc_macro::{TokenStream as TokenStream1};
13use proc_macro2::{Span, TokenStream};
14use quote::quote;
15
16#[proc_macro_derive(Serialize)]
17pub fn npy_serialize(input: TokenStream1) -> TokenStream1 {
18 let ast = syn::parse(input).unwrap();
20
21 let expanded = impl_npy_serialize(&ast);
23
24 expanded.into()
26}
27
28#[proc_macro_derive(Deserialize)]
29pub fn npy_deserialize(input: TokenStream1) -> TokenStream1 {
30 let ast = syn::parse(input).unwrap();
32
33 let expanded = impl_npy_deserialize(&ast);
35
36 expanded.into()
38}
39
40#[proc_macro_derive(AutoSerialize)]
41pub fn npy_auto_serialize(input: TokenStream1) -> TokenStream1 {
42 let ast = syn::parse(input).unwrap();
44
45 let expanded = impl_npy_auto_serialize(&ast);
47
48 expanded.into()
50}
51
52struct FieldData {
53 idents: Vec<syn::Ident>,
54 idents_str: Vec<String>,
55 types: Vec<TokenStream>,
56}
57
58impl FieldData {
59 fn extract(ast: &syn::DeriveInput) -> Self {
60 let fields = match ast.data {
61 syn::Data::Struct(ref data) => &data.fields,
62 _ => panic!("npyz derive macros can only be used with structs"),
63 };
64
65 let idents: Vec<syn::Ident> = fields.iter().map(|f| {
66 f.ident.clone().expect("Tuple structs not supported")
67 }).collect();
68 let idents_str = idents.iter().map(|t| unraw(t)).collect::<Vec<_>>();
69
70 let types: Vec<TokenStream> = fields.iter().map(|f| {
71 let ty = &f.ty;
72 quote!( #ty )
73 }).collect::<Vec<_>>();
74
75 FieldData { idents, idents_str, types }
76 }
77}
78
79fn impl_npy_serialize(ast: &syn::DeriveInput) -> TokenStream {
80 let name = &ast.ident;
81 let vis = &ast.vis;
82 let FieldData { ref idents, ref idents_str, ref types } = FieldData::extract(ast);
83
84 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
85 let field_dtypes_struct = gen_field_dtypes_struct(idents, idents_str);
86
87 let idents_1 = idents;
88
89 wrap_in_const("Serialize", &name, quote! {
90 use ::std::io;
91
92 #vis struct GeneratedWriter #ty_generics #where_clause {
93 writers: FieldWriters #ty_generics,
94 }
95
96 struct FieldWriters #ty_generics #where_clause {
97 #( #idents: <#types as _npyz::Serialize>::TypeWriter ,)*
98 }
99
100 #field_dtypes_struct
101
102 impl #impl_generics _npyz::TypeWrite for GeneratedWriter #ty_generics #where_clause {
103 type Value = #name #ty_generics;
104
105 #[allow(unused_mut)]
106 fn write_one<W: io::Write>(&self, mut w: W, value: &Self::Value) -> io::Result<()> {
107 #({ let method = <<#types as _npyz::Serialize>::TypeWriter as _npyz::TypeWrite>::write_one;
109 method(&self.writers.#idents, &mut w, &value.#idents_1)?;
110 })*
111 p::Ok(())
112 }
113 }
114
115 impl #impl_generics _npyz::Serialize for #name #ty_generics #where_clause {
116 type TypeWriter = GeneratedWriter #ty_generics;
117
118 fn writer(dtype: &_npyz::DType) -> p::Result<GeneratedWriter, _npyz::DTypeError> {
119 let dtypes = FieldDTypes::extract(dtype)?;
120 let writers = FieldWriters {
121 #( #idents: <#types as _npyz::Serialize>::writer(&dtypes.#idents_1)? ,)*
122 };
123
124 p::Ok(GeneratedWriter { writers })
125 }
126 }
127 })
128}
129
130fn impl_npy_deserialize(ast: &syn::DeriveInput) -> TokenStream {
131 let name = &ast.ident;
132 let vis = &ast.vis;
133 let FieldData { ref idents, ref idents_str, ref types } = FieldData::extract(ast);
134
135 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
136 let field_dtypes_struct = gen_field_dtypes_struct(idents, idents_str);
137
138 let idents_1 = idents;
139
140 wrap_in_const("Deserialize", &name, quote! {
141 use ::std::io;
142
143 #vis struct GeneratedReader #ty_generics #where_clause {
144 readers: FieldReaders #ty_generics,
145 }
146
147 struct FieldReaders #ty_generics #where_clause {
148 #( #idents: <#types as _npyz::Deserialize>::TypeReader ,)*
149 }
150
151 #field_dtypes_struct
152
153 impl #impl_generics _npyz::TypeRead for GeneratedReader #ty_generics #where_clause {
154 type Value = #name #ty_generics;
155
156 #[allow(unused_mut)]
157 fn read_one<R: io::Read>(&self, mut reader: R) -> io::Result<Self::Value> {
158 #(
159 let func = <<#types as _npyz::Deserialize>::TypeReader as _npyz::TypeRead>::read_one;
160 let #idents = func(&self.readers.#idents_1, &mut reader)?;
161 )*
162 io::Result::Ok(#name { #( #idents ),* })
163 }
164 }
165
166 impl #impl_generics _npyz::Deserialize for #name #ty_generics #where_clause {
167 type TypeReader = GeneratedReader #ty_generics;
168
169 fn reader(dtype: &_npyz::DType) -> p::Result<GeneratedReader, _npyz::DTypeError> {
170 let dtypes = FieldDTypes::extract(dtype)?;
171 let readers = FieldReaders {
172 #( #idents: <#types as _npyz::Deserialize>::reader(&dtypes.#idents_1)? ,)*
173 };
174
175 p::Ok(GeneratedReader { readers })
176 }
177 }
178 })
179}
180
181fn impl_npy_auto_serialize(ast: &syn::DeriveInput) -> TokenStream {
182 let name = &ast.ident;
183 let FieldData { idents: _, ref idents_str, ref types } = FieldData::extract(ast);
184
185 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
186
187 wrap_in_const("AutoSerialize", &name, quote! {
188 impl #impl_generics _npyz::AutoSerialize for #name #ty_generics #where_clause {
189 fn default_dtype() -> _npyz::DType {
190 _npyz::DType::Record(::std::vec![#(
191 _npyz::Field {
192 name: p::ToString::to_string(#idents_str),
193 dtype: <#types as _npyz::AutoSerialize>::default_dtype()
194 }
195 ),*])
196 }
197 }
198 })
199}
200
201fn gen_field_dtypes_struct(
202 idents: &[syn::Ident],
203 idents_str: &[String],
204) -> TokenStream {
205 assert_eq!(idents.len(), idents_str.len());
206 quote!{
207 struct FieldDTypes {
208 #( #idents : _npyz::DType ,)*
209 }
210
211 impl FieldDTypes {
212 fn extract(dtype: &_npyz::DType) -> p::Result<Self, _npyz::DTypeError> {
213 let fields = match dtype {
214 _npyz::DType::Record(fields) => fields,
215 ty => return p::Err(_npyz::DTypeError::expected_record(ty)),
216 };
217
218 let correct_names: &[&str] = &[ #(#idents_str),* ];
219
220 if p::Iterator::ne(
221 p::Iterator::map(fields.iter(), |f| &f.name[..]),
222 p::Iterator::cloned(correct_names.iter()),
223 ) {
224 let actual_names = p::Iterator::map(fields.iter(), |f| &f.name[..]);
225 return p::Err(_npyz::DTypeError::wrong_fields(actual_names, correct_names));
226 }
227
228 #[allow(unused_mut)]
229 let mut fields = p::IntoIterator::into_iter(fields);
230 p::Result::Ok(FieldDTypes {
231 #( #idents : {
232 let field = p::Iterator::next(&mut fields).unwrap();
233 p::Clone::clone(&field.dtype)
234 },)*
235 })
236 }
237 }
238 }
239}
240
241fn wrap_in_const(
243 trait_: &str,
244 ty: &syn::Ident,
245 code: TokenStream,
246) -> TokenStream {
247 let dummy_const = syn::Ident::new(
248 &format!("__IMPL_npy_{}_FOR_{}", trait_, unraw(ty)),
249 Span::call_site(),
250 );
251
252 quote! {
253 #[allow(non_upper_case_globals, unused_attributes, unused_qualifications)]
254 const #dummy_const: () = {
255 #[allow(unknown_lints)]
256 #[cfg_attr(feature = "cargo-clippy", allow(useless_attribute))]
257 #[allow(rust_2018_idioms)]
258 extern crate npyz as _npyz;
259
260 use ::std::prelude::v1 as p;
265
266 #code
267 };
268 }
269}
270
271fn unraw(ident: &syn::Ident) -> String {
272 ident.to_string().trim_start_matches("r#").to_owned()
273}