use {
crate::{
analysis::{
format_brand_name,
get_type_parameters,
},
core::{
config::Config,
constants::{
macros,
types,
},
error_handling::ErrorCollector,
},
hkt::ApplyInput,
resolution::ProjectionKey,
},
quote::quote,
std::collections::HashMap,
syn::{
Error,
GenericParam,
Signature,
parse_quote,
spanned::Spanned,
visit_mut::{
self,
VisitMut,
},
},
};
pub fn get_concrete_type_name(
ty: &syn::Type,
config: &Config,
) -> Option<String> {
match ty {
syn::Type::Path(type_path) => {
if let Some(segment) = type_path.path.segments.first() {
let name = segment.ident.to_string();
Some(format_brand_name(&name, config))
} else {
None
}
}
_ => None,
}
}
pub fn get_self_type_info(
self_ty: &syn::Type,
impl_generics: &syn::Generics,
) -> (Option<String>, Vec<String>) {
let base_name = match self_ty {
syn::Type::Path(type_path) =>
type_path.path.segments.last().map(|seg| seg.ident.to_string()),
_ => None,
};
let generic_names = get_type_parameters(impl_generics);
(base_name, generic_names)
}
pub fn build_parameterized_type(
base_name: &str,
generic_params: &[String],
) -> syn::Type {
let base_ident = syn::Ident::new(base_name, proc_macro2::Span::call_site());
if generic_params.is_empty() {
parse_quote!(#base_ident)
} else {
let params: Vec<syn::Ident> = generic_params
.iter()
.map(|p| syn::Ident::new(p, proc_macro2::Span::call_site()))
.collect();
parse_quote!(#base_ident<#(#params),*>)
}
}
pub fn merge_generics(
sig: &mut Signature,
impl_generics: &syn::Generics,
) {
let mut new_params = syn::punctuated::Punctuated::<GenericParam, syn::token::Comma>::new();
for p in impl_generics.params.iter().chain(sig.generics.params.iter()) {
if let GenericParam::Lifetime(_) = p {
new_params.push(p.clone());
}
}
for p in impl_generics.params.iter().chain(sig.generics.params.iter()) {
if let GenericParam::Type(_) = p {
new_params.push(p.clone());
}
}
for p in impl_generics.params.iter().chain(sig.generics.params.iter()) {
if let GenericParam::Const(_) = p {
new_params.push(p.clone());
}
}
sig.generics.params = new_params;
if let Some(impl_where) = &impl_generics.where_clause {
let where_clause = sig.generics.make_where_clause();
for pred in &impl_where.predicates {
where_clause.predicates.push(pred.clone());
}
}
}
pub struct SelfSubstitutor<'a> {
self_ty: &'a syn::Type,
self_ty_path: &'a str,
trait_path: Option<&'a str>,
document_use: Option<&'a str>,
signature_hash: Option<u64>,
config: &'a Config,
pub errors: ErrorCollector,
base_type_name: Option<String>,
impl_generic_params: Vec<String>,
}
impl<'a> SelfSubstitutor<'a> {
pub fn new(
self_ty: &'a syn::Type,
self_ty_path: &'a str,
trait_path: Option<&'a str>,
document_use: Option<&'a str>,
config: &'a Config,
base_type_name: Option<String>,
impl_generic_params: Vec<String>,
) -> Self {
Self {
self_ty,
self_ty_path,
trait_path,
document_use,
signature_hash: None,
config,
errors: ErrorCollector::new(),
base_type_name,
impl_generic_params,
}
}
fn resolve_default_assoc_name(&self) -> Option<String> {
self.document_use
.map(|s| s.to_string())
.or_else(|| {
self.trait_path.and_then(|tp| {
self.config
.scoped_defaults
.get(&(self.self_ty_path.to_string(), tp.to_string()))
.cloned()
})
})
.or_else(|| self.config.module_defaults.get(self.self_ty_path).cloned())
}
fn lookup_projection(
&self,
assoc_name: &str,
) -> Option<&(syn::Generics, syn::Type)> {
let scoped_key = self
.trait_path
.map(|trait_path| ProjectionKey::scoped(self.self_ty_path, trait_path, assoc_name));
if let Some(key) = scoped_key
&& let Some(result) = self.config.projections.get(&key)
{
return Some(result);
}
let mut module_key = ProjectionKey::new(self.self_ty_path, assoc_name);
if let Some(hash) = self.signature_hash {
module_key = module_key.with_signature_hash(hash);
}
if let Some(result) = self.config.projections.get(&module_key) {
return Some(result);
}
if self.signature_hash.is_some() {
let module_key_no_hash = ProjectionKey::new(self.self_ty_path, assoc_name);
if let Some(result) = self.config.projections.get(&module_key_no_hash) {
#[cfg(debug_assertions)]
{
eprintln!(
"Warning: Signature hash lookup failed for {}.{assoc_name}, falling back to legacy lookup",
self.self_ty_path
);
}
return Some(result);
}
}
None
}
fn build_fallback_type(&self) -> syn::Type {
if let Some(base_name) = &self.base_type_name {
build_parameterized_type(base_name, &self.impl_generic_params)
} else {
self.self_ty.clone()
}
}
fn resolve_bare_self(
&mut self,
tp: &syn::TypePath,
) -> syn::Type {
if let Some(assoc_name) = self.resolve_default_assoc_name() {
if let Some((_generics, target)) = self.lookup_projection(&assoc_name) {
target.clone()
} else {
self.build_fallback_type()
}
} else {
if self.base_type_name.is_some() {
self.build_fallback_type()
} else {
self.errors.push(create_missing_default_error(
tp.span(),
self.self_ty_path,
self.trait_path,
self.config,
));
self.self_ty.clone()
}
}
}
fn resolve_self_assoc_type(
&mut self,
tp: &syn::TypePath,
segment: &syn::PathSegment,
) -> syn::Type {
let assoc_name = segment.ident.to_string();
if let Some((generics, target)) = self.lookup_projection(&assoc_name) {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
substitute_generics(target.clone(), generics, &args.args)
} else {
target.clone()
}
} else {
self.errors.push(create_missing_assoc_type_error(
tp.span(),
self.self_ty_path,
&assoc_name,
self.trait_path,
self.config,
));
let self_ty = self.self_ty;
let mut new_path = tp.path.clone();
new_path.segments = new_path.segments.into_iter().skip(1).collect();
let segments = &new_path.segments;
parse_quote!(<#self_ty>::#segments)
}
}
}
impl<'a> VisitMut for SelfSubstitutor<'a> {
fn visit_type_mut(
&mut self,
i: &mut syn::Type,
) {
if let syn::Type::Path(tp) = i {
if tp.path.is_ident(types::SELF) {
*i = self.resolve_bare_self(tp);
} else if let Some(first) = tp.path.segments.first()
&& first.ident == types::SELF
&& tp.path.segments.len() > 1
{
#[expect(clippy::indexing_slicing, reason = "segments.len() > 1 checked above")]
let segment = &tp.path.segments[1];
*i = self.resolve_self_assoc_type(tp, segment);
}
}
visit_mut::visit_type_mut(self, i);
}
fn visit_type_macro_mut(
&mut self,
i: &mut syn::TypeMacro,
) {
if i.mac.path.is_ident(macros::APPLY_MACRO)
&& let Ok(mut apply_input) = syn::parse2::<ApplyInput>(i.mac.tokens.clone())
{
let prev_hash = self.signature_hash;
let current_hash = match apply_input
.kind_input
.associated_types
.iter()
.find(|a| a.signature.name == apply_input.assoc_name)
{
Some(a) => match crate::hkt::canonicalizer::hash_assoc_signature(&a.signature) {
Ok(h) => Some(h),
Err(e) => {
self.errors.push(syn::Error::new(
a.signature.name.span(),
format!("Failed to compute signature hash: {e}"),
));
None
}
},
None => None,
};
self.signature_hash = current_hash;
self.visit_type_mut(&mut apply_input.brand);
for arg in apply_input.args.args.iter_mut() {
if let syn::GenericArgument::Type(ty) = arg {
self.visit_type_mut(ty);
}
}
let brand = &apply_input.brand;
let kind_input = &apply_input.kind_input;
let assoc_name = &apply_input.assoc_name;
let args = &apply_input.args;
i.mac.tokens = quote! { <#brand as Kind!(#kind_input)>::#assoc_name #args };
self.signature_hash = prev_hash;
}
visit_mut::visit_type_macro_mut(self, i);
}
fn visit_signature_mut(
&mut self,
i: &mut Signature,
) {
for input in &mut i.inputs {
if let syn::FnArg::Receiver(r) = input {
let concrete_ty = if let Some(base_name) = &self.base_type_name {
build_parameterized_type(base_name, &self.impl_generic_params)
} else {
self.self_ty.clone()
};
let attrs = &r.attrs;
if let Some(reference) = &r.reference {
let lt = &reference.1;
if r.mutability.is_some() {
let pat: syn::Pat = parse_quote!(self);
let ty: syn::Type = parse_quote!(&#lt mut #concrete_ty);
*input = syn::FnArg::Typed(syn::PatType {
attrs: attrs.clone(),
pat: Box::new(pat),
colon_token: Default::default(),
ty: Box::new(ty),
});
} else {
let pat: syn::Pat = parse_quote!(self);
let ty: syn::Type = parse_quote!(&#lt #concrete_ty);
*input = syn::FnArg::Typed(syn::PatType {
attrs: attrs.clone(),
pat: Box::new(pat),
colon_token: Default::default(),
ty: Box::new(ty),
});
}
} else {
let pat: syn::Pat = parse_quote!(self);
let ty: syn::Type = parse_quote!(#concrete_ty);
*input = syn::FnArg::Typed(syn::PatType {
attrs: attrs.clone(),
pat: Box::new(pat),
colon_token: Default::default(),
ty: Box::new(ty),
});
}
}
}
visit_mut::visit_signature_mut(self, i);
}
}
pub fn type_uses_self_assoc(ty: &syn::Type) -> bool {
struct SelfAssocVisitor {
found: bool,
}
impl syn::visit::Visit<'_> for SelfAssocVisitor {
fn visit_type_path(
&mut self,
i: &syn::TypePath,
) {
if let Some(first) = i.path.segments.first()
&& first.ident == types::SELF
&& i.path.segments.len() > 1
{
self.found = true;
}
syn::visit::visit_type_path(self, i);
}
}
let mut visitor = SelfAssocVisitor {
found: false,
};
syn::visit::visit_type(&mut visitor, ty);
visitor.found
}
pub(crate) fn substitute_generics(
mut ty: syn::Type,
generics: &syn::Generics,
args: &syn::punctuated::Punctuated<syn::GenericArgument, syn::token::Comma>,
) -> syn::Type {
let mut mapping = HashMap::new();
let mut const_mapping = HashMap::new();
for (param, arg) in generics.params.iter().zip(args.iter()) {
match (param, arg) {
(syn::GenericParam::Type(tp), syn::GenericArgument::Type(at)) => {
mapping.insert(tp.ident.to_string(), at.clone());
}
(syn::GenericParam::Const(cp), syn::GenericArgument::Const(ca)) => {
const_mapping.insert(cp.ident.to_string(), ca.clone());
}
(syn::GenericParam::Const(cp), syn::GenericArgument::Type(syn::Type::Path(tp)))
if tp.path.get_ident().is_some() =>
{
if let Some(ident) = tp.path.get_ident() {
const_mapping.insert(cp.ident.to_string(), syn::parse_quote!(#ident));
}
}
_ => {}
}
}
struct SubstitutionVisitor<'a> {
mapping: &'a HashMap<String, syn::Type>,
const_mapping: &'a HashMap<String, syn::Expr>,
}
impl VisitMut for SubstitutionVisitor<'_> {
fn visit_type_mut(
&mut self,
i: &mut syn::Type,
) {
if let syn::Type::Path(tp) = i
&& let Some(ident) = tp.path.get_ident()
&& let Some(target) = self.mapping.get(&ident.to_string())
{
*i = target.clone();
return;
}
visit_mut::visit_type_mut(self, i);
}
fn visit_expr_mut(
&mut self,
i: &mut syn::Expr,
) {
if let syn::Expr::Path(ep) = i
&& let Some(ident) = ep.path.get_ident()
&& let Some(target) = self.const_mapping.get(&ident.to_string())
{
*i = target.clone();
return;
}
visit_mut::visit_expr_mut(self, i);
}
}
let mut visitor = SubstitutionVisitor {
mapping: &mapping,
const_mapping: &const_mapping,
};
visitor.visit_type_mut(&mut ty);
ty
}
pub fn normalize_type(
mut ty: syn::Type,
generics: &syn::Generics,
) -> syn::Type {
let mut mapping = HashMap::new();
let mut type_idx = 0;
for param in &generics.params {
if let syn::GenericParam::Type(tp) = param {
let ident = quote::format_ident!("T{type_idx}");
mapping.insert(tp.ident.to_string(), parse_quote!(#ident));
type_idx += 1;
}
}
struct NormalizationVisitor<'a> {
mapping: &'a HashMap<String, syn::Type>,
}
impl VisitMut for NormalizationVisitor<'_> {
fn visit_type_mut(
&mut self,
i: &mut syn::Type,
) {
if let syn::Type::Path(tp) = i
&& let Some(ident) = tp.path.get_ident()
&& let Some(target) = self.mapping.get(&ident.to_string())
{
*i = target.clone();
return;
}
visit_mut::visit_type_mut(self, i);
}
}
let mut visitor = NormalizationVisitor {
mapping: &mapping,
};
visitor.visit_type_mut(&mut ty);
ty
}
fn get_available_types_for_brand(
config: &Config,
self_ty_path: &str,
trait_path: Option<&str>,
) -> (Vec<String>, Vec<String>) {
let mut in_this_impl = Vec::new();
let mut in_other_traits = Vec::new();
for key in config.projections.keys() {
if key.type_path() == self_ty_path {
match (key.trait_path(), trait_path) {
(Some(t), Some(current)) if t == current => {
in_this_impl.push(key.assoc_name().to_string());
}
(Some(_), _) | (None, _) => {
in_other_traits.push(key.assoc_name().to_string());
}
}
}
}
in_this_impl.sort();
in_this_impl.dedup();
in_other_traits.sort();
in_other_traits.dedup();
(in_this_impl, in_other_traits)
}
fn create_missing_default_error(
span: proc_macro2::Span,
self_ty_path: &str,
trait_path: Option<&str>,
config: &Config,
) -> Error {
let (in_this_impl, in_other_traits) =
get_available_types_for_brand(config, self_ty_path, trait_path);
let mut message =
format!("Cannot resolve bare `Self` for type `{self_ty_path}` - no default specified");
if !in_this_impl.is_empty() {
message.push_str(&format!(
r#"
= note: Available in this impl: {}"#,
in_this_impl.join(", ")
));
}
if !in_other_traits.is_empty() {
message.push_str(&format!(
r#"
= note: Available in other traits: {}"#,
in_other_traits.join(", ")
));
}
message.push_str(
r#"
= help: Mark one as default with #[document_default], or use explicit #[document_use = "AssocName"]"#,
);
Error::new(span, message)
}
fn create_missing_assoc_type_error(
span: proc_macro2::Span,
self_ty_path: &str,
assoc_name: &str,
trait_path: Option<&str>,
config: &Config,
) -> Error {
let (in_this_impl, in_other_traits) =
get_available_types_for_brand(config, self_ty_path, trait_path);
let mut message = format!("Cannot resolve `Self::{assoc_name}` for type `{self_ty_path}`");
let all_available: Vec<String> =
in_this_impl.iter().chain(in_other_traits.iter()).cloned().collect();
if !all_available.is_empty() {
message.push_str(&format!(
r#"
= note: Available associated types: {}"#,
all_available.join(", ")
));
} else {
message.push_str(
r#"
= note: No associated types found for this type"#,
);
}
message.push_str(&format!(
r#"
= help: Add an associated type definition:
impl_kind! {{{{
for {self_ty_path} {{{{
type {assoc_name}<T> = YourType<T>;
}}}}
}}}}"#,
));
Error::new(span, message)
}