cubecl_core/frontend/operation/
branch.rs1use crate::{
2 ir::{Operator, Scope, Select},
3 prelude::*,
4};
5use crate::{
6 prelude::{CubePrimitive, Line},
7 unexpanded,
8};
9
10pub fn select<C: CubePrimitive>(condition: bool, then: C, or_else: C) -> C {
18 if condition { then } else { or_else }
19}
20
21#[allow(unused_variables)]
23pub fn select_many<C: CubePrimitive>(
24 condition: Line<bool>,
25 then: Line<C>,
26 or_else: Line<C>,
27) -> Line<C> {
28 unexpanded!()
29}
30
31pub mod select {
32 use std::num::NonZero;
33
34 use crate::ir::Instruction;
35
36 use super::*;
37
38 pub fn expand<C: CubePrimitive>(
39 scope: &mut Scope,
40 condition: ExpandElementTyped<bool>,
41 then: ExpandElementTyped<C>,
42 or_else: ExpandElementTyped<C>,
43 ) -> ExpandElementTyped<C> {
44 let cond = condition.expand.consume();
45 let then = then.expand.consume();
46 let or_else = or_else.expand.consume();
47
48 let vf = cond.vectorization_factor();
49 let vf = Ord::max(vf, then.vectorization_factor());
50 let vf = Ord::max(vf, or_else.vectorization_factor());
51
52 let output = scope.create_local(then.item.vectorize(NonZero::new(vf)));
53 let out = *output;
54
55 let select = Operator::Select(Select {
56 cond,
57 then,
58 or_else,
59 });
60 scope.register(Instruction::new(select, out));
61
62 output.into()
63 }
64}
65
66pub mod select_many {
67 use super::*;
68
69 pub fn expand<C: CubePrimitive>(
70 scope: &mut Scope,
71 condition: ExpandElementTyped<Line<bool>>,
72 then: ExpandElementTyped<Line<C>>,
73 or_else: ExpandElementTyped<Line<C>>,
74 ) -> ExpandElementTyped<Line<C>> {
75 select::expand(scope, condition.expand.into(), then, or_else)
76 }
77}