use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::Ident;
use crate::kernel_ir::{KernelSignature, KernelType};
pub fn generate_launch_fn(sig: &KernelSignature) -> syn::Result<TokenStream> {
let kernel_name = &sig.name;
let is_2d = sig.config.block_size_y.is_some();
let mut launch_params: Vec<TokenStream> = Vec::new();
let mut arg_calls: Vec<TokenStream> = Vec::new();
let mut elem_count_ident: Option<Ident> = None;
launch_params.push(quote! { device: &kaio::runtime::KaioDevice });
for param in &sig.params {
let name = format_ident!("{}", param.name);
match ¶m.ty {
KernelType::SliceRef(elem) => {
let rust_ty = rust_type_tokens(elem);
launch_params.push(quote! { #name: &kaio::runtime::GpuBuffer<#rust_ty> });
arg_calls.push(quote! { .arg(#name.inner()) });
}
KernelType::SliceMutRef(elem) => {
let rust_ty = rust_type_tokens(elem);
launch_params.push(quote! { #name: &mut kaio::runtime::GpuBuffer<#rust_ty> });
arg_calls.push(quote! { .arg(#name.inner_mut()) });
}
scalar_ty => {
let rust_ty = rust_type_tokens(scalar_ty);
launch_params.push(quote! { #name: #rust_ty });
arg_calls.push(quote! { .arg(&#name) });
if !is_2d && *scalar_ty == KernelType::U32 {
elem_count_ident = Some(name.clone());
}
}
}
}
let launch_config_expr = if is_2d {
let bx = sig.config.block_size;
let by = sig.config.block_size_y.unwrap(); launch_params.push(quote! { grid: (u32, u32, u32) });
quote! {
kaio::runtime::LaunchConfig {
grid_dim: grid,
block_dim: (#bx, #by, 1),
shared_mem_bytes: 0,
}
}
} else {
let n_ident = elem_count_ident.ok_or_else(|| {
syn::Error::new(
sig.name_span,
"GPU kernel must have at least one `u32` parameter for element count \
(used to compute grid size). For 2D kernels, use \
`block_size = (X, Y)` which accepts a `grid: (u32, u32, u32)` parameter.",
)
})?;
{
let bs = sig.config.block_size;
quote! {
kaio::runtime::LaunchConfig {
grid_dim: (#n_ident.div_ceil(#bs), 1, 1),
block_dim: (#bs, 1, 1),
shared_mem_bytes: 0,
}
}
}
};
Ok(quote! {
pub fn launch(#(#launch_params),*) -> Result<(), kaio::runtime::KaioError> {
use kaio::runtime::PushKernelArg;
let ptx = PTX_CACHE.get_or_init(build_ptx);
let module = device.load_ptx(ptx)?;
let func = module.function(#kernel_name)?;
let cfg = #launch_config_expr;
unsafe {
device
.stream()
.launch_builder(func.inner())
#(#arg_calls)*
.launch(cfg)?;
}
Ok(())
}
})
}
fn rust_type_tokens(ty: &KernelType) -> TokenStream {
match ty {
KernelType::F32 => quote! { f32 },
KernelType::F64 => quote! { f64 },
KernelType::I32 => quote! { i32 },
KernelType::U32 => quote! { u32 },
KernelType::I64 => quote! { i64 },
KernelType::U64 => quote! { u64 },
KernelType::Bool => quote! { bool },
_ => panic!("rust_type_tokens called on slice type"),
}
}