cubecl_core/frontend/operation/
branch.rs

1use cubecl_macros::intrinsic;
2
3use crate as cubecl;
4use crate::prelude::{CubePrimitive, Line};
5use crate::{
6    ir::{Operator, Scope, Select},
7    prelude::*,
8};
9
10/// Executes both branches, *then* selects a value based on the condition. This *should* be
11/// branchless, but might depend on the compiler.
12///
13/// # Safety
14///
15/// Since both branches are *evaluated* regardless of the condition, both branches must be *valid*
16/// regardless of the condition. Illegal memory accesses should not be done in either branch.
17pub fn select<C: CubePrimitive>(condition: bool, then: C, or_else: C) -> C {
18    if condition { then } else { or_else }
19}
20
21/// Same as [select()] but with lines instead.
22#[cube]
23#[allow(unused_variables)]
24pub fn select_many<C: CubePrimitive>(
25    condition: Line<bool>,
26    then: Line<C>,
27    or_else: Line<C>,
28) -> Line<C> {
29    intrinsic!(|scope| select::expand(scope, condition.expand.into(), then, or_else))
30}
31
32pub mod select {
33    use std::num::NonZero;
34
35    use crate::ir::Instruction;
36
37    use super::*;
38
39    pub fn expand<C: CubePrimitive>(
40        scope: &mut Scope,
41        condition: ExpandElementTyped<bool>,
42        then: ExpandElementTyped<C>,
43        or_else: ExpandElementTyped<C>,
44    ) -> ExpandElementTyped<C> {
45        let cond = condition.expand.consume();
46        let then = then.expand.consume();
47        let or_else = or_else.expand.consume();
48
49        let vf = cond.vectorization_factor();
50        let vf = Ord::max(vf, then.vectorization_factor());
51        let vf = Ord::max(vf, or_else.vectorization_factor());
52
53        let output = scope.create_local(then.item.vectorize(NonZero::new(vf)));
54        let out = *output;
55
56        let select = Operator::Select(Select {
57            cond,
58            then,
59            or_else,
60        });
61        scope.register(Instruction::new(select, out));
62
63        output.into()
64    }
65}