use quote::{quote, quote_spanned};
use syn::{
AngleBracketedGenericArguments, Block, FnArg, GenericArgument, Ident, ItemTrait, Pat, PatWild,
PathArguments, ReturnType, Token, TraitItem, TraitItemFn, Type, TypeParamBound, parse2,
punctuated::Punctuated, spanned::Spanned, token::Comma,
};
pub fn double_trait(org_trait: ItemTrait) -> syn::Result<ItemTrait> {
let items = org_trait
.items
.into_iter()
.map(|item| transform_trait_item(item, org_trait.ident.clone()))
.collect::<syn::Result<_>>()?;
Ok(ItemTrait { items, ..org_trait })
}
fn transform_trait_item(trait_item: TraitItem, double_trait_name: Ident) -> syn::Result<TraitItem> {
let transformed_trait_item = match trait_item {
TraitItem::Fn(fn_item) => TraitItem::Fn(transform_function(fn_item, double_trait_name)?),
_ => {
trait_item
}
};
Ok(transformed_trait_item)
}
fn transform_function(
mut fn_item: TraitItemFn,
double_trait_name: Ident,
) -> syn::Result<TraitItemFn> {
if fn_item.default.is_some() {
return Ok(fn_item);
}
strip_parameter_names(&mut fn_item.sig.inputs);
let return_type_info = return_type_info(&fn_item.sig.output);
let fn_name = fn_item.sig.ident.clone();
let default_impl = return_type_info.default_impl(&fn_item, double_trait_name, fn_name);
fn_item.default = Some(default_impl);
Ok(fn_item)
}
fn strip_parameter_names(input: &mut Punctuated<FnArg, Comma>) {
for arg in input {
if let FnArg::Typed(pat_type) = arg {
*pat_type.pat = Pat::Wild(PatWild {
attrs: Vec::new(),
underscore_token: Token),
})
}
}
}
fn return_type_info(output: &ReturnType) -> ReturnTypeInfo {
if let ReturnType::Type(_rarrow, ty) = output {
type_info(ty)
} else {
ReturnTypeInfo::Empty
}
}
fn type_info(ty: &Type) -> ReturnTypeInfo {
match *ty {
Type::ImplTrait(ref impl_trait) => {
let mut trait_bounds = impl_trait.bounds.iter().filter_map(|b| match b {
TypeParamBound::Trait(trait_bound) => Some(trait_bound),
TypeParamBound::Lifetime(_)
| TypeParamBound::PreciseCapture(_)
| TypeParamBound::Verbatim(_)
| _ => None,
});
let first_trait_bound = trait_bounds
.next()
.expect("At least one trait bound expected in impl trait.");
let first_path_segment = first_trait_bound
.path
.segments
.first()
.expect("There must be at least one path segment in trait bound");
let identifier = &first_path_segment.ident.to_string();
match identifier.as_str() {
"Future" => {
let output = assoctiated_type(&first_path_segment.arguments, "Output");
ReturnTypeInfo::ImplFuture {
output: output.map(|ty| Box::new(type_info(ty))),
}
}
"Iterator" => {
let item = assoctiated_type(&first_path_segment.arguments, "Item");
ReturnTypeInfo::ImplIterator {
item: item.map(|ty| Box::new(type_info(ty))),
}
}
_ => ReturnTypeInfo::UnknownImpl,
}
}
Type::Tuple(ref tuple_type) => {
if tuple_type.elems.is_empty() {
ReturnTypeInfo::Empty
} else {
ReturnTypeInfo::Other
}
}
_ => ReturnTypeInfo::Other,
}
}
fn assoctiated_type<'a>(
future_trait_args: &'a PathArguments,
associated: &str,
) -> Option<&'a Type> {
let PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) =
future_trait_args
else {
return None;
};
args.iter()
.filter_map(|arg| {
let GenericArgument::AssocType(at) = arg else {
return None;
};
Some(at)
})
.find(|at| at.ident == associated)
.map(|at| &at.ty)
}
#[derive(Debug)]
enum ReturnTypeInfo {
Empty,
ImplFuture {
output: Option<Box<ReturnTypeInfo>>,
},
ImplIterator {
item: Option<Box<ReturnTypeInfo>>,
},
UnknownImpl,
Other,
}
impl ReturnTypeInfo {
fn default_impl(
&self,
fn_item: &TraitItemFn,
double_trait_name: Ident,
fn_name: Ident,
) -> Block {
match self {
ReturnTypeInfo::ImplFuture { output } => {
let output = output.as_deref().unwrap_or(&ReturnTypeInfo::Other);
let inner = output.default_impl(fn_item, double_trait_name, fn_name);
parse2(quote! {{ async #inner }}).unwrap()
}
ReturnTypeInfo::ImplIterator { item } => {
let item = item.as_deref().unwrap_or(&ReturnTypeInfo::Other);
let inner = item.default_impl(fn_item, double_trait_name, fn_name);
parse2(quote! {{
#[allow(unreachable_code)]
std::iter::from_fn(move || {
if false {
Some(#inner)
} else {
None
}
})
}})
.unwrap()
}
ReturnTypeInfo::Other => {
parse2(quote! {{
let double_trait_name = stringify!(#double_trait_name);
let fn_name = stringify!(#fn_name);
unimplemented!("{double_trait_name}::{fn_name}")
}})
.unwrap()
}
ReturnTypeInfo::Empty => {
parse2(quote! { { } }).unwrap()
}
ReturnTypeInfo::UnknownImpl => parse2(quote_spanned! {
fn_item.sig.output.span() => {
compile_error!(
"impl Trait is currently not supported by double-trait. Apart from the \
special case of impl Future."
)}
})
.unwrap(),
}
}
}
#[cfg(test)]
mod tests {
use super::{ReturnTypeInfo, double_trait, return_type_info};
use quote::quote;
use syn::{ItemTrait, ReturnType, parse2};
#[test]
fn return_type_info_unit() {
let rt: ReturnType = parse2(quote! {-> () }).unwrap();
assert!(matches!(return_type_info(&rt), ReturnTypeInfo::Empty));
}
#[test]
fn return_type_info_i34() {
let rt: ReturnType = parse2(quote! {-> i32 }).unwrap();
assert!(matches!(return_type_info(&rt), ReturnTypeInfo::Other));
}
#[test]
fn return_type_info_impl_future_i32() {
let rt: ReturnType = parse2(quote! {-> impl Future<Output = i32> }).unwrap();
let ReturnTypeInfo::ImplFuture {
output: Some(output),
} = return_type_info(&rt)
else {
panic!("Expected ReturnTypeInfo::ImplFuture with Some output");
};
assert!(matches!(*output, ReturnTypeInfo::Other));
}
#[test]
fn return_type_info_impl_future_unit() {
let rt: ReturnType = parse2(quote! {-> impl Future<Output = ()> }).unwrap();
let ReturnTypeInfo::ImplFuture {
output: Some(output),
} = return_type_info(&rt)
else {
panic!("Expected ReturnTypeInfo::ImplFuture with Some output");
};
assert!(matches!(*output, ReturnTypeInfo::Empty));
}
#[test]
fn return_type_info_impl_future_impl_iterator_i32() {
let rt: ReturnType =
parse2(quote! {-> impl Future<Output = impl Iterator<Item=i32>> }).unwrap();
let ReturnTypeInfo::ImplFuture {
output: Some(output),
} = return_type_info(&rt)
else {
panic!("Expected ReturnTypeInfo::ImplFuture with Some output");
};
assert!(matches!(
*output,
ReturnTypeInfo::ImplIterator { item: Some(_) }
));
}
#[test]
fn default_impl_for_method_with_impl_future_output_unit() {
let org_trait = given(quote! {
trait MyTrait {
fn method(&self) -> impl Future<Output = ()>;
}
});
let double_trait = double_trait(org_trait).unwrap();
let actual = quote! { #double_trait };
let expected = quote! {
trait MyTrait {
fn method(&self) -> impl Future<Output = ()> {
async { }
}
}
};
assert_eq!(actual.to_string(), expected.to_string());
}
#[test]
fn default_impl_for_method_with_impl_future_output_i32() {
let org_trait = given(quote! {
trait MyTrait {
fn method(&self) -> impl Future<Output = i32>;
}
});
let double_trait = double_trait(org_trait).unwrap();
let actual = quote! { #double_trait };
let expected = quote! {
trait MyTrait {
fn method(&self) -> impl Future<Output = i32> {
async {
let double_trait_name = stringify!(MyTrait);
let fn_name = stringify!(method);
unimplemented!("{double_trait_name}::{fn_name}")
}
}
}
};
assert_eq!(expected.to_string(), actual.to_string());
}
#[test]
fn default_impl_for_method_with_impl_iterator_return() {
let org_trait = given(quote! {
trait MyTrait {
fn method(&self) -> impl Iterator<Item = String>;
}
});
let double_trait = double_trait(org_trait).unwrap();
let actual = quote! { #double_trait };
let expected = quote! {
trait MyTrait {
fn method(&self) -> impl Iterator<Item = String> {
#[allow(unreachable_code)]
std::iter::from_fn(move | | {
if false {
Some({
let double_trait_name = stringify!(MyTrait);
let fn_name = stringify!(method);
unimplemented!("{double_trait_name}::{fn_name}")
})
} else {
None
}
})
}
}
};
assert_eq!(actual.to_string(), expected.to_string());
}
#[test]
fn empty_default_implementation_if_function_does_not_return_anything() {
let org_trait = given(quote! {
trait MyTrait {
fn method(x: i32);
}
});
let double_trait = double_trait(org_trait).unwrap();
let actual = quote! { #double_trait };
let expected = quote! {
trait MyTrait {
fn method(_: i32) {}
}
};
assert_eq!(actual.to_string(), expected.to_string());
}
#[test]
fn default_implementation_for_function_with_i32_result() {
let org_trait = given(quote! {
trait MyTrait {
fn method(x: i32) -> i32;
}
});
let double_trait = double_trait(org_trait).unwrap();
let actual = quote! { #double_trait };
let expected = quote! {
trait MyTrait {
fn method(_: i32) -> i32 {
let double_trait_name = stringify!(MyTrait);
let fn_name = stringify!(method);
unimplemented!("{double_trait_name}::{fn_name}")
}
}
};
assert_eq!(actual.to_string(), expected.to_string());
}
#[test]
fn compiler_error_for_unknown_return_impl() {
let org_trait = given(quote! {
trait MyTrait {
fn method() -> impl UnsupportedTrait;
}
});
let double_trait = double_trait(org_trait).unwrap();
let actual = quote! { #double_trait };
let expected = quote! {
trait MyTrait {
fn method() -> impl UnsupportedTrait {
compile_error!(
"impl Trait is currently not supported by double-trait. Apart from the \
special case of impl Future."
)
}
}
};
assert_eq!(actual.to_string(), expected.to_string());
}
#[test]
fn strip_parameter_names_from_default_implementation() {
let org_trait = given(quote! {
trait MyTrait {
fn method(x: i32) -> i32;
}
});
let double_trait = double_trait(org_trait).unwrap();
let actual = quote! { #double_trait };
let expected = quote! {
trait MyTrait {
fn method(_: i32) -> i32{
let double_trait_name = stringify!(MyTrait);
let fn_name = stringify!(method);
unimplemented!("{double_trait_name}::{fn_name}")
}
}
};
assert_eq!(actual.to_string(), expected.to_string());
}
fn given(item: proc_macro2::TokenStream) -> ItemTrait {
parse2(item).unwrap()
}
}