Skip to main content

cubecl_spirv/
branch.rs

1use cubecl_core::ir as core;
2use cubecl_opt::{ControlFlow, NodeIndex};
3use rspirv::{
4    dr::Operand,
5    spirv::{LoopControl, SelectionControl, Word},
6};
7
8use crate::{SpirvCompiler, SpirvTarget, item::Item, variable::Variable};
9
10impl<T: SpirvTarget> SpirvCompiler<T> {
11    pub fn compile_read_bound(
12        &mut self,
13        arr: &Variable,
14        index: Word,
15        item: Item,
16        read: impl FnOnce(&mut Self) -> Word,
17    ) -> Word {
18        let ty = item.id(self);
19        let len = match arr.has_buffer_len() {
20            true => self.buffer_length(arr, None, false),
21            false => self.length(arr, None, false),
22        };
23        let bool = self.type_bool();
24        let cond = self.u_less_than(bool, None, index, len).unwrap();
25
26        let current_block = self.current_block.unwrap();
27
28        let current_label = self.end_label(current_block);
29        let in_bounds = self.id();
30        let next = self.id();
31
32        self.selection_merge(next, SelectionControl::DONT_FLATTEN)
33            .unwrap();
34        self.branch_conditional(cond, in_bounds, next, vec![1, 0])
35            .unwrap();
36
37        self.begin_block(Some(in_bounds)).unwrap();
38        let value = read(self);
39        self.branch(next).unwrap();
40
41        let fallback_value = item.constant(self, 0u32.into());
42
43        self.state.end_labels.insert(current_block, next);
44
45        self.begin_block(Some(next)).unwrap();
46        self.phi(
47            ty,
48            None,
49            vec![(value, in_bounds), (fallback_value, current_label)],
50        )
51        .unwrap()
52    }
53
54    pub fn compile_write_bound(
55        &mut self,
56        arr: &Variable,
57        index: Word,
58        write: impl FnOnce(&mut Self),
59    ) {
60        let len = match arr.has_buffer_len() {
61            true => self.buffer_length(arr, None, false),
62            false => self.length(arr, None, false),
63        };
64        let bool = self.type_bool();
65        let cond = self.u_less_than(bool, None, index, len).unwrap();
66        let current_block = self.current_block.unwrap();
67
68        let in_bounds = self.id();
69        let next = self.id();
70
71        self.selection_merge(next, SelectionControl::DONT_FLATTEN)
72            .unwrap();
73        self.branch_conditional(cond, in_bounds, next, vec![1, 0])
74            .unwrap();
75
76        self.begin_block(Some(in_bounds)).unwrap();
77        write(self);
78        self.branch(next).unwrap();
79
80        self.begin_block(Some(next)).unwrap();
81        self.state.end_labels.insert(current_block, next);
82    }
83
84    pub fn compile_copy_bound(
85        &mut self,
86        input: &Variable,
87        out: &Variable,
88        in_index: Word,
89        out_index: Word,
90        len: Option<u32>,
91        copy: impl FnOnce(&mut Self),
92    ) {
93        let in_len = match input.has_buffer_len() {
94            true => self.buffer_length(input, None, false),
95            false => self.length(input, None, false),
96        };
97        let out_len = match out.has_buffer_len() {
98            true => self.buffer_length(out, None, false),
99            false => self.length(out, None, false),
100        };
101
102        let bool = self.type_bool();
103        let int = self.type_int(32, 0);
104        let in_index = match len {
105            Some(len) => self.i_add(int, None, in_index, len).unwrap(),
106            None => in_index,
107        };
108        let out_index = match len {
109            Some(len) => self.i_add(int, None, out_index, len).unwrap(),
110            None => out_index,
111        };
112        let cond_in = self.u_less_than(bool, None, in_index, in_len).unwrap();
113        let cond_out = self.u_less_than(bool, None, out_index, out_len).unwrap();
114        let cond = self.logical_and(bool, None, cond_in, cond_out).unwrap();
115
116        let current_block = self.current_block.unwrap();
117
118        let in_bounds = self.id();
119        let next = self.id();
120
121        self.selection_merge(next, SelectionControl::DONT_FLATTEN)
122            .unwrap();
123        self.branch_conditional(cond, in_bounds, next, vec![1, 0])
124            .unwrap();
125
126        self.begin_block(Some(in_bounds)).unwrap();
127        copy(self);
128        self.branch(next).unwrap();
129
130        self.begin_block(Some(next)).unwrap();
131        self.state.end_labels.insert(current_block, next);
132    }
133
134    pub fn compile_control_flow(&mut self, control_flow: ControlFlow) {
135        match control_flow {
136            ControlFlow::IfElse {
137                cond,
138                then,
139                or_else,
140                merge,
141            } => self.compile_if_else(cond, then, or_else, merge),
142            ControlFlow::Switch {
143                value,
144                default,
145                branches,
146                merge,
147            } => self.compile_switch(value, default, branches, merge),
148            ControlFlow::Loop {
149                body,
150                continue_target,
151                merge,
152            } => self.compile_loop(body, continue_target, merge),
153            ControlFlow::LoopBreak {
154                break_cond,
155                body,
156                continue_target,
157                merge,
158            } => self.compile_loop_break(break_cond, body, continue_target, merge),
159            ControlFlow::Return => {
160                self.ret().unwrap();
161                self.current_block = None;
162            }
163            ControlFlow::Unreachable => {
164                self.unreachable().unwrap();
165                self.current_block = None;
166            }
167            ControlFlow::None => {
168                let opt = self.opt.clone();
169                let children = opt.successors(self.current_block.unwrap());
170                assert_eq!(
171                    children.len(),
172                    1,
173                    "None control flow should have only 1 outgoing edge"
174                );
175                let label = self.label(children[0]);
176                self.branch(label).unwrap();
177                self.compile_block(children[0]);
178            }
179        }
180    }
181
182    fn compile_if_else(
183        &mut self,
184        cond: core::Variable,
185        then: NodeIndex,
186        or_else: NodeIndex,
187        merge: Option<NodeIndex>,
188    ) {
189        let cond = self.compile_variable(cond);
190        let then_label = self.label(then);
191        let else_label = self.label(or_else);
192        let cond_id = self.read(&cond);
193
194        if let Some(merge) = merge {
195            let merge_label = self.label(merge);
196            self.selection_merge(merge_label, SelectionControl::NONE)
197                .unwrap();
198        }
199        self.branch_conditional(cond_id, then_label, else_label, None)
200            .unwrap();
201        self.compile_block(then);
202        self.compile_block(or_else);
203        if let Some(it) = merge {
204            self.compile_block(it);
205        }
206    }
207
208    fn compile_switch(
209        &mut self,
210        value: core::Variable,
211        default: NodeIndex,
212        branches: Vec<(u32, NodeIndex)>,
213        merge: Option<NodeIndex>,
214    ) {
215        let value = self.compile_variable(value);
216        let value_id = self.read(&value);
217
218        let default_label = self.label(default);
219        let targets = branches
220            .iter()
221            .map(|(value, block)| {
222                let label = self.label(*block);
223                (Operand::LiteralBit32(*value), label)
224            })
225            .collect::<Vec<_>>();
226
227        if let Some(merge) = merge {
228            let merge_label = self.label(merge);
229            self.selection_merge(merge_label, SelectionControl::NONE)
230                .unwrap();
231        }
232
233        self.switch(value_id, default_label, targets).unwrap();
234        self.compile_block(default);
235        for (_, block) in branches {
236            self.compile_block(block);
237        }
238        if let Some(it) = merge {
239            self.compile_block(it);
240        }
241    }
242
243    fn compile_loop(&mut self, body: NodeIndex, continue_target: NodeIndex, merge: NodeIndex) {
244        let body_label = self.label(body);
245        let continue_label = self.label(continue_target);
246        let merge_label = self.label(merge);
247
248        self.loop_merge(merge_label, continue_label, LoopControl::NONE, vec![])
249            .unwrap();
250        self.branch(body_label).unwrap();
251        self.compile_block(body);
252        self.compile_block(continue_target);
253        self.compile_block(merge);
254    }
255
256    fn compile_loop_break(
257        &mut self,
258        break_cond: core::Variable,
259        body: NodeIndex,
260        continue_target: NodeIndex,
261        merge: NodeIndex,
262    ) {
263        let break_cond = self.compile_variable(break_cond);
264        let cond_id = self.read(&break_cond);
265        let body_label = self.label(body);
266        let continue_label = self.label(continue_target);
267        let merge_label = self.label(merge);
268
269        self.loop_merge(merge_label, continue_label, LoopControl::NONE, [])
270            .unwrap();
271        self.branch_conditional(cond_id, body_label, merge_label, [])
272            .unwrap();
273        self.compile_block(body);
274        self.compile_block(continue_target);
275        self.compile_block(merge);
276    }
277}