cubecl_core/frontend/operation/
branch.rs1use cubecl_macros::intrinsic;
2
3use crate as cubecl;
4use crate::prelude::{CubePrimitive, Vector};
5use crate::{
6 ir::{Operator, Scope, Select},
7 prelude::*,
8};
9
10pub fn select<C: CubePrimitive>(condition: bool, then: C, or_else: C) -> C {
18 if condition { then } else { or_else }
19}
20
21#[cube]
23#[allow(unused_variables)]
24pub fn select_many<C: Scalar, N: Size>(
25 condition: Vector<bool, N>,
26 then: Vector<C, N>,
27 or_else: Vector<C, N>,
28) -> Vector<C, N> {
29 intrinsic!(|scope| select::expand(scope, condition.expand.into(), then, or_else))
30}
31
32pub mod select {
33 use cubecl_ir::VariableKind;
34
35 use crate::ir::Instruction;
36
37 use super::*;
38
39 pub fn expand<C: CubePrimitive>(
40 scope: &mut Scope,
41 condition: NativeExpand<bool>,
42 then: NativeExpand<C>,
43 or_else: NativeExpand<C>,
44 ) -> NativeExpand<C> {
45 let cond = condition.expand.consume();
46
47 if let VariableKind::Constant(value) = cond.kind {
48 if value.as_bool() {
49 return then;
50 } else {
51 return or_else;
52 }
53 }
54
55 let then = then.expand.consume();
56 let or_else = or_else.expand.consume();
57
58 let vf = cond.vector_size();
59 let vf = Ord::max(vf, then.vector_size());
60 let vf = Ord::max(vf, or_else.vector_size());
61
62 let output = scope.create_local(then.ty.with_vector_size(vf));
63 let out = *output;
64
65 let select = Operator::Select(Select {
66 cond,
67 then,
68 or_else,
69 });
70 scope.register(Instruction::new(select, out));
71
72 output.into()
73 }
74}