nate_engine_macros/
lib.rs

1extern crate proc_macro;
2use proc_macro::TokenStream;
3use proc_macro2::Span;
4use quote::{format_ident, quote};
5use syn::{parse::{Parse, ParseStream, Result}, parse_macro_input, Block, Error, Expr, ExprBinary, FnArg, Ident, ItemFn, ItemStruct, Token};
6
7use std::collections::{HashMap, HashSet};
8
9struct IgnoreArgs {
10    ignore_identifiers: HashSet<String>,
11}
12
13impl Parse for IgnoreArgs {
14    fn parse(input: ParseStream) -> Result<Self> {
15        let mut ignore_identifiers = HashSet::new();
16
17        let parts = input.parse_terminated(Expr::parse, Token![,])?;
18        for part in parts.iter() {
19            if let Expr::Assign(assignment) = part {
20                if let Expr::Path(path) = assignment.left.as_ref() {
21                    if let Some(segment) = path.path.segments.first() {
22                        if segment.ident.to_string().as_str() == "singular" {
23                            match assignment.right.as_ref() {
24                                Expr::Path(path) => {
25                                    if let Some(segment) = path.path.segments.first() {
26                                        ignore_identifiers.insert(segment.ident.to_string());
27                                    }
28                                },
29                                Expr::Array(array) => {
30                                    for element in array.elems.iter() {
31                                        if let Expr::Path(path) = element {
32                                            if let Some(segment) = path.path.segments.first() {
33                                                ignore_identifiers.insert(segment.ident.to_string());
34                                            }
35                                        }
36                                    }
37                                },
38                                _ => (),
39                            }
40                        }
41                    }
42                }
43            }
44        }
45
46        Ok(IgnoreArgs {
47            ignore_identifiers,
48        })
49    }
50}
51
52#[proc_macro_attribute]
53pub fn world(attr: TokenStream, item: TokenStream) -> TokenStream {
54    let item = parse_macro_input!(item as ItemStruct);
55    let item_name = item.ident;
56
57    let ignore_args = parse_macro_input!(attr as IgnoreArgs);
58
59    let fields = item.fields;
60    let mut field_identifiers = Vec::new();
61    let mut field_types = Vec::new();
62    let mut ignore_identifiers = Vec::new();
63    let mut ignore_types = Vec::new();
64    for field in fields.iter() {
65        if let Some(ident) = &field.ident {
66            if ignore_args.ignore_identifiers.contains(&ident.to_string()) {
67                ignore_identifiers.push(ident);
68                ignore_types.push(&field.ty);
69            } else {
70                field_identifiers.push(ident);
71                field_types.push(&field.ty);
72            }
73        }
74    }
75
76    let plural_identifiers: Vec<Ident> = field_identifiers.iter().map(|v| format_ident!("{}s", v)).collect();
77    let setter_identifiers: Vec<Ident> = field_identifiers.iter().map(|v| format_ident!("set_{}", v)).collect();
78    let set_many_identifiers: Vec<Ident> = field_identifiers.iter().map(|v| format_ident!("set_{}s", v)).collect();
79    let clear_identifiers: Vec<Ident> = field_identifiers.iter().map(|v| format_ident!("clear_{}", v)).collect();
80    let clear_many_identifiers: Vec<Ident> = field_identifiers.iter().map(|v| format_ident!("clear{}s", v)).collect();
81    let set_ignore_identifiers: Vec<Ident> = ignore_identifiers.iter().map(|v| format_ident!("set_{}", v)).collect();
82    let clear_ignore_identifiers: Vec<Ident> = ignore_identifiers.iter().map(|v| format_ident!("clear_{}", v)).collect();
83
84    let entity_fields = if field_identifiers.len() > 0 {
85        quote!{
86            #(pub #field_identifiers: std::sync::Arc<std::sync::RwLock<std::vec::Vec<std::option::Option<#field_types>>>>),*,
87        }
88    } else {
89        quote!{ }
90    };
91
92    let entity_initializers = if field_identifiers.len() > 0 {
93        quote!{
94            #(#field_identifiers: std::sync::Arc::new(std::sync::RwLock::new(std::vec::Vec::new()))),*,
95        }
96    } else {
97        quote!{ }
98    };
99
100    let global_fields = if ignore_identifiers.len() > 0 {
101        quote!{
102            #(pub #ignore_identifiers: std::sync::Arc<std::sync::RwLock<std::option::Option<#ignore_types>>>),*,
103        }
104    } else {
105        quote!{ }
106    };
107
108    let global_initializers = if ignore_identifiers.len() > 0 {
109        quote!{
110            #(#ignore_identifiers: std::sync::Arc::new(std::sync::RwLock::new(None))),*,
111        }
112    } else {
113        quote!{ }
114    };
115
116    TokenStream::from(quote!{
117        pub struct #item_name {
118            entities: std::sync::Arc<std::sync::RwLock<std::vec::Vec<usize>>>,
119            #entity_fields
120            #global_fields
121        }
122
123        impl #item_name {
124            pub fn new() -> std::sync::Arc<std::sync::RwLock<Self>> {
125                std::sync::Arc::new(std::sync::RwLock::new(Self {
126                    entities: std::sync::Arc::new(std::sync::RwLock::new(std::vec::Vec::new())),
127                    #entity_initializers
128                    #global_initializers
129                }))
130            }
131
132            pub fn add_entity(&mut self) -> usize {
133                let entity_id = self.entities.read().unwrap().len() as usize;
134                self.entities.write().unwrap().push(entity_id);
135                #(self.#field_identifiers.write().unwrap().push(None));*;
136                entity_id
137            }
138
139            pub fn add_entities(&mut self, entities: usize) -> Vec<usize> {
140                let mut new_entity_ids = Vec::with_capacity(entities as usize);
141                let mut entities_list = self.entities.write().unwrap();
142                #(let mut #field_identifiers = self.#field_identifiers.write().unwrap());*;
143                let start_len = entities_list.len();
144
145                for i in 0..entities {
146                    let new_entity_id = start_len + i;
147                    entities_list.push(new_entity_id);
148                     #(#field_identifiers.push(None));*;
149                     new_entity_ids.push(new_entity_id);
150                }
151
152                new_entity_ids
153            }
154
155            pub fn remove_entity(&mut self, entity_id: usize) {
156                self.entities.write().unwrap().remove(entity_id as usize);
157                #(self.#field_identifiers.write().unwrap().remove(entity_id as usize));*;
158            }
159
160            pub fn remove_entities(&mut self, entity_ids: Vec<usize>) {
161                for entity_id in entity_ids {
162                    self.entities.write().unwrap().remove(entity_id as usize);
163                    #(self.#field_identifiers.write().unwrap().remove(entity_id as usize));*;
164                }
165            }
166
167            #(pub fn #setter_identifiers(&mut self, entity_id: usize, #field_identifiers: #field_types) {
168                self.#field_identifiers.write().unwrap()[entity_id as usize] = Some(#field_identifiers);
169            })*
170
171            #(pub fn #set_ignore_identifiers(&mut self, #ignore_identifiers: #ignore_types) {
172                *self.#ignore_identifiers.write().unwrap() = Some(#ignore_identifiers);
173            })*
174
175            #(pub fn #set_many_identifiers(&mut self, entity_ids: &Vec<usize>, mut #plural_identifiers: Vec<#field_types>) {
176                let mut component = self.#field_identifiers.write().unwrap();
177                for (#field_identifiers, entity_id) in #plural_identifiers.drain(..).zip(entity_ids.iter()) {
178                    component[*entity_id] = Some(#field_identifiers);
179                }
180            })*
181
182            #(pub fn #clear_identifiers(&mut self, entity_id: usize) {
183                self.#field_identifiers.write().unwrap()[entity_id as usize] = None;
184            })*
185
186            #(pub fn #clear_ignore_identifiers(&mut self) {
187                *self.#ignore_identifiers.write().unwrap() = None;
188            })*
189
190            #(pub fn #clear_many_identifiers(&mut self, entity_ids: &Vec<usize>) {
191                let mut component = self.#field_identifiers.write().unwrap();
192                for entity_id in entity_ids {
193                    component[*entity_id as usize] = None;
194                }
195            })*
196        }
197
198        unsafe impl Send for #item_name {}
199        unsafe impl Sync for #item_name {}
200    })
201}
202
203struct WorldArgs {
204    function_name: Ident,
205    function_args: Vec<FnArg>,
206    body: Block,
207}
208
209impl Parse for WorldArgs {
210    fn parse(input: ParseStream) -> Result<Self> {
211        let function_parts = ItemFn::parse(input)?;
212
213        // Parsing System Name
214        let function_name = function_parts.sig.ident;
215
216        // Function Arguments
217        let function_args: Vec<FnArg> = function_parts.sig.inputs.iter().map(|v| v.clone()).collect();
218
219        Ok(WorldArgs {
220            function_name,
221            function_args,
222            body: *function_parts.block,
223        })
224    }
225}
226
227struct FunctionArgs {
228    world_type: Ident,
229    read_components: Vec<Ident>,
230    global_read_components: Vec<Ident>,
231    write_components: Vec<Ident>,
232    global_write_components: Vec<Ident>,
233    global_write_assignments: HashMap<Ident, Expr>,
234    filters: Vec<ExprBinary>,
235    enumerated: bool,
236}
237
238impl Parse for FunctionArgs {
239    fn parse(input: ParseStream) -> Result<Self> {
240        let mut world_type: Option<Ident> = None;
241        let mut read_components: Vec<Ident> = Vec::new();
242        let mut global_read_components: Vec<Ident> = Vec::new();
243        let mut write_components: Vec<Ident> = Vec::new();
244        let mut global_write_components: Vec<Ident> = Vec::new();
245        let mut global_write_assignments: HashMap<Ident, Expr> = HashMap::new();
246        let mut filters: Vec<ExprBinary> = Vec::new();
247        let mut enumerated: bool = false;
248
249        let parts = input.parse_terminated(Expr::parse, Token![,])?;
250        for part in parts.iter() {
251            if let Expr::Assign(assignment) = part {
252                match assignment.left.as_ref() {
253                    Expr::Path(path) => {
254                        if let Some(segment) = path.path.segments.first() {
255                            match segment.ident.to_string().as_str() {
256                                "world" => {
257                                    if let Expr::Path(path) = assignment.right.as_ref() {
258                                        if let Some(segment) = path.path.segments.first() {
259                                            world_type = Some(segment.ident.clone());
260                                        } else {
261                                            return Err(Error::new(Span::call_site(), "Expected Singular Type for World Type"));
262                                        }
263                                    }
264                                },
265                                "read" => {
266                                    match assignment.right.as_ref() {
267                                        Expr::Path(path) => {
268                                            if let Some(segment) = path.path.segments.first() {
269                                                read_components.push(segment.ident.clone());
270                                            }
271                                        },
272                                        Expr::Array(array) => {
273                                            for element in array.elems.iter() {
274                                                if let Expr::Path(path) = element {
275                                                    if let Some(segment) = path.path.segments.first() {
276                                                        read_components.push(segment.ident.clone());
277                                                    }
278                                                }
279                                            }
280                                        },
281                                        _ => (),
282                                    }
283                                },
284                                "write" => {
285                                    match assignment.right.as_ref() {
286                                        Expr::Path(path) => {
287                                            if let Some(segment) = path.path.segments.first() {
288                                                write_components.push(segment.ident.clone());
289                                            }
290                                        },
291                                        Expr::Array(array) => {
292                                            for element in array.elems.iter() {
293                                                if let Expr::Path(path) = element {
294                                                    if let Some(segment) = path.path.segments.first() {
295                                                        write_components.push(segment.ident.clone());
296                                                    }
297                                                }
298                                            }
299                                        },
300                                        _ => (),
301                                    }
302                                },
303                                "filter" => {
304                                    if let Expr::Array(array) = assignment.right.as_ref() {
305                                        for element in array.elems.iter() {
306                                            if let Expr::Binary(binary) = element {
307                                                filters.push(binary.clone());
308                                            }
309                                        }
310                                    }
311                                },
312                                "_read" => {
313                                    match assignment.right.as_ref() {
314                                        Expr::Path(path) => {
315                                            if let Some(segment) = path.path.segments.first() {
316                                                global_read_components.push(segment.ident.clone());
317                                            }
318                                        },
319                                        Expr::Array(array) => {
320                                            for element in array.elems.iter() {
321                                                if let Expr::Path(path) = element {
322                                                    if let Some(segment) = path.path.segments.first() {
323                                                        global_read_components.push(segment.ident.clone());
324                                                    }
325                                                }
326                                            }
327                                        },
328                                        _ => (),
329                                    }
330                                },
331                                "_write" => {
332                                    match assignment.right.as_ref() {
333                                        Expr::Path(path) => {
334                                            if let Some(segment) = path.path.segments.first() {
335                                                global_write_components.push(segment.ident.clone());
336                                            }
337                                        },
338                                        Expr::Array(array) => {
339                                            for element in array.elems.iter() {
340                                                match element {
341                                                    Expr::Path(path) => {
342                                                        if let Some(segment) = path.path.segments.first() {
343                                                            global_write_components.push(segment.ident.clone());
344                                                        }
345                                                    },
346                                                    Expr::Assign(assignment) => {
347                                                        if let Expr::Path(path) = assignment.left.as_ref() {
348                                                            if let Some(segment) = path.path.segments.first() {
349                                                                global_write_components.push(segment.ident.clone());
350                                                                global_write_assignments.insert(segment.ident.clone(), *assignment.right.clone());
351                                                            }
352                                                        }
353                                                    },
354                                                    _ => (),
355                                                }
356                                            }
357                                        },
358                                        _ => (),
359                                    }
360                                },
361                                "enumerate" => enumerated = true,
362                                _ => (),
363                            }
364                        }
365                    },
366                    _ => return Err(Error::new(Span::call_site(), "Invalid Parameter to system macro")),
367                }
368            } else {
369                return Err(Error::new(Span::call_site(), "Expected Assignments in Attribute"));
370            }
371        }
372
373        if world_type.is_none() {
374            return Err(Error::new(Span::call_site(), "World Type was not Provided"));
375        }
376
377        Ok(FunctionArgs {
378            world_type: world_type.unwrap(),
379            read_components,
380            global_read_components,
381            write_components,
382            global_write_components,
383            global_write_assignments,
384            filters,
385            enumerated,
386        })
387    }
388}
389
390#[proc_macro_attribute]
391pub fn system(attr: TokenStream, item: TokenStream) -> TokenStream {
392    let world_args = parse_macro_input!(item as WorldArgs);
393    let function_args = parse_macro_input!(attr as FunctionArgs);
394
395    let fn_name = world_args.function_name;
396    let fn_args = world_args.function_args;
397    let body = world_args.body;
398
399    let read_components = function_args.read_components;
400    let global_read_components = function_args.global_read_components;
401    let global_read_refs: Vec<Ident> = global_read_components.iter().map(|v| format_ident!("{}_ref", v)).collect();
402    let write_components = function_args.write_components;
403    let global_write_components = function_args.global_write_components;
404    let global_write_refs: Vec<Ident> = global_write_components.iter().map(|v| format_ident!("{}_ref", v)).collect();
405    let world_type = function_args.world_type;
406
407    let (mut items, mut iterators) = match read_components.len() {
408        0 => (quote!{ }, quote!{ }),
409        1 => {
410            let read_component = &read_components[0];
411            (
412                quote!{ (entity_id, #read_component) },
413                quote!{#read_component.iter().enumerate()}
414            )
415        },
416        _ => {
417            let read_component1 = &read_components[0];
418            let read_component2 = &read_components[1];
419            (
420                quote!{ ((entity_id, #read_component1), #read_component2) },
421                quote!{ #read_component1.iter().enumerate().zip(#read_component2.iter())}
422            )
423        },
424    };
425
426    if read_components.len() > 2 {
427        for read_component in read_components[2..].iter() {
428            items = quote!{ (#items, #read_component) };
429            iterators = quote!{#iterators.zip(#read_component.iter())};
430        }
431    }
432
433    if read_components.len() == 0 {
434        (items, iterators) = match write_components.len() {
435            0 => (quote!{ }, quote!{ }),
436            1 => {
437                let write_component = &write_components[0];
438                (
439                    quote!{ (entity_id, #write_component) },
440                    quote!{#write_component.iter_mut().enumerate()}
441                )
442            },
443            _ => {
444                let write_component1 = &write_components[0];
445                let write_component2 = &write_components[1];
446                (
447                    quote!{ ((entity_id, #write_component1), #write_component2) },
448                    quote!{#write_component1.iter_mut().enumerate().zip(#write_component2.iter_mut())}
449                )
450            },
451        };
452        if write_components.len() > 2 {
453            for write_component in write_components[2..].iter() {
454                items = quote!{ (#items, #write_component) };
455                iterators = quote!{#iterators.zip(#write_component.iter_mut())};
456            }
457        }
458    } else {
459        for write_component in write_components.iter() {
460            items = quote!{ (#items, #write_component) };
461            iterators = quote!{#iterators.zip(#write_component.iter_mut())};
462        }
463    }
464
465    let mut filter = quote!{ };
466    let combined_length = read_components.len() + write_components.len();
467    match combined_length {
468        0 => (),
469        1 => filter = quote!{ .filter(|v| v.1.is_some()) },
470        2 => filter = quote!{ .filter(|v| v.0.1.is_some() && v.1.is_some()) },
471        _ => {
472            filter = quote!{ v.1.is_some() };
473            for i in 1..(combined_length-1) {
474                filter = quote!{ #filter && v };
475                for _ in 0..i {
476                    filter = quote!{ #filter.0 };
477                }
478                filter = quote!{ #filter.1.is_some() };
479            }
480            filter = quote!{ .filter(|v| #filter) };
481        },
482    }
483
484    let mut filter_condition = quote!{ };
485    if function_args.filters.len() > 0 {
486        let first_filter = &function_args.filters[0];
487        filter_condition = quote!{ if #first_filter };
488        for filter in function_args.filters[1..].iter() {
489            filter_condition = quote!{ #filter_condition && #filter };
490        }
491    }
492
493    let body = if read_components.len() + write_components.len() > 0 {
494        quote!{ 
495            for #items in #iterators #filter {
496                #(let #read_components = #read_components.as_ref().unwrap());*;
497                #(let mut #write_components = #write_components.as_mut().unwrap());*;
498
499                #filter_condition {
500                    #body
501                }
502            }
503        }
504    } else if function_args.enumerated {
505        quote!{ 
506            for entity_id in 0..world.entities.read().unwrap().len() {
507                #body
508            }
509        }
510    } else {
511        quote!{ #body }
512    };
513
514    let mut global_write_assignments = quote!{ };
515    for key in function_args.global_write_assignments.keys() {
516        let value = function_args.global_write_assignments.get(key).unwrap();
517        global_write_assignments = quote!{
518            #global_write_assignments
519            *#key = #value;
520        };
521    }
522
523
524    TokenStream::from(quote!{
525        pub fn #fn_name(world: std::sync::Arc<std::sync::RwLock<#world_type>>, #(#fn_args),*) {
526            let world = world.read().unwrap();
527            #(let #read_components = world.#read_components.read().unwrap());*;
528            #(let mut #write_components = world.#write_components.write().unwrap());*;
529            #(let #global_read_refs = world.#global_read_components.read().unwrap());*;
530            #(let #global_read_components = #global_read_refs.as_ref().expect("Global Components must not be None"));*;
531            #(let mut #global_write_refs = world.#global_write_components.write().unwrap());*;
532            #(let mut #global_write_components = #global_write_refs.as_mut().expect("Global Components must not be None"));*;
533            #global_write_assignments
534
535            #body
536        }
537    })
538}