Skip to main content

cubecl_core/frontend/operation/
branch.rs

1use cubecl_macros::intrinsic;
2
3use crate as cubecl;
4use crate::prelude::{CubePrimitive, Vector};
5use crate::{
6    ir::{Operator, Scope, Select},
7    prelude::*,
8};
9
10/// Executes both branches, *then* selects a value based on the condition. This *should* be
11/// branchless, but might depend on the compiler.
12///
13/// # Safety
14///
15/// Since both branches are *evaluated* regardless of the condition, both branches must be *valid*
16/// regardless of the condition. Illegal memory accesses should not be done in either branch.
17pub fn select<C: CubePrimitive>(condition: bool, then: C, or_else: C) -> C {
18    if condition { then } else { or_else }
19}
20
21/// Same as [`select()`] but with vectors instead.
22#[cube]
23#[allow(unused_variables)]
24pub fn select_many<C: Scalar, N: Size>(
25    condition: Vector<bool, N>,
26    then: Vector<C, N>,
27    or_else: Vector<C, N>,
28) -> Vector<C, N> {
29    intrinsic!(|scope| select::expand(scope, condition.expand.into(), then, or_else))
30}
31
32pub mod select {
33    use cubecl_ir::VariableKind;
34
35    use crate::ir::Instruction;
36
37    use super::*;
38
39    pub fn expand<C: CubePrimitive>(
40        scope: &mut Scope,
41        condition: NativeExpand<bool>,
42        then: NativeExpand<C>,
43        or_else: NativeExpand<C>,
44    ) -> NativeExpand<C> {
45        let cond = condition.expand.consume();
46
47        if let VariableKind::Constant(value) = cond.kind {
48            if value.as_bool() {
49                return then;
50            } else {
51                return or_else;
52            }
53        }
54
55        let then = then.expand.consume();
56        let or_else = or_else.expand.consume();
57
58        let vf = cond.vector_size();
59        let vf = Ord::max(vf, then.vector_size());
60        let vf = Ord::max(vf, or_else.vector_size());
61
62        let output = scope.create_local(then.ty.with_vector_size(vf));
63        let out = *output;
64
65        let select = Operator::Select(Select {
66            cond,
67            then,
68            or_else,
69        });
70        scope.register(Instruction::new(select, out));
71
72        output.into()
73    }
74}