generic_array_struct/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use builder::impl_builder;
4use destr::impl_destr;
5use errs::{
6    panic_only_works_with_structs, panic_only_works_with_structs_with_named_fields,
7    panic_req_all_fields_same_generic, panic_req_single_generic,
8};
9use idents::{
10    array_len_ident, const_with_ident, field_idx_ident, ident_mut, set_ident, with_ident,
11};
12use proc_macro::TokenStream;
13use quote::quote;
14use syn::{
15    parse::{Parse, ParseStream},
16    parse_macro_input,
17    token::{Bracket, Paren, Semi},
18    Attribute, Data, DataStruct, DeriveInput, Expr, ExprPath, Field, Fields, FieldsNamed,
19    FieldsUnnamed, GenericParam, Ident, Type, TypeArray, TypePath, Visibility,
20};
21use utils::path_from_ident;
22
23use crate::{idents::assoc_field_idx_ident, trymap::impl_trymap, zip::impl_zip};
24
25mod builder;
26mod destr;
27mod errs;
28mod idents;
29mod trymap;
30mod utils;
31mod zip;
32
33const MACRO_NAME: &str = "generic_array_struct";
34
35#[repr(transparent)]
36struct GenericArrayStructParams(DeriveInput);
37
38/// Accessors
39impl GenericArrayStructParams {
40    #[inline]
41    pub fn struct_vis(&self) -> &Visibility {
42        &self.0.vis
43    }
44
45    #[inline]
46    pub fn struct_ident(&self) -> &Ident {
47        &self.0.ident
48    }
49
50    #[inline]
51    pub fn generic_ident(&self) -> &Ident {
52        let mut generic_iter = self.0.generics.params.iter();
53        let generic = match generic_iter.next() {
54            Some(GenericParam::Type(g)) => g,
55            _ => panic_req_single_generic(),
56        };
57        if generic_iter.next().is_some() {
58            panic_req_single_generic();
59        }
60        &generic.ident
61    }
62
63    #[inline]
64    pub fn data_struct(&self) -> &DataStruct {
65        match &self.0.data {
66            Data::Struct(ds) => ds,
67            _ => panic_only_works_with_structs(),
68        }
69    }
70
71    #[inline]
72    pub fn data_struct_mut(&mut self) -> &mut DataStruct {
73        match &mut self.0.data {
74            Data::Struct(ds) => ds,
75            _ => panic_only_works_with_structs(),
76        }
77    }
78
79    #[inline]
80    pub fn fields_named(&self) -> &FieldsNamed {
81        match &self.data_struct().fields {
82            Fields::Named(f) => f,
83            _ => panic_only_works_with_structs_with_named_fields(),
84        }
85    }
86
87    #[inline]
88    pub fn attrs(&self) -> &[Attribute] {
89        &self.0.attrs
90    }
91}
92
93struct AttrArgs {
94    array_field_vis: Visibility,
95    flags: Flags,
96}
97
98// be pretty funny if this was a #[generic_array_struct]
99#[derive(Default)]
100struct Flags {
101    builder: bool,
102    destr: bool,
103    trymap: bool,
104    zip: bool,
105}
106
107const FLAGS_LEN: usize = core::mem::size_of::<Flags>();
108
109fn set_flag_checked(r: &mut bool, name: &'static str) {
110    if *r {
111        panic!("`{name}` already set");
112    }
113    *r = true;
114}
115
116impl Parse for AttrArgs {
117    fn parse(input: ParseStream) -> syn::Result<Self> {
118        let mut flags = Flags::default();
119        let Flags {
120            builder,
121            destr,
122            trymap,
123            zip,
124        } = &mut flags;
125
126        for i in 0..FLAGS_LEN {
127            if !input.peek(Ident) {
128                break;
129            }
130
131            let id: Ident = input.parse()?;
132            // cant match here, ident is not str
133            if id == "all" {
134                if i != 0 {
135                    panic!("`all` must not be used with other args");
136                }
137
138                *builder = true;
139                *destr = true;
140                *trymap = true;
141                *zip = true;
142
143                break;
144            } else if id == "builder" {
145                set_flag_checked(builder, "builder");
146            } else if id == "destr" {
147                set_flag_checked(destr, "destr");
148            } else if id == "trymap" {
149                set_flag_checked(trymap, "trymap");
150            } else if id == "zip" {
151                set_flag_checked(zip, "zip");
152            } else {
153                panic!("Expected one of [`all`, `builder`, `destr`, `trymap`, `zip`]")
154            }
155        }
156
157        if input.is_empty() {
158            return Ok(Self {
159                array_field_vis: Visibility::Inherited,
160                flags,
161            });
162        }
163
164        let array_field_vis = input.parse()?;
165        Ok(Self {
166            array_field_vis,
167            flags,
168        })
169    }
170}
171
172/// The main attribute proc macro. See crate docs for usage.
173#[proc_macro_attribute]
174pub fn generic_array_struct(attr_arg: TokenStream, input: TokenStream) -> TokenStream {
175    let AttrArgs {
176        array_field_vis,
177        flags:
178            Flags {
179                builder,
180                destr,
181                trymap,
182                zip,
183            },
184    } = parse_macro_input!(attr_arg as AttrArgs);
185
186    let input = parse_macro_input!(input as DeriveInput);
187    let mut params = GenericArrayStructParams(input);
188
189    let mut fields_idx_consts = quote! {};
190    let mut fields_idx_assoc_consts = quote! {};
191    let mut accessor_mutator_impls = quote! {};
192    let mut const_with_impls = quote! {};
193    let n_fields =
194        params
195            .fields_named()
196            .named
197            .iter()
198            .enumerate()
199            .fold(0usize, |n_fields, (i, field)| {
200                let expect_same_generic = match &field.ty {
201                    Type::Path(g) => g,
202                    _ => panic_req_all_fields_same_generic(),
203                };
204                if !expect_same_generic
205                    .path
206                    .get_ident()
207                    .map(|id| id == params.generic_ident())
208                    .unwrap_or(false)
209                {
210                    panic_req_all_fields_same_generic();
211                }
212
213                let field_vis = &field.vis;
214                // unwrap-safety: named field checked above
215                let field_ident = field.ident.as_ref().unwrap();
216
217                // pub const RGB_IDX_R: usize = 0;
218                let idx_ident = field_idx_ident(params.struct_ident(), field_ident);
219                fields_idx_consts.extend(quote! {
220                    #field_vis const #idx_ident: usize = #i;
221                });
222
223                // associated consts
224                // pub const IDX_R: usize = 0;
225                let assoc_idx_ident = assoc_field_idx_ident(field_ident);
226                fields_idx_assoc_consts.extend(quote! {
227                    #field_vis const #assoc_idx_ident: usize = #i;
228                });
229
230                // fn r(), r_mut(), set_r(), with_r()
231                let id_mut = ident_mut(field_ident);
232                let set_id = set_ident(field_ident);
233                let with_id = with_ident(field_ident);
234                // preserve attributes such as doc comments on getter method
235                let field_attrs = &field.attrs;
236                accessor_mutator_impls.extend(quote! {
237                    #(#field_attrs)*
238                    #[inline]
239                    #field_vis const fn #field_ident(&self) -> &T {
240                        &self.0[#idx_ident]
241                    }
242
243                    #[inline]
244                    #field_vis const fn #id_mut(&mut self) -> &mut T {
245                        &mut self.0[#idx_ident]
246                    }
247
248                    /// Returns the old field value
249                    #[inline]
250                    #field_vis const fn #set_id(&mut self, val: T) -> T {
251                        core::mem::replace(&mut self.0[#idx_ident], val)
252                    }
253
254                    #[inline]
255                    #field_vis fn #with_id(mut self, val: T) -> Self {
256                        self.0[#idx_ident] = val;
257                        self
258                    }
259                });
260
261                // fn const_with_r()
262                let const_with_id = const_with_ident(field_ident);
263                const_with_impls.extend(quote! {
264                    #[inline]
265                    #field_vis const fn #const_with_id(mut self, val: T) -> Self {
266                        self.0[#idx_ident] = val;
267                        self
268                    }
269                });
270
271                n_fields + 1
272            });
273
274    let len_ident = array_len_ident(params.struct_ident());
275
276    let struct_vis = params.struct_vis();
277    let struct_ident = params.struct_ident();
278    let mut res = quote! {
279        #struct_vis const #len_ident: usize = #n_fields;
280
281        impl<T> #struct_ident<T> {
282            #accessor_mutator_impls
283        }
284
285        impl<T: Copy> #struct_ident<T> {
286            #const_with_impls
287        }
288
289        impl<T> #struct_ident<T> {
290            #struct_vis const LEN: usize = #n_fields;
291
292            #fields_idx_assoc_consts
293        }
294
295        #fields_idx_consts
296    };
297
298    if builder {
299        res.extend(impl_builder(&params, struct_vis));
300    }
301
302    if destr {
303        res.extend(impl_destr(&params, struct_vis));
304    }
305
306    if trymap {
307        res.extend(impl_trymap(&params));
308    }
309
310    if zip {
311        res.extend(impl_zip(&params));
312    }
313
314    // finally, replace the struct defn with a single array field tuple struct
315    params.data_struct_mut().fields = Fields::Unnamed(FieldsUnnamed {
316        paren_token: Paren::default(),
317        unnamed: core::iter::once(Field {
318            vis: array_field_vis,
319            attrs: Vec::new(),
320            mutability: syn::FieldMutability::None,
321            ident: None,
322            colon_token: None,
323            ty: Type::Array(TypeArray {
324                bracket_token: Bracket::default(),
325                elem: Box::new(Type::Path(TypePath {
326                    qself: None,
327                    path: path_from_ident(params.generic_ident().clone()),
328                })),
329                semi_token: Semi::default(),
330                len: Expr::Path(ExprPath {
331                    attrs: Vec::new(),
332                    qself: None,
333                    path: path_from_ident(len_ident),
334                }),
335            }),
336        })
337        .collect(),
338    });
339
340    // extend with original input with modified struct defn
341    let GenericArrayStructParams(input) = params;
342    res.extend(quote! { #input });
343
344    res.into()
345}