soa_rs_derive/
lib.rs

1//! This crate provides the derive macro for Soars.
2
3mod fields;
4mod zst;
5
6use fields::{fields_struct, FieldKind};
7use proc_macro::TokenStream;
8use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
9use quote::{quote, quote_spanned};
10use std::{
11    error::Error,
12    fmt::{self, Display, Formatter},
13};
14use syn::{parse_macro_input, Attribute, Data, DeriveInput, Fields};
15use zst::{zst_struct, ZstKind};
16
17#[proc_macro_derive(Soars, attributes(align, soa_derive, soa_array))]
18pub fn soa(input: TokenStream) -> TokenStream {
19    let input: DeriveInput = parse_macro_input!(input);
20    let span = input.ident.span();
21    match soa_inner(input) {
22        Ok(tokens) => tokens,
23        Err(e) => match e {
24            SoarsError::NotAStruct => quote_spanned! {
25                span => compile_error!("Soars only applies to structs");
26            },
27            SoarsError::Syn(e) => e.into_compile_error(),
28        },
29    }
30    .into()
31}
32
33fn soa_inner(input: DeriveInput) -> Result<TokenStream2, SoarsError> {
34    let DeriveInput {
35        ident,
36        vis,
37        data,
38        attrs,
39        generics: _,
40    } = input;
41
42    let attrs = SoaAttrs::new(attrs)?;
43    match data {
44        Data::Struct(strukt) => match strukt.fields {
45            Fields::Named(fields) => Ok(fields_struct(
46                ident,
47                vis,
48                fields.named,
49                FieldKind::Named,
50                attrs,
51            )?),
52            Fields::Unnamed(fields) => Ok(fields_struct(
53                ident,
54                vis,
55                fields.unnamed,
56                FieldKind::Unnamed,
57                attrs,
58            )?),
59            Fields::Unit => Ok(zst_struct(ident, vis, ZstKind::Unit)),
60        },
61        Data::Enum(_) | Data::Union(_) => Err(SoarsError::NotAStruct),
62    }
63}
64
65#[derive(Debug, Clone)]
66enum SoarsError {
67    NotAStruct,
68    Syn(syn::Error),
69}
70
71impl From<syn::Error> for SoarsError {
72    fn from(value: syn::Error) -> Self {
73        Self::Syn(value)
74    }
75}
76
77#[derive(Debug, Clone)]
78struct SoaAttrs {
79    pub derive: SoaDerive,
80    pub include_array: bool,
81}
82
83impl SoaAttrs {
84    pub fn new(attributes: Vec<Attribute>) -> Result<Self, syn::Error> {
85        let mut derive_parse = SoaDeriveParse::new();
86        let mut include_array = false;
87        for attr in attributes {
88            let path = attr.path();
89            if path.is_ident("soa_derive") {
90                derive_parse.append(attr)?;
91            } else if path.is_ident("soa_array") {
92                include_array = true;
93            }
94        }
95
96        Ok(Self {
97            derive: derive_parse.into_derive(),
98            include_array,
99        })
100    }
101}
102
103#[derive(Debug, Clone, Default)]
104struct SoaDeriveParse {
105    r#ref: Vec<syn::Path>,
106    ref_mut: Vec<syn::Path>,
107    slices: Vec<syn::Path>,
108    slices_mut: Vec<syn::Path>,
109    array: Vec<syn::Path>,
110}
111
112impl SoaDeriveParse {
113    pub fn new() -> Self {
114        Self {
115            r#ref: copy_clone(),
116            ref_mut: vec![],
117            slices: copy_clone(),
118            slices_mut: vec![],
119            array: vec![],
120        }
121    }
122
123    fn into_derive(self) -> SoaDerive {
124        let Self {
125            r#ref: reff,
126            ref_mut,
127            slices,
128            slices_mut,
129            array,
130        } = self;
131        SoaDerive {
132            r#ref: quote! {
133                #[derive(#(#reff),*)]
134            },
135            ref_mut: quote! {
136                #[derive(#(#ref_mut),*)]
137            },
138            slices: quote! {
139                #[derive(#(#slices),*)]
140            },
141            slices_mut: quote! {
142                #[derive(#(#slices_mut),*)]
143            },
144            array: quote! {
145                #[derive(#(#array),*)]
146            },
147        }
148    }
149
150    pub fn append(&mut self, attr: Attribute) -> Result<(), syn::Error> {
151        let mut collected = vec![];
152        let mut mask = SoaDeriveMask::new();
153        attr.parse_nested_meta(|meta| {
154            if meta.path.is_ident("include") {
155                mask = SoaDeriveMask::splat(false);
156                meta.parse_nested_meta(|meta| {
157                    mask.set_by_path(&meta.path, true).map_err(|_| {
158                        meta.error(format!("unknown include specifier {:?}", meta.path))
159                    })
160                })?;
161            } else if meta.path.is_ident("exclude") {
162                meta.parse_nested_meta(|meta| {
163                    mask.set_by_path(&meta.path, false).map_err(|_| {
164                        meta.error(format!("unknown exclude specifier {:?}", meta.path))
165                    })
166                })?;
167            } else {
168                collected.push(meta.path);
169            }
170            Ok(())
171        })?;
172
173        let to_extend = mask
174            .r#ref
175            .then_some(&mut self.r#ref)
176            .into_iter()
177            .chain(mask.ref_mut.then_some(&mut self.ref_mut))
178            .chain(mask.slice.then_some(&mut self.slices))
179            .chain(mask.slice_mut.then_some(&mut self.slices_mut))
180            .chain(mask.array.then_some(&mut self.array));
181
182        for set in to_extend {
183            set.extend(collected.iter().cloned());
184        }
185
186        Ok(())
187    }
188}
189
190fn copy_clone() -> Vec<syn::Path> {
191    vec![str_to_path("Copy"), str_to_path("Clone")]
192}
193
194fn str_to_path(s: &str) -> syn::Path {
195    syn::Path::from(syn::PathSegment {
196        ident: Ident::new(s, Span::call_site()),
197        arguments: syn::PathArguments::None,
198    })
199}
200
201#[derive(Debug, Clone, Default)]
202struct SoaDerive {
203    pub r#ref: TokenStream2,
204    pub ref_mut: TokenStream2,
205    pub slices: TokenStream2,
206    pub slices_mut: TokenStream2,
207    pub array: TokenStream2,
208}
209
210#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
211struct SoaDeriveMask {
212    pub r#ref: bool,
213    pub ref_mut: bool,
214    pub slice: bool,
215    pub slice_mut: bool,
216    pub array: bool,
217}
218
219impl SoaDeriveMask {
220    pub const fn new() -> Self {
221        Self::splat(true)
222    }
223
224    pub const fn splat(value: bool) -> Self {
225        Self {
226            r#ref: value,
227            ref_mut: value,
228            slice: value,
229            slice_mut: value,
230            array: value,
231        }
232    }
233
234    pub fn set_by_path(&mut self, path: &syn::Path, value: bool) -> Result<(), SetByPathError> {
235        if path.is_ident("Ref") {
236            self.r#ref = value;
237        } else if path.is_ident("RefMut") {
238            self.ref_mut = value;
239        } else if path.is_ident("Slices") {
240            self.slice = value;
241        } else if path.is_ident("SlicesMut") {
242            self.slice_mut = value;
243        } else if path.is_ident("Array") {
244            self.array = value;
245        } else {
246            return Err(SetByPathError);
247        }
248        Ok(())
249    }
250}
251
252impl Default for SoaDeriveMask {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
259struct SetByPathError;
260
261impl Display for SetByPathError {
262    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
263        write!(f, "unknown mask specifier")
264    }
265}
266
267impl Error for SetByPathError {}