use cutile_compiler::kernel_naming::KernelNaming;
use cutile_compiler::syn_utils::*;
use cutile_compiler::types::get_ptr_type;
use proc_macro2::Ident;
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::{quote, ToTokens};
use std::collections::HashMap;
use syn::{
parse_quote, AngleBracketedGenericArguments, Expr, ExprBlock, FnArg, GenericArgument,
GenericParam, Generics, ImplItemFn, ItemFn, PatType, Stmt, Type, TypeParam, TypeReference,
WherePredicate,
};
use crate::error::{Error, SpannedError};
#[derive(Debug, PartialOrd, PartialEq)]
pub(crate) enum SupportedGenericType {
TypeParam,
ConstScalar,
ConstArray,
Unknown,
}
#[derive(Debug)]
pub(crate) struct RequiredGenerics {
names: Vec<String>,
launcher_type_params: Vec<String>,
types: Vec<Option<Type>>,
expressions: HashMap<String, Option<String>>,
}
impl RequiredGenerics {
fn new(generics: &Generics) -> Self {
let req_generics = get_supported_generic_params(generics);
let (names, types): (Vec<String>, _) = req_generics.into_iter().unzip();
let mut expressions: HashMap<String, Option<String>> = HashMap::new();
for name in &names {
expressions.insert(name.clone(), None);
}
Self {
names,
launcher_type_params: vec![],
types,
expressions,
}
}
pub(crate) fn is_required(&self, s: &str) -> bool {
matches!(self.expressions.get(s), None | Some(None))
}
pub(crate) fn to_expr_str(&self) -> String {
let mut res = vec![];
for name in &self.names {
let Some(Some(expr_str)) = self.expressions.get(name) else {
return format!("panic!(\"Failed to infer value for generic parameter {name}\")");
};
res.push(expr_str.clone());
}
format!("vec![{}].concat()", res.join(","))
}
pub(crate) fn get_ty(&self, name: &str) -> SupportedGenericType {
let Some(index) = self.names.iter().position(|n| n == name) else {
return SupportedGenericType::Unknown;
};
let Some(ty) = &self.types[index] else {
return SupportedGenericType::TypeParam;
};
match ty {
Type::Array(_) => SupportedGenericType::ConstArray,
_ => SupportedGenericType::ConstScalar,
}
}
pub(crate) fn get_required_generics(&self) -> Generics {
let mut type_params = vec![];
for name in &self.names {
let is_launcher_type_param = self.launcher_type_params.contains(name);
if is_launcher_type_param && self.get_ty(name) == SupportedGenericType::TypeParam {
type_params.push(format!("{}: Send + DType", name.clone()));
}
}
syn::parse2::<Generics>(format!("<{}>", type_params.join(", ")).parse().unwrap()).unwrap()
}
pub(crate) fn get_generic_args(&self) -> AngleBracketedGenericArguments {
let mut type_params = vec![];
for name in &self.names {
let is_launcher_type_param = self.launcher_type_params.contains(name);
if is_launcher_type_param && self.get_ty(name) == SupportedGenericType::TypeParam {
type_params.push(name.to_string());
}
}
syn::parse2::<AngleBracketedGenericArguments>(
format!("<{}>", type_params.join(", ")).parse().unwrap(),
)
.unwrap()
}
}
pub fn join_as_cons_tuple(vals: &[String]) -> String {
if vals.is_empty() {
return "()".to_string();
}
if vals.len() == 1 {
return vals[0].clone();
};
let mut cons = vals.last().expect("Impossible").clone();
for i in (0..vals.len() - 1).rev() {
cons = format!("({}, {})", vals[i], cons);
}
cons
}
fn zippable(expr: &str, wrap_as_val: bool) -> String {
if !wrap_as_val {
return expr.to_string();
}
format!("value({})", expr)
}
#[allow(dead_code)]
pub fn zip_cons(inputs: &[String], var_name: &str, wrap_as_val: bool) -> ExprBlock {
let mut zip_block = syn::parse2::<ExprBlock>(quote! {{
}})
.unwrap();
if inputs.is_empty() {
return zip_block;
}
let mut i = inputs.len() - 1;
zip_block.block.stmts.push(parse_stmt(format!(
"let {var_name} = {};",
zippable(&inputs[i], wrap_as_val)
)));
while i != 0 {
i -= 1;
zip_block.block.stmts.push(parse_stmt(format!(
"let {var_name} = zip!({}, {var_name});",
zippable(&inputs[i], wrap_as_val)
)));
}
zip_block
}
pub fn zip_and_then_flatten(inputs: &[String], var_name: &str, wrap_as_val: bool) -> ExprBlock {
let mut zip_block = syn::parse2::<ExprBlock>(quote! {{
}})
.unwrap();
if inputs.is_empty() {
zip_block
.block
.stmts
.push(parse_stmt(format!("let {var_name} = value(());")));
return zip_block;
}
let mut i = inputs.len() - 1;
zip_block.block.stmts.push(parse_stmt(format!(
"let {var_name} = {};",
zippable(&inputs[i], wrap_as_val)
)));
while i != 0 {
i -= 1;
zip_block.block.stmts.push(parse_stmt(format!(
"let {var_name} = zip!({}, {var_name});",
zippable(&inputs[i], wrap_as_val)
)));
}
zip_block.block.stmts.push(parse_stmt(format!(
r#"
let {var_name} = {var_name}.then(|{}| {{
value({})
}});
"#,
join_as_cons_tuple(inputs),
to_tuple_string(inputs)
)));
zip_block
}
#[allow(dead_code)]
fn kernel_arg_alias(launcher_name: &str, i: usize) -> String {
format!("{launcher_name}Arg{i}")
}
pub fn generate_launcher_arg_types(
generic_args: &AngleBracketedGenericArguments,
arg_tys: &Vec<Type>,
_launcher_name: &str,
launcher_args_name: &str,
) -> (Type, TokenStream2) {
let launcher_args_ident = Ident::new(launcher_args_name, Span::call_site());
let launcher_args_type: Type = if !generic_args.args.is_empty() {
parse_quote! { #launcher_args_ident #generic_args }
} else {
parse_quote! { #launcher_args_ident }
};
(
launcher_args_type.clone(),
quote! { type #launcher_args_type = ( #(#arg_tys,)* ); },
)
}
pub fn to_tuple_string(args: &[String]) -> String {
format!(
"({})",
args.iter()
.map(|s| format!("{s},"))
.collect::<Vec<String>>()
.join("")
)
}
pub fn generate_kernel_launcher(
item: &ItemFn,
module_name: &str,
function_name: &str,
function_entry_name: &str,
launcher_name: &str,
_launcher_args_name: &str,
) -> Result<
(
RequiredGenerics,
(Type, Type),
TokenStream2,
KernelInputInfo,
),
Error,
> {
let unsafety = item.sig.unsafety;
let is_unsafe = unsafety.is_some();
let launcher_ident = Ident::new(launcher_name, Span::call_site());
let mut launcher_method = syn::parse2::<ImplItemFn>(quote! {
unsafe fn execute(mut self, ctx: &ExecutionContext) -> Result<<Self as DeviceOp>::Output, DeviceError> {}
})
.unwrap();
let param_names = get_sig_param_names(&item.sig);
let param_names_tuple_str = to_tuple_string(¶m_names);
let (input_types, _output_type) = get_sig_types(&item.sig, None);
let mut stride_args = vec![];
let mut spec_args: Vec<String> = vec![];
let mut scalar_hint_exprs: Vec<String> = vec![];
let mut builder_statements = vec![];
let mut launch_grid_expr_strs = vec![];
let mut validator_statements = vec![];
let mut arg_types: Vec<Type> = vec![];
let mut param_element_types: Vec<Option<String>> = vec![];
let mut required_generics: RequiredGenerics = RequiredGenerics::new(&item.sig.generics);
for (i, ty) in input_types.iter().enumerate() {
let var_name = ¶m_names[i];
match ty {
Type::Reference(ref_ty) => {
let res = get_tensor_code(i, var_name, ref_ty, &mut required_generics)?;
arg_types.push(res.fn_arg.ty.as_ref().clone());
param_element_types.push(res.element_type_name);
stride_args.push(res.stride_expr_str);
spec_args.push(res.spec_expr_str);
builder_statements.extend(res.builder_statements);
launch_grid_expr_strs.extend(res.launch_grid_expr_strs);
validator_statements.extend(res.validator_statements.block.stmts);
}
Type::Path(path_ty) => {
let ident = get_ident_from_path(&path_ty.path);
let type_name = ident.to_string();
arg_types.push(syn::parse2::<Type>(type_name.parse().unwrap()).unwrap());
if required_generics.is_required(&type_name) {
required_generics
.launcher_type_params
.push(type_name.clone());
required_generics.expressions.insert(
type_name.clone(),
Some(format!("vec![{type_name}::DTYPE.as_str().to_string()]")),
);
}
builder_statements.push(parse_stmt(format!("kernel_launch.push_arg({var_name});")));
if cutile_compiler::specialization::is_integer_scalar(&type_name) {
scalar_hint_exprs.push(format!(
r#"("{var_name}".to_string(), cutile_compiler::specialization::DivHint::from_value({var_name} as i32))"#
));
}
param_element_types.push(None);
}
Type::Ptr(ptr_type) => {
if !is_unsafe {
return ptr_type
.err("Pointers can only be used in unsafe kernel entry points.");
}
let ptr_str = ptr_type.to_token_stream().to_string();
let Some((is_mutable, type_name)) = get_ptr_type(&ptr_str) else {
return ptr_type.err(&format!("Unexpected pointer type: {}", ptr_str));
};
if !is_mutable {
return ptr_type.err("Pointers must be * mut.");
}
arg_types.push(
syn::parse2::<Type>(format!("DevicePointer<{}>", type_name).parse().unwrap())
.unwrap(),
);
if required_generics.is_required(&type_name) {
required_generics
.launcher_type_params
.push(type_name.clone());
required_generics.expressions.insert(
type_name.clone(),
Some(format!("vec![{type_name}::DTYPE.as_str().to_string()]")),
);
}
builder_statements.push(parse_stmt(format!(
"unsafe {{ kernel_launch.push_device_ptr({var_name}.cu_deviceptr()); }}"
)));
param_element_types.push(None);
}
_ => {
return ty.err("Unable to generate launcher: unsupported parameter type.");
}
}
}
let mut ki_type_param_names: Vec<String> = vec![];
let mut ki_element_type_names: Vec<String> = vec![];
let mut ki_param_idx: Vec<Option<usize>> = vec![];
let mut ko_type_param_names: Vec<String> = vec![];
let mut ko_element_type_names: Vec<String> = vec![];
let mut ko_param_idx: Vec<Option<usize>> = vec![];
for (i, ty) in arg_types.iter().enumerate() {
let ty_str = ty.to_token_stream().to_string();
if ty_str.starts_with("Arc <") {
let idx = ki_type_param_names.len();
ki_type_param_names.push(format!("_K{}", idx));
ki_element_type_names.push(
param_element_types[i]
.clone()
.expect("&Tensor param must have element type"),
);
ki_param_idx.push(Some(idx));
ko_param_idx.push(None);
} else if ty_str.contains("Partition") {
let idx = ko_type_param_names.len();
ko_type_param_names.push(format!("_P{}", idx));
let elem = ty_str
.split("Tensor <")
.nth(1)
.and_then(|s| s.split('>').next())
.map(|s| s.trim().to_string())
.expect("Partition param must have element type");
ko_element_type_names.push(elem);
ko_param_idx.push(Some(idx));
ki_param_idx.push(None);
} else {
ki_param_idx.push(None);
ko_param_idx.push(None);
}
}
let recovered_fields: Vec<String> = param_names
.iter()
.enumerate()
.map(|(i, name)| {
if let Some(ki_idx) = ki_param_idx[i] {
let ki_name = &ki_type_param_names[ki_idx];
let elem = &ki_element_type_names[ki_idx];
format!("<{ki_name} as KernelInput<{elem}>>::recover({name})")
} else if let Some(ko_idx) = ko_param_idx[i] {
let ko_name = &ko_type_param_names[ko_idx];
let elem = &ko_element_type_names[ko_idx];
format!("<{ko_name} as KernelOutput<{elem}>>::recover({name})")
} else {
name.clone()
}
})
.collect();
let recovered_tuple_str = to_tuple_string(&recovered_fields);
let kernel_input_info = KernelInputInfo {
type_param_names: ki_type_param_names,
element_type_names: ki_element_type_names,
param_kernel_input_idx: ki_param_idx.clone(),
ko_type_param_names: ko_type_param_names.clone(),
ko_element_type_names: ko_element_type_names.clone(),
param_kernel_output_idx: ko_param_idx.clone(),
recovered_tuple_str: recovered_tuple_str.clone(),
};
let ki_info = &kernel_input_info;
let stored_arg_types: Vec<Type> = arg_types
.iter()
.enumerate()
.map(|(i, ty)| {
if let Some(ki_idx) = ki_info.param_kernel_input_idx[i] {
let ki_name = &ki_info.type_param_names[ki_idx];
let elem = &ki_info.element_type_names[ki_idx];
syn::parse_str::<Type>(&format!("<{ki_name} as KernelInput<{elem}>>::Stored"))
.unwrap()
} else if let Some(ko_idx) = ki_info.param_kernel_output_idx[i] {
let ko_name = &ki_info.ko_type_param_names[ko_idx];
let elem = &ki_info.ko_element_type_names[ko_idx];
syn::parse_str::<Type>(&format!("<{ko_name} as KernelOutput<{elem}>>::Stored"))
.unwrap()
} else {
ty.clone()
}
})
.collect();
let returned_arg_types: Vec<Type> = arg_types
.iter()
.enumerate()
.map(|(i, ty)| {
if let Some(ki_idx) = ki_info.param_kernel_input_idx[i] {
let ki_name = &ki_info.type_param_names[ki_idx];
let elem = &ki_info.element_type_names[ki_idx];
syn::parse_str::<Type>(&format!("<{ki_name} as KernelInput<{elem}>>::Returned"))
.unwrap()
} else if let Some(ko_idx) = ki_info.param_kernel_output_idx[i] {
let ko_name = &ki_info.ko_type_param_names[ko_idx];
let elem = &ki_info.ko_element_type_names[ko_idx];
syn::parse_str::<Type>(&format!("<{ko_name} as KernelOutput<{elem}>>::Returned"))
.unwrap()
} else {
ty.clone()
}
})
.collect();
let stored_args_type: Type = parse_quote! { ( #(#stored_arg_types,)* ) };
let returned_args_type: Type = parse_quote! { ( #(#returned_arg_types,)* ) };
let stored_args_type_str = stored_args_type.to_token_stream().to_string();
let generic_params = required_generics.get_required_generics();
let generic_args = required_generics.get_generic_args();
let mut struct_generics = generic_params.clone();
for (ki_idx, ki_name) in ki_info.type_param_names.iter().enumerate() {
let elem = &ki_info.element_type_names[ki_idx];
struct_generics.params.push(
syn::parse_str::<GenericParam>(&format!("{ki_name}: KernelInput<{elem}>")).unwrap(),
);
}
for (ko_idx, ko_name) in ki_info.ko_type_param_names.iter().enumerate() {
let elem = &ki_info.ko_element_type_names[ko_idx];
struct_generics.params.push(
syn::parse_str::<GenericParam>(&format!("{ko_name}: KernelOutput<{elem}>")).unwrap(),
);
}
let device_op_param: GenericParam = parse_quote! { DI: DeviceOp<Output=#stored_args_type> };
struct_generics.params.push(device_op_param.clone());
let mut struct_args = generic_args.clone();
for ki_name in &ki_info.type_param_names {
struct_args
.args
.push(syn::parse_str::<GenericArgument>(ki_name).unwrap());
}
for ko_name in &ki_info.ko_type_param_names {
struct_args
.args
.push(syn::parse_str::<GenericArgument>(ko_name).unwrap());
}
let device_op_arg: GenericArgument = parse_quote! { DI };
struct_args.args.push(device_op_arg.clone());
let launcher_stored_arg_types: Vec<Type> = arg_types
.iter()
.enumerate()
.map(|(i, ty)| {
if let Some(ki_idx) = ki_info.param_kernel_input_idx[i] {
let elem = &ki_info.element_type_names[ki_idx];
syn::parse_str::<Type>(&format!("<_S{i} as KernelInput<{elem}>>::Stored")).unwrap()
} else if let Some(ko_idx) = ki_info.param_kernel_output_idx[i] {
let elem = &ki_info.ko_element_type_names[ko_idx];
syn::parse_str::<Type>(&format!("<_Q{i} as KernelOutput<{elem}>>::Stored")).unwrap()
} else {
ty.clone()
}
})
.collect();
let launcher_stored_args_type: Type = parse_quote! { ( #(#launcher_stored_arg_types,)* ) };
let mut launch_output_type = generic_args.clone();
for (i, is_arc) in ki_param_idx.iter().enumerate() {
if is_arc.is_some() {
launch_output_type
.args
.push(syn::parse_str::<GenericArgument>(&format!("_S{}", i)).unwrap());
}
}
for (i, is_part) in ko_param_idx.iter().enumerate() {
if is_part.is_some() {
launch_output_type
.args
.push(syn::parse_str::<GenericArgument>(&format!("_Q{}", i)).unwrap());
}
}
let impl_device_op: GenericArgument =
parse_quote! { impl DeviceOp<Output=#launcher_stored_args_type> };
launch_output_type.args.push(impl_device_op);
let init_stmts = syn::parse2::<ExprBlock>(quote! {{
let module_name = #module_name;
let function_name = #function_name;
let function_entry = #function_entry_name;
let input = self.input.take().unwrap();
}})
.unwrap()
.block
.stmts;
launcher_method.block.stmts.extend(init_stmts);
launcher_method.block.stmts.push(parse_stmt(format!(
r#"let {param_names_tuple_str}: {stored_args_type_str} = input.execute(ctx)?;"#
)));
if !required_generics.names.is_empty() {
launcher_method.block.stmts.push(parse_stmt(format!(
r#"
let function_generics: Vec<String> = if self.function_generics.is_some() {{
self.function_generics.take().unwrap()
}} else {{
{}
}};
"#,
required_generics.to_expr_str()
)));
} else {
launcher_method.block.stmts.push(parse_stmt(
"let function_generics: Vec<String> = vec![];".to_string(),
));
}
launcher_method.block.stmts.push(parse_stmt(format!(
"let stride_args: Vec<(String, Vec<i32>)> = vec![{}];",
stride_args.join(",")
)));
launcher_method.block.stmts.push(parse_stmt(format!(
"let spec_args = vec![{}];",
spec_args.join(",")
)));
launcher_method.block.stmts.push(parse_stmt(format!(
"let scalar_hints: Vec<(String, cutile_compiler::specialization::DivHint)> = vec![{}];",
scalar_hint_exprs.join(",")
)));
let compile_stmts = syn::parse2::<ExprBlock>(quote! {{
let const_grid = if self._const_grid { Some(self._grid) } else { None };
let compile_options = std::mem::take(&mut self._compile_options);
let (function, validator) = self.compile(
ctx, _module_asts,
module_name, function_name, function_entry,
function_generics, stride_args, spec_args.clone(), scalar_hints,
const_grid,
compile_options
)?;
}})
.unwrap()
.block
.stmts;
launcher_method.block.stmts.extend(compile_stmts);
launcher_method.block.stmts.extend(validator_statements);
launcher_method.block.stmts.push(parse_stmt(
"let mut kernel_launch = AsyncKernelLaunch::new(function.clone());".to_string(),
));
launcher_method.block.stmts.extend(builder_statements);
launcher_method.block.stmts.push(parse_stmt(format!(
"let launch_grid: (u32, u32, u32) = self.infer_launch_grid(&[{}])?;",
launch_grid_expr_strs.join(",")
)));
let launch_stmts = syn::parse2::<ExprBlock>(quote! {{
kernel_launch
.set_launch_config(LaunchConfig {
grid_dim: launch_grid,
block_dim: (1, 1, 1),
shared_mem_bytes: 0
});
kernel_launch.execute(ctx)?;
}})
.unwrap()
.block
.stmts;
launcher_method.block.stmts.extend(launch_stmts);
launcher_method
.block
.stmts
.push(parse_stmt(format!(r#"return Ok({recovered_tuple_str});"#)));
let kernel_naming = KernelNaming::new(function_name);
let concrete_args_type: Type = parse_quote! { ( #(#arg_types,)* ) };
let mut apply_launch_output_type = generic_args.clone();
for ki_name in &ki_info.type_param_names {
let ki_idx_in_types = ki_info
.param_kernel_input_idx
.iter()
.enumerate()
.find(|(_, idx)| {
idx.map(|k| ki_info.type_param_names[k] == *ki_name)
.unwrap_or(false)
})
.map(|(i, _)| i)
.unwrap();
let arc_type = &arg_types[ki_idx_in_types];
apply_launch_output_type.args.push(
syn::parse_str::<GenericArgument>(&arc_type.to_token_stream().to_string()).unwrap(),
);
}
for ko_name in &ki_info.ko_type_param_names {
let ko_idx_in_types = ki_info
.param_kernel_output_idx
.iter()
.enumerate()
.find(|(_, idx)| {
idx.map(|k| ki_info.ko_type_param_names[k] == *ko_name)
.unwrap_or(false)
})
.map(|(i, _)| i)
.unwrap();
let part_type = &arg_types[ko_idx_in_types];
apply_launch_output_type.args.push(
syn::parse_str::<GenericArgument>(&part_type.to_token_stream().to_string()).unwrap(),
);
}
let apply_impl_device_op: GenericArgument =
parse_quote! { impl DeviceOp<Output=#concrete_args_type> };
apply_launch_output_type.args.push(apply_impl_device_op);
let apply_return_type = quote! { #launcher_ident #apply_launch_output_type };
let apply_name = kernel_naming.apply_name();
let launcher_apply_ident = Ident::new(apply_name.as_str(), Span::call_site());
let launcher_apply = syn::parse2::<ItemFn>(quote! {
pub #unsafety fn #launcher_apply_ident #generic_params (input: #concrete_args_type) -> #apply_return_type {
return #launcher_ident::launch(value(input));
}
})
.unwrap();
let kernel_return_type = quote! { #launcher_ident #launch_output_type };
let arg_aliases = arg_types
.iter()
.map(|i| i.to_token_stream().to_string())
.collect::<Vec<_>>();
let launcher_direct_ident = Ident::new(kernel_naming.public_name(), Span::call_site());
let mut launcher_direct = syn::parse2::<ItemFn>(quote! {
pub #unsafety fn #launcher_direct_ident #generic_params() -> #kernel_return_type {}
})
.unwrap();
launcher_direct.sig.generics.make_where_clause();
let mut function_params = vec![];
let mut is_arc_param = vec![];
for (i, _arg_ty) in arg_types.iter().enumerate() {
let function_param = format!("arg{}", i);
let type_param_name = format!("_A{}", i);
let arg_type_str = &arg_aliases[i];
let is_arc = arg_type_str.starts_with("Arc <");
is_arc_param.push(is_arc);
launcher_direct.sig.inputs.push(FnArg::Typed(
syn::parse2::<PatType>(
format!("{}: {}", function_param, type_param_name)
.parse()
.unwrap(),
)
.unwrap(),
));
let where_clause = launcher_direct
.sig
.generics
.where_clause
.as_mut()
.expect("Impossible.");
if is_arc {
let intermediate_type = format!("_S{}", i);
launcher_direct.sig.generics.params.push(GenericParam::Type(
syn::parse2::<TypeParam>(type_param_name.parse().unwrap()).unwrap(),
));
launcher_direct.sig.generics.params.push(GenericParam::Type(
syn::parse2::<TypeParam>(intermediate_type.parse().unwrap()).unwrap(),
));
where_clause.predicates.push(
syn::parse2::<WherePredicate>(
format!("{}: IntoDeviceOp<{}>", type_param_name, intermediate_type)
.parse()
.unwrap(),
)
.unwrap(),
);
let ki_idx = ki_param_idx[i].unwrap();
let elem = &ki_info.element_type_names[ki_idx];
where_clause.predicates.push(
syn::parse2::<WherePredicate>(
format!("{}: KernelInput<{}>", intermediate_type, elem)
.parse()
.unwrap(),
)
.unwrap(),
);
} else if ko_param_idx[i].is_some() {
let intermediate_type = format!("_Q{}", i);
launcher_direct.sig.generics.params.push(GenericParam::Type(
syn::parse2::<TypeParam>(type_param_name.parse().unwrap()).unwrap(),
));
launcher_direct.sig.generics.params.push(GenericParam::Type(
syn::parse2::<TypeParam>(intermediate_type.parse().unwrap()).unwrap(),
));
where_clause.predicates.push(
syn::parse2::<WherePredicate>(
format!("{}: IntoDeviceOp<{}>", type_param_name, intermediate_type)
.parse()
.unwrap(),
)
.unwrap(),
);
let ko_idx = ko_param_idx[i].unwrap();
let elem = &ki_info.ko_element_type_names[ko_idx];
where_clause.predicates.push(
syn::parse2::<WherePredicate>(
format!("{}: KernelOutput<{}>", intermediate_type, elem)
.parse()
.unwrap(),
)
.unwrap(),
);
} else {
launcher_direct.sig.generics.params.push(GenericParam::Type(
syn::parse2::<TypeParam>(type_param_name.parse().unwrap()).unwrap(),
));
where_clause.predicates.push(
syn::parse2::<WherePredicate>(
format!("{}: IntoDeviceOp<{}>", type_param_name, arg_type_str)
.parse()
.unwrap(),
)
.unwrap(),
);
}
function_params.push(function_param);
}
let mut di_var_names: Vec<String> = vec![];
for (i, var) in function_params.iter().enumerate() {
let di_var = format!("_di{}", i);
if is_arc_param[i] {
launcher_direct.block.stmts.push(parse_stmt(format!(
"let {di_var} = {var}.into_op().map(KernelInput::prepare);"
)));
} else if ko_param_idx[i].is_some() {
launcher_direct.block.stmts.push(parse_stmt(format!(
"let {di_var} = {var}.into_op().map(KernelOutput::prepare);"
)));
} else {
launcher_direct
.block
.stmts
.push(parse_stmt(format!("let {di_var} = {var}.into_op();")));
}
di_var_names.push(di_var);
}
let input_zips = zip_and_then_flatten(&di_var_names, "input", false);
launcher_direct.block.stmts.extend(input_zips.block.stmts);
launcher_direct.block.stmts.push(parse_stmt(format!(
"return {}::launch(input);",
launcher_ident
)));
let returned_args_type_2 = returned_args_type.clone();
Ok((
required_generics,
(stored_args_type, returned_args_type),
quote! {
impl #struct_generics DeviceOp for #launcher_ident #struct_args {
type Output = #returned_args_type_2;
#launcher_method
}
impl #struct_generics GraphNode for #launcher_ident #struct_args {}
#launcher_apply
#launcher_direct
},
kernel_input_info,
))
}
fn parse_stmt(s: String) -> Stmt {
syn::parse::<Stmt>(s.parse().unwrap()).unwrap()
}
#[allow(dead_code)]
fn parse_expr(s: String) -> Expr {
syn::parse::<Expr>(s.parse().unwrap()).unwrap()
}
pub struct KernelInputInfo {
pub type_param_names: Vec<String>,
pub element_type_names: Vec<String>,
pub param_kernel_input_idx: Vec<Option<usize>>,
pub ko_type_param_names: Vec<String>,
pub ko_element_type_names: Vec<String>,
pub param_kernel_output_idx: Vec<Option<usize>>,
pub recovered_tuple_str: String,
}
struct TensorLaunchCode {
fn_arg: PatType, stride_expr_str: String,
spec_expr_str: String,
builder_statements: Vec<Stmt>,
launch_grid_expr_strs: Vec<String>,
validator_statements: ExprBlock,
element_type_name: Option<String>,
}
fn get_tensor_code(
var_idx: usize,
var_name: &str,
ty: &TypeReference,
required_generics: &mut RequiredGenerics,
) -> Result<TensorLaunchCode, Error> {
let (type_ident, type_generic_args) = get_ident_generic_args(&Type::Reference(ty.clone()));
let Some(type_ident) = type_ident else {
return ty.err("Expected a named type identifier for tensor parameter.");
};
if type_ident != "Tensor" {
return ty.err(&format!("Expected Tensor type, got {}.", type_ident));
}
let Some(GenericArgument::Type(syn::Type::Path(element_type_path))) =
type_generic_args.args.first()
else {
return ty.err("Expected generic argument type path for tensor element type.");
};
infer_shape_params_from_tensor_type(
var_name,
&type_generic_args,
required_generics,
ty.mutability.is_some(),
)?;
let dtype = element_type_path
.path
.segments
.last()
.unwrap()
.ident
.to_string();
let tensor_type = if ty.mutability.is_some() {
format!("tensor::Partition<tensor::Tensor<{dtype}>>")
} else {
format!("Arc<tensor::Tensor<{dtype}>>")
};
let fn_arg =
syn::parse::<PatType>(format!("{var_name}: {tensor_type}").parse().unwrap()).unwrap();
let stride_expr_str = if ty.mutability.is_some() {
format!(
r#"(
"{var_name}".to_string(),
KernelOutputStored::strides_hint(&{var_name})
)"#
)
} else {
format!(
r#"(
"{var_name}".to_string(),
{var_name}.spec().stride_one.iter()
.map(|&is_one| if is_one {{ 1 }} else {{ -1 }})
.collect::<Vec<i32>>()
)"#
)
};
let spec_expr_str = if ty.mutability.is_some() {
format!(
r#"(
"{var_name}".to_string(),
KernelOutputStored::spec(&{var_name}).clone()
)"#
)
} else {
format!(
r#"(
"{var_name}".to_string(),
{var_name}.spec().clone()
)"#
)
};
let var_ident = Ident::new(var_name, Span::call_site());
let mut builder_statements = vec![];
let mut launch_grid_expr_strs = vec![];
let validator_statements = if ty.mutability.is_some() {
builder_statements.push(parse_stmt(format!(
"KernelOutputStored::push_kernel_args(&{var_name}, &mut kernel_launch);"
)));
launch_grid_expr_strs.push(format!("KernelOutputStored::grid(&{var_name})?"));
syn::parse2::<ExprBlock>(quote! {{
{
let ValidParamType::Tensor(tensor_validator) = &validator.params[#var_idx] else {
panic!("Unexpected validator type {:#?}", &validator.params[#var_idx]);
};
let valid_shape = &tensor_validator.shape;
let given_shape: Vec<i32> = KernelOutputStored::partition_shape_as_i32(&#var_ident);
kernel_launch_assert(valid_shape.len() == given_shape.len(),
format!("{} rank mismatch: Expected {}, got {}", #var_name, valid_shape.len(), given_shape.len()).as_str())?;
kernel_launch_assert(valid_shape == &given_shape,
format!("{} partition shape mismatch. Expected {:?}, got {:?}", #var_name, valid_shape, given_shape).as_str())?;
}
}})
.unwrap()
} else {
builder_statements.push(parse_stmt(format!(
"KernelInputStored::push_kernel_args(&{var_name}, &mut kernel_launch);"
)));
syn::parse2::<ExprBlock>(quote! {{
{
let ValidParamType::Tensor(tensor_validator) = &validator.params[#var_idx] else {
panic!("Unexpected validator type {:#?}", &validator.params[#var_idx]);
};
let valid_shape = &tensor_validator.shape;
let given_shape = #var_ident.shape();
kernel_launch_assert(valid_shape.len() == given_shape.len(),
format!("{} rank mismatch: Expected {}, got {}", #var_name, valid_shape.len(), given_shape.len()).as_str())?;
let valid_shape_mixed = zip(valid_shape, given_shape).map(|(&expected, &given)|{
if expected == -1 { given } else { expected }
}).collect::<Vec<_>>();
let pred = zip(&valid_shape_mixed, given_shape).all(|(&expected, &given)|{
expected == given
});
kernel_launch_assert(pred,
format!("{} partition shape mismatch. Expected {:?}, got {:?}", #var_name, valid_shape_mixed, given_shape).as_str())?;
}
}})
.unwrap()
};
let element_type_name = if ty.mutability.is_none() {
Some(dtype.clone())
} else {
None
};
Ok(TensorLaunchCode {
fn_arg,
stride_expr_str,
spec_expr_str,
builder_statements,
launch_grid_expr_strs,
validator_statements,
element_type_name,
})
}
pub fn infer_shape_params_from_tensor_type(
var_name: &str,
type_generic_args: &AngleBracketedGenericArguments,
required_generics: &mut RequiredGenerics,
is_mutable: bool,
) -> Result<(), Error> {
for generic_arg in &type_generic_args.args {
match generic_arg {
GenericArgument::Type(syn::Type::Path(type_path)) => {
let last_ident = type_path.path.segments.last().unwrap().ident.to_string();
match required_generics.get_ty(&last_ident) {
SupportedGenericType::TypeParam => {
required_generics
.launcher_type_params
.push(last_ident.clone());
required_generics.expressions.insert(
last_ident.clone(),
Some(format!("vec![{var_name}.dtype_str().to_string()]")),
);
}
SupportedGenericType::ConstArray => {
if is_mutable {
required_generics.expressions.insert(last_ident.clone(), Some(format!("KernelOutputStored::partition_shape_as_i32(&{var_name}).iter().map(|x| x.to_string()).collect::<Vec<String>>()")));
} else {
required_generics.expressions.insert(last_ident.clone(), Some(format!("{var_name}.shape().iter().map(|x| x.to_string()).collect::<Vec<String>>()")));
}
}
SupportedGenericType::ConstScalar => {
return type_path
.err("Unexpected constant scalar type in tensor generic argument.");
}
SupportedGenericType::Unknown => {}
}
}
GenericArgument::Const(Expr::Block(block_expr)) => {
if block_expr.block.stmts.len() != 1 {
return block_expr.err(&format!(
"Expected exactly 1 statement in block expression, got {}.",
block_expr.block.stmts.len()
));
}
let statement = &block_expr.block.stmts[0];
let Stmt::Expr(statement_expr, _) = statement else {
return block_expr
.err("Unexpected block expression: expected an expression statement.");
};
match statement_expr {
Expr::Array(array_expr) => {
for (i, elem) in array_expr.elems.iter().enumerate() {
match elem {
Expr::Lit(_lit) => {
continue;
}
Expr::Unary(_unary_expr) => {
continue;
}
Expr::Path(path) => {
let ident = get_ident_from_path_expr(path).to_string();
match required_generics.get_ty(&ident) {
SupportedGenericType::TypeParam => {
return path.err(
"Unexpected type param in array type expression.",
);
}
SupportedGenericType::ConstArray => {
return path.err("Unexpected const generic array param in array type expression.");
}
SupportedGenericType::ConstScalar => {
if is_mutable {
required_generics.expressions.insert(ident.clone(), Some(format!("vec![KernelOutputStored::partition_shape_as_i32(&{var_name})[{i}].to_string()]")));
} else {
required_generics.expressions.insert(
ident.clone(),
Some(format!(
"vec![{var_name}.shape()[{i}].to_string()]"
)),
);
}
}
SupportedGenericType::Unknown => {}
}
}
_ => {
return elem.err(
"Unsupported array element in tensor shape expression.",
);
}
}
}
}
Expr::Repeat(repeat_expr) => {
return repeat_expr
.err("Repeat expressions in tensor shape are not yet supported.");
}
_ => {
return block_expr
.err("Unexpected block expression in tensor const generic argument.");
}
}
}
_ => {}
}
}
Ok(())
}