use std::collections::BTreeMap;
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::{format_ident, quote};
use syn::{
spanned::Spanned, AngleBracketedGenericArguments, Expr, ExprPath, FnArg, GenericArgument,
GenericParam, Generics, Ident, ImplItem, ImplItemFn, ItemFn, ItemImpl, ItemTrait, Pat,
PathArguments, ReturnType, TraitItem, Type,
};
use crate::error::{syn_err, syn_error_at, Error};
pub const MAX_RANK: usize = 6;
pub fn emit_shadow_dispatch(
item: &ItemFn,
method_override: Option<Ident>,
trait_name_override: Option<Ident>,
) -> Result<TokenStream2, Error> {
let spec = RankPolyOpSpec::parse(item, method_override, trait_name_override)?;
let trait_decl = spec.emit_trait();
let impls = spec.emit_impls();
let wrapper = spec.emit_wrapper();
Ok(quote! {
#trait_decl
#(#impls)*
#wrapper
})
}
#[derive(Clone, Debug)]
struct CgaInfo {
cga_ident: Ident,
length_ident: Ident,
}
struct RankPolyOpSpec {
fn_ident: Ident,
method_ident: Ident,
trait_ident: Ident,
cgas: Vec<CgaInfo>,
length_groups: Vec<LengthGroup>,
trait_level_params: Vec<GenericParam>,
impl_only_params: Vec<GenericParam>,
where_clause: Option<syn::WhereClause>,
args: Vec<ParsedArg>,
return_cga: Option<Ident>,
return_type: Type,
is_unsafe: bool,
dead_lifetimes: Vec<String>,
dead_lt_idents: Vec<syn::Lifetime>,
rank_dep_arg_idents: Vec<Option<Ident>>,
}
#[derive(Clone, Debug)]
struct LengthGroup {
length_ident: Ident,
cgas: Vec<Ident>,
}
#[derive(Debug)]
struct ParsedArg {
name: Ident,
cga: Option<Ident>,
ty: Type,
is_self_like: bool,
}
impl RankPolyOpSpec {
fn parse(
item: &ItemFn,
method_override: Option<Ident>,
trait_name_override: Option<Ident>,
) -> Result<Self, Error> {
let fn_ident = item.sig.ident.clone();
let method_ident = method_override.unwrap_or_else(|| fn_ident.clone());
let trait_ident = trait_name_override
.unwrap_or_else(|| format_ident!("{}", pascal_case(&fn_ident.to_string())));
let cgas = find_cgas(&item.sig.generics, fn_ident.span())?;
if cgas.is_empty() {
return syn_error_at(
fn_ident.span(),
"shadow_dispatch: no `const X: [i32; N]` generic found \
(is this op rank-polymorphic?)",
);
}
let length_groups = group_cgas_by_length(&cgas);
let mut args = Vec::new();
let mut first_shape_bearing_ty: Option<Type> = None;
for arg in &item.sig.inputs {
match arg {
FnArg::Typed(pat_type) => {
let name = match &*pat_type.pat {
Pat::Ident(pat_ident) => pat_ident.ident.clone(),
_ => {
return syn_error_at(
pat_type.pat.span(),
"shadow_dispatch: unsupported argument pattern",
);
}
};
let ty = (*pat_type.ty).clone();
let cga = cgas
.iter()
.find(|c| type_references_ident(&ty, &c.cga_ident))
.map(|c| c.cga_ident.clone());
let is_self_like = if cga.is_some() {
match &first_shape_bearing_ty {
None => {
first_shape_bearing_ty = Some(ty.clone());
true
}
Some(prev) => types_equal(prev, &ty),
}
} else {
false
};
args.push(ParsedArg {
name,
cga,
ty,
is_self_like,
});
}
FnArg::Receiver(r) => {
return syn_error_at(
r.span(),
"shadow_dispatch: free fns only (no `self` receiver)",
);
}
}
}
if first_shape_bearing_ty.is_none() {
return syn_error_at(
fn_ident.span(),
"shadow_dispatch: no shape-bearing argument found",
);
}
let (raw_return_type, return_cga) = match &item.sig.output {
ReturnType::Default => (syn::parse_quote! { () }, None),
ReturnType::Type(_, ty) => {
let cga = cgas
.iter()
.find(|c| type_references_ident(ty, &c.cga_ident))
.map(|c| c.cga_ident.clone());
((**ty).clone(), cga)
}
};
let all_arg_types: Vec<&Type> = args.iter().map(|a| &a.ty).collect();
let dead_lifetimes: Vec<String> = item
.sig
.generics
.params
.iter()
.filter_map(|p| {
if let GenericParam::Lifetime(lt) = p {
let ident_s = lt.lifetime.ident.to_string();
let in_args = all_arg_types
.iter()
.any(|ty| type_references_ident(ty, <.lifetime.ident));
let in_return = type_references_ident(&raw_return_type, <.lifetime.ident);
if in_return && !in_args {
return Some(ident_s);
}
}
None
})
.collect();
let return_type = raw_return_type;
let non_shape_bearing_types: Vec<&Type> = args
.iter()
.filter(|a| a.cga.is_none())
.map(|a| &a.ty)
.collect();
let cga_idents: Vec<&Ident> = cgas.iter().map(|c| &c.cga_ident).collect();
let length_idents: Vec<&Ident> = cgas.iter().map(|c| &c.length_ident).collect();
let mut trait_level_params = Vec::new();
let mut impl_only_params = Vec::new();
for param in &item.sig.generics.params {
if let GenericParam::Lifetime(lt) = param {
if dead_lifetimes
.iter()
.any(|d| d == <.lifetime.ident.to_string())
{
continue;
}
}
let ident = match param {
GenericParam::Type(t) => t.ident.clone(),
GenericParam::Const(c) if cga_idents.iter().any(|i| **i == c.ident) => continue,
GenericParam::Const(c) if length_idents.iter().any(|i| **i == c.ident) => continue,
GenericParam::Const(c) => c.ident.clone(),
GenericParam::Lifetime(lt) => lt.lifetime.ident.clone(),
};
let used_in_non_shape_arg = non_shape_bearing_types
.iter()
.any(|ty| type_references_ident(ty, &ident));
if used_in_non_shape_arg {
trait_level_params.push(param.clone());
} else {
impl_only_params.push(param.clone());
}
}
let dead_lt_idents: Vec<syn::Lifetime> = (0..dead_lifetimes.len())
.map(|i| syn::parse_str(&format!("'__td_lt{}", i)).unwrap())
.collect();
let mut rank_dep_arg_idents: Vec<Option<Ident>> = vec![None; args.len()];
let mut next_ra = 0usize;
for (i, arg) in args.iter().enumerate() {
if arg.cga.is_some() {
continue;
}
let mentions_length = cgas
.iter()
.any(|c| type_references_ident(&arg.ty, &c.length_ident));
if mentions_length {
rank_dep_arg_idents[i] = Some(format_ident!("__td_arg{}", next_ra));
next_ra += 1;
}
}
let is_unsafe = item.sig.unsafety.is_some();
Ok(Self {
fn_ident,
method_ident,
trait_ident,
cgas,
length_groups,
trait_level_params,
impl_only_params,
where_clause: item.sig.generics.where_clause.clone(),
args,
return_cga,
return_type,
dead_lifetimes,
dead_lt_idents,
rank_dep_arg_idents,
is_unsafe,
})
}
fn is_same_shape(&self) -> bool {
match &self.return_cga {
None => false,
Some(ret_cga) => {
let first_shape_bearing = self.args.iter().find(|a| a.cga.is_some()).unwrap();
let first_cga = first_shape_bearing.cga.as_ref().unwrap();
ret_cga == first_cga && types_equal(&first_shape_bearing.ty, &self.return_type)
}
}
}
fn is_void_return(&self) -> bool {
matches!(&self.return_type, Type::Tuple(t) if t.elems.is_empty())
}
fn return_is_free(&self) -> bool {
if self.is_same_shape() {
return false;
}
if self.return_cga.is_none() {
return false;
}
let arg_texts: Vec<String> = self
.args
.iter()
.map(|a| {
let ty = &a.ty;
quote! { #ty }.to_string()
})
.collect();
for param in &self.impl_only_params {
let ident = match param {
GenericParam::Type(t) => t.ident.clone(),
GenericParam::Const(c) => c.ident.clone(),
GenericParam::Lifetime(_) => continue,
};
let in_any_arg = arg_texts
.iter()
.any(|s| contains_whole_word(s, &ident.to_string()));
if !in_any_arg && type_references_ident(&self.return_type, &ident) {
return true;
}
}
if let Some(ret_cga) = &self.return_cga {
if !self.args.iter().any(|a| a.cga.as_ref() == Some(ret_cga)) {
return true;
}
}
false
}
fn emit_trait(&self) -> TokenStream2 {
let trait_ident = &self.trait_ident;
let method_ident = &self.method_ident;
let trait_params = &self.trait_level_params;
let where_clause = &self.where_clause;
let extra_shape_generics = self.extra_shape_trait_generics();
let extra_shape_generic_params: Vec<TokenStream2> = extra_shape_generics
.iter()
.map(|ident| quote! { #ident })
.collect();
let method_inputs = self.method_inputs_tokens(&extra_shape_generics);
let (return_token, assoc_type, extra_out_trait_param): (
TokenStream2,
TokenStream2,
Option<TokenStream2>,
) = if self.is_void_return() {
(quote! { () }, quote! {}, None)
} else if self.is_same_shape() {
(quote! { Self }, quote! {}, None)
} else if self.return_is_free() {
(quote! { Out }, quote! {}, Some(quote! { Out }))
} else {
(quote! { Self::Out }, quote! { type Out; }, None)
};
let mut all_trait_params: Vec<TokenStream2> = Vec::new();
for lt in &self.dead_lt_idents {
all_trait_params.push(quote! { #lt });
}
for p in &extra_shape_generic_params {
all_trait_params.push(p.clone());
}
for slot in &self.rank_dep_arg_idents {
if let Some(id) = slot {
all_trait_params.push(quote! { #id });
}
}
if let Some(ref out) = extra_out_trait_param {
all_trait_params.push(out.clone());
}
for p in trait_params {
all_trait_params.push(quote! { #p });
}
let unsafe_kw: TokenStream2 = if self.is_unsafe {
quote! { unsafe }
} else {
quote! {}
};
quote! {
#[allow(non_camel_case_types)]
pub trait #trait_ident < #(#all_trait_params),* >
#where_clause
{
#assoc_type
#unsafe_kw fn #method_ident(#(#method_inputs),*) -> #return_token;
}
}
}
fn emit_impls(&self) -> Vec<TokenStream2> {
let mut impls = Vec::new();
let rank_space = RankSpace::new(&self.length_groups, MAX_RANK);
for combo in rank_space.iter() {
impls.push(self.emit_impl(&combo));
}
impls
}
fn emit_impl(&self, combo: &RankCombo) -> TokenStream2 {
let trait_ident = &self.trait_ident;
let method_ident = &self.method_ident;
let where_clause = &self.where_clause;
let impl_only = &self.impl_only_params;
let trait_level = &self.trait_level_params;
let dim_params: Vec<TokenStream2> = combo.dim_params_tokens().into_iter().collect();
let first_shape_bearing = self.args.iter().find(|a| a.cga.is_some()).unwrap();
let self_type = rewrite_ty_for_rank(&first_shape_bearing.ty, combo, &self.cgas);
let extra_shape_bindings: Vec<Type> = self
.args
.iter()
.filter(|a| a.cga.is_some() && !a.is_self_like)
.map(|a| rewrite_ty_for_rank(&a.ty, combo, &self.cgas))
.collect();
let trait_level_args = self
.trait_level_params
.iter()
.map(|p| match p {
GenericParam::Type(t) => {
let i = &t.ident;
quote! { #i }
}
GenericParam::Const(c) => {
let i = &c.ident;
quote! { #i }
}
GenericParam::Lifetime(lt) => {
let l = <.lifetime;
quote! { #l }
}
})
.collect::<Vec<_>>();
let mut return_concrete = rewrite_ty_for_rank(&self.return_type, combo, &self.cgas);
for (orig, replacement) in self.dead_lifetimes.iter().zip(self.dead_lt_idents.iter()) {
return_concrete =
replace_lifetimes_with(&return_concrete, &[orig.clone()], replacement);
}
let mut trait_instantiation_args: Vec<TokenStream2> = Vec::new();
for lt in &self.dead_lt_idents {
trait_instantiation_args.push(quote! { #lt });
}
for binding in &extra_shape_bindings {
trait_instantiation_args.push(quote! { #binding });
}
for (i, slot) in self.rank_dep_arg_idents.iter().enumerate() {
if slot.is_some() {
let ty = rewrite_ty_for_rank(&self.args[i].ty, combo, &self.cgas);
trait_instantiation_args.push(quote! { #ty });
}
}
if !self.is_same_shape() && self.return_is_free() {
trait_instantiation_args.push(quote! { #return_concrete });
}
for arg in &trait_level_args {
trait_instantiation_args.push(arg.clone());
}
let method_inputs = self.impl_method_inputs_tokens(combo);
let method_input_muted = self.method_arg_idents_muted();
let (return_in_method, assoc_type_item): (TokenStream2, TokenStream2) =
if self.is_void_return() {
(quote! { () }, quote! {})
} else if self.is_same_shape() {
(quote! { Self }, quote! {})
} else if self.return_is_free() {
(quote! { #return_concrete }, quote! {})
} else {
(
quote! { Self::Out },
quote! { type Out = #return_concrete; },
)
};
let mut all_impl_params: Vec<TokenStream2> = Vec::new();
for lt in &self.dead_lt_idents {
all_impl_params.push(quote! { #lt });
}
for p in impl_only {
all_impl_params.push(quote! { #p });
}
for p in trait_level {
all_impl_params.push(quote! { #p });
}
for p in &dim_params {
all_impl_params.push(p.clone());
}
let unsafe_kw: TokenStream2 = if self.is_unsafe {
quote! { unsafe }
} else {
quote! {}
};
quote! {
impl < #(#all_impl_params),* > #trait_ident < #(#trait_instantiation_args),* >
for #self_type
#where_clause
{
#assoc_type_item
#unsafe_kw fn #method_ident(#(#method_inputs),*) -> #return_in_method {
let _ = (#(#method_input_muted),*);
::std::unreachable!()
}
}
}
}
fn impl_method_inputs_tokens(&self, combo: &RankCombo) -> Vec<TokenStream2> {
let mut tokens = Vec::new();
tokens.push(quote! { self });
let mut seen_first_shape_bearing = false;
for arg in &self.args {
let name = &arg.name;
if arg.cga.is_some() {
if !seen_first_shape_bearing {
seen_first_shape_bearing = true;
continue;
} else if arg.is_self_like {
tokens.push(quote! { #name: Self });
} else {
let ty = rewrite_ty_for_rank(&arg.ty, combo, &self.cgas);
tokens.push(quote! { #name: #ty });
}
} else {
let ty = rewrite_ty_for_rank(&arg.ty, combo, &self.cgas);
tokens.push(quote! { #name: #ty });
}
}
tokens
}
fn emit_wrapper(&self) -> TokenStream2 {
let fn_ident = &self.fn_ident;
let method_ident = &self.method_ident;
let trait_ident = &self.trait_ident;
let trait_params = &self.trait_level_params;
let where_clause = &self.where_clause;
let t_ident = Ident::new("__T", Span::call_site());
let extra_shape_generics = self.extra_shape_trait_generics();
let extra_shape_generic_idents: Vec<Ident> = extra_shape_generics.clone();
let first_shape_arg = self.args.iter().find(|a| a.cga.is_some()).unwrap();
let recv_ref_kind: Option<bool> = match &first_shape_arg.ty {
Type::Reference(r) => Some(r.mutability.is_some()),
_ => None,
};
let recv_lifetime: Option<syn::Lifetime> =
recv_ref_kind.map(|_| syn::parse_quote! { '__td_recv });
let receiver_ty_expr: TokenStream2 = match (&recv_ref_kind, &recv_lifetime) {
(Some(true), Some(lt)) => quote! { & #lt mut #t_ident },
(Some(false), Some(lt)) => quote! { & #lt #t_ident },
_ => quote! { #t_ident },
};
let mut wrapper_args = Vec::new();
let mut extra_shape_idx = 0;
let mut seen_first_shape_bearing = false;
for (i, arg) in self.args.iter().enumerate() {
let name = &arg.name;
if arg.cga.is_some() {
if arg.is_self_like {
if !seen_first_shape_bearing {
seen_first_shape_bearing = true;
wrapper_args.push(quote! { #name: #receiver_ty_expr });
} else {
wrapper_args.push(quote! { #name: #receiver_ty_expr });
}
} else {
let sh = &extra_shape_generic_idents[extra_shape_idx];
extra_shape_idx += 1;
wrapper_args.push(quote! { #name: #sh });
}
} else if let Some(rd) = &self.rank_dep_arg_idents[i] {
wrapper_args.push(quote! { #name: #rd });
} else {
let ty = rewrite_literal_cgas_only(&arg.ty);
wrapper_args.push(quote! { #name: #ty });
}
}
let trait_level_args = trait_params
.iter()
.map(|p| match p {
GenericParam::Type(t) => {
let i = &t.ident;
quote! { #i }
}
GenericParam::Const(c) => {
let i = &c.ident;
quote! { #i }
}
GenericParam::Lifetime(lt) => {
let l = <.lifetime;
quote! { #l }
}
})
.collect::<Vec<_>>();
let out_ident = Ident::new("Out", Span::call_site());
let use_free_out = !self.is_same_shape() && self.return_is_free();
let mut trait_args: Vec<TokenStream2> = Vec::new();
for lt in &self.dead_lt_idents {
trait_args.push(quote! { #lt });
}
for i in &extra_shape_generic_idents {
trait_args.push(quote! { #i });
}
for slot in &self.rank_dep_arg_idents {
if let Some(id) = slot {
trait_args.push(quote! { #id });
}
}
if use_free_out {
trait_args.push(quote! { #out_ident });
}
for arg in &trait_level_args {
trait_args.push(arg.clone());
}
let wrapper_return = if self.is_void_return() {
quote! { () }
} else if self.is_same_shape() {
receiver_ty_expr.clone()
} else if use_free_out {
quote! { #out_ident }
} else {
quote! { <#receiver_ty_expr as #trait_ident < #(#trait_args),* >>::Out }
};
let receiver_name = self
.args
.iter()
.find(|a| a.cga.is_some())
.map(|a| a.name.clone())
.unwrap();
let mut method_call_args: Vec<TokenStream2> = Vec::new();
let mut seen_receiver = false;
for arg in &self.args {
if !seen_receiver && arg.cga.is_some() {
seen_receiver = true;
continue;
}
let n = &arg.name;
method_call_args.push(quote! { #n });
}
let mut all_wrapper_generics: Vec<TokenStream2> = Vec::new();
if let Some(lt) = &recv_lifetime {
all_wrapper_generics.push(quote! { #lt });
}
for lt in &self.dead_lt_idents {
all_wrapper_generics.push(quote! { #lt });
}
all_wrapper_generics.push(quote! { #t_ident });
for i in &extra_shape_generic_idents {
all_wrapper_generics.push(quote! { #i });
}
for slot in &self.rank_dep_arg_idents {
if let Some(id) = slot {
all_wrapper_generics.push(quote! { #id });
}
}
if use_free_out {
all_wrapper_generics.push(quote! { #out_ident });
}
for p in trait_params.iter() {
all_wrapper_generics.push(quote! { #p });
}
let wrapper_where = if let Some(wc) = where_clause {
let preds = &wc.predicates;
quote! { where #receiver_ty_expr: #trait_ident < #(#trait_args),* >, #preds }
} else {
quote! { where #receiver_ty_expr: #trait_ident < #(#trait_args),* > }
};
let unsafe_kw: TokenStream2 = if self.is_unsafe {
quote! { unsafe }
} else {
quote! {}
};
let body_call: TokenStream2 = if self.is_unsafe {
quote! { unsafe { #receiver_name.#method_ident( #(#method_call_args),* ) } }
} else {
quote! { #receiver_name.#method_ident( #(#method_call_args),* ) }
};
quote! {
pub #unsafe_kw fn #fn_ident < #(#all_wrapper_generics),* > ( #(#wrapper_args),* ) -> #wrapper_return
#wrapper_where
{
#body_call
}
}
}
fn method_inputs_tokens(&self, extra_shape_generics: &[Ident]) -> Vec<TokenStream2> {
let mut tokens = Vec::new();
tokens.push(quote! { self });
let mut seen_first_shape_bearing = false;
let mut extra_shape_idx = 0;
for (i, arg) in self.args.iter().enumerate() {
let name = &arg.name;
if arg.cga.is_some() {
if !seen_first_shape_bearing {
seen_first_shape_bearing = true;
continue;
} else if arg.is_self_like {
tokens.push(quote! { #name: Self });
} else {
let sh = &extra_shape_generics[extra_shape_idx];
extra_shape_idx += 1;
tokens.push(quote! { #name: #sh });
}
} else if let Some(rd) = &self.rank_dep_arg_idents[i] {
tokens.push(quote! { #name: #rd });
} else {
let ty = rewrite_literal_cgas_only(&arg.ty);
tokens.push(quote! { #name: #ty });
}
}
tokens
}
fn method_arg_idents_muted(&self) -> Vec<TokenStream2> {
let mut tokens = Vec::new();
let mut seen_first_shape_bearing = false;
for arg in &self.args {
if arg.cga.is_some() && !seen_first_shape_bearing {
seen_first_shape_bearing = true;
continue;
}
let n = &arg.name;
tokens.push(quote! { #n });
}
tokens
}
fn extra_shape_trait_generics(&self) -> Vec<Ident> {
let mut count: usize = 0;
let mut names = Vec::new();
let mut seen_first_shape_bearing = false;
for arg in &self.args {
if arg.cga.is_some() {
if !seen_first_shape_bearing {
seen_first_shape_bearing = true;
} else if !arg.is_self_like {
names.push(format_ident!("Sh{}", count));
count += 1;
}
}
}
names
}
}
struct RankSpace {
group_lengths: Vec<Ident>, cga_to_group: BTreeMap<String, usize>, cgas: Vec<CgaInfo>,
max_rank: usize,
}
impl RankSpace {
fn new(groups: &[LengthGroup], max_rank: usize) -> Self {
let group_lengths: Vec<Ident> = groups.iter().map(|g| g.length_ident.clone()).collect();
let mut cga_to_group = BTreeMap::new();
let mut cgas = Vec::new();
for (idx, g) in groups.iter().enumerate() {
for cga_ident in &g.cgas {
cga_to_group.insert(cga_ident.to_string(), idx);
cgas.push(CgaInfo {
cga_ident: cga_ident.clone(),
length_ident: g.length_ident.clone(),
});
}
}
Self {
group_lengths,
cga_to_group,
cgas,
max_rank,
}
}
fn iter(&self) -> Vec<RankCombo> {
let n_groups = self.group_lengths.len();
if n_groups == 0 {
return vec![];
}
let mut result = Vec::new();
let per_group_range: usize = self.max_rank + 1;
let total = per_group_range.pow(n_groups as u32);
for code in 0..total {
let mut ranks_per_group = Vec::with_capacity(n_groups);
let mut n = code;
for _ in 0..n_groups {
ranks_per_group.push(n % per_group_range);
n /= per_group_range;
}
result.push(RankCombo {
ranks_per_group,
cga_to_group: self.cga_to_group.clone(),
cgas: self.cgas.clone(),
});
}
result
}
}
struct RankCombo {
ranks_per_group: Vec<usize>,
cga_to_group: BTreeMap<String, usize>,
cgas: Vec<CgaInfo>,
}
impl RankCombo {
fn rank_of(&self, cga_ident: &Ident) -> usize {
let group_idx = *self.cga_to_group.get(&cga_ident.to_string()).unwrap_or(&0);
self.ranks_per_group[group_idx]
}
fn rank_of_length(&self, length_ident: &Ident) -> Option<usize> {
for cga in &self.cgas {
if cga.length_ident == *length_ident {
return Some(self.rank_of(&cga.cga_ident));
}
}
None
}
fn dim_names_for(&self, cga_ident: &Ident) -> Vec<Ident> {
let rank = self.rank_of(cga_ident);
(0..rank)
.map(|i| format_ident!("{}_{}", cga_ident, i))
.collect()
}
fn dim_names_per_cga(&self) -> BTreeMap<String, Vec<Ident>> {
let mut out = BTreeMap::new();
for cga in &self.cgas {
out.insert(
cga.cga_ident.to_string(),
self.dim_names_for(&cga.cga_ident),
);
}
out
}
fn dim_params_tokens(&self) -> Vec<TokenStream2> {
let mut out = Vec::new();
for cga in &self.cgas {
for name in self.dim_names_for(&cga.cga_ident) {
out.push(quote! { const #name: i32 });
}
}
out
}
}
fn find_cgas(generics: &Generics, _span: Span) -> Result<Vec<CgaInfo>, Error> {
let mut out = Vec::new();
for param in &generics.params {
if let GenericParam::Const(c) = param {
if let Type::Array(arr) = &c.ty {
let is_i32 = matches!(
&*arr.elem,
Type::Path(p) if p.path.is_ident("i32")
);
if !is_i32 {
continue;
}
let length_ident = match &arr.len {
Expr::Path(ExprPath { path, .. }) => path
.get_ident()
.cloned()
.ok_or_else(|| syn_err(c.ty.span(), "CGA length must be a simple ident"))?,
_ => continue,
};
out.push(CgaInfo {
cga_ident: c.ident.clone(),
length_ident,
});
}
}
}
Ok(out)
}
fn group_cgas_by_length(cgas: &[CgaInfo]) -> Vec<LengthGroup> {
let mut groups: Vec<LengthGroup> = Vec::new();
for cga in cgas {
if let Some(g) = groups
.iter_mut()
.find(|g| g.length_ident == cga.length_ident)
{
g.cgas.push(cga.cga_ident.clone());
} else {
groups.push(LengthGroup {
length_ident: cga.length_ident.clone(),
cgas: vec![cga.cga_ident.clone()],
});
}
}
groups
}
fn pascal_case(name: &str) -> String {
let mut out = String::new();
let mut upper_next = true;
for ch in name.chars() {
if ch == '_' {
upper_next = true;
continue;
}
if upper_next {
out.extend(ch.to_uppercase());
upper_next = false;
} else {
out.push(ch);
}
}
out
}
fn type_references_ident(ty: &Type, ident: &Ident) -> bool {
let s = quote! { #ty }.to_string();
contains_whole_word(&s, &ident.to_string())
}
fn contains_whole_word(haystack: &str, needle: &str) -> bool {
let bytes = haystack.as_bytes();
let nb = needle.as_bytes();
let mut i = 0;
while i + nb.len() <= bytes.len() {
if &bytes[i..i + nb.len()] == nb {
let prev_ok = i == 0 || !is_word_byte(bytes[i - 1]);
let next_ok = i + nb.len() == bytes.len() || !is_word_byte(bytes[i + nb.len()]);
if prev_ok && next_ok {
return true;
}
}
i += 1;
}
false
}
fn is_word_byte(b: u8) -> bool {
b.is_ascii_alphanumeric() || b == b'_'
}
fn types_equal(a: &Type, b: &Type) -> bool {
quote! { #a }.to_string() == quote! { #b }.to_string()
}
fn rewrite_literal_cgas_only(ty: &Type) -> Type {
match ty {
Type::Path(tp) => {
let mut new_tp = tp.clone();
if let Some(last_seg) = new_tp.path.segments.last_mut() {
if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
args, ..
}) = &mut last_seg.arguments
{
if let Some(rank) = args.iter().find_map(literal_cga_rank) {
let original_base = last_seg.ident.clone();
last_seg.ident = format_ident!("{}_{}", last_seg.ident, rank);
let already_has_lifetime = args
.iter()
.any(|a| matches!(a, GenericArgument::Lifetime(_)));
let needs_lifetime =
base_has_lifetime(&original_base) && !already_has_lifetime;
let mut new_args: syn::punctuated::Punctuated<
GenericArgument,
syn::Token![,],
> = syn::punctuated::Punctuated::new();
if needs_lifetime {
let lt: syn::Lifetime = syn::parse_quote! { '_ };
new_args.push(GenericArgument::Lifetime(lt));
}
for arg in args.iter() {
if let Some(dims) = literal_cga_dims(arg) {
for d in dims {
let dim_expr: Expr = syn::parse_quote! { #d };
new_args.push(GenericArgument::Const(dim_expr));
}
} else if let GenericArgument::Type(t) = arg {
new_args.push(GenericArgument::Type(rewrite_literal_cgas_only(t)));
} else {
new_args.push(arg.clone());
}
}
*args = new_args;
if let PathArguments::AngleBracketed(a) = &last_seg.arguments {
if a.args.is_empty() {
last_seg.arguments = PathArguments::None;
}
}
} else {
let mut new_args: syn::punctuated::Punctuated<
GenericArgument,
syn::Token![,],
> = syn::punctuated::Punctuated::new();
for arg in args.iter() {
if let GenericArgument::Type(t) = arg {
new_args.push(GenericArgument::Type(rewrite_literal_cgas_only(t)));
} else {
new_args.push(arg.clone());
}
}
*args = new_args;
}
}
}
Type::Path(new_tp)
}
Type::Reference(r) => {
let mut new_r = r.clone();
new_r.elem = Box::new(rewrite_literal_cgas_only(&r.elem));
Type::Reference(new_r)
}
Type::Array(a) => {
let mut new_a = a.clone();
new_a.elem = Box::new(rewrite_literal_cgas_only(&a.elem));
Type::Array(new_a)
}
Type::Tuple(t) => {
let mut new_t = t.clone();
new_t.elems = t.elems.iter().map(rewrite_literal_cgas_only).collect();
Type::Tuple(new_t)
}
other => other.clone(),
}
}
fn literal_cga_rank(arg: &GenericArgument) -> Option<usize> {
if let GenericArgument::Const(Expr::Block(block)) = arg {
if let Some(syn::Stmt::Expr(Expr::Array(arr), _)) = block.block.stmts.first() {
return Some(arr.elems.len());
}
}
None
}
fn literal_cga_rank_with_combo(arg: &GenericArgument, combo: &RankCombo) -> Option<usize> {
if let Some(rank) = literal_cga_rank(arg) {
return Some(rank);
}
if let GenericArgument::Const(Expr::Block(block)) = arg {
if let Some(syn::Stmt::Expr(Expr::Repeat(rep), _)) = block.block.stmts.first() {
if let Expr::Path(p) = &*rep.len {
if let Some(ident) = p.path.get_ident() {
return combo.rank_of_length(ident);
}
}
}
}
None
}
fn literal_cga_dims(arg: &GenericArgument) -> Option<Vec<Expr>> {
if let GenericArgument::Const(Expr::Block(block)) = arg {
if let Some(syn::Stmt::Expr(Expr::Array(arr), _)) = block.block.stmts.first() {
return Some(arr.elems.iter().cloned().collect());
}
}
None
}
fn literal_cga_dims_with_combo(arg: &GenericArgument, combo: &RankCombo) -> Option<Vec<Expr>> {
if let Some(dims) = literal_cga_dims(arg) {
return Some(dims);
}
if let GenericArgument::Const(Expr::Block(block)) = arg {
if let Some(syn::Stmt::Expr(Expr::Repeat(rep), _)) = block.block.stmts.first() {
if let Expr::Path(p) = &*rep.len {
if let Some(ident) = p.path.get_ident() {
if let Some(rank) = combo.rank_of_length(ident) {
return Some((0..rank).map(|_| (*rep.expr).clone()).collect());
}
}
}
}
}
None
}
fn bind_outer_reference_lifetime(ty: &Type, lt: &syn::Lifetime) -> Type {
if let Type::Reference(r) = ty {
let mut new_r = r.clone();
new_r.lifetime = Some(lt.clone());
Type::Reference(new_r)
} else {
ty.clone()
}
}
fn replace_lifetimes_with(ty: &Type, dead: &[String], replacement: &syn::Lifetime) -> Type {
match ty {
Type::Path(tp) => {
let mut new_tp = tp.clone();
for seg in new_tp.path.segments.iter_mut() {
if let PathArguments::AngleBracketed(ab) = &mut seg.arguments {
let mut new_args: syn::punctuated::Punctuated<GenericArgument, syn::Token![,]> =
syn::punctuated::Punctuated::new();
for arg in ab.args.iter() {
match arg {
GenericArgument::Lifetime(lt)
if dead.iter().any(|d| d == <.ident.to_string()) =>
{
new_args.push(GenericArgument::Lifetime(replacement.clone()));
}
GenericArgument::Type(t) => new_args.push(GenericArgument::Type(
replace_lifetimes_with(t, dead, replacement),
)),
other => new_args.push(other.clone()),
}
}
ab.args = new_args;
}
}
Type::Path(new_tp)
}
Type::Reference(r) => {
let mut new_r = r.clone();
if let Some(lt) = &r.lifetime {
if dead.iter().any(|d| d == <.ident.to_string()) {
new_r.lifetime = Some(replacement.clone());
}
}
new_r.elem = Box::new(replace_lifetimes_with(&r.elem, dead, replacement));
Type::Reference(new_r)
}
Type::Array(a) => {
let mut new_a = a.clone();
new_a.elem = Box::new(replace_lifetimes_with(&a.elem, dead, replacement));
Type::Array(new_a)
}
Type::Tuple(t) => {
let mut new_t = t.clone();
new_t.elems = t
.elems
.iter()
.map(|e| replace_lifetimes_with(e, dead, replacement))
.collect();
Type::Tuple(new_t)
}
other => other.clone(),
}
}
const LIFETIME_BEARING_BASES: &[&str] = &["Shape", "Array", "Partition", "PartitionMut"];
fn base_has_lifetime(base: &Ident) -> bool {
let s = base.to_string();
LIFETIME_BEARING_BASES.iter().any(|b| *b == s)
}
fn rewrite_ty_for_rank(ty: &Type, combo: &RankCombo, cgas: &[CgaInfo]) -> Type {
match ty {
Type::Path(tp) => {
let mut new_tp = tp.clone();
if let Some(last_seg) = new_tp.path.segments.last_mut() {
let original_base = last_seg.ident.clone();
let mut cga_for_this_seg: Option<&CgaInfo> = None;
let mut already_has_lifetime = false;
if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
args, ..
}) = &last_seg.arguments
{
for arg in args.iter() {
if matches!(arg, GenericArgument::Lifetime(_)) {
already_has_lifetime = true;
}
if cga_for_this_seg.is_none() {
if let Some(cga) = match_arg_to_cga(arg, cgas) {
cga_for_this_seg = Some(cga);
}
}
}
}
if let Some(cga) = cga_for_this_seg {
let rank = combo.rank_of(&cga.cga_ident);
last_seg.ident = format_ident!("{}_{}", last_seg.ident, rank);
let needs_lifetime = base_has_lifetime(&original_base) && !already_has_lifetime;
if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
args,
..
}) = &mut last_seg.arguments
{
let mut new_args: syn::punctuated::Punctuated<
GenericArgument,
syn::Token![,],
> = syn::punctuated::Punctuated::new();
if needs_lifetime {
let lt: syn::Lifetime = syn::parse_quote! { '_ };
new_args.push(GenericArgument::Lifetime(lt));
}
for arg in args.iter() {
if let Some(cga_ref) = match_arg_to_cga(arg, cgas) {
let dim_names = combo.dim_names_for(&cga_ref.cga_ident);
for dim_name in dim_names {
let dim_expr: Expr = syn::parse_quote! { #dim_name };
new_args.push(GenericArgument::Const(dim_expr));
}
} else if let GenericArgument::Type(t) = arg {
let rewritten = rewrite_ty_for_rank(t, combo, cgas);
new_args.push(GenericArgument::Type(rewritten));
} else {
new_args.push(arg.clone());
}
}
*args = new_args;
if let PathArguments::AngleBracketed(a) = &last_seg.arguments {
if a.args.is_empty() {
last_seg.arguments = PathArguments::None;
}
}
} else if needs_lifetime {
let args: AngleBracketedGenericArguments = syn::parse_quote! { <'_> };
last_seg.arguments = PathArguments::AngleBracketed(args);
}
} else if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
args,
..
}) = &mut last_seg.arguments
{
let literal_rank = args
.iter()
.find_map(|a| literal_cga_rank_with_combo(a, combo));
if let Some(rank) = literal_rank {
let original_base = last_seg.ident.clone();
last_seg.ident = format_ident!("{}_{}", last_seg.ident, rank);
let already_has_lifetime = args
.iter()
.any(|a| matches!(a, GenericArgument::Lifetime(_)));
let needs_lifetime =
base_has_lifetime(&original_base) && !already_has_lifetime;
let mut new_args: syn::punctuated::Punctuated<
GenericArgument,
syn::Token![,],
> = syn::punctuated::Punctuated::new();
if needs_lifetime {
let lt: syn::Lifetime = syn::parse_quote! { '_ };
new_args.push(GenericArgument::Lifetime(lt));
}
for arg in args.iter() {
if let Some(dims) = literal_cga_dims_with_combo(arg, combo) {
for d in dims {
let dim_expr: Expr = syn::parse_quote! { #d };
new_args.push(GenericArgument::Const(dim_expr));
}
} else if let GenericArgument::Type(t) = arg {
let rewritten = rewrite_ty_for_rank(t, combo, cgas);
new_args.push(GenericArgument::Type(rewritten));
} else {
new_args.push(arg.clone());
}
}
*args = new_args;
if let PathArguments::AngleBracketed(a) = &last_seg.arguments {
if a.args.is_empty() {
last_seg.arguments = PathArguments::None;
}
}
} else {
let mut new_args: syn::punctuated::Punctuated<
GenericArgument,
syn::Token![,],
> = syn::punctuated::Punctuated::new();
for arg in args.iter() {
if let GenericArgument::Type(t) = arg {
let rewritten = rewrite_ty_for_rank(t, combo, cgas);
new_args.push(GenericArgument::Type(rewritten));
} else {
new_args.push(arg.clone());
}
}
*args = new_args;
}
}
}
Type::Path(new_tp)
}
Type::Reference(r) => {
let mut new_r = r.clone();
new_r.elem = Box::new(rewrite_ty_for_rank(&r.elem, combo, cgas));
Type::Reference(new_r)
}
Type::Array(a) => {
let mut new_a = a.clone();
new_a.elem = Box::new(rewrite_ty_for_rank(&a.elem, combo, cgas));
if let Expr::Path(ExprPath { path, .. }) = &a.len {
if let Some(ident) = path.get_ident() {
if let Some(rank) = combo.rank_of_length(ident) {
new_a.len = syn::parse_quote! { #rank };
}
}
}
Type::Array(new_a)
}
Type::Tuple(t) => {
let mut new_t = t.clone();
new_t.elems = t
.elems
.iter()
.map(|e| rewrite_ty_for_rank(e, combo, cgas))
.collect();
Type::Tuple(new_t)
}
other => other.clone(),
}
}
fn match_arg_to_cga<'a>(arg: &GenericArgument, cgas: &'a [CgaInfo]) -> Option<&'a CgaInfo> {
let ident = match arg {
GenericArgument::Type(Type::Path(p)) => p.path.get_ident(),
GenericArgument::Const(Expr::Path(ExprPath { path, .. })) => path.get_ident(),
_ => None,
}?;
cgas.iter().find(|c| &c.cga_ident == ident)
}
#[derive(Clone, Debug)]
enum CgaRole {
ShapeBound { sh_ident: Ident },
Free,
}
struct RankPolyShape {
roles: Vec<CgaRole>,
sh_idents: Vec<Ident>,
has_free_cga: bool,
any_return_uses_cga: bool,
}
fn classify_cgas<'a, AI, RI>(
cgas: &[CgaInfo],
methods_args_iter: AI,
methods_returns_iter: RI,
) -> RankPolyShape
where
AI: IntoIterator<Item = &'a Type>,
RI: IntoIterator<Item = &'a Type>,
{
let arg_types: Vec<&Type> = methods_args_iter.into_iter().collect();
let return_types: Vec<&Type> = methods_returns_iter.into_iter().collect();
let mut sh_count = 0usize;
let mut sh_idents: Vec<Ident> = Vec::new();
let mut roles: Vec<CgaRole> = Vec::with_capacity(cgas.len());
for cga in cgas {
let in_args = arg_types
.iter()
.any(|t| type_references_ident(t, &cga.cga_ident));
if in_args {
let sh = format_ident!("Sh{}", sh_count);
sh_count += 1;
sh_idents.push(sh.clone());
roles.push(CgaRole::ShapeBound { sh_ident: sh });
} else {
roles.push(CgaRole::Free);
}
}
let has_free_cga = roles.iter().any(|r| matches!(r, CgaRole::Free));
let any_return_uses_cga = return_types.iter().any(|ret| {
cgas.iter()
.any(|c| type_references_ident(ret, &c.cga_ident))
});
RankPolyShape {
roles,
sh_idents,
has_free_cga,
any_return_uses_cga,
}
}
pub fn desugar_variadic_trait_decl(item: &ItemTrait) -> Result<TokenStream2, Error> {
let cgas = find_cgas(&item.generics, item.ident.span())?;
if cgas.is_empty() {
return syn_error_at(
item.ident.span(),
"variadic_trait: no `const X: [i32; N]` generic found \
(is this trait rank-polymorphic?)",
);
}
let arg_types: Vec<&Type> = item
.items
.iter()
.filter_map(|ti| match ti {
TraitItem::Fn(tf) => Some(tf.sig.inputs.iter().filter_map(|a| match a {
FnArg::Typed(pt) => Some(&*pt.ty),
_ => None,
})),
_ => None,
})
.flatten()
.collect();
let return_types: Vec<&Type> = item
.items
.iter()
.filter_map(|ti| match ti {
TraitItem::Fn(tf) => match &tf.sig.output {
ReturnType::Type(_, ret) => Some(&**ret),
_ => None,
},
_ => None,
})
.collect();
let shape = classify_cgas(
&cgas,
arg_types.iter().copied(),
return_types.iter().copied(),
);
let cga_idents: Vec<Ident> = cgas.iter().map(|c| c.cga_ident.clone()).collect();
let mut new_items: Vec<TraitItem> = Vec::new();
for ti in &item.items {
match ti {
TraitItem::Fn(tf) => {
let mut new_tf = tf.clone();
rewrite_trait_method_for_rank_poly(&mut new_tf.sig, &cgas, &shape);
new_items.push(TraitItem::Fn(new_tf));
}
other => new_items.push(other.clone()),
}
}
if shape.any_return_uses_cga && !shape.has_free_cga {
new_items.insert(0, syn::parse_quote! { type Out; });
}
let mut new_params: syn::punctuated::Punctuated<GenericParam, syn::Token![,]> =
syn::punctuated::Punctuated::new();
for param in &item.generics.params {
let drop_it = matches!(
param,
GenericParam::Const(c) if cga_idents.iter().any(|i| *i == c.ident)
);
if !drop_it {
new_params.push(param.clone());
}
}
for sh in &shape.sh_idents {
new_params.push(syn::parse_quote! { #sh });
}
if shape.has_free_cga && shape.any_return_uses_cga {
new_params.push(syn::parse_quote! { Out });
}
let trait_ident = &item.ident;
let where_clause = &item.generics.where_clause;
let supertraits_marker = if item.supertraits.is_empty() {
quote! {}
} else {
let st = &item.supertraits;
quote! { : #st }
};
let vis = &item.vis;
let attrs = filter_cuda_tile_attrs(&item.attrs);
Ok(quote! {
#(#attrs)*
#[allow(non_camel_case_types)]
#vis trait #trait_ident < #new_params > #supertraits_marker
#where_clause
{
#(#new_items)*
}
})
}
pub fn desugar_variadic_trait_impl(item: &ItemImpl) -> Result<TokenStream2, Error> {
let cgas = find_cgas(&item.generics, item.span())?;
if cgas.is_empty() {
return syn_error_at(
item.span(),
"variadic_trait_impl: no `const X: [i32; N]` generic found",
);
}
if item.trait_.is_none() {
return syn_error_at(
item.span(),
"variadic_trait_impl: expected a trait impl (e.g. `impl T for X`)",
);
}
let length_groups = group_cgas_by_length(&cgas);
let rank_space = RankSpace::new(&length_groups, MAX_RANK);
let mut cga_shape_types: Vec<Option<Type>> = vec![None; cgas.len()];
let mut return_type_for_out: Option<Type> = None;
for ii in &item.items {
if let ImplItem::Fn(impl_fn) = ii {
for arg in &impl_fn.sig.inputs {
if let FnArg::Typed(pt) = arg {
for (i, cga) in cgas.iter().enumerate() {
if cga_shape_types[i].is_none()
&& type_references_ident(&pt.ty, &cga.cga_ident)
{
cga_shape_types[i] = Some((*pt.ty).clone());
}
}
}
}
if let ReturnType::Type(_, ret) = &impl_fn.sig.output {
let uses_cga = cgas
.iter()
.any(|c| type_references_ident(ret, &c.cga_ident));
if uses_cga && return_type_for_out.is_none() {
return_type_for_out = Some((**ret).clone());
}
}
}
}
let arg_types: Vec<&Type> = item
.items
.iter()
.filter_map(|ii| match ii {
ImplItem::Fn(impl_fn) => Some(impl_fn.sig.inputs.iter().filter_map(|a| match a {
FnArg::Typed(pt) => Some(&*pt.ty),
_ => None,
})),
_ => None,
})
.flatten()
.collect();
let return_types: Vec<&Type> = item
.items
.iter()
.filter_map(|ii| match ii {
ImplItem::Fn(impl_fn) => match &impl_fn.sig.output {
ReturnType::Type(_, ret) => Some(&**ret),
_ => None,
},
_ => None,
})
.collect();
let shape = classify_cgas(
&cgas,
arg_types.iter().copied(),
return_types.iter().copied(),
);
if shape.has_free_cga && return_type_for_out.is_none() {
return syn_error_at(
item.span(),
"variadic_trait_impl: free CGAs detected but no method has a return \
type that uses them — cannot substitute case-3c trait args",
);
}
let mut impls = Vec::new();
for combo in rank_space.iter() {
impls.push(emit_variadic_trait_impl_for_rank(
item,
&cgas,
&shape,
&cga_shape_types,
&return_type_for_out,
&combo,
));
}
Ok(quote! { #(#impls)* })
}
fn emit_variadic_trait_impl_for_rank(
item: &ItemImpl,
cgas: &[CgaInfo],
shape: &RankPolyShape,
cga_shape_types: &[Option<Type>],
return_type_for_out: &Option<Type>,
combo: &RankCombo,
) -> TokenStream2 {
let recv_lt: syn::Lifetime = syn::parse_quote! { '__td_recv };
let cga_idents: Vec<Ident> = cgas.iter().map(|c| c.cga_ident.clone()).collect();
let needs_recv_lt = cga_shape_types
.iter()
.any(|t| t.as_ref().is_some_and(type_uses_lifetime))
|| return_type_for_out.as_ref().is_some_and(type_uses_lifetime)
|| type_uses_lifetime(&item.self_ty);
let (_, trait_path_orig, _) = item.trait_.as_ref().unwrap();
let trait_path = rewrite_trait_path_args_for_rank(
trait_path_orig,
cgas,
shape,
cga_shape_types,
return_type_for_out,
combo,
&recv_lt,
);
let self_ty_rewritten = rewrite_ty_for_rank(&item.self_ty, combo, cgas);
let self_ty_rewritten = bind_anon_lifetimes_to(&self_ty_rewritten, &recv_lt);
let dim_params = combo.dim_params_tokens();
let mut all_impl_params: Vec<TokenStream2> = Vec::new();
if needs_recv_lt {
all_impl_params.push(quote! { #recv_lt });
}
for param in &item.generics.params {
let skip = matches!(
param,
GenericParam::Const(c) if cga_idents.iter().any(|i| *i == c.ident)
);
if !skip {
all_impl_params.push(quote! { #param });
}
}
for d in &dim_params {
all_impl_params.push(d.clone());
}
let out_binding = if shape.has_free_cga {
quote! {}
} else {
match return_type_for_out {
Some(ret) => {
let ret_concrete = rewrite_ty_for_rank(ret, combo, cgas);
let ret_concrete = bind_anon_lifetimes_to(&ret_concrete, &recv_lt);
quote! { type Out = #ret_concrete; }
}
None => quote! {},
}
};
let mut method_tokens: Vec<TokenStream2> = Vec::new();
for ii in &item.items {
match ii {
ImplItem::Fn(impl_fn) => {
method_tokens.push(rewrite_impl_method_body_for_rank(
impl_fn, cgas, combo, &recv_lt,
));
}
other => method_tokens.push(quote! { #other }),
}
}
let where_clause = &item.generics.where_clause;
let attrs = filter_cuda_tile_attrs(&item.attrs);
quote! {
#(#attrs)*
impl < #(#all_impl_params),* > #trait_path for #self_ty_rewritten
#where_clause
{
#out_binding
#(#method_tokens)*
}
}
}
fn rewrite_trait_method_for_rank_poly(
sig: &mut syn::Signature,
cgas: &[CgaInfo],
shape: &RankPolyShape,
) {
for arg in sig.inputs.iter_mut() {
if let FnArg::Typed(pt) = arg {
let cga_idx = cgas
.iter()
.enumerate()
.find(|(_, c)| type_references_ident(&pt.ty, &c.cga_ident))
.map(|(i, _)| i);
if let Some(i) = cga_idx {
if let CgaRole::ShapeBound { sh_ident } = &shape.roles[i] {
pt.ty = Box::new(syn::parse_quote! { #sh_ident });
}
}
}
}
if let ReturnType::Type(_, ret) = &mut sig.output {
let uses_cga = cgas
.iter()
.any(|c| type_references_ident(ret, &c.cga_ident));
if uses_cga {
let new_ret: Type = if shape.has_free_cga {
syn::parse_quote! { Out }
} else {
syn::parse_quote! { Self::Out }
};
*ret = Box::new(new_ret);
}
}
}
fn rewrite_trait_path_args_for_rank(
path: &syn::Path,
cgas: &[CgaInfo],
shape: &RankPolyShape,
cga_shape_types: &[Option<Type>],
return_type_for_out: &Option<Type>,
combo: &RankCombo,
recv_lt: &syn::Lifetime,
) -> syn::Path {
let mut new_path = path.clone();
if let Some(last_seg) = new_path.segments.last_mut() {
if let PathArguments::AngleBracketed(ab) = &mut last_seg.arguments {
let mut new_args: syn::punctuated::Punctuated<GenericArgument, syn::Token![,]> =
syn::punctuated::Punctuated::new();
for arg in ab.args.iter() {
let cga_idx: Option<usize> = match arg {
GenericArgument::Const(Expr::Path(ExprPath { path, .. })) => path
.get_ident()
.and_then(|id| cgas.iter().position(|c| c.cga_ident == *id)),
GenericArgument::Type(Type::Path(p)) => p
.path
.get_ident()
.and_then(|id| cgas.iter().position(|c| c.cga_ident == *id)),
_ => None,
};
if let Some(idx) = cga_idx {
let substitute: Option<&Type> = match shape.roles[idx] {
CgaRole::ShapeBound { .. } => {
cga_shape_types.get(idx).and_then(|t| t.as_ref())
}
CgaRole::Free => return_type_for_out.as_ref(),
};
if let Some(src) = substitute {
let concrete = rewrite_ty_for_rank(src, combo, cgas);
let concrete = bind_anon_lifetimes_to(&concrete, recv_lt);
new_args.push(GenericArgument::Type(concrete));
continue;
}
}
new_args.push(arg.clone());
}
ab.args = new_args;
}
}
new_path
}
fn rewrite_impl_method_body_for_rank(
impl_fn: &ImplItemFn,
cgas: &[CgaInfo],
combo: &RankCombo,
recv_lt: &syn::Lifetime,
) -> TokenStream2 {
let mut new_sig = impl_fn.sig.clone();
for arg in new_sig.inputs.iter_mut() {
if let FnArg::Typed(pt) = arg {
let new_ty = rewrite_ty_for_rank(&pt.ty, combo, cgas);
let new_ty = bind_anon_lifetimes_to(&new_ty, recv_lt);
pt.ty = Box::new(new_ty);
}
}
if let ReturnType::Type(_, ret) = &mut new_sig.output {
let new_ret = rewrite_ty_for_rank(ret, combo, cgas);
let new_ret = bind_anon_lifetimes_to(&new_ret, recv_lt);
*ret = Box::new(new_ret);
}
let muted_args: Vec<TokenStream2> = new_sig
.inputs
.iter()
.filter_map(|a| match a {
FnArg::Typed(pt) => match &*pt.pat {
Pat::Ident(pi) => {
let n = &pi.ident;
Some(quote! { #n })
}
_ => None,
},
FnArg::Receiver(_) => None,
})
.collect();
let attrs = filter_cuda_tile_attrs(&impl_fn.attrs);
quote! {
#(#attrs)*
#new_sig {
let _ = (#(#muted_args),*);
::std::unreachable!()
}
}
}
fn bind_anon_lifetimes_to(ty: &Type, lt: &syn::Lifetime) -> Type {
match ty {
Type::Path(tp) => {
let mut new_tp = tp.clone();
for seg in new_tp.path.segments.iter_mut() {
if let PathArguments::AngleBracketed(ab) = &mut seg.arguments {
let mut new_args: syn::punctuated::Punctuated<GenericArgument, syn::Token![,]> =
syn::punctuated::Punctuated::new();
for arg in ab.args.iter() {
match arg {
GenericArgument::Lifetime(l) if l.ident == "_" => {
new_args.push(GenericArgument::Lifetime(lt.clone()))
}
GenericArgument::Type(t) => {
new_args.push(GenericArgument::Type(bind_anon_lifetimes_to(t, lt)))
}
other => new_args.push(other.clone()),
}
}
ab.args = new_args;
}
}
Type::Path(new_tp)
}
Type::Reference(r) => {
let mut new_r = r.clone();
if r.lifetime.is_none() {
new_r.lifetime = Some(lt.clone());
}
new_r.elem = Box::new(bind_anon_lifetimes_to(&r.elem, lt));
Type::Reference(new_r)
}
Type::Array(a) => {
let mut new_a = a.clone();
new_a.elem = Box::new(bind_anon_lifetimes_to(&a.elem, lt));
Type::Array(new_a)
}
Type::Tuple(t) => {
let mut new_t = t.clone();
new_t.elems = t
.elems
.iter()
.map(|e| bind_anon_lifetimes_to(e, lt))
.collect();
Type::Tuple(new_t)
}
other => other.clone(),
}
}
fn type_uses_lifetime(ty: &Type) -> bool {
match ty {
Type::Path(tp) => {
for seg in tp.path.segments.iter() {
if base_has_lifetime(&seg.ident) {
return true;
}
if let PathArguments::AngleBracketed(ab) = &seg.arguments {
for arg in ab.args.iter() {
match arg {
GenericArgument::Lifetime(_) => return true,
GenericArgument::Type(t) => {
if type_uses_lifetime(t) {
return true;
}
}
_ => {}
}
}
}
}
false
}
Type::Reference(_) => true,
Type::Array(a) => type_uses_lifetime(&a.elem),
Type::Tuple(t) => t.elems.iter().any(type_uses_lifetime),
_ => false,
}
}
fn filter_cuda_tile_attrs(attrs: &[syn::Attribute]) -> Vec<syn::Attribute> {
attrs
.iter()
.filter(|a| {
let path = a.path();
!path
.segments
.first()
.is_some_and(|s| s.ident == "cuda_tile")
})
.cloned()
.collect()
}