Skip to main content

oxicuda_ptx/ir/
function.rs

1//! PTX kernel and device function definitions.
2//!
3//! A [`PtxFunction`] represents a complete PTX function (`.entry` kernel or
4//! `.func` device function) including its parameters, body instructions,
5//! shared memory declarations, and optional launch bounds.
6
7use super::instruction::Instruction;
8use super::types::PtxType;
9
10/// A PTX kernel or device function definition.
11///
12/// This structure holds all the information needed to emit a complete PTX
13/// function: the function signature (name and typed parameters), the instruction
14/// body, any shared memory allocations, and optional performance hints.
15///
16/// # Examples
17///
18/// ```
19/// use oxicuda_ptx::ir::{PtxFunction, PtxType};
20///
21/// let func = PtxFunction {
22///     name: "vector_add".to_string(),
23///     params: vec![
24///         ("a_ptr".to_string(), PtxType::U64),
25///         ("b_ptr".to_string(), PtxType::U64),
26///         ("c_ptr".to_string(), PtxType::U64),
27///         ("n".to_string(), PtxType::U32),
28///     ],
29///     body: Vec::new(),
30///     shared_mem: Vec::new(),
31///     max_threads: Some(256),
32/// };
33/// assert_eq!(func.params.len(), 4);
34/// ```
35#[derive(Debug, Clone)]
36pub struct PtxFunction {
37    /// The function name (without leading underscore — emitter adds `$` prefix if needed).
38    pub name: String,
39    /// Kernel parameters as `(name, type)` pairs.
40    pub params: Vec<(String, PtxType)>,
41    /// The instruction body of the function.
42    pub body: Vec<Instruction>,
43    /// Static shared memory declarations as `(name, element_type, num_elements)`.
44    pub shared_mem: Vec<(String, PtxType, usize)>,
45    /// Optional `.maxnthreads` directive (launch bounds hint to `ptxas`).
46    pub max_threads: Option<u32>,
47}
48
49impl PtxFunction {
50    /// Creates a new empty function with the given name.
51    #[must_use]
52    pub fn new(name: impl Into<String>) -> Self {
53        Self {
54            name: name.into(),
55            params: Vec::new(),
56            body: Vec::new(),
57            shared_mem: Vec::new(),
58            max_threads: None,
59        }
60    }
61
62    /// Adds a parameter to the function signature.
63    pub fn add_param(&mut self, name: impl Into<String>, ty: PtxType) {
64        self.params.push((name.into(), ty));
65    }
66
67    /// Adds a static shared memory allocation.
68    pub fn add_shared_mem(&mut self, name: impl Into<String>, ty: PtxType, count: usize) {
69        self.shared_mem.push((name.into(), ty, count));
70    }
71
72    /// Appends an instruction to the function body.
73    pub fn push(&mut self, inst: Instruction) {
74        self.body.push(inst);
75    }
76}