cubecl_ir/
branch.rs

1use alloc::{boxed::Box, format, vec::Vec};
2use core::fmt::Display;
3
4use crate::OperationReflect;
5
6use super::{Elem, Item, OperationCode, Scope, UIntKind, Variable};
7use crate::TypeHash;
8
9/// All branching types.
10#[allow(clippy::large_enum_variant)]
11#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
12#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationCode)]
13#[operation(opcode_name = BranchOpCode)]
14pub enum Branch {
15    /// An if statement.
16    If(Box<If>),
17    /// An if else statement.
18    IfElse(Box<IfElse>),
19    /// A switch statement
20    Switch(Box<Switch>),
21    /// A range loop.
22    RangeLoop(Box<RangeLoop>),
23    /// A loop.
24    Loop(Box<Loop>),
25    /// A return statement.
26    Return,
27    /// A break statement.
28    Break,
29}
30
31impl OperationReflect for Branch {
32    type OpCode = BranchOpCode;
33
34    fn op_code(&self) -> Self::OpCode {
35        self.__match_opcode()
36    }
37}
38
39impl Display for Branch {
40    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
41        match self {
42            Branch::If(if_) => write!(f, "if({}) {}", if_.cond, if_.scope),
43            Branch::IfElse(if_else) => write!(
44                f,
45                "if({}) {} else {}",
46                if_else.cond, if_else.scope_if, if_else.scope_else
47            ),
48            Branch::Switch(switch) => write!(
49                f,
50                "switch({}) {:?}",
51                switch.value,
52                switch
53                    .cases
54                    .iter()
55                    .map(|case| format!("{}", case.0))
56                    .collect::<Vec<_>>(),
57            ),
58            Branch::RangeLoop(range_loop) => write!(
59                f,
60                "for({} in {}{}{}) {}",
61                range_loop.i,
62                range_loop.start,
63                if range_loop.inclusive { "..=" } else { ".." },
64                range_loop.end,
65                range_loop.scope
66            ),
67            Branch::Loop(loop_) => write!(f, "loop {}", loop_.scope),
68            Branch::Return => write!(f, "return"),
69            Branch::Break => write!(f, "break"),
70        }
71    }
72}
73
74#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
75#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash)]
76#[allow(missing_docs)]
77pub struct If {
78    pub cond: Variable,
79    pub scope: Scope,
80}
81
82#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
83#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash)]
84#[allow(missing_docs)]
85pub struct IfElse {
86    pub cond: Variable,
87    pub scope_if: Scope,
88    pub scope_else: Scope,
89}
90
91#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
92#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash)]
93#[allow(missing_docs)]
94pub struct Switch {
95    pub value: Variable,
96    pub scope_default: Scope,
97    pub cases: Vec<(Variable, Scope)>,
98}
99
100#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
101#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash)]
102#[allow(missing_docs)]
103pub struct RangeLoop {
104    pub i: Variable,
105    pub start: Variable,
106    pub end: Variable,
107    pub step: Option<Variable>,
108    pub inclusive: bool,
109    pub scope: Scope,
110}
111
112#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
113#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash)]
114#[allow(missing_docs)]
115pub struct Loop {
116    pub scope: Scope,
117}
118
119impl If {
120    /// Registers an if statement to the given scope.
121    pub fn register<F: Fn(&mut Scope)>(parent_scope: &mut Scope, cond: Variable, func: F) {
122        let mut scope = parent_scope.child();
123
124        func(&mut scope);
125
126        let op = Self { cond, scope };
127        parent_scope.register(Branch::If(Box::new(op)));
128    }
129}
130
131impl IfElse {
132    /// Registers an if else statement to the given scope.
133    pub fn register<IF, ELSE>(
134        parent_scope: &mut Scope,
135        cond: Variable,
136        func_if: IF,
137        func_else: ELSE,
138    ) where
139        IF: Fn(&mut Scope),
140        ELSE: Fn(&mut Scope),
141    {
142        let mut scope_if = parent_scope.child();
143        let mut scope_else = parent_scope.child();
144
145        func_if(&mut scope_if);
146        func_else(&mut scope_else);
147
148        parent_scope.register(Branch::IfElse(Box::new(Self {
149            cond,
150            scope_if,
151            scope_else,
152        })));
153    }
154}
155
156impl RangeLoop {
157    /// Registers a range loop to the given scope.
158    pub fn register<F: Fn(Variable, &mut Scope)>(
159        parent_scope: &mut Scope,
160        start: Variable,
161        end: Variable,
162        step: Option<Variable>,
163        inclusive: bool,
164        func: F,
165    ) {
166        let mut scope = parent_scope.child();
167        let index_ty = Item::new(Elem::UInt(UIntKind::U32));
168        let i = *scope.create_local_restricted(index_ty);
169
170        func(i, &mut scope);
171
172        parent_scope.register(Branch::RangeLoop(Box::new(Self {
173            i,
174            start,
175            end,
176            step,
177            scope,
178            inclusive,
179        })));
180    }
181}
182
183impl Loop {
184    /// Registers a loop to the given scope.
185    pub fn register<F: Fn(&mut Scope)>(parent_scope: &mut Scope, func: F) {
186        let mut scope = parent_scope.child();
187
188        func(&mut scope);
189
190        let op = Self { scope };
191        parent_scope.register(Branch::Loop(Box::new(op)));
192    }
193}
194
195#[allow(missing_docs)]
196pub struct UnrolledRangeLoop;
197
198impl UnrolledRangeLoop {
199    /// Registers an unrolled range loop to the given scope.
200    pub fn register<F: Fn(Variable, &mut Scope)>(
201        scope: &mut Scope,
202        start: u32,
203        end: u32,
204        step: Option<u32>,
205        inclusive: bool,
206        func: F,
207    ) {
208        if inclusive {
209            if let Some(step) = step {
210                for i in (start..=end).step_by(step as usize) {
211                    func(i.into(), scope);
212                }
213            } else {
214                for i in start..=end {
215                    func(i.into(), scope);
216                }
217            }
218        } else if let Some(step) = step {
219            for i in (start..end).step_by(step as usize) {
220                func(i.into(), scope);
221            }
222        } else {
223            for i in start..end {
224                func(i.into(), scope);
225            }
226        }
227    }
228}