kaio_core/ir/param.rs
1//! PTX kernel parameter types.
2
3use crate::types::PtxType;
4
5/// A parameter to a PTX kernel function.
6///
7/// In PTX, kernel parameters are declared in the `.entry` signature:
8/// - Scalar: `.param .u32 n` — a value passed by copy
9/// - Pointer: `.param .u64 a_ptr` — a device memory address (always 64-bit)
10///
11/// For pointer params, `elem_type` records the type of data pointed to
12/// (used by the code generator for load/store instruction types).
13#[derive(Debug, Clone)]
14pub enum PtxParam {
15 /// A scalar value (e.g. `.param .u32 n`).
16 Scalar {
17 /// Parameter name.
18 name: String,
19 /// PTX type of the scalar value.
20 ptx_type: PtxType,
21 },
22 /// A pointer to device memory (declared as `.param .u64 name`).
23 Pointer {
24 /// Parameter name.
25 name: String,
26 /// Type of the data pointed to (for codegen, not the param declaration).
27 elem_type: PtxType,
28 },
29}
30
31impl PtxParam {
32 /// Create a scalar parameter.
33 pub fn scalar(name: &str, ptx_type: PtxType) -> Self {
34 Self::Scalar {
35 name: name.to_string(),
36 ptx_type,
37 }
38 }
39
40 /// Create a pointer parameter.
41 pub fn pointer(name: &str, elem_type: PtxType) -> Self {
42 Self::Pointer {
43 name: name.to_string(),
44 elem_type,
45 }
46 }
47
48 /// The parameter name.
49 pub fn name(&self) -> &str {
50 match self {
51 Self::Scalar { name, .. } | Self::Pointer { name, .. } => name,
52 }
53 }
54
55 /// PTX parameter declaration string (without trailing comma).
56 ///
57 /// - Scalar: `.param .u32 n`
58 /// - Pointer: `.param .u64 a_ptr` (always 64-bit address)
59 pub fn ptx_decl(&self) -> String {
60 match self {
61 Self::Scalar { name, ptx_type } => {
62 format!(".param {} {}", ptx_type.ptx_suffix(), name)
63 }
64 Self::Pointer { name, .. } => {
65 format!(".param .u64 {}", name)
66 }
67 }
68 }
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74
75 #[test]
76 fn scalar_param() {
77 let p = PtxParam::scalar("n", PtxType::U32);
78 assert_eq!(p.name(), "n");
79 assert!(matches!(
80 p,
81 PtxParam::Scalar {
82 ptx_type: PtxType::U32,
83 ..
84 }
85 ));
86 }
87
88 #[test]
89 fn pointer_param() {
90 let p = PtxParam::pointer("a_ptr", PtxType::F32);
91 assert_eq!(p.name(), "a_ptr");
92 assert!(matches!(
93 p,
94 PtxParam::Pointer {
95 elem_type: PtxType::F32,
96 ..
97 }
98 ));
99 }
100
101 #[test]
102 fn scalar_ptx_decl() {
103 let p = PtxParam::scalar("n", PtxType::U32);
104 assert_eq!(p.ptx_decl(), ".param .u32 n");
105 }
106
107 #[test]
108 fn pointer_ptx_decl() {
109 let p = PtxParam::pointer("a_ptr", PtxType::F32);
110 // Pointers always declared as .u64 regardless of elem_type
111 assert_eq!(p.ptx_decl(), ".param .u64 a_ptr");
112 }
113}