use proc_macro::TokenStream;
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{Expr, Ident, ItemStruct, Path, Token, Type, bracketed, parse_macro_input};
pub fn module(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as ModuleArgs);
let item = parse_macro_input!(input as ItemStruct);
let name = item.ident.clone();
let name_str = name.to_string();
let import_calls = args.imports.iter().map(|import| match import {
Expr::Path(p) => {
let path = &p.path;
quote! { builder = <#path as ::nest_rs_core::Module>::register(builder); }
}
other => {
quote! { builder = ::nest_rs_core::DynamicModule::register(#other, builder); }
}
});
let collect_calls = args.imports.iter().map(|import| match import {
Expr::Path(p) => {
let path = &p.path;
quote! { builder = <#path as ::nest_rs_core::Module>::collect(builder); }
}
other => {
quote! { builder = ::nest_rs_core::DynamicModule::collect(&(#other), builder); }
}
});
let import_type_ids = args.imports.iter().filter_map(|import| match import {
Expr::Path(p) => {
let path = &p.path;
Some(quote! { || ::std::any::TypeId::of::<#path>() })
}
_ => None,
});
let provider_descriptors = args.providers.iter().map(|binding| match binding {
ProviderBinding::Concrete(p) => {
let name_lit = path_tail(p);
quote! {
::nest_rs_core::ProviderDescriptor {
name: #name_lit,
provides: || ::std::any::TypeId::of::<#p>(),
injects: <#p as ::nest_rs_core::Discoverable>::injected,
}
}
}
ProviderBinding::Dyn { provider, trait_ty } => {
let name_lit = format!("dyn {}", path_tail_of_type(trait_ty));
quote! {
::nest_rs_core::ProviderDescriptor {
name: #name_lit,
provides: || ::std::any::TypeId::of::<::std::sync::Arc<#trait_ty>>(),
injects: <#provider as ::nest_rs_core::Discoverable>::injected,
}
}
}
});
let descriptor_submission = quote! {
::nest_rs_core::inventory::submit! {
::nest_rs_core::ModuleDescriptor {
module: || ::std::any::TypeId::of::<#name>(),
name: #name_str,
imports: &[ #(#import_type_ids),* ],
providers: &[ #(#provider_descriptors),* ],
}
}
};
let body = if args.providers.is_empty() {
quote! {
#(#import_calls)*
::nest_rs_core::__module_registered(#name_str);
builder
}
} else {
let count = proc_macro2::Literal::usize_unsuffixed(args.providers.len());
let parts: Vec<(
proc_macro2::TokenStream,
proc_macro2::TokenStream,
proc_macro2::TokenStream,
)> = args
.providers
.iter()
.enumerate()
.map(|(i, binding)| {
let idx = proc_macro2::Literal::usize_unsuffixed(i);
let (provider, name_lit, provided_key, register_action) = match binding {
ProviderBinding::Concrete(p) => (
p,
path_tail(p),
quote! { ::std::any::TypeId::of::<#p>() },
quote! {
builder = <#p as ::nest_rs_core::Discoverable>::register(builder);
},
),
ProviderBinding::Dyn { provider, trait_ty } => (
provider,
path_tail(provider),
quote! { ::std::any::TypeId::of::<::std::sync::Arc<#trait_ty>>() },
quote! {
let __snapshot = builder.snapshot();
let __provider = #provider::from_container(&__snapshot);
let __dyn: ::std::sync::Arc<#trait_ty> =
::std::sync::Arc::new(__provider);
builder = builder.provide_dyn::<#trait_ty>(__dyn);
},
),
};
let step = quote! {
if !__done[#idx] {
let __required_ready =
<#provider as ::nest_rs_core::Discoverable>::dependencies()
.iter()
.all(|__id| builder.contains(*__id));
let __optional_ready =
<#provider as ::nest_rs_core::Discoverable>::optional_dependencies()
.iter()
.all(|__id| builder.contains(*__id) || !__pending_keys.contains(__id));
if __required_ready && __optional_ready {
#register_action
__done[#idx] = true;
__progressed = true;
} else {
__any_pending = true;
}
}
};
let key_push = quote! {
if !__done[#idx] {
__pending_keys.push(#provided_key);
}
};
let classify = quote! {
if !__done[#idx] {
let __deps = <#provider as ::nest_rs_core::Discoverable>::dependencies();
let __dep_names =
<#provider as ::nest_rs_core::Discoverable>::dependency_names();
let mut __missing_ids: ::std::vec::Vec<::std::any::TypeId> =
::std::vec::Vec::new();
let mut __missing_names: ::std::vec::Vec<&'static str> =
::std::vec::Vec::new();
let mut __k = 0usize;
while __k < __deps.len() {
if !builder.contains(__deps[__k]) {
__missing_ids.push(__deps[__k]);
__missing_names.push(*__dep_names.get(__k).unwrap_or(&"?"));
}
__k += 1;
}
if !__missing_ids.is_empty()
&& __missing_ids.iter().all(|__id| __pending_keys.contains(__id))
{
__cyclic.push(#name_lit);
} else {
__unprovided.push(::std::format!(
"{} (needs {})", #name_lit, __missing_names.join(", ")
));
}
}
};
(step, key_push, classify)
})
.collect();
let steps = parts.iter().map(|p| &p.0);
let key_pushes = parts.iter().map(|p| &p.1);
let classifies = parts.iter().map(|p| &p.2);
quote! {
#(#import_calls)*
let mut __done = [false; #count];
loop {
let mut __pending_keys: ::std::vec::Vec<::std::any::TypeId> =
::std::vec::Vec::new();
#(#key_pushes)*
let mut __any_pending = false;
let mut __progressed = false;
#(#steps)*
if !__any_pending {
break;
}
if !__progressed {
let mut __cyclic: ::std::vec::Vec<&'static str> = ::std::vec::Vec::new();
let mut __unprovided: ::std::vec::Vec<::std::string::String> =
::std::vec::Vec::new();
#(#classifies)*
if !__unprovided.is_empty() {
::std::panic!(
"module `{}`: cannot register provider(s) {:?} — each injects a dependency that neither this module's `providers` nor its `imports` registers; add the provider or import the module that exposes it",
#name_str, __unprovided
);
} else {
::std::panic!(
"module `{}`: dependency cycle among provider(s) {:?} — each waits on another provider in the same module; break it by injecting `Arc<dyn Trait>` instead of the concrete type",
#name_str, __cyclic
);
}
}
}
::nest_rs_core::__module_registered(#name_str);
builder
}
};
quote! {
#item
impl ::nest_rs_core::Module for #name {
fn register(
mut builder: ::nest_rs_core::ContainerBuilder,
) -> ::nest_rs_core::ContainerBuilder {
if !::nest_rs_core::ContainerBuilder::mark_registered(
&mut builder,
::std::any::TypeId::of::<#name>(),
) {
return builder;
}
#body
}
fn collect(
mut builder: ::nest_rs_core::ContainerBuilder,
) -> ::nest_rs_core::ContainerBuilder {
if !::nest_rs_core::ContainerBuilder::mark_collected(
&mut builder,
::std::any::TypeId::of::<#name>(),
) {
return builder;
}
#(#collect_calls)*
builder
}
}
#descriptor_submission
}
.into()
}
fn path_tail(p: &Path) -> String {
p.segments
.last()
.map(|s| s.ident.to_string())
.unwrap_or_else(|| quote!(#p).to_string())
}
fn path_tail_of_type(ty: &Type) -> String {
if let Type::TraitObject(obj) = ty {
for bound in &obj.bounds {
if let syn::TypeParamBound::Trait(t) = bound
&& let Some(seg) = t.path.segments.last()
{
return seg.ident.to_string();
}
}
}
quote!(#ty).to_string()
}
#[derive(Default)]
struct ModuleArgs {
imports: Vec<Expr>,
providers: Vec<ProviderBinding>,
}
enum ProviderBinding {
Concrete(Path),
Dyn { provider: Path, trait_ty: Box<Type> },
}
impl Parse for ProviderBinding {
fn parse(input: ParseStream) -> syn::Result<Self> {
let provider: Path = input.parse()?;
if input.peek(Token![as]) {
input.parse::<Token![as]>()?;
let trait_ty: Type = input.parse()?;
Ok(Self::Dyn {
provider,
trait_ty: Box::new(trait_ty),
})
} else {
Ok(Self::Concrete(provider))
}
}
}
impl Parse for ModuleArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut args = ModuleArgs::default();
while !input.is_empty() {
let key: Ident = input.parse()?;
input.parse::<Token![=]>()?;
let content;
bracketed!(content in input);
match key.to_string().as_str() {
"imports" => {
let exprs: Punctuated<Expr, Token![,]> =
Punctuated::parse_terminated(&content)?;
args.imports.extend(exprs);
}
"providers" => {
let bindings: Punctuated<ProviderBinding, Token![,]> =
Punctuated::parse_terminated(&content)?;
args.providers.extend(bindings);
}
other => {
return Err(syn::Error::new(
key.span(),
format!(
"unknown #[module] key `{other}` (expected `imports` or `providers`)"
),
));
}
}
if !input.is_empty() {
input.parse::<Token![,]>()?;
}
}
Ok(args)
}
}