1use cubecl_core::ir as core;
2use cubecl_opt::{ControlFlow, NodeIndex};
3use rspirv::{
4 dr::Operand,
5 spirv::{LoopControl, SelectionControl, Word},
6};
7
8use crate::{SpirvCompiler, SpirvTarget, item::Item, variable::Variable};
9
10impl<T: SpirvTarget> SpirvCompiler<T> {
11 pub fn compile_read_bound(
12 &mut self,
13 arr: &Variable,
14 index: Word,
15 item: Item,
16 read: impl FnOnce(&mut Self) -> Word,
17 ) -> Word {
18 let ty = item.id(self);
19 let len = match arr.has_buffer_len() {
20 true => self.buffer_length(arr, None, false),
21 false => self.length(arr, None, false),
22 };
23 let bool = self.type_bool();
24 let cond = self.u_less_than(bool, None, index, len).unwrap();
25
26 let current_block = self.current_block.unwrap();
27
28 let current_label = self.end_label(current_block);
29 let in_bounds = self.id();
30 let next = self.id();
31
32 self.selection_merge(next, SelectionControl::DONT_FLATTEN)
33 .unwrap();
34 self.branch_conditional(cond, in_bounds, next, vec![1, 0])
35 .unwrap();
36
37 self.begin_block(Some(in_bounds)).unwrap();
38 let value = read(self);
39 self.branch(next).unwrap();
40
41 let fallback_value = item.constant(self, 0u32.into());
42
43 self.state.end_labels.insert(current_block, next);
44
45 self.begin_block(Some(next)).unwrap();
46 self.phi(
47 ty,
48 None,
49 vec![(value, in_bounds), (fallback_value, current_label)],
50 )
51 .unwrap()
52 }
53
54 pub fn compile_write_bound(
55 &mut self,
56 arr: &Variable,
57 index: Word,
58 write: impl FnOnce(&mut Self),
59 ) {
60 let len = match arr.has_buffer_len() {
61 true => self.buffer_length(arr, None, false),
62 false => self.length(arr, None, false),
63 };
64 let bool = self.type_bool();
65 let cond = self.u_less_than(bool, None, index, len).unwrap();
66 let current_block = self.current_block.unwrap();
67
68 let in_bounds = self.id();
69 let next = self.id();
70
71 self.selection_merge(next, SelectionControl::DONT_FLATTEN)
72 .unwrap();
73 self.branch_conditional(cond, in_bounds, next, vec![1, 0])
74 .unwrap();
75
76 self.begin_block(Some(in_bounds)).unwrap();
77 write(self);
78 self.branch(next).unwrap();
79
80 self.begin_block(Some(next)).unwrap();
81 self.state.end_labels.insert(current_block, next);
82 }
83
84 pub fn compile_copy_bound(
85 &mut self,
86 input: &Variable,
87 out: &Variable,
88 in_index: Word,
89 out_index: Word,
90 len: Option<u32>,
91 copy: impl FnOnce(&mut Self),
92 ) {
93 let in_len = match input.has_buffer_len() {
94 true => self.buffer_length(input, None, false),
95 false => self.length(input, None, false),
96 };
97 let out_len = match out.has_buffer_len() {
98 true => self.buffer_length(out, None, false),
99 false => self.length(out, None, false),
100 };
101
102 let bool = self.type_bool();
103 let int = self.type_int(32, 0);
104 let in_index = match len {
105 Some(len) => self.i_add(int, None, in_index, len).unwrap(),
106 None => in_index,
107 };
108 let out_index = match len {
109 Some(len) => self.i_add(int, None, out_index, len).unwrap(),
110 None => out_index,
111 };
112 let cond_in = self.u_less_than(bool, None, in_index, in_len).unwrap();
113 let cond_out = self.u_less_than(bool, None, out_index, out_len).unwrap();
114 let cond = self.logical_and(bool, None, cond_in, cond_out).unwrap();
115
116 let current_block = self.current_block.unwrap();
117
118 let in_bounds = self.id();
119 let next = self.id();
120
121 self.selection_merge(next, SelectionControl::DONT_FLATTEN)
122 .unwrap();
123 self.branch_conditional(cond, in_bounds, next, vec![1, 0])
124 .unwrap();
125
126 self.begin_block(Some(in_bounds)).unwrap();
127 copy(self);
128 self.branch(next).unwrap();
129
130 self.begin_block(Some(next)).unwrap();
131 self.state.end_labels.insert(current_block, next);
132 }
133
134 pub fn compile_control_flow(&mut self, control_flow: ControlFlow) {
135 match control_flow {
136 ControlFlow::IfElse {
137 cond,
138 then,
139 or_else,
140 merge,
141 } => self.compile_if_else(cond, then, or_else, merge),
142 ControlFlow::Switch {
143 value,
144 default,
145 branches,
146 merge,
147 } => self.compile_switch(value, default, branches, merge),
148 ControlFlow::Loop {
149 body,
150 continue_target,
151 merge,
152 } => self.compile_loop(body, continue_target, merge),
153 ControlFlow::LoopBreak {
154 break_cond,
155 body,
156 continue_target,
157 merge,
158 } => self.compile_loop_break(break_cond, body, continue_target, merge),
159 ControlFlow::Return => {
160 self.ret().unwrap();
161 self.current_block = None;
162 }
163 ControlFlow::None => {
164 let opt = self.opt.clone();
165 let children = opt.successors(self.current_block.unwrap());
166 assert_eq!(
167 children.len(),
168 1,
169 "None control flow should have only 1 outgoing edge"
170 );
171 let label = self.label(children[0]);
172 self.branch(label).unwrap();
173 self.compile_block(children[0]);
174 }
175 }
176 }
177
178 fn compile_if_else(
179 &mut self,
180 cond: core::Variable,
181 then: NodeIndex,
182 or_else: NodeIndex,
183 merge: Option<NodeIndex>,
184 ) {
185 let cond = self.compile_variable(cond);
186 let then_label = self.label(then);
187 let else_label = self.label(or_else);
188 let cond_id = self.read(&cond);
189
190 if let Some(merge) = merge {
191 let merge_label = self.label(merge);
192 self.selection_merge(merge_label, SelectionControl::NONE)
193 .unwrap();
194 }
195 self.branch_conditional(cond_id, then_label, else_label, None)
196 .unwrap();
197 self.compile_block(then);
198 self.compile_block(or_else);
199 if let Some(it) = merge {
200 self.compile_block(it);
201 }
202 }
203
204 fn compile_switch(
205 &mut self,
206 value: core::Variable,
207 default: NodeIndex,
208 branches: Vec<(u32, NodeIndex)>,
209 merge: Option<NodeIndex>,
210 ) {
211 let value = self.compile_variable(value);
212 let value_id = self.read(&value);
213
214 let default_label = self.label(default);
215 let targets = branches
216 .iter()
217 .map(|(value, block)| {
218 let label = self.label(*block);
219 (Operand::LiteralBit32(*value), label)
220 })
221 .collect::<Vec<_>>();
222
223 if let Some(merge) = merge {
224 let merge_label = self.label(merge);
225 self.selection_merge(merge_label, SelectionControl::NONE)
226 .unwrap();
227 }
228
229 self.switch(value_id, default_label, targets).unwrap();
230 self.compile_block(default);
231 for (_, block) in branches {
232 self.compile_block(block);
233 }
234 if let Some(it) = merge {
235 self.compile_block(it);
236 }
237 }
238
239 fn compile_loop(&mut self, body: NodeIndex, continue_target: NodeIndex, merge: NodeIndex) {
240 let body_label = self.label(body);
241 let continue_label = self.label(continue_target);
242 let merge_label = self.label(merge);
243
244 self.loop_merge(merge_label, continue_label, LoopControl::NONE, vec![])
245 .unwrap();
246 self.branch(body_label).unwrap();
247 self.compile_block(body);
248 self.compile_block(continue_target);
249 self.compile_block(merge);
250 }
251
252 fn compile_loop_break(
253 &mut self,
254 break_cond: core::Variable,
255 body: NodeIndex,
256 continue_target: NodeIndex,
257 merge: NodeIndex,
258 ) {
259 let break_cond = self.compile_variable(break_cond);
260 let cond_id = self.read(&break_cond);
261 let body_label = self.label(body);
262 let continue_label = self.label(continue_target);
263 let merge_label = self.label(merge);
264
265 self.loop_merge(merge_label, continue_label, LoopControl::NONE, [])
266 .unwrap();
267 self.branch_conditional(cond_id, body_label, merge_label, [])
268 .unwrap();
269 self.compile_block(body);
270 self.compile_block(continue_target);
271 self.compile_block(merge);
272 }
273}