use {
crate::analysis::patterns::get_apply_macro_parameters,
syn::{
GenericParam,
Signature,
Type,
TypeParamBound,
},
};
pub struct ImplTraitCandidate {
pub param_name: String,
pub param_span: proc_macro2::Span,
pub bounds_display: String,
}
pub fn find_impl_trait_candidates(sig: &Signature) -> Vec<ImplTraitCandidate> {
let mut candidates = Vec::new();
for param in &sig.generics.params {
let GenericParam::Type(type_param) = param else {
continue;
};
let name = type_param.ident.to_string();
let all_bounds = collect_all_bounds(type_param, &name, &sig.generics.where_clause);
if !has_trait_bounds(&all_bounds) {
continue;
}
if !appears_once_at_top_level(sig, &name) {
continue;
}
if appears_in_return_type(sig, &name) {
continue;
}
if is_cross_referenced(sig, &name) {
continue;
}
let bounds_display = format_bounds(&all_bounds);
candidates.push(ImplTraitCandidate {
param_name: name,
param_span: type_param.ident.span(),
bounds_display,
});
}
candidates
}
fn collect_all_bounds<'a>(
type_param: &'a syn::TypeParam,
name: &str,
where_clause: &'a Option<syn::WhereClause>,
) -> Vec<&'a TypeParamBound> {
let mut bounds: Vec<&TypeParamBound> = type_param.bounds.iter().collect();
if let Some(wc) = where_clause {
for pred in &wc.predicates {
if let syn::WherePredicate::Type(pred_type) = pred
&& type_is_ident(&pred_type.bounded_ty, name)
{
bounds.extend(pred_type.bounds.iter());
}
}
}
bounds
}
fn has_trait_bounds(bounds: &[&TypeParamBound]) -> bool {
bounds.iter().any(|b| matches!(b, TypeParamBound::Trait(_)))
}
#[expect(clippy::indexing_slicing, reason = "matching.len() == 1 checked on return")]
fn appears_once_at_top_level(
sig: &Signature,
name: &str,
) -> bool {
let matching: Vec<_> = sig
.inputs
.iter()
.filter_map(|arg| if let syn::FnArg::Typed(pat_type) = arg { Some(pat_type) } else { None })
.filter(|pat_type| contains_type_param(&pat_type.ty, name))
.collect();
matching.len() == 1 && is_top_level_type_param(&matching[0].ty, name)
}
fn is_top_level_type_param(
ty: &Type,
name: &str,
) -> bool {
match ty {
Type::Path(type_path) => type_path.qself.is_none() && type_path.path.is_ident(name),
Type::Reference(type_ref) => is_top_level_type_param(&type_ref.elem, name),
Type::Paren(type_paren) => is_top_level_type_param(&type_paren.elem, name),
Type::Group(type_group) => is_top_level_type_param(&type_group.elem, name),
_ => false,
}
}
fn appears_in_return_type(
sig: &Signature,
name: &str,
) -> bool {
match &sig.output {
syn::ReturnType::Default => false,
syn::ReturnType::Type(_, ty) => contains_type_param(ty, name),
}
}
fn is_cross_referenced(
sig: &Signature,
name: &str,
) -> bool {
for param in &sig.generics.params {
let GenericParam::Type(other_param) = param else {
continue;
};
if other_param.ident == name {
continue;
}
if bounds_contain_type_param(other_param.bounds.iter(), name) {
return true;
}
}
if let Some(wc) = &sig.generics.where_clause {
for pred in &wc.predicates {
if let syn::WherePredicate::Type(pred_type) = pred
&& !type_is_ident(&pred_type.bounded_ty, name)
&& bounds_contain_type_param(pred_type.bounds.iter(), name)
{
return true;
}
}
}
false
}
fn bounds_contain_type_param<'a>(
bounds: impl Iterator<Item = &'a TypeParamBound>,
name: &str,
) -> bool {
for bound in bounds {
if let TypeParamBound::Trait(trait_bound) = bound
&& trait_bound_contains_type_param(trait_bound, name)
{
return true;
}
}
false
}
fn format_bounds(bounds: &[&TypeParamBound]) -> String {
use quote::ToTokens;
bounds.iter().map(|b| b.to_token_stream().to_string()).collect::<Vec<_>>().join(" + ")
}
fn type_is_ident(
ty: &Type,
name: &str,
) -> bool {
if let Type::Path(type_path) = ty
&& type_path.qself.is_none()
&& type_path.path.is_ident(name)
{
return true;
}
false
}
pub fn contains_type_param(
ty: &Type,
name: &str,
) -> bool {
match ty {
Type::Path(type_path) => {
if type_path.qself.is_none() && type_path.path.is_ident(name) {
return true;
}
if let Some(qself) = &type_path.qself
&& contains_type_param(&qself.ty, name)
{
return true;
}
for segment in &type_path.path.segments {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
for arg in &args.args {
match arg {
syn::GenericArgument::Type(inner_ty) => {
if contains_type_param(inner_ty, name) {
return true;
}
}
syn::GenericArgument::AssocType(assoc) => {
if contains_type_param(&assoc.ty, name) {
return true;
}
}
_ => {}
}
}
}
}
false
}
Type::Macro(type_macro) => {
if let Some((brand, args)) = get_apply_macro_parameters(type_macro) {
if contains_type_param(&brand, name) {
return true;
}
for arg_ty in &args {
if contains_type_param(arg_ty, name) {
return true;
}
}
}
false
}
Type::Reference(type_ref) => contains_type_param(&type_ref.elem, name),
Type::Tuple(type_tuple) =>
type_tuple.elems.iter().any(|elem| contains_type_param(elem, name)),
Type::ImplTrait(type_impl) => type_impl.bounds.iter().any(|bound| {
if let TypeParamBound::Trait(trait_bound) = bound {
trait_bound_contains_type_param(trait_bound, name)
} else {
false
}
}),
Type::TraitObject(type_obj) => type_obj.bounds.iter().any(|bound| {
if let TypeParamBound::Trait(trait_bound) = bound {
trait_bound_contains_type_param(trait_bound, name)
} else {
false
}
}),
Type::BareFn(type_fn) => {
for input in &type_fn.inputs {
if contains_type_param(&input.ty, name) {
return true;
}
}
if let syn::ReturnType::Type(_, ret_ty) = &type_fn.output
&& contains_type_param(ret_ty, name)
{
return true;
}
false
}
Type::Array(type_array) => contains_type_param(&type_array.elem, name),
Type::Slice(type_slice) => contains_type_param(&type_slice.elem, name),
Type::Paren(type_paren) => contains_type_param(&type_paren.elem, name),
Type::Group(type_group) => contains_type_param(&type_group.elem, name),
_ => false,
}
}
fn trait_bound_contains_type_param(
trait_bound: &syn::TraitBound,
name: &str,
) -> bool {
for segment in &trait_bound.path.segments {
if segment.ident == name {
return true;
}
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
for arg in &args.args {
match arg {
syn::GenericArgument::Type(ty) =>
if contains_type_param(ty, name) {
return true;
},
syn::GenericArgument::AssocType(assoc) => {
if contains_type_param(&assoc.ty, name) {
return true;
}
}
_ => {}
}
}
}
if let syn::PathArguments::Parenthesized(args) = &segment.arguments {
for input in &args.inputs {
if contains_type_param(input, name) {
return true;
}
}
if let syn::ReturnType::Type(_, ret_ty) = &args.output
&& contains_type_param(ret_ty, name)
{
return true;
}
}
}
false
}
#[cfg(test)]
#[expect(
clippy::unwrap_used,
clippy::indexing_slicing,
reason = "Tests use panicking operations for brevity and clarity"
)]
mod tests {
use {
super::*,
syn::parse_str,
};
#[test]
fn test_simple_path() {
let ty: Type = parse_str("F").unwrap();
assert!(contains_type_param(&ty, "F"));
}
#[test]
fn test_simple_path_mismatch() {
let ty: Type = parse_str("F").unwrap();
assert!(!contains_type_param(&ty, "G"));
}
#[test]
fn test_nested_in_generic() {
let ty: Type = parse_str("Option<F>").unwrap();
assert!(contains_type_param(&ty, "F"));
}
#[test]
fn test_deeply_nested() {
let ty: Type = parse_str("Vec<Option<F>>").unwrap();
assert!(contains_type_param(&ty, "F"));
}
#[test]
fn test_reference() {
let ty: Type = parse_str("&F").unwrap();
assert!(contains_type_param(&ty, "F"));
}
#[test]
fn test_mutable_reference() {
let ty: Type = parse_str("&mut F").unwrap();
assert!(contains_type_param(&ty, "F"));
}
#[test]
fn test_tuple() {
let ty: Type = parse_str("(A, F, B)").unwrap();
assert!(contains_type_param(&ty, "F"));
}
#[test]
fn test_tuple_absent() {
let ty: Type = parse_str("(A, B)").unwrap();
assert!(!contains_type_param(&ty, "F"));
}
#[test]
fn test_bare_fn() {
let ty: Type = parse_str("fn(F) -> B").unwrap();
assert!(contains_type_param(&ty, "F"));
}
#[test]
fn test_bare_fn_return() {
let ty: Type = parse_str("fn(A) -> F").unwrap();
assert!(contains_type_param(&ty, "F"));
}
#[test]
fn test_impl_trait_bound() {
let ty: Type = parse_str("impl Iterator<Item = F>").unwrap();
assert!(contains_type_param(&ty, "F"));
}
#[test]
fn test_dyn_trait_bound() {
let ty: Type = parse_str("dyn Fn(F) -> B").unwrap();
assert!(contains_type_param(&ty, "F"));
}
#[test]
fn test_array() {
let ty: Type = parse_str("[F; 3]").unwrap();
assert!(contains_type_param(&ty, "F"));
}
#[test]
fn test_slice() {
let ty: Type = parse_str("[F]").unwrap();
assert!(contains_type_param(&ty, "F"));
}
#[test]
fn test_no_match_in_complex() {
let ty: Type = parse_str("Vec<Option<&str>>").unwrap();
assert!(!contains_type_param(&ty, "F"));
}
fn parse_sig(s: &str) -> Signature {
let item: syn::ItemFn = parse_str(&format!("{s} {{}}")).unwrap();
item.sig
}
#[test]
fn test_basic_fn_bound() {
let sig = parse_sig("fn new<F>(f: F) where F: FnOnce() -> A");
let candidates = find_impl_trait_candidates(&sig);
assert_eq!(candidates.len(), 1);
assert_eq!(candidates[0].param_name, "F");
assert!(candidates[0].bounds_display.contains("FnOnce"));
}
#[test]
fn test_inline_bounds() {
let sig = parse_sig("fn apply<F: Fn(A) -> B>(f: F, a: A) -> B");
let candidates = find_impl_trait_candidates(&sig);
assert_eq!(candidates.len(), 1);
assert_eq!(candidates[0].param_name, "F");
}
#[test]
fn test_multiple_candidates() {
let sig = parse_sig("fn foo<F: Fn(A), G: Fn(B)>(f: F, g: G)");
let candidates = find_impl_trait_candidates(&sig);
assert_eq!(candidates.len(), 2);
let names: Vec<&str> = candidates.iter().map(|c| c.param_name.as_str()).collect();
assert!(names.contains(&"F"));
assert!(names.contains(&"G"));
}
#[test]
fn test_mixed_where_and_inline() {
let sig = parse_sig("fn bar<B: 'static, F>(f: F) -> Out where F: FnOnce(A) -> B + 'static");
let candidates = find_impl_trait_candidates(&sig);
assert_eq!(candidates.len(), 1);
assert_eq!(candidates[0].param_name, "F");
}
#[test]
fn test_lifetime_only_bound_skipped() {
let sig = parse_sig("fn baz<B: 'static>(x: B) -> B");
let candidates = find_impl_trait_candidates(&sig);
assert!(candidates.is_empty());
}
#[test]
fn test_in_return_type() {
let sig = parse_sig("fn identity<T: Clone>(x: T) -> T");
let candidates = find_impl_trait_candidates(&sig);
assert!(candidates.is_empty());
}
#[test]
fn test_multiple_param_positions() {
let sig = parse_sig("fn combine<T: Clone>(a: T, b: T) -> T");
let candidates = find_impl_trait_candidates(&sig);
assert!(candidates.is_empty());
}
#[test]
fn test_no_trait_bounds() {
let sig = parse_sig("fn wrap<T>(x: T) -> Box<T>");
let candidates = find_impl_trait_candidates(&sig);
assert!(candidates.is_empty());
}
#[test]
fn test_cross_referenced() {
let sig = parse_sig("fn foo<F: Clone, G: Fn(F)>(f: F, g: G)");
let candidates = find_impl_trait_candidates(&sig);
let names: Vec<&str> = candidates.iter().map(|c| c.param_name.as_str()).collect();
assert!(!names.contains(&"F"));
assert!(names.contains(&"G"));
}
#[test]
fn test_only_lifetime_bounds() {
let sig = parse_sig("fn bar<T: 'a>(x: T)");
let candidates = find_impl_trait_candidates(&sig);
assert!(candidates.is_empty());
}
#[test]
fn test_self_receiver_ignored() {
let sig = parse_sig("fn method<F: Fn()>(self_: &Self, f: F, f2: F)");
let candidates = find_impl_trait_candidates(&sig);
assert!(candidates.is_empty());
}
#[test]
fn test_no_generics() {
let sig = parse_sig("fn foo(x: i32) -> i32");
let candidates = find_impl_trait_candidates(&sig);
assert!(candidates.is_empty());
}
#[test]
fn test_empty_where_clause() {
let sig = parse_sig("fn foo<T: Clone>(x: T) -> T where");
let candidates = find_impl_trait_candidates(&sig);
assert!(candidates.is_empty()); }
#[test]
fn test_self_receiver_not_counted() {
let item: syn::TraitItemFn = parse_str("fn method<F: Fn()>(&self, f: F);").unwrap();
let candidates = find_impl_trait_candidates(&item.sig);
assert_eq!(candidates.len(), 1);
assert_eq!(candidates[0].param_name, "F");
}
#[test]
fn test_multiple_bounds_displayed() {
let sig = parse_sig("fn foo<F: Clone + Send + Fn()>(f: F)");
let candidates = find_impl_trait_candidates(&sig);
assert_eq!(candidates.len(), 1);
assert!(candidates[0].bounds_display.contains("Clone"));
assert!(candidates[0].bounds_display.contains("Send"));
assert!(candidates[0].bounds_display.contains("Fn"));
}
#[test]
fn test_where_clause_cross_ref() {
let sig = parse_sig("fn foo<A: Clone, B>(a: A, b: B) where B: From<A>");
let candidates = find_impl_trait_candidates(&sig);
let names: Vec<&str> = candidates.iter().map(|c| c.param_name.as_str()).collect();
assert!(!names.contains(&"A"));
}
#[test]
fn test_top_level_bare_ident() {
let ty: Type = parse_str("F").unwrap();
assert!(is_top_level_type_param(&ty, "F"));
}
#[test]
fn test_top_level_reference() {
let ty: Type = parse_str("&F").unwrap();
assert!(is_top_level_type_param(&ty, "F"));
}
#[test]
fn test_top_level_mut_reference() {
let ty: Type = parse_str("&mut F").unwrap();
assert!(is_top_level_type_param(&ty, "F"));
}
#[test]
fn test_not_top_level_in_option() {
let ty: Type = parse_str("Option<F>").unwrap();
assert!(!is_top_level_type_param(&ty, "F"));
}
#[test]
fn test_not_top_level_in_vec() {
let ty: Type = parse_str("Vec<F>").unwrap();
assert!(!is_top_level_type_param(&ty, "F"));
}
#[test]
fn test_not_top_level_in_tuple() {
let ty: Type = parse_str("(A, F)").unwrap();
assert!(!is_top_level_type_param(&ty, "F"));
}
#[test]
fn test_not_top_level_in_associated_type() {
let ty: Type = parse_str("<F as Trait>::Assoc").unwrap();
assert!(!is_top_level_type_param(&ty, "F"));
}
#[test]
fn test_nested_param_not_candidate() {
let sig = parse_sig("fn foo<F: Clone>(x: Option<F>)");
let candidates = find_impl_trait_candidates(&sig);
assert!(candidates.is_empty());
}
#[test]
fn test_reference_param_is_candidate() {
let sig = parse_sig("fn foo<F: Clone>(x: &F)");
let candidates = find_impl_trait_candidates(&sig);
assert_eq!(candidates.len(), 1);
assert_eq!(candidates[0].param_name, "F");
}
#[test]
fn test_associated_type_projection_not_candidate() {
let sig = parse_sig("fn foo<F: Iterator>(x: <F as Iterator>::Item)");
let candidates = find_impl_trait_candidates(&sig);
assert!(candidates.is_empty());
}
}