use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::visit_mut::VisitMut;
use syn::{Data, DeriveInput, Fields, Lifetime, parse_macro_input};
#[proc_macro_derive(Resource)]
pub fn derive_resource(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let mut bounds = where_clause.cloned();
let predicate: syn::WherePredicate = syn::parse_quote!(#name #ty_generics: Send + 'static);
bounds
.get_or_insert_with(|| syn::parse_quote!(where))
.predicates
.push(predicate);
quote! {
impl #impl_generics ::nexus_rt::Resource for #name #ty_generics
#bounds
{}
}
.into()
}
#[proc_macro_derive(Deref, attributes(deref))]
pub fn derive_deref(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let (field_ty, field_access) = match deref_field(&input.data, name) {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
quote! {
impl #impl_generics ::core::ops::Deref for #name #ty_generics
#where_clause
{
type Target = #field_ty;
#[inline]
fn deref(&self) -> &Self::Target {
&self.#field_access
}
}
}
.into()
}
#[proc_macro_derive(DerefMut, attributes(deref))]
pub fn derive_deref_mut(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let (_field_ty, field_access) = match deref_field(&input.data, name) {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
quote! {
impl #impl_generics ::core::ops::DerefMut for #name #ty_generics
#where_clause
{
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.#field_access
}
}
}
.into()
}
fn deref_field(
data: &Data,
name: &syn::Ident,
) -> Result<(syn::Type, proc_macro2::TokenStream), syn::Error> {
let fields = match data {
Data::Struct(s) => &s.fields,
Data::Enum(_) => {
return Err(syn::Error::new_spanned(
name,
"Deref/DerefMut can only be derived for structs, not enums",
));
}
Data::Union(_) => {
return Err(syn::Error::new_spanned(
name,
"Deref/DerefMut can only be derived for structs, not unions",
));
}
};
match fields {
Fields::Unnamed(f) if f.unnamed.len() == 1 => {
let field = f.unnamed.first().unwrap();
let ty = field.ty.clone();
let access = quote!(0);
Ok((ty, access))
}
Fields::Named(f) if f.named.len() == 1 => {
let field = f.named.first().unwrap();
let ty = field.ty.clone();
let ident = field.ident.as_ref().unwrap();
let access = quote!(#ident);
Ok((ty, access))
}
Fields::Named(f) => {
let marked: Vec<_> = f
.named
.iter()
.filter(|field| field.attrs.iter().any(|a| a.path().is_ident("deref")))
.collect();
match marked.len() {
0 => Err(syn::Error::new_spanned(
name,
"multiple fields require exactly one `#[deref]` attribute",
)),
1 => {
let field = marked[0];
let ty = field.ty.clone();
let ident = field.ident.as_ref().unwrap();
let access = quote!(#ident);
Ok((ty, access))
}
_ => Err(syn::Error::new_spanned(
name,
"only one field may have `#[deref]`",
)),
}
}
Fields::Unnamed(f) => {
let marked: Vec<_> = f
.unnamed
.iter()
.enumerate()
.filter(|(_, field)| field.attrs.iter().any(|a| a.path().is_ident("deref")))
.collect();
match marked.len() {
0 => Err(syn::Error::new_spanned(
name,
"multiple fields require exactly one `#[deref]` attribute",
)),
1 => {
let (idx, field) = marked[0];
let ty = field.ty.clone();
let idx = syn::Index::from(idx);
let access = quote!(#idx);
Ok((ty, access))
}
_ => Err(syn::Error::new_spanned(
name,
"only one field may have `#[deref]`",
)),
}
}
Fields::Unit => Err(syn::Error::new_spanned(
name,
"Deref/DerefMut cannot be derived for unit structs",
)),
}
}
#[proc_macro_derive(Param, attributes(param))]
pub fn derive_param(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
match derive_param_impl(&input) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}
fn derive_param_impl(input: &DeriveInput) -> Result<proc_macro2::TokenStream, syn::Error> {
let name = &input.ident;
let fields = match &input.data {
Data::Struct(s) => &s.fields,
_ => {
return Err(syn::Error::new_spanned(
name,
"derive(Param) can only be applied to structs",
));
}
};
let lifetimes: Vec<_> = input.generics.lifetimes().collect();
if lifetimes.len() != 1 {
return Err(syn::Error::new_spanned(
&input.generics,
"derive(Param) requires exactly one lifetime parameter, \
e.g., `struct MyParam<'w>`",
));
}
if input.generics.type_params().next().is_some()
|| input.generics.const_params().next().is_some()
{
return Err(syn::Error::new_spanned(
&input.generics,
"derive(Param) does not yet support type or const generics — \
only a single lifetime parameter (e.g., `struct MyParam<'w>`). \
Use a concrete type instead (e.g., `Res<'w, Buffer<64>>` not `Res<'w, Buffer<N>>`)",
));
}
let world_lifetime = &lifetimes[0].lifetime;
let named_fields = match fields {
Fields::Named(f) => &f.named,
_ => {
return Err(syn::Error::new_spanned(
name,
"derive(Param) requires named fields",
));
}
};
let mut param_fields = Vec::new();
let mut ignored_fields = Vec::new();
for field in named_fields {
let field_name = field.ident.as_ref().unwrap();
let is_ignored = field.attrs.iter().any(|a| {
a.path().is_ident("param")
&& a.meta
.require_list()
.is_ok_and(|l| l.tokens.to_string().trim() == "ignore")
});
if is_ignored {
ignored_fields.push(field_name);
} else {
let mut static_ty = field.ty.clone();
let mut replacer = LifetimeReplacer {
from: world_lifetime.ident.to_string(),
};
replacer.visit_type_mut(&mut static_ty);
param_fields.push((field_name, &field.ty, static_ty));
}
}
let state_name = format_ident!("{}State", name);
let state_fields = param_fields.iter().map(|(field_name, _, static_ty)| {
quote! {
#field_name: <#static_ty as ::nexus_rt::Param>::State
}
});
let ignored_state_fields = ignored_fields.iter().map(|field_name| {
quote! {
#field_name: ()
}
});
let init_fields = param_fields.iter().map(|(field_name, _, static_ty)| {
quote! {
#field_name: <#static_ty as ::nexus_rt::Param>::init(registry)
}
});
let init_ignored = ignored_fields.iter().map(|field_name| {
quote! { #field_name: () }
});
let fetch_fields = param_fields.iter().map(|(field_name, _, static_ty)| {
quote! {
#field_name: <#static_ty as ::nexus_rt::Param>::fetch(world, &mut state.#field_name)
}
});
let fetch_ignored = ignored_fields.iter().map(|field_name| {
quote! {
#field_name: ::core::default::Default::default()
}
});
Ok(quote! {
#[doc(hidden)]
#[allow(non_camel_case_types)]
pub struct #state_name {
#(#state_fields,)*
#(#ignored_state_fields,)*
}
impl ::nexus_rt::Param for #name<'_> {
type State = #state_name;
type Item<'w> = #name<'w>;
fn init(registry: &::nexus_rt::Registry) -> Self::State {
#state_name {
#(#init_fields,)*
#(#init_ignored,)*
}
}
unsafe fn fetch<'w>(
world: &'w ::nexus_rt::World,
state: &'w mut Self::State,
) -> #name<'w> {
#name {
#(#fetch_fields,)*
#(#fetch_ignored,)*
}
}
}
})
}
struct LifetimeReplacer {
from: String,
}
impl VisitMut for LifetimeReplacer {
fn visit_lifetime_mut(&mut self, lt: &mut Lifetime) {
if lt.ident == self.from {
*lt = Lifetime::new("'static", lt.apostrophe);
}
}
}