netgauze_serde_macros/
lib.rs

1// Copyright (C) 2022-present The NetGauze Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
12// implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use quote::{format_ident, quote, TokenStreamExt};
17use syn::{parse::Parse, spanned::Spanned, Expr, Lit};
18
19#[derive(Debug)]
20struct AttributeNameValue {
21    ident: syn::Ident,
22    value: String,
23}
24
25#[derive(Debug)]
26struct Attribute {
27    value: Vec<AttributeNameValue>,
28}
29
30fn parse_attribute(attr: &syn::Attribute) -> Result<Option<Attribute>, syn::Error> {
31    let mut out = vec![];
32    let expr = attr.parse_args_with(syn::Expr::parse)?;
33    if let Expr::Assign(assign) = expr {
34        if let (Expr::Path(left_path), Expr::Lit(right_lit)) =
35            (assign.left.as_ref(), assign.right.as_ref())
36        {
37            if let (Some(ident), Lit::Str(str_lit)) = (left_path.path.get_ident(), &right_lit.lit) {
38                let attr = AttributeNameValue {
39                    ident: ident.clone(),
40                    value: str_lit.value(),
41                };
42                out.push(attr);
43            }
44        }
45    }
46    Ok(Some(Attribute { value: out }))
47}
48
49fn filter_attribute_by_name(
50    enum_data: &syn::DataEnum,
51    filter: &str,
52) -> syn::Result<(Vec<syn::Ident>, Vec<syn::Ident>)> {
53    let mut variants = vec![];
54    let mut idents = vec![];
55    for variant in &enum_data.variants {
56        for field in &variant.fields {
57            for attr in field.attrs.iter().filter(|attr| {
58                attr.path()
59                    .segments
60                    .iter()
61                    .any(|seg| seg.ident == syn::Ident::new(filter, seg.span()))
62            }) {
63                if let syn::Type::Path(path) = &field.ty {
64                    variants.push(variant.ident.clone());
65                    let ident = path.path.get_ident();
66                    match ident {
67                        Some(ident) => idents.push(ident.clone()),
68                        None => {
69                            return Err(syn::Error::new(
70                                attr.span(),
71                                "Couldn't find identifier for this attribute",
72                            ));
73                        }
74                    }
75                }
76            }
77        }
78    }
79    Ok((variants, idents))
80}
81
82fn filter_attribute_by_name_with_module(
83    enum_data: &syn::DataEnum,
84    filter: &str,
85) -> (Vec<syn::Ident>, Vec<(Vec<syn::Ident>, syn::Ident)>) {
86    let mut variants = vec![];
87    let mut idents = vec![];
88    for variant in &enum_data.variants {
89        for field in &variant.fields {
90            for _attr in field.attrs.iter().filter(|attr| {
91                attr.path()
92                    .segments
93                    .iter()
94                    .any(|seg| seg.ident == syn::Ident::new(filter, seg.span()))
95            }) {
96                if let syn::Type::Path(path) = &field.ty {
97                    variants.push(variant.ident.clone());
98                    let ident = path.path.get_ident();
99                    match ident {
100                        Some(ident) => idents.push((vec![format_ident!("self")], ident.clone())),
101                        None => {
102                            let segments = path.path.segments.iter().collect::<Vec<_>>();
103                            let module_path = segments.as_slice()[0..segments.len() - 1]
104                                .iter()
105                                .map(|x| x.ident.clone())
106                                .collect::<Vec<_>>();
107                            let from_ident = segments.last().unwrap().ident.clone();
108                            idents.push((module_path, from_ident));
109                        }
110                    }
111                }
112            }
113        }
114    }
115    (variants, idents)
116}
117
118#[derive(Debug)]
119struct LocatedError {}
120
121impl LocatedError {
122    fn get_from_nom(enum_data: &syn::DataEnum) -> syn::Result<Vec<syn::Ident>> {
123        let mut from_nom_variants = vec![];
124        for variant in &enum_data.variants {
125            for field in &variant.fields {
126                for _ in field.attrs.iter().filter(|attr| {
127                    attr.path()
128                        .segments
129                        .iter()
130                        .any(|seg| seg.ident == syn::Ident::new("from_nom", seg.span()))
131                }) {
132                    if let syn::Type::Path(_) = &field.ty {
133                        from_nom_variants.push(variant.ident.clone());
134                    }
135                }
136            }
137        }
138        Ok(from_nom_variants)
139    }
140
141    fn get_from_located(
142        enum_data: &syn::DataEnum,
143    ) -> syn::Result<Vec<(syn::Ident, syn::Ident, Vec<syn::Ident>)>> {
144        let mut ret = vec![];
145        for variant in &enum_data.variants {
146            for field in &variant.fields {
147                for attr in field.attrs.iter().filter(|attr| {
148                    attr.path()
149                        .segments
150                        .iter()
151                        .any(|seg| seg.ident == syn::Ident::new("from_located", seg.span()))
152                }) {
153                    if let syn::Type::Path(path) = &field.ty {
154                        let located_variants = variant.ident.clone();
155                        let ident = path.path.get_ident();
156                        let (ident_module, located_ident) = match ident {
157                            Some(ident) => (None, format_ident!("Located{}", ident.clone())),
158                            None => {
159                                let path = path
160                                    .path
161                                    .segments
162                                    .iter()
163                                    .map(|x| x.ident.to_string())
164                                    .collect::<Vec<String>>();
165                                let ident_string = path.join("::");
166                                let ident_module = ident_string
167                                    [..ident_string.rfind("::").unwrap_or(0)]
168                                    .to_string();
169                                let ident_module = if ident_module.is_empty() {
170                                    None
171                                } else {
172                                    Some(ident_module)
173                                };
174                                let located_ident = ident_string
175                                    [ident_string.rfind("::").map(|x| x + 2).unwrap_or(0)..]
176                                    .to_string();
177                                (ident_module, format_ident!("Located{}", located_ident))
178                            }
179                        };
180                        let located_module = match parse_attribute(attr)? {
181                            None => {
182                                return Err(syn::Error::new(
183                                    attr.span(),
184                                    "'module' must be defined",
185                                ));
186                            }
187                            Some(parsed_attr) => {
188                                match parsed_attr.value.first() {
189                                    None => {
190                                        return Err(syn::Error::new(
191                                            attr.span(),
192                                            "'module' of the Located error is not defined",
193                                        ));
194                                    }
195                                    Some(name_value) => {
196                                        if name_value.ident != format_ident!("module") {
197                                            return Err(syn::Error::new(
198                                                attr.span(),
199                                                format!("Only accepts one attribute 'module', found {:?}", name_value.ident),
200                                            ));
201                                        }
202                                        let mut module_path = name_value.value.clone();
203                                        if let Some(path) = ident_module {
204                                            if !module_path.is_empty() {
205                                                module_path.push_str("::");
206                                            }
207                                            module_path.push_str(path.as_str());
208                                        }
209                                        module_path
210                                            .split("::")
211                                            .map(|part| format_ident!("{}", part))
212                                            .collect()
213                                    }
214                                }
215                            }
216                        };
217                        ret.push((located_variants, located_ident, located_module));
218                    }
219                }
220            }
221        }
222        Ok(ret)
223    }
224
225    fn from(input: &syn::DeriveInput) -> Result<proc_macro::TokenStream, syn::Error> {
226        let syn::Data::Enum(en) = &input.data else {
227            return Err(syn::Error::new(
228                input.span(),
229                "Works only with enum error types",
230            ));
231        };
232        let ident = input.ident.clone();
233        let located_struct_name: syn::Ident = format_ident!("Located{}", ident);
234
235        let from_nom_variants = LocatedError::get_from_nom(en)?;
236        let (from_external_variants, from_external_ident) =
237            filter_attribute_by_name(en, "from_external")?;
238        let from_located = LocatedError::get_from_located(en)?;
239
240        let mut output = quote! {
241            #[derive(PartialEq, Clone, Debug)]
242            #[automatically_derived]
243            pub struct #located_struct_name<'a> {
244                span: netgauze_parse_utils::Span<'a>,
245                error: #ident,
246            }
247
248            #[automatically_derived]
249            impl<'a> #located_struct_name<'a> {
250                pub const fn new(span: netgauze_parse_utils::Span<'a>, error: #ident) -> Self {
251                    Self { span, error }
252                }
253            }
254
255            #[automatically_derived]
256            impl<'a> From<#located_struct_name<'a>> for (netgauze_parse_utils::Span<'a>, #ident) {
257                fn from(input: #located_struct_name<'a>) -> Self {
258                    (input.span, input.error)
259                }
260            }
261
262            #[automatically_derived]
263            impl<'a> netgauze_parse_utils::LocatedParsingError for #located_struct_name<'a> {
264                type Span = netgauze_parse_utils::Span<'a>;
265                type Error = #ident;
266
267                fn span(&self) -> &Self::Span {
268                    &self.span
269                }
270
271                fn error(&self) -> &Self::Error {
272                    &self.error
273                }
274            }
275
276            #[automatically_derived]
277            impl<'a> nom::error::FromExternalError<netgauze_parse_utils::Span<'a>, #ident> for #located_struct_name<'a> {
278                fn from_external_error(input: netgauze_parse_utils::Span<'a>, _kind: nom::error::ErrorKind, error:  #ident) -> Self {
279                    #located_struct_name::new(input, error)
280                }
281            }
282
283            #(
284                #[automatically_derived]
285                impl<'a> nom::error::FromExternalError<netgauze_parse_utils::Span<'a>, #from_external_ident> for #located_struct_name<'a> {
286                    fn from_external_error(input: netgauze_parse_utils::Span<'a>, _kind: nom::error::ErrorKind, error:  #from_external_ident) -> Self {
287                        #located_struct_name::new(input, #ident::#from_external_variants(error))
288                    }
289                }
290            )*
291
292            #(
293                #[automatically_derived]
294                impl<'a> nom::error::ParseError<netgauze_parse_utils::Span<'a>> for #located_struct_name<'a> {
295                    fn from_error_kind(input: netgauze_parse_utils::Span<'a>, kind: nom::error::ErrorKind) -> Self {
296                        #located_struct_name::new(input, #ident::#from_nom_variants(kind))
297                    }
298
299                    fn append(_input: netgauze_parse_utils::Span<'a>, _kind: nom::error::ErrorKind, other: Self) -> Self {
300                        other
301                    }
302                }
303            )*
304        };
305
306        for (located_variant, located_ident, located_module) in &from_located {
307            let tmp = quote! {
308                #[automatically_derived]
309                impl<'a> From<#(#located_module)::*::#located_ident<'a>> for #located_struct_name<'a> {
310                    fn from(input: #(#located_module)::*::#located_ident<'a>) -> Self {
311                        let (span, error) = input.into();
312                        #located_struct_name::new(span, #ident::#located_variant(error))
313                    }
314                }
315            };
316            output.append_all(tmp);
317        }
318        Ok(proc_macro::TokenStream::from(output))
319    }
320}
321
322/// For a given error enum {Name} generate a struct called Located{Name} that
323/// carries the `Span` (the error location in the input stream) info along the
324/// error. Additionally, generates [`From`] for `nom` library errors, external,
325/// and another located errors.
326///
327/// Example:
328/// ```no_compile
329/// use netgauze_serde_macros::LocatedError;
330///
331/// #[derive(LocatedError, PartialEq, Clone, Debug)]
332/// pub enum ExtendedCommunityParsingError {
333///     NomError(#[from_nom] nom::error::ErrorKind),
334///     CommunityError(#[from_located(module = "self")] CommunityParsingError),
335///     UndefinedCapabilityCode(#[from_external] UndefinedBgpCapabilityCode),
336/// }
337///
338/// #[derive(LocatedError, PartialEq, Clone, Debug)]
339/// pub enum CommunityParsingError {
340///     NomError(#[from_nom] nom::error::ErrorKind),
341/// }
342///
343/// #[derive(Copy, Clone, PartialEq, Debug)]
344/// pub struct UndefinedBgpCapabilityCode(pub u8);
345/// ```
346#[proc_macro_derive(LocatedError, attributes(from_nom, from_external, from_located))]
347pub fn located_error(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
348    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
349    LocatedError::from(&ast)
350        .unwrap_or_else(|err| proc_macro::TokenStream::from(err.to_compile_error()))
351}
352
353#[derive(Debug)]
354struct WritingError {}
355
356impl WritingError {
357    fn from(input: &syn::DeriveInput) -> Result<proc_macro::TokenStream, syn::Error> {
358        let syn::Data::Enum(en) = &input.data else {
359            return Err(syn::Error::new(
360                input.span(),
361                "Works only with enum error types",
362            ));
363        };
364        let ident = input.ident.clone();
365        let (from_variants, from_idents) = filter_attribute_by_name_with_module(en, "from");
366        let (from_std_io_error_variants, _) = filter_attribute_by_name(en, "from_std_io_error")?;
367
368        let mut output = quote! {
369            #(
370                #[automatically_derived]
371                impl From<std::io::Error> for #ident {
372                    fn from(err: std::io::Error) -> Self {
373                        #ident::#from_std_io_error_variants(err.to_string())
374                    }
375                }
376            )*
377        };
378        for (index, variant) in from_variants.iter().enumerate() {
379            let (from_module_path, from_ident) = from_idents
380                .get(index)
381                .expect("Error in generating WritingError");
382            let tmp = quote! {
383                #[automatically_derived]
384                impl From<#(#from_module_path)::*::#from_ident> for #ident {
385                    fn from(err: #(#from_module_path)::*::#from_ident) -> Self {
386                        #ident::#variant(err)
387                    }
388                }
389            };
390            output.append_all(tmp);
391        }
392        Ok(proc_macro::TokenStream::from(output))
393    }
394}
395
396/// Decorate an `enum` as an error for serializing binary protocol
397/// provides the following decorations for any members of the enum.
398///
399/// 1. `#[from_std_io_error]` automatically generate [`From`] implementation
400///    from [`std::io::Error`] to a [`String`].
401///
402/// 2. `#[from]`, automatically generates a [`From`] implementation for a given
403///    type.
404///
405///
406/// Example:
407/// ```no_compile
408/// use netgauze_serde_macros::WritingError;
409///
410/// #[derive(WritingError, PartialEq, Clone, Debug)]
411/// pub enum BgpOpenMessageWritingError {
412///     // std::io::Error will be converted to this value
413///     StdIOError(#[from_std_io_error] String),
414/// }
415///
416/// #[derive(WritingError, PartialEq, Clone, Debug)]
417/// pub enum BgpMessageWritingError {
418///     // std::io::Error will be converted to this value
419///     StdIOError(#[from_std_io_error] String),
420///
421///     // BgpOpenMessageWritingError will be converted to this value
422///     OpenError(#[from] BgpOpenMessageWritingError),
423/// }
424/// ```
425#[proc_macro_derive(WritingError, attributes(from_std_io_error, from))]
426pub fn writing_error(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
427    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
428    WritingError::from(&ast)
429        .unwrap_or_else(|err| proc_macro::TokenStream::from(err.to_compile_error()))
430}