use proc_macro::TokenStream;
use quote::quote;
use syn::{
parse::Parser,
punctuated::Punctuated,
spanned::Spanned,
Error,
GenericArgument,
Ident,
ImplItem,
ItemImpl,
PathArguments,
ReturnType,
Token,
Type,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum OpKind {
Binary, BinaryAssign, Unary, }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Variant {
Owned, Borrowed, Flipped, FlippedCommutative,
}
impl Variant {
fn from_str(s: &str, op_kind: OpKind) -> Result<Self, String> {
match (s, op_kind) {
("owned", OpKind::Binary | OpKind::BinaryAssign) => Ok(Self::Owned),
("borrowed", OpKind::Binary | OpKind::Unary) => Ok(Self::Borrowed),
("flipped", OpKind::Binary) => Ok(Self::Flipped),
("flipped_commutative", OpKind::Binary) => Ok(Self::FlippedCommutative),
_ => Err(format!("Invalid variant '{s}' for {op_kind:?}")),
}
}
}
pub fn op_variants(attr: TokenStream, item: TokenStream) -> TokenStream {
_op_variants(attr, item).unwrap_or_else(|e| e.to_compile_error().into())
}
fn _op_variants(attr: TokenStream, item: TokenStream) -> Result<TokenStream, Error> {
let impl_block: ItemImpl = syn::parse(item)?;
let attrs = Punctuated::<Ident, Token![,]>::parse_terminated.parse(attr)?;
let (_, trait_path, _) = (impl_block.trait_.as_ref())
.ok_or_else(|| Error::new_spanned(&impl_block, "Expected trait impl"))?;
let trait_name = trait_path.segments.last().unwrap().ident.clone();
let rhs_ty = trait_path.segments.last().and_then(|seg| {
let PathArguments::AngleBracketed(args) = &seg.arguments else {
return None;
};
args.args.first().and_then(|arg| match arg {
GenericArgument::Type(ty) => Some(ty.clone()),
_ => None,
})
});
let methods: Vec<_> = (impl_block.items.iter())
.filter_map(|item| match item {
ImplItem::Fn(m) => Some(m),
_ => None,
})
.collect();
if methods.is_empty() {
return Err(Error::new_spanned(&impl_block, "No method found"));
}
if methods.len() > 1 {
return Err(Error::new_spanned(&impl_block, "Multiple methods found"));
}
let op_fn = methods.first().unwrap();
let op = op_fn.sig.ident.clone();
let has_return = !matches!(op_fn.sig.output, ReturnType::Default);
let op_kind = match (rhs_ty.is_some(), has_return) {
(true, true) => OpKind::Binary,
(true, false) => OpKind::BinaryAssign,
(false, true) => OpKind::Unary,
_ => return Err(Error::new_spanned(op_fn, "Invalid trait signature")),
};
let out_ty = if has_return {
(impl_block.items.iter())
.find_map(|item| match (item, &op_fn.sig.output) {
(ImplItem::Type(t), _) if t.ident == "Output" => Some(t.ty.clone()),
(_, ReturnType::Type(_, ty)) => Some((**ty).clone()),
_ => None,
})
.ok_or_else(|| Error::new_spanned(op_fn, "Cannot determine output type"))?
} else {
syn::parse_quote!(())
};
let variants: Result<Vec<_>, _> = (attrs.iter())
.map(|ident| Variant::from_str(&ident.to_string(), op_kind))
.collect();
let variants = variants.map_err(|e| Error::new(attrs.span(), e))?;
let ty = (*impl_block.self_ty).clone();
if matches!(ty, Type::Reference(_)) {
return Err(Error::new_spanned(
&ty,
"impl must be for an owned type, not a reference",
));
}
let gen = impl_block.generics.clone();
let wc = gen.where_clause.clone();
let mut output = quote! { #impl_block };
let ref_ty: Type = syn::parse_quote!(&#ty);
for variant in variants {
let (rhs_own, rhs_ref) = owned_and_ref(rhs_ty.as_ref().unwrap_or(&ty));
let (impl_ty, rhs_ty, op_body) = match (variant, op_kind) {
(Variant::Owned, _) => (&ty, Some(rhs_own), quote! { self.#op(&rhs) }),
(Variant::Borrowed, OpKind::Binary) => {
(&ref_ty, Some(rhs_ref), quote! { self.to_owned().#op(rhs) })
}
(Variant::Borrowed, OpKind::Unary) => (&ref_ty, None, quote! { self.to_owned().#op() }),
(Variant::Flipped, OpKind::Binary) => {
(&ref_ty, Some(rhs_own), quote! { self.to_owned().#op(&rhs) })
}
(Variant::FlippedCommutative, OpKind::Binary) => {
(&ref_ty, Some(rhs_own), quote! { rhs.#op(self) })
}
_ => unreachable!(),
};
let needs_clone = matches!(variant, Variant::Borrowed | Variant::Flipped);
let wc = if needs_clone {
let mut wc = wc.clone().unwrap_or_else(|| syn::parse_quote!(where));
wc.predicates.push(syn::parse_quote!(#ty: Clone));
Some(wc)
} else {
wc.clone()
};
let impl_block = match (rhs_ty, op_kind) {
(Some(rhs), OpKind::Binary) => quote! {
impl #gen #trait_name<#rhs> for #impl_ty #wc {
type Output = #out_ty;
#[inline]
fn #op(self, rhs: #rhs) -> Self::Output { #op_body }
}
},
(Some(rhs), OpKind::BinaryAssign) => quote! {
impl #gen #trait_name<#rhs> for #impl_ty #wc {
#[inline]
fn #op(&mut self, rhs: #rhs) { #op_body }
}
},
(None, OpKind::Unary) => quote! {
impl #gen #trait_name for #impl_ty #wc {
type Output = #out_ty;
#[inline]
fn #op(self) -> Self::Output { #op_body }
}
},
_ => unreachable!(),
};
output.extend(impl_block);
}
Ok(output.into())
}
fn owned_and_ref(ty: &Type) -> (Type, Type) {
match ty {
Type::Reference(r) => ((*r.elem).clone(), ty.clone()),
_ => (ty.clone(), syn::parse_quote!(&#ty)),
}
}