use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::Ident;
use crate::kernel_ir::{KernelSignature, KernelType};
pub fn generate_launch_fn(sig: &KernelSignature) -> syn::Result<TokenStream> {
let kernel_name = &sig.name;
let is_2d = sig.config.block_size_y.is_some();
let mut launch_params: Vec<TokenStream> = Vec::new();
let mut arg_calls: Vec<TokenStream> = Vec::new();
let mut elem_count_ident: Option<Ident> = None;
launch_params.push(quote! { device: &kaio::runtime::KaioDevice });
for param in &sig.params {
let name = format_ident!("{}", param.name);
match ¶m.ty {
KernelType::SliceRef(elem) => {
let rust_ty = rust_type_tokens(elem);
launch_params.push(quote! { #name: &kaio::runtime::GpuBuffer<#rust_ty> });
arg_calls.push(quote! { .arg(#name.inner()) });
}
KernelType::SliceMutRef(elem) => {
let rust_ty = rust_type_tokens(elem);
launch_params.push(quote! { #name: &mut kaio::runtime::GpuBuffer<#rust_ty> });
arg_calls.push(quote! { .arg(#name.inner_mut()) });
}
scalar_ty => {
let rust_ty = rust_type_tokens(scalar_ty);
launch_params.push(quote! { #name: #rust_ty });
arg_calls.push(quote! { .arg(&#name) });
if !is_2d && *scalar_ty == KernelType::U32 {
elem_count_ident = Some(name.clone());
}
}
}
}
let launch_config_expr = if is_2d {
let bx = sig.config.block_size;
let by = sig.config.block_size_y.unwrap(); launch_params.push(quote! { grid: (u32, u32, u32) });
quote! {
kaio::runtime::LaunchConfig {
grid_dim: grid,
block_dim: (#bx, #by, 1),
shared_mem_bytes: 0,
}
}
} else {
let n_ident = elem_count_ident.ok_or_else(|| {
syn::Error::new(
sig.name_span,
"GPU kernel must have at least one `u32` parameter for element count \
(used to compute grid size). For 2D kernels, use \
`block_size = (X, Y)` which accepts a `grid: (u32, u32, u32)` parameter.",
)
})?;
{
let bs = sig.config.block_size;
quote! {
kaio::runtime::LaunchConfig {
grid_dim: (#n_ident.div_ceil(#bs), 1, 1),
block_dim: (#bs, 1, 1),
shared_mem_bytes: 0,
}
}
}
};
Ok(quote! {
pub fn launch(#(#launch_params),*) -> Result<(), kaio::runtime::KaioError> {
use kaio::runtime::PushKernelArg;
let info = device.info()?;
let (major, minor) = info.compute_capability;
let sm = format!("sm_{major}{minor}");
let ptx_module = build_module(&sm);
let module = device.load_module(&ptx_module)?;
let func = module.function(#kernel_name)?;
let cfg = #launch_config_expr;
unsafe {
device
.stream()
.launch_builder(func.inner())
#(#arg_calls)*
.launch(cfg)?;
}
Ok(())
}
})
}
fn rust_type_tokens(ty: &KernelType) -> TokenStream {
match ty {
KernelType::F32 => quote! { f32 },
KernelType::F64 => quote! { f64 },
KernelType::I32 => quote! { i32 },
KernelType::U32 => quote! { u32 },
KernelType::I64 => quote! { i64 },
KernelType::U64 => quote! { u64 },
KernelType::Bool => quote! { bool },
_ => panic!("rust_type_tokens called on slice type"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel_ir::{KernelConfig, KernelParam, KernelSignature, KernelType};
use proc_macro2::Span;
fn mock_param_slice_mut_f32(name: &str) -> KernelParam {
KernelParam {
name: name.to_string(),
ty: KernelType::SliceMutRef(Box::new(KernelType::F32)),
span: Span::call_site(),
}
}
fn mock_param_u32(name: &str) -> KernelParam {
KernelParam {
name: name.to_string(),
ty: KernelType::U32,
span: Span::call_site(),
}
}
fn mock_signature_1d(name: &str, block_size: u32) -> KernelSignature {
KernelSignature {
name: name.to_string(),
params: vec![mock_param_slice_mut_f32("out"), mock_param_u32("n")],
config: KernelConfig {
block_size,
block_size_y: None,
block_size_span: Span::call_site(),
},
name_span: Span::call_site(),
}
}
fn mock_signature_2d(name: &str, block_size_x: u32, block_size_y: u32) -> KernelSignature {
KernelSignature {
name: name.to_string(),
params: vec![mock_param_slice_mut_f32("out")],
config: KernelConfig {
block_size: block_size_x,
block_size_y: Some(block_size_y),
block_size_span: Span::call_site(),
},
name_span: Span::call_site(),
}
}
#[test]
fn launch_wrapper_emits_correct_block_dim_1d() {
let sig = mock_signature_1d("test_kernel_1d", 256);
let output = generate_launch_fn(&sig)
.expect("codegen should succeed for mock signature")
.to_string();
assert!(
output.contains("block_dim : (256u32 , 1 , 1)"),
"expected block_dim (256u32, 1, 1) for block_size=256, got:\n{output}"
);
}
#[test]
fn launch_wrapper_emits_correct_block_dim_2d() {
let sig = mock_signature_2d("test_kernel_2d", 16, 8);
let output = generate_launch_fn(&sig)
.expect("codegen should succeed for mock signature")
.to_string();
assert!(
output.contains("block_dim : (16u32 , 8u32 , 1)"),
"expected block_dim (16u32, 8u32, 1) for block_size=(16, 8), got:\n{output}"
);
}
#[test]
fn launch_wrapper_threads_compute_capability_into_module_build() {
let sig = mock_signature_1d("test_kernel_sm", 256);
let output = generate_launch_fn(&sig)
.expect("codegen should succeed for mock signature")
.to_string();
assert!(
output.contains("compute_capability"),
"expected launch wrapper to read device.info().compute_capability, got:\n{output}"
);
assert!(
output.contains("load_module"),
"expected launch wrapper to call device.load_module, got:\n{output}"
);
assert!(
!output.contains("load_ptx"),
"launch wrapper should not call device.load_ptx after D1a migration, got:\n{output}"
);
}
}