Skip to main content

nanoxml_derive/
lib.rs

1#![allow(unused)]
2
3use proc_macro::TokenStream;
4use proc_macro2::Span;
5use quote::quote;
6use syn::parse_macro_input;
7use syn::punctuated::Punctuated;
8use syn::token::Comma;
9use syn::{Attribute, Data, DataEnum, DataStruct, DeriveInput, Expr, Field};
10use syn::{Fields, GenericParam, Generics, Ident, Lifetime, LifetimeParam};
11use syn::{Lit, LitStr, Type, Variant};
12
13#[cfg(feature = "ser")]
14#[proc_macro_derive(SerXml, attributes(nanoxml))]
15pub fn derive_serxml(input: TokenStream) -> TokenStream {
16    let input = parse_macro_input!(input as DeriveInput);
17    let rename = get_rename_attr(&input.attrs).unwrap_or(input.ident.to_string());
18    match input.data {
19        Data::Struct(DataStruct {
20            fields: Fields::Named(ref fields),
21            ..
22        }) => derive_serxml_struct(&input.ident, &rename, &fields.named, input.generics),
23        Data::Enum(DataEnum { variants, .. }) => {
24            derive_serxml_enum(&input.ident, &rename, &variants)
25        }
26        _ => panic!("SerXml can only be derived for structs with named fields or enums"),
27    }
28}
29
30#[cfg(feature = "ser")]
31fn derive_serxml_struct(
32    name: &Ident,
33    rename: &str,
34    fields: &Punctuated<Field, Comma>,
35    generics: Generics,
36) -> TokenStream {
37    let xml_fields = get_xml_fields(fields);
38
39    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
40
41    let ser_text = xml_fields
42        .iter()
43        .find(|f| f.field_kind == FieldKind::Text)
44        .filter(|f| !f.skip_ser)
45        .map(|f| {
46            let field_name = f.field_name;
47            quote! { ::nanoxml::derive::ser::SerXmlAsAttr::ser_as_text(&self.#field_name, __xml)?; }
48        });
49
50    let ser_regular = xml_fields
51        .iter()
52        .filter(|f| f.field_kind == FieldKind::Regular)
53        .filter(|f| !f.skip_ser)
54        .map(|f| {
55            let field_name = f.field_name;
56            let renamed = &f.renamed;
57            quote! { ::nanoxml::derive::ser::SerXml::ser_xml(&self.#field_name, __xml, #renamed)?; }
58        });
59
60    let ser_attrs = xml_fields
61        .iter()
62        .filter(|f| f.field_kind == FieldKind::Attr)
63        .filter(|f| !f.skip_ser)
64        .map(|f| {
65            let field_name = f.field_name;
66            let renamed = &f.renamed;
67            quote! { ::nanoxml::derive::ser::SerXmlAsAttr::ser_as_attr(&self.#field_name, __xml, #renamed)?; }
68        });
69
70    let serxml_impl = quote! {
71        impl #impl_generics ::nanoxml::derive::ser::SerXml for #name #ty_generics #where_clause {
72            fn ser_body<W: ::core::fmt::Write>(&self, __xml: &mut ::nanoxml::ser::XmlBuilder<'_, W>) -> ::core::fmt::Result {
73                #ser_text
74                #(#ser_regular)*
75                Ok(())
76            }
77
78            fn ser_attrs<W: ::core::fmt::Write>(&self, __xml: &mut ::nanoxml::ser::XmlBuilder<'_, W>) -> ::core::fmt::Result {
79                #(#ser_attrs)*
80                Ok(())
81            }
82        }
83    };
84
85    let attr_impl = xml_fields.iter().all(|f| f.field_kind == FieldKind::Text).then(|| quote !{
86        impl #impl_generics ::nanoxml::derive::ser::SerXmlAsAttr for #name #ty_generics #where_clause {}
87    });
88
89    let top_level_impl = quote! {
90        impl #impl_generics ::nanoxml::derive::ser::SerXmlTopLevel for #name #ty_generics #where_clause {
91            const TAG_NAME: &'static str = #rename;
92        }
93    };
94
95    let full_impl = quote! {
96        #serxml_impl
97        #attr_impl
98        #top_level_impl
99    };
100
101    full_impl.into()
102}
103
104#[cfg(feature = "ser")]
105fn derive_serxml_enum(
106    name: &Ident,
107    rename: &str,
108    variants: &Punctuated<Variant, Comma>,
109) -> TokenStream {
110    let variants = get_xml_variants(variants);
111
112    let cases: Vec<_> = variants
113        .iter()
114        .map(|v| {
115            let variant_name = v.variant_name;
116            let renamed = &v.renamed;
117            quote! { Self::#variant_name => __xml.text(#renamed), }
118        })
119        .collect();
120
121    let serxml_impl = quote! {
122        impl ::nanoxml::derive::ser::SerXml for #name {
123            fn ser_body<W: ::core::fmt::Write>(&self, __xml: &mut ::nanoxml::ser::XmlBuilder<'_, W>) -> ::core::fmt::Result {
124                match self {
125                    #(#cases)*
126                }
127            }
128
129            fn ser_attrs<W: ::core::fmt::Write>(&self, __xml: &mut ::nanoxml::ser::XmlBuilder<'_, W>) -> ::core::fmt::Result {
130                Ok(())
131            }
132        }
133    };
134
135    let as_attr_impl = quote! { impl ::nanoxml::derive::ser::SerXmlAsAttr for #name {} };
136
137    let top_level_impl = quote! {
138        impl ::nanoxml::derive::ser::SerXmlTopLevel for #name {
139            const TAG_NAME: &'static str = #rename;
140        }
141    };
142
143    let full_impl = quote! {
144        #serxml_impl
145        #as_attr_impl
146        #top_level_impl
147    };
148
149    full_impl.into()
150}
151
152#[cfg(feature = "de")]
153#[proc_macro_derive(DeXml, attributes(nanoxml))]
154pub fn derive_dexml(input: TokenStream) -> TokenStream {
155    let input = parse_macro_input!(input as DeriveInput);
156    let rename = get_rename_attr(&input.attrs).unwrap_or(input.ident.to_string());
157    match input.data {
158        Data::Struct(DataStruct {
159            fields: Fields::Named(ref fields),
160            ..
161        }) => derive_dexml_struct(&input.ident, &rename, &fields.named, input.generics),
162        Data::Enum(DataEnum { variants, .. }) => {
163            derive_dexml_enum(&input.ident, &rename, &variants)
164        }
165        _ => panic!("DeXml can only be derived for structs with named fields or enums"),
166    }
167}
168
169#[cfg(feature = "de")]
170fn derive_dexml_struct(
171    name: &Ident,
172    rename: &str,
173    fields: &Punctuated<Field, Comma>,
174    generics: Generics,
175) -> TokenStream {
176    use quote::format_ident;
177
178    let xml_fields = get_xml_fields(fields);
179
180    let mut generics_clone = generics.clone();
181    let lifetime_param = match generics.lifetimes().next() {
182        Some(lt) => quote! { <#lt> },
183        None => {
184            generics_clone
185                .params
186                .push(GenericParam::Lifetime(LifetimeParam::new(Lifetime::new(
187                    "'a",
188                    Span::call_site(),
189                ))));
190            quote! { <'a> }
191        }
192    };
193    let (impl_generics, _, _) = generics_clone.split_for_impl();
194    let (_, ty_generics, where_clause) = generics.split_for_impl();
195
196    let field_init = xml_fields.iter().map(|f| {
197        let field_name = f.field_name;
198        let real_type = f.real_type;
199        match f.field_type {
200            FieldType::Seq => quote! { let mut #field_name = <#real_type as ::nanoxml::derive::de::DeXmlSeq>::new_seq(); },
201            _ => quote! { let mut #field_name = None; },
202        }
203    });
204
205    let de_attr = xml_fields
206        .iter()
207        .filter(|f| f.field_kind == FieldKind::Attr)
208        .map(|f| {
209            let field_name = f.field_name;
210            let renamed = &f.renamed;
211            quote! {
212                #renamed => {
213                    if #field_name.is_some() {
214                        return Err(::nanoxml::de::XmlError::DuplicateField);
215                    }
216                    #field_name = Some(::nanoxml::derive::de::DeXmlAttr::de_xml_attr(__attr_value)?);
217                }
218            }
219        });
220
221    let de_text = xml_fields
222        .iter()
223        .find(|f| f.field_kind == FieldKind::Text)
224        .map(|f| {
225            let field_name = f.field_name;
226            quote! {
227                #field_name = Some(::nanoxml::derive::de::DeXmlAttr::de_xml_attr(__parser.text()?)?);
228                __parser.tag_close()?;
229            }
230        });
231
232    let de_regular = xml_fields
233        .iter()
234        .filter(|f| f.field_kind == FieldKind::Regular)
235        .map(|f| {
236            let field_name = f.field_name;
237            let real_type = f.real_type;
238            let renamed = &f.renamed;
239            match f.field_type {
240                FieldType::Seq => quote! {
241                    #renamed => <#real_type as ::nanoxml::derive::de::DeXmlSeq>::push_item(&mut #field_name, __parser)?,
242                },
243                _ => quote! {
244                    #renamed => {
245                        if #field_name.is_some() {
246                            return Err(::nanoxml::de::XmlError::DuplicateField);
247                        }
248                        #field_name = Some(::nanoxml::derive::de::DeXml::de_xml(__parser)?);
249                    }
250                },
251            }
252        });
253
254    let de_body = match de_text {
255        Some(de_text) => de_text,
256        None => quote! {
257            while let Ok((__tag)) = __parser.tag_open_or_close()? {
258                match __tag {
259                    #(#de_regular)*
260                    _ => return Err(::nanoxml::de::XmlError::InvalidField("body".to_owned())),
261                }
262            }
263        },
264    };
265
266    let field_unwraps: Vec<_> = xml_fields
267        .iter()
268        .map(|f| {
269            let field_name = f.field_name;
270            match f.field_type {
271                FieldType::Regular => match &f.default_de {
272                    None => quote! { #field_name: #field_name.ok_or(::nanoxml::de::XmlError::MissingField(stringify!(#field_name)))?, },
273                    Some(None) => quote! { #field_name: #field_name.unwrap_or_default(), },
274                    Some(Some(func)) => {
275                        let func = format_ident!("{func}");
276                        quote! { #field_name: #field_name.unwrap_or_else(#func), }
277                    }
278                }
279                FieldType::Option => quote! { #field_name, },
280                FieldType::Seq => quote! { #field_name: ::nanoxml::derive::de::DeXmlSeq::finish(#field_name)?, },
281            }
282        })
283        .collect();
284
285    let dexml_impl = quote! {
286        impl #lifetime_param ::nanoxml::derive::de::DeXml #lifetime_param for #name #ty_generics #where_clause {
287            fn de_xml(__parser: &mut ::nanoxml::de::XmlParser<'a>) -> Result<Self, ::nanoxml::de::XmlError> {
288                #(#field_init)*
289                while let Ok((__attr_key, __attr_value)) = __parser.attr_or_tag_open_end()? {
290                    match __attr_key {
291                        #(#de_attr)*
292                        _ => (),
293                    }
294                }
295                #de_body
296                Ok(Self { #(#field_unwraps)* })
297            }
298        }
299    };
300
301    let top_level_impl = quote! {
302        impl #impl_generics ::nanoxml::derive::de::DeXmlTopLevel #lifetime_param for #name #ty_generics #where_clause {
303            const TAG_NAME: &'static str = #rename;
304        }
305    };
306
307    let full_impl = quote! {
308        #dexml_impl
309        #top_level_impl
310    };
311
312    full_impl.into()
313}
314
315#[cfg(feature = "de")]
316fn derive_dexml_enum(
317    name: &Ident,
318    rename: &str,
319    variants: &Punctuated<Variant, Comma>,
320) -> TokenStream {
321    let variants = get_xml_variants(variants);
322
323    if variants.is_empty() {
324        panic!("empty enum cannot be deserialized");
325    }
326
327    let cases: Vec<_> = variants
328        .iter()
329        .map(|v| {
330            let variant_name = v.variant_name;
331            let renamed = &v.renamed;
332            quote! { if s == #renamed { Ok(Self::#variant_name) } }
333        })
334        .collect();
335
336    let serxml_impl = quote! {
337        impl ::nanoxml::derive::de::DeXmlAttr<'_> for #name {
338            fn de_xml_attr(s: ::nanoxml::de::XmlStr<'_>) -> Result<Self, ::nanoxml::de::XmlError> {
339                #(#cases else)*
340                {
341                    Err(::nanoxml::de::XmlError::InvalidVariant)
342                }
343            }
344        }
345    };
346
347    let top_level_impl = quote! {
348        impl ::nanoxml::derive::de::DeXmlTopLevel<'_> for #name {
349            const TAG_NAME: &'static str = #rename;
350        }
351    };
352
353    let full_impl = quote! {
354        #serxml_impl
355        #top_level_impl
356    };
357
358    full_impl.into()
359}
360
361fn get_rename_attr(attrs: &[Attribute]) -> Option<String> {
362    let mut renamed = None;
363    for attr in attrs.iter().filter(|attr| attr.path().is_ident("nanoxml")) {
364        attr.parse_nested_meta(|meta| {
365            if meta.path.is_ident("rename") {
366                if renamed.is_some() {
367                    panic!("duplicate rename attr")
368                }
369                let value = meta.value().expect("rename requires value");
370                let lit: LitStr = value.parse().expect("rename requires atr value");
371                renamed = Some(lit.value());
372            } else {
373                panic!("invalid nanoxml attr");
374            }
375            Ok(())
376        })
377        .unwrap();
378    }
379    renamed
380}
381
382struct XmlField<'a> {
383    field_name: &'a Ident,
384    field_type: FieldType,
385    field_kind: FieldKind,
386    real_type: &'a Type,
387    renamed: String,
388    skip_ser: bool,
389    default_de: Option<Option<String>>,
390}
391
392struct XmlVariant<'a> {
393    variant_name: &'a Ident,
394    renamed: String,
395}
396
397#[derive(Copy, Clone, Eq, PartialEq)]
398enum FieldKind {
399    Regular,
400    Text,
401    Attr,
402}
403
404#[derive(Copy, Clone, Eq, PartialEq)]
405enum FieldType {
406    Regular,
407    Option,
408    Seq,
409}
410
411fn get_xml_fields(fields: &Punctuated<Field, Comma>) -> Vec<XmlField<'_>> {
412    let mut ret = Vec::<XmlField<'_>>::new();
413
414    for field in fields {
415        let field_name = field.ident.as_ref().unwrap();
416        let real_type = &field.ty;
417
418        let mut renamed = None;
419        let mut is_seq = false;
420        let mut is_attr = false;
421        let mut is_text = false;
422        let mut skip_ser = false;
423        let mut default_de = None;
424
425        for attr in field
426            .attrs
427            .iter()
428            .filter(|attr| attr.path().is_ident("nanoxml"))
429        {
430            attr.parse_nested_meta(|meta| {
431                if meta.path.is_ident("rename") {
432                    if renamed.is_some() {
433                        panic!("duplicate rename attr")
434                    }
435                    let value = meta.value().expect("rename requires value");
436                    let lit: LitStr = value.parse().expect("rename requires atr value");
437                    renamed = Some(lit.value());
438                } else if meta.path.is_ident("seq") {
439                    is_seq = true;
440                } else if meta.path.is_ident("attr") {
441                    is_attr = true;
442                } else if meta.path.is_ident("text") {
443                    is_text = true;
444                } else if meta.path.is_ident("skip_ser") {
445                    skip_ser = true;
446                } else if meta.path.is_ident("default_de") {
447                    if default_de.is_some() {
448                        panic!("duplicate default_de attr")
449                    }
450                    match meta.value() {
451                        Ok(value) => {
452                            let lit: LitStr = value.parse().expect("default_de requires str value");
453                            default_de = Some(Some(lit.value()));
454                        }
455                        Err(_) => default_de = Some(None),
456                    }
457                } else {
458                    panic!("invalid nanoxml attr");
459                }
460                Ok(())
461            })
462            .unwrap();
463        }
464
465        let field_type = if is_seq {
466            FieldType::Seq
467        } else if is_option(real_type) {
468            FieldType::Option
469        } else {
470            FieldType::Regular
471        };
472
473        if default_de.is_some() && field_type != FieldType::Regular {
474            panic!("default_de only works for non-option, non-seq fields");
475        }
476
477        let field_kind = match (is_attr, is_text) {
478            (true, true) => {
479                panic!("#[attr] and #[text] on the same field are incompatible");
480            }
481            (false, true) if ret.iter().any(|f| f.field_kind == FieldKind::Text) => {
482                panic!("only one #[text] field is allowed");
483            }
484            (false, true) if ret.iter().any(|f| f.field_kind == FieldKind::Regular) => {
485                panic!("#[text] is incompatible with regular fields");
486            }
487            (false, true) if renamed.is_some() => {
488                panic!("#[text] and #[rename] on the same field are incompatible");
489            }
490            (false, false) if ret.iter().any(|f| f.field_kind == FieldKind::Text) => {
491                panic!("#[text] is incompatible with regular fields");
492            }
493            (true, false) => FieldKind::Attr,
494            (false, true) => FieldKind::Text,
495            (false, false) => FieldKind::Regular,
496        };
497
498        let renamed = renamed.unwrap_or_else(|| field_name.to_string());
499
500        ret.push(XmlField {
501            field_name,
502            field_type,
503            field_kind,
504            real_type,
505            renamed,
506            skip_ser,
507            default_de,
508        });
509    }
510
511    ret
512}
513
514fn get_xml_variants(variants: &Punctuated<Variant, Comma>) -> Vec<XmlVariant<'_>> {
515    let mut ret = Vec::<XmlVariant<'_>>::new();
516
517    for variant in variants {
518        let variant_name = &variant.ident;
519
520        let mut renamed = variant_name.to_string();
521
522        for attr in variant
523            .attrs
524            .iter()
525            .filter(|attr| attr.path().is_ident("nanoxml"))
526        {
527            attr.parse_nested_meta(|meta| {
528                if meta.path.is_ident("rename") {
529                    let value = meta.value().expect("rename requires value");
530                    let lit: LitStr = value.parse().expect("rename requires atr value");
531                    renamed = lit.value();
532                } else {
533                    panic!("invalid nanoxml attr");
534                }
535                Ok(())
536            })
537            .unwrap();
538        }
539
540        ret.push(XmlVariant {
541            variant_name,
542            renamed,
543        });
544    }
545
546    ret
547}
548
549fn is_option(ty: &Type) -> bool {
550    check_type(
551        ty,
552        &["Option|", "std|option|Option|", "core|option|Option|"],
553    )
554}
555
556fn check_type(ty: &Type, valid: &[&str]) -> bool {
557    let path = match *ty {
558        Type::Path(ref path) if path.qself.is_none() => &path.path,
559        _ => return false,
560    };
561
562    let idents_of_path = path.segments.iter().fold(String::new(), |mut acc, v| {
563        acc.push_str(&v.ident.to_string());
564        acc.push('|');
565        acc
566    });
567
568    valid.iter().any(|s| idents_of_path == *s)
569}