generic_array_struct/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use heck::ToShoutySnakeCase;
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use syn::{
7    parse::{Parse, ParseStream},
8    parse_macro_input,
9    token::{Bracket, Paren, Semi},
10    Data, DeriveInput, Expr, ExprPath, Field, Fields, FieldsUnnamed, GenericParam, Ident, Path,
11    PathSegment, Type, TypeArray, TypePath, Visibility,
12};
13
14struct AttrArgs {
15    array_field_vis: Visibility,
16}
17
18impl Parse for AttrArgs {
19    fn parse(input: ParseStream) -> syn::Result<Self> {
20        if input.is_empty() {
21            return Ok(Self {
22                array_field_vis: Visibility::Inherited,
23            });
24        }
25        Ok(Self {
26            array_field_vis: input.parse()?,
27        })
28    }
29}
30
31/// The main attribute proc macro. See crate docs for usage.
32#[proc_macro_attribute]
33pub fn generic_array_struct(attr_arg: TokenStream, input: TokenStream) -> TokenStream {
34    let AttrArgs { array_field_vis } = parse_macro_input!(attr_arg as AttrArgs);
35
36    let mut input = parse_macro_input!(input as DeriveInput);
37
38    let struct_vis = &input.vis;
39    let struct_ident = &input.ident;
40    let Data::Struct(data_struct) = &mut input.data else {
41        panic!("{MACRO_NAME} only works with structs");
42    };
43    let Fields::Named(fields) = &data_struct.fields else {
44        panic!("{MACRO_NAME} only works with structs with named fields");
45    };
46    let mut generic_iter = input.generics.params.iter();
47    let Some(GenericParam::Type(generic)) = generic_iter.next() else {
48        panic!("{MACRO_NAME} {REQ_SINGLE_GENERIC_TYPE_PARAM_ERRMSG}");
49    };
50    if generic_iter.next().is_some() {
51        panic!("{MACRO_NAME} {REQ_SINGLE_GENERIC_TYPE_PARAM_ERRMSG}");
52    }
53    let generic_param_ident = &generic.ident;
54
55    let (n_fields, fields_idx_consts, accessor_mutator_impls, const_with_impls) =
56        fields.named.iter().enumerate().fold(
57            (0usize, quote! {}, quote! {}, quote! {}),
58            |(
59                n_fields,
60                mut fields_idx_consts,
61                mut accessor_mutator_impls,
62                mut const_with_impls,
63            ),
64             (i, field)| {
65                let Type::Path(expect_same_generic) = &field.ty else {
66                    panic!("{MACRO_NAME} {REQ_ALL_FIELDS_SAME_GENERIC_TYPE_ERRMSG}")
67                };
68                if !expect_same_generic
69                    .path
70                    .get_ident()
71                    .map(|id| id == generic_param_ident)
72                    .unwrap_or(false)
73                {
74                    panic!("{MACRO_NAME} {REQ_ALL_FIELDS_SAME_GENERIC_TYPE_ERRMSG}")
75                }
76
77                let field_vis = &field.vis;
78                // unwrap-safety: named field checked above
79                let field_ident = field.ident.as_ref().unwrap();
80
81                // pub const RGB_IDX_R: usize = 0;
82                let idx_ident = field_idx_ident(struct_ident, field_ident);
83                fields_idx_consts.extend(quote! {
84                    #field_vis const #idx_ident: usize = #i;
85                });
86
87                // fn r(), r_mut(), set_r(), with_r()
88                let ident_mut = format_ident!("{field_ident}_mut");
89                let set_ident = format_ident!("set_{field_ident}");
90                let with_ident = format_ident!("with_{field_ident}");
91                // preserve attributes such as doc comments on getter method
92                let field_attrs = &field.attrs;
93                accessor_mutator_impls.extend(quote! {
94                    #(#field_attrs)*
95                    #[inline]
96                    #field_vis const fn #field_ident(&self) -> &T {
97                        &self.0[#idx_ident]
98                    }
99
100                    #[inline]
101                    #field_vis const fn #ident_mut(&mut self) -> &mut T {
102                        &mut self.0[#idx_ident]
103                    }
104
105                    /// Returns the old field value
106                    #[inline]
107                    #field_vis const fn #set_ident(&mut self, val: T) -> T {
108                        core::mem::replace(&mut self.0[#idx_ident], val)
109                    }
110
111                    #[inline]
112                    #field_vis fn #with_ident(mut self, val: T) -> Self {
113                        self.0[#idx_ident] = val;
114                        self
115                    }
116                });
117
118                // fn const_with_r()
119                let const_with_ident = format_ident!("const_with_{field_ident}");
120                const_with_impls.extend(quote! {
121                    #[inline]
122                    #field_vis const fn #const_with_ident(mut self, val: T) -> Self {
123                        self.0[#idx_ident] = val;
124                        self
125                    }
126                });
127
128                (
129                    n_fields + 1,
130                    fields_idx_consts,
131                    accessor_mutator_impls,
132                    const_with_impls,
133                )
134            },
135        );
136
137    let len_ident = array_len_ident(struct_ident);
138
139    // finally, replace the struct defn with a single array field tuple struct
140    data_struct.fields = Fields::Unnamed(FieldsUnnamed {
141        paren_token: Paren::default(),
142        unnamed: core::iter::once(Field {
143            vis: array_field_vis,
144            attrs: Vec::new(),
145            mutability: syn::FieldMutability::None,
146            ident: None,
147            colon_token: None,
148            ty: Type::Array(TypeArray {
149                bracket_token: Bracket::default(),
150                elem: Box::new(Type::Path(TypePath {
151                    qself: None,
152                    path: path_from_var(generic_param_ident.clone()),
153                })),
154                semi_token: Semi::default(),
155                len: Expr::Path(ExprPath {
156                    attrs: Vec::new(),
157                    qself: None,
158                    path: path_from_var(len_ident.clone()),
159                }),
160            }),
161        })
162        .collect(),
163    });
164
165    quote! {
166        #input
167
168        #struct_vis const #len_ident: usize = #n_fields;
169
170        impl<T> #struct_ident<T> {
171            #accessor_mutator_impls
172        }
173
174        impl<T: Copy> #struct_ident<T> {
175            #const_with_impls
176        }
177
178        #fields_idx_consts
179    }
180    .into()
181}
182
183/// e.g. RGB_LEN
184fn array_len_ident(struct_ident: &Ident) -> Ident {
185    format_ident!("{}_LEN", struct_ident.to_string().to_shouty_snake_case())
186}
187
188/// e.g. RGB_IDX_R
189fn field_idx_ident(struct_ident: &Ident, field_ident: &Ident) -> Ident {
190    format_ident!(
191        "{}_IDX_{}",
192        struct_ident.to_string().to_shouty_snake_case(),
193        field_ident.to_string().to_shouty_snake_case()
194    )
195}
196
197/// A plain path with a single segment
198/// e.g.
199/// - `T` (as in generic type param)
200/// - `RGB_LEN`
201fn path_from_var(ident: Ident) -> Path {
202    Path {
203        leading_colon: None,
204        segments: core::iter::once(PathSegment {
205            ident,
206            arguments: Default::default(),
207        })
208        .collect(),
209    }
210}
211
212const MACRO_NAME: &str = "generic_array_struct";
213
214const REQ_SINGLE_GENERIC_TYPE_PARAM_ERRMSG: &str =
215    "only works with structs with a single generic type param";
216
217const REQ_ALL_FIELDS_SAME_GENERIC_TYPE_ERRMSG: &str =
218    "requires all fields to have the same generic type";