cubecl_core/frontend/operation/
branch.rs

1use crate::{ir::Select, prelude::*};
2use crate::{
3    prelude::{CubePrimitive, Line},
4    unexpanded,
5};
6
7/// Executes both branches, *then* selects a value based on the condition. This *should* be
8/// branchless, but might depend on the compiler.
9///
10/// # Safety
11///
12/// Since both branches are *evaluated* regardless of the condition, both branches must be *valid*
13/// regardless of the condition. Illegal memory accesses should not be done in either branch.
14pub fn select<C: CubePrimitive>(condition: bool, then: C, or_else: C) -> C {
15    if condition {
16        then
17    } else {
18        or_else
19    }
20}
21
22/// Same as [select] but with lines instead.
23#[allow(unused_variables)]
24pub fn select_many<C: CubePrimitive>(
25    condition: Line<bool>,
26    then: Line<C>,
27    or_else: Line<C>,
28) -> Line<C> {
29    unexpanded!()
30}
31
32pub mod select {
33    use std::num::NonZero;
34
35    use crate::ir::{Instruction, Operator};
36
37    use super::*;
38
39    pub fn expand<C: CubePrimitive>(
40        context: &mut CubeContext,
41        condition: ExpandElementTyped<bool>,
42        then: ExpandElementTyped<C>,
43        or_else: ExpandElementTyped<C>,
44    ) -> ExpandElementTyped<C> {
45        let cond = condition.expand.consume();
46        let then = then.expand.consume();
47        let or_else = or_else.expand.consume();
48
49        let vf = cond.vectorization_factor();
50        let vf = Ord::max(vf, then.vectorization_factor());
51        let vf = Ord::max(vf, or_else.vectorization_factor());
52
53        let output = context.create_local(then.item.vectorize(NonZero::new(vf)));
54        let out = *output;
55
56        let select = Operator::Select(Select {
57            cond,
58            then,
59            or_else,
60        });
61        context.register(Instruction::new(select, out));
62
63        output.into()
64    }
65}
66
67pub mod select_many {
68    use super::*;
69
70    pub fn expand<C: CubePrimitive>(
71        context: &mut CubeContext,
72        condition: ExpandElementTyped<Line<bool>>,
73        then: ExpandElementTyped<Line<C>>,
74        or_else: ExpandElementTyped<Line<C>>,
75    ) -> ExpandElementTyped<Line<C>> {
76        select::expand(context, condition.expand.into(), then, or_else)
77    }
78}