geoip2_codegen/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    parse::{Parse, ParseStream},
5    parse_macro_input,
6    punctuated::Punctuated,
7    DeriveInput, Fields, FieldsNamed, GenericArgument, Ident, ItemStruct, LitStr, PathArguments,
8    Result, Token, Type,
9};
10
11fn extract_field(field_ident: Ident, ty: &Type) -> proc_macro2::TokenStream {
12    match &ty {
13        syn::Type::Path(tp) => {
14            let segment = &tp.path.segments[0];
15            let ident = &segment.ident;
16            match ident.to_string().as_str() {
17                "Option" => match &segment.arguments {
18                    PathArguments::AngleBracketed(ga) => match &ga.args[0] {
19                        GenericArgument::Type(ty) => match ty {
20                            syn::Type::Path(tp) => {
21                                let segment = &tp.path.segments[0];
22                                let ident = &segment.ident;
23                                match ident.to_string().as_str() {
24                                    "u64" => quote! {
25                                        self.#field_ident = Some(read_usize(buffer, offset)? as u64)
26                                    },
27                                    "u32" => quote! {
28                                        self.#field_ident = Some(read_usize(buffer, offset)? as u32)
29                                    },
30                                    "u16" => quote! {
31                                        self.#field_ident = Some(read_usize(buffer, offset)? as u16)
32                                    },
33                                    "f64" => quote! {
34                                        self.#field_ident = Some(read_f64(buffer, offset)?)
35                                    },
36                                    "bool" => quote! {
37                                        self.#field_ident = Some(read_bool(buffer, offset)?)
38                                    },
39                                    "Map" => quote! {
40                                        self.#field_ident = Some(read_map(buffer, offset)?)
41                                    },
42                                    "models" => {
43                                        let ident = &tp.path.segments[1].ident;
44                                        quote! {
45                                            let mut model = models::#ident::default();
46                                            model.from_bytes(buffer, offset)?;
47                                            self.#field_ident = Some(model);
48                                        }
49                                    }
50                                    "Vec" => match &segment.arguments {
51                                        PathArguments::AngleBracketed(ga) => match &ga.args[0] {
52                                            GenericArgument::Type(syn::Type::Path(p)) => {
53                                                let segment = &p.path.segments[0];
54                                                let ident = &segment.ident;
55                                                match ident.to_string().as_str() {
56                                                    "models" => {
57                                                        let ident = &p.path.segments[1].ident;
58                                                        quote! {
59                                                            let (data_type, size) = read_control(buffer, offset)?;
60                                                            self.#field_ident = Some(match data_type {
61                                                                DATA_TYPE_SLICE => {
62                                                                    let mut array: Vec<models::#ident<'a>> = Vec::with_capacity(size);
63                                                                    for _i in 0..size {
64                                                                        let mut model = models::#ident::default();
65                                                                        model.from_bytes(buffer, offset)?;
66                                                                        array.push(model);
67                                                                    }
68                                                                    array
69                                                                }
70                                                                DATA_TYPE_POINTER => {
71                                                                    let offset = &mut read_pointer(buffer, offset, size)?;
72                                                                    let (data_type, size) = read_control(buffer, offset)?;
73                                                                    match data_type {
74                                                                        DATA_TYPE_SLICE => {
75                                                                            let mut array: Vec<models::#ident<'a>> =
76                                                                                Vec::with_capacity(size);
77                                                                            for _ in 0..size {
78                                                                                let mut model = models::#ident::default();
79                                                                                model.from_bytes(buffer, offset)?;
80                                                                                array.push(model);
81                                                                            }
82                                                                            array
83                                                                        }
84                                                                        _ => return Err(Error::InvalidDataType(data_type)),
85                                                                    }
86                                                                }
87                                                                _ => return Err(Error::InvalidDataType(data_type)),
88                                                            })
89                                                        }
90                                                    }
91                                                    _ => unimplemented!(),
92                                                }
93                                            }
94                                            _ => unimplemented!("{:?}", &ga.args[0]),
95                                        },
96                                        _ => unimplemented!("{:?}", &segment.arguments),
97                                    },
98                                    _ => unimplemented!("{:?}", ident),
99                                }
100                            }
101                            syn::Type::Reference(tr) => match tr.elem.as_ref() {
102                                syn::Type::Path(tp) => {
103                                    let segment = &tp.path.segments[0];
104                                    let ident = &segment.ident;
105                                    match ident.to_string().as_str() {
106                                        "str" => quote! {
107                                            self.#field_ident = Some(read_str(buffer, offset)?)
108                                        },
109                                        _ => unimplemented!("{:?}", ident),
110                                    }
111                                }
112                                _ => unimplemented!("{:?}", tr.elem),
113                            },
114                            _ => unimplemented!("{:?}", ty),
115                        },
116                        _ => unimplemented!("{:?}", &ga.args[0]),
117                    },
118                    _ => unimplemented!("{:?}", &segment.arguments),
119                },
120                "Vec" => match &segment.arguments {
121                    PathArguments::AngleBracketed(ga) => match &ga.args[0] {
122                        GenericArgument::Type(ty) => match ty {
123                            syn::Type::Reference(tr) => match tr.elem.as_ref() {
124                                syn::Type::Path(tp) => {
125                                    let segment = &tp.path.segments[0];
126                                    let ident = &segment.ident;
127                                    match ident.to_string().as_str() {
128                                        "str" => quote! {
129                                            self.#field_ident = read_array(buffer, offset)?
130                                        },
131                                        _ => unimplemented!("{:?}", ident),
132                                    }
133                                }
134                                _ => unimplemented!("{:?}", tr.elem),
135                            },
136                            _ => unimplemented!("{:?}", ty),
137                        },
138                        _ => unimplemented!("{:?}", &ga.args[0]),
139                    },
140                    _ => unimplemented!("{:?}", &segment.arguments),
141                },
142                "u64" => quote! {
143                    self.#field_ident = read_usize(buffer, offset)? as u64
144                },
145                "u32" => quote! {
146                    self.#field_ident = read_usize(buffer, offset)? as u32
147                },
148                "u16" => quote! {
149                    self.#field_ident = read_usize(buffer, offset)? as u16
150                },
151                "Map" => quote! {
152                    self.#field_ident = read_map(buffer, offset)?
153                },
154                _ => unimplemented!("{:?}", ident),
155            }
156        }
157        syn::Type::Reference(tr) => match tr.elem.as_ref() {
158            syn::Type::Path(tp) => {
159                let segment = &tp.path.segments[0];
160                let ident = &segment.ident;
161                match ident.to_string().as_str() {
162                    "str" => quote! {
163                        self.#field_ident = read_str(buffer, offset)?
164                    },
165                    _ => unimplemented!("{:?}", ident),
166                }
167            }
168            _ => unimplemented!("{:?}", tr.elem),
169        },
170        _ => unimplemented!("{:?}", ty),
171    }
172}
173
174fn extract_fields(fields: &Fields) -> Vec<proc_macro2::TokenStream> {
175    let fields = if let syn::Fields::Named(FieldsNamed { named, .. }) = fields {
176        named
177    } else {
178        unimplemented!("{:?}", fields);
179    };
180    let mut result = Vec::new();
181    for field in fields.iter() {
182        let field_ident = field.ident.clone().unwrap();
183        let mut field_ident_value = format!("{}", field_ident);
184        if field_ident_value == "country_type" {
185            field_ident_value = "type".into();
186        }
187        let field_stream = extract_field(field_ident, &field.ty);
188        result.push(quote! {
189            #field_ident_value => {
190                #field_stream
191            }
192        });
193    }
194    result
195}
196
197#[proc_macro_derive(Decoder)]
198pub fn derive_decoder(input: TokenStream) -> TokenStream {
199    let DeriveInput {
200        ident,
201        generics,
202        data,
203        ..
204    } = parse_macro_input!(input);
205
206    let fields = if let syn::Data::Struct(s) = data {
207        extract_fields(&s.fields)
208    } else {
209        unimplemented!("{:?}", data)
210    };
211
212    let output = quote! {
213        impl<'a> #ident #generics {
214            pub(crate) fn from_bytes(&mut self, buffer: &'a [u8], offset: &mut usize) -> Result<(), Error> {
215                let (data_type, size) = read_control(buffer, offset)?;
216                match data_type {
217                    DATA_TYPE_MAP => self.from_bytes_map(buffer, offset, size),
218                    DATA_TYPE_POINTER => {
219                        let offset = &mut read_pointer(buffer, offset, size)?;
220                        let (data_type, size) = read_control(buffer, offset)?;
221                        match data_type {
222                            DATA_TYPE_MAP => self.from_bytes_map(buffer, offset, size),
223                            _ => return Err(Error::InvalidDataType(data_type)),
224                        }
225                    }
226                    _ => return Err(Error::InvalidDataType(data_type)),
227                }
228            }
229
230            fn from_bytes_map(
231                &mut self,
232                buffer: &'a [u8],
233                offset: &mut usize,
234                size: usize,
235            ) -> Result<(), Error> {
236                for _ in 0..size {
237                    match read_str(buffer, offset)? {
238                        #(#fields ,)*
239                        field => return Err(Error::UnknownField(field.into()))
240                    }
241                }
242                Ok(())
243            }
244        }
245    };
246
247    output.into()
248}
249
250struct Args {
251    types: Vec<LitStr>,
252}
253
254impl Parse for Args {
255    fn parse(input: ParseStream) -> Result<Self> {
256        let vars = Punctuated::<LitStr, Token![,]>::parse_terminated(input)?;
257        Ok(Args {
258            types: vars.into_iter().collect(),
259        })
260    }
261}
262
263#[proc_macro_attribute]
264pub fn reader(metadata: TokenStream, input: TokenStream) -> TokenStream {
265    let types = parse_macro_input!(metadata as Args).types;
266
267    let types_len = types.len();
268
269    let input = parse_macro_input!(input as ItemStruct);
270    let ident = &input.ident;
271    let generics = &input.generics;
272    let fields = extract_fields(&input.fields);
273
274    let output = quote! {
275        #input
276
277        impl<'a> Reader<'a, #ident #generics> {
278            pub fn from_bytes(buffer: &[u8]) -> Result<Reader<#ident>, Error> {
279                const types: [&'static str; #types_len] = [#(#types ,)*];
280                let reader = Reader::from_bytes_raw(buffer)?;
281                if !types.contains(&reader.metadata.database_type) {
282                    return Err(Error::InvalidDatabaseType(
283                        reader.metadata.database_type.into(),
284                    ));
285                }
286                Ok(reader)
287            }
288
289            pub fn lookup(&self, address: IpAddr) -> Result<#ident, Error> {
290                let mut result = #ident::default();
291                result.from_bytes(self.decoder_buffer, &mut self.get_offset(address)?)?;
292                Ok(result)
293            }
294        }
295
296        impl<'a> #ident #generics {
297            pub(crate) fn from_bytes(&mut self, buffer: &'a [u8], offset: &mut usize) -> Result<(), Error> {
298                let (data_type, size) = read_control(buffer, offset)?;
299                if data_type != DATA_TYPE_MAP {
300                    return Err(Error::InvalidDataType(data_type));
301                }
302                self.from_bytes_map(buffer, offset, size)
303            }
304
305            fn from_bytes_map(
306                &mut self,
307                buffer: &'a [u8],
308                offset: &mut usize,
309                size: usize,
310            ) -> Result<(), Error> {
311                for _ in 0..size {
312                    match read_str(buffer, offset)? {
313                        #(#fields ,)*
314                        field => return Err(Error::UnknownField(field.into()))
315                    }
316                }
317                Ok(())
318            }
319        }
320    };
321    output.into()
322}