cubecl_core/frontend/operation/
branch.rs

1use crate::{
2    ir::{Operator, Scope, Select},
3    prelude::*,
4};
5use crate::{
6    prelude::{CubePrimitive, Line},
7    unexpanded,
8};
9
10/// Executes both branches, *then* selects a value based on the condition. This *should* be
11/// branchless, but might depend on the compiler.
12///
13/// # Safety
14///
15/// Since both branches are *evaluated* regardless of the condition, both branches must be *valid*
16/// regardless of the condition. Illegal memory accesses should not be done in either branch.
17pub fn select<C: CubePrimitive>(condition: bool, then: C, or_else: C) -> C {
18    if condition { then } else { or_else }
19}
20
21/// Same as [select()] but with lines instead.
22#[allow(unused_variables)]
23pub fn select_many<C: CubePrimitive>(
24    condition: Line<bool>,
25    then: Line<C>,
26    or_else: Line<C>,
27) -> Line<C> {
28    unexpanded!()
29}
30
31pub mod select {
32    use std::num::NonZero;
33
34    use crate::ir::Instruction;
35
36    use super::*;
37
38    pub fn expand<C: CubePrimitive>(
39        scope: &mut Scope,
40        condition: ExpandElementTyped<bool>,
41        then: ExpandElementTyped<C>,
42        or_else: ExpandElementTyped<C>,
43    ) -> ExpandElementTyped<C> {
44        let cond = condition.expand.consume();
45        let then = then.expand.consume();
46        let or_else = or_else.expand.consume();
47
48        let vf = cond.vectorization_factor();
49        let vf = Ord::max(vf, then.vectorization_factor());
50        let vf = Ord::max(vf, or_else.vectorization_factor());
51
52        let output = scope.create_local(then.item.vectorize(NonZero::new(vf)));
53        let out = *output;
54
55        let select = Operator::Select(Select {
56            cond,
57            then,
58            or_else,
59        });
60        scope.register(Instruction::new(select, out));
61
62        output.into()
63    }
64}
65
66pub mod select_many {
67    use super::*;
68
69    pub fn expand<C: CubePrimitive>(
70        scope: &mut Scope,
71        condition: ExpandElementTyped<Line<bool>>,
72        then: ExpandElementTyped<Line<C>>,
73        or_else: ExpandElementTyped<Line<C>>,
74    ) -> ExpandElementTyped<Line<C>> {
75        select::expand(scope, condition.expand.into(), then, or_else)
76    }
77}