1mod format;
2#[cfg(feature = "macro_debug")]
3use format::{highlight_rust_code, rustfmt_generated_code};
4use proc_macro::TokenStream;
5use quote::{ToTokens, format_ident, quote};
6use syn::{Data, DeriveInput, Fields, Type, parse_macro_input};
7
8#[proc_macro_derive(Soa)]
69pub fn derive_soa(input: TokenStream) -> TokenStream {
70 use syn::TypePath;
71
72 let input = parse_macro_input!(input as DeriveInput);
73 let visibility = &input.vis;
74
75 let name = &input.ident;
76 let module_name = format_ident!("{}_soa", name.to_string().to_lowercase());
77 let soa_struct_name = format_ident!("{}Soa", name);
78 let soa_struct_wire_name = format_ident!("{}SoaWire", name);
79
80 let data = match &input.data {
81 Data::Struct(data) => data,
82 _ => panic!("Only structs are supported"),
83 };
84 let fields = match &data.fields {
85 Fields::Named(fields) => &fields.named,
86 _ => panic!("Only named fields are supported"),
87 };
88
89 let mut field_names = vec![];
90 let mut field_names_mut = vec![];
91 let mut field_names_range = vec![];
92 let mut field_names_range_mut = vec![];
93 let mut field_types = vec![];
94 let mut unique_imports = vec![];
95 let mut unique_import_names = vec![];
96
97 fn is_primitive(type_name: &str) -> bool {
98 matches!(
99 type_name,
100 "i8" | "i16"
101 | "i32"
102 | "i64"
103 | "i128"
104 | "u8"
105 | "u16"
106 | "u32"
107 | "u64"
108 | "u128"
109 | "f32"
110 | "f64"
111 | "bool"
112 | "char"
113 | "str"
114 | "usize"
115 | "isize"
116 )
117 }
118
119 for field in fields {
120 let field_name = field.ident.as_ref().unwrap();
121 let field_type = &field.ty;
122 field_names.push(field_name);
123 field_names_mut.push(format_ident!("{}_mut", field_name));
124 field_names_range.push(format_ident!("{}_range", field_name));
125 field_names_range_mut.push(format_ident!("{}_range_mut", field_name));
126 field_types.push(field_type);
127
128 if let Type::Path(TypePath { path, .. }) = field_type {
129 let type_name = path.segments.last().unwrap().ident.to_string();
130 let path_str = path.to_token_stream().to_string();
131
132 if !is_primitive(&type_name) && !unique_import_names.contains(&path_str) {
133 unique_imports.push(path.clone());
134 unique_import_names.push(path_str);
135 }
136 }
137 }
138
139 let soa_struct_name_iterator = format_ident!("{}Iterator", name);
140 let field_count = field_names.len() + 1; let iterator = quote! {
143 pub struct #soa_struct_name_iterator<'a, const N: usize> {
144 soa_struct: &'a #soa_struct_name<N>,
145 current: usize,
146 }
147
148 impl<'a, const N: usize> #soa_struct_name_iterator<'a, N> {
149 pub fn new(soa_struct: &'a #soa_struct_name<N>) -> Self {
150 Self {
151 soa_struct,
152 current: 0,
153 }
154 }
155 }
156
157 impl<'a, const N: usize> Iterator for #soa_struct_name_iterator<'a, N> {
158 type Item = super::#name;
159
160 fn next(&mut self) -> Option<Self::Item> {
161 if self.current < self.soa_struct.len {
162 let item = self.soa_struct.get(self.current); self.current += 1;
164 Some(item)
165 } else {
166 None
167 }
168 }
169 }
170 };
171
172 let expanded = quote! {
173 #visibility mod #module_name {
174 use bincode::{Decode, Encode};
175 use bincode::enc::Encoder;
176 use bincode::de::Decoder;
177 use bincode::error::{DecodeError, EncodeError};
178 use serde::Deserialize;
179 use serde::Serialize;
180 use serde::Serializer;
181 use serde::ser::SerializeStruct;
182 use std::ops::{Index, IndexMut};
183 #( use super::#unique_imports; )*
184 use core::array::from_fn;
185
186 #[derive(Debug)]
187 #visibility struct #soa_struct_name<const N: usize> {
188 pub len: usize,
189 #(pub #field_names: [#field_types; N], )*
190 }
191
192 impl<const N: usize> #soa_struct_name<N> {
193 pub fn new(default: super::#name) -> Self {
194 Self {
195 #( #field_names: from_fn(|_| default.#field_names.clone()), )*
196 len: 0,
197 }
198 }
199
200 pub fn len(&self) -> usize {
201 self.len
202 }
203
204 pub fn is_empty(&self) -> bool {
205 self.len == 0
206 }
207
208 pub fn push(&mut self, value: super::#name) {
209 if self.len < N {
210 #( self.#field_names[self.len] = value.#field_names.clone(); )*
211 self.len += 1;
212 } else {
213 panic!("Capacity exceeded")
214 }
215 }
216
217 pub fn pop(&mut self) -> Option<super::#name> {
218 if self.len == 0 {
219 None
220 } else {
221 self.len -= 1;
222 Some(super::#name {
223 #( #field_names: self.#field_names[self.len].clone(), )*
224 })
225 }
226 }
227
228 pub fn set(&mut self, index: usize, value: super::#name) {
229 assert!(index < self.len, "Index out of bounds");
230 #( self.#field_names[index] = value.#field_names.clone(); )*
231 }
232
233 pub fn get(&self, index: usize) -> super::#name {
234 assert!(index < self.len, "Index out of bounds");
235 super::#name {
236 #( #field_names: self.#field_names[index].clone(), )*
237 }
238 }
239
240 pub fn apply<F>(&mut self, mut f: F)
241 where
242 F: FnMut(#(#field_types),*) -> (#(#field_types),*)
243 {
244 for _idx in 0..self.len {
246 let result = f(#(self.#field_names[_idx].clone()),*);
247 let (#(#field_names),*) = result;
248 #(
249 self.#field_names[_idx] = #field_names;
250 )*
251 }
252 }
253
254 pub fn iter(&self) -> #soa_struct_name_iterator<N> {
255 #soa_struct_name_iterator::new(self)
256 }
257
258 #(
259 pub fn #field_names(&self) -> &[#field_types] {
260 &self.#field_names
261 }
262
263 pub fn #field_names_mut(&mut self) -> &mut [#field_types] {
264 &mut self.#field_names
265 }
266
267 pub fn #field_names_range(&self, range: std::ops::Range<usize>) -> &[#field_types] {
268 &self.#field_names[range]
269 }
270
271 pub fn #field_names_range_mut(&mut self, range: std::ops::Range<usize>) -> &mut [#field_types] {
272 &mut self.#field_names[range]
273 }
274 )*
275 }
276
277 impl<const N: usize> Encode for #soa_struct_name<N> {
278 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
279 self.len.encode(encoder)?;
280 #( self.#field_names[..self.len].encode(encoder)?; )*
281 Ok(())
282 }
283 }
284
285 impl<const N: usize> Decode<()> for #soa_struct_name<N> {
286 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
287 let mut result = Self::default();
288 result.len = Decode::decode(decoder)?;
289 #(
290 for _idx in 0..result.len {
291 result.#field_names[_idx] = Decode::decode(decoder)?;
292 }
293 )*
294 Ok(result)
295 }
296 }
297
298 impl<const N: usize> Default for #soa_struct_name<N> {
299 fn default() -> Self {
300 Self {
301 #( #field_names: from_fn(|_| #field_types::default()), )*
302 len: 0,
303 }
304 }
305 }
306
307 impl<const N: usize> Clone for #soa_struct_name<N>
308 where
309 #(
310 #field_types: Clone,
311 )*
312 {
313 fn clone(&self) -> Self {
314 Self {
315 #( #field_names: self.#field_names.clone(), )*
316 len: self.len,
317 }
318 }
319 }
320
321 impl<const N: usize> Serialize for #soa_struct_name<N>
322 where
323 #(
324 #field_types: Serialize,
325 )*
326 {
327 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
328 where
329 S: Serializer,
330 {
331 let mut state = serializer.serialize_struct(stringify!(#soa_struct_name), #field_count)?;
332 state.serialize_field("len", &self.len)?;
333 #(
334 state.serialize_field(stringify!(#field_names), &self.#field_names[..self.len])?;
335 )*
336 state.end()
337 }
338 }
339
340 impl<'de, const N: usize> Deserialize<'de> for #soa_struct_name<N>
341 where
342 #(
343 #field_types: Deserialize<'de> + Default,
344 )*
345 {
346 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
347 where
348 D: serde::Deserializer<'de>,
349 {
350 #[derive(Deserialize)]
351 struct #soa_struct_wire_name {
352 len: usize,
353 #( #field_names: Vec<#field_types>, )*
354 }
355
356 let wire = #soa_struct_wire_name::deserialize(deserializer)?;
357 let #soa_struct_wire_name { len, #( #field_names ),* } = wire;
358
359 if len > N {
360 return Err(serde::de::Error::custom(format!(
361 "len {} exceeds capacity {}",
362 len,
363 N
364 )));
365 }
366
367 #(
368 if #field_names.len() != len {
369 return Err(serde::de::Error::custom(format!(
370 "field {} has length {} but len is {}",
371 stringify!(#field_names),
372 #field_names.len(),
373 len
374 )));
375 }
376 )*
377
378 let mut result = Self::default();
379 result.len = len;
380 #(
381 for (idx, value) in #field_names.into_iter().enumerate() {
382 result.#field_names[idx] = value;
383 }
384 )*
385 Ok(result)
386 }
387 }
388
389 #iterator
390
391 }
392 #visibility use #module_name::#soa_struct_name;
393 };
394
395 let tokens: TokenStream = expanded.into();
396
397 #[cfg(feature = "macro_debug")]
398 {
399 let formatted_code = rustfmt_generated_code(tokens.to_string());
400 eprintln!("\n === Gen. SOA ===\n");
401 eprintln!("{}", highlight_rust_code(formatted_code));
402 eprintln!("\n === === === === === ===\n");
403 }
404 tokens
405}