Skip to main content

mtb_entity_slab_macros/
lib.rs

1//! Syntax:
2//!
3//! ```ignore
4//! #[entity_id(WrapperName)] -- default policy (256)
5//! #[entity_id(WrapperName, options...)] -- custom policy (named argument)
6//! #[entity_id(WrapperName, opaque)] -- opaque ID with default policy (256)
7//! #[entity_id(WrapperName, options..., opaque)] -- opaque ID with custom policy
8//! pub struct|enum|union ObjType { ... }
9//! ```
10//!
11//! Options:
12//!
13//! - `policy = <PolicyType>`: specify allocation policy type, e.g., `256 | Policy256 | AllocPolicy256`.
14//! - `opaque`: makes the wrapper opaque (restricts field visibility to crate-only)
15//! - `allocator_type = AllocTypeName`: generates a type alias for the corresponding `EntityAlloc<Obj, Policy>`
16//! - `backend = ptr | index`: choose backend ID type (default: `ptr`)
17//!
18//! Example (pointer backend, allocator alias):
19//! ```ignore
20//! #[entity_id(MyInstID, policy = 256, backend = ptr, allocator_type = MyInstAlloc)]
21//! struct Inst { /* fields */ }
22//! let alloc: MyInstAlloc = MyInstAlloc::new();
23//! let raw = alloc.allocate_ptr(Inst { /* init */ });
24//! let id = MyInstID::from(raw);
25//! let obj = id.deref_alloc(&alloc);
26//! ```
27//!
28//! Example (index backend):
29//! ```ignore
30//! #[entity_id(MyIdxID, policy = 256, backend = index)]
31//! struct Inst { /* fields */ }
32//! let alloc: EntityAlloc<Inst, AllocPolicy256> = EntityAlloc::new();
33//! let raw = IndexedID::allocate_from(&alloc, Inst { /* init */ });
34//! let id = MyIdxID::from_backend(raw);
35//! let obj = id.deref_alloc(&alloc);
36//! ```
37//!
38//! Wrapper types intentionally DO NOT expose high-level allocate/free convenience methods.
39//! Allocation always happens at the raw layer (`allocate_ptr`, `allocate_index`, or
40//! `IEntityAllocID::allocate_from`), then wrapped via `from_backend`.
41//!
42//! Generates:
43//!
44//! ```ignore
45//! /// non-opaque wrapper
46//! pub struct $WrapperName(pub $ObjType);
47//!
48//! /// opaque wrapper
49//! pub struct $WrapperName(pub(crate) $ObjType);
50//!
51//! Implements `$crate::IPoliciedID` so the wrapper can be used anywhere
52//! a raw policy-bound ID is expected. Allocation remains at the lower layer:
53//! create a value with the allocator, then wrap with `WrapperName::from_backend(...)`.
54//! ```
55
56use proc_macro::TokenStream;
57use proc_macro_crate::{FoundCrate, crate_name};
58use quote::{ToTokens, quote};
59use syn::{
60    Ident, LitInt,
61    parse::{Parse, ParseBuffer},
62};
63
64#[proc_macro_attribute]
65pub fn entity_id(attr: TokenStream, item: TokenStream) -> TokenStream {
66    let mut attr = syn::parse_macro_input!(attr as WrapperAttr);
67    let item = syn::parse_macro_input!(item as syn::DeriveInput);
68
69    // Capture ident and generics for codegen
70    attr.init_items(&item);
71
72    let output = quote! {
73        #item
74        #attr
75    };
76    TokenStream::from(output)
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80enum BackendKind {
81    Ptr,
82    Index,
83}
84
85struct WrapperAttr {
86    crate_name: proc_macro2::TokenStream,
87    visability: syn::Visibility,
88    wrapper_name: Ident,
89    object_name: Ident,
90    policy: WrapperPolicy,
91    generics: syn::Generics,
92    backend_kind: BackendKind,
93    opaque: bool,
94    allocator_type: Option<Ident>,
95}
96
97impl Parse for WrapperAttr {
98    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
99        use syn::Token;
100        let crate_name = {
101            let found_crate =
102                crate_name("mtb_entity_slab").or_else(|_| crate_name("mtb-entity-slab"));
103            match found_crate {
104                Ok(FoundCrate::Itself) => quote! { crate },
105                Ok(FoundCrate::Name(name)) => {
106                    let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
107                    quote! { #ident }
108                }
109                Err(_) => quote! { mtb_entity_slab },
110            }
111        };
112
113        let mut ret = Self {
114            crate_name,
115            wrapper_name: syn::parse_quote! { __entity_placeholder },
116            object_name: syn::parse_quote! { __entity_placeholder },
117            policy: WrapperPolicy::Policy256,
118            visability: syn::Visibility::Inherited,
119            generics: syn::Generics::default(),
120            backend_kind: BackendKind::Ptr,
121            opaque: false,
122            allocator_type: None,
123        };
124
125        // First arg: wrapper name
126        ret.wrapper_name = input.parse::<Ident>()?;
127
128        while input.peek(Token![,]) {
129            let _ = input.parse::<Token![,]>()?;
130            if input.is_empty() {
131                break;
132            }
133            if input.peek(Ident) {
134                ret.parse_ident_prefix(input)?;
135                continue;
136            }
137            if input.peek(LitInt) {
138                let policy_level: LitInt = input.parse()?;
139                ret.policy = WrapperPolicy::from_litint(&policy_level)?;
140                continue;
141            }
142            return Err(syn::Error::new(
143                input.span(),
144                "Invalid argument (expected wrapper options: opaque | policy = v | PolicyXXX | integer)",
145            ));
146        }
147        Ok(ret)
148    }
149}
150
151impl ToTokens for WrapperAttr {
152    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
153        // Original user item (struct/enum) has already been injected by caller before this ToTokens
154        // We generate a wrapper newtype around PtrID<ObjectT, PolicyT>, optionally opaque.
155        let crate_path = &self.crate_name;
156        let wrapper_name = &self.wrapper_name;
157        let object_name = &self.object_name;
158        let generics = &self.generics;
159        let vis = &self.visability;
160        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
161
162        // Object type (with generics) we are wrapping pointer to
163        let object_ty = quote! { #object_name #ty_generics };
164
165        // Resolve policy concrete type referencing object type
166        let policy_ty = self.make_policy_type_tokens();
167
168        // Field visibility: public if not opaque, restricted if opaque
169        let field_vis = if self.opaque {
170            quote! { pub(crate) }
171        } else {
172            quote! { pub }
173        };
174
175        // Debug impl body differs for opaque vs transparent, and by backend kind
176        let debug_body = self.make_debug_body_tokens();
177
178        let allocator_alias = match &self.allocator_type {
179            None => quote! {},
180            Some(allocator_type) => {
181                quote! {
182                    /// Type alias to `EntityAlloc<..., ...>` to prevent repetition
183                    #vis type #allocator_type #ty_generics #where_clause
184                        = #crate_path::EntityAlloc<#object_ty, #policy_ty>;
185                }
186            }
187        };
188
189        let backend_tokens = match self.backend_kind {
190            BackendKind::Ptr => quote! { #crate_path::PtrID<#object_ty, #policy_ty> },
191            BackendKind::Index => quote! { #crate_path::IndexedID<#object_ty, #policy_ty> },
192        };
193        let genindex_convert = self.make_genindex_convert_tokens();
194
195        let attr_toks = quote! {
196            // #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
197            #[repr(transparent)]
198            #vis struct #wrapper_name #impl_generics (#field_vis #backend_tokens) #where_clause;
199
200            impl #impl_generics ::std::clone::Clone for #wrapper_name #ty_generics #where_clause {
201                #[inline]
202                fn clone(&self) -> Self { *self }
203            }
204            impl #impl_generics ::std::marker::Copy for #wrapper_name #ty_generics #where_clause {}
205            impl #impl_generics ::std::cmp::PartialEq for #wrapper_name #ty_generics #where_clause {
206                #[inline]
207                fn eq(&self, other: &Self) -> bool {
208                    self.0 == other.0
209                }
210            }
211            impl #impl_generics ::std::cmp::Eq for #wrapper_name #ty_generics #where_clause {}
212            impl #impl_generics ::std::cmp::PartialOrd for #wrapper_name #ty_generics #where_clause {
213                #[inline]
214                fn partial_cmp(&self, other: &Self) -> Option<::std::cmp::Ordering> {
215                    self.0.partial_cmp(&other.0)
216                }
217            }
218            impl #impl_generics ::std::cmp::Ord for #wrapper_name #ty_generics #where_clause {
219                #[inline]
220                fn cmp(&self, other: &Self) -> ::std::cmp::Ordering {
221                    self.0.cmp(&other.0)
222                }
223            }
224            impl #impl_generics ::std::hash::Hash for #wrapper_name #ty_generics #where_clause {
225                #[inline]
226                fn hash<H: ::std::hash::Hasher>(&self, state: &mut H) {
227                    self.0.hash(state);
228                }
229            }
230            impl #impl_generics ::std::fmt::Debug for #wrapper_name #ty_generics #where_clause {
231                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> std::fmt::Result { #debug_body }
232            }
233
234            impl #impl_generics ::std::convert::From<#backend_tokens> for #wrapper_name #ty_generics #where_clause {
235                #[inline]
236                fn from(p: #backend_tokens) -> Self { Self(p) }
237            }
238            impl #impl_generics ::std::convert::From<#wrapper_name #ty_generics> for #backend_tokens #where_clause {
239                #[inline]
240                fn from(x: #wrapper_name #ty_generics) -> Self { x.0 }
241            }
242            #genindex_convert
243
244            // Implement policy-bound ID trait over chosen backend (pointer or indexed).
245            impl #impl_generics #crate_path::IPoliciedID for #wrapper_name #ty_generics #where_clause {
246                type ObjectT = #object_ty;
247                type PolicyT = #policy_ty;
248                type BackID = #backend_tokens;
249
250                #[inline]
251                fn from_backend(ptr: #backend_tokens) -> Self { Self(ptr) }
252                #[inline]
253                fn into_backend(self) -> #backend_tokens { self.0 }
254            }
255
256            #allocator_alias
257        };
258        tokens.extend(attr_toks);
259    }
260}
261impl WrapperAttr {
262    fn init_items(&mut self, item: &syn::DeriveInput) {
263        self.object_name = item.ident.clone();
264        self.generics = item.generics.clone();
265        self.visability = item.vis.clone();
266    }
267
268    fn parse_policy(input: &ParseBuffer<'_>) -> syn::Result<WrapperPolicy> {
269        let policy = if input.peek(LitInt) {
270            let lit: LitInt = input.parse()?;
271            WrapperPolicy::from_litint(&lit)?
272        } else if input.peek(Ident) {
273            let pol_ident: Ident = input.parse()?;
274            WrapperPolicy::from_ident(&pol_ident)?
275        } else {
276            return Err(syn::Error::new(
277                input.span(),
278                "Expected integer or identifier for policy",
279            ));
280        };
281        Ok(policy)
282    }
283
284    fn parse_ident_prefix(&mut self, input: &ParseBuffer<'_>) -> syn::Result<()> {
285        use syn::Token;
286        let ident: Ident = input.parse()?;
287        match ident.to_string().as_str() {
288            "opaque" => self.opaque = true,
289            "policy" => {
290                let _eq = input.parse::<Token![=]>()?;
291                self.policy = Self::parse_policy(input)?
292            }
293            "allocator_type" => {
294                let _eq = input.parse::<Token![=]>()?;
295                self.allocator_type = Some(input.parse::<Ident>()?);
296            }
297            "backend" => {
298                let _eq = input.parse::<Token![=]>()?;
299                let backend_ident: Ident = input.parse()?;
300                self.backend_kind = match backend_ident.to_string().as_str() {
301                    "ptr" => BackendKind::Ptr,
302                    "index" => BackendKind::Index,
303                    _ => {
304                        return Err(syn::Error::new(
305                            backend_ident.span(),
306                            "Invalid backend kind (expected `ptr` or `index`)",
307                        ));
308                    }
309                }
310            }
311            _ => {
312                // treat as policy identifier like Policy256 / 256-as-ident
313                self.policy = WrapperPolicy::from_ident(&ident)?
314            }
315        }
316        Ok(())
317    }
318
319    fn make_debug_body_tokens(&self) -> proc_macro2::TokenStream {
320        let wrapper_name = &self.wrapper_name;
321        match self.backend_kind {
322            _ if self.opaque => {
323                quote! { write!(f, concat!(stringify!(#wrapper_name), "(<opaque>)")) }
324            }
325            BackendKind::Ptr => {
326                quote! { write!(f, concat!(stringify!(#wrapper_name), "({:p})"), self.0) }
327            }
328            BackendKind::Index => {
329                quote! {
330                    let index = self.0.indexed.real_index();
331                    let gene = self.0.indexed.generation();
332                    write!(f, concat!(stringify!(#wrapper_name), "({:x}:{:x})"), index, gene)
333                }
334            }
335        }
336    }
337
338    fn make_genindex_convert_tokens(&self) -> proc_macro2::TokenStream {
339        let crate_path = &self.crate_name;
340        let wrapper_name = &self.wrapper_name;
341        let generics = &self.generics;
342        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
343
344        match self.backend_kind {
345            BackendKind::Ptr => quote! {},
346            BackendKind::Index => quote! {
347                impl #impl_generics #wrapper_name #ty_generics #where_clause {
348                    #[inline]
349                    pub fn from_gen_index(gen_index: #crate_path::GenIndex) -> Self {
350                        Self(#crate_path::IndexedID::from(gen_index))
351                    }
352
353                    #[inline]
354                    pub fn into_gen_index(self) -> #crate_path::GenIndex {
355                        self.0.indexed
356                    }
357                }
358            },
359        }
360    }
361    fn make_policy_type_tokens(&self) -> proc_macro2::TokenStream {
362        let crate_path = &self.crate_name;
363        match self.policy {
364            WrapperPolicy::Policy128 => quote! { #crate_path::AllocPolicy128 },
365            WrapperPolicy::Policy256 => quote! { #crate_path::AllocPolicy256 },
366            WrapperPolicy::Policy512 => quote! { #crate_path::AllocPolicy512 },
367            WrapperPolicy::Policy1024 => quote! { #crate_path::AllocPolicy1024 },
368            WrapperPolicy::Policy2048 => quote! { #crate_path::AllocPolicy2048 },
369            WrapperPolicy::Policy4096 => quote! { #crate_path::AllocPolicy4096 },
370        }
371    }
372}
373
374#[derive(Debug, Clone, Copy)]
375enum WrapperPolicy {
376    Policy128,
377    Policy256,
378    Policy512,
379    Policy1024,
380    Policy2048,
381    Policy4096,
382}
383
384impl WrapperPolicy {
385    fn from_str(s: &str) -> syn::Result<Self> {
386        let ret = match s {
387            "128" | "Policy128" | "AllocPolicy128" => Self::Policy128,
388            "256" | "Policy256" | "AllocPolicy256" => Self::Policy256,
389            "512" | "Policy512" | "AllocPolicy512" => Self::Policy512,
390            "1024" | "Policy1024" | "AllocPolicy1024" => Self::Policy1024,
391            "2048" | "Policy2048" | "AllocPolicy2048" => Self::Policy2048,
392            "4096" | "Policy4096" | "AllocPolicy4096" => Self::Policy4096,
393            _ => {
394                return Err(syn::Error::new_spanned(
395                    syn::LitStr::new(s, proc_macro2::Span::call_site()),
396                    concat!(
397                        "Invalid policy (expected 128|256|512|1024|2048|4096,",
398                        " Policy128|..|4096,",
399                        " or AllocPolicy128|..|4096)"
400                    ),
401                ));
402            }
403        };
404        Ok(ret)
405    }
406
407    fn from_ident(ident: &Ident) -> syn::Result<Self> {
408        Self::from_str(&ident.to_string())
409    }
410    fn from_litint(litint: &LitInt) -> syn::Result<Self> {
411        Self::from_str(litint.base10_digits())
412    }
413}