cu29_soa_derive/
lib.rs

1mod format;
2#[cfg(feature = "macro_debug")]
3use format::{highlight_rust_code, rustfmt_generated_code};
4use proc_macro::TokenStream;
5use quote::{ToTokens, format_ident, quote};
6use syn::{Data, DeriveInput, Fields, Type, parse_macro_input};
7
8/// Build a fixed sized SoA (Structure of Arrays) from a struct.
9/// The outputted SoA will be suitable for in place storage in messages and should be
10/// easier for the compiler to vectorize.
11///
12/// for example:
13///
14/// ```ignore
15/// #[derive(Soa)]
16/// struct MyStruct {
17///    a: i32,
18///    b: f32,
19/// }
20/// ```
21///
22/// will generate:
23/// ```ignore
24/// pub struct MyStructSoa<const N: usize> {
25///     pub a: [i32; N],
26///     pub b: [f32; N],
27/// }
28/// ```
29///
30/// You can then use the generated struct to store multiple
31/// instances of the original struct in an SoA format.
32///
33/// ```ignore
34/// // makes an SOA with a default value
35/// let soa1: MyStructSoa<8> = XyzSoa::new(MyStruct{ a: 1, b: 2.3 });
36/// ```
37///
38/// Then you can access the fields of the SoA as slices:
39/// ```ignore
40/// let a = soa1.a();
41/// let b = soa1.b();
42/// ```
43///
44/// You can also access a range of the fields:
45/// ```ignore
46/// let a = soa1.a_range(0..4);
47/// let b = soa1.b_range(0..4);
48/// ```
49///
50/// You can also modify the fields of the SoA:
51/// ```ignore
52/// soa1.a_mut()[0] = 42;
53/// soa1.b_mut()[0] = 42.0;
54/// ```
55///
56/// You can also modify a range of the fields:
57/// ```ignore
58/// soa1.a_range_mut(0..4)[0] = 42;
59/// soa1.b_range_mut(0..4)[0] = 42.0;
60/// ```
61///
62/// You can also apply a function to all the fields of the SoA:
63/// ```ignore
64/// soa1.apply(|a, b| {
65///    (a + 1, b + 1.0)
66/// });
67/// ```
68#[proc_macro_derive(Soa)]
69pub fn derive_soa(input: TokenStream) -> TokenStream {
70    use syn::TypePath;
71
72    let input = parse_macro_input!(input as DeriveInput);
73    let visibility = &input.vis;
74
75    let name = &input.ident;
76    let module_name = format_ident!("{}_soa", name.to_string().to_lowercase());
77    let soa_struct_name = format_ident!("{}Soa", name);
78    let soa_struct_wire_name = format_ident!("{}SoaWire", name);
79
80    let data = match &input.data {
81        Data::Struct(data) => data,
82        _ => panic!("Only structs are supported"),
83    };
84    let fields = match &data.fields {
85        Fields::Named(fields) => &fields.named,
86        _ => panic!("Only named fields are supported"),
87    };
88
89    let mut field_names = vec![];
90    let mut field_names_mut = vec![];
91    let mut field_names_range = vec![];
92    let mut field_names_range_mut = vec![];
93    let mut field_types = vec![];
94    let mut unique_imports = vec![];
95    let mut unique_import_names = vec![];
96
97    fn is_primitive(type_name: &str) -> bool {
98        matches!(
99            type_name,
100            "i8" | "i16"
101                | "i32"
102                | "i64"
103                | "i128"
104                | "u8"
105                | "u16"
106                | "u32"
107                | "u64"
108                | "u128"
109                | "f32"
110                | "f64"
111                | "bool"
112                | "char"
113                | "str"
114                | "usize"
115                | "isize"
116        )
117    }
118
119    for field in fields {
120        let field_name = field.ident.as_ref().unwrap();
121        let field_type = &field.ty;
122        field_names.push(field_name);
123        field_names_mut.push(format_ident!("{}_mut", field_name));
124        field_names_range.push(format_ident!("{}_range", field_name));
125        field_names_range_mut.push(format_ident!("{}_range_mut", field_name));
126        field_types.push(field_type);
127
128        if let Type::Path(TypePath { path, .. }) = field_type {
129            let type_name = path.segments.last().unwrap().ident.to_string();
130            let path_str = path.to_token_stream().to_string();
131
132            if !is_primitive(&type_name) && !unique_import_names.contains(&path_str) {
133                unique_imports.push(path.clone());
134                unique_import_names.push(path_str);
135            }
136        }
137    }
138
139    let soa_struct_name_iterator = format_ident!("{}Iterator", name);
140    let field_count = field_names.len() + 1; // +1 for the len field
141
142    let iterator = quote! {
143        pub struct #soa_struct_name_iterator<'a, const N: usize> {
144            soa_struct: &'a #soa_struct_name<N>,
145            current: usize,
146        }
147
148        impl<'a, const N: usize> #soa_struct_name_iterator<'a, N> {
149            pub fn new(soa_struct: &'a #soa_struct_name<N>) -> Self {
150                Self {
151                    soa_struct,
152                    current: 0,
153                }
154            }
155        }
156
157        impl<'a, const N: usize> Iterator for #soa_struct_name_iterator<'a, N> {
158            type Item = super::#name;
159
160            fn next(&mut self) -> Option<Self::Item> {
161                if self.current < self.soa_struct.len {
162                    let item = self.soa_struct.get(self.current); // Reuse `get` method
163                    self.current += 1;
164                    Some(item)
165                } else {
166                    None
167                }
168            }
169        }
170    };
171
172    let expanded = quote! {
173        #visibility mod #module_name {
174            use bincode::{Decode, Encode};
175            use bincode::enc::Encoder;
176            use bincode::de::Decoder;
177            use bincode::error::{DecodeError, EncodeError};
178            use serde::Deserialize;
179            use serde::Serialize;
180            use serde::Serializer;
181            use serde::ser::SerializeStruct;
182            use std::ops::{Index, IndexMut};
183            #( use super::#unique_imports; )*
184            use core::array::from_fn;
185
186            #[derive(Debug)]
187            #visibility struct #soa_struct_name<const N: usize> {
188                pub len: usize,
189                #(pub #field_names: [#field_types; N], )*
190            }
191
192            impl<const N: usize> #soa_struct_name<N> {
193                pub fn new(default: super::#name) -> Self {
194                    Self {
195                        #( #field_names: from_fn(|_| default.#field_names.clone()), )*
196                        len: 0,
197                    }
198                }
199
200                pub fn len(&self) -> usize {
201                    self.len
202                }
203
204                pub fn is_empty(&self) -> bool {
205                    self.len == 0
206                }
207
208                pub fn push(&mut self, value: super::#name) {
209                    if self.len < N {
210                        #( self.#field_names[self.len] = value.#field_names.clone(); )*
211                        self.len += 1;
212                    } else {
213                        panic!("Capacity exceeded")
214                    }
215                }
216
217                pub fn pop(&mut self) -> Option<super::#name> {
218                    if self.len == 0 {
219                        None
220                    } else {
221                        self.len -= 1;
222                        Some(super::#name {
223                            #( #field_names: self.#field_names[self.len].clone(), )*
224                        })
225                    }
226                }
227
228                pub fn set(&mut self, index: usize, value: super::#name) {
229                    assert!(index < self.len, "Index out of bounds");
230                    #( self.#field_names[index] = value.#field_names.clone(); )*
231                }
232
233                pub fn get(&self, index: usize) -> super::#name {
234                    assert!(index < self.len, "Index out of bounds");
235                    super::#name {
236                        #( #field_names: self.#field_names[index].clone(), )*
237                    }
238                }
239
240                pub fn apply<F>(&mut self, mut f: F)
241                where
242                    F: FnMut(#(#field_types),*) -> (#(#field_types),*)
243                {
244                    // don't use something common like i here.
245                    for _idx in 0..self.len {
246                        let result = f(#(self.#field_names[_idx].clone()),*);
247                        let (#(#field_names),*) = result;
248                        #(
249                            self.#field_names[_idx] = #field_names;
250                        )*
251                    }
252                }
253
254                pub fn iter(&self) -> #soa_struct_name_iterator<N> {
255                    #soa_struct_name_iterator::new(self)
256                }
257
258                #(
259                    pub fn #field_names(&self) -> &[#field_types] {
260                        &self.#field_names
261                    }
262
263                    pub fn #field_names_mut(&mut self) -> &mut [#field_types] {
264                        &mut self.#field_names
265                    }
266
267                    pub fn #field_names_range(&self, range: std::ops::Range<usize>) -> &[#field_types] {
268                        &self.#field_names[range]
269                    }
270
271                    pub fn #field_names_range_mut(&mut self, range: std::ops::Range<usize>) -> &mut [#field_types] {
272                        &mut self.#field_names[range]
273                    }
274                )*
275            }
276
277            impl<const N: usize> Encode for #soa_struct_name<N> {
278                fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
279                    self.len.encode(encoder)?;
280                    #( self.#field_names[..self.len].encode(encoder)?; )*
281                    Ok(())
282                }
283            }
284
285            impl<const N: usize> Decode<()> for #soa_struct_name<N> {
286                fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
287                    let mut result = Self::default();
288                    result.len = Decode::decode(decoder)?;
289                    #(
290                        for _idx in 0..result.len {
291                            result.#field_names[_idx] = Decode::decode(decoder)?;
292                        }
293                    )*
294                    Ok(result)
295                }
296            }
297
298            impl<const N: usize> Default for #soa_struct_name<N> {
299                fn default() -> Self {
300                    Self {
301                        #( #field_names: from_fn(|_| #field_types::default()), )*
302                        len: 0,
303                    }
304                }
305            }
306
307            impl<const N: usize> Clone for #soa_struct_name<N>
308            where
309                #(
310                    #field_types: Clone,
311                )*
312            {
313                fn clone(&self) -> Self {
314                    Self {
315                        #( #field_names: self.#field_names.clone(), )*
316                        len: self.len,
317                    }
318                }
319            }
320
321            impl<const N: usize> Serialize for #soa_struct_name<N>
322            where
323                #(
324                    #field_types: Serialize,
325                )*
326            {
327                fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
328                where
329                    S: Serializer,
330                {
331                    let mut state = serializer.serialize_struct(stringify!(#soa_struct_name), #field_count)?;
332                    state.serialize_field("len", &self.len)?;
333                    #(
334                        state.serialize_field(stringify!(#field_names), &self.#field_names[..self.len])?;
335                    )*
336                    state.end()
337                }
338            }
339
340            impl<'de, const N: usize> Deserialize<'de> for #soa_struct_name<N>
341            where
342                #(
343                    #field_types: Deserialize<'de> + Default,
344                )*
345            {
346                fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
347                where
348                    D: serde::Deserializer<'de>,
349                {
350                    #[derive(Deserialize)]
351                    struct #soa_struct_wire_name {
352                        len: usize,
353                        #( #field_names: Vec<#field_types>, )*
354                    }
355
356                    let wire = #soa_struct_wire_name::deserialize(deserializer)?;
357                    let #soa_struct_wire_name { len, #( #field_names ),* } = wire;
358
359                    if len > N {
360                        return Err(serde::de::Error::custom(format!(
361                            "len {} exceeds capacity {}",
362                            len,
363                            N
364                        )));
365                    }
366
367                    #(
368                        if #field_names.len() != len {
369                            return Err(serde::de::Error::custom(format!(
370                                "field {} has length {} but len is {}",
371                                stringify!(#field_names),
372                                #field_names.len(),
373                                len
374                            )));
375                        }
376                    )*
377
378                    let mut result = Self::default();
379                    result.len = len;
380                    #(
381                        for (idx, value) in #field_names.into_iter().enumerate() {
382                            result.#field_names[idx] = value;
383                        }
384                    )*
385                    Ok(result)
386                }
387            }
388
389            #iterator
390
391        }
392        #visibility use #module_name::#soa_struct_name;
393    };
394
395    let tokens: TokenStream = expanded.into();
396
397    #[cfg(feature = "macro_debug")]
398    {
399        let formatted_code = rustfmt_generated_code(tokens.to_string());
400        eprintln!("\n     ===    Gen. SOA     ===\n");
401        eprintln!("{}", highlight_rust_code(formatted_code));
402        eprintln!("\n     === === === === === ===\n");
403    }
404    tokens
405}