netcdf_derive/
lib.rs

1use proc_macro2::{Ident, TokenStream};
2use proc_macro_error::{abort, proc_macro_error};
3use quote::quote;
4use syn::{
5    parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Fields, FieldsNamed, LitStr, Type,
6    Variant,
7};
8
9#[proc_macro_derive(NcType, attributes(netcdf))]
10/// Derives `NcTypeDescriptor` for user defined types.
11///
12/// See the documentation under `netcdf::TypeDescriptor` for examples
13///
14/// Use `#[netcdf(rename = "name")]` to
15/// rename field names or enum member names, or the name
16/// of the compound/enum.
17///
18/// Types one derives `NcType` for must have some properties to
19/// ensure correctness:
20/// * Structs must have `repr(C)` to ensure layout compatibility
21/// * Structs must be packed (no padding allowed)
22/// * Enums must have `repr(T)` where `T` is an int type (`{i/u}{8/16/32/64}`)
23#[proc_macro_error]
24pub fn derive(stream: proc_macro::TokenStream) -> proc_macro::TokenStream {
25    let input = parse_macro_input!(stream as DeriveInput);
26    let name = &input.ident;
27    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
28
29    let mut renamed = None;
30    let mut repr_c = false;
31    for attr in &input.attrs {
32        if attr.path().is_ident("netcdf") {
33            attr.parse_nested_meta(|meta| {
34                if meta.path.is_ident("rename") {
35                    renamed = Some(meta.value()?.parse::<LitStr>()?.value());
36                } else {
37                    abort!(meta.path, "NcType encountered an unknown attribute");
38                }
39                Ok(())
40            })
41            .unwrap();
42        } else if attr.path().is_ident("repr") {
43            attr.parse_nested_meta(|meta| {
44                if meta.path.is_ident("C") {
45                    repr_c = true;
46                }
47                Ok(())
48            })
49            .unwrap();
50        }
51    }
52    let ncname = renamed.unwrap_or_else(|| name.to_string());
53
54    let body = match input.data {
55        Data::Struct(DataStruct {
56            struct_token: _,
57            ref fields,
58            semi_token: _,
59        }) => {
60            if !repr_c {
61                abort!(
62                    input,
63                    "Can not derive NcType for struct without fixed layout";
64                    help = "struct must have attribute #[repr(C)]"
65                );
66            }
67            match fields {
68                Fields::Named(fields) => impl_compound(name, &ncname, fields.clone()),
69                Fields::Unnamed(f) => {
70                    abort!(f, "Can not derive NcType for struct with unnamed field"; note="#[derive(NcType)]")
71                }
72                Fields::Unit => abort!(input, "Can not derive NcType for unit struct"),
73            }
74        }
75        Data::Enum(DataEnum {
76            enum_token: _,
77            brace_token: _,
78            ref variants,
79        }) => {
80            let mut basetyp = None;
81            for attr in &input.attrs {
82                if attr.path().is_ident("repr") {
83                    attr.parse_nested_meta(|meta| {
84                        for item in ["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"] {
85                            if meta.path.is_ident(item) {
86                                basetyp = Some(meta.path.get_ident().unwrap().clone());
87                            }
88                        }
89                        Ok(())
90                    })
91                    .unwrap();
92                }
93            }
94            let Some(basetyp) = basetyp else {
95                abort!(
96                    input,
97                    "Can not derive NcType for enum without suitable repr";
98                    help="Add #[repr(i32)] (or another integer type) as an attribute to the enum"
99                );
100            };
101            impl_enum(/*&name,*/ &ncname, &basetyp, variants.iter())
102        }
103        Data::Union(_) => abort!(
104            input,
105            "Can not derive NcType for union";
106            note = "netCDF has no concept of Union type"
107        ),
108    };
109
110    let expanded = quote! {
111        const _: () = {
112            use netcdf::types::*;
113
114            #[automatically_derived]
115            unsafe impl #impl_generics NcTypeDescriptor for #name #ty_generics #where_clause {
116                fn type_descriptor() -> NcVariableType {
117                    #body
118                }
119            }
120        };
121    };
122    proc_macro::TokenStream::from(expanded)
123}
124
125fn impl_compound(ty: &Ident, ncname: &str, fields: FieldsNamed) -> TokenStream {
126    struct FieldInfo {
127        name: String,
128        typ: Type,
129    }
130    let mut items: Vec<FieldInfo> = vec![];
131
132    for field in fields.named {
133        let ident = field.ident.expect("Field must have a name").clone();
134        let mut rename = None;
135        for attr in field.attrs {
136            if attr.path().is_ident("netcdf") {
137                attr.parse_nested_meta(|meta| {
138                    if meta.path.is_ident("rename") {
139                        rename = Some(meta.value()?.parse::<LitStr>()?.value());
140                    } else {
141                        abort!(meta.path, "NcType encountered an unknown attribute")
142                    }
143                    Ok(())
144                })
145                .unwrap();
146            }
147        }
148        let name = rename.unwrap_or_else(|| ident.to_string());
149        items.push(FieldInfo {
150            name,
151            typ: field.ty,
152        });
153    }
154
155    let fieldnames = items
156        .iter()
157        .map(|item| item.name.clone())
158        .collect::<Vec<_>>();
159    let typeids = items
160        .iter()
161        .map(|item| item.typ.clone())
162        .collect::<Vec<_>>();
163    let fieldinfo = quote!(vec![#(
164            (
165                (#fieldnames).to_owned(),
166                <#typeids as NcTypeDescriptor>::type_descriptor(),
167                (<#typeids as NcTypeDescriptor>::ARRAY_ELEMENTS).as_dims().map(Vec::from),
168            )
169            ),*]);
170
171    quote! {
172        let mut fields = vec![];
173        let mut offset = 0;
174        for (name, basetype, arraydims) in #fieldinfo {
175            let nelems = arraydims.as_ref().map_or(1, |x| x.iter().copied().product());
176            let thissize = basetype.size() * nelems;
177            fields.push(CompoundTypeField {
178                name,
179                offset,
180                basetype,
181                arraydims,
182            });
183            offset += thissize;
184        }
185        let compound = NcVariableType::Compound(CompoundType {
186            name: (#ncname).to_owned(),
187            size: offset,
188            fields,
189        });
190        assert_eq!(compound.size(), std::mem::size_of::<#ty>(), "Compound must be packed");
191        compound
192    }
193}
194
195fn impl_enum<'a>(
196    // ty: &Ident,
197    ncname: &str,
198    basetyp: &Ident,
199    fields: impl Iterator<Item = &'a Variant>,
200) -> TokenStream {
201    let mut fieldnames = vec![];
202    let mut fieldvalues = vec![];
203
204    for field in fields {
205        let ident = field.ident.clone();
206        let mut rename = None;
207        for attr in &field.attrs {
208            if attr.path().is_ident("netcdf") {
209                attr.parse_nested_meta(|meta| {
210                    if meta.path.is_ident("rename") {
211                        rename = Some(meta.value()?.parse::<LitStr>()?.value());
212                    } else {
213                        abort!(meta.path, "NcType encountered an unknown attribute")
214                    }
215                    Ok(())
216                })
217                .unwrap();
218            }
219        }
220        let name = rename.unwrap_or_else(|| ident.to_string());
221        fieldnames.push(name);
222
223        let variant = match field.discriminant.clone() {
224            Some((_, x)) => quote!(#x),
225            None => fieldvalues
226                .last()
227                .map(|e| quote!(#e + 1))
228                .unwrap_or(quote!(0)),
229        };
230
231        fieldvalues.push(variant);
232    }
233
234    let fieldnames = quote!(vec![#(#fieldnames),*]);
235    let fieldvalues = quote!(vec![#(#fieldvalues),*]);
236
237    quote! {
238        NcVariableType::Enum(EnumType {
239            name: (#ncname).to_owned(),
240            fieldnames: (#fieldnames).iter().map(|x| x.to_string()).collect(),
241            fieldvalues: ((#fieldvalues).into_iter().collect::<Vec::<#basetyp>>()).into(),
242        })
243    }
244}