cubecl_core/frontend/operation/
branch.rs

1use cubecl_macros::intrinsic;
2
3use crate as cubecl;
4use crate::prelude::{CubePrimitive, Line};
5use crate::{
6    ir::{Operator, Scope, Select},
7    prelude::*,
8};
9
10/// Executes both branches, *then* selects a value based on the condition. This *should* be
11/// branchless, but might depend on the compiler.
12///
13/// # Safety
14///
15/// Since both branches are *evaluated* regardless of the condition, both branches must be *valid*
16/// regardless of the condition. Illegal memory accesses should not be done in either branch.
17pub fn select<C: CubePrimitive>(condition: bool, then: C, or_else: C) -> C {
18    if condition { then } else { or_else }
19}
20
21/// Same as [select()] but with lines instead.
22#[cube]
23#[allow(unused_variables)]
24pub fn select_many<C: CubePrimitive>(
25    condition: Line<bool>,
26    then: Line<C>,
27    or_else: Line<C>,
28) -> Line<C> {
29    intrinsic!(|scope| select::expand(scope, condition.expand.into(), then, or_else))
30}
31
32pub mod select {
33    use crate::ir::Instruction;
34
35    use super::*;
36
37    pub fn expand<C: CubePrimitive>(
38        scope: &mut Scope,
39        condition: ExpandElementTyped<bool>,
40        then: ExpandElementTyped<C>,
41        or_else: ExpandElementTyped<C>,
42    ) -> ExpandElementTyped<C> {
43        let cond = condition.expand.consume();
44        let then = then.expand.consume();
45        let or_else = or_else.expand.consume();
46
47        let vf = cond.line_size();
48        let vf = Ord::max(vf, then.line_size());
49        let vf = Ord::max(vf, or_else.line_size());
50
51        let output = scope.create_local(then.ty.line(vf));
52        let out = *output;
53
54        let select = Operator::Select(Select {
55            cond,
56            then,
57            or_else,
58        });
59        scope.register(Instruction::new(select, out));
60
61        output.into()
62    }
63}