oxicuda_ptx/ir/function.rs
1//! PTX kernel and device function definitions.
2//!
3//! A [`PtxFunction`] represents a complete PTX function (`.entry` kernel or
4//! `.func` device function) including its parameters, body instructions,
5//! shared memory declarations, and optional launch bounds.
6
7use super::instruction::Instruction;
8use super::types::PtxType;
9
10/// A PTX kernel or device function definition.
11///
12/// This structure holds all the information needed to emit a complete PTX
13/// function: the function signature (name and typed parameters), the instruction
14/// body, any shared memory allocations, and optional performance hints.
15///
16/// # Examples
17///
18/// ```
19/// use oxicuda_ptx::ir::{PtxFunction, PtxType};
20///
21/// let func = PtxFunction {
22/// name: "vector_add".to_string(),
23/// params: vec![
24/// ("a_ptr".to_string(), PtxType::U64),
25/// ("b_ptr".to_string(), PtxType::U64),
26/// ("c_ptr".to_string(), PtxType::U64),
27/// ("n".to_string(), PtxType::U32),
28/// ],
29/// body: Vec::new(),
30/// shared_mem: Vec::new(),
31/// max_threads: Some(256),
32/// };
33/// assert_eq!(func.params.len(), 4);
34/// ```
35#[derive(Debug, Clone)]
36pub struct PtxFunction {
37 /// The function name (without leading underscore — emitter adds `$` prefix if needed).
38 pub name: String,
39 /// Kernel parameters as `(name, type)` pairs.
40 pub params: Vec<(String, PtxType)>,
41 /// The instruction body of the function.
42 pub body: Vec<Instruction>,
43 /// Static shared memory declarations as `(name, element_type, num_elements)`.
44 pub shared_mem: Vec<(String, PtxType, usize)>,
45 /// Optional `.maxnthreads` directive (launch bounds hint to `ptxas`).
46 pub max_threads: Option<u32>,
47}
48
49impl PtxFunction {
50 /// Creates a new empty function with the given name.
51 #[must_use]
52 pub fn new(name: impl Into<String>) -> Self {
53 Self {
54 name: name.into(),
55 params: Vec::new(),
56 body: Vec::new(),
57 shared_mem: Vec::new(),
58 max_threads: None,
59 }
60 }
61
62 /// Adds a parameter to the function signature.
63 pub fn add_param(&mut self, name: impl Into<String>, ty: PtxType) {
64 self.params.push((name.into(), ty));
65 }
66
67 /// Adds a static shared memory allocation.
68 pub fn add_shared_mem(&mut self, name: impl Into<String>, ty: PtxType, count: usize) {
69 self.shared_mem.push((name.into(), ty, count));
70 }
71
72 /// Appends an instruction to the function body.
73 pub fn push(&mut self, inst: Instruction) {
74 self.body.push(inst);
75 }
76}