cubecl_core/frontend/operation/
branch.rs1use cubecl_macros::intrinsic;
2
3use crate as cubecl;
4use crate::prelude::{CubePrimitive, Line};
5use crate::{
6 ir::{Operator, Scope, Select},
7 prelude::*,
8};
9
10pub fn select<C: CubePrimitive>(condition: bool, then: C, or_else: C) -> C {
18 if condition { then } else { or_else }
19}
20
21#[cube]
23#[allow(unused_variables)]
24pub fn select_many<C: CubePrimitive>(
25 condition: Line<bool>,
26 then: Line<C>,
27 or_else: Line<C>,
28) -> Line<C> {
29 intrinsic!(|scope| select::expand(scope, condition.expand.into(), then, or_else))
30}
31
32pub mod select {
33 use std::num::NonZero;
34
35 use crate::ir::Instruction;
36
37 use super::*;
38
39 pub fn expand<C: CubePrimitive>(
40 scope: &mut Scope,
41 condition: ExpandElementTyped<bool>,
42 then: ExpandElementTyped<C>,
43 or_else: ExpandElementTyped<C>,
44 ) -> ExpandElementTyped<C> {
45 let cond = condition.expand.consume();
46 let then = then.expand.consume();
47 let or_else = or_else.expand.consume();
48
49 let vf = cond.vectorization_factor();
50 let vf = Ord::max(vf, then.vectorization_factor());
51 let vf = Ord::max(vf, or_else.vectorization_factor());
52
53 let output = scope.create_local(then.item.vectorize(NonZero::new(vf)));
54 let out = *output;
55
56 let select = Operator::Select(Select {
57 cond,
58 then,
59 or_else,
60 });
61 scope.register(Instruction::new(select, out));
62
63 output.into()
64 }
65}