proc_contra/
lib.rs

1//! Macro implementations for [contra](https://docs.rs/contra)
2//!
3//! Provides the derive macros for the serialization and deserialization of any arbitrary object.
4
5use proc_macro::TokenStream;
6use quote::quote;
7use syn::{DataEnum, DataStruct, DeriveInput};
8
9/// Derives the *Serialize* trait implementation
10///
11/// # Example
12/// ```ignore
13/// use proc_contra::Serialize;
14///
15/// #[derive(Serialize)]
16/// struct Point {
17///     x: f32,
18///     y: f32,
19///     z: f32
20/// }
21/// ```
22///
23/// Expands into:
24/// ```
25/// use lib_contra::{serialize::Serialize, serialize::Serializer, position::Position, error::SuccessResult};
26///
27/// struct Point {
28///     x: f32,
29///     y: f32,
30///     z: f32
31/// }
32///
33/// impl Serialize for Point {
34///     fn serialize<S: Serializer>(&self, ser: &mut S, _pos: &Position) -> SuccessResult {
35///         ser.begin_struct("Point", 3)?;
36///     
37///         ser.serialize_field("x", &self.x, &Position::Trailing)?;
38///         ser.serialize_field("y", &self.y, &Position::Trailing)?;
39///         ser.serialize_field("z", &self.z, &Position::Closing)?;
40///     
41///         ser.end_struct("Point")?;
42///     
43///         Ok(())
44///     }
45/// }
46/// ```
47///
48#[proc_macro_derive(Serialize)]
49pub fn impl_serialize(input: TokenStream) -> TokenStream {
50    let ast = syn::parse_macro_input!(input as DeriveInput);
51
52    match ast.data {
53        syn::Data::Struct(decl) => gen_struct_serialize(ast.ident, decl),
54        syn::Data::Enum(decl) => gen_enum_serialize(ast.ident, decl),
55        syn::Data::Union(_) => todo!(),
56    }
57}
58
59/// Derives the *Deserialize* trait implementation
60///
61/// # Example
62/// ```ignore
63/// use proc_contra::Deserialize;
64/// use lib_contra::{deserialize::Deserialize, position::Position, deserialize::Deserializer, error::AnyError};
65/// #[derive(Deserialize)]
66/// struct Point {
67///     x: f32,
68///     y: f32,
69///     z: f32
70/// }
71/// ```
72///
73/// Expands into:
74/// ```
75/// use lib_contra::{deserialize::{MapAccess, Visitor, Deserialize}, position::Position, deserialize::Deserializer, error::AnyError};
76///
77/// struct Point {
78///     x: f32,
79///     y: f32,
80///     z: f32
81/// }
82///
83/// impl Deserialize for Point {
84///     fn deserialize<D: Deserializer>(de: D) -> Result<Self, AnyError> {
85///         enum Field {
86///             x, y, z
87///         }
88///         impl Deserialize for Field {
89///             fn deserialize<D: Deserializer>(de: D) -> Result<Self, AnyError> {
90///                 struct FieldVisitor {}
91///                 impl Visitor for FieldVisitor {
92///                     type Value = Field;
93///                     fn expected_a(self) -> String { "Point field".to_string() }
94///                     fn visit_str(self, v: &str) -> Result<Self::Value, AnyError> {
95///                         match v {
96///                             "x" => Ok(Field::x),
97///                             "y" => Ok(Field::y),
98///                             "z" => Ok(Field::z),
99///                             val => Err(format!("unexpected Point field {}", val).into())
100///                         }
101///                     }
102///                 }
103///                 de.deserialize_str(FieldVisitor {})
104///             }
105///         }
106///
107///         struct PointVisitor {}
108///         impl Visitor for PointVisitor {
109///             type Value = Point;
110///             fn expected_a(self) -> String { "Point object".to_string() }
111///             fn visit_map<M: MapAccess>(self, mut map: M) -> Result<Self::Value, AnyError> {
112///                 let mut x = None;
113///                 let mut y = None;
114///                 let mut z = None;
115///                 
116///                 while let Some(key) = map.next_key()? {
117///                     match key {
118///                         Field::x => { if x.is_some() { return Err("duplicate field x".into()); } x = Some(map.next_value()?) },
119///                         Field::y => { if y.is_some() { return Err("duplicate field y".into()); } y = Some(map.next_value()?) },
120///                         Field::z => { if z.is_some() { return Err("duplicate field z".into()); } z = Some(map.next_value()?) },
121///                     }
122///                 }
123///
124///                 let x = x.ok_or_else(|| "missing field x")?;
125///                 let y = y.ok_or_else(|| "missing field y")?;
126///                 let z = z.ok_or_else(|| "missing field z")?;
127///
128///                 Ok(Point {
129///                     x, y, z
130///                 })
131///             }
132///         }
133///
134///         de.deserialize_struct(PointVisitor {})
135///     }
136/// }
137/// ```
138#[proc_macro_derive(Deserialize)]
139pub fn impl_deserialize(input: TokenStream) -> TokenStream {
140    let ast = syn::parse_macro_input!(input as DeriveInput);
141
142    match ast.data {
143        syn::Data::Struct(decl) => gen_struct_deserialize(ast.ident, decl),
144        syn::Data::Enum(decl) => gen_enum_deserialize(ast.ident, decl),
145        syn::Data::Union(_) => todo!(),
146    }
147}
148
149fn gen_struct_serialize(ident: syn::Ident, decl: DataStruct) -> TokenStream {
150    let c_ident = ident;
151    let n_fields = decl.fields.len();
152    let mut ser_fields = decl
153        .fields
154        .into_iter()
155        .map(|f| f.ident)
156        .filter(|f| f.is_some())
157        .map(|f| f.unwrap());
158    let closing_field = ser_fields.next_back()
159        .map(|f| Some(quote!(ser.serialize_field(stringify!(#f), &self.#f, &contra::lib_contra::position::Position::Closing )?; ))).into_iter();
160    let trailing_fields = ser_fields
161        .map(|f| Some(quote!(ser.serialize_field(stringify!(#f), &self.#f, &contra::lib_contra::position::Position::Trailing)?; ))).into_iter();
162    let ser_fields = trailing_fields
163        .chain(closing_field.into_iter())
164        .filter(|f| f.is_some());
165
166    quote!(
167        impl contra::lib_contra::serialize::Serialize for #c_ident {
168            fn serialize<S: contra::lib_contra::serialize::Serializer>(&self, ser: &mut S, _pos: &contra::lib_contra::position::Position) -> contra::lib_contra::error::SuccessResult {
169                ser.begin_struct(stringify!(#c_ident), #n_fields)?;
170
171                #(#ser_fields)*
172
173                ser.end_struct(stringify!(#c_ident))?;
174
175                Ok(())
176            }
177        }
178    ).into()
179}
180
181fn gen_enum_serialize(ident: syn::Ident, decl: DataEnum) -> TokenStream {
182    let e_ident = ident;
183    let variants = decl.variants.into_iter().map(|v| v.ident);
184
185    let ser_variants = variants
186        .clone()
187        .map(|v| quote! { #e_ident::#v => ser.serialize_str(stringify!(#v)) });
188
189    quote!(
190        impl contra::lib_contra::serialize::Serialize for #e_ident {
191            fn serialize<S: contra::lib_contra::serialize::Serializer>(&self, ser: &mut S, _pos: &contra::lib_contra::position::Position) -> contra::lib_contra::error::SuccessResult {
192                match self {
193                    #(#ser_variants,)*
194                }
195            }
196        }
197    ).into()
198}
199
200fn gen_enum_deserialize(ident: syn::Ident, decl: DataEnum) -> TokenStream {
201    let e_ident = ident;
202    let variants = decl.variants.into_iter().map(|v| v.ident);
203
204    let parse_variants = variants
205        .clone()
206        .map(|v| quote! { stringify!(#v) => Ok(#e_ident::#v) });
207
208    quote! {
209        impl contra::lib_contra::deserialize::Deserialize for #e_ident {
210            fn deserialize<D: contra::lib_contra::deserialize::Deserializer>(des: D) -> Result<Self, contra::lib_contra::error::AnyError> {
211                struct EnumVisitor {}
212                impl contra::lib_contra::deserialize::Visitor for EnumVisitor {
213                    type Value = #e_ident;
214
215                    fn expected_a(self) -> String {
216                        concat!(stringify!(#e_ident), " variant").to_string()
217                    }
218
219                    fn visit_str(self, v: &str) -> Result<Self::Value, contra::lib_contra::error::AnyError> {
220                        match v {
221                            #(#parse_variants,)*
222                            err => Err(format!("invalid {} variant \"{}\"", stringify!(#e_ident), err).into())
223                        }
224                    }
225                }
226
227                des.deserialize_str(EnumVisitor {})
228            }
229        }
230    }.into()
231}
232
233fn gen_struct_deserialize(ident: syn::Ident, decl: DataStruct) -> TokenStream {
234    let c_ident = ident;
235    let f_idents = decl
236        .fields
237        .into_iter()
238        .map(|f| f.ident)
239        .filter(|f| f.is_some())
240        .map(|f| f.unwrap());
241
242    let field_enum_decl = f_idents.clone().map(|i| quote! { #i });
243    let field_enum_parse = f_idents
244        .clone()
245        .map(|i| quote! { stringify!(#i) => Ok(Field::#i) });
246    let tmp_field_decl = f_idents.clone().map(|i| quote! { let mut #i = None });
247    let tmp_field_parse = f_idents.clone().map(|i| {
248        quote! {
249            Field::#i => {
250                if #i.is_some() {
251                    return Err(concat!("duplicate field ", stringify!(#i)).into());
252                }
253                #i = Some(map.next_value()?)
254            }
255        }
256    });
257    let tmp_field_result = f_idents
258        .clone()
259        .map(|i| quote! { let #i = #i.ok_or_else(|| concat!("missing field ", stringify!(#i)))? });
260    let tmp_field_initializer_list = f_idents.clone().map(|i| quote! { #i });
261
262    quote!(
263        impl contra::lib_contra::deserialize::Deserialize for #c_ident {
264            fn deserialize<D: contra::lib_contra::deserialize::Deserializer>(de: D) -> Result<Self, contra::lib_contra::error::AnyError> {
265                enum Field {
266                    #(#field_enum_decl,)*
267                }
268                impl contra::lib_contra::deserialize::Deserialize for Field {
269                    fn deserialize<D: contra::lib_contra::deserialize::Deserializer>(de: D) -> Result<Self, contra::lib_contra::error::AnyError> {
270                        struct FieldVisitor {}
271                        impl contra::lib_contra::deserialize::Visitor for FieldVisitor {
272                            type Value = Field;
273                            fn expected_a(self) -> String {
274                                concat!(stringify!(#c_ident), " field").into()
275                            }
276                            fn visit_str(self, v: &str) -> Result<Self::Value, contra::lib_contra::error::AnyError> {
277                                match v {
278                                    #(#field_enum_parse,)*
279                                    val => Err(format!("unknown \"{}\" field for {}", val, stringify!(#c_ident)).into())
280                                }
281                            }
282                        }
283                        de.deserialize_str(FieldVisitor {})
284                    }
285                }
286
287                struct StructVisitor {}
288                impl contra::lib_contra::deserialize::Visitor for StructVisitor {
289                    type Value = #c_ident;
290                    fn expected_a(self) -> String {
291                        concat!(stringify!(#c_ident), " object").into()
292                    }
293                    fn visit_map<M: contra::lib_contra::deserialize::MapAccess>(self, mut map: M) -> Result<Self::Value, contra::lib_contra::error::AnyError> {
294                        #(#tmp_field_decl;)*
295
296                        while let Some(key) = map.next_key::<Field>()? {
297                            match key {
298                                #(#tmp_field_parse,)*
299                            }
300                        }
301
302                        #(#tmp_field_result;)*
303
304                        Ok(#c_ident {
305                            #(#tmp_field_initializer_list,)*
306                        })
307                    }
308                }
309
310                de.deserialize_struct(StructVisitor {})
311            }
312        }
313    ).into()
314}