use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{
parse::{Parse, ParseStream},
parse_macro_input, punctuated::Punctuated, DeriveInput, Expr, ExprLit, Fields, Lit, Meta,
MetaList, MetaNameValue, Token,
};
pub fn derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
match derive_impl(input) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}
struct HandleEnumAttrs {
plugin_id: String,
plugin_id_type: Option<String>,
version: String,
crate_path: String,
}
impl Parse for HandleEnumAttrs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut plugin_id = None;
let mut plugin_id_type = None;
let mut version = None;
let mut crate_path = "plexus_core".to_string();
if !input.is_empty() {
let metas = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
for meta in metas {
if let Meta::NameValue(MetaNameValue { path, value, .. }) = meta {
if let Expr::Lit(ExprLit {
lit: Lit::Str(s), ..
}) = value
{
if path.is_ident("plugin_id") {
plugin_id = Some(s.value());
} else if path.is_ident("plugin_id_type") {
plugin_id_type = Some(s.value());
} else if path.is_ident("version") {
version = Some(s.value());
} else if path.is_ident("crate_path") {
crate_path = s.value();
}
}
}
}
}
Ok(HandleEnumAttrs {
plugin_id: plugin_id.ok_or_else(|| {
syn::Error::new(input.span(), "HandleEnum requires plugin_id = \"...\" attribute")
})?,
plugin_id_type,
version: version.ok_or_else(|| {
syn::Error::new(input.span(), "HandleEnum requires version = \"...\" attribute")
})?,
crate_path,
})
}
}
struct HandleVariantAttrs {
method: String,
table: Option<String>,
key: Option<String>,
key_field: Option<String>,
strip_prefix: Option<String>,
}
fn parse_variant_attrs(attrs: &[syn::Attribute]) -> syn::Result<Option<HandleVariantAttrs>> {
for attr in attrs {
if attr.path().is_ident("handle") {
if let Meta::List(MetaList { tokens, .. }) = &attr.meta {
let mut method = None;
let mut table = None;
let mut key = None;
let mut key_field = None;
let mut strip_prefix = None;
let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
let nested = syn::parse::Parser::parse2(parser, tokens.clone())?;
for meta in nested {
if let Meta::NameValue(MetaNameValue { path, value, .. }) = meta {
if let Expr::Lit(ExprLit {
lit: Lit::Str(s), ..
}) = value
{
if path.is_ident("method") {
method = Some(s.value());
} else if path.is_ident("table") {
table = Some(s.value());
} else if path.is_ident("key") {
key = Some(s.value());
} else if path.is_ident("key_field") {
key_field = Some(s.value());
} else if path.is_ident("strip_prefix") {
strip_prefix = Some(s.value());
}
}
}
}
if let Some(method) = method {
return Ok(Some(HandleVariantAttrs { method, table, key, key_field, strip_prefix }));
}
}
}
}
Ok(None)
}
struct FieldInfo {
name: syn::Ident,
name_str: String,
}
struct ResolutionInfo {
table: String,
key: String,
key_field_index: usize,
strip_prefix: Option<String>,
}
struct VariantInfo {
name: syn::Ident,
method: String,
fields: Vec<FieldInfo>,
resolution: Option<ResolutionInfo>,
}
fn derive_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
let enum_name = &input.ident;
let enum_attrs = extract_enum_attrs(&input)?;
let crate_path: syn::Path = syn::parse_str(&enum_attrs.crate_path)?;
let plugin_id_expr: TokenStream2 = match enum_attrs.plugin_id_type.as_deref() {
Some(type_str) => {
let qualified_type: syn::Type = syn::parse_str(type_str).map_err(|e| {
syn::Error::new_spanned(
&input,
format!("Invalid plugin_id_type '{}': {}", type_str, e),
)
})?;
let plugin_id_path: syn::Path =
syn::parse_str(&enum_attrs.plugin_id).map_err(|e| {
syn::Error::new_spanned(
&input,
format!(
"Invalid plugin_id path '{}': {}",
enum_attrs.plugin_id, e
),
)
})?;
let tail = plugin_id_path.segments.last().ok_or_else(|| {
syn::Error::new_spanned(
&input,
format!("plugin_id '{}' has no path segments", enum_attrs.plugin_id),
)
})?;
let tail_ident = &tail.ident;
quote! { <#qualified_type>::#tail_ident }
}
None => {
let plugin_id_path: syn::Path =
syn::parse_str(&enum_attrs.plugin_id).map_err(|e| {
syn::Error::new_spanned(
&input,
format!(
"Invalid plugin_id constant name '{}': {}",
enum_attrs.plugin_id, e
),
)
})?;
quote! { #plugin_id_path }
}
};
let version = &enum_attrs.version;
let variants = extract_variants(&input)?;
if variants.is_empty() {
return Err(syn::Error::new_spanned(
&input,
"HandleEnum requires at least one variant with #[handle(method = \"...\")]",
));
}
let to_handle_arms = generate_to_handle_arms(&variants, &plugin_id_expr, version);
let try_from_arms = generate_try_from_arms(&variants, &crate_path);
let resolution_arms = generate_resolution_arms(&variants, &crate_path);
Ok(quote! {
impl #enum_name {
pub fn to_handle(&self) -> #crate_path::Handle {
match self {
#(#to_handle_arms)*
}
}
pub fn resolution_params(&self) -> Option<#crate_path::HandleResolutionParams> {
match self {
#(#resolution_arms)*
}
}
}
impl From<#enum_name> for #crate_path::Handle {
fn from(h: #enum_name) -> #crate_path::Handle {
h.to_handle()
}
}
impl TryFrom<&#crate_path::Handle> for #enum_name {
type Error = #crate_path::HandleParseError;
fn try_from(handle: &#crate_path::Handle) -> Result<Self, Self::Error> {
if handle.plugin_id != #plugin_id_expr {
return Err(#crate_path::HandleParseError::WrongPlugin {
expected: #plugin_id_expr,
got: handle.plugin_id,
});
}
match handle.method.as_str() {
#(#try_from_arms)*
_ => Err(#crate_path::HandleParseError::UnknownMethod(handle.method.clone()))
}
}
}
})
}
fn extract_enum_attrs(input: &DeriveInput) -> syn::Result<HandleEnumAttrs> {
for attr in &input.attrs {
if attr.path().is_ident("handle") {
if let Meta::List(MetaList { tokens, .. }) = &attr.meta {
return syn::parse2(tokens.clone());
}
}
}
Err(syn::Error::new_spanned(
input,
"HandleEnum requires #[handle(plugin_id = \"...\", version = \"...\")] attribute",
))
}
fn extract_variants(input: &DeriveInput) -> syn::Result<Vec<VariantInfo>> {
let data_enum = match &input.data {
syn::Data::Enum(data) => data,
_ => {
return Err(syn::Error::new_spanned(
input,
"HandleEnum can only be derived for enums",
))
}
};
let mut variants = Vec::new();
for variant in &data_enum.variants {
if let Some(attrs) = parse_variant_attrs(&variant.attrs)? {
let fields: Vec<FieldInfo> = match &variant.fields {
Fields::Named(named) => named
.named
.iter()
.map(|f| {
let name = f.ident.clone().unwrap();
let name_str = name.to_string();
FieldInfo { name, name_str }
})
.collect(),
Fields::Unit => Vec::new(),
Fields::Unnamed(_) => {
return Err(syn::Error::new_spanned(
variant,
"HandleEnum variants must use named fields (e.g., `Variant { field: Type }`)",
))
}
};
let resolution = if let (Some(table), Some(key)) = (attrs.table, attrs.key) {
let key_field_name = attrs.key_field.unwrap_or_else(|| {
fields.first().map(|f| f.name_str.clone()).unwrap_or_default()
});
let key_field_index = fields
.iter()
.position(|f| f.name_str == key_field_name)
.unwrap_or(0);
Some(ResolutionInfo {
table,
key,
key_field_index,
strip_prefix: attrs.strip_prefix,
})
} else {
None
};
variants.push(VariantInfo {
name: variant.ident.clone(),
method: attrs.method,
fields,
resolution,
});
}
}
Ok(variants)
}
fn generate_to_handle_arms(
variants: &[VariantInfo],
plugin_id_expr: &TokenStream2,
version: &str,
) -> Vec<TokenStream2> {
variants
.iter()
.map(|v| {
let variant_name = &v.name;
let method = &v.method;
if v.fields.is_empty() {
quote! {
Self::#variant_name => {
plexus_core::Handle::new(#plugin_id_expr, #version, #method)
}
}
} else {
let field_names: Vec<_> = v.fields.iter().map(|f| &f.name).collect();
let field_clones: Vec<_> = v
.fields
.iter()
.map(|f| {
let name = &f.name;
quote! { #name.clone() }
})
.collect();
quote! {
Self::#variant_name { #(#field_names),* } => {
plexus_core::Handle::new(#plugin_id_expr, #version, #method)
.with_meta(vec![#(#field_clones),*])
}
}
}
})
.collect()
}
fn generate_try_from_arms(variants: &[VariantInfo], crate_path: &syn::Path) -> Vec<TokenStream2> {
variants
.iter()
.map(|v| {
let variant_name = &v.name;
let method = &v.method;
if v.fields.is_empty() {
quote! {
#method => Ok(Self::#variant_name),
}
} else {
let field_extractions: Vec<_> = v
.fields
.iter()
.enumerate()
.map(|(idx, f)| {
let name = &f.name;
let name_str = &f.name_str;
quote! {
let #name = handle.meta.get(#idx)
.ok_or(#crate_path::HandleParseError::MissingMeta {
index: #idx,
field: #name_str,
})?
.clone();
}
})
.collect();
let field_names: Vec<_> = v.fields.iter().map(|f| &f.name).collect();
quote! {
#method => {
#(#field_extractions)*
Ok(Self::#variant_name { #(#field_names),* })
}
}
}
})
.collect()
}
fn generate_resolution_arms(variants: &[VariantInfo], crate_path: &syn::Path) -> Vec<TokenStream2> {
variants
.iter()
.map(|v| {
let variant_name = &v.name;
match &v.resolution {
Some(res) => {
let table = &res.table;
let key = &res.key;
let key_field_index = res.key_field_index;
let key_field_name = &v.fields[key_field_index].name;
let context_items: Vec<_> = v
.fields
.iter()
.enumerate()
.filter(|(idx, _)| *idx != key_field_index)
.map(|(_, f)| {
let name = &f.name;
let name_str = &f.name_str;
quote! {
(#name_str.to_string(), #name.clone())
}
})
.collect();
let key_value_expr = if let Some(prefix) = &res.strip_prefix {
quote! {
#key_field_name.strip_prefix(#prefix).unwrap_or(&#key_field_name).to_string()
}
} else {
quote! {
#key_field_name.clone()
}
};
if v.fields.is_empty() {
quote! {
Self::#variant_name => None,
}
} else {
let field_names: Vec<_> = v.fields.iter().map(|f| &f.name).collect();
quote! {
Self::#variant_name { #(#field_names),* } => Some(#crate_path::HandleResolutionParams {
table: #table,
key_column: #key,
key_value: #key_value_expr,
context: vec![#(#context_items),*],
}),
}
}
}
None => {
if v.fields.is_empty() {
quote! {
Self::#variant_name => None,
}
} else {
quote! {
Self::#variant_name { .. } => None,
}
}
}
}
})
.collect()
}