use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{
FnArg, GenericArgument, ItemTrait, Pat, PathArguments, ReturnType, TraitItem, Type,
TypeParamBound, WherePredicate, parse_macro_input,
};
fn looks_like_result(ty: &Type) -> bool {
match ty {
Type::Path(tp) => tp
.path
.segments
.last()
.map(|s| {
let name = s.ident.to_string();
name == "Result" || name.ends_with("Result")
})
.unwrap_or(false),
_ => false,
}
}
fn result_ok_is_unit(ty: &Type) -> bool {
let Type::Path(tp) = ty else { return false };
let Some(seg) = tp.path.segments.last() else {
return false;
};
let PathArguments::AngleBracketed(args) = &seg.arguments else {
return false;
};
let Some(GenericArgument::Type(first_ty)) = args.args.first() else {
return false;
};
matches!(first_ty, Type::Tuple(t) if t.elems.is_empty())
}
fn result_ok_is_self(ty: &Type) -> bool {
let Type::Path(tp) = ty else { return false };
let Some(seg) = tp.path.segments.last() else {
return false;
};
let PathArguments::AngleBracketed(args) = &seg.arguments else {
return false;
};
let Some(GenericArgument::Type(ok_ty)) = args.args.first() else {
return false;
};
match ok_ty {
Type::Path(p) => p.path.is_ident("Self"),
_ => false,
}
}
fn is_chainable_return(ret: &ReturnType) -> bool {
match ret {
ReturnType::Default => true,
ReturnType::Type(_, ty) => {
if matches!(ty.as_ref(), Type::Path(p) if p.path.is_ident("Self")) {
return true;
}
if matches!(ty.as_ref(), Type::Tuple(t) if t.elems.is_empty()) {
return true;
}
if looks_like_result(ty) {
return result_ok_is_unit(ty) || result_ok_is_self(ty);
}
false
}
}
}
fn is_consuming_receiver(arg: &FnArg) -> bool {
matches!(arg, FnArg::Receiver(r) if r.reference.is_none())
}
fn collect_self_assoc_in_type(ty: &Type, found: &mut Vec<syn::Ident>) {
match ty {
Type::Path(tp) if tp.qself.is_none() => {
let segs: Vec<_> = tp.path.segments.iter().collect();
if segs.len() == 2 && segs[0].ident == "Self" {
let name = segs[1].ident.clone();
if !found.iter().any(|i: &syn::Ident| *i == name) {
found.push(name);
}
}
for seg in &tp.path.segments {
if let PathArguments::AngleBracketed(args) = &seg.arguments {
for ga in &args.args {
if let GenericArgument::Type(inner) = ga {
collect_self_assoc_in_type(inner, found);
}
}
}
}
}
Type::Reference(r) => collect_self_assoc_in_type(&r.elem, found),
Type::Slice(s) => collect_self_assoc_in_type(&s.elem, found),
Type::Array(a) => collect_self_assoc_in_type(&a.elem, found),
Type::Tuple(t) => t
.elems
.iter()
.for_each(|e| collect_self_assoc_in_type(e, found)),
_ => {}
}
}
#[proc_macro_attribute]
pub fn explainable(_args: TokenStream, input: TokenStream) -> TokenStream {
let trait_def = parse_macro_input!(input as ItemTrait);
let trait_name = &trait_def.ident;
let explain_text_trait_name = format_ident!("{}ExplainText", trait_name);
let ext_trait_name = format_ident!("{}Ext", trait_name);
let vis = &trait_def.vis;
let self_methods: Vec<_> = trait_def
.items
.iter()
.filter_map(|item| {
if let TraitItem::Fn(f) = item {
let has_receiver = f
.sig
.inputs
.first()
.map(|a| matches!(a, FnArg::Receiver(_)))
.unwrap_or(false);
let chainable = is_chainable_return(&f.sig.output);
if has_receiver && chainable {
Some(f)
} else {
None
}
} else {
None
}
})
.collect();
let mut assoc_idents: Vec<syn::Ident> = Vec::new();
for m in &self_methods {
for param in m.sig.inputs.iter() {
if let FnArg::Typed(pt) = param {
collect_self_assoc_in_type(&pt.ty, &mut assoc_idents);
}
}
}
let where_bounds: Vec<Vec<&TypeParamBound>> = assoc_idents
.iter()
.map(|name| {
let mut bounds: Vec<&TypeParamBound> = Vec::new();
if let Some(wc) = &trait_def.generics.where_clause {
for pred in &wc.predicates {
if let WherePredicate::Type(pt) = pred {
if let Type::Path(tp) = &pt.bounded_ty {
let segs: Vec<_> = tp.path.segments.iter().collect();
if segs.len() == 2 && segs[0].ident == "Self" && &segs[1].ident == name
{
bounds.extend(pt.bounds.iter());
}
}
}
}
}
bounds
})
.collect();
let ext_assoc_type_decls: Vec<TokenStream2> = assoc_idents
.iter()
.zip(where_bounds.iter())
.map(|(name, bounds)| {
let doc = format!("Associated type `{}` forwarded from the domain type.", name);
if bounds.is_empty() {
quote! {
#[doc = #doc]
type #name;
}
} else {
quote! {
#[doc = #doc]
type #name: #(#bounds)+*;
}
}
})
.collect();
let ext_assoc_type_impls: Vec<TokenStream2> = assoc_idents
.iter()
.map(|name| quote! { type #name = T::#name; })
.collect();
let explain_text_methods: Vec<TokenStream2> = self_methods
.iter()
.map(|m| {
let method_name = &m.sig.ident;
let explain_fn = format_ident!("explain_text_{}", method_name);
let cfg_attrs: Vec<_> = m
.attrs
.iter()
.filter(|a| a.path().is_ident("cfg"))
.collect();
quote! {
#(#cfg_attrs)*
fn #explain_fn(before: &Self, after: &Self) -> String;
}
})
.collect();
let ext_method_sigs: Vec<TokenStream2> = self_methods
.iter()
.map(|m| {
let method_name = &m.sig.ident;
let cfg_attrs: Vec<_> = m
.attrs
.iter()
.filter(|a| a.path().is_ident("cfg"))
.collect();
let non_recv_params: Vec<_> = m
.sig
.inputs
.iter()
.filter(|a| !matches!(a, FnArg::Receiver(_)))
.collect();
quote! {
#(#cfg_attrs)*
fn #method_name(&mut self, #(#non_recv_params),*) -> &mut Self;
}
})
.collect();
let ext_method_impls: Vec<TokenStream2> = self_methods
.iter()
.map(|m| {
let method_name = &m.sig.ident;
let explain_fn = format_ident!("explain_text_{}", method_name);
let cfg_attrs: Vec<_> = m
.attrs
.iter()
.filter(|a| a.path().is_ident("cfg"))
.collect();
let non_recv_params: Vec<_> = m
.sig
.inputs
.iter()
.filter(|a| !matches!(a, FnArg::Receiver(_)))
.collect();
let arg_idents: Vec<_> = non_recv_params
.iter()
.filter_map(|a| {
if let FnArg::Typed(pt) = a {
if let Pat::Ident(pi) = pt.pat.as_ref() {
Some(&pi.ident)
} else {
None
}
} else {
None
}
})
.collect();
let consuming = m
.sig
.inputs
.first()
.map(|a| is_consuming_receiver(a))
.unwrap_or(false);
let (is_result, is_void) = match &m.sig.output {
ReturnType::Type(_, ty) => {
let r = looks_like_result(ty);
(r, r && result_ok_is_unit(ty))
}
ReturnType::Default => (false, true),
};
let update_inner = if is_void {
if is_result {
quote! { self.inner.#method_name(#(#arg_idents),*).unwrap(); }
} else {
quote! { self.inner.#method_name(#(#arg_idents),*); }
}
} else if consuming {
if is_result {
quote! {
let __taken = ::std::mem::replace(&mut self.inner, before.clone());
self.inner = __taken.#method_name(#(#arg_idents),*).unwrap();
}
} else {
quote! {
let __taken = ::std::mem::replace(&mut self.inner, before.clone());
self.inner = __taken.#method_name(#(#arg_idents),*);
}
}
} else if is_result {
quote! { self.inner = self.inner.#method_name(#(#arg_idents),*).unwrap(); }
} else {
quote! { self.inner = self.inner.#method_name(#(#arg_idents),*); }
};
quote! {
#(#cfg_attrs)*
fn #method_name(&mut self, #(#non_recv_params),*) -> &mut Self {
let before = self.inner.clone();
#update_inner
let text = match self.mode {
::explainable::ExplainMode::Text
| ::explainable::ExplainMode::Both => Some(
<T as #explain_text_trait_name>::#explain_fn(
&before,
&self.inner,
),
),
_ => None,
};
let visual = match self.mode {
::explainable::ExplainMode::Visual
| ::explainable::ExplainMode::Both => Some(
<T as ::explainable::RenderVisual>::render_visual(
&before,
&self.inner,
),
),
_ => None,
};
self.explanations.push(::explainable::Explanation::new(
self.mode,
text,
visual,
));
self
}
}
})
.collect();
let explain_text_doc = format!(
"Companion text trait generated by `#[explainable]` for [`{}`].\n\n\
Implement one `explain_text_<method>` per operation to supply the pedagogical \
text explanation shown when that operation runs inside an explaining chain.",
trait_name
);
let ext_trait_doc = format!(
"Extension trait generated by `#[explainable]` for [`{}`].\n\n\
Bring this into scope to call `{}` operations on an \
[`explainable::Explaining`] chain.",
trait_name, trait_name
);
let output = quote! {
#trait_def
#[doc = #explain_text_doc]
#[allow(missing_docs)]
#vis trait #explain_text_trait_name:
::explainable::Explainable + #trait_name
{
#(#explain_text_methods)*
}
#[doc = #ext_trait_doc]
#[allow(missing_docs)]
#vis trait #ext_trait_name {
#(#ext_assoc_type_decls)*
#(#ext_method_sigs)*
}
#[allow(missing_docs)]
impl<T> #ext_trait_name for ::explainable::Explaining<T>
where
T: ::explainable::Explainable + #trait_name + #explain_text_trait_name,
{
#(#ext_assoc_type_impls)*
#(#ext_method_impls)*
}
};
output.into()
}