cubecl_spirv/
branch.rs

1use cubecl_core::ir as core;
2use cubecl_opt::{ControlFlow, NodeIndex};
3use rspirv::{
4    dr::Operand,
5    spirv::{LoopControl, SelectionControl, Word},
6};
7
8use crate::{SpirvCompiler, SpirvTarget, item::Item, variable::Variable};
9
10impl<T: SpirvTarget> SpirvCompiler<T> {
11    pub fn compile_read_bound(
12        &mut self,
13        arr: &Variable,
14        index: Word,
15        item: Item,
16        read: impl FnOnce(&mut Self) -> Word,
17    ) -> Word {
18        let ty = item.id(self);
19        let len = match arr.has_buffer_len() {
20            true => self.buffer_length(arr, None, false),
21            false => self.length(arr, None, false),
22        };
23        let bool = self.type_bool();
24        let cond = self.u_less_than(bool, None, index, len).unwrap();
25
26        let current_block = self.current_block.unwrap();
27
28        let current_label = self.end_label(current_block);
29        let in_bounds = self.id();
30        let next = self.id();
31
32        self.selection_merge(next, SelectionControl::DONT_FLATTEN)
33            .unwrap();
34        self.branch_conditional(cond, in_bounds, next, vec![1, 0])
35            .unwrap();
36
37        self.begin_block(Some(in_bounds)).unwrap();
38        let value = read(self);
39        self.branch(next).unwrap();
40
41        let fallback_value = item.constant(self, 0u32.into());
42
43        self.state.end_labels.insert(current_block, next);
44
45        self.begin_block(Some(next)).unwrap();
46        self.phi(
47            ty,
48            None,
49            vec![(value, in_bounds), (fallback_value, current_label)],
50        )
51        .unwrap()
52    }
53
54    pub fn compile_write_bound(
55        &mut self,
56        arr: &Variable,
57        index: Word,
58        write: impl FnOnce(&mut Self),
59    ) {
60        let len = match arr.has_buffer_len() {
61            true => self.buffer_length(arr, None, false),
62            false => self.length(arr, None, false),
63        };
64        let bool = self.type_bool();
65        let cond = self.u_less_than(bool, None, index, len).unwrap();
66        let current_block = self.current_block.unwrap();
67
68        let in_bounds = self.id();
69        let next = self.id();
70
71        self.selection_merge(next, SelectionControl::DONT_FLATTEN)
72            .unwrap();
73        self.branch_conditional(cond, in_bounds, next, vec![1, 0])
74            .unwrap();
75
76        self.begin_block(Some(in_bounds)).unwrap();
77        write(self);
78        self.branch(next).unwrap();
79
80        self.begin_block(Some(next)).unwrap();
81        self.state.end_labels.insert(current_block, next);
82    }
83
84    pub fn compile_copy_bound(
85        &mut self,
86        input: &Variable,
87        out: &Variable,
88        in_index: Word,
89        out_index: Word,
90        len: Option<u32>,
91        copy: impl FnOnce(&mut Self),
92    ) {
93        let in_len = match input.has_buffer_len() {
94            true => self.buffer_length(input, None, false),
95            false => self.length(input, None, false),
96        };
97        let out_len = match out.has_buffer_len() {
98            true => self.buffer_length(out, None, false),
99            false => self.length(out, None, false),
100        };
101
102        let bool = self.type_bool();
103        let int = self.type_int(32, 0);
104        let in_index = match len {
105            Some(len) => self.i_add(int, None, in_index, len).unwrap(),
106            None => in_index,
107        };
108        let out_index = match len {
109            Some(len) => self.i_add(int, None, out_index, len).unwrap(),
110            None => out_index,
111        };
112        let cond_in = self.u_less_than(bool, None, in_index, in_len).unwrap();
113        let cond_out = self.u_less_than(bool, None, out_index, out_len).unwrap();
114        let cond = self.logical_and(bool, None, cond_in, cond_out).unwrap();
115
116        let current_block = self.current_block.unwrap();
117
118        let in_bounds = self.id();
119        let next = self.id();
120
121        self.selection_merge(next, SelectionControl::DONT_FLATTEN)
122            .unwrap();
123        self.branch_conditional(cond, in_bounds, next, vec![1, 0])
124            .unwrap();
125
126        self.begin_block(Some(in_bounds)).unwrap();
127        copy(self);
128        self.branch(next).unwrap();
129
130        self.begin_block(Some(next)).unwrap();
131        self.state.end_labels.insert(current_block, next);
132    }
133
134    pub fn compile_control_flow(&mut self, control_flow: ControlFlow) {
135        match control_flow {
136            ControlFlow::IfElse {
137                cond,
138                then,
139                or_else,
140                merge,
141            } => self.compile_if_else(cond, then, or_else, merge),
142            ControlFlow::Switch {
143                value,
144                default,
145                branches,
146                merge,
147            } => self.compile_switch(value, default, branches, merge),
148            ControlFlow::Loop {
149                body,
150                continue_target,
151                merge,
152            } => self.compile_loop(body, continue_target, merge),
153            ControlFlow::LoopBreak {
154                break_cond,
155                body,
156                continue_target,
157                merge,
158            } => self.compile_loop_break(break_cond, body, continue_target, merge),
159            ControlFlow::Return => {
160                self.ret().unwrap();
161                self.current_block = None;
162            }
163            ControlFlow::None => {
164                let opt = self.opt.clone();
165                let children = opt.successors(self.current_block.unwrap());
166                assert_eq!(
167                    children.len(),
168                    1,
169                    "None control flow should have only 1 outgoing edge"
170                );
171                let label = self.label(children[0]);
172                self.branch(label).unwrap();
173                self.compile_block(children[0]);
174            }
175        }
176    }
177
178    fn compile_if_else(
179        &mut self,
180        cond: core::Variable,
181        then: NodeIndex,
182        or_else: NodeIndex,
183        merge: Option<NodeIndex>,
184    ) {
185        let cond = self.compile_variable(cond);
186        let then_label = self.label(then);
187        let else_label = self.label(or_else);
188        let cond_id = self.read(&cond);
189
190        if let Some(merge) = merge {
191            let merge_label = self.label(merge);
192            self.selection_merge(merge_label, SelectionControl::NONE)
193                .unwrap();
194        }
195        self.branch_conditional(cond_id, then_label, else_label, None)
196            .unwrap();
197        self.compile_block(then);
198        self.compile_block(or_else);
199        if let Some(it) = merge {
200            self.compile_block(it);
201        }
202    }
203
204    fn compile_switch(
205        &mut self,
206        value: core::Variable,
207        default: NodeIndex,
208        branches: Vec<(u32, NodeIndex)>,
209        merge: Option<NodeIndex>,
210    ) {
211        let value = self.compile_variable(value);
212        let value_id = self.read(&value);
213
214        let default_label = self.label(default);
215        let targets = branches
216            .iter()
217            .map(|(value, block)| {
218                let label = self.label(*block);
219                (Operand::LiteralBit32(*value), label)
220            })
221            .collect::<Vec<_>>();
222
223        if let Some(merge) = merge {
224            let merge_label = self.label(merge);
225            self.selection_merge(merge_label, SelectionControl::NONE)
226                .unwrap();
227        }
228
229        self.switch(value_id, default_label, targets).unwrap();
230        self.compile_block(default);
231        for (_, block) in branches {
232            self.compile_block(block);
233        }
234        if let Some(it) = merge {
235            self.compile_block(it);
236        }
237    }
238
239    fn compile_loop(&mut self, body: NodeIndex, continue_target: NodeIndex, merge: NodeIndex) {
240        let body_label = self.label(body);
241        let continue_label = self.label(continue_target);
242        let merge_label = self.label(merge);
243
244        self.loop_merge(merge_label, continue_label, LoopControl::NONE, vec![])
245            .unwrap();
246        self.branch(body_label).unwrap();
247        self.compile_block(body);
248        self.compile_block(continue_target);
249        self.compile_block(merge);
250    }
251
252    fn compile_loop_break(
253        &mut self,
254        break_cond: core::Variable,
255        body: NodeIndex,
256        continue_target: NodeIndex,
257        merge: NodeIndex,
258    ) {
259        let break_cond = self.compile_variable(break_cond);
260        let cond_id = self.read(&break_cond);
261        let body_label = self.label(body);
262        let continue_label = self.label(continue_target);
263        let merge_label = self.label(merge);
264
265        self.loop_merge(merge_label, continue_label, LoopControl::NONE, [])
266            .unwrap();
267        self.branch_conditional(cond_id, body_label, merge_label, [])
268            .unwrap();
269        self.compile_block(body);
270        self.compile_block(continue_target);
271        self.compile_block(merge);
272    }
273}