use crate::error::{syn_err, Error};
use cutile_compiler::syn_utils::*;
use cutile_compiler::types::parse_signed_literal_as_i32;
use proc_macro2::{Ident, Span};
use quote::{format_ident, ToTokens};
use std::collections::BTreeMap;
use std::collections::HashMap;
use syn::{
parse::Parser,
parse_quote,
punctuated::Punctuated,
spanned::Spanned,
visit_mut::{self, VisitMut},
AngleBracketedGenericArguments, Expr, ExprPath, ExprStruct, FnArg, GenericArgument,
GenericParam, Generics, ImplItem, ImplItemFn, ItemFn, ItemImpl, ItemStatic, ItemStruct,
ItemType, Macro, Path, PathArguments, PathSegment, ReturnType, Signature, Stmt, Token, Type,
};
pub fn rank_suffix(n: &[u32]) -> String {
n.iter()
.map(|v| v.to_string())
.collect::<Vec<String>>()
.join("_")
}
pub fn concrete_name(name: &str, n: &[u32]) -> String {
format!("{}_{}", name, rank_suffix(n))
}
#[derive(Debug, Clone)]
pub struct RankBindings {
inst_u32: HashMap<String, u32>,
var_arrays: HashMap<String, VarCGAParameter>,
inst_array: HashMap<String, CGAParameter>,
}
impl RankBindings {
fn new() -> Self {
let inst_u32: HashMap<String, u32> = HashMap::new();
let inst_array: HashMap<String, CGAParameter> = HashMap::new();
let var_arrays: HashMap<String, VarCGAParameter> = HashMap::new();
RankBindings {
inst_u32,
inst_array,
var_arrays,
}
}
fn from_variadic(
cga_lengths: &VariadicLengthItem,
var_cgas: &[VarCGAParameter],
) -> Result<Self, Error> {
let mut inst_u32: HashMap<String, u32> = HashMap::new();
let mut inst_array: HashMap<String, CGAParameter> = HashMap::new();
let mut var_arrays: HashMap<String, VarCGAParameter> = HashMap::new();
for (length_var_name, length_instance) in &cga_lengths.variadic_length_instance {
inst_u32.insert(length_var_name.clone(), *length_instance as u32);
}
for (cga, (length_var_name, length_instance)) in
var_cgas.iter().zip(cga_lengths.cga_length_instance.iter())
{
let length_instance = *length_instance as u32;
if length_var_name != &cga.length_var {
return Err(syn_err(
Span::call_site(),
&format!(
"CGA length var name mismatch: expected '{}', got '{}'",
cga.length_var, length_var_name
),
));
}
if let Some(existing_length) = inst_u32.insert(length_var_name.clone(), length_instance)
{
if existing_length != length_instance {
return Err(syn_err(
Span::call_site(),
&format!(
"CGA length instance mismatch for '{}': expected {}, got {}",
length_var_name, existing_length, length_instance
),
));
}
}
inst_array.insert(cga.name.clone(), cga.instance(length_instance));
var_arrays.insert(cga.name.clone(), cga.clone());
}
Ok(RankBindings {
inst_u32,
inst_array,
var_arrays,
})
}
fn from_generics(generics: &Generics) -> Result<Self, Error> {
let (cga_param, _u32_param) = parse_cgas(generics);
let inst_u32: HashMap<String, u32> = HashMap::new();
let mut inst_array: HashMap<String, CGAParameter> = HashMap::new();
let var_arrays: HashMap<String, VarCGAParameter> = HashMap::new();
for cga in cga_param {
inst_array.insert(cga.name.clone(), cga.clone());
}
Ok(RankBindings {
inst_u32,
inst_array,
var_arrays,
})
}
fn instantiate_var_cgas(&self, var_cgas: &[VarCGAParameter]) -> Result<Self, Error> {
let mut result = self.clone();
for cga in var_cgas {
if !result.inst_u32.contains_key(&cga.length_var) {
return Err(syn_err(
Span::call_site(),
&format!(
"instantiate_var_cgas: Missing inst_u32 entry for '{}'",
cga.length_var
),
));
}
let n = result.inst_u32.get(&cga.length_var).unwrap();
result.inst_array.insert(cga.name.clone(), cga.instance(*n));
result.var_arrays.insert(cga.name.clone(), cga.clone());
}
Ok(result)
}
fn instantiate_new_var_cgas(
&self,
n_list: &[u32],
var_cgas: &[VarCGAParameter],
) -> Result<Self, Error> {
let mut result = self.clone();
for i in 0..n_list.len() {
let n: u32 = n_list[i];
let cga = &var_cgas[i];
if result.inst_u32.contains_key(&cga.length_var) {
return Err(syn_err(
Span::call_site(),
&format!(
"instantiate_new_var_cgas: inst_u32 already contains entry for '{}'",
cga.length_var
),
));
}
result.inst_u32.insert(cga.length_var.clone(), n);
result.inst_array.insert(cga.name.clone(), cga.instance(n));
result.var_arrays.insert(cga.name.clone(), cga.clone());
}
Ok(result)
}
}
#[derive(Debug)]
struct VariadicLengthIterator {
i: usize,
i_max: usize,
variadic_lengths: BTreeMap<String, usize>, cga_length_vars: Vec<String>,
}
impl VariadicLengthIterator {
fn new(attribute_list: &SingleMetaList, arrays: &[VarCGAParameter]) -> Result<Self, Error> {
let mut i_max: usize = 1;
let mut variadic_lengths: BTreeMap<String, usize> = BTreeMap::new();
if let Some(variadic_length_vars) = attribute_list.parse_string_arr("variadic_length_vars")
{
for var in variadic_length_vars {
let len = (attribute_list.parse_int(var.as_str()).ok_or_else(|| {
syn_err(
Span::call_site(),
&format!("Missing attribute value for '{var}'"),
)
})? + 1) as usize;
i_max *= len;
if variadic_lengths.insert(var.clone(), len).is_some() {
return Err(syn_err(
Span::call_site(),
&format!("Duplicate variadic_length_var '{var}'"),
));
}
}
}
let mut cga_length_vars = vec![];
for cga in arrays {
let var = cga.length_var.clone();
cga_length_vars.push(var.clone());
let len = (attribute_list.parse_int(var.as_str()).ok_or_else(|| {
syn_err(
Span::call_site(),
&format!("Missing attribute value for '{var}'"),
)
})? + 1) as usize;
if variadic_lengths.contains_key(&var) {
if *variadic_lengths.get(&var).unwrap() != len {
return Err(syn_err(
Span::call_site(),
&format!("Variadic length mismatch for '{var}'"),
));
}
} else {
i_max *= len;
variadic_lengths.insert(var.clone(), len);
}
}
Ok(VariadicLengthIterator {
i: 0,
i_max,
variadic_lengths,
cga_length_vars,
})
}
}
pub struct VariadicLengthItem {
variadic_length_instance: BTreeMap<String, usize>,
cga_length_instance: Vec<(String, usize)>,
}
impl VariadicLengthItem {
pub fn vec_of_cga_lengths(&self) -> Vec<u32> {
self.cga_length_instance
.iter()
.map(|x| x.1 as u32)
.collect::<Vec<_>>()
}
pub fn vec_of_unique_lengths(&self) -> Vec<u32> {
self.variadic_length_instance
.values()
.map(|x| *x as u32)
.collect::<Vec<_>>()
}
}
impl Iterator for VariadicLengthIterator {
type Item = VariadicLengthItem;
fn next(&mut self) -> Option<Self::Item> {
if self.i < self.i_max {
let mut variadic_length_instance: BTreeMap<String, usize> = BTreeMap::new();
let mut i = self.i;
for (len_var, len) in self.variadic_lengths.iter() {
let pos = i % len;
i /= len;
variadic_length_instance.insert(len_var.clone(), pos);
}
self.i += 1;
let mut cga_length_instance: Vec<(String, usize)> = vec![];
for len_var in &self.cga_length_vars {
let len = *variadic_length_instance
.get(len_var)
.unwrap_or_else(|| panic!("Unexpected length var {len_var}"));
cga_length_instance.push((len_var.clone(), len));
}
Some(VariadicLengthItem {
variadic_length_instance,
cga_length_instance,
})
} else {
None
}
}
}
pub fn cgas_from_generics(generics: &Generics) -> Vec<VarCGAParameter> {
let mut result: Vec<VarCGAParameter> = vec![];
for param in &generics.params {
match param {
GenericParam::Type(_type_param) => continue,
GenericParam::Const(const_param) => match &const_param.ty {
Type::Array(_ty_arr) => {
let arr_type_param = VarCGAParameter::from_const_param(const_param);
result.push(arr_type_param);
}
_ => continue,
},
_ => continue,
}
}
result
}
pub fn variadic_struct(
attributes: &SingleMetaList,
item: ItemStruct,
) -> Result<Vec<(ItemStruct, Option<ItemImpl>)>, Error> {
let cgas = cgas_from_generics(&item.generics);
let cga_iter = VariadicLengthIterator::new(attributes, &cgas)?;
let maybe_constructor_name = attributes.parse_string("constructor");
let base_name = item.ident.to_string();
let has_dims_field = struct_has_dims_field(&item);
let mut result: Vec<(ItemStruct, Option<ItemImpl>)> = vec![];
for var_cga_iter_item in cga_iter {
let mut concrete = item.clone();
let const_instances = RankBindings::from_variadic(&var_cga_iter_item, &cgas)?;
let concrete_ident = Ident::new(
&concrete_name(&base_name, &var_cga_iter_item.vec_of_cga_lengths()),
concrete.ident.span(),
);
concrete.ident = concrete_ident;
rewrite_generics_for_rank(&mut concrete.generics, &const_instances)?;
let concrete_impl = if maybe_constructor_name.is_some() {
let mut type_params: Vec<String> = vec![];
let mut type_args: Vec<String> = vec![];
let mut constructors: Vec<String> = vec![];
for (cga_idx, cga) in cgas.iter().enumerate() {
let n = var_cga_iter_item.vec_of_cga_lengths()[cga_idx];
let cga_name = &cga.name;
let cga_index_type = &cga.element_type;
for dim_idx in 0..n {
type_params.push(format!("const {cga_name}{dim_idx}: {cga_index_type}"));
type_args.push(format!("{cga_name}{dim_idx}"));
}
if has_dims_field {
for num_dynamic in 0..(n + 1) {
let struct_name = concrete.ident.to_string();
let constructor_name = format!(
"{}_{}",
maybe_constructor_name.clone().unwrap(),
num_dynamic
);
let dim_type_str = "i32";
let dyn_constructor = format!(
r#"
pub fn {constructor_name}(dims: &'a [{dim_type_str}; {num_dynamic}]) -> Self {{
{struct_name} {{ dims: dims }}
}}
"#
);
constructors.push(dyn_constructor);
if num_dynamic == 0 {
let constructor_name =
maybe_constructor_name.clone().unwrap().to_string();
let const_constructor = format!(
r#"
pub fn const_{constructor_name}() -> Self {{
{struct_name} {{ dims: &[] }}
}}
"#
);
constructors.push(const_constructor);
}
}
}
}
if constructors.is_empty() {
None
} else {
let name = concrete.ident.to_string();
let impl_generics = type_params.join(",");
let impl_constructors = constructors.join("\n");
let impl_type_args = type_args.join(",");
let constructor_impl = format!(
r#"
impl<'a, {impl_generics}> {name}<'a, {impl_type_args}> {{
{impl_constructors}
}}
"#
);
let parsed_impl =
syn::parse::<ItemImpl>(constructor_impl.parse().map_err(|_| {
syn_err(item.ident.span(), "Failed to parse constructor impl")
})?)
.map_err(|e| {
syn_err(
item.ident.span(),
&format!("Failed to parse constructor impl: {e}"),
)
})?;
Some(parsed_impl)
}
} else {
None
};
result.push((concrete, concrete_impl));
}
Ok(result)
}
fn struct_has_dims_field(item: &ItemStruct) -> bool {
let syn::Fields::Named(fields) = &item.fields else {
return false;
};
fields.named.iter().any(|f| {
let is_dims = f.ident.as_ref().is_some_and(|i| i == "dims");
if !is_dims {
return false;
}
matches!(&f.ty, Type::Reference(r) if matches!(&*r.elem, Type::Slice(_)))
})
}
pub fn variadic_impl(attributes: &SingleMetaList, item: ItemImpl) -> Result<Vec<ItemImpl>, Error> {
let cgas = cgas_from_generics(&item.generics);
let cga_iter = VariadicLengthIterator::new(attributes, &cgas)?;
let mut result: Vec<ItemImpl> = vec![];
for n_list in cga_iter {
let bindings = RankBindings::from_variadic(&n_list, &cgas)?;
result.push(RankInstantiator::new(bindings).rewrite_impl(&item)?);
}
Ok(result)
}
fn rewrite_fn_sig(sig: &mut Signature, const_instances: &RankBindings) -> Result<(), Error> {
rewrite_generics_for_rank(&mut sig.generics, const_instances)?;
for input in sig.inputs.iter_mut() {
match input {
FnArg::Receiver(_receiver) => {
}
FnArg::Typed(fn_param) => {
let fn_param_type = rewrite_type_for_rank(&fn_param.ty, const_instances)?;
*fn_param.ty = fn_param_type;
}
}
}
if let ReturnType::Type(_, return_type) = &mut sig.output {
**return_type = rewrite_type_for_rank(&return_type.clone(), const_instances)?;
}
Ok(())
}
fn rewrite_generics_for_rank(
generics: &mut Generics,
const_instances: &RankBindings,
) -> Result<(), Error> {
let mut concrete_type_params = generics.params.clone();
concrete_type_params.clear();
for param in generics.params.iter() {
match param {
GenericParam::Const(const_param) => match &const_param.ty {
Type::Array(_ty_arr) => {
let const_param_name = const_param.ident.to_string();
let cga = const_instances
.inst_array
.get(const_param_name.as_str())
.ok_or_else(|| {
syn_err(
const_param.ident.span(),
&format!("Missing inst_array entry for '{const_param_name}'"),
)
})?;
if cga.element_type != "i32" {
return Err(syn_err(
const_param.ident.span(),
&format!("Expected element_type 'i32', got '{}'", cga.element_type),
));
}
for i in 0..cga.length {
let const_str = format!("const {}{}: {}", cga.name, i, cga.element_type);
let generic_param =
syn::parse::<GenericParam>(const_str.parse().map_err(|_| {
syn_err(
const_param.ident.span(),
&format!("Failed to parse generic param '{const_str}'"),
)
})?)
.map_err(|e| {
syn_err(
const_param.ident.span(),
&format!("Failed to parse generic param '{const_str}': {e}"),
)
})?;
concrete_type_params.push(generic_param);
}
}
_ => concrete_type_params.push(param.clone()),
},
_ => concrete_type_params.push(param.clone()),
}
}
generics.params = concrete_type_params;
Ok(())
}
fn instantiate_cga(
path: &Path,
instances: &RankBindings,
) -> Result<AngleBracketedGenericArguments, Error> {
let _result_path = path.clone();
let last_seg = path.segments.last().ok_or_else(|| {
syn_err(
path.span(),
"Expected at least one path segment in instantiate_cga",
)
})?;
let param_name = last_seg.ident.to_string();
if instances.inst_array.contains_key(¶m_name) {
let cga = instances.inst_array.get(¶m_name).unwrap();
let mut generic_args_result: Vec<String> = vec![];
for j in 0..cga.length {
generic_args_result.push(format!("{}{}", cga.name, j));
}
let formatted = format!("<{}>", generic_args_result.join(","));
Ok(
syn::parse::<AngleBracketedGenericArguments>(formatted.parse().map_err(|_| {
syn_err(
path.span(),
&format!("Failed to parse angle bracketed args '{formatted}'"),
)
})?)
.map_err(|e| {
syn_err(
path.span(),
&format!("Failed to parse angle bracketed args '{formatted}': {e}"),
)
})?,
)
} else {
Err(syn_err(
path.span(),
&format!("{} is not a const generic array.", path.to_token_stream()),
))
}
}
#[derive(Copy, Clone, Eq, PartialEq)]
enum PathContext {
Type,
ExprPath,
}
fn rewrite_path_for_rank(
path: &Path,
instances: &RankBindings,
context: PathContext,
) -> Result<Path, Error> {
let mut result_path = path.clone();
let last_idx = path.segments.len().saturating_sub(1);
for (i, seg) in path.segments.iter().enumerate() {
let param_name = seg.ident.to_string();
if instances.inst_array.contains_key(¶m_name) {
return Err(syn_err(
seg.ident.span(),
&format!(
"Unexpected use of rewrite_path_for_rank for {}",
path.to_token_stream()
),
));
} else {
let skip_suffix = context == PathContext::ExprPath && i == last_idx;
let (last_type_ident, last_seg_args) = match &seg.arguments {
PathArguments::AngleBracketed(type_params) => {
let (type_ident, last_seg_args) =
instantiate_cga_args(instances, &seg.ident, type_params, skip_suffix)?;
(
type_ident.clone(),
PathArguments::AngleBracketed(last_seg_args),
)
}
PathArguments::None => (seg.ident.clone(), PathArguments::None),
_ => return Err(syn_err(seg.ident.span(), "Unexpected Path arguments.")),
};
let result_seg = PathSegment {
ident: last_type_ident,
arguments: last_seg_args,
};
result_path.segments[i] = result_seg.clone();
}
}
Ok(result_path)
}
fn rewrite_generic_args_for_rank(
generic_args: &mut AngleBracketedGenericArguments,
const_instances: &RankBindings,
) -> Result<(), Error> {
let mut new_args: syn::punctuated::Punctuated<GenericArgument, syn::Token![,]> =
syn::punctuated::Punctuated::new();
for arg in &generic_args.args {
match arg {
GenericArgument::Type(ty) => {
if let Type::Path(type_path) = ty {
if let Some(ident) = type_path.path.get_ident() {
let ident_str = ident.to_string();
if let Some(cga) = const_instances.inst_array.get(&ident_str) {
for j in 0..cga.length {
let dim = format_ident!("{}{}", cga.name, j);
new_args.push(parse_quote! { #dim });
}
continue;
}
}
}
new_args.push(GenericArgument::Type(rewrite_type_for_rank(
ty,
const_instances,
)?));
}
other => new_args.push(other.clone()),
}
}
generic_args.args = new_args;
Ok(())
}
fn rewrite_type_for_rank(ty: &Type, instances: &RankBindings) -> Result<Type, Error> {
Ok(match ty {
Type::Path(type_path) => {
let last_segment = type_path.path.segments.last().ok_or_else(|| {
syn_err(
type_path.span(),
"Expected at least one path segment in rewrite_type_for_rank",
)
})?;
if last_segment.ident == "Option" {
let mut result_type = type_path.clone();
if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
let mut new_args = args.clone();
for arg in &mut new_args.args {
if let GenericArgument::Type(inner_ty) = arg {
*inner_ty = rewrite_type_for_rank(inner_ty, instances)?;
}
}
let last_idx = result_type.path.segments.len() - 1;
result_type.path.segments[last_idx].arguments =
PathArguments::AngleBracketed(new_args);
}
return Ok(result_type.into());
}
let mut result_type = type_path.clone();
let path = rewrite_path_for_rank(&result_type.path, instances, PathContext::Type)?;
result_type.path = path;
result_type.into()
}
Type::Array(type_array) => {
let mut result = type_array.clone();
*result.elem = rewrite_type_for_rank(&type_array.elem, instances)?;
let arr_len = result.len.to_token_stream().to_string();
if instances.inst_u32.contains_key(&arr_len) {
let n = instances.inst_u32.get(&arr_len).unwrap();
result.len = syn::parse::<Expr>(format!("{}", n).parse().map_err(|_| {
syn_err(
type_array.span(),
&format!("Failed to parse array length '{n}'"),
)
})?)
.map_err(|e| {
syn_err(
type_array.span(),
&format!("Failed to parse array length '{n}': {e}"),
)
})?;
}
result.into()
}
Type::Reference(ref_type) => {
let mut result = ref_type.clone();
*result.elem = rewrite_type_for_rank(&ref_type.elem, instances)?;
result.into()
}
Type::Tuple(tuple_type) => {
let mut result = tuple_type.clone();
for elem in &mut result.elems {
*elem = rewrite_type_for_rank(elem, instances)?;
}
Type::Tuple(result)
}
_ => ty.clone(),
})
}
fn instantiate_cga_args(
instances: &RankBindings,
type_ident: &Ident,
generic_args: &AngleBracketedGenericArguments,
skip_suffix: bool,
) -> Result<(Ident, AngleBracketedGenericArguments), Error> {
let mut instantiated_param_name = type_ident.to_string();
let mut generic_args_result: Vec<String> = vec![];
let mut rank_suffixes: Vec<u32> = vec![];
for generic_arg in &generic_args.args {
match generic_arg {
GenericArgument::Type(type_param) => {
match type_param {
Type::Path(type_path) => {
let last_ident = type_path
.path
.segments
.last()
.ok_or_else(|| {
syn_err(type_path.span(), "Expected at least one path segment")
})?
.ident
.to_string();
if instances.inst_array.contains_key(&last_ident) {
let cga = instances.inst_array.get(&last_ident).unwrap();
for j in 0..cga.length {
generic_args_result.push(format!("{}{}", cga.name, j));
}
rank_suffixes.push(cga.length);
} else {
generic_args_result.push(generic_arg.to_token_stream().to_string());
}
}
Type::Reference(type_ref) => {
generic_args_result.push(type_ref.to_token_stream().to_string());
}
_ => {
generic_args_result.push(generic_arg.to_token_stream().to_string());
}
}
}
GenericArgument::Const(const_param) => {
match const_param {
Expr::Block(block_expr) => {
if block_expr.block.stmts.len() != 1 {
return Err(syn_err(
block_expr.span(),
&format!(
"Expected exactly 1 statement in block expression, got {}",
block_expr.block.stmts.len()
),
));
}
let statement = &block_expr.block.stmts[0];
let Stmt::Expr(statement_expr, _) = statement else {
return Err(syn_err(block_expr.span(), "Unexpected block expression."));
};
match statement_expr {
Expr::Array(array_expr) => {
let rank = array_expr.elems.len();
for elem in &array_expr.elems {
generic_args_result
.push(format_rank_const_expr(elem, instances)?);
}
rank_suffixes.push(rank as u32);
}
Expr::Repeat(repeat_expr) => {
let thing_to_repeat =
format_rank_const_expr(&repeat_expr.expr, instances)?;
let num_repetitions = match &*repeat_expr.len {
Expr::Path(len_path) => {
let num_rep_var = len_path.to_token_stream().to_string();
if !instances.inst_u32.contains_key(&num_rep_var) {
return Err(syn_err(
len_path.span(),
&format!(
"Expected instance for generic argument {}",
num_rep_var
),
));
}
let num_repetitions =
*instances.inst_u32.get(&num_rep_var).unwrap();
for _ in 0..num_repetitions {
generic_args_result.push(thing_to_repeat.clone());
}
num_repetitions
}
Expr::Lit(len_lit) => {
let num_repetitions: u32 = len_lit
.to_token_stream()
.to_string()
.parse::<u32>()
.map_err(|e| {
syn_err(
len_lit.span(),
&format!(
"Failed to parse repeat length as u32: {e}"
),
)
})?;
for _ in 0..num_repetitions {
generic_args_result.push(thing_to_repeat.clone());
}
num_repetitions
}
_ => {
return Err(syn_err(
generic_args.span(),
"Unexpected repeat expression.",
))
}
};
rank_suffixes.push(num_repetitions);
}
_ => {
return Err(syn_err(
block_expr.span(),
"Unexpected block expression.",
))
}
}
}
Expr::Lit(lit_expr) => {
generic_args_result.push(lit_expr.lit.to_token_stream().to_string());
}
_ => {
generic_args_result.push(generic_arg.to_token_stream().to_string());
}
}
}
_ => {
generic_args_result.push(generic_arg.to_token_stream().to_string());
}
}
}
if !skip_suffix && !rank_suffixes.is_empty() {
instantiated_param_name = concrete_name(&type_ident.to_string(), &rank_suffixes);
}
let instantiated_param_ident = Ident::new(instantiated_param_name.as_str(), type_ident.span());
let formatted = format!("<{}>", generic_args_result.join(","));
Ok((
instantiated_param_ident,
syn::parse::<AngleBracketedGenericArguments>(formatted.parse().map_err(|_| {
syn_err(
type_ident.span(),
&format!("Failed to parse angle bracketed args '{formatted}'"),
)
})?)
.map_err(|e| {
syn_err(
type_ident.span(),
&format!("Failed to parse angle bracketed args '{formatted}': {e}"),
)
})?,
))
}
fn format_rank_const_expr(expr: &Expr, instances: &RankBindings) -> Result<String, Error> {
if let Expr::Index(index) = expr {
if let Expr::Path(path) = index.expr.as_ref() {
let name = path
.path
.segments
.last()
.map(|segment| segment.ident.to_string())
.unwrap_or_default();
if let Some(cga) = instances.inst_array.get(&name) {
let i = parse_signed_literal_as_i32(&index.index);
if !(0 <= i && (i as u32) < cga.length) {
return Err(syn_err(
index.index.span(),
&format!(
"Index {i} out of bounds for CGA `{}` of length {}",
cga.name, cga.length
),
));
}
return Ok(format!("{}{}", cga.name, i as u32));
}
}
}
Ok(expr.to_token_stream().to_string())
}
pub struct RankInstantiator {
bindings: RankBindings,
error: Option<Error>,
}
impl RankInstantiator {
pub fn new(bindings: RankBindings) -> Self {
Self {
bindings,
error: None,
}
}
fn into_result<T>(self, value: T) -> Result<T, Error> {
match self.error {
Some(err) => Err(err),
None => Ok(value),
}
}
pub fn rewrite_struct(mut self, item: &ItemStruct) -> Result<ItemStruct, Error> {
let mut item = item.clone();
for field in &mut item.fields {
self.visit_type_mut(&mut field.ty);
}
self.into_result(item)
}
pub fn rewrite_function(mut self, item: &ItemFn) -> Result<ItemFn, Error> {
let mut item = item.clone();
if let Err(e) = rewrite_fn_sig(&mut item.sig, &self.bindings) {
return Err(e);
}
self.visit_block_mut(&mut item.block);
self.into_result(item)
}
pub fn rewrite_impl(mut self, item: &ItemImpl) -> Result<ItemImpl, Error> {
let mut item = item.clone();
let original_self_ty = (*item.self_ty).clone();
match rewrite_type_for_rank(&item.self_ty, &self.bindings) {
Ok(t) => *item.self_ty = t,
Err(e) => return Err(e),
}
if let Err(e) = rewrite_generics_for_rank(&mut item.generics, &self.bindings) {
return Err(e);
}
if let Some(trait_) = &mut item.trait_ {
let path = &mut trait_.1;
if path.segments.is_empty() {
return Err(syn_err(
path.span(),
"Expected at least one path segment in trait path",
));
}
let last_seg = path.segments.last_mut().unwrap();
if let PathArguments::AngleBracketed(path_args) = &mut last_seg.arguments {
if let Err(e) = rewrite_generic_args_for_rank(path_args, &self.bindings) {
return Err(e);
}
}
}
let mut impl_items: Vec<ImplItem> = Vec::new();
for item_in_impl in &mut item.items {
match item_in_impl {
ImplItem::Type(type_impl) => {
let mut result = type_impl.clone();
match rewrite_type_for_rank(&type_impl.ty, &self.bindings) {
Ok(t) => result.ty = t,
Err(e) => return Err(e),
}
impl_items.push(ImplItem::Type(result));
}
ImplItem::Const(c) => impl_items.push(ImplItem::Const(c.clone())),
ImplItem::Fn(fn_impl) => {
if get_meta_list("cuda_tile :: variadic_impl_fn", &fn_impl.attrs).is_some() {
continue;
}
let mut result = fn_impl.clone();
self.rewrite_impl_method(&original_self_ty, &mut result);
if self.error.is_some() {
return Err(self.error.unwrap());
}
impl_items.push(ImplItem::Fn(result));
}
_ => return Err(syn_err(item_in_impl.span(), "Unsupported impl item.")),
}
}
item.items = impl_items;
self.into_result(item)
}
fn rewrite_impl_method(&mut self, _self_ty: &Type, item: &mut ImplItemFn) {
let cgas = cgas_from_generics(&item.sig.generics);
let inner_bindings = match self.bindings.instantiate_var_cgas(&cgas) {
Ok(b) => b,
Err(e) => {
self.error = Some(e);
return;
}
};
let prev = std::mem::replace(&mut self.bindings, inner_bindings);
if let Err(e) = rewrite_fn_sig(&mut item.sig, &self.bindings) {
self.error = Some(e);
} else {
self.visit_block_mut(&mut item.block);
}
self.bindings = prev;
}
fn expand_shape_macro(&self, mac: &Macro, kind: &str) -> Result<Expr, Error> {
let mut args: Vec<String> = Vec::new();
let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
let exprs = parser
.parse2(mac.tokens.clone())
.map_err(|e| syn_err(mac.span(), &format!("Failed to parse {kind}! args: {e}")))?;
let expr_count = exprs.len();
for expr in exprs {
match &expr {
Expr::Path(path) if path.path.segments.len() == 1 => {
let ident = &path.path.segments[0].ident;
let name = ident.to_string();
if let Some(cga) = self.bindings.inst_array.get(&name) {
if expr_count != 1 {
return Err(syn_err(
expr.span(),
&format!(
"`{name}` names a const generic array; use it alone or index it as `{name}[i]`"
),
));
}
for i in 0..cga.length {
args.push(format!("{}{}", cga.name, i));
}
continue;
}
}
Expr::Index(index) => {
if let Expr::Path(path) = index.expr.as_ref() {
let name = path
.path
.segments
.last()
.map(|segment| segment.ident.to_string())
.unwrap_or_default();
if let Some(cga) = self.bindings.inst_array.get(&name) {
let i = parse_signed_literal_as_i32(&index.index);
if !(0 <= i && (i as u32) < cga.length) {
return Err(syn_err(
index.index.span(),
&format!(
"Index {i} out of bounds for CGA `{}` of length {}",
cga.name, cga.length
),
));
}
args.push(format!("{}{}", cga.name, i as u32));
continue;
}
}
}
_ => {}
}
args.push(expr.to_token_stream().to_string());
}
let cga_str = format!("{{[{}]}}", args.join(", "));
let ty_str = if kind == "const_shape" {
"Shape"
} else {
"Array"
};
let expr_str = format!("{ty_str}::<{cga_str}>::const_new()");
syn::parse_str::<Expr>(&expr_str)
.map_err(|e| syn_err(mac.span(), &format!("Failed to parse '{expr_str}': {e}")))
}
}
impl VisitMut for RankInstantiator {
fn visit_expr_mut(&mut self, expr: &mut Expr) {
if self.error.is_some() {
return;
}
if let Expr::Macro(em) = expr {
let name = em
.mac
.path
.segments
.last()
.map(|s| s.ident.to_string())
.unwrap_or_default();
if name == "const_shape" || name == "const_array" {
match self.expand_shape_macro(&em.mac, &name) {
Ok(new_expr) => {
*expr = new_expr;
visit_mut::visit_expr_mut(self, expr);
return;
}
Err(e) => {
self.error = Some(e);
return;
}
}
}
}
if let Expr::Path(ep) = expr {
if ep.path.segments.len() == 1 {
let name = ep.path.segments[0].ident.to_string();
if let Some(n) = self.bindings.inst_u32.get(&name).copied() {
*expr = parse_quote! { #n };
return;
}
}
}
if let Expr::Index(ei) = expr {
if let Expr::Path(p) = &*ei.expr {
let name = p
.path
.segments
.last()
.map(|s| s.ident.to_string())
.unwrap_or_default();
if let Some(cga) = self.bindings.inst_array.get(&name).cloned() {
let i = parse_signed_literal_as_i32(&ei.index);
if !(0 <= i && (i as u32) < cga.length) {
self.error = Some(syn_err(
ei.index.span(),
&format!(
"Index {i} out of bounds for CGA `{}` of length {}",
cga.name, cga.length
),
));
return;
}
let dim_ident = Ident::new(&format!("{}{}", cga.name, i as u32), p.span());
*expr = parse_quote! { #dim_ident };
return;
}
}
}
visit_mut::visit_expr_mut(self, expr);
}
fn visit_expr_path_mut(&mut self, e: &mut ExprPath) {
if self.error.is_some() {
return;
}
match rewrite_path_for_rank(&e.path, &self.bindings, PathContext::ExprPath) {
Ok(p) => e.path = p,
Err(err) => self.error = Some(err),
}
}
fn visit_expr_struct_mut(&mut self, s: &mut ExprStruct) {
if self.error.is_some() {
return;
}
for field in &mut s.fields {
self.visit_expr_mut(&mut field.expr);
}
if let Some(rest) = &mut s.rest {
self.visit_expr_mut(rest);
}
match rewrite_path_for_rank(&s.path, &self.bindings, PathContext::Type) {
Ok(p) => s.path = p,
Err(err) => self.error = Some(err),
}
}
fn visit_type_mut(&mut self, ty: &mut Type) {
if self.error.is_some() {
return;
}
match rewrite_type_for_rank(ty, &self.bindings) {
Ok(t) => *ty = t,
Err(e) => self.error = Some(e),
}
}
}
pub fn instantiate_struct_for_rank(item: &ItemStruct) -> Result<ItemStruct, Error> {
let bindings = RankBindings::from_generics(&item.generics)?;
RankInstantiator::new(bindings).rewrite_struct(item)
}
pub fn instantiate_function_for_rank(item: &ItemFn) -> Result<ItemFn, Error> {
let bindings = RankBindings::from_generics(&item.sig.generics)?;
RankInstantiator::new(bindings).rewrite_function(item)
}
pub fn instantiate_type_alias_for_rank(item: &ItemType) -> Result<ItemType, Error> {
let bindings = RankBindings::from_generics(&item.generics)?;
let mut item = item.clone();
rewrite_generics_for_rank(&mut item.generics, &bindings)?;
item.ty = Box::new(rewrite_type_for_rank(&item.ty, &bindings)?);
Ok(item)
}
pub fn instantiate_static_for_rank(item: &ItemStatic) -> Result<ItemStatic, Error> {
let bindings = RankBindings::new();
let mut item = item.clone();
item.ty = Box::new(rewrite_type_for_rank(&item.ty, &bindings)?);
let concrete_type_ident = concrete_type_ident(&item.ty);
let mut instantiator = RankInstantiator::new(bindings);
instantiator.visit_expr_mut(&mut item.expr);
let mut item = instantiator.into_result(item)?;
if let Some(concrete_type_ident) = concrete_type_ident {
rewrite_static_constructor_path(&mut item.expr, &concrete_type_ident);
}
Ok(item)
}
fn concrete_type_ident(ty: &Type) -> Option<Ident> {
let Type::Path(type_path) = ty else {
return None;
};
type_path
.path
.segments
.last()
.map(|segment| segment.ident.clone())
}
fn rewrite_static_constructor_path(expr: &mut Expr, concrete_type_ident: &Ident) {
let Expr::Call(call) = expr else {
return;
};
let Expr::Path(path) = &mut *call.func else {
return;
};
if path.path.segments.len() < 2 {
return;
}
let Some(last) = path.path.segments.last() else {
return;
};
if last.ident != "new" {
return;
}
if let Some(first) = path.path.segments.first_mut() {
first.ident = concrete_type_ident.clone();
}
}
pub fn instantiate_impl_for_rank(item: &ItemImpl) -> Result<ItemImpl, Error> {
let bindings = RankBindings::from_generics(&item.generics)?;
RankInstantiator::new(bindings).rewrite_impl(item)
}