epee_encoding_derive/
lib.rs

1#![no_std]
2
3extern crate alloc;
4
5use alloc::format;
6use alloc::string::ToString;
7
8use proc_macro2::{Ident, Span, TokenStream};
9use quote::quote;
10use syn::{
11    parse_macro_input, parse_quote, Data, DeriveInput, Expr, Fields, GenericParam, Generics, Lit,
12    Type,
13};
14
15#[proc_macro_derive(
16    EpeeObject,
17    attributes(epee_default, epee_alt_name, epee_flatten, epee_try_from_into)
18)]
19pub fn derive_epee_object(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
20    // Parse the input tokens into a syntax tree.
21    let input = parse_macro_input!(input as DeriveInput);
22
23    let struct_name = input.ident;
24
25    let generics = add_trait_bounds(input.generics);
26    let (_impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();
27
28    let output = match input.data {
29        Data::Struct(data) => build(&data.fields, &struct_name),
30        _ => panic!("Only structs can be epee objects"),
31    };
32
33    output.into()
34}
35
36fn add_trait_bounds(mut generics: Generics) -> Generics {
37    for param in &mut generics.params {
38        if let GenericParam::Type(ref mut type_param) = *param {
39            type_param
40                .bounds
41                .push(parse_quote!(epee_encoding::EpeeValue));
42        }
43    }
44    generics
45}
46
47fn build(fields: &Fields, struct_name: &Ident) -> TokenStream {
48    let mut struct_fields = TokenStream::new();
49    let mut default_values = TokenStream::new();
50    let mut count_fields = TokenStream::new();
51    let mut write_fields = TokenStream::new();
52
53    let mut read_match_body = TokenStream::new();
54    let mut read_catch_all = TokenStream::new();
55
56    let mut object_finish = TokenStream::new();
57
58    let numb_o_fields: u64 = fields.len().try_into().unwrap();
59
60    for field in fields {
61        let field_name = field.ident.clone().expect("Epee only accepts named fields");
62        let field_type = &field.ty;
63        // If this field has a default value find it
64        let default_val: Option<Expr> = field
65            .attrs
66            .iter()
67            .find(|f| f.path().is_ident("epee_default"))
68            .map(|f| f.parse_args().unwrap());
69        // If this field has a different name when encoded find it
70        let alt_name: Option<Lit> = field
71            .attrs
72            .iter()
73            .find(|f| f.path().is_ident("epee_alt_name"))
74            .map(|f| f.parse_args().unwrap());
75
76        let is_flattened = field
77            .attrs
78            .iter()
79            .any(|f| f.path().is_ident("epee_flatten"));
80
81        let try_from_into: Option<Type> = field
82            .attrs
83            .iter()
84            .find(|f| f.path().is_ident("epee_try_from_into"))
85            .map(|f| f.parse_args().unwrap());
86
87        // Gets this objects epee name, the name its encoded with
88        let epee_name = if let Some(alt) = alt_name {
89            if is_flattened {
90                panic!("Cant rename a flattened field")
91            }
92            match alt {
93                Lit::Str(name) => name.value(),
94                _ => panic!("Alt name was not a string"),
95            }
96        } else {
97            field_name.to_string()
98        };
99
100        if try_from_into.is_some() && is_flattened {
101            panic!("Cant flatten this field: {}", field_name);
102        }
103
104        // This is fields part of a struct:
105        // struct T {
106        //  #struct_fields
107        // }
108        if is_flattened {
109            struct_fields = quote! {
110                #struct_fields
111                #field_name: <#field_type as epee_encoding::EpeeObject>::Builder,
112            };
113
114            count_fields = quote! {
115                #count_fields
116                // This filed has been flattened so dont count it.
117                numb_o_fields -= 1;
118                // Add the flattend fields to this one.
119                numb_o_fields += self.#field_name.number_of_fields();
120
121            };
122        } else if let Some(try_from_into) = &try_from_into {
123            struct_fields = quote! {
124                #struct_fields
125                #field_name: (Option<#try_from_into>, bool),
126            };
127        } else {
128            struct_fields = quote! {
129                #struct_fields
130                #field_name: (Option<#field_type>, bool),
131            };
132        }
133
134        let inner_write_field = if let Some(try_from_into) = &try_from_into {
135            quote! {
136                epee_encoding::write_field(&Into::<#try_from_into>::into(self.#field_name.clone()), &#epee_name, w)?;
137            }
138        } else {
139            quote! {
140                epee_encoding::write_field(&self.#field_name, &#epee_name, w)?;
141            }
142        };
143
144        // `default_val`: this is the body of a default function:
145        // fn default() -> Self {
146        //    Self {
147        //       #default_values
148        //    }
149        // }
150
151        // `count_fields`: this is the part of the write function that takes
152        // away from the number of fields if the field is the default value.
153
154        // `write_fields`: this is the part of the write function that writes
155        // this specific epee field.
156        if let Some(default_val) = default_val {
157            if is_flattened {
158                panic!("Cant have a default on a flattened field");
159            };
160
161            default_values = quote! {
162                #default_values
163                #field_name: (Some(#default_val), false),
164            };
165
166            if try_from_into.is_some() {
167                count_fields = quote! {
168                    #count_fields
169                    if self.#field_name == #default_val.into() {
170                        numb_o_fields -= 1;
171                    };
172                };
173                write_fields = quote! {
174                    #write_fields
175                    if self.#field_name != #default_val.into() {
176                         #inner_write_field
177                    }
178                }
179            } else {
180                count_fields = quote! {
181                    #count_fields
182                    if self.#field_name == #default_val {
183                        numb_o_fields -= 1;
184                    };
185                };
186
187                write_fields = quote! {
188                    #write_fields
189                    if self.#field_name != #default_val {
190                         #inner_write_field
191                    }
192                }
193            }
194        } else if !is_flattened {
195            if let Some(try_from_into) = &try_from_into {
196                count_fields = quote! {
197                    #count_fields
198                    if !epee_encoding::EpeeValue::should_write(&Into::<#try_from_into>::into(self.#field_name.clone())) {
199                        numb_o_fields -= 1;
200                    };
201                };
202            } else {
203                count_fields = quote! {
204                    #count_fields
205                    if !epee_encoding::EpeeValue::should_write(&self.#field_name) {
206                        numb_o_fields -= 1;
207                    };
208                };
209            }
210            default_values = quote! {
211                #default_values
212                #field_name: (epee_encoding::EpeeValue::epee_default_value(), false),
213            };
214
215            write_fields = quote! {
216                #write_fields
217                #inner_write_field
218            };
219        } else {
220            default_values = quote! {
221                #default_values
222                #field_name: Default::default(),
223            };
224
225            write_fields = quote! {
226                #write_fields
227                self.#field_name.write_fields(w)?;
228            };
229        };
230
231        // This is what these values do:
232        // fn add_field(name: &str, r: &mut r) -> Result<bool> {
233        //    match name {
234        //        #read_match_body
235        //        _ => {
236        //           #read_catch_all
237        //           return Ok(false);
238        //         }
239        //    }
240        //    Ok(true)
241        // }
242        if is_flattened {
243            read_catch_all = quote! {
244                #read_catch_all
245                if self.#field_name.add_field(name, r)? {
246                    return Ok(true);
247                };
248            };
249
250            object_finish = quote! {
251                #object_finish
252                #field_name: self.#field_name.finish()?,
253            };
254        } else {
255            if try_from_into.is_some() {
256                object_finish = quote! {
257                    #object_finish
258                    #field_name: self.#field_name.0.ok_or_else(|| epee_encoding::error::Error::Format("Required field was not found!"))?
259                                 .try_into().map_err(|_| epee_encoding::error::Error::Format("Error converting data using try_into"))?,
260                };
261            } else {
262                object_finish = quote! {
263                    #object_finish
264                    #field_name: self.#field_name.0.ok_or_else(|| epee_encoding::error::Error::Format("Required field was not found!"))?,
265                };
266            }
267            read_match_body = quote! {
268                #read_match_body
269                #epee_name => {
270                    self.#field_name.0.replace(epee_encoding::read_epee_value(r)?);
271                    if self.#field_name.1 {
272                        return Err(epee_encoding::error::Error::Format("Double key in data!"))
273                    }
274                    self.#field_name.1 = true;
275                },
276            };
277        }
278    }
279
280    let builder_name = Ident::new(&format!("__{}EpeeBuilder", struct_name), Span::call_site());
281    let mod_name = Ident::new(&format!("__{}_epee_module", struct_name), Span::call_site());
282
283    let builder_impl = quote! {
284        pub struct #builder_name {
285            #struct_fields
286        }
287
288        impl Default for #builder_name {
289            fn default() -> Self {
290                Self {
291                    #default_values
292                }
293            }
294        }
295
296        impl epee_encoding::EpeeObjectBuilder<#struct_name> for #builder_name {
297            fn add_field<R: epee_encoding::io::Read>(&mut self, name: &str, r: &mut R) -> epee_encoding::error::Result<bool> {
298                match name {
299                    #read_match_body
300                    _ => {
301                        #read_catch_all
302                        return Ok(false);
303                    }
304                };
305
306                Ok(true)
307            }
308
309            fn finish(self) -> epee_encoding::error::Result<#struct_name> {
310                Ok(#struct_name {
311                    #object_finish
312                })
313            }
314        }
315    };
316
317    let object_impl = quote! {
318        impl EpeeObject for #struct_name {
319            type Builder = #mod_name::#builder_name;
320
321            fn number_of_fields(&self) -> u64 {
322                let mut numb_o_fields: u64 = #numb_o_fields;
323                #count_fields
324                numb_o_fields
325            }
326
327
328            fn write_fields<W: epee_encoding::io::Write>(&self, w: &mut W) -> epee_encoding::error::Result<()> {
329
330                #write_fields
331
332                Ok(())
333            }
334        }
335    };
336
337    quote! {
338        mod #mod_name {
339            use super::*;
340            #builder_impl
341        }
342
343        #object_impl
344    }
345}