cubecl_core/ir/
branch.rs

1use std::fmt::Display;
2
3use super::{Elem, Item, Scope, UIntKind, Variable};
4use serde::{Deserialize, Serialize};
5
6/// All branching types.
7#[allow(clippy::large_enum_variant)]
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
9pub enum Branch {
10    /// An if statement.
11    If(Box<If>),
12    /// An if else statement.
13    IfElse(Box<IfElse>),
14    /// A switch statement
15    Switch(Box<Switch>),
16    /// A range loop.
17    RangeLoop(Box<RangeLoop>),
18    /// A loop.
19    Loop(Box<Loop>),
20    /// A return statement.
21    Return,
22    /// A break statement.
23    Break,
24}
25
26impl Display for Branch {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            Branch::If(if_) => write!(f, "if({})", if_.cond),
30            Branch::IfElse(if_else) => write!(f, "if({})", if_else.cond),
31            Branch::Switch(switch) => write!(
32                f,
33                "switch({}) {:?}",
34                switch.value,
35                switch
36                    .cases
37                    .iter()
38                    .map(|case| format!("{}", case.0))
39                    .collect::<Vec<_>>(),
40            ),
41            Branch::RangeLoop(range_loop) => write!(
42                f,
43                "for({} in {}{}{})",
44                range_loop.i,
45                range_loop.start,
46                if range_loop.inclusive { "..=" } else { ".." },
47                range_loop.end
48            ),
49            Branch::Loop(_) => write!(f, "loop{{}}"),
50            Branch::Return => write!(f, "return"),
51            Branch::Break => write!(f, "break"),
52        }
53    }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57#[allow(missing_docs)]
58pub struct If {
59    pub cond: Variable,
60    pub scope: Scope,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
64#[allow(missing_docs)]
65pub struct IfElse {
66    pub cond: Variable,
67    pub scope_if: Scope,
68    pub scope_else: Scope,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
72#[allow(missing_docs)]
73pub struct Select {
74    pub cond: Variable,
75    pub then: Variable,
76    pub or_else: Variable,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
80#[allow(missing_docs)]
81pub struct Switch {
82    pub value: Variable,
83    pub scope_default: Scope,
84    pub cases: Vec<(Variable, Scope)>,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
88#[allow(missing_docs)]
89pub struct RangeLoop {
90    pub i: Variable,
91    pub start: Variable,
92    pub end: Variable,
93    pub step: Option<Variable>,
94    pub inclusive: bool,
95    pub scope: Scope,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
99#[allow(missing_docs)]
100pub struct Loop {
101    pub scope: Scope,
102}
103
104impl If {
105    /// Registers an if statement to the given scope.
106    pub fn register<F: Fn(&mut Scope)>(parent_scope: &mut Scope, cond: Variable, func: F) {
107        let mut scope = parent_scope.child();
108
109        func(&mut scope);
110
111        let op = Self { cond, scope };
112        parent_scope.register(Branch::If(Box::new(op)));
113    }
114}
115
116impl IfElse {
117    /// Registers an if else statement to the given scope.
118    pub fn register<IF, ELSE>(
119        parent_scope: &mut Scope,
120        cond: Variable,
121        func_if: IF,
122        func_else: ELSE,
123    ) where
124        IF: Fn(&mut Scope),
125        ELSE: Fn(&mut Scope),
126    {
127        let mut scope_if = parent_scope.child();
128        let mut scope_else = parent_scope.child();
129
130        func_if(&mut scope_if);
131        func_else(&mut scope_else);
132
133        parent_scope.register(Branch::IfElse(Box::new(Self {
134            cond,
135            scope_if,
136            scope_else,
137        })));
138    }
139}
140
141impl RangeLoop {
142    /// Registers a range loop to the given scope.
143    pub fn register<F: Fn(Variable, &mut Scope)>(
144        parent_scope: &mut Scope,
145        start: Variable,
146        end: Variable,
147        step: Option<Variable>,
148        inclusive: bool,
149        func: F,
150    ) {
151        let mut scope = parent_scope.child();
152        let index_ty = Item::new(Elem::UInt(UIntKind::U32));
153        let i = scope.create_local_restricted(index_ty);
154
155        func(i, &mut scope);
156
157        parent_scope.register(Branch::RangeLoop(Box::new(Self {
158            i,
159            start,
160            end,
161            step,
162            scope,
163            inclusive,
164        })));
165    }
166}
167
168impl Loop {
169    /// Registers a loop to the given scope.
170    pub fn register<F: Fn(&mut Scope)>(parent_scope: &mut Scope, func: F) {
171        let mut scope = parent_scope.child();
172
173        func(&mut scope);
174
175        let op = Self { scope };
176        parent_scope.register(Branch::Loop(Box::new(op)));
177    }
178}
179
180#[allow(missing_docs)]
181pub struct UnrolledRangeLoop;
182
183impl UnrolledRangeLoop {
184    /// Registers an unrolled range loop to the given scope.
185    pub fn register<F: Fn(Variable, &mut Scope)>(
186        scope: &mut Scope,
187        start: u32,
188        end: u32,
189        step: Option<u32>,
190        inclusive: bool,
191        func: F,
192    ) {
193        if inclusive {
194            if let Some(step) = step {
195                for i in (start..=end).step_by(step as usize) {
196                    func(i.into(), scope);
197                }
198            } else {
199                for i in start..=end {
200                    func(i.into(), scope);
201                }
202            }
203        } else if let Some(step) = step {
204            for i in (start..end).step_by(step as usize) {
205                func(i.into(), scope);
206            }
207        } else {
208            for i in start..end {
209                func(i.into(), scope);
210            }
211        }
212    }
213}