Skip to main content

grib_template_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3
4/// Derive macro generating an impl of the trait
5/// `grib_template_helpers::TryFromSlice`.
6#[proc_macro_derive(TryFromSlice, attributes(grib_template))]
7pub fn derive_try_from_slice(input: TokenStream) -> TokenStream {
8    let input = syn::parse_macro_input!(input as syn::DeriveInput);
9
10    match &input.data {
11        syn::Data::Struct(data) => impl_try_from_slice_for_struct(&input, data),
12        syn::Data::Enum(data) => impl_try_from_slice_for_enum(&input, data),
13        _ => unimplemented!("`TryFromSlice` can only be derived for structs/enums"),
14    }
15    .into()
16}
17
18fn impl_try_from_slice_for_struct(
19    input: &syn::DeriveInput,
20    data: &syn::DataStruct,
21) -> proc_macro2::TokenStream {
22    let name = &input.ident;
23    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
24    let Some((kind, fields)) = extract_struct_info(data) else {
25        unimplemented!(
26            "`TryFromSlice` can only be derived for structs with named fields or with a single unnamed `u8` field"
27        )
28    };
29
30    if kind == StructKind::TupleStruct {
31        let field_reads = fields.iter().map(|field| {
32            let ty = &field.ty;
33            quote! {
34                <#ty as grib_template_helpers::TryFromSlice>::try_from_slice(slice, pos)?
35            }
36        });
37
38        return quote! {
39            impl #impl_generics grib_template_helpers::TryFromSlice for #name #type_generics #where_clause {
40                fn try_from_slice(
41                    slice: &[u8],
42                    pos: &mut usize,
43                ) -> grib_template_helpers::TryFromSliceResult<Self> {
44                    Ok(Self(#(#field_reads),*))
45                }
46            }
47        };
48    }
49
50    let mut field_reads = Vec::new();
51    let mut idents = Vec::new();
52
53    for field in fields {
54        let ident = field.ident.as_ref().unwrap();
55        let ty = &field.ty;
56
57        let len_attr = field
58            .attrs
59            .iter()
60            .find_map(|attr| attr_value(attr, "len").map(|v| parse_len_attr(&v)));
61        if let Some(len) = len_attr {
62            if let syn::Type::Path(type_path) = ty
63                && let Some((inner_ty, has_option)) = extract_vec_inner(type_path)
64            {
65                let tokens = quote! {
66                    let mut #ident = Vec::with_capacity(#len);
67                    for _ in 0..#len {
68                        let item =
69                            <#inner_ty as grib_template_helpers::TryFromSlice>::try_from_slice(
70                                slice,
71                                pos,
72                            )?;
73                        #ident.push(item);
74                    }
75                };
76
77                let tokens = if has_option {
78                    quote! {
79                        let #ident = if *pos == slice.len() {
80                            None
81                        } else {
82                            #tokens
83                            Some(#ident)
84                        };
85                    }
86                } else {
87                    tokens
88                };
89                field_reads.push(tokens);
90
91                idents.push(ident);
92                continue;
93            }
94            unimplemented!(
95                "`#[grib_template(len = N)]` is only available for `Vec<T>` and `Option<Vec<T>>`"
96            );
97        }
98
99        let disc_attr = field
100            .attrs
101            .iter()
102            .find_map(|attr| attr_value(attr, "variant").map(|v| parse_variant_attr(&v)));
103        if let Some(disc_ident) = disc_attr {
104            field_reads.push(quote! {
105                let #ident = <#ty as grib_template_helpers::TryEnumFromSlice>::try_enum_from_slice(
106                    #disc_ident,
107                    slice,
108                    pos,
109                )?;
110            });
111            idents.push(ident);
112            continue;
113        }
114
115        field_reads.push(quote! {
116            let #ident = <#ty as grib_template_helpers::TryFromSlice>::try_from_slice(slice, pos)?;
117        });
118        idents.push(ident);
119    }
120
121    quote! {
122        impl #impl_generics grib_template_helpers::TryFromSlice for #name #type_generics #where_clause {
123            fn try_from_slice(
124                slice: &[u8],
125                pos: &mut usize,
126            ) -> grib_template_helpers::TryFromSliceResult<Self> {
127                #(#field_reads)*
128                Ok(Self { #(#idents),* })
129            }
130        }
131    }
132}
133
134fn impl_try_from_slice_for_enum(
135    input: &syn::DeriveInput,
136    data: &syn::DataEnum,
137) -> proc_macro2::TokenStream {
138    let name = &input.ident;
139    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
140
141    let mut arms = Vec::new();
142
143    for variant in &data.variants {
144        let variant_ident = &variant.ident;
145        let disc_expr = variant
146            .discriminant
147            .as_ref()
148            .expect("`TryFromSlice` requires the enum to have explicit discriminant")
149            .1
150            .clone();
151
152        if let syn::Fields::Unnamed(fields) = &variant.fields
153            && fields.unnamed.len() == 1
154        {
155            let inner_ty = &fields.unnamed.first().unwrap().ty;
156            arms.push(quote! {
157                #disc_expr => {
158                    let inner = <#inner_ty as grib_template_helpers::TryFromSlice>::try_from_slice(
159                        slice,
160                        pos
161                    )?;
162                    Ok(#name::#variant_ident(inner))
163                }
164            });
165        } else {
166            unimplemented!("`TryFromSlice` only supports single-field tuple variants");
167        }
168    }
169
170    quote! {
171        impl #impl_generics grib_template_helpers::TryEnumFromSlice for #name #type_generics #where_clause {
172            fn try_enum_from_slice(
173                discriminant: impl Into<u64>,
174                slice: &[u8],
175                pos: &mut usize,
176            ) -> grib_template_helpers::TryFromSliceResult<Self> {
177                match discriminant.into() {
178                    #(#arms),*,
179                    _ => panic!("unknown variant for {}", stringify!(#name)),
180                }
181            }
182        }
183    }
184}
185
186enum LenKind {
187    Literal(usize),
188    Ident(syn::Ident),
189}
190
191impl quote::ToTokens for LenKind {
192    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
193        match self {
194            LenKind::Literal(n) => {
195                tokens.extend(quote! { #n });
196            }
197            LenKind::Ident(ident) => {
198                tokens.extend(quote! { #ident as usize });
199            }
200        }
201    }
202}
203
204fn attr_value(attr: &syn::Attribute, ident: &str) -> Option<syn::Expr> {
205    if !attr.path().is_ident("grib_template") {
206        return None;
207    }
208    let meta = attr.parse_args::<syn::Meta>().ok()?;
209    if let syn::Meta::NameValue(nv) = meta {
210        if !nv.path.is_ident(ident) {
211            return None;
212        }
213        Some(nv.value)
214    } else {
215        None
216    }
217}
218
219fn parse_len_attr(attr_value: &syn::Expr) -> Option<LenKind> {
220    match attr_value {
221        syn::Expr::Lit(syn::ExprLit {
222            lit: syn::Lit::Int(lit_int),
223            ..
224        }) => Some(LenKind::Literal(lit_int.base10_parse::<usize>().unwrap())),
225        syn::Expr::Lit(syn::ExprLit {
226            lit: syn::Lit::Str(lit_str),
227            ..
228        }) => Some(LenKind::Ident(syn::Ident::new(
229            &lit_str.value(),
230            lit_str.span(),
231        ))),
232        _ => None,
233    }
234}
235
236fn parse_variant_attr(attr_value: &syn::Expr) -> Option<syn::Ident> {
237    match attr_value {
238        syn::Expr::Lit(syn::ExprLit {
239            lit: syn::Lit::Str(lit_str),
240            ..
241        }) => Some(syn::Ident::new(&lit_str.value(), lit_str.span())),
242        _ => None,
243    }
244}
245
246fn extract_vec_inner(type_path: &syn::TypePath) -> Option<(syn::Type, bool)> {
247    if type_path.path.segments.len() == 1 {
248        let (type_path, has_option) = if type_path.path.segments[0].ident == "Option"
249            && let syn::PathArguments::AngleBracketed(ref args) =
250                type_path.path.segments[0].arguments
251            && let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first()
252            && let syn::Type::Path(type_path) = inner_ty
253        {
254            (type_path, true)
255        } else {
256            (type_path, false)
257        };
258
259        if type_path.path.segments[0].ident == "Vec"
260            && let syn::PathArguments::AngleBracketed(ref args) =
261                type_path.path.segments[0].arguments
262            && let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first()
263        {
264            return Some((inner_ty.clone(), has_option));
265        }
266    }
267    None
268}
269
270/// Derive macro generating an impl of the trait `grib_template_helpers::Dump`.
271#[proc_macro_derive(Dump)]
272pub fn derive_dump(input: TokenStream) -> TokenStream {
273    let input = syn::parse_macro_input!(input as syn::DeriveInput);
274
275    match &input.data {
276        syn::Data::Struct(data) => impl_dump_for_struct(&input, data),
277        syn::Data::Enum(data) => impl_dump_for_enum(&input, data),
278        _ => unimplemented!("`Dump` can only be derived for structs/enums"),
279    }
280    .into()
281}
282
283fn impl_dump_for_struct(
284    input: &syn::DeriveInput,
285    data: &syn::DataStruct,
286) -> proc_macro2::TokenStream {
287    let name = &input.ident;
288    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
289    let Some((kind, fields)) = extract_struct_info(data) else {
290        unimplemented!(
291            "`Dump` can only be derived for structs with named fields or with a single unnamed `u8` field"
292        )
293    };
294
295    if kind == StructKind::TupleStruct {
296        let doc = get_doc(&fields[0].attrs)
297            .map(|s| format!("  // {}", s.trim()))
298            .unwrap_or_default();
299
300        return quote! {
301            impl #impl_generics grib_template_helpers::Dump for #name #type_generics #where_clause {
302                fn dump<W: std::io::Write>(
303                    &self,
304                    parent: Option<&std::borrow::Cow<str>>,
305                    pos: &mut usize,
306                    output: &mut W,
307                ) -> Result<(), std::io::Error> {
308                    let size = 1;
309                    grib_template_helpers::write_position_column(output, pos, size)?;
310                    if let Some(parent) = parent {
311                        write!(output, "{}", parent)?;
312                    }
313                    writeln!(output, " = {:#010b}{}",
314                        self.0,
315                        #doc,
316                    )
317                }
318            }
319        };
320    }
321
322    let mut dumps = Vec::new();
323
324    for field in fields {
325        let ident = field.ident.as_ref().unwrap();
326        let ty = &field.ty;
327
328        let doc = get_doc(&field.attrs)
329            .map(|s| format!("  // {}", s.trim()))
330            .unwrap_or_default();
331        dumps.push(quote! {
332            <#ty as grib_template_helpers::DumpField>::dump_field(
333                &self.#ident,
334                stringify!(#ident),
335                parent,
336                #doc,
337                pos,
338                output,
339            )?;
340        });
341    }
342
343    quote! {
344        impl #impl_generics grib_template_helpers::Dump for #name #type_generics #where_clause {
345            fn dump<W: std::io::Write>(
346                &self,
347                parent: Option<&std::borrow::Cow<str>>,
348                pos: &mut usize,
349                output: &mut W,
350            ) -> Result<(), std::io::Error> {
351                #(#dumps)*;
352                Ok(())
353            }
354        }
355    }
356}
357
358fn impl_dump_for_enum(input: &syn::DeriveInput, data: &syn::DataEnum) -> proc_macro2::TokenStream {
359    let name = &input.ident;
360    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
361
362    let mut arms = Vec::new();
363
364    for variant in &data.variants {
365        let variant_ident = &variant.ident;
366
367        if let syn::Fields::Unnamed(fields) = &variant.fields
368            && fields.unnamed.len() == 1
369        {
370            let inner_ty = &fields.unnamed.first().unwrap().ty;
371            arms.push(quote! {
372                #name::#variant_ident(inner) => <#inner_ty as grib_template_helpers::Dump>::dump(
373                    inner,
374                    parent,
375                    pos,
376                    output
377                )
378            });
379        } else {
380            unimplemented!("`Dump` only supports single-field tuple variants");
381        }
382    }
383
384    quote! {
385        impl #impl_generics grib_template_helpers::Dump for #name #type_generics #where_clause {
386            fn dump<W: std::io::Write>(
387                &self,
388                parent: Option<&std::borrow::Cow<str>>,
389                pos: &mut usize,
390                output: &mut W,
391            ) -> Result<(), std::io::Error> {
392                match self {
393                    #(#arms),*,
394                }
395            }
396        }
397    }
398}
399
400fn get_doc(attrs: &[syn::Attribute]) -> Option<String> {
401    let mut doc = String::new();
402    for attr in attrs.iter() {
403        match attr.meta {
404            syn::Meta::NameValue(ref value) if value.path.is_ident("doc") => {
405                if let syn::Expr::Lit(lit) = &value.value
406                    && let syn::Lit::Str(s) = &lit.lit
407                {
408                    doc.push_str(&s.value());
409                }
410            }
411            _ => {}
412        }
413    }
414    if doc.is_empty() { None } else { Some(doc) }
415}
416
417fn extract_struct_info(
418    data: &syn::DataStruct,
419) -> Option<(
420    StructKind,
421    &syn::punctuated::Punctuated<syn::Field, syn::token::Comma>,
422)> {
423    match &data.fields {
424        syn::Fields::Named(fields) => Some((StructKind::NamedStruct, &fields.named)),
425        syn::Fields::Unnamed(fields) => {
426            let fields = &fields.unnamed;
427            if fields.len() == 1 && is_type_u8(&fields.first().unwrap().ty) {
428                Some((StructKind::TupleStruct, fields))
429            } else {
430                None
431            }
432        }
433        _ => None,
434    }
435}
436
437#[derive(PartialEq)]
438enum StructKind {
439    TupleStruct,
440    NamedStruct,
441}
442
443fn is_type_u8(ty: &syn::Type) -> bool {
444    if let syn::Type::Path(syn::TypePath { path, .. }) = ty
445        && let Some(segment) = path.segments.last()
446    {
447        return segment.ident == "u8";
448    }
449    false
450}