use crate::{
error::{syn_err, Error},
types::{
concrete_name, get_variadic_function_suffix, get_variadic_method_data,
get_variadic_op_data, get_variadic_trait_type_data, get_variadic_type_data,
ConstGenericArrayType, DimType, VariadicOpData, VariadicTypeData,
},
};
use cutile_compiler::syn_utils::*;
use cutile_compiler::train_map::TrainMap;
use cutile_compiler::types::parse_signed_literal_as_i32;
use proc_macro2::{Ident, Span, TokenTree};
use quote::ToTokens;
use std::collections::BTreeMap;
#[allow(unused_assignments)]
use std::collections::{HashMap, HashSet};
use syn::{
parse_quote, spanned::Spanned, AngleBracketedGenericArguments, Expr, ExprCall, ExprMethodCall,
ExprPath, FnArg, GenericArgument, GenericParam, Generics, ImplItem, ImplItemFn, Item, ItemFn,
ItemImpl, ItemStruct, ItemTrait, Member, Pat, Path, PathArguments, PathSegment, ReturnType,
Signature, Stmt, TraitItem, Type,
};
pub fn get_variadic_trait_impl_meta_data(
maybe_primitive: &str,
method_name: &str,
) -> Result<Option<(&'static str, VariadicTypeData, VariadicOpData)>, Error> {
Ok(
match get_variadic_trait_type_data(maybe_primitive, method_name) {
Some(vtd) => match get_variadic_method_data(&vtd, method_name)? {
Some((op_name, vod)) => Some((op_name, vtd, vod)),
None => None,
},
None => None,
},
)
}
pub fn get_variadic_method_meta_data(
receiver_ty: &Type,
method_name: &str,
) -> Result<Option<(&'static str, VariadicTypeData, VariadicOpData)>, Error> {
Ok(match get_vtd(receiver_ty)? {
Some(vtd) => match get_variadic_method_data(&vtd, method_name)? {
Some((op_name, vod)) => Some((op_name, vtd, vod)),
None => None,
},
None => None,
})
}
fn try_get_path_expr_ident_str(maybe_path_expr: &Expr) -> Result<Option<String>, Error> {
match maybe_path_expr {
Expr::Path(path_expr) => {
if path_expr.path.segments.len() != 1 {
return Err(syn_err(
path_expr.path.span(),
&format!(
"Expected single-segment path, got: {:?}",
path_expr.path.segments.to_token_stream().to_string()
),
));
}
let fn_name = path_expr.path.segments[0].ident.to_string();
Ok(Some(fn_name))
}
_ => Ok(None),
}
}
fn get_vod_from_call(expr: &mut ExprCall) -> Result<Option<VariadicOpData>, Error> {
let name = match &*expr.func {
Expr::Path(path_expr) => {
if path_expr.path.segments.len() == 0 {
return Ok(None);
} else {
let fn_name = path_expr
.path
.segments
.last()
.ok_or_else(|| syn_err(path_expr.span(), "Expected at least one path segment"))?
.ident
.to_string();
Some(fn_name)
}
}
_ => None,
};
Ok(match name {
Some(name) => get_variadic_op_data(name.as_str()),
None => None,
})
}
fn get_vtd(ty: &Type) -> Result<Option<VariadicTypeData>, Error> {
Ok(match ty {
Type::Path(ty_path) => {
let last_seg = ty_path.path.segments.last();
match last_seg {
Some(seg) => get_variadic_type_data(seg.ident.to_string().as_str()),
None => None,
}
}
Type::Reference(ref_type) => {
get_vtd(&ref_type.elem)?
}
_ => None,
})
}
fn get_ident_generic_args(
ty: &Type,
vtd: &VariadicTypeData,
) -> Result<(Ident, AngleBracketedGenericArguments), Error> {
match ty {
Type::Path(type_path) => {
let result_type = type_path.clone();
let maybe_last_seg =
result_type.path.segments.last().ok_or_else(|| {
syn_err(type_path.span(), "Expected at least one path segment")
})?;
let last_seg = maybe_last_seg.clone();
if last_seg.ident.to_string() != vtd.name {
return Err(syn_err(
last_seg.ident.span(),
&format!(
"get_ident_generic_args: Expected type '{}', got '{}'",
vtd.name, last_seg.ident
),
));
}
match last_seg.arguments {
PathArguments::AngleBracketed(type_params) => {
Ok((last_seg.ident.clone(), type_params.clone()))
}
_ => Err(syn_err(
type_path.span(),
&format!("Unexpected generic arguments"),
)),
}
}
Type::Reference(ref_type) => get_ident_generic_args(&ref_type.elem, vtd),
_ => Err(syn_err(ty.span(), &format!("Unexpected type"))),
}
}
fn get_concrete_op_ident_from_types(
op_ident: &Ident,
input_types: &Vec<Option<Type>>,
output_type: Option<Type>,
const_instances: &ConstInstances,
disable_output_inference: bool,
) -> Result<(Ident, Option<Type>), Error> {
let vod = get_variadic_op_data(op_ident.to_string().as_str());
if vod.is_none() {
return Ok((op_ident.clone(), output_type));
}
let vod = vod.unwrap();
get_concrete_op_or_method_ident_from_types(
vod,
op_ident,
input_types,
output_type,
const_instances,
disable_output_inference,
)
}
fn get_concrete_op_or_method_ident_from_types(
vod: VariadicOpData,
op_or_method_ident: &Ident,
input_types: &Vec<Option<Type>>,
output_type: Option<Type>,
const_instances: &ConstInstances,
disable_output_inference: bool,
) -> Result<(Ident, Option<Type>), Error> {
let mut vod_cga_name_to_context_cga_name = HashMap::<&str, Option<String>>::new();
let mut const_length_values = HashMap::<&str, u32>::new();
let mut missing_idx = vec![];
let mut missing_types = vec![];
for (idx, expected_type_name, vod_cga_var_names) in vod.input_map {
let Some(ty) = &input_types[idx] else {
missing_idx.push(idx);
missing_types.push(expected_type_name);
continue;
};
let Some(vtd) = get_vtd(ty)? else {
return Err(syn_err(
op_or_method_ident.span(),
&format!("Unable to infer type for argument {idx} for call to {op_or_method_ident}. Expected {expected_type_name}. Required by calls to variadic functions and methods."),
));
};
let Some(cga_instances) = get_cga_type(ty, const_instances)? else {
return Err(syn_err(
op_or_method_ident.span(),
&format!("get_concrete_op_ident_from_types({op_or_method_ident}, ...): Unable to get cga instances for type: {}", ty.to_token_stream().to_string()),
));
};
let type_name = vtd.name;
if expected_type_name != type_name {
return Err(syn_err(
op_or_method_ident.span(),
&format!("get_concrete_op_ident_from_types({op_or_method_ident}, ...): Unexpected positional argument type: {:#?}", (idx, ty.to_token_stream().to_string())),
));
}
if vod_cga_var_names.len() != cga_instances.n.len() {
return Err(syn_err(
op_or_method_ident.span(),
&format!("get_concrete_op_ident_from_types({op_or_method_ident}, ...): Expected {} cga instances for {type_name}, got {:?}.", vod_cga_var_names.len(), cga_instances.n),
));
}
for i in 0..vod_cga_var_names.len() {
let cga_var_name = vod_cga_var_names[i];
let cga_var_length_var = vod.cga_map.get(cga_var_name).ok_or_else(|| {
syn_err(
op_or_method_ident.span(),
&format!("Missing cga_map entry for '{cga_var_name}'"),
)
})?;
let cga_var_length = cga_instances.n[i];
vod_cga_name_to_context_cga_name
.insert(cga_var_name, cga_instances.cga_arg_strings[i].clone());
if const_length_values.contains_key(cga_var_length_var) {
let current_var_length = *const_length_values.get(cga_var_length_var).unwrap();
if current_var_length != cga_var_length {
return Err(syn_err(
op_or_method_ident.span(),
&format!("get_concrete_op_ident_from_types({op_or_method_ident}, ...): CGA instance var length mismatch. Expected {current_var_length} but got {cga_var_length} for cga {cga_var_name}."),
));
}
} else {
const_length_values.insert(cga_var_length_var, cga_var_length);
}
}
}
if const_length_values.len() > vod.const_length_vars.len() {
return Err(syn_err(
op_or_method_ident.span(),
&format!("get_concrete_op_ident_from_types({op_or_method_ident}, ...): Unexpected number of cga instances: {:#?} ", const_length_values),
));
} else if const_length_values.len() < vod.const_length_vars.len() {
if output_type.is_none() {
return Err(syn_err(
op_or_method_ident.span(),
&format!("Unable to infer call to {}. Try binding it to a statically typed variable. \nDebug info:\n const_length_values={:#?}, vod.const_length_vars={:#?}",
op_or_method_ident.to_string(),
const_length_values,
vod.const_length_vars),
));
}
let output_type = output_type.clone().unwrap();
let maybe_vtd = get_vtd(&output_type)?;
if maybe_vtd.is_none() {
return Err(syn_err(
op_or_method_ident.span(),
&format!(
"Unable to infer call to {}. Try binding it to a statically typed variable.",
op_or_method_ident.to_string()
),
));
}
let vtd = maybe_vtd.unwrap();
let cga_instances = get_cga_type(&output_type, const_instances)?;
if cga_instances.is_none() {
return Err(syn_err(
op_or_method_ident.span(),
&format!("get_concrete_op_ident_from_types({op_or_method_ident}, ...): Unable to get cga instances for output type: {}", output_type.to_token_stream().to_string()),
));
}
let cga_instances = cga_instances.unwrap();
let (expected_type_name, vod_cga_var_names) = vod.output_map.clone();
if expected_type_name != vtd.name {
return Err(syn_err(
op_or_method_ident.span(),
&format!("get_concrete_op_ident_from_types({op_or_method_ident}, ...): Unexpected output type: {}", output_type.to_token_stream().to_string()),
));
}
if vod_cga_var_names.len() != cga_instances.n.len() {
return Err(syn_err(
op_or_method_ident.span(),
&format!("get_concrete_op_ident_from_types({op_or_method_ident}, ...): Expected {} cga instances, got {}.", vod_cga_var_names.len(), cga_instances.n.len()),
));
}
for i in 0..vod_cga_var_names.len() {
let cga_var_name = &vod_cga_var_names[i];
let cga_var_length_var = vod.cga_map.get(cga_var_name).ok_or_else(|| {
syn_err(
op_or_method_ident.span(),
&format!("Missing cga_map entry for '{cga_var_name}'"),
)
})?;
let cga_var_length = cga_instances.n[i];
if const_length_values.contains_key(cga_var_length_var) {
if *const_length_values.get(cga_var_length_var).unwrap() != cga_var_length {
return Err(syn_err(
op_or_method_ident.span(),
&format!("get_concrete_op_ident_from_types({op_or_method_ident}, ...): CGA instance var length mismatch for output type."),
));
}
} else {
const_length_values.insert(cga_var_length_var, cga_var_length);
}
}
}
if const_length_values.len() != vod.const_length_vars.len() {
return Err(syn_err(
op_or_method_ident.span(),
&format!("Unable to infer type for argument(s) {missing_idx:?} for call to {op_or_method_ident}. Expected {missing_types:?}. Required by calls to variadic functions and methods."),
));
}
let rtype: Option<Type> = if disable_output_inference {
output_type
} else {
let (return_type_name, return_type_generic_args) = vod.return_type;
if return_type_generic_args.len() == 0 {
let ty = syn::parse::<Type>(return_type_name.parse().map_err(|_| {
syn_err(
op_or_method_ident.span(),
&format!("Unable to parse {return_type_name}"),
)
})?)
.map_err(|e| {
syn_err(
op_or_method_ident.span(),
&format!("Unable to parse {return_type_name}: {e}"),
)
})?;
Some(ty)
} else {
let mut missing_cgas = vec![];
let mut return_type_generic_arg_strings = vec![];
let mut num_cgas = 0;
for arg in return_type_generic_args {
if !vod.cga_map.contains_key(*arg) {
continue;
}
num_cgas += 1;
match vod_cga_name_to_context_cga_name.get(arg) {
Some(Some(s)) => return_type_generic_arg_strings.push(s.clone()),
_ => missing_cgas.push(arg),
};
}
if return_type_generic_arg_strings.len() != num_cgas {
if output_type.is_none() {
return Err(syn_err(
op_or_method_ident.span(),
&format!("Failed to infer return type generic args {:?} \nop={} \nvod_cga_name_to_context_cga_name={vod_cga_name_to_context_cga_name:#?}", missing_cgas, op_or_method_ident.to_string()),
));
}
output_type
} else {
let return_type_str = format!(
"{}<{}>",
return_type_name,
return_type_generic_arg_strings.join(", ")
);
let ty = syn::parse::<Type>(return_type_str.parse().map_err(|_| {
syn_err(
op_or_method_ident.span(),
&format!("Unable to parse {return_type_str}"),
)
})?)
.map_err(|e| {
syn_err(
op_or_method_ident.span(),
&format!("Unable to parse {return_type_str}: {e}"),
)
})?;
Some(ty)
}
}
};
if vod.const_length_vars.len() != const_length_values.len() {
return Err(syn_err(
op_or_method_ident.span(),
&format!("Failed to infer op name from given parameters {op_or_method_ident}"),
));
}
let mut length_vec = Vec::with_capacity(vod.const_length_vars.len());
for const_length_var in vod.const_length_vars {
let val = const_length_values.get(const_length_var).ok_or_else(|| {
syn_err(
op_or_method_ident.span(),
&format!("Missing const_length_value for '{const_length_var}'"),
)
})?;
length_vec.push(*val);
}
let ident = get_variadic_op_ident(op_or_method_ident, &length_vec);
Ok((ident, rtype))
}
fn get_variadic_op_ident(ident: &Ident, const_ga_lengths: &Vec<u32>) -> Ident {
let fn_name_suffix = get_variadic_function_suffix(const_ga_lengths);
Ident::new(&format!("{}__{}", ident, fn_name_suffix), ident.span())
}
#[derive(Debug, Clone)]
pub struct ConstInstances {
inst_u32: HashMap<String, u32>,
var_arrays: HashMap<String, VarCGAParameter>,
inst_array: HashMap<String, CGAParameter>,
}
impl ConstInstances {
fn new() -> Self {
let inst_u32: HashMap<String, u32> = HashMap::new();
let inst_array: HashMap<String, CGAParameter> = HashMap::new();
let var_arrays: HashMap<String, VarCGAParameter> = HashMap::new();
ConstInstances {
inst_u32,
inst_array,
var_arrays,
}
}
fn from_variadic(
cga_lengths: &VariadicLengthItem,
var_cgas: &Vec<VarCGAParameter>,
) -> Result<Self, Error> {
let mut inst_u32: HashMap<String, u32> = HashMap::new();
let mut inst_array: HashMap<String, CGAParameter> = HashMap::new();
let mut var_arrays: HashMap<String, VarCGAParameter> = HashMap::new();
for (length_var_name, length_instance) in &cga_lengths.variadic_length_instance {
inst_u32.insert(length_var_name.clone(), *length_instance as u32);
}
for i in 0..cga_lengths.cga_length_instance.len() {
let (length_var_name, length_instance): &(String, usize) =
&cga_lengths.cga_length_instance[i];
let length_instance = *length_instance as u32;
let cga = &var_cgas[i];
if length_var_name != &cga.length_var {
return Err(syn_err(
Span::call_site(),
&format!(
"CGA length var name mismatch: expected '{}', got '{}'",
cga.length_var, length_var_name
),
));
}
if let Some(existing_length) = inst_u32.insert(length_var_name.clone(), length_instance)
{
if existing_length != length_instance {
return Err(syn_err(
Span::call_site(),
&format!(
"CGA length instance mismatch for '{}': expected {}, got {}",
length_var_name, existing_length, length_instance
),
));
}
}
inst_array.insert(cga.name.clone(), cga.instance(length_instance));
var_arrays.insert(cga.name.clone(), cga.clone());
}
Ok(ConstInstances {
inst_u32,
inst_array,
var_arrays,
})
}
fn from_generics(generics: &Generics) -> Result<Self, Error> {
let (cga_param, _u32_param) = parse_cgas(&generics);
let inst_u32: HashMap<String, u32> = HashMap::new();
let mut inst_array: HashMap<String, CGAParameter> = HashMap::new();
let var_arrays: HashMap<String, VarCGAParameter> = HashMap::new();
for cga in cga_param {
inst_array.insert(cga.name.clone(), cga.clone());
}
Ok(ConstInstances {
inst_u32,
inst_array,
var_arrays,
})
}
fn instantiate_var_cgas(&self, var_cgas: &Vec<VarCGAParameter>) -> Result<Self, Error> {
let mut result = self.clone();
for cga in var_cgas {
if !result.inst_u32.contains_key(&cga.length_var) {
return Err(syn_err(
Span::call_site(),
&format!(
"instantiate_var_cgas: Missing inst_u32 entry for '{}'",
cga.length_var
),
));
}
let n = result.inst_u32.get(&cga.length_var).unwrap();
result.inst_array.insert(cga.name.clone(), cga.instance(*n));
result.var_arrays.insert(cga.name.clone(), cga.clone());
}
Ok(result)
}
fn instantiate_new_var_cgas(
&self,
n_list: &Vec<u32>,
var_cgas: &Vec<VarCGAParameter>,
) -> Result<Self, Error> {
let mut result = self.clone();
for i in 0..n_list.len() {
let n: u32 = n_list[i];
let cga = &var_cgas[i];
if result.inst_u32.contains_key(&cga.length_var) {
return Err(syn_err(
Span::call_site(),
&format!(
"instantiate_new_var_cgas: inst_u32 already contains entry for '{}'",
cga.length_var
),
));
}
result.inst_u32.insert(cga.length_var.clone(), n);
result.inst_array.insert(cga.name.clone(), cga.instance(n));
result.var_arrays.insert(cga.name.clone(), cga.clone());
}
Ok(result)
}
}
#[derive(Debug)]
struct VariadicLengthIterator {
i: usize,
i_max: usize,
variadic_lengths: BTreeMap<String, usize>, cga_length_vars: Vec<String>,
}
impl VariadicLengthIterator {
fn new(attribute_list: &SingleMetaList, arrays: &Vec<VarCGAParameter>) -> Result<Self, Error> {
let mut i_max: usize = 1;
let mut variadic_lengths: BTreeMap<String, usize> = BTreeMap::new();
if let Some(variadic_length_vars) = attribute_list.parse_string_arr("variadic_length_vars")
{
for var in variadic_length_vars {
let len = (attribute_list.parse_int(var.as_str()).ok_or_else(|| {
syn_err(
Span::call_site(),
&format!("Missing attribute value for '{var}'"),
)
})? + 1) as usize;
i_max *= len;
if variadic_lengths.insert(var.clone(), len.clone()).is_some() {
return Err(syn_err(
Span::call_site(),
&format!("Duplicate variadic_length_var '{var}'"),
));
}
}
}
let mut cga_length_vars = vec![];
for cga in arrays {
let var = cga.length_var.clone();
cga_length_vars.push(var.clone());
let len = (attribute_list.parse_int(var.as_str()).ok_or_else(|| {
syn_err(
Span::call_site(),
&format!("Missing attribute value for '{var}'"),
)
})? + 1) as usize;
if variadic_lengths.contains_key(&var) {
if *variadic_lengths.get(&var).unwrap() != len {
return Err(syn_err(
Span::call_site(),
&format!("Variadic length mismatch for '{var}'"),
));
}
} else {
i_max *= len;
variadic_lengths.insert(var.clone(), len.clone());
}
}
Ok(VariadicLengthIterator {
i: 0,
i_max,
variadic_lengths,
cga_length_vars,
})
}
}
pub struct VariadicLengthItem {
variadic_length_instance: BTreeMap<String, usize>,
cga_length_instance: Vec<(String, usize)>,
}
impl VariadicLengthItem {
pub fn vec_of_cga_lengths(&self) -> Vec<u32> {
self.cga_length_instance
.iter()
.map(|x| x.1 as u32)
.collect::<Vec<_>>()
}
pub fn vec_of_unique_lengths(&self) -> Vec<u32> {
self.variadic_length_instance
.values()
.into_iter()
.map(|x| *x as u32)
.collect::<Vec<_>>()
}
}
impl Iterator for VariadicLengthIterator {
type Item = VariadicLengthItem;
fn next(&mut self) -> Option<Self::Item> {
if self.i < self.i_max {
let mut variadic_length_instance: BTreeMap<String, usize> = BTreeMap::new();
let mut i = self.i;
for (len_var, len) in self.variadic_lengths.iter() {
let pos = i % len;
i /= len;
variadic_length_instance.insert(len_var.clone(), pos);
}
self.i += 1;
let mut cga_length_instance: Vec<(String, usize)> = vec![];
for len_var in &self.cga_length_vars {
let len = *variadic_length_instance
.get(len_var)
.expect(&format!("Unexpected length var {len_var}"));
cga_length_instance.push((len_var.clone(), len));
}
Some(VariadicLengthItem {
variadic_length_instance,
cga_length_instance,
})
} else {
None
}
}
}
pub fn parse_var_cgas(generics: &Generics) -> Vec<VarCGAParameter> {
let mut result: Vec<VarCGAParameter> = vec![];
for param in &generics.params {
match param {
GenericParam::Type(_type_param) => continue,
GenericParam::Const(const_param) => match &const_param.ty {
Type::Array(_ty_arr) => {
let arr_type_param = VarCGAParameter::from_const_param(const_param);
result.push(arr_type_param);
}
_ => continue,
},
_ => continue,
}
}
result
}
pub fn variadic_struct(
attributes: &SingleMetaList,
item: ItemStruct,
) -> Result<Vec<(ItemStruct, Option<ItemImpl>)>, Error> {
let vtd = get_variadic_type_data(item.ident.to_string().as_str());
if vtd.is_none() {
return Err(syn_err(
item.ident.span(),
&format!(
"Generating {} requires a corresponding entry in VARIADIC_TYPES",
item.ident
),
));
}
let cgas = parse_var_cgas(&item.generics);
let cga_iter = VariadicLengthIterator::new(attributes, &cgas)?;
let maybe_constructor_name = attributes.parse_string("constructor");
let vtd = vtd.unwrap();
let num_cgas = vtd.num_cgas();
if cgas.len() as u32 != num_cgas {
return Err(syn_err(
item.ident.span(),
&format!(
"Expected {} const generic arrays, got {}",
num_cgas,
cgas.len()
),
));
}
let mut result: Vec<(ItemStruct, Option<ItemImpl>)> = vec![];
for (_, var_cga_iter_item) in cga_iter.enumerate() {
let mut concrete = item.clone();
let const_instances = ConstInstances::from_variadic(&var_cga_iter_item, &cgas)?;
let concrete_ident = Ident::new(
&vtd.concrete_name(&var_cga_iter_item.vec_of_cga_lengths()),
concrete.ident.span(),
);
concrete.ident = concrete_ident;
desugar_generics(&mut concrete.generics, &const_instances)?;
let concrete_impl = if maybe_constructor_name.is_some() {
let mut type_params: Vec<String> = vec![];
let mut type_args: Vec<String> = vec![];
let mut constructors: Vec<String> = vec![];
for cga_idx in 0..num_cgas {
let n = var_cga_iter_item.vec_of_cga_lengths()[cga_idx as usize];
let cga_name: &str = vtd.cga_names[cga_idx as usize];
let cga_index_type: &str = vtd.cga_index_types[cga_idx as usize];
for dim_idx in 0..n {
type_params.push(format!("const {cga_name}{dim_idx}: {cga_index_type}"));
type_args.push(format!("{cga_name}{dim_idx}"));
}
let cga_dim_type: &DimType = &vtd.cga_dim_types[cga_idx as usize];
match cga_dim_type {
DimType::Mixed => {
for num_dynamic in 0..(n + 1) {
let struct_name = concrete.ident.to_string();
let constructor_name = format!(
"{}_{}",
maybe_constructor_name.clone().unwrap(),
num_dynamic
);
let dim_type_str = "i32";
let dyn_constructor = format!(
r#"
pub fn {constructor_name}(dims: &'a [{dim_type_str}; {num_dynamic}]) -> Self {{
{struct_name} {{ dims: dims }}
}}
"#
);
constructors.push(dyn_constructor);
if num_dynamic == 0 {
let constructor_name =
format!("{}", maybe_constructor_name.clone().unwrap());
let const_constructor = format!(
r#"
pub fn const_{constructor_name}() -> Self {{
{struct_name} {{ dims: &[] }}
}}
"#
);
constructors.push(const_constructor);
}
}
}
DimType::Static => {}
}
}
if constructors.is_empty() {
None
} else {
let name = concrete.ident.to_string();
let impl_generics = type_params.join(",");
let impl_constructors = constructors.join("\n");
let impl_type_args = type_args.join(",");
let constructor_impl = format!(
r#"
impl<'a, {impl_generics}> {name}<'a, {impl_type_args}> {{
{impl_constructors}
}}
"#
);
let parsed_impl =
syn::parse::<ItemImpl>(constructor_impl.parse().map_err(|_| {
syn_err(item.ident.span(), "Failed to parse constructor impl")
})?)
.map_err(|e| {
syn_err(
item.ident.span(),
&format!("Failed to parse constructor impl: {e}"),
)
})?;
Some(parsed_impl)
}
} else {
None
};
result.push((concrete, concrete_impl));
}
Ok(result)
}
pub fn variadic_trait(
attributes: &SingleMetaList,
item: ItemTrait,
) -> Result<Vec<ItemTrait>, Error> {
let cgas = parse_var_cgas(&item.generics);
let cga_iter = VariadicLengthIterator::new(attributes, &cgas)?;
let rewrite_variadics = RewriteVariadicsPass {};
let mut result: Vec<ItemTrait> = vec![];
for (_, n_list) in cga_iter.enumerate() {
let const_instances = ConstInstances::from_variadic(&n_list, &cgas)?;
result.push(rewrite_variadics.rewrite_trait(&item, &const_instances)?);
}
Ok(result)
}
pub fn variadic_impl(attributes: &SingleMetaList, item: ItemImpl) -> Result<Vec<ItemImpl>, Error> {
let cgas = parse_var_cgas(&item.generics);
let cga_iter = VariadicLengthIterator::new(attributes, &cgas)?;
let rewrite_variadics = RewriteVariadicsPass {};
let mut result: Vec<ItemImpl> = vec![];
for (_, n_list) in cga_iter.enumerate() {
let const_instances = ConstInstances::from_variadic(&n_list, &cgas)?;
result.push(rewrite_variadics.rewrite_impl(&item, &const_instances)?);
}
Ok(result)
}
pub(self) fn variadic_impl_fn_gen(
attributes: &SingleMetaList,
self_ty: &Type,
item: &ImplItemFn,
const_instances_impl: &ConstInstances,
) -> Result<Vec<ImplItemFn>, Error> {
let cgas = parse_var_cgas(&item.sig.generics);
let _fn_types = get_sig_types(&item.sig, Some(self_ty));
let cga_iter = VariadicLengthIterator::new(attributes, &cgas)?;
let mut result: Vec<ImplItemFn> = vec![];
for (_, cga_iter_item) in cga_iter.enumerate() {
let const_instances = const_instances_impl
.instantiate_new_var_cgas(&cga_iter_item.vec_of_cga_lengths(), &cgas)?;
let concrete_fn = rewrite_impl_fn(self_ty, item, &const_instances)?;
result.push(concrete_fn);
}
Ok(result)
}
pub(self) fn rewrite_impl_fn(
self_ty: &Type,
item: &ImplItemFn,
const_instances: &ConstInstances,
) -> Result<ImplItemFn, Error> {
let rewrite_pass = RewriteVariadicsPass {};
let mut result = item.clone();
rewrite_pass.rewrite_impl_fn(self_ty, &mut result, const_instances, None)?;
Ok(result)
}
pub(self) fn rewrite_fn_sig(
sig: &mut Signature,
const_instances: &ConstInstances,
) -> Result<(), Error> {
desugar_generics(&mut sig.generics, &const_instances)?;
let mut desugared_inputs = sig.inputs.clone();
for input in desugared_inputs.iter_mut() {
match input {
FnArg::Receiver(_receiver) => {
}
FnArg::Typed(fn_param) => {
let fn_param_type = desugar_ty(&*fn_param.ty, &const_instances)?;
*fn_param.ty = fn_param_type;
}
}
}
sig.inputs = desugared_inputs;
let mut desugared_outputs = sig.output.clone();
match &mut desugared_outputs {
ReturnType::Type(_, return_type) => {
*return_type = Box::new(desugar_ty(&return_type.clone(), &const_instances)?);
}
_ => {}
}
sig.output = desugared_outputs;
Ok(())
}
pub fn variadic_op(attributes: &SingleMetaList, item: ItemFn) -> Result<Vec<ItemFn>, Error> {
let op_name = item.sig.ident.to_string();
if get_variadic_op_data(&op_name).is_none() {
return Err(syn_err(
item.sig.ident.span(),
&format!("Variadic op data not found for {op_name}. VariadicOpData entry is required for ops with cuda_tile::variadic_op annotation."),
));
}
let cgas = parse_var_cgas(&item.sig.generics);
let cga_iter = VariadicLengthIterator::new(attributes, &cgas)?;
let mut result: Vec<ItemFn> = vec![];
let rewrite_variadics = RewriteVariadicsPass {};
for (_, n_list) in cga_iter.enumerate() {
let const_instances = ConstInstances::from_variadic(&n_list, &cgas)?;
result.push(rewrite_variadics.rewrite_function(&item, &const_instances)?);
}
Ok(result)
}
pub(self) fn desugar_generics(
generics: &mut Generics,
const_instances: &ConstInstances,
) -> Result<(), Error> {
let mut concrete_type_params = generics.params.clone();
concrete_type_params.clear();
for param in generics.params.iter() {
match param {
GenericParam::Const(const_param) => match &const_param.ty {
Type::Array(_ty_arr) => {
let const_param_name = const_param.ident.to_string();
let cga = const_instances
.inst_array
.get(const_param_name.as_str())
.ok_or_else(|| {
syn_err(
const_param.ident.span(),
&format!("Missing inst_array entry for '{const_param_name}'"),
)
})?;
if cga.element_type != "i32" {
return Err(syn_err(
const_param.ident.span(),
&format!("Expected element_type 'i32', got '{}'", cga.element_type),
));
}
for i in 0..cga.length {
let const_str = format!("const {}{}: {}", cga.name, i, cga.element_type);
let generic_param =
syn::parse::<GenericParam>(const_str.parse().map_err(|_| {
syn_err(
const_param.ident.span(),
&format!("Failed to parse generic param '{const_str}'"),
)
})?)
.map_err(|e| {
syn_err(
const_param.ident.span(),
&format!("Failed to parse generic param '{const_str}': {e}"),
)
})?;
concrete_type_params.push(generic_param);
}
}
_ => concrete_type_params.push(param.clone()),
},
_ => concrete_type_params.push(param.clone()),
}
}
generics.params = concrete_type_params;
Ok(())
}
pub(self) fn expand_cga(
path: &Path,
instances: &ConstInstances,
) -> Result<AngleBracketedGenericArguments, Error> {
let _result_path = path.clone();
let last_seg = path.segments.last().ok_or_else(|| {
syn_err(
path.span(),
"Expected at least one path segment in expand_cga",
)
})?;
let param_name = last_seg.ident.to_string();
if instances.inst_array.contains_key(¶m_name) {
let cga = instances.inst_array.get(¶m_name).unwrap();
let mut generic_args_result: Vec<String> = vec![];
for j in 0..cga.length {
generic_args_result.push(format!("{}{}", cga.name, j));
}
let formatted = format!("<{}>", generic_args_result.join(","));
Ok(
syn::parse::<AngleBracketedGenericArguments>(formatted.parse().map_err(|_| {
syn_err(
path.span(),
&format!("Failed to parse angle bracketed args '{formatted}'"),
)
})?)
.map_err(|e| {
syn_err(
path.span(),
&format!("Failed to parse angle bracketed args '{formatted}': {e}"),
)
})?,
)
} else {
Err(syn_err(
path.span(),
&format!(
"{} is not a const generic array.",
path.to_token_stream().to_string()
),
))
}
}
pub(self) fn desugar_path(path: &Path, instances: &ConstInstances) -> Result<Path, Error> {
let mut result_path = path.clone();
for (i, seg) in path.segments.iter().enumerate() {
let param_name = seg.ident.to_string();
if instances.inst_array.contains_key(¶m_name) {
return Err(syn_err(
seg.ident.span(),
&format!(
"Unexpected use of desugar_path for {}",
path.to_token_stream().to_string()
),
));
} else {
let (last_type_ident, last_seg_args) = match &seg.arguments {
PathArguments::AngleBracketed(type_params) => {
let (type_ident, last_seg_args) =
desugar_cga(&instances, &seg.ident, &type_params)?;
(
type_ident.clone(),
PathArguments::AngleBracketed(last_seg_args),
)
}
PathArguments::None => {
let variadic_type_data: Option<VariadicTypeData> =
get_variadic_type_data(seg.ident.to_string().as_str());
if variadic_type_data.is_some() {
return Err(syn_err(
seg.ident.span(),
"Variadic type arguments are required to desugar variadic types.",
));
}
(seg.ident.clone(), PathArguments::None)
}
_ => return Err(syn_err(seg.ident.span(), "Unexpected Path arguments.")),
};
let result_seg = PathSegment {
ident: last_type_ident,
arguments: last_seg_args,
};
result_path.segments[i] = result_seg.clone();
}
}
Ok(result_path)
}
pub(self) fn desugar_generic_arguments(
generic_args: &mut AngleBracketedGenericArguments,
const_instances: &ConstInstances,
) -> Result<(), Error> {
let span = generic_args.span();
for arg in &mut generic_args.args {
match arg {
GenericArgument::Type(ty) => {
*arg = GenericArgument::Type(desugar_ty(&ty, &const_instances)?);
}
_ => {
return Err(syn_err(
span,
&format!(
"Unsupported generic argument {}",
arg.to_token_stream().to_string()
),
))
}
}
}
Ok(())
}
pub(self) fn desugar_ty(ty: &Type, instances: &ConstInstances) -> Result<Type, Error> {
Ok(match ty {
Type::Path(type_path) => {
let last_segment = type_path.path.segments.last().ok_or_else(|| {
syn_err(
type_path.span(),
"Expected at least one path segment in desugar_ty",
)
})?;
if last_segment.ident == "Option" {
let mut result_type = type_path.clone();
if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
let mut new_args = args.clone();
for arg in &mut new_args.args {
if let GenericArgument::Type(inner_ty) = arg {
*inner_ty = desugar_ty(inner_ty, instances)?;
}
}
let last_idx = result_type.path.segments.len() - 1;
result_type.path.segments[last_idx].arguments =
PathArguments::AngleBracketed(new_args);
}
return Ok(result_type.into());
}
let mut result_type = type_path.clone();
let path = desugar_path(&result_type.path, instances)?;
result_type.path = path;
result_type.into()
}
Type::Array(type_array) => {
let mut result = type_array.clone();
*result.elem = desugar_ty(&type_array.elem, instances)?;
let arr_len = result.len.to_token_stream().to_string();
if instances.inst_u32.contains_key(&arr_len) {
let n = instances.inst_u32.get(&arr_len).unwrap();
result.len = syn::parse::<Expr>(format!("{}", n).parse().map_err(|_| {
syn_err(
type_array.span(),
&format!("Failed to parse array length '{n}'"),
)
})?)
.map_err(|e| {
syn_err(
type_array.span(),
&format!("Failed to parse array length '{n}': {e}"),
)
})?;
}
result.into()
}
Type::Reference(ref_type) => {
let mut result = ref_type.clone();
*result.elem = desugar_ty(&ref_type.elem, instances)?;
result.into()
}
Type::Tuple(tuple_type) => {
let mut result = tuple_type.clone();
for elem in &mut result.elems {
*elem = desugar_ty(&elem, instances)?;
}
Type::Tuple(result)
}
_ => ty.clone(),
})
}
pub(self) fn desugar_cga(
instances: &ConstInstances,
type_ident: &Ident,
generic_args: &AngleBracketedGenericArguments,
) -> Result<(Ident, AngleBracketedGenericArguments), Error> {
let mut expanded_param_name = type_ident.to_string();
let variadic_type_data: Option<VariadicTypeData> =
get_variadic_type_data(expanded_param_name.as_str());
let mut generic_args_result: Vec<String> = vec![];
for generic_arg in &generic_args.args {
match generic_arg {
GenericArgument::Type(type_param) => {
match type_param {
Type::Path(type_path) => {
let last_ident = type_path
.path
.segments
.last()
.ok_or_else(|| {
syn_err(type_path.span(), "Expected at least one path segment")
})?
.ident
.to_string();
if instances.inst_array.contains_key(&last_ident) {
let cga = instances.inst_array.get(&last_ident).unwrap();
for j in 0..cga.length {
generic_args_result.push(format!("{}{}", cga.name, j));
}
expanded_param_name = match &variadic_type_data {
Some(vtd) => vtd.concrete_name(&vec![cga.length]),
None => type_ident.to_string(),
};
} else {
generic_args_result.push(generic_arg.to_token_stream().to_string());
}
}
Type::Reference(type_ref) => {
generic_args_result.push(type_ref.to_token_stream().to_string());
}
_ => {
generic_args_result.push(generic_arg.to_token_stream().to_string());
}
}
}
GenericArgument::Const(const_param) => {
match const_param {
Expr::Block(block_expr) => {
if block_expr.block.stmts.len() != 1 {
return Err(syn_err(
block_expr.span(),
&format!(
"Expected exactly 1 statement in block expression, got {}",
block_expr.block.stmts.len()
),
));
}
let statement = &block_expr.block.stmts[0];
let Stmt::Expr(statement_expr, _) = statement else {
return Err(syn_err(block_expr.span(), "Unexpected block expression."));
};
match statement_expr {
Expr::Array(array_expr) => {
let rank = array_expr.elems.len();
for elem in &array_expr.elems {
let val = elem.to_token_stream().to_string();
generic_args_result.push(val);
}
expanded_param_name = match &variadic_type_data {
Some(vtd) => vtd.concrete_name(&vec![rank as u32]),
None => type_ident.to_string(),
};
}
Expr::Repeat(repeat_expr) => {
let thing_to_repeat =
repeat_expr.expr.to_token_stream().to_string();
let num_repetitions = match &*repeat_expr.len {
Expr::Path(len_path) => {
let num_rep_var = len_path.to_token_stream().to_string();
if !instances.inst_u32.contains_key(&num_rep_var) {
return Err(syn_err(
len_path.span(),
&format!(
"Expected instance for generic argument {}",
num_rep_var
),
));
}
let num_repetitions =
*instances.inst_u32.get(&num_rep_var).unwrap();
for _ in 0..num_repetitions {
generic_args_result.push(thing_to_repeat.clone());
}
num_repetitions
}
Expr::Lit(len_lit) => {
let num_repetitions: u32 = len_lit
.to_token_stream()
.to_string()
.parse::<u32>()
.map_err(|e| {
syn_err(
len_lit.span(),
&format!(
"Failed to parse repeat length as u32: {e}"
),
)
})?;
for _ in 0..num_repetitions {
generic_args_result.push(thing_to_repeat.clone());
}
num_repetitions
}
_ => {
return Err(syn_err(
generic_args.span(),
"Unexpected repeat expression.",
))
}
};
expanded_param_name = match &variadic_type_data {
Some(vtd) => vtd.concrete_name(&vec![num_repetitions]),
None => type_ident.to_string(),
};
}
_ => {
return Err(syn_err(
block_expr.span(),
"Unexpected block expression.",
))
}
}
}
Expr::Lit(lit_expr) => {
generic_args_result.push(lit_expr.lit.to_token_stream().to_string());
}
_ => {
generic_args_result.push(generic_arg.to_token_stream().to_string());
}
}
}
_ => {
generic_args_result.push(generic_arg.to_token_stream().to_string());
}
}
}
let expanded_param_ident = Ident::new(expanded_param_name.as_str(), type_ident.span());
let formatted = format!("<{}>", generic_args_result.join(","));
Ok((
expanded_param_ident,
syn::parse::<AngleBracketedGenericArguments>(formatted.parse().map_err(|_| {
syn_err(
type_ident.span(),
&format!("Failed to parse angle bracketed args '{formatted}'"),
)
})?)
.map_err(|e| {
syn_err(
type_ident.span(),
&format!("Failed to parse angle bracketed args '{formatted}': {e}"),
)
})?,
))
}
pub(self) fn get_cga_type(
ty: &Type,
const_instances: &ConstInstances,
) -> Result<Option<ConstGenericArrayType>, Error> {
let vtd = match get_vtd(ty)? {
Some(vtd) => vtd,
None => return Ok(None),
};
let (_type_ident, generic_args) = get_ident_generic_args(ty, &vtd)?;
let mut n: Vec<u32> = vec![];
let mut cgas: Vec<Option<String>> = vec![];
for generic_arg in &generic_args.args {
match generic_arg {
GenericArgument::Type(type_param) => {
match type_param {
Type::Path(type_path) => {
let last_ident = type_path
.path
.segments
.last()
.ok_or_else(|| {
syn_err(
type_path.span(),
"Expected at least one path segment in get_cga_type",
)
})?
.ident
.to_string();
if const_instances.inst_array.contains_key(&last_ident) {
let cga = const_instances.inst_array.get(&last_ident).unwrap();
n.push(cga.length);
cgas.push(Some(generic_arg.to_token_stream().to_string()));
}
}
Type::Reference(type_ref) => {
return Err(syn_err(
type_ref.span(),
&format!(
"get_cga_type: Type::Reference not supported: {}",
type_ref.to_token_stream().to_string()
),
));
}
_ => {}
}
}
GenericArgument::Const(const_param) => {
match const_param {
Expr::Block(block_expr) => {
if block_expr.block.stmts.len() != 1 {
return Err(syn_err(
block_expr.span(),
&format!(
"Expected exactly 1 statement in block expression, got {}",
block_expr.block.stmts.len()
),
));
}
let statement = &block_expr.block.stmts[0];
let Stmt::Expr(statement_expr, _) = statement else {
return Err(syn_err(block_expr.span(), "Unexpected block expression."));
};
match statement_expr {
Expr::Array(array_expr) => {
n.push(array_expr.elems.len() as u32);
cgas.push(Some(generic_arg.to_token_stream().to_string()));
}
Expr::Repeat(repeat_expr) => {
let _thing_to_repeat =
repeat_expr.expr.to_token_stream().to_string();
match &*repeat_expr.len {
Expr::Path(len_path) => {
let num_rep_var = len_path.to_token_stream().to_string();
if !const_instances.inst_u32.contains_key(&num_rep_var) {
return Err(syn_err(
len_path.span(),
&format!(
"Expected instance for generic argument {}",
num_rep_var
),
));
}
let num_rep =
const_instances.inst_u32.get(&num_rep_var).unwrap();
n.push(*num_rep);
cgas.push(Some(generic_arg.to_token_stream().to_string()));
}
Expr::Lit(len_lit) => {
let num_repetitions: u32 = len_lit
.to_token_stream()
.to_string()
.parse::<u32>()
.map_err(|e| {
syn_err(
len_lit.span(),
&format!(
"Failed to parse repeat length as u32: {e}"
),
)
})?;
n.push(num_repetitions);
cgas.push(Some(generic_arg.to_token_stream().to_string()));
}
_ => {
return Err(syn_err(
ty.span(),
"Unexpected repeat expression.",
))
}
}
}
_ => {
return Err(syn_err(
block_expr.span(),
"Unexpected block expression.",
))
}
}
}
_ => {}
}
}
_ => {}
}
}
if n.len() != cgas.len() {
return Err(syn_err(
ty.span(),
&format!(
"get_cga_type: n.len() ({}) != cgas.len() ({})",
n.len(),
cgas.len()
),
));
}
Ok(Some(ConstGenericArrayType {
cga_arg_strings: cgas,
n,
}))
}
#[derive(Debug)]
pub struct Binding {
ty: Option<Type>,
}
impl Binding {
fn get_cga_type(
&self,
const_instances: &ConstInstances,
) -> Result<Option<ConstGenericArrayType>, Error> {
match &self.ty {
Some(ty) => get_cga_type(ty, const_instances),
None => Ok(None),
}
}
fn get_vtd(&self) -> Result<Option<VariadicTypeData>, Error> {
match &self.ty {
Some(ty) => get_vtd(ty),
None => Ok(None),
}
}
}
pub struct RewriteVariadicsPass {}
impl RewriteVariadicsPass {
fn rewrite_struct(
&self,
item: &ItemStruct,
const_instances: &ConstInstances,
) -> Result<ItemStruct, Error> {
let mut item = item.clone();
for field in &mut item.fields {
field.ty = desugar_ty(&field.ty, &const_instances)?;
}
Ok(item)
}
fn rewrite_function(
&self,
item: &ItemFn,
const_instances: &ConstInstances,
) -> Result<ItemFn, Error> {
let mut item = item.clone();
let mut variables: TrainMap<String, Binding> = self.bind_parameters(None, &item.sig)?;
let (inputs, output) = get_sig_types(&item.sig, None);
let inputs = inputs.into_iter().map(|x| Some(x)).collect::<Vec<_>>();
item.sig.ident = get_concrete_op_ident_from_types(
&item.sig.ident,
&inputs,
Some(output.clone()),
&const_instances,
true,
)?
.0;
self.rewrite_sig(&mut item.sig, &const_instances)?;
self.rewrite_statements(
&mut item.block.stmts,
&const_instances,
&mut variables,
Some(output),
)?;
Ok(item)
}
fn rewrite_trait(
&self,
item: &ItemTrait,
const_instances: &ConstInstances,
) -> Result<ItemTrait, Error> {
let mut item = item.clone();
if const_instances.inst_u32.len() == 0 {
return Ok(item);
}
if const_instances.inst_u32.len() != 1 {
return Err(syn_err(
item.ident.span(),
"Only one CGA is permitted for variadic traits.",
));
}
let key = const_instances.inst_u32.keys().next().unwrap().clone();
let n = const_instances.inst_u32.get(&key).unwrap().clone();
let trait_name = item.ident.to_string();
let concrete_name = concrete_name(&trait_name, &vec![n]);
item.ident = Ident::new(&concrete_name, item.ident.span());
desugar_generics(&mut item.generics, &const_instances)?;
let mut impl_items: Vec<TraitItem> = vec![];
for concrete_item in &mut item.items {
match concrete_item {
TraitItem::Fn(trait_item_fn) => {
let mut result = trait_item_fn.clone();
let cgas = parse_var_cgas(&result.sig.generics);
let const_instances = const_instances.instantiate_var_cgas(&cgas)?;
let method_name = result.sig.ident.to_string();
if let Some(vtd) = get_variadic_type_data(&trait_name) {
if let Some((_op_name, vod)) = get_variadic_method_data(&vtd, &method_name)?
{
let self_type =
syn::parse2::<syn::Type>("Self".parse().map_err(|_| {
syn_err(result.sig.ident.span(), "Failed to parse 'Self' type")
})?)
.map_err(|e| {
syn_err(
result.sig.ident.span(),
&format!("Failed to parse 'Self' type: {e}"),
)
})?;
let (inputs, output) = get_sig_types(&result.sig, Some(&self_type));
let inputs = inputs.into_iter().map(|x| Some(x)).collect::<Vec<_>>();
result.sig.ident = get_concrete_op_or_method_ident_from_types(
vod,
&result.sig.ident,
&inputs,
Some(output.clone()),
&const_instances,
true,
)?
.0;
}
}
self.rewrite_sig(&mut result.sig, &const_instances)?;
impl_items.push(TraitItem::Fn(result));
}
_ => {
return Err(syn_err(
concrete_item.span(),
&format!("Unsupported impl item"),
))
}
}
}
item.items = impl_items;
Ok(item)
}
fn rewrite_impl(
&self,
item: &ItemImpl,
const_instances: &ConstInstances,
) -> Result<ItemImpl, Error> {
let mut item = item.clone();
let self_ty = *item.self_ty.clone();
*item.self_ty = desugar_ty(&*item.self_ty, &const_instances)?;
desugar_generics(&mut item.generics, &const_instances)?;
let mut variadic_trait_vtd = None;
if let Some(trait_) = &mut item.trait_ {
let path_copy = trait_.1.clone();
let path = &mut trait_.1;
if path.segments.len() == 0 {
return Err(syn_err(
path.span(),
"Expected at least one path segment in trait path",
));
}
let last_seg = path.segments.last_mut().unwrap();
let ident_vtd = get_variadic_type_data(last_seg.ident.to_string().as_str());
if ident_vtd.is_some() {
match ident_vtd {
Some(vtd) => {
if const_instances.inst_u32.len() != 1 {
return Err(syn_err(
path.span(),
"Only one CGA is permitted for variadic traits.",
));
}
*path = desugar_path(&path_copy, const_instances)?;
variadic_trait_vtd = Some(vtd);
}
None => {}
}
} else {
match &mut last_seg.arguments {
PathArguments::AngleBracketed(path_args) => {
desugar_generic_arguments(path_args, &const_instances)?
}
_ => {}
}
}
}
let mut impl_items: Vec<ImplItem> = vec![];
for concrete_item in &mut item.items {
match concrete_item {
ImplItem::Type(type_impl) => {
let mut result = type_impl.clone();
result.ty = desugar_ty(&type_impl.ty, &const_instances)?;
impl_items.push(ImplItem::Type(result));
}
ImplItem::Fn(fn_impl) => {
let attributes = get_meta_list("cuda_tile :: variadic_impl_fn", &fn_impl.attrs);
match attributes {
Some(attributes) => {
if variadic_trait_vtd.is_some() {
return Err(syn_err(fn_impl.sig.ident.span(), "variadic_impl_fn attributes are not supported for variadic traits."));
}
clear_attributes(
HashSet::from(["cuda_tile :: variadic_impl_fn"]),
&mut fn_impl.attrs,
);
let results: Vec<ImplItemFn> = variadic_impl_fn_gen(
&attributes,
&self_ty,
&fn_impl,
&const_instances,
)?;
for result in results {
impl_items.push(ImplItem::Fn(result));
}
}
None => {
let mut result = fn_impl.clone();
self.rewrite_impl_fn(
&self_ty,
&mut result,
&const_instances,
variadic_trait_vtd.clone(),
)?;
impl_items.push(ImplItem::Fn(result));
}
}
}
_ => {
return Err(syn_err(
concrete_item.span(),
&format!("Unsupported impl item."),
))
}
}
}
item.items = impl_items;
Ok(item)
}
fn rewrite_impl_fn(
&self,
self_ty: &Type,
item: &mut ImplItemFn,
const_instances: &ConstInstances,
variadic_trait_vtd: Option<VariadicTypeData>,
) -> Result<(), Error> {
let cgas = parse_var_cgas(&item.sig.generics);
let const_instances = const_instances.instantiate_var_cgas(&cgas)?;
let mut variables: TrainMap<String, Binding> =
self.bind_parameters(Some(self_ty), &item.sig)?;
let return_type: Option<Type> = match item.sig.output {
ReturnType::Type(_, ref return_type) => Some(*return_type.clone()),
_ => None,
};
let method_name = item.sig.ident.to_string();
let vmmd = if variadic_trait_vtd.is_some() {
let vtd = variadic_trait_vtd.unwrap();
match get_variadic_method_data(&vtd, &method_name)? {
Some((op_name, vod)) => Some((op_name, vtd, vod)),
None => None,
}
} else {
get_variadic_method_meta_data(&self_ty, &method_name)?
};
if let Some((_op_name, _vtd, vod)) = vmmd {
let (inputs, output) = get_sig_types(&item.sig, Some(self_ty));
let inputs = inputs.into_iter().map(|x| Some(x)).collect::<Vec<_>>();
item.sig.ident = get_concrete_op_or_method_ident_from_types(
vod,
&item.sig.ident,
&inputs,
Some(output.clone()),
&const_instances,
true,
)?
.0;
};
self.rewrite_sig(&mut item.sig, &const_instances)?;
self.rewrite_statements(
&mut item.block.stmts,
&const_instances,
&mut variables,
return_type,
)?;
Ok(())
}
fn bind_parameters(
&self,
self_ty: Option<&Type>,
sig: &Signature,
) -> Result<TrainMap<'_, String, Binding>, Error> {
let mut variables: TrainMap<String, Binding> = TrainMap::new();
for input in sig.inputs.iter() {
match input {
FnArg::Typed(fn_param) => {
let name = {
match &*fn_param.pat {
Pat::Ident(ident) => ident.ident.to_string(),
_ => {
return Err(syn_err(
fn_param.span(),
&format!("Unexpected function param pattern."),
))
}
}
};
let ty = &*fn_param.ty;
variables.insert(
name.clone(),
Binding {
ty: Some(ty.clone()),
},
);
}
FnArg::Receiver(_fn_self) => {
if self_ty.is_none() {
return Err(syn_err(
sig.ident.span(),
"bind_parameters for impls requires self_ty.",
));
}
let self_ty = self_ty.unwrap().clone();
variables.insert("self".to_string(), Binding { ty: Some(self_ty) });
}
}
}
Ok(variables)
}
fn rewrite_sig(
&self,
sig: &mut Signature,
const_instances: &ConstInstances,
) -> Result<(), Error> {
rewrite_fn_sig(sig, &const_instances)
}
fn rewrite_statements(
&self,
statements: &mut [Stmt],
const_instances: &ConstInstances,
variables: &mut TrainMap<String, Binding>,
mut return_type: Option<Type>,
) -> Result<Option<Type>, Error> {
let num_statements = statements.len();
for (i, statement) in statements.iter_mut().enumerate() {
let is_last = i == num_statements - 1;
match statement {
Stmt::Local(local) => {
let mut binding_name: Option<String> = None;
let mut binding_ty: Option<Type> = None;
match &mut local.pat {
Pat::Type(pat_type) => {
match &*pat_type.pat {
Pat::Ident(pat_ident) => {
binding_name = Some(pat_ident.ident.to_string());
}
Pat::Tuple(_) => {
if let Some(init) = &mut local.init {
self.rewrite_expr(
&mut init.expr,
const_instances,
variables,
None,
)?;
}
binding_ty = Some(*pat_type.ty.clone());
let new_ty = desugar_ty(&*pat_type.ty, &const_instances)?;
*pat_type.ty = new_ty;
continue;
}
_ => {
return Err(syn_err(
pat_type.span(),
&format!("let binding LHS not implemented."),
))
}
}
binding_ty = Some(*pat_type.ty.clone());
let new_ty = desugar_ty(&*pat_type.ty, &const_instances)?;
*pat_type.ty = new_ty;
}
Pat::Ident(pat_ident) => {
binding_name = Some(pat_ident.ident.to_string());
binding_ty = None;
} Pat::Tuple(_) => {
if let Some(init) = &mut local.init {
self.rewrite_expr(
&mut init.expr,
const_instances,
variables,
None,
)?;
}
continue; }
_ => {
return Err(syn_err(
local.span(),
&format!("Local pattern type not supported"),
))
}
}
if binding_name.is_none() {
return Err(syn_err(local.span(), &format!("Unable to rewrite expr.")));
}
let binding_name = binding_name.unwrap();
match &mut local.init {
Some(init) => {
let inferred_ty = self.rewrite_expr(
&mut *init.expr,
const_instances,
variables,
binding_ty.clone(),
)?;
if binding_ty.is_none() {
binding_ty = inferred_ty;
}
}
None => {}
}
variables.insert(
binding_name.clone(),
Binding {
ty: binding_ty.clone(),
},
);
}
Stmt::Item(item) => {
let mut binding_name: Option<String> = None;
let binding_ty: Option<Type> = match item {
Item::Const(const_item) => {
binding_name = Some(const_item.ident.to_string());
let return_type = Some(*const_item.ty.clone());
self.rewrite_expr(
&mut *const_item.expr,
const_instances,
variables,
return_type,
)?
}
_ => {
return Err(syn_err(
item.span(),
&format!(
"{}\nOnly const local item definitions are supported.",
item.to_token_stream().to_string()
),
))
}
};
let Some(binding_name) = binding_name else {
return Err(syn_err(item.span(), &format!("Unable to rewrite expr.")));
};
variables.insert(
binding_name.clone(),
Binding {
ty: binding_ty.clone(),
},
);
}
Stmt::Expr(expr, semicolon) => {
let ty =
self.rewrite_expr(expr, const_instances, variables, return_type.clone())?;
match expr {
Expr::Assign(assign_expr) => {
let binding_name: String;
let mut binding_ty: Option<Type> = None;
match &mut *assign_expr.left {
Expr::Path(path_expr) => {
if path_expr.path.segments.len() != 1 {
return Err(syn_err(
path_expr.span(),
"Expected single-segment path in assignment",
));
}
binding_name = path_expr.path.segments[0].ident.to_string()
}
_ => {
return Err(syn_err(
assign_expr.span(),
&format!("Expr::Assign not supported"),
))
}
}
binding_ty = match variables.get(&binding_name) {
Some(old) => old.ty.clone(),
None => {
None
}
};
let _ = variables.insert(
binding_name,
Binding {
ty: binding_ty.clone(),
},
);
}
Expr::Return(_return_expr) => {
return_type = ty;
break;
}
_ => {
if is_last && semicolon.is_none() {
return_type = ty;
} else {
}
}
}
}
Stmt::Macro(_mac_stmt) => continue, }
}
Ok(return_type)
}
fn rewrite_expr(
&self,
expr: &mut Expr,
const_instances: &ConstInstances,
variables: &mut TrainMap<String, Binding>,
mut return_type: Option<Type>,
) -> Result<Option<Type>, Error> {
match expr {
Expr::Index(index_expr) => {
let index_span = index_expr.span();
let inner_expr_span = index_expr.expr.span();
let is_path = matches!(&*index_expr.expr, Expr::Path(_));
if !is_path {
return Err(syn_err(
inner_expr_span,
&format!(
"Index expression not supported: {}",
index_expr.expr.to_token_stream().to_string()
),
));
}
let path_expr = match &*index_expr.expr {
Expr::Path(p) => p.clone(),
_ => unreachable!(),
};
let expr_str = path_expr.to_token_stream().to_string();
if let Some(cga) = const_instances.inst_array.get(&expr_str) {
let expanded_cga = expand_cga(&path_expr.path, const_instances)?;
let index = parse_signed_literal_as_i32(&index_expr.index);
if !(0 <= index && (index as u32) < cga.length) {
return Err(syn_err(
index_span,
&format!(
"Index {index} out of bounds for CGA of length {}",
cga.length
),
));
}
if cga.element_type != "i32" {
return Err(syn_err(
index_span,
&format!("Expected element_type 'i32', got '{}'", cga.element_type),
));
}
let desugared_idx_expression = expanded_cga.args[index as usize].clone();
*expr = parse_quote!(#desugared_idx_expression);
let return_type: Type = parse_quote! { i32 };
Ok(Some(return_type))
} else {
let expr_span = expr.span();
let index_expr = match expr {
Expr::Index(ie) => ie,
_ => unreachable!(),
};
let index_expr_expr = &mut *index_expr.expr;
match self.rewrite_expr(
index_expr_expr,
const_instances,
variables,
return_type,
)? {
Some(Type::Array(ty)) => Ok(Some(*ty.elem.clone())),
Some(Type::Reference(ty)) => {
let Type::Slice(slice_ty) = *ty.elem.clone() else {
return Err(syn_err(
expr_span,
&format!("Index expression not supported (reference)"),
));
};
Ok(Some(*slice_ty.elem.clone()))
}
None => Err(syn_err(
expr_span,
&format!("Failed to compute type for index expression"),
)),
Some(_other) => Err(syn_err(
expr_span,
&format!("Index expression not supported"),
)),
}
}
}
Expr::Const(const_expr) => {
let mut block_vars = variables.fork();
self.rewrite_statements(
&mut const_expr.block.stmts,
const_instances,
&mut block_vars,
return_type.clone(),
)
}
Expr::Block(block_expr) => {
let mut block_vars = variables.fork();
self.rewrite_statements(
&mut block_expr.block.stmts,
const_instances,
&mut block_vars,
return_type.clone(),
)
}
Expr::Unsafe(block_expr) => {
let mut block_vars = variables.fork();
self.rewrite_statements(
&mut block_expr.block.stmts,
const_instances,
&mut block_vars,
return_type.clone(),
)
}
Expr::ForLoop(for_expr) => {
let mut block_vars = variables.fork();
self.rewrite_statements(
&mut for_expr.body.stmts,
const_instances,
&mut block_vars,
return_type.clone(),
)
}
Expr::While(while_expr) => {
self.rewrite_expr(&mut *while_expr.cond, const_instances, variables, None)?;
let mut block_vars = variables.fork();
self.rewrite_statements(
&mut while_expr.body.stmts,
const_instances,
&mut block_vars,
return_type.clone(),
)
}
Expr::Loop(loop_expr) => {
let mut block_vars = variables.fork();
self.rewrite_statements(
&mut loop_expr.body.stmts,
const_instances,
&mut block_vars,
return_type.clone(),
)
}
Expr::If(if_expr) => {
self.rewrite_expr(
&mut if_expr.cond,
const_instances,
variables,
return_type.clone(),
)?;
if let Some((_Else, else_expr)) = &mut if_expr.else_branch {
let mut block_vars = variables.fork();
self.rewrite_expr(
&mut **else_expr,
const_instances,
&mut block_vars,
return_type.clone(),
)?;
}
let mut block_vars = variables.fork();
self.rewrite_statements(
&mut if_expr.then_branch.stmts,
const_instances,
&mut block_vars,
return_type.clone(),
)
}
Expr::Continue(_continue_expr) => Ok(None),
Expr::Break(_break_expr) => Ok(None),
Expr::Call(call_expr) => {
self.rewrite_call(call_expr, const_instances, variables, return_type.clone())
}
Expr::MethodCall(method_call_expr) => self.rewrite_method_call(
method_call_expr,
const_instances,
variables,
return_type.clone(),
),
Expr::Cast(cast_expr) => {
self.rewrite_expr(
&mut *cast_expr.expr,
const_instances,
variables,
return_type.clone(),
)?;
*cast_expr.ty = desugar_ty(&*cast_expr.ty, const_instances)?;
Ok(return_type)
}
Expr::Path(path_expr) => {
let path_span = path_expr.span();
let last_seg = path_expr
.path
.segments
.last_mut()
.ok_or_else(|| syn_err(path_span, "Expected at least one path segment"))?;
let name = last_seg.ident.to_string();
if let Some(n) = const_instances.inst_u32.get(name.as_str()) {
let new_expr = syn::parse::<Expr>(
format!("{n}")
.parse()
.map_err(|_| syn_err(path_span, &format!("Failed to parse '{n}'")))?,
)
.map_err(|e| syn_err(path_span, &format!("Failed to parse '{n}': {e}")))?;
*expr = new_expr;
return Ok(Some(
syn::parse::<Type>(
"i32"
.parse()
.map_err(|_| syn_err(path_span, "Failed to parse 'i32'"))?,
)
.map_err(|e| syn_err(path_span, &format!("Failed to parse 'i32': {e}")))?,
));
}
self.rewrite_path_expr_type(
path_expr,
const_instances,
variables,
return_type.clone(),
)
}
Expr::Reference(ref_expr) => self.rewrite_expr(
&mut *ref_expr.expr,
const_instances,
variables,
return_type.clone(),
),
Expr::Return(return_expr) => match &mut return_expr.expr {
Some(return_expr) => self.rewrite_expr(
&mut *return_expr,
const_instances,
variables,
return_type.clone(),
),
None => Ok(return_type),
},
Expr::Assign(assign_expr) => self.rewrite_expr(
&mut *assign_expr.right,
const_instances,
variables,
return_type.clone(),
),
Expr::Unary(unary_expr) => self.rewrite_expr(
&mut *unary_expr.expr,
const_instances,
variables,
return_type.clone(),
),
Expr::Binary(bin_expr) => {
self.rewrite_expr(
&mut *bin_expr.left,
const_instances,
variables,
return_type.clone(),
)?;
self.rewrite_expr(
&mut *bin_expr.right,
const_instances,
variables,
return_type.clone(),
)?;
Ok(return_type)
}
Expr::Tuple(tuple_expr) => {
for elem_expr in tuple_expr.elems.iter_mut() {
self.rewrite_expr(elem_expr, const_instances, variables, None)?;
}
Ok(return_type)
}
Expr::Array(arr_expr) => {
for elem_expr in arr_expr.elems.iter_mut() {
self.rewrite_expr(elem_expr, const_instances, variables, None)?;
}
Ok(return_type)
}
Expr::Repeat(repeat_expr) => {
self.rewrite_expr(&mut *repeat_expr.len, const_instances, variables, None)?;
Ok(return_type)
}
Expr::Field(field_expr) => {
return_type = self.rewrite_expr(
&mut *field_expr.base,
const_instances,
variables,
return_type.clone(),
)?;
Ok(return_type)
}
Expr::Struct(struct_expr) => {
if struct_expr.path.segments.len() == 0 {
return Err(syn_err(
struct_expr.span(),
"Expected at least one path segment in struct expression",
));
}
let last_seg = struct_expr.path.segments.last_mut().unwrap();
let name = last_seg.ident.to_string();
let vtd = get_variadic_type_data(name.as_str());
match vtd {
Some(_vtd) => {
if return_type.is_none() {
return Err(syn_err(
struct_expr.span(),
"Variadic structs require a static type annotation. Try assigning to a statically typed let binding.",
));
}
let (last_type_ident, last_seg_args) = match &last_seg.arguments {
PathArguments::AngleBracketed(type_params) => {
let (type_ident, last_seg_args) =
desugar_cga(&const_instances, &last_seg.ident, &type_params)?;
(
type_ident.clone(),
PathArguments::AngleBracketed(last_seg_args),
)
}
PathArguments::None => (last_seg.ident.clone(), PathArguments::None),
_ => {
return Err(syn_err(
struct_expr.span(),
"Unexpected Path arguments.",
))
}
};
*last_seg = PathSegment {
ident: last_type_ident,
arguments: last_seg_args,
};
}
None => {}
}
for field in &mut struct_expr.fields {
self.rewrite_expr(&mut field.expr, const_instances, variables, None)?;
match &mut field.member {
Member::Named(_named) => {}
Member::Unnamed(_idx) => {
return Err(syn_err(struct_expr.span(), "Tuples not supported."))
}
}
}
Ok(return_type)
}
Expr::Macro(mac_expr) => {
let last_seg = mac_expr.mac.path.segments.last();
if last_seg.is_none() {
return Ok(return_type);
}
let last_seg = last_seg.unwrap();
let mac_name = last_seg.ident.to_string();
match mac_name.as_str() {
"const_shape" | "const_array" => {
let mut args = vec![];
#[allow(unused_variables)]
let mut is_cga = false;
#[allow(unused_variables)]
let mut is_consts = false;
for token in mac_expr.mac.tokens.clone() {
match token {
TokenTree::Literal(lit) => {
args.push(lit.to_string());
}
TokenTree::Ident(ident) => {
let const_var = ident.to_string();
if let Some(_cga) = const_instances.inst_array.get(&const_var) {
is_cga = true;
let path: Path = parse_quote! { #ident };
let generic_args = expand_cga(&path, const_instances)?;
args = generic_args
.args
.iter()
.map(|x| x.to_token_stream().to_string())
.collect::<Vec<String>>();
} else {
is_consts = true;
args.push(const_var);
}
}
TokenTree::Punct(punct) => {
if punct.as_char() == ',' {
continue;
} else {
return Err(syn_err(
mac_expr.span(),
&format!("Unexpected punctuation {punct:}"),
));
}
}
_ => {
return Err(syn_err(
mac_expr.span(),
&format!("Unexpected token {:?}", token),
))
}
}
}
let cga_str = format!("{{[{}]}}", args.join(", "));
let ty_str = if mac_name == "const_shape" {
"Shape"
} else {
"Array"
};
let mac_span = mac_expr.span();
let shape_fmt = format!("{ty_str}::<{cga_str}>::const_new()");
let shape_expr = syn::parse2::<Expr>(shape_fmt.parse().map_err(|_| {
syn_err(mac_span, &format!("Failed to parse '{shape_fmt}'"))
})?)
.map_err(|e| {
syn_err(mac_span, &format!("Failed to parse '{shape_fmt}': {e}"))
})?;
*expr = shape_expr;
let shape_str = format!("{ty_str}<{cga_str}>");
let shape_ty = syn::parse::<Type>(shape_str.parse().map_err(|_| {
syn_err(mac_span, &format!("Failed to parse '{shape_str}'"))
})?)
.map_err(|e| {
syn_err(mac_span, &format!("Failed to parse '{shape_str}': {e}"))
})?;
self.rewrite_expr(&mut *expr, const_instances, variables, Some(shape_ty))
}
_ => Ok(return_type),
}
}
Expr::Lit(_lit_expr) => Ok(return_type),
Expr::Paren(paren_expr) => {
return_type = self.rewrite_expr(
&mut *paren_expr.expr,
const_instances,
variables,
return_type.clone(),
)?;
Ok(return_type)
}
Expr::Closure(_closure_expr) => {
Ok(return_type)
}
_ => Err(syn_err(
expr.span(),
&format!("Expression type not supported"),
)),
}
}
fn rewrite_path_expr_type(
&self,
expr: &mut ExprPath,
const_instances: &ConstInstances,
variables: &mut TrainMap<String, Binding>,
return_type: Option<Type>,
) -> Result<Option<Type>, Error> {
let result_path = desugar_path(&expr.path, const_instances)?;
expr.path = result_path;
if expr.path.segments.len() == 0 {
return Ok(None);
}
let last_seg = expr.path.segments.last_mut().unwrap();
let name = last_seg.ident.to_string();
Ok(match variables.get(&name) {
Some(var) => var.ty.clone(),
None => return_type, })
}
fn try_get_var<'a>(
&self,
maybe_path_expr: &mut Expr,
variables: &'a TrainMap<String, Binding>,
) -> Result<Option<&'a Binding>, Error> {
Ok(match try_get_path_expr_ident_str(maybe_path_expr)? {
Some(name) => variables.get(&name),
None => None,
})
}
fn rewrite_method_call(
&self,
expr: &mut ExprMethodCall,
const_instances: &ConstInstances,
variables: &mut TrainMap<String, Binding>,
return_type: Option<Type>,
) -> Result<Option<Type>, Error> {
let method_ident = &expr.method;
let method_name = method_ident.to_string();
let self_ty =
match self.rewrite_expr(&mut *expr.receiver, const_instances, variables, None)? {
Some(ty) => ty,
None => {
return Err(syn_err(
expr.receiver.span(),
&format!("Unable to infer receiver type"),
))
}
};
let maybe_primitive_type = self_ty.to_token_stream().to_string();
let variadic_meta = get_variadic_trait_impl_meta_data(&maybe_primitive_type, &method_name)?;
let variadic_meta = if variadic_meta.is_some() {
variadic_meta
} else {
get_variadic_method_meta_data(&self_ty, &method_name)?
};
if let Some((_op_name, _vtd, vod)) = variadic_meta {
let rtype = return_type.clone();
let mut maybe_input_types = vec![Some(self_ty.clone())];
for arg in &mut expr.args {
maybe_input_types.push(self.rewrite_expr(arg, const_instances, variables, None)?);
}
let (concrete_ident, inferred_rtype) = get_concrete_op_or_method_ident_from_types(
vod,
method_ident,
&maybe_input_types,
rtype.clone(),
const_instances,
false,
)?;
expr.method = concrete_ident;
if inferred_rtype.is_some() {
Ok(inferred_rtype)
} else {
Ok(rtype)
}
} else {
Ok(return_type)
}
}
fn rewrite_call(
&self,
expr: &mut ExprCall,
const_instances: &ConstInstances,
variables: &mut TrainMap<String, Binding>,
return_type: Option<Type>,
) -> Result<Option<Type>, Error> {
let vod = get_vod_from_call(expr)?;
let maybe_inferred_rtype = match vod {
Some(_vod) => {
let rtype = return_type.clone();
let last_seg = match &mut *expr.func {
Expr::Path(path_expr) => {
if path_expr.path.segments.is_empty() {
return Err(syn_err(
path_expr.span(),
"Expected at least one path segment in function call",
));
}
path_expr.path.segments.last_mut().unwrap()
}
_ => {
return Err(syn_err(
expr.func.span(),
&format!("Unexpected function call expression."),
))
}
};
let mut maybe_input_types = vec![];
for arg in &mut expr.args {
maybe_input_types.push(self.rewrite_expr(
arg,
const_instances,
variables,
None,
)?);
}
let (concrete_ident, inferred_rtype) = get_concrete_op_ident_from_types(
&last_seg.ident,
&maybe_input_types,
rtype.clone(),
const_instances,
false,
)?;
last_seg.ident = concrete_ident;
if inferred_rtype.is_some() {
inferred_rtype
} else {
rtype
}
}
None => {
for arg in &mut expr.args {
self.rewrite_expr(arg, const_instances, variables, None)?;
}
return_type
}
};
self.rewrite_expr(&mut *expr.func, const_instances, variables, None)?;
Ok(maybe_inferred_rtype)
}
}
pub fn desugar_structure_cgas(item: &ItemStruct) -> Result<ItemStruct, Error> {
let const_instances = ConstInstances::from_generics(&item.generics)?;
let rewrite_pass = RewriteVariadicsPass {};
rewrite_pass.rewrite_struct(&item, &const_instances)
}
pub fn desugar_function_cgas(item: &ItemFn) -> Result<ItemFn, Error> {
let rewrite_pass = RewriteVariadicsPass {};
let const_instances = ConstInstances::from_generics(&item.sig.generics)?;
rewrite_pass.rewrite_function(&item, &const_instances)
}
pub fn desugar_impl_cgas(item: &ItemImpl) -> Result<ItemImpl, Error> {
let rewrite_pass = RewriteVariadicsPass {};
let const_instances = ConstInstances::from_generics(&item.generics)?;
rewrite_pass.rewrite_impl(item, &const_instances)
}
pub fn desugar_trait_cgas(item: &ItemTrait) -> Result<ItemTrait, Error> {
let rewrite_pass = RewriteVariadicsPass {};
let const_instances = ConstInstances::from_generics(&item.generics)?;
rewrite_pass.rewrite_trait(item, &const_instances)
}