Skip to main content

cubecl_ir/
branch.rs

1use alloc::{boxed::Box, format, vec::Vec};
2use core::fmt::Display;
3
4use crate::OperationReflect;
5
6use super::{OperationCode, Scope, Variable};
7use crate::TypeHash;
8
9/// All branching types.
10#[allow(clippy::large_enum_variant)]
11#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
12#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationCode)]
13#[operation(opcode_name = BranchOpCode)]
14pub enum Branch {
15    /// An if statement.
16    If(Box<If>),
17    /// An if else statement.
18    IfElse(Box<IfElse>),
19    /// A switch statement
20    Switch(Box<Switch>),
21    /// A range loop.
22    RangeLoop(Box<RangeLoop>),
23    /// A loop.
24    Loop(Box<Loop>),
25    /// A return statement.
26    Return,
27    /// A break statement.
28    Break,
29    /// Unreachable block end (equivalent to `unreachable_unchecked()`)
30    Unreachable,
31}
32
33impl OperationReflect for Branch {
34    type OpCode = BranchOpCode;
35
36    fn op_code(&self) -> Self::OpCode {
37        self.__match_opcode()
38    }
39}
40
41impl Display for Branch {
42    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
43        match self {
44            Branch::If(if_) => write!(f, "if({}) {}", if_.cond, if_.scope),
45            Branch::IfElse(if_else) => write!(
46                f,
47                "if({}) {} else {}",
48                if_else.cond, if_else.scope_if, if_else.scope_else
49            ),
50            Branch::Switch(switch) => write!(
51                f,
52                "switch({}) {:?}",
53                switch.value,
54                switch
55                    .cases
56                    .iter()
57                    .map(|case| format!("{}", case.0))
58                    .collect::<Vec<_>>(),
59            ),
60            Branch::RangeLoop(range_loop) => write!(
61                f,
62                "for({} in {}{}{}) {}",
63                range_loop.i,
64                range_loop.start,
65                if range_loop.inclusive { "..=" } else { ".." },
66                range_loop.end,
67                range_loop.scope
68            ),
69            Branch::Loop(loop_) => write!(f, "loop {}", loop_.scope),
70            Branch::Return => write!(f, "return"),
71            Branch::Break => write!(f, "break"),
72            Branch::Unreachable => write!(f, "unreachable"),
73        }
74    }
75}
76
77#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
78#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash)]
79#[allow(missing_docs)]
80pub struct If {
81    pub cond: Variable,
82    pub scope: Scope,
83}
84
85#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
86#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash)]
87#[allow(missing_docs)]
88pub struct IfElse {
89    pub cond: Variable,
90    pub scope_if: Scope,
91    pub scope_else: Scope,
92}
93
94#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
95#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash)]
96#[allow(missing_docs)]
97pub struct Switch {
98    pub value: Variable,
99    pub scope_default: Scope,
100    pub cases: Vec<(Variable, Scope)>,
101}
102
103#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
104#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash)]
105#[allow(missing_docs)]
106pub struct RangeLoop {
107    pub i: Variable,
108    pub start: Variable,
109    pub end: Variable,
110    pub step: Option<Variable>,
111    pub inclusive: bool,
112    pub scope: Scope,
113}
114
115#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
116#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash)]
117#[allow(missing_docs)]
118pub struct Loop {
119    pub scope: Scope,
120}