use std::collections::BTreeMap;
use std::collections::BTreeSet;
use proc_macro::TokenStream;
use proc_macro2::Span;
use proc_macro2::TokenStream as TokenStream2;
use quote::ToTokens;
use quote::format_ident;
use quote::quote;
use syn::Data;
use syn::DataEnum;
use syn::DataStruct;
use syn::DeriveInput;
use syn::Fields;
use syn::GenericArgument;
use syn::GenericParam;
use syn::Ident;
use syn::ItemFn;
use syn::ItemImpl;
use syn::PathArguments;
use syn::Token;
use syn::Type;
use syn::Visibility;
use syn::parse::Parse;
use syn::parse::ParseStream;
use syn::parse_macro_input;
use syn::punctuated::Punctuated;
#[proc_macro_derive(RefView, attributes(ref_view))]
pub fn derive_ref_view(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
expand_ref_view(input)
.unwrap_or_else(|err| err.to_compile_error())
.into()
}
#[proc_macro_attribute]
pub fn impl_trait(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as ImplTraitArgs);
let item = parse_macro_input!(item as ItemImpl);
expand_impl_trait(args, item)
.unwrap_or_else(|err| err.to_compile_error())
.into()
}
#[proc_macro_attribute]
pub fn impl_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
let trait_path = parse_macro_input!(attr as syn::Path);
let item = parse_macro_input!(item as ItemFn);
expand_impl_fn(trait_path, item)
.unwrap_or_else(|err| err.to_compile_error())
.into()
}
fn expand_ref_view(input: DeriveInput) -> syn::Result<TokenStream2> {
if !input.generics.params.is_empty() {
return Err(syn::Error::new_spanned(
input.generics,
"RefView does not support generic types yet",
));
}
let type_ident = input.ident;
let vis = input.vis;
let config = Config::parse(&input.attrs, &type_ident)?;
match input.data {
Data::Struct(data) => expand_struct(&vis, &type_ident, data, config),
Data::Enum(data) => expand_enum(&vis, &type_ident, data, config),
Data::Union(data) => Err(syn::Error::new_spanned(
data.union_token,
"RefView does not support unions",
)),
}
}
#[derive(Debug, Default)]
struct Config {
trait_name: Option<Ident>,
derives: Vec<syn::Path>,
views: Vec<ViewSpec>,
}
#[derive(Debug, Clone)]
struct ViewSpec {
name: Ident,
omitted: BTreeSet<OmitTarget>,
}
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
enum OmitTarget {
Field(String),
VariantField { variant: String, field: String },
}
impl Config {
fn parse(attrs: &[syn::Attribute], type_ident: &Ident) -> syn::Result<Self> {
let mut config = Config::default();
for attr in attrs {
if !attr.path().is_ident("ref_view") {
continue;
}
let mut view_name = None;
let mut omitted = BTreeSet::new();
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("trait_name") {
let value = meta.value()?;
config.trait_name = Some(value.parse()?);
return Ok(());
}
if meta.path.is_ident("name") {
let value = meta.value()?;
view_name = Some(value.parse()?);
return Ok(());
}
if meta.path.is_ident("derive") {
let content;
syn::parenthesized!(content in meta.input);
while !content.is_empty() {
config.derives.push(content.parse()?);
if content.is_empty() {
break;
}
content.parse::<Token![,]>()?;
}
return Ok(());
}
if meta.path.is_ident("omit") {
let content;
syn::parenthesized!(content in meta.input);
while !content.is_empty() {
omitted.insert(parse_omit_target(&content)?);
if content.is_empty() {
break;
}
content.parse::<Token![,]>()?;
}
return Ok(());
}
Err(meta.error("unsupported ref_view option"))
})?;
if let Some(name) = view_name {
config.views.push(ViewSpec { name, omitted });
}
}
let default_ref = format_ident!("{}Ref", type_ident);
if !config.views.iter().any(|view| view.name == default_ref) {
config.views.insert(
0,
ViewSpec {
name: default_ref,
omitted: BTreeSet::new(),
},
);
}
Ok(config)
}
fn trait_ident(&self, type_ident: &Ident) -> Ident {
self.trait_name
.clone()
.unwrap_or_else(|| format_ident!("{}View", type_ident))
}
fn view_derives(&self) -> Vec<syn::Path> {
let mut derives = vec![syn::parse_quote!(Clone), syn::parse_quote!(Copy)];
let mut seen = derives.iter().map(type_key_path).collect::<BTreeSet<_>>();
for derive in &self.derives {
if seen.insert(type_key_path(derive)) {
derives.push(derive.clone());
}
}
derives
}
}
fn parse_omit_target(input: syn::parse::ParseStream) -> syn::Result<OmitTarget> {
let first = input.parse::<Ident>()?;
if input.peek(Token![.]) {
input.parse::<Token![.]>()?;
let second = input.parse::<Ident>()?;
Ok(OmitTarget::VariantField {
variant: first.to_string(),
field: second.to_string(),
})
} else {
Ok(OmitTarget::Field(first.to_string()))
}
}
#[derive(Debug, Clone)]
struct FieldInfo {
ident: Ident,
ty: Type,
option_inner: Option<Type>,
returns_option: bool,
}
impl FieldInfo {
fn new(ident: Ident, ty: Type) -> Self {
let option_inner = option_inner_type(&ty);
Self {
ident,
ty,
option_inner,
returns_option: false,
}
}
fn accessor_return_ty(&self) -> TokenStream2 {
let ty = self.return_inner_ty();
if self.returns_option || self.option_inner.is_some() {
quote! { ::core::option::Option<&#ty> }
} else {
quote! { & #ty }
}
}
fn return_inner_ty(&self) -> &Type {
self.option_inner.as_ref().unwrap_or(&self.ty)
}
fn view_field_ty(&self) -> TokenStream2 {
let ty = &self.ty;
quote! { &'a #ty }
}
fn full_value_expr(&self, receiver: TokenStream2) -> TokenStream2 {
let ident = &self.ident;
if self.option_inner.is_some() {
quote! { #receiver.#ident.as_ref() }
} else if self.returns_option {
quote! { ::core::option::Option::Some(&#receiver.#ident) }
} else {
quote! { &#receiver.#ident }
}
}
fn view_value_expr(&self, receiver: TokenStream2) -> TokenStream2 {
let ident = &self.ident;
if self.option_inner.is_some() {
quote! { #receiver.#ident.as_ref() }
} else if self.returns_option {
quote! { ::core::option::Option::Some(#receiver.#ident) }
} else {
quote! { #receiver.#ident }
}
}
fn enum_field_expr(&self, field: TokenStream2) -> TokenStream2 {
if self.option_inner.is_some() {
quote! { #field.as_ref() }
} else {
quote! { ::core::option::Option::Some(#field) }
}
}
}
fn expand_struct(
vis: &Visibility,
type_ident: &Ident,
data: DataStruct,
config: Config,
) -> syn::Result<TokenStream2> {
let Fields::Named(fields) = data.fields else {
return Err(syn::Error::new_spanned(
data.struct_token,
"RefView only supports structs with named fields",
));
};
let mut field_infos = fields
.named
.into_iter()
.map(|field| {
let ident = field.ident.expect("named field");
FieldInfo::new(ident, field.ty)
})
.collect::<Vec<_>>();
let omitted_fields = config
.views
.iter()
.flat_map(|view| view.omitted.iter())
.map(|target| match target {
OmitTarget::Field(field) => Ok(field.clone()),
OmitTarget::VariantField { .. } => Err(syn::Error::new(
Span::call_site(),
"struct omit targets must use `field`, not `Variant.field`",
)),
})
.collect::<syn::Result<BTreeSet<_>>>()?;
for field in &mut field_infos {
field.returns_option =
field.option_inner.is_some() || omitted_fields.contains(&field.ident.to_string());
}
validate_struct_omits(&field_infos, &omitted_fields)?;
let trait_ident = config.trait_ident(type_ident);
let view_derives = config.view_derives();
let trait_def = generate_trait(vis, &trait_ident, &field_infos);
let source_trait_impl =
generate_struct_source_trait_impl(&trait_ident, type_ident, &field_infos);
let views = config
.views
.iter()
.map(|view| {
generate_struct_view(
vis,
type_ident,
&trait_ident,
&field_infos,
view,
&view_derives,
)
})
.collect::<syn::Result<Vec<_>>>()?;
Ok(quote! {
#trait_def
#source_trait_impl
#(#views)*
})
}
fn validate_struct_omits(
fields: &[FieldInfo],
omitted_fields: &BTreeSet<String>,
) -> syn::Result<()> {
let known = fields
.iter()
.map(|field| field.ident.to_string())
.collect::<BTreeSet<_>>();
for omitted in omitted_fields {
if !known.contains(omitted) {
return Err(syn::Error::new(
Span::call_site(),
format!("unknown omitted field `{}`", omitted),
));
}
}
Ok(())
}
fn generate_trait(vis: &Visibility, trait_ident: &Ident, fields: &[FieldInfo]) -> TokenStream2 {
let methods = fields.iter().map(|field| {
let ident = &field.ident;
let return_ty = field.accessor_return_ty();
quote! {
fn #ident(&self) -> #return_ty;
}
});
quote! {
#vis trait #trait_ident {
#(#methods)*
}
}
}
fn generate_struct_source_trait_impl(
trait_ident: &Ident,
type_ident: &Ident,
fields: &[FieldInfo],
) -> TokenStream2 {
let methods = fields.iter().map(|field| {
let ident = &field.ident;
let return_ty = field.accessor_return_ty();
let value = field.full_value_expr(quote! { self });
quote! {
fn #ident(&self) -> #return_ty {
#value
}
}
});
quote! {
impl #trait_ident for #type_ident {
#(#methods)*
}
}
}
fn generate_struct_view(
vis: &Visibility,
type_ident: &Ident,
trait_ident: &Ident,
fields: &[FieldInfo],
view: &ViewSpec,
view_derives: &[syn::Path],
) -> syn::Result<TokenStream2> {
let view_ident = &view.name;
let omitted = view
.omitted
.iter()
.map(|target| match target {
OmitTarget::Field(field) => Ok(field.clone()),
OmitTarget::VariantField { .. } => Err(syn::Error::new(
Span::call_site(),
"struct omit targets must use `field`, not `Variant.field`",
)),
})
.collect::<syn::Result<BTreeSet<_>>>()?;
let included_fields = fields
.iter()
.filter(|field| !omitted.contains(&field.ident.to_string()))
.collect::<Vec<_>>();
let struct_fields = included_fields.iter().map(|field| {
let ident = &field.ident;
let ty = field.view_field_ty();
quote! { #ident: #ty }
});
let from_bindings = included_fields.iter().map(|field| {
let ident = &field.ident;
quote! { #ident: &value.#ident }
});
let methods = fields.iter().map(|field| {
let ident = &field.ident;
let return_ty = field.accessor_return_ty();
let body = if omitted.contains(&field.ident.to_string()) {
quote! { ::core::option::Option::None }
} else {
field.view_value_expr(quote! { self })
};
quote! {
fn #ident(&self) -> #return_ty {
#body
}
}
});
Ok(quote! {
#[derive(#(#view_derives),*)]
#vis struct #view_ident<'a> {
#(#struct_fields,)*
}
impl<'a> ::core::convert::From<&'a #type_ident> for #view_ident<'a> {
fn from(value: &'a #type_ident) -> Self {
Self {
#(#from_bindings,)*
}
}
}
impl<'a> #trait_ident for #view_ident<'a> {
#(#methods)*
}
})
}
#[derive(Debug, Clone)]
struct VariantInfo {
ident: Ident,
fields: Vec<FieldInfo>,
}
fn expand_enum(
vis: &Visibility,
type_ident: &Ident,
data: DataEnum,
config: Config,
) -> syn::Result<TokenStream2> {
let variants = data
.variants
.into_iter()
.map(|variant| {
let Fields::Named(fields) = variant.fields else {
return Err(syn::Error::new_spanned(
variant.ident,
"RefView only supports enum variants with named fields",
));
};
let fields = fields
.named
.into_iter()
.map(|field| {
let ident = field.ident.expect("named field");
let mut info = FieldInfo::new(ident, field.ty);
info.returns_option = true;
info
})
.collect::<Vec<_>>();
Ok(VariantInfo {
ident: variant.ident,
fields,
})
})
.collect::<syn::Result<Vec<_>>>()?;
let fields = collect_enum_fields(&variants)?;
validate_enum_omits(&variants, &config.views)?;
let trait_ident = config.trait_ident(type_ident);
let view_derives = config.view_derives();
let trait_def = generate_trait(vis, &trait_ident, &fields);
let source_trait_impl =
generate_enum_source_trait_impl(&trait_ident, type_ident, &variants, &fields);
let views = config
.views
.iter()
.map(|view| {
generate_enum_view(
vis,
type_ident,
&trait_ident,
&variants,
&fields,
view,
&view_derives,
)
})
.collect::<syn::Result<Vec<_>>>()?;
Ok(quote! {
#trait_def
#source_trait_impl
#(#views)*
})
}
fn collect_enum_fields(variants: &[VariantInfo]) -> syn::Result<Vec<FieldInfo>> {
let mut fields = BTreeMap::<String, FieldInfo>::new();
for variant in variants {
for field in &variant.fields {
let name = field.ident.to_string();
if let Some(existing) = fields.get(&name) {
if type_key(existing.return_inner_ty()) != type_key(field.return_inner_ty()) {
return Err(syn::Error::new_spanned(
&field.ident,
format!(
"field `{}` appears with different types across variants",
name
),
));
}
} else {
fields.insert(name, field.clone());
}
}
}
Ok(fields.into_values().collect())
}
fn validate_enum_omits(variants: &[VariantInfo], views: &[ViewSpec]) -> syn::Result<()> {
let known = variants
.iter()
.flat_map(|variant| {
let variant_name = variant.ident.to_string();
variant
.fields
.iter()
.map(move |field| (variant_name.clone(), field.ident.to_string()))
})
.collect::<BTreeSet<_>>();
for view in views {
for target in &view.omitted {
match target {
OmitTarget::VariantField { variant, field } => {
if !known.contains(&(variant.clone(), field.clone())) {
return Err(syn::Error::new(
Span::call_site(),
format!("unknown omitted field `{}.{}`", variant, field),
));
}
}
OmitTarget::Field(field) => {
if !known.iter().any(|(_, known_field)| known_field == field) {
return Err(syn::Error::new(
Span::call_site(),
format!("unknown omitted field `{}`", field),
));
}
}
}
}
}
Ok(())
}
fn generate_enum_source_trait_impl(
trait_ident: &Ident,
type_ident: &Ident,
variants: &[VariantInfo],
fields: &[FieldInfo],
) -> TokenStream2 {
let methods = fields.iter().map(|field| {
let field_ident = &field.ident;
let return_ty = field.accessor_return_ty();
let arms = variants.iter().map(|variant| {
let variant_ident = &variant.ident;
if variant
.fields
.iter()
.any(|field| field.ident == *field_ident)
{
let value = field.enum_field_expr(quote! { #field_ident });
quote! {
#type_ident::#variant_ident { #field_ident, .. } => #value
}
} else {
quote! {
#type_ident::#variant_ident { .. } => ::core::option::Option::None
}
}
});
quote! {
fn #field_ident(&self) -> #return_ty {
match self {
#(#arms,)*
}
}
}
});
quote! {
impl #trait_ident for #type_ident {
#(#methods)*
}
}
}
fn generate_enum_view(
vis: &Visibility,
type_ident: &Ident,
trait_ident: &Ident,
variants: &[VariantInfo],
fields: &[FieldInfo],
view: &ViewSpec,
view_derives: &[syn::Path],
) -> syn::Result<TokenStream2> {
let view_ident = &view.name;
let variant_defs = variants.iter().map(|variant| {
let variant_ident = &variant.ident;
let fields = variant.fields.iter().filter_map(|field| {
if is_enum_field_omitted(view, &variant.ident, &field.ident) {
None
} else {
let field_ident = &field.ident;
let ty = field.view_field_ty();
Some(quote! { #field_ident: #ty })
}
});
quote! { #variant_ident { #(#fields,)* } }
});
let from_arms = variants.iter().map(|variant| {
let variant_ident = &variant.ident;
let included = variant
.fields
.iter()
.filter(|field| !is_enum_field_omitted(view, &variant.ident, &field.ident))
.collect::<Vec<_>>();
let pattern_fields = included.iter().map(|field| &field.ident);
let value_fields = included.iter().map(|field| {
let field_ident = &field.ident;
quote! { #field_ident }
});
quote! {
#type_ident::#variant_ident { #(#pattern_fields,)* .. } => {
Self::#variant_ident { #(#value_fields,)* }
}
}
});
let methods = fields.iter().map(|field| {
let field_ident = &field.ident;
let return_ty = field.accessor_return_ty();
let arms = variants.iter().map(|variant| {
let variant_ident = &variant.ident;
let contains_field = variant
.fields
.iter()
.any(|field| field.ident == *field_ident);
if contains_field && !is_enum_field_omitted(view, &variant.ident, field_ident) {
let value = field.enum_field_expr(quote! { #field_ident });
quote! {
#view_ident::#variant_ident { #field_ident, .. } => #value
}
} else {
quote! {
#view_ident::#variant_ident { .. } => ::core::option::Option::None
}
}
});
quote! {
fn #field_ident(&self) -> #return_ty {
match self {
#(#arms,)*
}
}
}
});
Ok(quote! {
#[derive(#(#view_derives),*)]
#vis enum #view_ident<'a> {
#(#variant_defs,)*
}
impl<'a> ::core::convert::From<&'a #type_ident> for #view_ident<'a> {
fn from(value: &'a #type_ident) -> Self {
match value {
#(#from_arms,)*
}
}
}
impl<'a> #trait_ident for #view_ident<'a> {
#(#methods)*
}
})
}
struct ImplTraitArgs {
view_trait: syn::Path,
target_types: Punctuated<Type, Token![,]>,
}
impl Parse for ImplTraitArgs {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let view_trait = input.parse()?;
input.parse::<Token![=>]>()?;
Ok(Self {
view_trait,
target_types: Punctuated::parse_terminated(input)?,
})
}
}
fn expand_impl_trait(args: ImplTraitArgs, item: ItemImpl) -> syn::Result<TokenStream2> {
if args.target_types.is_empty() {
return Err(syn::Error::new(
Span::call_site(),
"impl_trait requires at least one target type after `=>`",
));
}
let _view_trait = &args.view_trait;
let impls = args.target_types.iter().map(|target_ty| {
let mut item = item.clone();
item.self_ty = Box::new(target_ty.clone());
item
});
Ok(quote! {
#item
#(#impls)*
})
}
fn expand_impl_fn(trait_path: syn::Path, mut item: ItemFn) -> syn::Result<TokenStream2> {
if item.sig.generics.params.iter().any(|param| match param {
GenericParam::Type(param) => param.ident == "__RefView",
_ => false,
}) {
return Err(syn::Error::new_spanned(
item.sig.generics,
"impl_fn reserves the generic parameter name `__RefView`",
));
}
let Some(first_arg) = item.sig.inputs.first_mut() else {
return Err(syn::Error::new_spanned(
item.sig.ident,
"impl_fn requires a first argument",
));
};
let syn::FnArg::Typed(first_arg) = first_arg else {
return Err(syn::Error::new_spanned(
first_arg,
"impl_fn requires the first argument to be a typed argument",
));
};
let Type::Reference(reference) = first_arg.ty.as_mut() else {
return Err(syn::Error::new_spanned(
&first_arg.ty,
"impl_fn requires the first argument to be a shared reference",
));
};
if reference.mutability.is_some() {
return Err(syn::Error::new_spanned(
&first_arg.ty,
"impl_fn does not support mutable first arguments",
));
}
reference.elem = Box::new(syn::parse_quote! { __RefView });
item.sig
.generics
.params
.push(syn::parse_quote! { __RefView: #trait_path + ?Sized });
Ok(quote! {
#item
})
}
fn is_enum_field_omitted(view: &ViewSpec, variant: &Ident, field: &Ident) -> bool {
let variant = variant.to_string();
let field = field.to_string();
view.omitted.contains(&OmitTarget::Field(field.clone()))
|| view
.omitted
.contains(&OmitTarget::VariantField { variant, field })
}
fn option_inner_type(ty: &Type) -> Option<Type> {
let Type::Path(type_path) = ty else {
return None;
};
let segment = type_path.path.segments.last()?;
if segment.ident != "Option" {
return None;
}
let PathArguments::AngleBracketed(args) = &segment.arguments else {
return None;
};
let Some(GenericArgument::Type(inner)) = args.args.first() else {
return None;
};
Some(inner.clone())
}
fn type_key(ty: &Type) -> String {
ty.to_token_stream().to_string()
}
fn type_key_path(path: &syn::Path) -> String {
path.to_token_stream().to_string()
}