use cutile_compiler::syn_utils::*;
use cutile_compiler::types::get_ptr_type;
use proc_macro2::Ident;
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::{quote, ToTokens};
use std::collections::HashMap;
use syn::{
parse_quote, AngleBracketedGenericArguments, Expr, ExprBlock, FnArg, GenericArgument,
GenericParam, Generics, ImplItemFn, ItemFn, PatType, Stmt, Type, TypeParam, TypeReference,
WherePredicate,
};
use crate::error::{Error, SpannedError};
#[derive(Debug, PartialOrd, PartialEq)]
pub(crate) enum SupportedGenericType {
TypeParam,
ConstScalar,
ConstArray,
Unknown,
}
#[derive(Debug)]
pub(crate) struct RequiredGenerics {
names: Vec<String>,
launcher_type_params: Vec<String>,
types: Vec<Option<Type>>,
expressions: HashMap<String, Option<String>>,
}
impl RequiredGenerics {
fn new(generics: &Generics) -> Self {
let req_generics = get_supported_generic_params(generics);
let (names, types): (Vec<String>, _) = req_generics.into_iter().unzip();
let mut expressions: HashMap<String, Option<String>> = HashMap::new();
for name in &names {
expressions.insert(name.clone(), None);
}
Self {
names,
launcher_type_params: vec![],
types,
expressions,
}
}
pub(crate) fn is_required(&self, s: &str) -> bool {
match self.expressions.get(s) {
None | Some(None) => true,
_ => false,
}
}
pub(crate) fn to_expr_str(&self) -> String {
let mut res = vec![];
for name in &self.names {
let Some(Some(expr_str)) = self.expressions.get(name) else {
return format!("panic!(\"Failed to infer value for generic parameter {name}\")");
};
res.push(expr_str.clone());
}
format!("vec![{}].concat()", res.join(","))
}
pub(crate) fn get_ty(&self, name: &str) -> SupportedGenericType {
let Some(index) = self.names.iter().position(|n| n == name) else {
return SupportedGenericType::Unknown;
};
let Some(ty) = &self.types[index] else {
return SupportedGenericType::TypeParam;
};
match ty {
Type::Array(_) => SupportedGenericType::ConstArray,
_ => SupportedGenericType::ConstScalar,
}
}
pub(crate) fn get_required_generics(&self) -> Generics {
let mut type_params = vec![];
for name in &self.names {
let is_launcher_type_param = self.launcher_type_params.contains(name);
if is_launcher_type_param && self.get_ty(&name) == SupportedGenericType::TypeParam {
type_params.push(format!("{}: Send + WithDType", name.clone()));
}
}
syn::parse2::<Generics>(format!("<{}>", type_params.join(", ")).parse().unwrap()).unwrap()
}
pub(crate) fn get_generic_args(&self) -> AngleBracketedGenericArguments {
let mut type_params = vec![];
for name in &self.names {
let is_launcher_type_param = self.launcher_type_params.contains(name);
if is_launcher_type_param && self.get_ty(&name) == SupportedGenericType::TypeParam {
type_params.push(format!("{}", name.clone()));
}
}
syn::parse2::<AngleBracketedGenericArguments>(
format!("<{}>", type_params.join(", ")).parse().unwrap(),
)
.unwrap()
}
}
pub fn join_as_cons_tuple(vals: &Vec<String>) -> String {
if vals.len() == 0 {
return "()".to_string();
}
if vals.len() == 1 {
return vals[0].clone();
};
let mut cons = vals.last().expect("Impossible").clone();
for i in (0..vals.len() - 1).rev() {
cons = format!("({}, {})", vals[i], cons);
}
cons
}
fn zippable(expr: &str, wrap_as_val: bool) -> String {
if !wrap_as_val {
return expr.to_string();
}
return format!("value({})", expr);
}
#[allow(dead_code)]
pub fn zip_cons(inputs: &Vec<String>, var_name: &str, wrap_as_val: bool) -> ExprBlock {
let mut zip_block = syn::parse2::<ExprBlock>(quote! {{
}})
.unwrap();
if inputs.len() == 0 {
return zip_block;
}
let mut i = inputs.len() - 1;
zip_block.block.stmts.push(parse_stmt(format!(
"let {var_name} = {};",
zippable(&inputs[i], wrap_as_val)
)));
while i != 0 {
i -= 1;
zip_block.block.stmts.push(parse_stmt(format!(
"let {var_name} = zip!({}, {var_name});",
zippable(&inputs[i], wrap_as_val)
)));
}
zip_block
}
pub fn zip_and_then_flatten(inputs: &Vec<String>, var_name: &str, wrap_as_val: bool) -> ExprBlock {
let mut zip_block = syn::parse2::<ExprBlock>(quote! {{
}})
.unwrap();
if inputs.len() == 0 {
zip_block
.block
.stmts
.push(parse_stmt(format!("let {var_name} = value(());")));
return zip_block;
}
let mut i = inputs.len() - 1;
zip_block.block.stmts.push(parse_stmt(format!(
"let {var_name} = {};",
zippable(&inputs[i], wrap_as_val)
)));
while i != 0 {
i -= 1;
zip_block.block.stmts.push(parse_stmt(format!(
"let {var_name} = zip!({}, {var_name});",
zippable(&inputs[i], wrap_as_val)
)));
}
zip_block.block.stmts.push(parse_stmt(format!(
r#"
let {var_name} = {var_name}.and_then(|{}| {{
value({})
}});
"#,
join_as_cons_tuple(inputs),
to_tuple_string(inputs)
)));
zip_block
}
#[allow(dead_code)]
fn kernel_arg_alias(launcher_name: &str, i: usize) -> String {
format!("{launcher_name}Arg{i}")
}
pub fn generate_launcher_arg_types(
generic_args: &AngleBracketedGenericArguments,
arg_tys: &Vec<Type>,
_launcher_name: &str,
launcher_args_name: &str,
) -> (Type, TokenStream2) {
let launcher_args_ident = Ident::new(launcher_args_name, Span::call_site());
let launcher_args_type: Type = if generic_args.args.len() > 0 {
parse_quote! { #launcher_args_ident #generic_args }
} else {
parse_quote! { #launcher_args_ident }
};
(
launcher_args_type.clone(),
quote! { type #launcher_args_type = ( #(#arg_tys,)* ); },
)
}
pub fn to_tuple_string(args: &Vec<String>) -> String {
format!(
"({})",
args.iter()
.map(|s| format!("{s},"))
.collect::<Vec<String>>()
.join("")
)
}
pub fn generate_kernel_launcher(
item: &ItemFn,
module_name: &str,
function_name: &str,
function_entry_name: &str,
launcher_name: &str,
launcher_args_name: &str,
) -> Result<(RequiredGenerics, (Type, TokenStream2), TokenStream2), Error> {
let unsafety = item.sig.unsafety;
let is_unsafe = unsafety.is_some();
let launcher_ident = Ident::new(launcher_name, Span::call_site());
let mut launcher_method = syn::parse2::<ImplItemFn>(quote! {
unsafe fn execute(mut self, ctx: &ExecutionContext) -> Result<<Self as DeviceOperation>::Output, DeviceError> {}
})
.unwrap();
let param_names = get_sig_param_names(&item.sig);
let param_names_tuple_str = to_tuple_string(¶m_names);
let (input_types, _output_type) = get_sig_types(&item.sig, None);
let mut stride_args = vec![];
let mut builder_statements = vec![];
let mut launch_grid_expr_strs = vec![];
let mut validator_statements = vec![];
let mut arg_types: Vec<Type> = vec![];
let mut required_generics: RequiredGenerics = RequiredGenerics::new(&item.sig.generics);
for (i, ty) in input_types.iter().enumerate() {
let var_name = ¶m_names[i];
match ty {
Type::Reference(ref_ty) => {
let res = get_tensor_code(i, var_name, &ref_ty, &mut required_generics)?;
arg_types.push(res.fn_arg.ty.as_ref().clone());
stride_args.push(res.stride_expr_str);
builder_statements.extend(res.builder_statements);
launch_grid_expr_strs.extend(res.launch_grid_expr_strs);
validator_statements.extend(res.validator_statements.block.stmts);
}
Type::Path(path_ty) => {
let ident = get_ident_from_path(&path_ty.path);
let type_name = ident.to_string();
arg_types.push(syn::parse2::<Type>(type_name.parse().unwrap()).unwrap());
if required_generics.is_required(&type_name) {
required_generics
.launcher_type_params
.push(type_name.clone());
required_generics.expressions.insert(
type_name.clone(),
Some(format!("vec![{type_name}::DTYPE.as_str().to_string()]")),
);
}
builder_statements.push(parse_stmt(format!(
"kernel_launch.push_arg(Box::new({var_name}));"
)));
}
Type::Ptr(ptr_type) => {
if !is_unsafe {
return ptr_type
.err("Pointers can only be used in unsafe kernel entry points.");
}
let ptr_str = ptr_type.to_token_stream().to_string();
let Some((is_mutable, type_name)) = get_ptr_type(&ptr_str) else {
return ptr_type.err(&format!("Unexpected pointer type: {}", ptr_str));
};
if !is_mutable {
return ptr_type.err("Pointers must be * mut.");
}
arg_types.push(
syn::parse2::<Type>(format!("DevicePointer<{}>", type_name).parse().unwrap())
.unwrap(),
);
if required_generics.is_required(&type_name) {
required_generics
.launcher_type_params
.push(type_name.clone());
required_generics.expressions.insert(
type_name.clone(),
Some(format!("vec![{type_name}::DTYPE.as_str().to_string()]")),
);
}
builder_statements.push(parse_stmt(format!("kernel_launch.push_arg({var_name});")));
}
_ => {
return ty.err("Unable to generate launcher: unsupported parameter type.");
}
}
}
let generic_params = required_generics.get_required_generics();
let generic_args = required_generics.get_generic_args();
let (launcher_args_type, launcher_arg_type_def) =
generate_launcher_arg_types(&generic_args, &arg_types, launcher_name, launcher_args_name);
let launcher_args_type_str = launcher_args_type.to_token_stream().to_string();
let device_op_param: GenericParam =
parse_quote! { DI: DeviceOperation<Output=#launcher_args_type> };
let device_op_arg: GenericArgument = parse_quote! { DI };
let mut struct_generics = generic_params.clone();
struct_generics.params.push(device_op_param.clone());
let mut struct_args = generic_args.clone();
struct_args.args.push(device_op_arg.clone());
let mut launch_output_type = generic_args.clone();
let impl_device_op: GenericArgument =
parse_quote! { impl DeviceOperation<Output=#launcher_args_type> };
launch_output_type.args.push(impl_device_op);
let init_stmts = syn::parse2::<ExprBlock>(quote! {{
let module_name = #module_name;
let function_name = #function_name;
let function_entry = #function_entry_name;
let input = self.input.take().unwrap();
}})
.unwrap()
.block
.stmts;
launcher_method.block.stmts.extend(init_stmts);
launcher_method.block.stmts.push(parse_stmt(format!(
r#"let {param_names_tuple_str}: {launcher_args_type_str} = input.execute(ctx)?;"#
)));
if required_generics.names.len() > 0 {
launcher_method.block.stmts.push(parse_stmt(format!(
r#"
let function_generics: Vec<String> = if self.function_generics.is_some() {{
self.function_generics.take().unwrap()
}} else {{
{}
}};
"#,
required_generics.to_expr_str()
)));
} else {
launcher_method.block.stmts.push(parse_stmt(
"let function_generics: Vec<String> = vec![];".to_string(),
));
}
launcher_method.block.stmts.push(parse_stmt(format!(
"let stride_args: Vec<(String, Vec<i32>)> = vec![{}];",
stride_args.join(",")
)));
let compile_stmts = syn::parse2::<ExprBlock>(quote! {{
let const_grid = if self._const_grid { Some(self._grid) } else { None };
let (function, validator) = self.compile(
ctx, _module_asts,
module_name, function_name, function_entry,
function_generics, stride_args, const_grid
)?;
}})
.unwrap()
.block
.stmts;
launcher_method.block.stmts.extend(compile_stmts);
launcher_method.block.stmts.extend(validator_statements);
launcher_method.block.stmts.push(parse_stmt(
"let mut kernel_launch = AsyncKernelLaunch::new(function.clone());".to_string(),
));
launcher_method.block.stmts.extend(builder_statements);
launcher_method.block.stmts.push(parse_stmt(format!(
"let launch_grid: (u32, u32, u32) = self.infer_launch_grid(&[{}])?;",
launch_grid_expr_strs.join(",")
)));
let launch_stmts = syn::parse2::<ExprBlock>(quote! {{
kernel_launch
.set_launch_config(LaunchConfig {
grid_dim: launch_grid,
block_dim: (1, 1, 1),
shared_mem_bytes: 0
});
kernel_launch.execute(ctx)?;
}})
.unwrap()
.block
.stmts;
launcher_method.block.stmts.extend(launch_stmts);
launcher_method.block.stmts.push(parse_stmt(format!(
r#"return Ok({param_names_tuple_str});"#
)));
let kernel_return_type = quote! {
#launcher_ident #launch_output_type
};
let apply_name = format!("{}_apply", function_name);
let launcher_apply_ident = Ident::new(
format!("{}_apply", function_name).as_str(),
Span::call_site(),
);
let launcher_apply = syn::parse2::<ItemFn>(quote! {
pub #unsafety fn #launcher_apply_ident #struct_generics (input: DI) -> #kernel_return_type {
return #launcher_ident::launch(input);
}
})
.unwrap();
let arg_aliases = {
let mut r = vec![];
for i in 0..arg_types.len() {
r.push(arg_types[i].to_token_stream().to_string());
}
r
};
let async_name = format!("{}_async", function_name);
let launcher_async_ident = Ident::new(async_name.as_str(), Span::call_site());
let mut launcher_async = syn::parse2::<ItemFn>(quote! {
pub #unsafety fn #launcher_async_ident #generic_params() -> #kernel_return_type {}
})
.unwrap();
let mut function_params = vec![];
launcher_async.sig.generics.make_where_clause();
for (i, _arg_ty) in arg_types.iter().enumerate() {
let function_param = format!("arg{}", i);
let type_param = format!("DI{}", i);
let type_bound = format!("DeviceOperation<Output={}>", arg_aliases[i]);
launcher_async.sig.inputs.push(FnArg::Typed(
syn::parse2::<PatType>(
format!("{}: {}", function_param, type_param)
.parse()
.unwrap(),
)
.unwrap(),
));
launcher_async.sig.generics.params.push(GenericParam::Type(
syn::parse2::<TypeParam>(type_param.parse().unwrap()).unwrap(),
));
let where_clause = launcher_async
.sig
.generics
.where_clause
.as_mut()
.expect("Impossible.");
where_clause.predicates.push(
syn::parse2::<WherePredicate>(
format!("{}: {}", type_param, type_bound).parse().unwrap(),
)
.unwrap(),
);
function_params.push(function_param);
}
let input_zips = zip_and_then_flatten(&function_params, "input", false);
launcher_async.block.stmts.extend(input_zips.block.stmts);
launcher_async
.block
.stmts
.push(parse_stmt(format!("return {}(input);", apply_name)));
let launcher_sync_ident = Ident::new(
format!("{}_sync", function_name).as_str(),
Span::call_site(),
);
let mut launcher_sync = syn::parse2::<ItemFn>(quote! {
pub #unsafety fn #launcher_sync_ident #generic_params() -> #kernel_return_type {}
})
.unwrap();
for (i, _arg_ty) in arg_types.iter().enumerate() {
let function_param = &function_params[i];
let type_param = &arg_aliases[i];
launcher_sync.sig.inputs.push(FnArg::Typed(
syn::parse2::<PatType>(
format!("{}: {}", function_param, type_param)
.parse()
.unwrap(),
)
.unwrap(),
));
}
let return_op = format!(
"return {async_name}({});",
function_params
.iter()
.map(|var| zippable(var, true))
.collect::<Vec<String>>()
.join(", ")
);
launcher_sync.block.stmts.push(parse_stmt(return_op));
Ok((
required_generics,
(launcher_args_type.clone(), launcher_arg_type_def),
quote! {
impl #struct_generics DeviceOperation for #launcher_ident #struct_args {
type Output = #launcher_args_type;
#launcher_method
}
#launcher_apply
#launcher_async
#launcher_sync
},
))
}
fn parse_stmt(s: String) -> Stmt {
syn::parse::<Stmt>(s.parse().unwrap()).unwrap()
}
#[allow(dead_code)]
fn parse_expr(s: String) -> Expr {
syn::parse::<Expr>(s.parse().unwrap()).unwrap()
}
struct TensorLaunchCode {
fn_arg: PatType, stride_expr_str: String,
builder_statements: Vec<Stmt>,
launch_grid_expr_strs: Vec<String>,
validator_statements: ExprBlock,
}
fn get_tensor_code(
var_idx: usize,
var_name: &str,
ty: &TypeReference,
required_generics: &mut RequiredGenerics,
) -> Result<TensorLaunchCode, Error> {
let (type_ident, type_generic_args) = get_ident_generic_args(&Type::Reference(ty.clone()));
let Some(type_ident) = type_ident else {
return ty.err("Expected a named type identifier for tensor parameter.");
};
if type_ident.to_string() != "Tensor" {
return ty.err(&format!("Expected Tensor type, got {}.", type_ident));
}
let Some(GenericArgument::Type(syn::Type::Path(element_type_path))) =
type_generic_args.args.first()
else {
return ty.err("Expected generic argument type path for tensor element type.");
};
infer_shape_params_from_tensor_type(
var_name,
&type_generic_args,
required_generics,
ty.mutability.is_some(),
)?;
let dtype = element_type_path
.path
.segments
.last()
.unwrap()
.ident
.to_string();
let tensor_type = if ty.mutability.is_some() {
format!("tensor::Partition<tensor::Tensor<{dtype}>>")
} else {
format!("Arc<tensor::Tensor<{dtype}>>")
};
let fn_arg =
syn::parse::<PatType>(format!("{var_name}: {tensor_type}").parse().unwrap()).unwrap();
let stride_expr_str = if ty.mutability.is_some() {
format!(
r#"(
"{var_name}".to_string(),
{{
let len = {var_name}.partition_strides.len();
let mut res = vec![-1; len];
res[len-1] = 1;
res
}}
)"#
)
} else {
format!(
r#"(
"{var_name}".to_string(),
{{
let len = {var_name}.strides.len();
let mut res = vec![-1; len];
res[len-1] = 1;
res
}}
)"#
)
};
let var_ident = Ident::new(var_name, Span::call_site());
let mut builder_statements = vec![];
let mut launch_grid_expr_strs = vec![];
let validator_statements = if ty.mutability.is_some() {
builder_statements.push(parse_stmt(format!("kernel_launch.push_arg(&{var_name});")));
launch_grid_expr_strs.push(format!("{var_name}.grid()?"));
syn::parse2::<ExprBlock>(quote! {{
{
let ValidParamType::Tensor(tensor_validator) = &validator.params[#var_idx] else {
panic!("Unexpected validator type {:#?}", &validator.params[#var_idx]);
};
let valid_shape = &tensor_validator.shape;
let given_shape = &#var_ident.partition_shape;
kernel_launch_assert(valid_shape.len() == given_shape.len(),
format!("{} rank mismatch: Expected {}, got {}", #var_name, valid_shape.len(), given_shape.len()).as_str())?;
kernel_launch_assert(valid_shape == given_shape,
format!("{} partition shape mismatch. Expected {:?}, got {:?}", #var_name, valid_shape, given_shape).as_str())?;
}
}})
.unwrap()
} else {
builder_statements.push(parse_stmt(format!(
"kernel_launch.push_arg_arc(&{var_name});"
)));
syn::parse2::<ExprBlock>(quote! {{
{
let ValidParamType::Tensor(tensor_validator) = &validator.params[#var_idx] else {
panic!("Unexpected validator type {:#?}", &validator.params[#var_idx]);
};
let valid_shape = &tensor_validator.shape;
let given_shape = &#var_ident.shape;
kernel_launch_assert(valid_shape.len() == given_shape.len(),
format!("{} rank mismatch: Expected {}, got {}", #var_name, valid_shape.len(), given_shape.len()).as_str())?;
let valid_shape_mixed = zip(valid_shape, given_shape).map(|(&expected, &given)|{
if expected == -1 { given } else { expected }
}).collect::<Vec<_>>();
let pred = zip(&valid_shape_mixed, given_shape).all(|(&expected, &given)|{
expected == given
});
kernel_launch_assert(pred,
format!("{} partition shape mismatch. Expected {:?}, got {:?}", #var_name, valid_shape_mixed, given_shape).as_str())?;
}
}})
.unwrap()
};
Ok(TensorLaunchCode {
fn_arg,
stride_expr_str,
builder_statements,
launch_grid_expr_strs,
validator_statements,
})
}
pub fn infer_shape_params_from_tensor_type(
var_name: &str,
type_generic_args: &AngleBracketedGenericArguments,
required_generics: &mut RequiredGenerics,
is_mutable: bool,
) -> Result<(), Error> {
for generic_arg in &type_generic_args.args {
match generic_arg {
GenericArgument::Type(type_param) => {
match type_param {
syn::Type::Path(type_path) => {
let last_ident = type_path.path.segments.last().unwrap().ident.to_string();
match required_generics.get_ty(&last_ident) {
SupportedGenericType::TypeParam => {
required_generics
.launcher_type_params
.push(last_ident.clone());
required_generics.expressions.insert(
last_ident.clone(),
Some(format!("vec![{var_name}.dtype().as_str().to_string()]")),
);
}
SupportedGenericType::ConstArray => {
if is_mutable {
required_generics.expressions.insert(last_ident.clone(), Some(format!("{var_name}.partition_shape.iter().map(|x| x.to_string()).collect::<Vec<String>>()")));
} else {
required_generics.expressions.insert(last_ident.clone(), Some(format!("{var_name}.shape.iter().map(|x| x.to_string()).collect::<Vec<String>>()")));
}
}
SupportedGenericType::ConstScalar => {
return type_path.err(
"Unexpected constant scalar type in tensor generic argument.",
);
}
SupportedGenericType::Unknown => {}
}
}
_ => {}
}
}
GenericArgument::Const(const_param) => {
match const_param {
Expr::Block(block_expr) => {
if block_expr.block.stmts.len() != 1 {
return block_expr.err(&format!(
"Expected exactly 1 statement in block expression, got {}.",
block_expr.block.stmts.len()
));
}
let statement = &block_expr.block.stmts[0];
let Stmt::Expr(statement_expr, _) = statement else {
return block_expr.err(
"Unexpected block expression: expected an expression statement.",
);
};
match statement_expr {
Expr::Array(array_expr) => {
for (i, elem) in array_expr.elems.iter().enumerate() {
match elem {
Expr::Lit(_lit) => {
continue;
}
Expr::Unary(_unary_expr) => {
continue;
}
Expr::Path(path) => {
let ident = get_ident_from_path_expr(path).to_string();
match required_generics.get_ty(&ident) {
SupportedGenericType::TypeParam => {
return path.err("Unexpected type param in array type expression.");
}
SupportedGenericType::ConstArray => {
return path.err("Unexpected const generic array param in array type expression.");
}
SupportedGenericType::ConstScalar => {
if is_mutable {
required_generics.expressions.insert(ident.clone(), Some(format!("vec![{var_name}.partition_shape[{i}].to_string()]")));
} else {
required_generics.expressions.insert(ident.clone(), Some(format!("vec![{var_name}.shape[{i}].to_string()]")));
}
}
SupportedGenericType::Unknown => {}
}
}
_ => {
return elem.err("Unsupported array element in tensor shape expression.");
}
}
}
}
Expr::Repeat(repeat_expr) => {
return repeat_expr.err(
"Repeat expressions in tensor shape are not yet supported.",
);
}
_ => {
return block_expr.err(
"Unexpected block expression in tensor const generic argument.",
);
}
}
}
_ => {}
}
}
_ => {}
}
}
Ok(())
}