generic_array_struct/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use builder::impl_builder;
4use errs::{
5    panic_only_works_with_structs, panic_only_works_with_structs_with_named_fields,
6    panic_req_all_fields_same_generic, panic_req_single_generic,
7};
8use idents::{
9    array_len_ident, const_with_ident, field_idx_ident, ident_mut, set_ident, with_ident,
10};
11use proc_macro::TokenStream;
12use quote::quote;
13use syn::{
14    parse::{Parse, ParseStream},
15    parse_macro_input,
16    token::{Bracket, Paren, Semi},
17    Data, DataStruct, DeriveInput, Expr, ExprPath, Field, Fields, FieldsNamed, FieldsUnnamed,
18    GenericParam, Ident, Type, TypeArray, TypePath, Visibility,
19};
20use utils::path_from_ident;
21
22mod builder;
23mod errs;
24mod idents;
25mod utils;
26
27const MACRO_NAME: &str = "generic_array_struct";
28
29#[repr(transparent)]
30struct GenericArrayStructParams(DeriveInput);
31
32/// Accessors
33impl GenericArrayStructParams {
34    #[inline]
35    pub fn struct_vis(&self) -> &Visibility {
36        &self.0.vis
37    }
38
39    #[inline]
40    pub fn struct_ident(&self) -> &Ident {
41        &self.0.ident
42    }
43
44    #[inline]
45    pub fn generic_ident(&self) -> &Ident {
46        let mut generic_iter = self.0.generics.params.iter();
47        let generic = match generic_iter.next() {
48            Some(GenericParam::Type(g)) => g,
49            _ => panic_req_single_generic(),
50        };
51        if generic_iter.next().is_some() {
52            panic_req_single_generic();
53        }
54        &generic.ident
55    }
56
57    #[inline]
58    pub fn data_struct(&self) -> &DataStruct {
59        match &self.0.data {
60            Data::Struct(ds) => ds,
61            _ => panic_only_works_with_structs(),
62        }
63    }
64
65    #[inline]
66    pub fn data_struct_mut(&mut self) -> &mut DataStruct {
67        match &mut self.0.data {
68            Data::Struct(ds) => ds,
69            _ => panic_only_works_with_structs(),
70        }
71    }
72
73    #[inline]
74    pub fn fields_named(&self) -> &FieldsNamed {
75        match &self.data_struct().fields {
76            Fields::Named(f) => f,
77            _ => panic_only_works_with_structs_with_named_fields(),
78        }
79    }
80}
81
82struct AttrArgs {
83    array_field_vis: Visibility,
84    should_gen_builder: bool,
85}
86
87impl Parse for AttrArgs {
88    fn parse(input: ParseStream) -> syn::Result<Self> {
89        let should_gen_builder = if input.peek(Ident) {
90            let id: Ident = input.parse()?;
91            if id != "builder" {
92                panic!("Expected token `builder`")
93            } else {
94                true
95            }
96        } else {
97            false
98        };
99
100        if input.is_empty() {
101            return Ok(Self {
102                array_field_vis: Visibility::Inherited,
103                should_gen_builder,
104            });
105        }
106
107        let array_field_vis = input.parse()?;
108        Ok(Self {
109            array_field_vis,
110            should_gen_builder,
111        })
112    }
113}
114
115/// The main attribute proc macro. See crate docs for usage.
116#[proc_macro_attribute]
117pub fn generic_array_struct(attr_arg: TokenStream, input: TokenStream) -> TokenStream {
118    let AttrArgs {
119        array_field_vis,
120        should_gen_builder,
121    } = parse_macro_input!(attr_arg as AttrArgs);
122
123    let input = parse_macro_input!(input as DeriveInput);
124    let mut params = GenericArrayStructParams(input);
125
126    let mut fields_idx_consts = quote! {};
127    let mut accessor_mutator_impls = quote! {};
128    let mut const_with_impls = quote! {};
129    let n_fields =
130        params
131            .fields_named()
132            .named
133            .iter()
134            .enumerate()
135            .fold(0usize, |n_fields, (i, field)| {
136                let expect_same_generic = match &field.ty {
137                    Type::Path(g) => g,
138                    _ => panic_req_all_fields_same_generic(),
139                };
140                if !expect_same_generic
141                    .path
142                    .get_ident()
143                    .map(|id| id == params.generic_ident())
144                    .unwrap_or(false)
145                {
146                    panic_req_all_fields_same_generic();
147                }
148
149                let field_vis = &field.vis;
150                // unwrap-safety: named field checked above
151                let field_ident = field.ident.as_ref().unwrap();
152
153                // pub const RGB_IDX_R: usize = 0;
154                let idx_ident = field_idx_ident(params.struct_ident(), field_ident);
155                fields_idx_consts.extend(quote! {
156                    #field_vis const #idx_ident: usize = #i;
157                });
158
159                // fn r(), r_mut(), set_r(), with_r()
160                let id_mut = ident_mut(field_ident);
161                let set_id = set_ident(field_ident);
162                let with_id = with_ident(field_ident);
163                // preserve attributes such as doc comments on getter method
164                let field_attrs = &field.attrs;
165                accessor_mutator_impls.extend(quote! {
166                    #(#field_attrs)*
167                    #[inline]
168                    #field_vis const fn #field_ident(&self) -> &T {
169                        &self.0[#idx_ident]
170                    }
171
172                    #[inline]
173                    #field_vis fn #id_mut(&mut self) -> &mut T {
174                        &mut self.0[#idx_ident]
175                    }
176
177                    /// Returns the old field value
178                    #[inline]
179                    #field_vis fn #set_id(&mut self, val: T) -> T {
180                        core::mem::replace(&mut self.0[#idx_ident], val)
181                    }
182
183                    #[inline]
184                    #field_vis fn #with_id(mut self, val: T) -> Self {
185                        self.0[#idx_ident] = val;
186                        self
187                    }
188                });
189
190                // fn const_with_r()
191                let const_with_id = const_with_ident(field_ident);
192                const_with_impls.extend(quote! {
193                    #[inline]
194                    #field_vis const fn #const_with_id(mut self, val: T) -> Self {
195                        self.0[#idx_ident] = val;
196                        self
197                    }
198                });
199
200                n_fields + 1
201            });
202
203    let len_ident = array_len_ident(params.struct_ident());
204
205    let struct_vis = params.struct_vis();
206    let struct_ident = params.struct_ident();
207    let mut res = quote! {
208        #struct_vis const #len_ident: usize = #n_fields;
209
210        impl<T> #struct_ident<T> {
211            #accessor_mutator_impls
212        }
213
214        impl<T: Copy> #struct_ident<T> {
215            #const_with_impls
216        }
217
218        #fields_idx_consts
219    };
220
221    if should_gen_builder {
222        res.extend(impl_builder(&params, struct_vis));
223    }
224
225    // finally, replace the struct defn with a single array field tuple struct
226    params.data_struct_mut().fields = Fields::Unnamed(FieldsUnnamed {
227        paren_token: Paren::default(),
228        unnamed: core::iter::once(Field {
229            vis: array_field_vis,
230            attrs: Vec::new(),
231            mutability: syn::FieldMutability::None,
232            ident: None,
233            colon_token: None,
234            ty: Type::Array(TypeArray {
235                bracket_token: Bracket::default(),
236                elem: Box::new(Type::Path(TypePath {
237                    qself: None,
238                    path: path_from_ident(params.generic_ident().clone()),
239                })),
240                semi_token: Semi::default(),
241                len: Expr::Path(ExprPath {
242                    attrs: Vec::new(),
243                    qself: None,
244                    path: path_from_ident(len_ident),
245                }),
246            }),
247        })
248        .collect(),
249    });
250
251    // extend with original input with modified struct defn
252    let GenericArrayStructParams(input) = params;
253    res.extend(quote! { #input });
254
255    res.into()
256}