use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::Ident;
use crate::kernel_ir::{KernelSignature, KernelType};
pub fn generate_launch_fn(sig: &KernelSignature) -> syn::Result<TokenStream> {
let kernel_name = &sig.name;
let mut launch_params: Vec<TokenStream> = Vec::new();
let mut arg_calls: Vec<TokenStream> = Vec::new();
let mut elem_count_ident: Option<Ident> = None;
launch_params.push(quote! { device: &kaio::runtime::KaioDevice });
for param in &sig.params {
let name = format_ident!("{}", param.name);
match ¶m.ty {
KernelType::SliceRef(elem) => {
let rust_ty = rust_type_tokens(elem);
launch_params.push(quote! { #name: &kaio::runtime::GpuBuffer<#rust_ty> });
arg_calls.push(quote! { .arg(#name.inner()) });
}
KernelType::SliceMutRef(elem) => {
let rust_ty = rust_type_tokens(elem);
launch_params.push(quote! { #name: &mut kaio::runtime::GpuBuffer<#rust_ty> });
arg_calls.push(quote! { .arg(#name.inner_mut()) });
}
scalar_ty => {
let rust_ty = rust_type_tokens(scalar_ty);
launch_params.push(quote! { #name: #rust_ty });
arg_calls.push(quote! { .arg(&#name) });
if *scalar_ty == KernelType::U32 {
elem_count_ident = Some(name.clone());
}
}
}
}
let n_ident = elem_count_ident.ok_or_else(|| {
syn::Error::new(
sig.name_span,
"GPU kernel must have at least one `u32` parameter for element count \
(used by LaunchConfig::for_num_elems)",
)
})?;
Ok(quote! {
pub fn launch(#(#launch_params),*) -> Result<(), kaio::runtime::KaioError> {
use kaio::runtime::PushKernelArg;
let ptx = PTX_CACHE.get_or_init(build_ptx);
let module = device.load_ptx(ptx)?;
let func = module.function(#kernel_name)?;
let cfg = kaio::runtime::LaunchConfig::for_num_elems(#n_ident);
unsafe {
device
.stream()
.launch_builder(func.inner())
#(#arg_calls)*
.launch(cfg)?;
}
Ok(())
}
})
}
fn rust_type_tokens(ty: &KernelType) -> TokenStream {
match ty {
KernelType::F32 => quote! { f32 },
KernelType::F64 => quote! { f64 },
KernelType::I32 => quote! { i32 },
KernelType::U32 => quote! { u32 },
KernelType::I64 => quote! { i64 },
KernelType::U64 => quote! { u64 },
KernelType::Bool => quote! { bool },
_ => panic!("rust_type_tokens called on slice type"),
}
}