use proc_macro2::TokenStream;
use quote::quote;
use crate::kernel_ir::KernelSignature;
use crate::kernel_ir::stmt::KernelStmt;
use crate::lower;
use crate::lower::LoweringContext;
pub fn generate_build_ptx(sig: &KernelSignature, body: &[KernelStmt]) -> syn::Result<TokenStream> {
let kernel_name = &sig.name;
let mut ctx = LoweringContext::new();
ctx.block_size = Some(sig.config.block_size);
let param_tokens = lower::params::lower_params(&mut ctx, &sig.params)?;
let body_tokens = lower::lower_stmts(&mut ctx, body)?;
let total_shared_bytes: u32 = ctx
.shared_arrays
.values()
.map(|(ty, count)| (ty.size_bytes() * count) as u32)
.sum::<u32>()
+ if ctx.reduce_smem_allocated {
(ctx.block_size.unwrap_or(256) / 32) * 4
} else {
0
};
let shared_mem_diagnostic = if total_shared_bytes > 0 {
let kb = total_shared_bytes as f64 / 1024.0;
if total_shared_bytes > 49152 {
quote! {
eprintln!("KAIO warning: kernel '{}' uses {} bytes ({:.1} KB) of shared memory — exceeds 48 KB default limit",
#kernel_name, #total_shared_bytes, #kb);
}
} else {
quote! {}
}
} else {
quote! {}
};
Ok(quote! {
fn build_ptx() -> String {
use kaio::core::emit::{Emit, PtxWriter};
use kaio::core::instr::ArithOp;
use kaio::core::instr::control::{CmpOp, ControlOp};
use kaio::core::instr::memory::MemoryOp;
use kaio::core::instr::special;
use kaio::core::ir::{
Operand, PtxInstruction, PtxKernel, PtxModule, PtxParam, RegisterAllocator,
SharedDecl,
};
use kaio::core::types::PtxType;
let mut alloc = RegisterAllocator::new();
let mut kernel = PtxKernel::new(#kernel_name);
#param_tokens
#body_tokens
kernel.push(PtxInstruction::Control(ControlOp::Ret));
kernel.set_registers(alloc.into_allocated());
let sm_target = std::env::var("KAIO_SM_TARGET")
.unwrap_or_else(|_| "sm_70".to_string());
let mut module = PtxModule::new(&sm_target);
module.add_kernel(kernel);
let mut w = PtxWriter::new();
module.emit(&mut w).unwrap();
let ptx = w.finish();
#shared_mem_diagnostic
if std::env::var("KAIO_DUMP_PTX").is_ok() {
eprintln!("=== KAIO PTX: {} ===\n{}", #kernel_name, ptx);
}
ptx
}
})
}