Skip to main content

kaio_core/ir/
param.rs

1//! PTX kernel parameter types.
2
3use crate::types::PtxType;
4
5/// A parameter to a PTX kernel function.
6///
7/// In PTX, kernel parameters are declared in the `.entry` signature:
8/// - Scalar: `.param .u32 n` — a value passed by copy
9/// - Pointer: `.param .u64 a_ptr` — a device memory address (always 64-bit)
10///
11/// For pointer params, `elem_type` records the type of data pointed to
12/// (used by the code generator for load/store instruction types).
13#[derive(Debug, Clone)]
14pub enum PtxParam {
15    /// A scalar value (e.g. `.param .u32 n`).
16    Scalar {
17        /// Parameter name.
18        name: String,
19        /// PTX type of the scalar value.
20        ptx_type: PtxType,
21    },
22    /// A pointer to device memory (declared as `.param .u64 name`).
23    Pointer {
24        /// Parameter name.
25        name: String,
26        /// Type of the data pointed to (for codegen, not the param declaration).
27        elem_type: PtxType,
28    },
29}
30
31impl PtxParam {
32    /// Create a scalar parameter.
33    pub fn scalar(name: &str, ptx_type: PtxType) -> Self {
34        Self::Scalar {
35            name: name.to_string(),
36            ptx_type,
37        }
38    }
39
40    /// Create a pointer parameter.
41    pub fn pointer(name: &str, elem_type: PtxType) -> Self {
42        Self::Pointer {
43            name: name.to_string(),
44            elem_type,
45        }
46    }
47
48    /// The parameter name.
49    pub fn name(&self) -> &str {
50        match self {
51            Self::Scalar { name, .. } | Self::Pointer { name, .. } => name,
52        }
53    }
54
55    /// PTX parameter declaration string (without trailing comma).
56    ///
57    /// - Scalar: `.param .u32 n`
58    /// - Pointer: `.param .u64 a_ptr` (always 64-bit address)
59    pub fn ptx_decl(&self) -> String {
60        match self {
61            Self::Scalar { name, ptx_type } => {
62                format!(".param {} {}", ptx_type.ptx_suffix(), name)
63            }
64            Self::Pointer { name, .. } => {
65                format!(".param .u64 {}", name)
66            }
67        }
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74
75    #[test]
76    fn scalar_param() {
77        let p = PtxParam::scalar("n", PtxType::U32);
78        assert_eq!(p.name(), "n");
79        assert!(matches!(
80            p,
81            PtxParam::Scalar {
82                ptx_type: PtxType::U32,
83                ..
84            }
85        ));
86    }
87
88    #[test]
89    fn pointer_param() {
90        let p = PtxParam::pointer("a_ptr", PtxType::F32);
91        assert_eq!(p.name(), "a_ptr");
92        assert!(matches!(
93            p,
94            PtxParam::Pointer {
95                elem_type: PtxType::F32,
96                ..
97            }
98        ));
99    }
100
101    #[test]
102    fn scalar_ptx_decl() {
103        let p = PtxParam::scalar("n", PtxType::U32);
104        assert_eq!(p.ptx_decl(), ".param .u32 n");
105    }
106
107    #[test]
108    fn pointer_ptx_decl() {
109        let p = PtxParam::pointer("a_ptr", PtxType::F32);
110        // Pointers always declared as .u64 regardless of elem_type
111        assert_eq!(p.ptx_decl(), ".param .u64 a_ptr");
112    }
113}