1use cubecl_core::ir as core;
2use cubecl_opt::{ControlFlow, NodeIndex};
3use rspirv::{
4 dr::Operand,
5 spirv::{LoopControl, SelectionControl, Word},
6};
7
8use crate::{SpirvCompiler, SpirvTarget, item::Item, variable::Variable};
9
10impl<T: SpirvTarget> SpirvCompiler<T> {
11 pub fn compile_read_bound(
12 &mut self,
13 arr: &Variable,
14 index: Word,
15 item: Item,
16 read: impl FnOnce(&mut Self) -> Word,
17 ) -> Word {
18 let ty = item.id(self);
19 let len = match arr.has_buffer_len() {
20 true => self.buffer_length(arr, None, false),
21 false => self.length(arr, None, false),
22 };
23 let bool = self.type_bool();
24 let cond = self.u_less_than(bool, None, index, len).unwrap();
25
26 let current_block = self.current_block.unwrap();
27
28 let current_label = self.end_label(current_block);
29 let in_bounds = self.id();
30 let next = self.id();
31
32 self.selection_merge(next, SelectionControl::DONT_FLATTEN)
33 .unwrap();
34 self.branch_conditional(cond, in_bounds, next, vec![1, 0])
35 .unwrap();
36
37 self.begin_block(Some(in_bounds)).unwrap();
38 let value = read(self);
39 self.branch(next).unwrap();
40
41 let fallback_value = item.constant(self, 0u32.into());
42
43 self.state.end_labels.insert(current_block, next);
44
45 self.begin_block(Some(next)).unwrap();
46 self.phi(
47 ty,
48 None,
49 vec![(value, in_bounds), (fallback_value, current_label)],
50 )
51 .unwrap()
52 }
53
54 pub fn compile_write_bound(
55 &mut self,
56 arr: &Variable,
57 index: Word,
58 write: impl FnOnce(&mut Self),
59 ) {
60 let len = match arr.has_buffer_len() {
61 true => self.buffer_length(arr, None, false),
62 false => self.length(arr, None, false),
63 };
64 let bool = self.type_bool();
65 let cond = self.u_less_than(bool, None, index, len).unwrap();
66 let current_block = self.current_block.unwrap();
67
68 let in_bounds = self.id();
69 let next = self.id();
70
71 self.selection_merge(next, SelectionControl::DONT_FLATTEN)
72 .unwrap();
73 self.branch_conditional(cond, in_bounds, next, vec![1, 0])
74 .unwrap();
75
76 self.begin_block(Some(in_bounds)).unwrap();
77 write(self);
78 self.branch(next).unwrap();
79
80 self.begin_block(Some(next)).unwrap();
81 self.state.end_labels.insert(current_block, next);
82 }
83
84 pub fn compile_copy_bound(
85 &mut self,
86 input: &Variable,
87 out: &Variable,
88 in_index: Word,
89 out_index: Word,
90 len: Option<u32>,
91 copy: impl FnOnce(&mut Self),
92 ) {
93 let in_len = match input.has_buffer_len() {
94 true => self.buffer_length(input, None, false),
95 false => self.length(input, None, false),
96 };
97 let out_len = match out.has_buffer_len() {
98 true => self.buffer_length(out, None, false),
99 false => self.length(out, None, false),
100 };
101
102 let bool = self.type_bool();
103 let int = self.type_int(32, 0);
104 let in_index = match len {
105 Some(len) => self.i_add(int, None, in_index, len).unwrap(),
106 None => in_index,
107 };
108 let out_index = match len {
109 Some(len) => self.i_add(int, None, out_index, len).unwrap(),
110 None => out_index,
111 };
112 let cond_in = self.u_less_than(bool, None, in_index, in_len).unwrap();
113 let cond_out = self.u_less_than(bool, None, out_index, out_len).unwrap();
114 let cond = self.logical_and(bool, None, cond_in, cond_out).unwrap();
115
116 let current_block = self.current_block.unwrap();
117
118 let in_bounds = self.id();
119 let next = self.id();
120
121 self.selection_merge(next, SelectionControl::DONT_FLATTEN)
122 .unwrap();
123 self.branch_conditional(cond, in_bounds, next, vec![1, 0])
124 .unwrap();
125
126 self.begin_block(Some(in_bounds)).unwrap();
127 copy(self);
128 self.branch(next).unwrap();
129
130 self.begin_block(Some(next)).unwrap();
131 self.state.end_labels.insert(current_block, next);
132 }
133
134 pub fn compile_control_flow(&mut self, control_flow: ControlFlow) {
135 match control_flow {
136 ControlFlow::IfElse {
137 cond,
138 then,
139 or_else,
140 merge,
141 } => self.compile_if_else(cond, then, or_else, merge),
142 ControlFlow::Switch {
143 value,
144 default,
145 branches,
146 merge,
147 } => self.compile_switch(value, default, branches, merge),
148 ControlFlow::Loop {
149 body,
150 continue_target,
151 merge,
152 } => self.compile_loop(body, continue_target, merge),
153 ControlFlow::LoopBreak {
154 break_cond,
155 body,
156 continue_target,
157 merge,
158 } => self.compile_loop_break(break_cond, body, continue_target, merge),
159 ControlFlow::Return => {
160 self.ret().unwrap();
161 self.current_block = None;
162 }
163 ControlFlow::Unreachable => {
164 self.unreachable().unwrap();
165 self.current_block = None;
166 }
167 ControlFlow::None => {
168 let opt = self.opt.clone();
169 let children = opt.successors(self.current_block.unwrap());
170 assert_eq!(
171 children.len(),
172 1,
173 "None control flow should have only 1 outgoing edge"
174 );
175 let label = self.label(children[0]);
176 self.branch(label).unwrap();
177 self.compile_block(children[0]);
178 }
179 }
180 }
181
182 fn compile_if_else(
183 &mut self,
184 cond: core::Variable,
185 then: NodeIndex,
186 or_else: NodeIndex,
187 merge: Option<NodeIndex>,
188 ) {
189 let cond = self.compile_variable(cond);
190 let then_label = self.label(then);
191 let else_label = self.label(or_else);
192 let cond_id = self.read(&cond);
193
194 if let Some(merge) = merge {
195 let merge_label = self.label(merge);
196 self.selection_merge(merge_label, SelectionControl::NONE)
197 .unwrap();
198 }
199 self.branch_conditional(cond_id, then_label, else_label, None)
200 .unwrap();
201 self.compile_block(then);
202 self.compile_block(or_else);
203 if let Some(it) = merge {
204 self.compile_block(it);
205 }
206 }
207
208 fn compile_switch(
209 &mut self,
210 value: core::Variable,
211 default: NodeIndex,
212 branches: Vec<(u32, NodeIndex)>,
213 merge: Option<NodeIndex>,
214 ) {
215 let value = self.compile_variable(value);
216 let value_id = self.read(&value);
217
218 let default_label = self.label(default);
219 let targets = branches
220 .iter()
221 .map(|(value, block)| {
222 let label = self.label(*block);
223 (Operand::LiteralBit32(*value), label)
224 })
225 .collect::<Vec<_>>();
226
227 if let Some(merge) = merge {
228 let merge_label = self.label(merge);
229 self.selection_merge(merge_label, SelectionControl::NONE)
230 .unwrap();
231 }
232
233 self.switch(value_id, default_label, targets).unwrap();
234 self.compile_block(default);
235 for (_, block) in branches {
236 self.compile_block(block);
237 }
238 if let Some(it) = merge {
239 self.compile_block(it);
240 }
241 }
242
243 fn compile_loop(&mut self, body: NodeIndex, continue_target: NodeIndex, merge: NodeIndex) {
244 let body_label = self.label(body);
245 let continue_label = self.label(continue_target);
246 let merge_label = self.label(merge);
247
248 self.loop_merge(merge_label, continue_label, LoopControl::NONE, vec![])
249 .unwrap();
250 self.branch(body_label).unwrap();
251 self.compile_block(body);
252 self.compile_block(continue_target);
253 self.compile_block(merge);
254 }
255
256 fn compile_loop_break(
257 &mut self,
258 break_cond: core::Variable,
259 body: NodeIndex,
260 continue_target: NodeIndex,
261 merge: NodeIndex,
262 ) {
263 let break_cond = self.compile_variable(break_cond);
264 let cond_id = self.read(&break_cond);
265 let body_label = self.label(body);
266 let continue_label = self.label(continue_target);
267 let merge_label = self.label(merge);
268
269 self.loop_merge(merge_label, continue_label, LoopControl::NONE, [])
270 .unwrap();
271 self.branch_conditional(cond_id, body_label, merge_label, [])
272 .unwrap();
273 self.compile_block(body);
274 self.compile_block(continue_target);
275 self.compile_block(merge);
276 }
277}