wave_compiler/hir/
kernel.rs1use super::stmt::Stmt;
10use super::types::{AddressSpace, Type};
11
12#[derive(Debug, Clone, PartialEq)]
14pub struct KernelParam {
15 pub name: String,
17 pub ty: Type,
19 pub address_space: AddressSpace,
21}
22
23#[derive(Debug, Clone, PartialEq, Default)]
25pub struct KernelAttributes {
26 pub workgroup_size: Option<[u32; 3]>,
28 pub max_registers: Option<u32>,
30}
31
32#[derive(Debug, Clone, PartialEq)]
34pub struct Kernel {
35 pub name: String,
37 pub params: Vec<KernelParam>,
39 pub body: Vec<Stmt>,
41 pub attributes: KernelAttributes,
43}
44
45#[cfg(test)]
46mod tests {
47 use super::*;
48 use crate::hir::expr::{BinOp, Dimension, Expr, Literal};
49
50 #[test]
51 fn test_kernel_construction() {
52 let kernel = Kernel {
53 name: "vector_add".into(),
54 params: vec![
55 KernelParam {
56 name: "a".into(),
57 ty: Type::Ptr(AddressSpace::Device),
58 address_space: AddressSpace::Device,
59 },
60 KernelParam {
61 name: "b".into(),
62 ty: Type::Ptr(AddressSpace::Device),
63 address_space: AddressSpace::Device,
64 },
65 KernelParam {
66 name: "out".into(),
67 ty: Type::Ptr(AddressSpace::Device),
68 address_space: AddressSpace::Device,
69 },
70 KernelParam {
71 name: "n".into(),
72 ty: Type::U32,
73 address_space: AddressSpace::Private,
74 },
75 ],
76 body: vec![Stmt::Assign {
77 target: "gid".into(),
78 value: Expr::ThreadId(Dimension::X),
79 }],
80 attributes: KernelAttributes::default(),
81 };
82 assert_eq!(kernel.name, "vector_add");
83 assert_eq!(kernel.params.len(), 4);
84 assert_eq!(kernel.body.len(), 1);
85 }
86
87 #[test]
88 fn test_kernel_with_if_body() {
89 let kernel = Kernel {
90 name: "guarded_add".into(),
91 params: vec![KernelParam {
92 name: "n".into(),
93 ty: Type::U32,
94 address_space: AddressSpace::Private,
95 }],
96 body: vec![
97 Stmt::Assign {
98 target: "gid".into(),
99 value: Expr::ThreadId(Dimension::X),
100 },
101 Stmt::If {
102 condition: Expr::BinOp {
103 op: BinOp::Lt,
104 lhs: Box::new(Expr::Var("gid".into())),
105 rhs: Box::new(Expr::Var("n".into())),
106 },
107 then_body: vec![Stmt::Assign {
108 target: "x".into(),
109 value: Expr::Literal(Literal::Int(1)),
110 }],
111 else_body: None,
112 },
113 ],
114 attributes: KernelAttributes {
115 workgroup_size: Some([256, 1, 1]),
116 max_registers: Some(32),
117 },
118 };
119 assert_eq!(kernel.body.len(), 2);
120 assert_eq!(kernel.attributes.workgroup_size, Some([256, 1, 1]));
121 assert_eq!(kernel.attributes.max_registers, Some(32));
122 }
123}