Skip to main content

nexus_rt_derive/
lib.rs

1//! Derive macros for nexus-rt.
2//!
3//! Use `nexus-rt` instead of depending on this crate directly.
4//! The derives are re-exported from `nexus_rt::{Resource, Deref, DerefMut}`.
5
6use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::visit_mut::VisitMut;
9use syn::{Data, DeriveInput, Fields, Lifetime, parse_macro_input};
10
11// =============================================================================
12// #[derive(Resource)]
13// =============================================================================
14
15/// Derive the `Resource` marker trait, allowing this type to be stored
16/// in a `World`.
17///
18/// ```ignore
19/// use nexus_rt::Resource;
20///
21/// #[derive(Resource)]
22/// struct OrderBook {
23///     bids: Vec<(f64, f64)>,
24///     asks: Vec<(f64, f64)>,
25/// }
26/// ```
27#[proc_macro_derive(Resource)]
28pub fn derive_resource(input: TokenStream) -> TokenStream {
29    let input = parse_macro_input!(input as DeriveInput);
30    let name = &input.ident;
31    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
32
33    // Add Send + 'static where clause so errors point at the derive,
34    // not at the register() call site.
35    let mut bounds = where_clause.cloned();
36    let predicate: syn::WherePredicate = syn::parse_quote!(#name #ty_generics: Send + 'static);
37    bounds
38        .get_or_insert_with(|| syn::parse_quote!(where))
39        .predicates
40        .push(predicate);
41
42    quote! {
43        impl #impl_generics ::nexus_rt::Resource for #name #ty_generics
44            #bounds
45        {}
46    }
47    .into()
48}
49
50// =============================================================================
51// #[derive(Deref)]
52// =============================================================================
53
54/// Derive `Deref` for newtype wrappers.
55///
56/// - Single-field structs: auto-selects the field.
57/// - Multi-field structs: requires `#[deref]` on exactly one field.
58///
59/// ```ignore
60/// use nexus_rt::Deref;
61///
62/// #[derive(Deref)]
63/// struct MyWrapper(u64);
64///
65/// #[derive(Deref)]
66/// struct Named {
67///     #[deref]
68///     data: Vec<u8>,
69///     label: String,
70/// }
71/// ```
72#[proc_macro_derive(Deref, attributes(deref))]
73pub fn derive_deref(input: TokenStream) -> TokenStream {
74    let input = parse_macro_input!(input as DeriveInput);
75    let name = &input.ident;
76    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
77
78    let (field_ty, field_access) = match deref_field(&input.data, name) {
79        Ok(v) => v,
80        Err(e) => return e.to_compile_error().into(),
81    };
82
83    quote! {
84        impl #impl_generics ::core::ops::Deref for #name #ty_generics
85            #where_clause
86        {
87            type Target = #field_ty;
88
89            #[inline]
90            fn deref(&self) -> &Self::Target {
91                &self.#field_access
92            }
93        }
94    }
95    .into()
96}
97
98// =============================================================================
99// #[derive(DerefMut)]
100// =============================================================================
101
102/// Derive `DerefMut` for newtype wrappers.
103///
104/// Same field selection rules as `#[derive(Deref)]`. Must be used
105/// alongside `#[derive(Deref)]`.
106#[proc_macro_derive(DerefMut, attributes(deref))]
107pub fn derive_deref_mut(input: TokenStream) -> TokenStream {
108    let input = parse_macro_input!(input as DeriveInput);
109    let name = &input.ident;
110    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
111
112    let (_field_ty, field_access) = match deref_field(&input.data, name) {
113        Ok(v) => v,
114        Err(e) => return e.to_compile_error().into(),
115    };
116
117    quote! {
118        impl #impl_generics ::core::ops::DerefMut for #name #ty_generics
119            #where_clause
120        {
121            #[inline]
122            fn deref_mut(&mut self) -> &mut Self::Target {
123                &mut self.#field_access
124            }
125        }
126    }
127    .into()
128}
129
130// =============================================================================
131// Shared field resolution
132// =============================================================================
133
134/// Find the deref target field. Returns (field_type, field_access).
135fn deref_field(
136    data: &Data,
137    name: &syn::Ident,
138) -> Result<(syn::Type, proc_macro2::TokenStream), syn::Error> {
139    let fields = match data {
140        Data::Struct(s) => &s.fields,
141        Data::Enum(_) => {
142            return Err(syn::Error::new_spanned(
143                name,
144                "Deref/DerefMut can only be derived for structs, not enums",
145            ));
146        }
147        Data::Union(_) => {
148            return Err(syn::Error::new_spanned(
149                name,
150                "Deref/DerefMut can only be derived for structs, not unions",
151            ));
152        }
153    };
154
155    match fields {
156        // Tuple struct: single field → auto-select
157        Fields::Unnamed(f) if f.unnamed.len() == 1 => {
158            let field = f.unnamed.first().unwrap();
159            let ty = field.ty.clone();
160            let access = quote!(0);
161            Ok((ty, access))
162        }
163        // Named struct: single field → auto-select
164        Fields::Named(f) if f.named.len() == 1 => {
165            let field = f.named.first().unwrap();
166            let ty = field.ty.clone();
167            let ident = field.ident.as_ref().unwrap();
168            let access = quote!(#ident);
169            Ok((ty, access))
170        }
171        // Multiple fields → look for #[deref] attribute
172        Fields::Named(f) => {
173            let marked: Vec<_> = f
174                .named
175                .iter()
176                .filter(|field| field.attrs.iter().any(|a| a.path().is_ident("deref")))
177                .collect();
178
179            match marked.len() {
180                0 => Err(syn::Error::new_spanned(
181                    name,
182                    "multiple fields require exactly one `#[deref]` attribute",
183                )),
184                1 => {
185                    let field = marked[0];
186                    let ty = field.ty.clone();
187                    let ident = field.ident.as_ref().unwrap();
188                    let access = quote!(#ident);
189                    Ok((ty, access))
190                }
191                _ => Err(syn::Error::new_spanned(
192                    name,
193                    "only one field may have `#[deref]`",
194                )),
195            }
196        }
197        Fields::Unnamed(f) => {
198            let marked: Vec<_> = f
199                .unnamed
200                .iter()
201                .enumerate()
202                .filter(|(_, field)| field.attrs.iter().any(|a| a.path().is_ident("deref")))
203                .collect();
204
205            match marked.len() {
206                0 => Err(syn::Error::new_spanned(
207                    name,
208                    "multiple fields require exactly one `#[deref]` attribute",
209                )),
210                1 => {
211                    let (idx, field) = marked[0];
212                    let ty = field.ty.clone();
213                    let idx = syn::Index::from(idx);
214                    let access = quote!(#idx);
215                    Ok((ty, access))
216                }
217                _ => Err(syn::Error::new_spanned(
218                    name,
219                    "only one field may have `#[deref]`",
220                )),
221            }
222        }
223        Fields::Unit => Err(syn::Error::new_spanned(
224            name,
225            "Deref/DerefMut cannot be derived for unit structs",
226        )),
227    }
228}
229
230// =============================================================================
231// #[derive(Param)]
232// =============================================================================
233
234/// Derive the `Param` trait for a struct, enabling it to be used as a
235/// grouped handler parameter.
236///
237/// The struct must have exactly one lifetime parameter. Each field must
238/// implement `Param`, or be annotated with `#[param(ignore)]` (in which
239/// case it must implement `Default`).
240///
241/// ```ignore
242/// use nexus_rt::{Param, Res, ResMut, Local};
243///
244/// #[derive(Param)]
245/// struct TradingParams<'w> {
246///     book: Res<'w, OrderBook>,
247///     risk: ResMut<'w, RiskState>,
248///     local_count: Local<'w, u64>,
249/// }
250///
251/// fn on_order(params: TradingParams<'_>, order: Order) {
252///     // params.book, params.risk, params.local_count all available
253/// }
254/// ```
255#[proc_macro_derive(Param, attributes(param))]
256pub fn derive_param(input: TokenStream) -> TokenStream {
257    let input = parse_macro_input!(input as DeriveInput);
258    match derive_param_impl(&input) {
259        Ok(tokens) => tokens.into(),
260        Err(e) => e.to_compile_error().into(),
261    }
262}
263
264fn derive_param_impl(input: &DeriveInput) -> Result<proc_macro2::TokenStream, syn::Error> {
265    let name = &input.ident;
266
267    // Validate: must be a struct
268    let fields = match &input.data {
269        Data::Struct(s) => &s.fields,
270        _ => {
271            return Err(syn::Error::new_spanned(
272                name,
273                "derive(Param) can only be applied to structs",
274            ));
275        }
276    };
277
278    // Validate: exactly one lifetime parameter, no type/const generics
279    let lifetimes: Vec<_> = input.generics.lifetimes().collect();
280    if lifetimes.len() != 1 {
281        return Err(syn::Error::new_spanned(
282            &input.generics,
283            "derive(Param) requires exactly one lifetime parameter, \
284             e.g., `struct MyParam<'w>`",
285        ));
286    }
287    // TODO: support type and const generics by threading them through
288    // the generated State struct and Param impl (e.g., `Buffer<const N: usize>`).
289    // This is straightforward with syn's split_for_impl() but deferred to
290    // avoid the lifetime inference issues Bevy hit with generic SystemParams.
291    if input.generics.type_params().next().is_some()
292        || input.generics.const_params().next().is_some()
293    {
294        return Err(syn::Error::new_spanned(
295            &input.generics,
296            "derive(Param) does not yet support type or const generics — \
297             only a single lifetime parameter (e.g., `struct MyParam<'w>`). \
298             Use a concrete type instead (e.g., `Res<'w, Buffer<64>>` not `Res<'w, Buffer<N>>`)",
299        ));
300    }
301    let world_lifetime = &lifetimes[0].lifetime;
302
303    // Must be named fields
304    let named_fields = match fields {
305        Fields::Named(f) => &f.named,
306        _ => {
307            return Err(syn::Error::new_spanned(
308                name,
309                "derive(Param) requires named fields",
310            ));
311        }
312    };
313
314    // Classify fields: param fields (participate in init/fetch) vs ignored
315    let mut param_fields = Vec::new();
316    let mut ignored_fields = Vec::new();
317
318    for field in named_fields {
319        let field_name = field.ident.as_ref().unwrap();
320        let is_ignored = field.attrs.iter().any(|a| {
321            a.path().is_ident("param")
322                && a.meta
323                    .require_list()
324                    .is_ok_and(|l| l.tokens.to_string().trim() == "ignore")
325        });
326
327        if is_ignored {
328            ignored_fields.push(field_name);
329        } else {
330            // Substitute the struct's lifetime with 'static in the field type
331            let mut static_ty = field.ty.clone();
332            let mut replacer = LifetimeReplacer {
333                from: world_lifetime.ident.to_string(),
334            };
335            replacer.visit_type_mut(&mut static_ty);
336
337            param_fields.push((field_name, &field.ty, static_ty));
338        }
339    }
340
341    // Generate the State struct name
342    let state_name = format_ident!("{}State", name);
343
344    // State struct fields
345    let state_fields = param_fields.iter().map(|(field_name, _, static_ty)| {
346        quote! {
347            #field_name: <#static_ty as ::nexus_rt::Param>::State
348        }
349    });
350    let ignored_state_fields = ignored_fields.iter().map(|field_name| {
351        quote! {
352            #field_name: ()
353        }
354    });
355
356    // init() body
357    let init_fields = param_fields.iter().map(|(field_name, _, static_ty)| {
358        quote! {
359            #field_name: <#static_ty as ::nexus_rt::Param>::init(registry)
360        }
361    });
362    let init_ignored = ignored_fields.iter().map(|field_name| {
363        quote! { #field_name: () }
364    });
365
366    // fetch() body
367    let fetch_fields = param_fields.iter().map(|(field_name, _, static_ty)| {
368        quote! {
369            #field_name: <#static_ty as ::nexus_rt::Param>::fetch(world, &mut state.#field_name)
370        }
371    });
372    let fetch_ignored = ignored_fields.iter().map(|field_name| {
373        quote! {
374            #field_name: ::core::default::Default::default()
375        }
376    });
377
378    Ok(quote! {
379        #[doc(hidden)]
380        #[allow(non_camel_case_types)]
381        pub struct #state_name {
382            #(#state_fields,)*
383            #(#ignored_state_fields,)*
384        }
385
386        impl ::nexus_rt::Param for #name<'_> {
387            type State = #state_name;
388            type Item<'w> = #name<'w>;
389
390            fn init(registry: &::nexus_rt::Registry) -> Self::State {
391                #state_name {
392                    #(#init_fields,)*
393                    #(#init_ignored,)*
394                }
395            }
396
397            unsafe fn fetch<'w>(
398                world: &'w ::nexus_rt::World,
399                state: &'w mut Self::State,
400            ) -> #name<'w> {
401                #name {
402                    #(#fetch_fields,)*
403                    #(#fetch_ignored,)*
404                }
405            }
406        }
407    })
408}
409
410/// Replaces occurrences of a specific lifetime with `'static`.
411struct LifetimeReplacer {
412    from: String,
413}
414
415impl VisitMut for LifetimeReplacer {
416    fn visit_lifetime_mut(&mut self, lt: &mut Lifetime) {
417        if lt.ident == self.from {
418            *lt = Lifetime::new("'static", lt.apostrophe);
419        }
420    }
421}
422
423// =============================================================================
424// #[derive(View)]
425// =============================================================================
426
427/// Derive a `View` projection for use with pipeline `.view()` scopes.
428///
429/// Generates a marker ZST (`As{ViewName}`) and `unsafe impl View<Source>`
430/// for each `#[source(Type)]` attribute. Use with `.view::<AsViewName>()`
431/// in pipeline and DAG builders.
432///
433/// # Attributes
434///
435/// **On the struct:**
436/// - `#[source(TypePath)]` — one per source event type
437///
438/// **On fields:**
439/// - `#[borrow]` — borrow from source (`&source.field`) instead of copy
440/// - `#[source(TypePath, from = "name")]` — remap field name for a specific source
441///
442/// # Examples
443///
444/// ```ignore
445/// use nexus_rt::View;
446///
447/// #[derive(View)]
448/// #[source(NewOrderCommand)]
449/// #[source(AmendOrderCommand)]
450/// struct OrderView<'a> {
451///     #[borrow]
452///     symbol: &'a str,
453///     qty: u64,
454///     price: f64,
455/// }
456///
457/// // Generates: struct AsOrderView;
458/// // Generates: unsafe impl View<NewOrderCommand> for AsOrderView { ... }
459/// // Generates: unsafe impl View<AmendOrderCommand> for AsOrderView { ... }
460/// ```
461#[proc_macro_derive(View, attributes(source, borrow))]
462pub fn derive_view(input: TokenStream) -> TokenStream {
463    let input = parse_macro_input!(input as DeriveInput);
464    match derive_view_impl(&input) {
465        Ok(tokens) => tokens.into(),
466        Err(e) => e.to_compile_error().into(),
467    }
468}
469
470fn derive_view_impl(input: &DeriveInput) -> Result<proc_macro2::TokenStream, syn::Error> {
471    // Only structs
472    let fields = match &input.data {
473        Data::Struct(s) => match &s.fields {
474            Fields::Named(f) => &f.named,
475            _ => {
476                return Err(syn::Error::new_spanned(
477                    &input.ident,
478                    "#[derive(View)] only supports structs with named fields",
479                ));
480            }
481        },
482        _ => {
483            return Err(syn::Error::new_spanned(
484                &input.ident,
485                "#[derive(View)] can only be used on structs",
486            ));
487        }
488    };
489
490    let view_name = &input.ident;
491    let vis = &input.vis;
492
493    // Extract #[source(TypePath)] attributes from the struct
494    let sources = parse_source_attrs(&input.attrs, view_name)?;
495    if sources.is_empty() {
496        return Err(syn::Error::new_spanned(
497            view_name,
498            "#[derive(View)] requires at least one #[source(Type)] attribute",
499        ));
500    }
501
502    // Reject type and const generics
503    if input.generics.type_params().count() > 0 {
504        return Err(syn::Error::new_spanned(
505            &input.generics,
506            "#[derive(View)] does not support type parameters",
507        ));
508    }
509    if input.generics.const_params().count() > 0 {
510        return Err(syn::Error::new_spanned(
511            &input.generics,
512            "#[derive(View)] does not support const parameters",
513        ));
514    }
515
516    // Detect lifetime: 0 or 1 lifetime param
517    let lifetime_param = match input.generics.lifetimes().count() {
518        0 => None,
519        1 => Some(input.generics.lifetimes().next().unwrap().lifetime.clone()),
520        _ => {
521            return Err(syn::Error::new_spanned(
522                &input.generics,
523                "#[derive(View)] supports at most one lifetime parameter",
524            ));
525        }
526    };
527
528    // Marker name: As{ViewName}
529    let marker_name = format_ident!("As{}", view_name);
530
531    // Build ViewType<'a>, StaticViewType, and tick-lifetime tokens
532    let (view_type_with_a, static_view_type, view_type_tick) = lifetime_param.as_ref().map_or_else(
533        || {
534            (
535                quote! { #view_name },
536                quote! { #view_name },
537                quote! { #view_name },
538            )
539        },
540        |lt| {
541            let lt_ident = &lt.ident;
542            let mut static_generics = input.generics.clone();
543            LifetimeReplacer {
544                from: lt_ident.to_string(),
545            }
546            .visit_generics_mut(&mut static_generics);
547            let (_, static_ty_generics, _) = static_generics.split_for_impl();
548            (
549                quote! { #view_name<'a> },
550                quote! { #view_name #static_ty_generics },
551                quote! { #view_name<'_> },
552            )
553        },
554    );
555
556    // Parse field info
557    let field_infos: Vec<FieldInfo> = fields
558        .iter()
559        .map(parse_field_info)
560        .collect::<Result<_, _>>()?;
561
562    // Generate impl for each source
563    let mut impls = Vec::new();
564    for source_type in &sources {
565        let field_exprs: Vec<proc_macro2::TokenStream> = field_infos
566            .iter()
567            .map(|fi| {
568                let view_field = &fi.ident;
569                // Check for per-source field remap
570                let source_field = fi
571                    .remaps
572                    .iter()
573                    .find(|(path, _)| path_matches(path, source_type))
574                    .map_or_else(|| fi.ident.clone(), |(_, name)| format_ident!("{}", name));
575
576                if fi.borrow {
577                    quote! { #view_field: &source.#source_field }
578                } else {
579                    quote! { #view_field: source.#source_field }
580                }
581            })
582            .collect();
583
584        impls.push(quote! {
585            // SAFETY: ViewType<'a> and StaticViewType are the same struct
586            // with different lifetime parameters. Layout-identical by construction.
587            unsafe impl ::nexus_rt::View<#source_type> for #marker_name {
588                type ViewType<'a> = #view_type_with_a where #source_type: 'a;
589                type StaticViewType = #static_view_type;
590
591                fn view(source: &#source_type) -> #view_type_tick {
592                    #view_name {
593                        #(#field_exprs),*
594                    }
595                }
596            }
597        });
598    }
599
600    Ok(quote! {
601        /// View marker generated by `#[derive(View)]`.
602        #vis struct #marker_name;
603
604        #(#impls)*
605    })
606}
607
608struct FieldInfo {
609    ident: syn::Ident,
610    borrow: bool,
611    /// Per-source field remaps: (source_path, source_field_name)
612    remaps: Vec<(syn::Path, String)>,
613}
614
615fn parse_field_info(field: &syn::Field) -> Result<FieldInfo, syn::Error> {
616    let ident = field
617        .ident
618        .clone()
619        .ok_or_else(|| syn::Error::new_spanned(field, "View fields must be named"))?;
620
621    let borrow = field.attrs.iter().any(|a| a.path().is_ident("borrow"));
622
623    let mut remaps = Vec::new();
624    for attr in &field.attrs {
625        if attr.path().is_ident("source") {
626            // Parse #[source(TypePath, from = "field_name")]
627            attr.parse_args_with(|input: syn::parse::ParseStream| {
628                let path: syn::Path = input.parse()?;
629
630                if input.is_empty() {
631                    return Ok(());
632                }
633
634                input.parse::<syn::Token![,]>()?;
635                let kw: syn::Ident = input.parse()?;
636                if kw != "from" {
637                    return Err(syn::Error::new_spanned(&kw, "expected `from`"));
638                }
639                input.parse::<syn::Token![=]>()?;
640                let lit: syn::LitStr = input.parse()?;
641                remaps.push((path, lit.value()));
642                Ok(())
643            })?;
644        }
645    }
646
647    Ok(FieldInfo {
648        ident,
649        borrow,
650        remaps,
651    })
652}
653
654/// Parse `#[source(TypePath)]` attributes from struct-level attrs.
655fn parse_source_attrs(
656    attrs: &[syn::Attribute],
657    span_target: &syn::Ident,
658) -> Result<Vec<syn::Path>, syn::Error> {
659    let mut sources = Vec::new();
660    for attr in attrs {
661        if attr.path().is_ident("source") {
662            let path: syn::Path = attr.parse_args()?;
663            sources.push(path);
664        }
665    }
666    let _ = span_target; // used for error span if needed
667    Ok(sources)
668}
669
670/// Check if two paths match by comparing full path equality.
671fn path_matches(a: &syn::Path, b: &syn::Path) -> bool {
672    a == b
673}