Skip to main content

oxicuda_ptx/builder/
kernel_builder.rs

1//! Kernel-level PTX builder.
2//!
3//! [`KernelBuilder`] is the top-level entry point for constructing a complete
4//! PTX module containing a single kernel entry point. It collects kernel
5//! metadata (name, target architecture, parameters, shared memory declarations)
6//! and delegates instruction generation to a [`BodyBuilder`] closure.
7
8use std::fmt::Write;
9
10use crate::arch::SmVersion;
11use crate::error::PtxGenError;
12use crate::ir::{Instruction, PtxType, RegisterAllocator};
13
14use super::body_builder::BodyBuilder;
15
16/// Type alias for the body closure to reduce type complexity.
17type BodyFn = Box<dyn FnOnce(&mut BodyBuilder<'_>)>;
18
19/// Builder for constructing complete PTX kernel modules.
20///
21/// `KernelBuilder` follows the fluent builder pattern: chain configuration
22/// methods, supply a body closure, and call [`build`] to produce the final
23/// PTX text.
24///
25/// # Example
26///
27/// ```
28/// use oxicuda_ptx::builder::KernelBuilder;
29/// use oxicuda_ptx::arch::SmVersion;
30/// use oxicuda_ptx::ir::PtxType;
31///
32/// let ptx = KernelBuilder::new("vector_add")
33///     .target(SmVersion::Sm80)
34///     .param("a", PtxType::U64)
35///     .param("b", PtxType::U64)
36///     .param("c", PtxType::U64)
37///     .param("n", PtxType::U32)
38///     .body(|b| {
39///         let tid = b.global_thread_id_x();
40///         let n_reg = b.load_param_u32("n");
41///         b.if_lt_u32(tid, n_reg, |b| {
42///             b.comment("kernel body goes here");
43///         });
44///         b.ret();
45///     })
46///     .build()
47///     .expect("PTX generation failed");
48///
49/// assert!(ptx.contains(".entry vector_add"));
50/// assert!(ptx.contains(".target sm_80"));
51/// ```
52///
53/// [`build`]: KernelBuilder::build
54pub struct KernelBuilder {
55    /// Kernel function name.
56    name: String,
57    /// Target GPU architecture.
58    target: SmVersion,
59    /// Kernel parameters as (name, type) pairs.
60    params: Vec<(String, PtxType)>,
61    /// Body closure that populates instructions via `BodyBuilder`.
62    body_fn: Option<BodyFn>,
63    /// Static shared memory declarations: (name, `element_type`, `element_count`).
64    shared_mem_declarations: Vec<(String, PtxType, usize)>,
65    /// Optional `.maxntid` directive (maximum threads per block).
66    max_threads: Option<u32>,
67}
68
69impl KernelBuilder {
70    /// Creates a new kernel builder with the given kernel name.
71    ///
72    /// The default target is [`SmVersion::Sm80`] (Ampere). Call [`target`]
73    /// to override.
74    ///
75    /// [`target`]: KernelBuilder::target
76    #[must_use]
77    pub fn new(name: &str) -> Self {
78        Self {
79            name: name.to_string(),
80            target: SmVersion::Sm80,
81            params: Vec::new(),
82            body_fn: None,
83            shared_mem_declarations: Vec::new(),
84            max_threads: None,
85        }
86    }
87
88    /// Sets the target GPU architecture for this kernel.
89    ///
90    /// This determines the `.target` and `.version` directives in the
91    /// generated PTX, and also controls which instructions the
92    /// [`BodyBuilder`] may emit.
93    #[must_use]
94    pub const fn target(mut self, sm: SmVersion) -> Self {
95        self.target = sm;
96        self
97    }
98
99    /// Adds a kernel parameter with the given name and type.
100    ///
101    /// Parameters are emitted in declaration order in the `.entry` signature.
102    /// Common types: `PtxType::U64` for pointers, `PtxType::U32` / `PtxType::F32`
103    /// for scalar arguments.
104    #[must_use]
105    pub fn param(mut self, name: &str, ty: PtxType) -> Self {
106        self.params.push((name.to_string(), ty));
107        self
108    }
109
110    /// Declares a static shared memory allocation.
111    ///
112    /// This generates a `.shared .align` declaration at the top of the
113    /// kernel body. The total size is `count * ty.size_bytes()` bytes.
114    #[must_use]
115    pub fn shared_mem(mut self, name: &str, ty: PtxType, count: usize) -> Self {
116        self.shared_mem_declarations
117            .push((name.to_string(), ty, count));
118        self
119    }
120
121    /// Sets the `.maxntid` directive, hinting to `ptxas` the maximum
122    /// number of threads per block this kernel will be launched with.
123    ///
124    /// This can improve register allocation and occupancy planning.
125    #[must_use]
126    pub const fn max_threads_per_block(mut self, n: u32) -> Self {
127        self.max_threads = Some(n);
128        self
129    }
130
131    /// Supplies the body closure that generates the kernel's instructions.
132    ///
133    /// The closure receives a mutable reference to a [`BodyBuilder`] which
134    /// provides the instruction emission API (loads, stores, arithmetic,
135    /// control flow, tensor core ops, etc.).
136    #[must_use]
137    pub fn body<F>(mut self, f: F) -> Self
138    where
139        F: FnOnce(&mut BodyBuilder<'_>) + 'static,
140    {
141        self.body_fn = Some(Box::new(f));
142        self
143    }
144
145    /// Consumes the builder and generates the complete PTX module text.
146    ///
147    /// # Errors
148    ///
149    /// Returns [`PtxGenError::MissingBody`] if no body closure was provided.
150    /// Returns [`PtxGenError::FormatError`] if string formatting fails.
151    pub fn build(self) -> Result<String, PtxGenError> {
152        let body_fn = self.body_fn.ok_or(PtxGenError::MissingBody)?;
153
154        // Phase 1: Execute the body closure to collect instructions.
155        let mut regs = RegisterAllocator::new();
156        let mut instructions: Vec<Instruction> = Vec::new();
157        {
158            let param_names: Vec<String> = self.params.iter().map(|(n, _)| n.clone()).collect();
159            let mut bb = BodyBuilder::new(&mut regs, &mut instructions, &param_names, self.target);
160            body_fn(&mut bb);
161        }
162
163        // Phase 2: Generate PTX text.
164        let mut ptx = String::with_capacity(4096);
165
166        // Header directives.
167        writeln!(ptx, ".version {}", self.target.ptx_version())?;
168        writeln!(ptx, ".target {}", self.target.as_ptx_str())?;
169        writeln!(ptx, ".address_size 64")?;
170        writeln!(ptx)?;
171
172        // Kernel entry point.
173        write!(ptx, ".visible .entry {}(", self.name)?;
174        for (i, (pname, pty)) in self.params.iter().enumerate() {
175            if i > 0 {
176                write!(ptx, ",")?;
177            }
178            writeln!(ptx)?;
179            write!(ptx, "    {} {}", param_type_str(*pty), param_ident(pname))?;
180        }
181        writeln!(ptx)?;
182        writeln!(ptx, ")")?;
183        writeln!(ptx, "{{")?;
184
185        // .maxntid directive.
186        if let Some(n) = self.max_threads {
187            writeln!(ptx, "    .maxntid {n}, 1, 1;")?;
188        }
189
190        // Register declarations.
191        let reg_decls = regs.emit_declarations();
192        for decl in &reg_decls {
193            writeln!(ptx, "    {decl}")?;
194        }
195
196        // Shared memory declarations.
197        for (sname, sty, count) in &self.shared_mem_declarations {
198            let align = sty.size_bytes().max(4);
199            let total_bytes = sty.size_bytes() * count;
200            writeln!(
201                ptx,
202                "    .shared .align {align} .b8 {sname}[{total_bytes}];"
203            )?;
204        }
205
206        if !reg_decls.is_empty() || !self.shared_mem_declarations.is_empty() {
207            writeln!(ptx)?;
208        }
209
210        // Instructions.
211        for inst in &instructions {
212            emit_instruction(&mut ptx, inst)?;
213        }
214
215        writeln!(ptx, "}}")?;
216
217        Ok(ptx)
218    }
219}
220
221/// Returns the PTX `.param` type annotation for a kernel parameter type.
222const fn param_type_str(ty: PtxType) -> &'static str {
223    match ty {
224        PtxType::U8 => ".param .u8",
225        PtxType::U16 => ".param .u16",
226        PtxType::U32 | PtxType::Pred => ".param .u32",
227        PtxType::U64 => ".param .u64",
228        PtxType::S8 => ".param .s8",
229        PtxType::S16 => ".param .s16",
230        PtxType::S32 => ".param .s32",
231        PtxType::S64 => ".param .s64",
232        PtxType::F16 => ".param .f16",
233        PtxType::BF16 | PtxType::B16 | PtxType::E2M3 | PtxType::E3M2 => ".param .b16",
234        PtxType::F32 => ".param .f32",
235        PtxType::F64 => ".param .f64",
236        PtxType::B8 | PtxType::E4M3 | PtxType::E5M2 | PtxType::E2M1 => ".param .b8",
237        PtxType::B32 | PtxType::F16x2 | PtxType::BF16x2 | PtxType::TF32 => ".param .b32",
238        PtxType::B64 => ".param .b64",
239        PtxType::B128 => ".param .b128",
240    }
241}
242
243/// Returns the PTX-safe parameter identifier (prefixed with `%param_`).
244fn param_ident(name: &str) -> String {
245    format!("%param_{name}")
246}
247
248/// Emits a single PTX instruction as text, appending to `out`.
249///
250/// Each instruction is indented by 4 spaces (labels have no indentation).
251fn emit_instruction(out: &mut String, inst: &Instruction) -> Result<(), std::fmt::Error> {
252    let text = inst.emit();
253    match inst {
254        Instruction::Label(_) => writeln!(out, "{text}"),
255        _ => writeln!(out, "    {text}"),
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn build_minimal_kernel() {
265        let ptx = KernelBuilder::new("test_kernel")
266            .target(SmVersion::Sm80)
267            .param("n", PtxType::U32)
268            .body(|b| {
269                b.ret();
270            })
271            .build();
272
273        let ptx = ptx.expect("build should succeed");
274        assert!(ptx.contains(".version 7.0"));
275        assert!(ptx.contains(".target sm_80"));
276        assert!(ptx.contains(".address_size 64"));
277        assert!(ptx.contains(".entry test_kernel"));
278        assert!(ptx.contains(".param .u32 %param_n"));
279        assert!(ptx.contains("ret;"));
280    }
281
282    #[test]
283    fn build_missing_body() {
284        let result = KernelBuilder::new("no_body")
285            .target(SmVersion::Sm75)
286            .build();
287
288        assert!(result.is_err());
289        let err = result.expect_err("should be MissingBody");
290        assert!(matches!(err, PtxGenError::MissingBody));
291    }
292
293    #[test]
294    fn build_with_shared_mem() {
295        let ptx = KernelBuilder::new("smem_kernel")
296            .target(SmVersion::Sm80)
297            .shared_mem("tile_a", PtxType::F32, 1024)
298            .body(|b| {
299                b.ret();
300            })
301            .build()
302            .expect("build should succeed");
303
304        assert!(ptx.contains(".shared .align 4 .b8 tile_a[4096];"));
305    }
306
307    #[test]
308    fn build_with_max_threads() {
309        let ptx = KernelBuilder::new("bounded_kernel")
310            .target(SmVersion::Sm80)
311            .max_threads_per_block(256)
312            .body(|b| {
313                b.ret();
314            })
315            .build()
316            .expect("build should succeed");
317
318        assert!(ptx.contains(".maxntid 256, 1, 1;"));
319    }
320
321    #[test]
322    fn param_type_str_coverage() {
323        assert_eq!(param_type_str(PtxType::U32), ".param .u32");
324        assert_eq!(param_type_str(PtxType::U64), ".param .u64");
325        assert_eq!(param_type_str(PtxType::F32), ".param .f32");
326        assert_eq!(param_type_str(PtxType::F64), ".param .f64");
327        assert_eq!(param_type_str(PtxType::S32), ".param .s32");
328        assert_eq!(param_type_str(PtxType::B32), ".param .b32");
329        assert_eq!(param_type_str(PtxType::B128), ".param .b128");
330    }
331}