Skip to main content

wave_compiler/hir/
kernel.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! HIR kernel definition for WAVE GPU kernels.
5//!
6//! A kernel is the top-level compilation unit, containing parameters,
7//! body statements, and optional attributes that control compilation.
8
9use super::stmt::Stmt;
10use super::types::{AddressSpace, Type};
11
12/// A kernel parameter with name, type, and address space.
13#[derive(Debug, Clone, PartialEq)]
14pub struct KernelParam {
15    /// Parameter name.
16    pub name: String,
17    /// Parameter type.
18    pub ty: Type,
19    /// Address space for pointer parameters.
20    pub address_space: AddressSpace,
21}
22
23/// Attributes controlling kernel compilation behavior.
24#[derive(Debug, Clone, PartialEq, Default)]
25pub struct KernelAttributes {
26    /// Requested workgroup size [x, y, z].
27    pub workgroup_size: Option<[u32; 3]>,
28    /// Maximum number of registers to use.
29    pub max_registers: Option<u32>,
30}
31
32/// A GPU kernel definition.
33#[derive(Debug, Clone, PartialEq)]
34pub struct Kernel {
35    /// Kernel name.
36    pub name: String,
37    /// Kernel parameters.
38    pub params: Vec<KernelParam>,
39    /// Kernel body.
40    pub body: Vec<Stmt>,
41    /// Kernel attributes.
42    pub attributes: KernelAttributes,
43}
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48    use crate::hir::expr::{BinOp, Dimension, Expr, Literal};
49
50    #[test]
51    fn test_kernel_construction() {
52        let kernel = Kernel {
53            name: "vector_add".into(),
54            params: vec![
55                KernelParam {
56                    name: "a".into(),
57                    ty: Type::Ptr(AddressSpace::Device),
58                    address_space: AddressSpace::Device,
59                },
60                KernelParam {
61                    name: "b".into(),
62                    ty: Type::Ptr(AddressSpace::Device),
63                    address_space: AddressSpace::Device,
64                },
65                KernelParam {
66                    name: "out".into(),
67                    ty: Type::Ptr(AddressSpace::Device),
68                    address_space: AddressSpace::Device,
69                },
70                KernelParam {
71                    name: "n".into(),
72                    ty: Type::U32,
73                    address_space: AddressSpace::Private,
74                },
75            ],
76            body: vec![Stmt::Assign {
77                target: "gid".into(),
78                value: Expr::ThreadId(Dimension::X),
79            }],
80            attributes: KernelAttributes::default(),
81        };
82        assert_eq!(kernel.name, "vector_add");
83        assert_eq!(kernel.params.len(), 4);
84        assert_eq!(kernel.body.len(), 1);
85    }
86
87    #[test]
88    fn test_kernel_with_if_body() {
89        let kernel = Kernel {
90            name: "guarded_add".into(),
91            params: vec![KernelParam {
92                name: "n".into(),
93                ty: Type::U32,
94                address_space: AddressSpace::Private,
95            }],
96            body: vec![
97                Stmt::Assign {
98                    target: "gid".into(),
99                    value: Expr::ThreadId(Dimension::X),
100                },
101                Stmt::If {
102                    condition: Expr::BinOp {
103                        op: BinOp::Lt,
104                        lhs: Box::new(Expr::Var("gid".into())),
105                        rhs: Box::new(Expr::Var("n".into())),
106                    },
107                    then_body: vec![Stmt::Assign {
108                        target: "x".into(),
109                        value: Expr::Literal(Literal::Int(1)),
110                    }],
111                    else_body: None,
112                },
113            ],
114            attributes: KernelAttributes {
115                workgroup_size: Some([256, 1, 1]),
116                max_registers: Some(32),
117            },
118        };
119        assert_eq!(kernel.body.len(), 2);
120        assert_eq!(kernel.attributes.workgroup_size, Some([256, 1, 1]));
121        assert_eq!(kernel.attributes.max_registers, Some(32));
122    }
123}