1use std::fmt::Write;
9
10use crate::arch::SmVersion;
11use crate::error::PtxGenError;
12use crate::ir::{Instruction, PtxType, RegisterAllocator};
13
14use super::body_builder::BodyBuilder;
15
16type BodyFn = Box<dyn FnOnce(&mut BodyBuilder<'_>)>;
18
19pub struct KernelBuilder {
55 name: String,
57 target: SmVersion,
59 params: Vec<(String, PtxType)>,
61 body_fn: Option<BodyFn>,
63 shared_mem_declarations: Vec<(String, PtxType, usize)>,
65 max_threads: Option<u32>,
67}
68
69impl KernelBuilder {
70 #[must_use]
77 pub fn new(name: &str) -> Self {
78 Self {
79 name: name.to_string(),
80 target: SmVersion::Sm80,
81 params: Vec::new(),
82 body_fn: None,
83 shared_mem_declarations: Vec::new(),
84 max_threads: None,
85 }
86 }
87
88 #[must_use]
94 pub const fn target(mut self, sm: SmVersion) -> Self {
95 self.target = sm;
96 self
97 }
98
99 #[must_use]
105 pub fn param(mut self, name: &str, ty: PtxType) -> Self {
106 self.params.push((name.to_string(), ty));
107 self
108 }
109
110 #[must_use]
115 pub fn shared_mem(mut self, name: &str, ty: PtxType, count: usize) -> Self {
116 self.shared_mem_declarations
117 .push((name.to_string(), ty, count));
118 self
119 }
120
121 #[must_use]
126 pub const fn max_threads_per_block(mut self, n: u32) -> Self {
127 self.max_threads = Some(n);
128 self
129 }
130
131 #[must_use]
137 pub fn body<F>(mut self, f: F) -> Self
138 where
139 F: FnOnce(&mut BodyBuilder<'_>) + 'static,
140 {
141 self.body_fn = Some(Box::new(f));
142 self
143 }
144
145 pub fn build(self) -> Result<String, PtxGenError> {
152 let body_fn = self.body_fn.ok_or(PtxGenError::MissingBody)?;
153
154 let mut regs = RegisterAllocator::new();
156 let mut instructions: Vec<Instruction> = Vec::new();
157 {
158 let param_names: Vec<String> = self.params.iter().map(|(n, _)| n.clone()).collect();
159 let mut bb = BodyBuilder::new(&mut regs, &mut instructions, ¶m_names, self.target);
160 body_fn(&mut bb);
161 }
162
163 let mut ptx = String::with_capacity(4096);
165
166 writeln!(ptx, ".version {}", self.target.ptx_version())?;
168 writeln!(ptx, ".target {}", self.target.as_ptx_str())?;
169 writeln!(ptx, ".address_size 64")?;
170 writeln!(ptx)?;
171
172 write!(ptx, ".visible .entry {}(", self.name)?;
174 for (i, (pname, pty)) in self.params.iter().enumerate() {
175 if i > 0 {
176 write!(ptx, ",")?;
177 }
178 writeln!(ptx)?;
179 write!(ptx, " {} {}", param_type_str(*pty), param_ident(pname))?;
180 }
181 writeln!(ptx)?;
182 writeln!(ptx, ")")?;
183 writeln!(ptx, "{{")?;
184
185 if let Some(n) = self.max_threads {
187 writeln!(ptx, " .maxntid {n}, 1, 1;")?;
188 }
189
190 let reg_decls = regs.emit_declarations();
192 for decl in ®_decls {
193 writeln!(ptx, " {decl}")?;
194 }
195
196 for (sname, sty, count) in &self.shared_mem_declarations {
198 let align = sty.size_bytes().max(4);
199 let total_bytes = sty.size_bytes() * count;
200 writeln!(
201 ptx,
202 " .shared .align {align} .b8 {sname}[{total_bytes}];"
203 )?;
204 }
205
206 if !reg_decls.is_empty() || !self.shared_mem_declarations.is_empty() {
207 writeln!(ptx)?;
208 }
209
210 for inst in &instructions {
212 emit_instruction(&mut ptx, inst)?;
213 }
214
215 writeln!(ptx, "}}")?;
216
217 Ok(ptx)
218 }
219}
220
221const fn param_type_str(ty: PtxType) -> &'static str {
223 match ty {
224 PtxType::U8 => ".param .u8",
225 PtxType::U16 => ".param .u16",
226 PtxType::U32 | PtxType::Pred => ".param .u32",
227 PtxType::U64 => ".param .u64",
228 PtxType::S8 => ".param .s8",
229 PtxType::S16 => ".param .s16",
230 PtxType::S32 => ".param .s32",
231 PtxType::S64 => ".param .s64",
232 PtxType::F16 => ".param .f16",
233 PtxType::BF16 | PtxType::B16 | PtxType::E2M3 | PtxType::E3M2 => ".param .b16",
234 PtxType::F32 => ".param .f32",
235 PtxType::F64 => ".param .f64",
236 PtxType::B8 | PtxType::E4M3 | PtxType::E5M2 | PtxType::E2M1 => ".param .b8",
237 PtxType::B32 | PtxType::F16x2 | PtxType::BF16x2 | PtxType::TF32 => ".param .b32",
238 PtxType::B64 => ".param .b64",
239 PtxType::B128 => ".param .b128",
240 }
241}
242
243fn param_ident(name: &str) -> String {
245 format!("%param_{name}")
246}
247
248fn emit_instruction(out: &mut String, inst: &Instruction) -> Result<(), std::fmt::Error> {
252 let text = inst.emit();
253 match inst {
254 Instruction::Label(_) => writeln!(out, "{text}"),
255 _ => writeln!(out, " {text}"),
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn build_minimal_kernel() {
265 let ptx = KernelBuilder::new("test_kernel")
266 .target(SmVersion::Sm80)
267 .param("n", PtxType::U32)
268 .body(|b| {
269 b.ret();
270 })
271 .build();
272
273 let ptx = ptx.expect("build should succeed");
274 assert!(ptx.contains(".version 7.0"));
275 assert!(ptx.contains(".target sm_80"));
276 assert!(ptx.contains(".address_size 64"));
277 assert!(ptx.contains(".entry test_kernel"));
278 assert!(ptx.contains(".param .u32 %param_n"));
279 assert!(ptx.contains("ret;"));
280 }
281
282 #[test]
283 fn build_missing_body() {
284 let result = KernelBuilder::new("no_body")
285 .target(SmVersion::Sm75)
286 .build();
287
288 assert!(result.is_err());
289 let err = result.expect_err("should be MissingBody");
290 assert!(matches!(err, PtxGenError::MissingBody));
291 }
292
293 #[test]
294 fn build_with_shared_mem() {
295 let ptx = KernelBuilder::new("smem_kernel")
296 .target(SmVersion::Sm80)
297 .shared_mem("tile_a", PtxType::F32, 1024)
298 .body(|b| {
299 b.ret();
300 })
301 .build()
302 .expect("build should succeed");
303
304 assert!(ptx.contains(".shared .align 4 .b8 tile_a[4096];"));
305 }
306
307 #[test]
308 fn build_with_max_threads() {
309 let ptx = KernelBuilder::new("bounded_kernel")
310 .target(SmVersion::Sm80)
311 .max_threads_per_block(256)
312 .body(|b| {
313 b.ret();
314 })
315 .build()
316 .expect("build should succeed");
317
318 assert!(ptx.contains(".maxntid 256, 1, 1;"));
319 }
320
321 #[test]
322 fn param_type_str_coverage() {
323 assert_eq!(param_type_str(PtxType::U32), ".param .u32");
324 assert_eq!(param_type_str(PtxType::U64), ".param .u64");
325 assert_eq!(param_type_str(PtxType::F32), ".param .f32");
326 assert_eq!(param_type_str(PtxType::F64), ".param .f64");
327 assert_eq!(param_type_str(PtxType::S32), ".param .s32");
328 assert_eq!(param_type_str(PtxType::B32), ".param .b32");
329 assert_eq!(param_type_str(PtxType::B128), ".param .b128");
330 }
331}