generic_array_struct/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use heck::ToShoutySnakeCase;
4use proc_macro::TokenStream;
5use quote::{format_ident, quote};
6use syn::{
7    parse::{Parse, ParseStream},
8    parse_macro_input,
9    token::{Bracket, Paren, Semi},
10    Data, DeriveInput, Expr, ExprPath, Field, Fields, FieldsUnnamed, GenericParam, Ident, Path,
11    PathSegment, Type, TypeArray, TypePath, Visibility,
12};
13
14struct AttrArgs {
15    array_field_vis: Visibility,
16}
17
18impl Parse for AttrArgs {
19    fn parse(input: ParseStream) -> syn::Result<Self> {
20        if input.is_empty() {
21            return Ok(Self {
22                array_field_vis: Visibility::Inherited,
23            });
24        }
25        Ok(Self {
26            array_field_vis: input.parse()?,
27        })
28    }
29}
30
31/// The main attribute proc macro. See crate docs for usage.
32#[proc_macro_attribute]
33pub fn generic_array_struct(attr_arg: TokenStream, input: TokenStream) -> TokenStream {
34    let AttrArgs { array_field_vis } = parse_macro_input!(attr_arg as AttrArgs);
35
36    let mut input = parse_macro_input!(input as DeriveInput);
37
38    let struct_vis = &input.vis;
39    let struct_ident = &input.ident;
40    let data_struct = match &mut input.data {
41        Data::Struct(ds) => ds,
42        _ => panic!("{MACRO_NAME} only works with structs"),
43    };
44    let fields = match &data_struct.fields {
45        Fields::Named(f) => f,
46        _ => panic!("{MACRO_NAME} only works with structs with named fields"),
47    };
48    let mut generic_iter = input.generics.params.iter();
49    let generic = match generic_iter.next() {
50        Some(GenericParam::Type(g)) => g,
51        _ => panic!("{MACRO_NAME} {REQ_SINGLE_GENERIC_TYPE_PARAM_ERRMSG}"),
52    };
53    if generic_iter.next().is_some() {
54        panic!("{MACRO_NAME} {REQ_SINGLE_GENERIC_TYPE_PARAM_ERRMSG}");
55    }
56    let generic_param_ident = &generic.ident;
57
58    let (n_fields, fields_idx_consts, accessor_mutator_impls, const_with_impls) =
59        fields.named.iter().enumerate().fold(
60            (0usize, quote! {}, quote! {}, quote! {}),
61            |(
62                n_fields,
63                mut fields_idx_consts,
64                mut accessor_mutator_impls,
65                mut const_with_impls,
66            ),
67             (i, field)| {
68                let expect_same_generic = match &field.ty {
69                    Type::Path(g) => g,
70                    _ => panic!("{MACRO_NAME} {REQ_ALL_FIELDS_SAME_GENERIC_TYPE_ERRMSG}"),
71                };
72                if !expect_same_generic
73                    .path
74                    .get_ident()
75                    .map(|id| id == generic_param_ident)
76                    .unwrap_or(false)
77                {
78                    panic!("{MACRO_NAME} {REQ_ALL_FIELDS_SAME_GENERIC_TYPE_ERRMSG}")
79                }
80
81                let field_vis = &field.vis;
82                // unwrap-safety: named field checked above
83                let field_ident = field.ident.as_ref().unwrap();
84
85                // pub const RGB_IDX_R: usize = 0;
86                let idx_ident = field_idx_ident(struct_ident, field_ident);
87                fields_idx_consts.extend(quote! {
88                    #field_vis const #idx_ident: usize = #i;
89                });
90
91                // fn r(), r_mut(), set_r(), with_r()
92                let ident_mut = format_ident!("{field_ident}_mut");
93                let set_ident = format_ident!("set_{field_ident}");
94                let with_ident = format_ident!("with_{field_ident}");
95                // preserve attributes such as doc comments on getter method
96                let field_attrs = &field.attrs;
97                accessor_mutator_impls.extend(quote! {
98                    #(#field_attrs)*
99                    #[inline]
100                    #field_vis const fn #field_ident(&self) -> &T {
101                        &self.0[#idx_ident]
102                    }
103
104                    #[inline]
105                    #field_vis fn #ident_mut(&mut self) -> &mut T {
106                        &mut self.0[#idx_ident]
107                    }
108
109                    /// Returns the old field value
110                    #[inline]
111                    #field_vis fn #set_ident(&mut self, val: T) -> T {
112                        core::mem::replace(&mut self.0[#idx_ident], val)
113                    }
114
115                    #[inline]
116                    #field_vis fn #with_ident(mut self, val: T) -> Self {
117                        self.0[#idx_ident] = val;
118                        self
119                    }
120                });
121
122                // fn const_with_r()
123                let const_with_ident = format_ident!("const_with_{field_ident}");
124                const_with_impls.extend(quote! {
125                    #[inline]
126                    #field_vis const fn #const_with_ident(mut self, val: T) -> Self {
127                        self.0[#idx_ident] = val;
128                        self
129                    }
130                });
131
132                (
133                    n_fields + 1,
134                    fields_idx_consts,
135                    accessor_mutator_impls,
136                    const_with_impls,
137                )
138            },
139        );
140
141    let len_ident = array_len_ident(struct_ident);
142
143    // finally, replace the struct defn with a single array field tuple struct
144    data_struct.fields = Fields::Unnamed(FieldsUnnamed {
145        paren_token: Paren::default(),
146        unnamed: core::iter::once(Field {
147            vis: array_field_vis,
148            attrs: Vec::new(),
149            mutability: syn::FieldMutability::None,
150            ident: None,
151            colon_token: None,
152            ty: Type::Array(TypeArray {
153                bracket_token: Bracket::default(),
154                elem: Box::new(Type::Path(TypePath {
155                    qself: None,
156                    path: path_from_var(generic_param_ident.clone()),
157                })),
158                semi_token: Semi::default(),
159                len: Expr::Path(ExprPath {
160                    attrs: Vec::new(),
161                    qself: None,
162                    path: path_from_var(len_ident.clone()),
163                }),
164            }),
165        })
166        .collect(),
167    });
168
169    quote! {
170        #input
171
172        #struct_vis const #len_ident: usize = #n_fields;
173
174        impl<T> #struct_ident<T> {
175            #accessor_mutator_impls
176        }
177
178        impl<T: Copy> #struct_ident<T> {
179            #const_with_impls
180        }
181
182        #fields_idx_consts
183    }
184    .into()
185}
186
187/// e.g. RGB_LEN
188fn array_len_ident(struct_ident: &Ident) -> Ident {
189    format_ident!("{}_LEN", struct_ident.to_string().to_shouty_snake_case())
190}
191
192/// e.g. RGB_IDX_R
193fn field_idx_ident(struct_ident: &Ident, field_ident: &Ident) -> Ident {
194    format_ident!(
195        "{}_IDX_{}",
196        struct_ident.to_string().to_shouty_snake_case(),
197        field_ident.to_string().to_shouty_snake_case()
198    )
199}
200
201/// A plain path with a single segment
202/// e.g.
203/// - `T` (as in generic type param)
204/// - `RGB_LEN`
205fn path_from_var(ident: Ident) -> Path {
206    Path {
207        leading_colon: None,
208        segments: core::iter::once(PathSegment {
209            ident,
210            arguments: Default::default(),
211        })
212        .collect(),
213    }
214}
215
216const MACRO_NAME: &str = "generic_array_struct";
217
218const REQ_SINGLE_GENERIC_TYPE_PARAM_ERRMSG: &str =
219    "only works with structs with a single generic type param";
220
221const REQ_ALL_FIELDS_SAME_GENERIC_TYPE_ERRMSG: &str =
222    "requires all fields to have the same generic type";