use {
crate::{
analysis::patterns::get_apply_macro_parameters,
core::constants::markers,
hkt::{
ApplyInput,
apply::apply_worker,
},
},
std::collections::HashMap,
syn::{
GenericArgument,
ImplItem,
Item,
PathArguments,
ReturnType,
Type,
WherePredicate,
},
};
#[derive(Debug, Clone)]
pub struct DispatchTraitInfo {
#[expect(dead_code, reason = "Stored for diagnostics")]
pub trait_name: String,
pub brand_param: String,
pub kind_trait_name: Option<String>,
pub semantic_constraint: Option<String>,
pub secondary_constraints: Vec<(String, String)>,
pub arrow_type: Option<DispatchArrow>,
pub tuple_closure: bool,
pub return_structure: ReturnStructure,
pub container_params: Vec<ContainerParam>,
pub associated_types: Vec<(String, Vec<String>)>,
pub self_type_elements: Option<Vec<String>>,
pub type_param_order: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct ContainerParam {
pub name: String,
pub position: usize,
pub element_types: Vec<String>,
}
#[derive(Debug, Clone)]
pub enum ReturnStructure {
Plain(String),
Applied(Vec<String>),
Nested { outer_param: String, inner_args: Vec<String> },
Tuple(Vec<Vec<String>>),
NestedTuple { outer_param: String, inner_elements: Vec<Vec<String>> },
}
#[derive(Debug, Clone)]
pub struct DispatchArrow {
pub inputs: Vec<DispatchArrowParam>,
pub output: ArrowOutput,
}
#[derive(Debug, Clone)]
pub enum ArrowOutput {
Plain(String),
BrandApplied(Vec<String>),
OtherApplied { brand: String, args: Vec<String> },
}
#[derive(Debug, Clone)]
pub enum DispatchArrowParam {
TypeParam(String),
AssociatedType { assoc_name: String },
SubArrow(DispatchArrow),
}
const INFRASTRUCTURE_TRAITS: &[&str] =
&["Send", "Sync", "Clone", "Copy", "Debug", "Display", "Sized", "LiftFn", "SendLiftFn"];
fn extract_apply_type_args(ty: &Type) -> Option<Vec<String>> {
let Type::Macro(type_macro) = ty else {
return None;
};
let (_brand, args) = get_apply_macro_parameters(type_macro)?;
Some(args.iter().map(|t| quote::quote!(#t).to_string().replace(' ', "")).collect())
}
pub fn analyze_dispatch_traits(items: &[Item]) -> HashMap<String, DispatchTraitInfo> {
let mut result = HashMap::new();
let dispatch_trait_names: Vec<String> = items
.iter()
.filter_map(|item| {
if let Item::Trait(item_trait) = item {
let name = item_trait.ident.to_string();
if name.ends_with(markers::DISPATCH_SUFFIX) {
return Some(name);
}
}
None
})
.collect();
for trait_name in &dispatch_trait_names {
if let Some(val_impl) = find_val_impl(items, trait_name) {
let trait_def = items.iter().find_map(|item| {
if let Item::Trait(item_trait) = item
&& item_trait.ident == trait_name.as_str()
{
return Some(item_trait);
}
None
});
let info = extract_dispatch_info(trait_name, val_impl, trait_def);
result.insert(trait_name.clone(), info);
}
}
result
}
fn find_val_impl<'a>(
items: &'a [Item],
trait_name: &str,
) -> Option<&'a syn::ItemImpl> {
items.iter().find_map(|item| {
if let Item::Impl(item_impl) = item {
if let Some((_, trait_path, _)) = &item_impl.trait_ {
let impl_trait_name =
trait_path.segments.last().map(|s| s.ident.to_string()).unwrap_or_default();
if impl_trait_name == trait_name {
if has_marker_type_arg(trait_path, "Val") {
return Some(item_impl);
}
}
}
}
None
})
}
fn has_marker_type_arg(
path: &syn::Path,
marker_name: &str,
) -> bool {
let Some(last_segment) = path.segments.last() else {
return false;
};
let PathArguments::AngleBracketed(args) = &last_segment.arguments else {
return false;
};
args.args.iter().any(|arg| {
if let GenericArgument::Type(Type::Path(type_path)) = arg {
type_path.path.get_ident().is_some_and(|ident| ident == marker_name)
} else {
false
}
})
}
fn extract_dispatch_info(
trait_name: &str,
val_impl: &syn::ItemImpl,
trait_def: Option<&syn::ItemTrait>,
) -> DispatchTraitInfo {
let brand_param =
trait_def.and_then(find_brand_param_from_trait_def).or_else(|| find_brand_param(val_impl));
let kind_trait_name = trait_def.and_then(|td| {
extract_kind_trait_name(td, brand_param.as_deref().unwrap_or(markers::DEFAULT_BRAND_PARAM))
});
let semantic_constraint =
brand_param.as_ref().and_then(|bp| extract_semantic_constraint(val_impl, bp));
let tuple_closure = is_tuple_closure(val_impl);
let arrow_type = if tuple_closure {
extract_tuple_arrow(val_impl, brand_param.as_deref())
} else {
extract_single_arrow(val_impl, brand_param.as_deref())
};
let secondary_constraints = brand_param
.as_ref()
.map(|bp| extract_secondary_constraints(val_impl, bp))
.unwrap_or_default();
let return_structure = trait_def
.and_then(|td| extract_return_structure(td, brand_param.as_deref()))
.unwrap_or(ReturnStructure::Plain("?".to_string()));
let container_params =
trait_def.map(|td| extract_container_params(td, val_impl)).unwrap_or_default();
let associated_types = extract_associated_types(val_impl);
let self_type_elements = extract_self_type_elements(val_impl);
let type_param_order =
trait_def.map(|td| extract_type_param_order(td, &container_params)).unwrap_or_default();
DispatchTraitInfo {
trait_name: trait_name.to_string(),
brand_param: brand_param.unwrap_or_else(|| markers::DEFAULT_BRAND_PARAM.to_string()),
kind_trait_name,
semantic_constraint,
secondary_constraints,
arrow_type,
tuple_closure,
return_structure,
container_params,
associated_types,
self_type_elements,
type_param_order,
}
}
fn extract_kind_trait_name(
trait_def: &syn::ItemTrait,
brand_param_name: &str,
) -> Option<String> {
for param in &trait_def.generics.params {
if let syn::GenericParam::Type(type_param) = param
&& type_param.ident == brand_param_name
{
for bound in &type_param.bounds {
if let syn::TypeParamBound::Trait(trait_bound) = bound {
let name = trait_bound
.path
.segments
.last()
.map(|s| s.ident.to_string())
.unwrap_or_default();
if name.starts_with(markers::KIND_PREFIX) {
return Some(name);
}
}
}
}
}
if let Some(where_clause) = &trait_def.generics.where_clause {
for predicate in &where_clause.predicates {
if let WherePredicate::Type(pred_type) = predicate {
let param_name = type_to_string(&pred_type.bounded_ty);
if param_name == brand_param_name {
for bound in &pred_type.bounds {
if let syn::TypeParamBound::Trait(trait_bound) = bound {
let name = trait_bound
.path
.segments
.last()
.map(|s| s.ident.to_string())
.unwrap_or_default();
if name.starts_with(markers::KIND_PREFIX) {
return Some(name);
}
}
}
}
}
}
}
None
}
fn extract_container_params(
trait_def: &syn::ItemTrait,
val_impl: &syn::ItemImpl,
) -> Vec<ContainerParam> {
let trait_params: Vec<String> = trait_def
.generics
.params
.iter()
.filter_map(|p| {
if let syn::GenericParam::Type(tp) = p { Some(tp.ident.to_string()) } else { None }
})
.collect();
let Some((_, trait_path, _)) = &val_impl.trait_ else {
return Vec::new();
};
let Some(last_seg) = trait_path.segments.last() else {
return Vec::new();
};
let PathArguments::AngleBracketed(impl_args) = &last_seg.arguments else {
return Vec::new();
};
let impl_type_args: Vec<&Type> = impl_args
.args
.iter()
.filter_map(|arg| if let GenericArgument::Type(ty) = arg { Some(ty) } else { None })
.collect();
let mut result = Vec::new();
for (i, param_name) in trait_params.iter().enumerate() {
if param_name == markers::DEFAULT_BRAND_PARAM
|| param_name == markers::FN_BRAND_PARAM
|| param_name == markers::MARKER_PARAM
|| param_name.len() == 1
{
continue;
}
if let Some(impl_arg) = impl_type_args.get(i)
&& let Some(element_types) = extract_apply_type_args(impl_arg)
&& !element_types.is_empty()
{
result.push(ContainerParam {
name: param_name.clone(),
position: i,
element_types,
});
}
}
result
}
fn extract_return_structure(
trait_def: &syn::ItemTrait,
brand_param: Option<&str>,
) -> Option<ReturnStructure> {
for item in &trait_def.items {
let syn::TraitItem::Fn(method) = item else {
continue;
};
if method.sig.ident != "dispatch" {
continue;
}
let syn::ReturnType::Type(_, return_ty) = &method.sig.output else {
return Some(ReturnStructure::Plain("()".to_string()));
};
return Some(classify_return_type(return_ty, brand_param));
}
None
}
fn classify_return_type(
ty: &Type,
brand_param: Option<&str>,
) -> ReturnStructure {
if let Type::Tuple(tuple) = ty {
let mut tuple_elements = Vec::new();
for elem in &tuple.elems {
if let Some(args) = extract_apply_type_args(elem) {
tuple_elements.push(args);
} else {
let elem_str = quote::quote!(#elem).to_string().replace(' ', "");
tuple_elements.push(vec![elem_str]);
}
}
if !tuple_elements.is_empty() {
return ReturnStructure::Tuple(tuple_elements);
}
}
if let Type::Macro(type_macro) = ty
&& let Some((brand, raw_args)) = get_apply_macro_parameters(type_macro)
{
let args: Vec<String> =
raw_args.iter().map(|t| quote::quote!(#t).to_string().replace(' ', "")).collect();
let brand_name = match &brand {
Type::Path(tp) => tp.path.segments.last().map(|s| s.ident.to_string()),
_ => None,
};
let is_brand = brand_param.is_some_and(|bp| brand_name.as_deref() == Some(bp));
if is_brand {
return ReturnStructure::Applied(args);
}
let outer_name = brand_name.unwrap_or_else(|| "G".to_string());
if let [single_arg] = raw_args.as_slice()
&& let Type::Tuple(tuple) = single_arg
&& tuple.elems.len() >= 2
{
let mut inner_elements = Vec::new();
for elem in &tuple.elems {
if let Some(nested_args) = extract_apply_type_args(elem) {
inner_elements.push(nested_args);
} else {
let s = quote::quote!(#elem).to_string().replace(' ', "");
inner_elements.push(vec![s]);
}
}
return ReturnStructure::NestedTuple {
outer_param: outer_name,
inner_elements,
};
}
let mut inner_args = Vec::new();
for raw_arg in &raw_args {
if let Some(nested_args) = extract_apply_type_args(raw_arg) {
inner_args = nested_args;
} else {
let arg_str = quote::quote!(#raw_arg).to_string().replace(' ', "");
inner_args.push(arg_str);
}
}
return ReturnStructure::Nested {
outer_param: outer_name,
inner_args,
};
}
let ret_str = quote::quote!(#ty).to_string().replace(' ', "");
ReturnStructure::Plain(ret_str)
}
fn find_brand_param_from_trait_def(trait_def: &syn::ItemTrait) -> Option<String> {
for param in &trait_def.generics.params {
if let syn::GenericParam::Type(type_param) = param {
for bound in &type_param.bounds {
if let syn::TypeParamBound::Trait(trait_bound) = bound
&& trait_bound
.path
.segments
.last()
.is_some_and(|s| s.ident.to_string().starts_with(markers::KIND_PREFIX))
{
return Some(type_param.ident.to_string());
}
}
}
}
if let Some(where_clause) = &trait_def.generics.where_clause {
for predicate in &where_clause.predicates {
if let WherePredicate::Type(pred_type) = predicate {
for bound in &pred_type.bounds {
if let syn::TypeParamBound::Trait(trait_bound) = bound
&& trait_bound
.path
.segments
.last()
.is_some_and(|s| s.ident.to_string().starts_with(markers::KIND_PREFIX))
{
return Some(type_to_string(&pred_type.bounded_ty));
}
}
}
}
}
None
}
fn find_brand_param(val_impl: &syn::ItemImpl) -> Option<String> {
let Some(where_clause) = &val_impl.generics.where_clause else {
return None;
};
for predicate in &where_clause.predicates {
if let WherePredicate::Type(pred_type) = predicate {
let param_name = type_to_string(&pred_type.bounded_ty);
let has_trait_bound =
pred_type.bounds.iter().any(|b| matches!(b, syn::TypeParamBound::Trait(_)));
if !has_trait_bound {
continue;
}
for bound in &pred_type.bounds {
if let syn::TypeParamBound::Trait(trait_bound) = bound {
let bound_name = trait_bound
.path
.segments
.last()
.map(|s| s.ident.to_string())
.unwrap_or_default();
if is_semantic_type_class(&bound_name) {
return Some(param_name);
}
}
}
}
}
for param in &val_impl.generics.params {
if let syn::GenericParam::Type(type_param) = param {
let param_name = type_param.ident.to_string();
for bound in &type_param.bounds {
if let syn::TypeParamBound::Trait(trait_bound) = bound {
let bound_name = trait_bound
.path
.segments
.last()
.map(|s| s.ident.to_string())
.unwrap_or_default();
if is_semantic_type_class(&bound_name) {
return Some(param_name);
}
}
}
}
}
None
}
fn is_semantic_type_class(name: &str) -> bool {
if name == "Fn" || name == "FnMut" || name == "FnOnce" {
return false;
}
if name.starts_with(markers::KIND_PREFIX) {
return false;
}
if name.starts_with(markers::INFERABLE_BRAND_PREFIX) {
return false;
}
if INFRASTRUCTURE_TRAITS.contains(&name) {
return false;
}
if name.ends_with(markers::DISPATCH_SUFFIX) {
return false;
}
true
}
fn extract_semantic_constraint(
val_impl: &syn::ItemImpl,
brand_param: &str,
) -> Option<String> {
if let Some(where_clause) = &val_impl.generics.where_clause {
for predicate in &where_clause.predicates {
if let WherePredicate::Type(pred_type) = predicate {
let param_name = type_to_string(&pred_type.bounded_ty);
if param_name == brand_param {
for bound in &pred_type.bounds {
if let syn::TypeParamBound::Trait(trait_bound) = bound {
let name = trait_bound
.path
.segments
.last()
.map(|s| s.ident.to_string())
.unwrap_or_default();
if is_semantic_type_class(&name) {
return Some(name);
}
}
}
}
}
}
}
for param in &val_impl.generics.params {
if let syn::GenericParam::Type(type_param) = param
&& type_param.ident == brand_param
{
for bound in &type_param.bounds {
if let syn::TypeParamBound::Trait(trait_bound) = bound {
let name = trait_bound
.path
.segments
.last()
.map(|s| s.ident.to_string())
.unwrap_or_default();
if is_semantic_type_class(&name) {
return Some(name);
}
}
}
}
}
None
}
fn extract_secondary_constraints(
val_impl: &syn::ItemImpl,
brand_param: &str,
) -> Vec<(String, String)> {
let mut result = Vec::new();
if let Some(where_clause) = &val_impl.generics.where_clause {
for predicate in &where_clause.predicates {
if let WherePredicate::Type(pred_type) = predicate {
let param_name = type_to_string(&pred_type.bounded_ty);
if param_name == brand_param {
continue;
}
for bound in &pred_type.bounds {
if let syn::TypeParamBound::Trait(trait_bound) = bound {
let name = trait_bound
.path
.segments
.last()
.map(|s| s.ident.to_string())
.unwrap_or_default();
if is_semantic_type_class(&name) && !is_fn_like(&name) {
result.push((param_name.clone(), name));
}
}
}
}
}
}
result
}
fn is_tuple_closure(val_impl: &syn::ItemImpl) -> bool {
if let Type::Tuple(tuple) = &*val_impl.self_ty { tuple.elems.len() >= 2 } else { false }
}
fn extract_single_arrow(
val_impl: &syn::ItemImpl,
brand_param: Option<&str>,
) -> Option<DispatchArrow> {
if let Some(where_clause) = &val_impl.generics.where_clause {
for predicate in &where_clause.predicates {
if let WherePredicate::Type(pred_type) = predicate {
for bound in &pred_type.bounds {
if let Some(arrow) = extract_fn_arrow_from_bound(bound, brand_param) {
return Some(arrow);
}
}
}
}
}
for param in &val_impl.generics.params {
if let syn::GenericParam::Type(type_param) = param {
for bound in &type_param.bounds {
if let Some(arrow) = extract_fn_arrow_from_bound(bound, brand_param) {
return Some(arrow);
}
}
}
}
None
}
fn extract_tuple_arrow(
val_impl: &syn::ItemImpl,
brand_param: Option<&str>,
) -> Option<DispatchArrow> {
let mut all_inputs = Vec::new();
let mut last_output = ArrowOutput::Plain("()".to_string());
if let Some(where_clause) = &val_impl.generics.where_clause {
for predicate in &where_clause.predicates {
if let WherePredicate::Type(pred_type) = predicate {
for bound in &pred_type.bounds {
if let Some(arrow) = extract_fn_arrow_from_bound(bound, brand_param) {
last_output = arrow.output.clone();
all_inputs.push(DispatchArrowParam::SubArrow(arrow));
}
}
}
}
}
for param in &val_impl.generics.params {
if let syn::GenericParam::Type(type_param) = param {
for bound in &type_param.bounds {
if let Some(arrow) = extract_fn_arrow_from_bound(bound, brand_param) {
last_output = arrow.output.clone();
all_inputs.push(DispatchArrowParam::SubArrow(arrow));
}
}
}
}
if all_inputs.is_empty() {
None
} else {
Some(DispatchArrow {
inputs: all_inputs,
output: last_output,
})
}
}
fn extract_fn_arrow_from_bound(
bound: &syn::TypeParamBound,
brand_param: Option<&str>,
) -> Option<DispatchArrow> {
let syn::TypeParamBound::Trait(trait_bound) = bound else {
return None;
};
let segment = trait_bound.path.segments.last()?;
let name = segment.ident.to_string();
if name != "Fn" && name != "FnMut" && name != "FnOnce" {
return None;
}
let PathArguments::Parenthesized(args) = &segment.arguments else {
return None;
};
let inputs: Vec<DispatchArrowParam> =
args.inputs.iter().map(|ty| type_to_arrow_param(ty, brand_param)).collect();
let output = match &args.output {
ReturnType::Default => ArrowOutput::Plain("()".to_string()),
ReturnType::Type(_, ty) => classify_arrow_output(ty, brand_param),
};
Some(DispatchArrow {
inputs,
output,
})
}
fn type_to_arrow_param(
ty: &Type,
brand_param: Option<&str>,
) -> DispatchArrowParam {
if let Type::Path(type_path) = ty {
let segments: Vec<_> = type_path.path.segments.iter().collect();
if let [first_seg, second_seg] = segments.as_slice() {
let first = first_seg.ident.to_string();
let second = second_seg.ident.to_string();
if brand_param.is_some_and(|bp| bp == first) {
return DispatchArrowParam::AssociatedType {
assoc_name: second,
};
}
}
}
DispatchArrowParam::TypeParam(type_to_string(ty))
}
fn classify_arrow_output(
ty: &Type,
brand_param: Option<&str>,
) -> ArrowOutput {
if let Type::Macro(type_macro) = ty
&& let Some((brand, args)) = get_apply_macro_parameters(type_macro)
{
let brand_name = match &brand {
Type::Path(tp) => tp.path.segments.last().map(|s| s.ident.to_string()),
_ => None,
};
let arg_strings: Vec<String> =
args.iter().map(|t| quote::quote!(#t).to_string().replace(' ', "")).collect();
if !arg_strings.is_empty() {
let is_brand = brand_param.is_some_and(|bp| brand_name.as_deref() == Some(bp));
if is_brand {
return ArrowOutput::BrandApplied(arg_strings);
}
if let Some(name) = brand_name {
return ArrowOutput::OtherApplied {
brand: name,
args: arg_strings,
};
}
}
}
ArrowOutput::Plain(quote::quote!(#ty).to_string())
}
fn extract_type_param_order(
trait_def: &syn::ItemTrait,
container_params: &[ContainerParam],
) -> Vec<String> {
let container_names: Vec<&str> = container_params.iter().map(|cp| cp.name.as_str()).collect();
trait_def
.generics
.params
.iter()
.filter_map(|p| {
if let syn::GenericParam::Type(tp) = p {
let name = tp.ident.to_string();
if name == markers::FN_BRAND_PARAM || name == markers::MARKER_PARAM {
return None;
}
if container_names.contains(&name.as_str()) {
return None;
}
Some(name)
} else {
None
}
})
.collect()
}
fn extract_associated_types(val_impl: &syn::ItemImpl) -> Vec<(String, Vec<String>)> {
let mut result = Vec::new();
for item in &val_impl.items {
if let ImplItem::Type(type_item) = item {
let name = type_item.ident.to_string();
if let Some(args) = extract_apply_type_args(&type_item.ty) {
result.push((name, args));
}
}
}
result
}
fn extract_self_type_elements(val_impl: &syn::ItemImpl) -> Option<Vec<String>> {
let Type::Macro(type_macro) = &*val_impl.self_ty else {
return None;
};
let (_brand, args) = get_apply_macro_parameters(type_macro)?;
Some(
args.iter()
.map(|t| {
if let Type::Macro(inner_macro) = t
&& let Ok(apply_input) =
syn::parse2::<ApplyInput>(inner_macro.mac.tokens.clone())
&& let Ok(resolved) = apply_worker(apply_input)
{
return resolved.to_string();
}
quote::quote!(#t).to_string().replace(' ', "")
})
.collect(),
)
}
fn is_fn_like(name: &str) -> bool {
name == "Fn" || name == "FnMut" || name == "FnOnce"
}
fn type_to_string(ty: &Type) -> String {
quote::quote!(#ty).to_string().replace(' ', "")
}
#[cfg(test)]
#[expect(
clippy::unwrap_used,
clippy::expect_used,
clippy::indexing_slicing,
reason = "Tests use panicking operations for brevity and clarity"
)]
mod tests {
use super::*;
fn make_items(code: &str) -> Vec<Item> {
let file: syn::File = syn::parse_str(code).expect("Failed to parse test code");
file.items
}
#[test]
fn test_simple_dispatch_trait() {
let items = make_items(
r#"
trait FunctorDispatch<'a, Brand, A, B, FA, Marker> {
fn dispatch(self, fa: FA) -> ();
}
impl<'a, Brand, A, B, F> FunctorDispatch<'a, Brand, A, B, (), Val> for F
where
Brand: Functor,
A: 'a,
B: 'a,
F: Fn(A) -> B + 'a,
{
fn dispatch(self, fa: ()) -> () {}
}
struct Val;
"#,
);
let result = analyze_dispatch_traits(&items);
assert_eq!(result.len(), 1);
let info = result.get("FunctorDispatch").unwrap();
assert_eq!(info.semantic_constraint.as_deref(), Some("Functor"));
assert!(info.arrow_type.is_some());
assert!(!info.tuple_closure);
let arrow = info.arrow_type.as_ref().unwrap();
assert_eq!(arrow.inputs.len(), 1);
assert!(matches!(arrow.output, ArrowOutput::Plain(ref s) if s == "B"));
}
#[test]
fn test_closureless_dispatch() {
let items = make_items(
r#"
trait AltDispatch<'a, Brand, A, Marker> {
fn dispatch(self, other: Self) -> ();
}
impl<'a, Brand, A> AltDispatch<'a, Brand, A, Val> for ()
where
Brand: Alt,
A: 'a,
{
fn dispatch(self, other: Self) -> () {}
}
struct Val;
"#,
);
let result = analyze_dispatch_traits(&items);
let info = result.get("AltDispatch").unwrap();
assert_eq!(info.semantic_constraint.as_deref(), Some("Alt"));
assert!(info.arrow_type.is_none());
assert!(!info.tuple_closure);
}
#[test]
fn test_tuple_closure_dispatch() {
let items = make_items(
r#"
trait BimapDispatch<'a, Brand, A, B, C, D, FA, Marker> {
fn dispatch(self, fa: FA) -> ();
}
impl<'a, Brand, A, B, C, D, F, G>
BimapDispatch<'a, Brand, A, B, C, D, (), Val> for (F, G)
where
Brand: Bifunctor,
A: 'a,
B: 'a,
C: 'a,
D: 'a,
F: Fn(A) -> B + 'a,
G: Fn(C) -> D + 'a,
{
fn dispatch(self, fa: ()) -> () {}
}
struct Val;
"#,
);
let result = analyze_dispatch_traits(&items);
let info = result.get("BimapDispatch").unwrap();
assert_eq!(info.semantic_constraint.as_deref(), Some("Bifunctor"));
assert!(info.arrow_type.is_some());
assert!(info.tuple_closure);
let arrow = info.arrow_type.as_ref().unwrap();
assert_eq!(arrow.inputs.len(), 2);
}
#[test]
fn test_secondary_constraints() {
let items = make_items(
r#"
trait TraverseDispatch<'a, FnBrand, Brand, A, B, F, FA, Marker> {
fn dispatch(self, fa: FA) -> ();
}
impl<'a, FnBrand, Brand, A, B, F, Func>
TraverseDispatch<'a, FnBrand, Brand, A, B, F, (), Val> for Func
where
Brand: Traversable,
A: 'a,
B: 'a,
F: Applicative,
Func: Fn(A) -> () + 'a,
{
fn dispatch(self, fa: ()) -> () {}
}
struct Val;
"#,
);
let result = analyze_dispatch_traits(&items);
let info = result.get("TraverseDispatch").unwrap();
assert_eq!(info.semantic_constraint.as_deref(), Some("Traversable"));
assert_eq!(info.secondary_constraints.len(), 1);
assert_eq!(info.secondary_constraints[0], ("F".to_string(), "Applicative".to_string()));
}
#[test]
fn test_container_params_with_apply() {
let items = make_items(
r#"
trait FunctorDispatch<'a, Brand: Kind_cdc7cd43dac7585f, A: 'a, B: 'a, FA, Marker> {
fn dispatch(self, fa: FA) -> Apply!(<Brand as Kind!( type Of<'a, T: 'a>: 'a; )>::Of<'a, B>);
}
impl<'a, Brand, A, B, F>
FunctorDispatch<
'a,
Brand,
A,
B,
Apply!(<Brand as Kind!( type Of<'a, T: 'a>: 'a; )>::Of<'a, A>),
Val,
> for F
where
Brand: Functor,
A: 'a,
B: 'a,
F: Fn(A) -> B + 'a,
{
fn dispatch(self, fa: Apply!(<Brand as Kind!( type Of<'a, T: 'a>: 'a; )>::Of<'a, A>)) -> Apply!(<Brand as Kind!( type Of<'a, T: 'a>: 'a; )>::Of<'a, B>) { todo!() }
}
struct Val;
"#,
);
let result = analyze_dispatch_traits(&items);
let info = result.get("FunctorDispatch").unwrap();
assert_eq!(
info.container_params.len(),
1,
"Expected 1 container param, got {:?}",
info.container_params
);
assert_eq!(info.container_params[0].name, "FA");
assert_eq!(info.container_params[0].element_types, vec!["A".to_string()]);
assert!(
matches!(info.return_structure, ReturnStructure::Applied(ref args) if args == &["B"]),
"Expected Applied([B]), got {:?}",
info.return_structure
);
}
#[test]
fn test_brand_param_in_middle_of_param_list() {
let items = make_items(
r#"
trait TraverseDispatch<'a, FnBrand, Brand: Kind_abc123, A: 'a, B: 'a, F, FA, Marker> {
fn dispatch(self, fa: FA) -> ();
}
impl<'a, FnBrand, Brand, A, B, F, Func>
TraverseDispatch<'a, FnBrand, Brand, A, B, F, (), Val> for Func
where
Brand: Traversable,
A: 'a,
B: 'a,
F: Applicative,
Func: Fn(A) -> () + 'a,
{
fn dispatch(self, fa: ()) -> () {}
}
struct Val;
"#,
);
let result = analyze_dispatch_traits(&items);
let info = result.get("TraverseDispatch").unwrap();
assert_eq!(info.brand_param, "Brand");
assert_eq!(info.kind_trait_name.as_deref(), Some("Kind_abc123"));
assert_eq!(info.semantic_constraint.as_deref(), Some("Traversable"));
}
#[test]
fn test_brand_param_with_unusual_name() {
let items = make_items(
r#"
trait MyDispatch<'a, F: Kind_xyz789, A: 'a, B: 'a, FA, Marker> {
fn dispatch(self, fa: FA) -> ();
}
impl<'a, F, A, B, Func>
MyDispatch<'a, F, A, B, (), Val> for Func
where
F: Functor,
A: 'a,
B: 'a,
Func: Fn(A) -> B + 'a,
{
fn dispatch(self, fa: ()) -> () {}
}
struct Val;
"#,
);
let result = analyze_dispatch_traits(&items);
let info = result.get("MyDispatch").unwrap();
assert_eq!(info.brand_param, "F");
assert_eq!(info.kind_trait_name.as_deref(), Some("Kind_xyz789"));
assert_eq!(info.semantic_constraint.as_deref(), Some("Functor"));
}
#[test]
fn test_no_semantic_constraint() {
let items = make_items(
r#"
trait WeirdDispatch<'a, Brand: Kind_abc123, A: 'a, Marker> {
fn dispatch(self) -> ();
}
impl<'a, Brand, A>
WeirdDispatch<'a, Brand, A, Val> for ()
where
A: 'a,
{
fn dispatch(self) -> () {}
}
struct Val;
"#,
);
let result = analyze_dispatch_traits(&items);
let info = result.get("WeirdDispatch").unwrap();
assert_eq!(info.brand_param, "Brand");
assert_eq!(info.kind_trait_name.as_deref(), Some("Kind_abc123"));
assert!(info.semantic_constraint.is_none());
}
#[test]
fn test_no_val_impl() {
let items = make_items(
r#"
trait OrphanDispatch<'a, Brand: Kind_abc123, A: 'a, Marker> {
fn dispatch(self) -> ();
}
struct Val;
"#,
);
let result = analyze_dispatch_traits(&items);
assert!(
!result.contains_key("OrphanDispatch"),
"Should not find info for dispatch trait without Val impl"
);
}
#[test]
fn test_extra_type_params_ignored() {
let items = make_items(
r#"
trait BigDispatch<'a, Brand: Kind_abc123, A: 'a, B: 'a, C: 'a, D: 'a, E: 'a, FA, Marker> {
fn dispatch(self, fa: FA) -> ();
}
impl<'a, Brand, A, B, C, D, E, F>
BigDispatch<'a, Brand, A, B, C, D, E, (), Val> for F
where
Brand: Functor,
A: 'a,
B: 'a,
C: 'a,
D: 'a,
E: 'a,
F: Fn(A) -> B + 'a,
{
fn dispatch(self, fa: ()) -> () {}
}
struct Val;
"#,
);
let result = analyze_dispatch_traits(&items);
let info = result.get("BigDispatch").unwrap();
assert_eq!(info.brand_param, "Brand");
assert_eq!(info.type_param_order, vec!["Brand", "A", "B", "C", "D", "E", "FA"]);
assert!(info.container_params.is_empty());
}
#[test]
fn test_associated_types_extracted() {
let items = make_items(
r#"
trait ApplyFirstDispatch<'a, Brand: Kind_abc123, A: 'a, B: 'a, Marker> {
type FB;
fn dispatch(self, fb: Self::FB) -> ();
}
impl<'a, Brand, A, B>
ApplyFirstDispatch<'a, Brand, A, B, Val>
for Apply!(<Brand as Kind!( type Of<'a, T: 'a>: 'a; )>::Of<'a, A>)
where
Brand: ApplyFirst,
A: 'a,
B: 'a,
{
type FB = Apply!(<Brand as Kind!( type Of<'a, T: 'a>: 'a; )>::Of<'a, B>);
fn dispatch(self, fb: Self::FB) -> () {}
}
struct Val;
"#,
);
let result = analyze_dispatch_traits(&items);
let info = result.get("ApplyFirstDispatch").unwrap();
assert_eq!(info.associated_types.len(), 1);
assert_eq!(info.associated_types[0].0, "FB");
assert_eq!(info.associated_types[0].1, vec!["B".to_string()]);
assert_eq!(info.self_type_elements, Some(vec!["A".to_string()]));
}
#[test]
fn test_container_param_position() {
let items = make_items(
r#"
trait Lift2Dispatch<'a, Brand: Kind_abc123, A: 'a, B: 'a, C: 'a, FA, FB, Marker> {
fn dispatch(self, fa: FA, fb: FB) -> ();
}
impl<'a, Brand, A, B, C, F>
Lift2Dispatch<
'a,
Brand,
A,
B,
C,
Apply!(<Brand as Kind!( type Of<'a, T: 'a>: 'a; )>::Of<'a, A>),
Apply!(<Brand as Kind!( type Of<'a, T: 'a>: 'a; )>::Of<'a, B>),
Val,
> for F
where
Brand: Lift,
A: 'a,
B: 'a,
C: 'a,
F: Fn(A, B) -> C + 'a,
{
fn dispatch(self, fa: Apply!(<Brand as Kind!( type Of<'a, T: 'a>: 'a; )>::Of<'a, A>), fb: Apply!(<Brand as Kind!( type Of<'a, T: 'a>: 'a; )>::Of<'a, B>)) -> () { todo!() }
}
struct Val;
"#,
);
let result = analyze_dispatch_traits(&items);
let info = result.get("Lift2Dispatch").unwrap();
assert_eq!(info.container_params.len(), 2);
assert_eq!(info.container_params[0].name, "FA");
assert_eq!(info.container_params[0].position, 4);
assert_eq!(info.container_params[0].element_types, vec!["A".to_string()]);
assert_eq!(info.container_params[1].name, "FB");
assert_eq!(info.container_params[1].position, 5);
assert_eq!(info.container_params[1].element_types, vec!["B".to_string()]);
}
#[test]
fn test_type_param_order_preserves_trait_definition_order() {
let items = make_items(
r#"
trait WiltDispatch<'a, FnBrand, Brand: Kind_abc123, M, A: 'a, E: 'a, O: 'a, FA, Marker> {
fn dispatch(self, ta: FA) -> ();
}
impl<'a, FnBrand, Brand, M, A, E, O, Func>
WiltDispatch<'a, FnBrand, Brand, M, A, E, O, (), Val> for Func
where
Brand: Witherable,
A: 'a,
E: 'a,
O: 'a,
M: Applicative,
Func: Fn(A) -> () + 'a,
{
fn dispatch(self, ta: ()) -> () {}
}
struct Val;
"#,
);
let result = analyze_dispatch_traits(&items);
let info = result.get("WiltDispatch").unwrap();
assert_eq!(info.type_param_order, vec!["Brand", "M", "A", "E", "O", "FA"]);
}
}