cubecl_core/frontend/
branch.rs

1use cubecl_ir::ExpandElement;
2use num_traits::NumCast;
3
4use crate::ir::Switch;
5use crate::ir::{Branch, If, IfElse, Item, Loop, RangeLoop, Scope};
6
7use super::{CubePrimitive, CubeType, ExpandElementTyped, Int, Numeric, assign};
8
9/// Something that can be iterated on by a for loop. Currently only includes `Range`, `StepBy` and
10/// `Sequence`.
11pub trait Iterable<T: CubeType>: Sized {
12    /// Expand a runtime loop without unrolling
13    ///
14    /// # Arguments
15    /// # Arguments
16    /// * `scope` - the expansion scope
17    /// * `body` - the loop body to be executed repeatedly
18    fn expand(self, scope: &mut Scope, body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType));
19    /// Expand an unrolled loop. The body should be invoced `n` times, where `n` is the number of
20    /// iterations.
21    ///
22    /// # Arguments
23    /// * `scope` - the expansion scope
24    /// * `body` - the loop body to be executed repeatedly
25    fn expand_unroll(
26        self,
27        scope: &mut Scope,
28        body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
29    );
30}
31
32pub struct RangeExpand<I: Int> {
33    pub start: ExpandElementTyped<I>,
34    pub end: ExpandElementTyped<I>,
35    pub inclusive: bool,
36}
37
38impl<I: Int> RangeExpand<I> {
39    pub fn new(start: ExpandElementTyped<I>, end: ExpandElementTyped<I>, inclusive: bool) -> Self {
40        RangeExpand {
41            start,
42            end,
43            inclusive,
44        }
45    }
46
47    pub fn __expand_step_by_method(
48        self,
49        n: impl Into<ExpandElementTyped<u32>>,
50    ) -> SteppedRangeExpand<I> {
51        SteppedRangeExpand {
52            start: self.start,
53            end: self.end,
54            step: n.into(),
55            inclusive: self.inclusive,
56        }
57    }
58}
59
60impl<I: Int> Iterable<I> for RangeExpand<I> {
61    fn expand_unroll(
62        self,
63        scope: &mut Scope,
64        mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
65    ) {
66        let start = self
67            .start
68            .expand
69            .as_const()
70            .expect("Only constant start can be unrolled.")
71            .as_i64();
72        let end = self
73            .end
74            .expand
75            .as_const()
76            .expect("Only constant end can be unrolled.")
77            .as_i64();
78
79        if self.inclusive {
80            for i in start..=end {
81                let var = I::from_int(i);
82                body(scope, var.into())
83            }
84        } else {
85            for i in start..end {
86                let var = I::from_int(i);
87                body(scope, var.into())
88            }
89        }
90    }
91
92    fn expand(
93        self,
94        scope: &mut Scope,
95        mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
96    ) {
97        let mut child = scope.child();
98        let index_ty = Item::new(I::as_elem(scope));
99        let i = child.create_local_restricted(index_ty);
100
101        body(&mut child, i.clone().into());
102
103        scope.register(Branch::RangeLoop(Box::new(RangeLoop {
104            i: *i,
105            start: *self.start.expand,
106            end: *self.end.expand,
107            step: None,
108            scope: child,
109            inclusive: self.inclusive,
110        })));
111    }
112}
113
114pub struct SteppedRangeExpand<I: Int> {
115    start: ExpandElementTyped<I>,
116    end: ExpandElementTyped<I>,
117    step: ExpandElementTyped<u32>,
118    inclusive: bool,
119}
120
121impl<I: Int + Into<ExpandElement>> Iterable<I> for SteppedRangeExpand<I> {
122    fn expand(
123        self,
124        scope: &mut Scope,
125        mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
126    ) {
127        let mut child = scope.child();
128        let index_ty = Item::new(I::as_elem(scope));
129        let i = child.create_local_restricted(index_ty);
130
131        body(&mut child, i.clone().into());
132
133        scope.register(Branch::RangeLoop(Box::new(RangeLoop {
134            i: *i,
135            start: *self.start.expand,
136            end: *self.end.expand,
137            step: Some(*self.step.expand),
138            scope: child,
139            inclusive: self.inclusive,
140        })));
141    }
142
143    fn expand_unroll(
144        self,
145        scope: &mut Scope,
146        mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
147    ) {
148        let start = self
149            .start
150            .expand
151            .as_const()
152            .expect("Only constant start can be unrolled.")
153            .as_i64();
154        let end = self
155            .end
156            .expand
157            .as_const()
158            .expect("Only constant end can be unrolled.")
159            .as_i64();
160        let step = self
161            .step
162            .expand
163            .as_const()
164            .expect("Only constant step can be unrolled.")
165            .as_usize();
166
167        if self.inclusive {
168            for i in (start..=end).step_by(step) {
169                let var = I::from_int(i);
170                body(scope, var.into())
171            }
172        } else {
173            for i in (start..end).step_by(step) {
174                let var = I::from_int(i);
175                body(scope, var.into())
176            }
177        }
178    }
179}
180
181/// integer range. Equivalent to:
182///
183/// ```ignore
184/// start..end
185/// ```
186pub fn range<T: Int>(start: T, end: T) -> impl Iterator<Item = T> {
187    let start: i64 = start.to_i64().unwrap();
188    let end: i64 = end.to_i64().unwrap();
189    (start..end).map(<T as NumCast>::from).map(Option::unwrap)
190}
191
192pub mod range {
193    use cubecl_ir::Scope;
194
195    use crate::prelude::{ExpandElementTyped, Int};
196
197    use super::RangeExpand;
198
199    pub fn expand<I: Int>(
200        _scope: &mut Scope,
201        start: ExpandElementTyped<I>,
202        end: ExpandElementTyped<I>,
203    ) -> RangeExpand<I> {
204        RangeExpand {
205            start,
206            end,
207            inclusive: false,
208        }
209    }
210}
211
212/// Stepped range. Equivalent to:
213///
214/// ```ignore
215/// (start..end).step_by(step)
216/// ```
217///
218/// Allows using any integer for the step, instead of just usize
219pub fn range_stepped<I: Int>(start: I, end: I, step: impl Int) -> impl Iterator<Item = I> {
220    let start = start.to_i64().unwrap();
221    let end = end.to_i64().unwrap();
222    let step = step.to_usize().unwrap();
223    (start..end)
224        .step_by(step)
225        .map(<I as NumCast>::from)
226        .map(Option::unwrap)
227}
228
229pub mod range_stepped {
230    use cubecl_ir::Scope;
231
232    use crate::prelude::{ExpandElementTyped, Int};
233
234    use super::SteppedRangeExpand;
235
236    pub fn expand<I: Int>(
237        _scope: &mut Scope,
238        start: ExpandElementTyped<I>,
239        end: ExpandElementTyped<I>,
240        step: ExpandElementTyped<u32>,
241    ) -> SteppedRangeExpand<I> {
242        SteppedRangeExpand {
243            start,
244            end,
245            step,
246            inclusive: false,
247        }
248    }
249}
250
251pub fn for_expand<I: Numeric>(
252    scope: &mut Scope,
253    range: impl Iterable<I>,
254    unroll: bool,
255    body: impl FnMut(&mut Scope, ExpandElementTyped<I>),
256) {
257    if unroll {
258        range.expand_unroll(scope, body);
259    } else {
260        range.expand(scope, body);
261    }
262}
263
264pub fn if_expand(scope: &mut Scope, runtime_cond: ExpandElement, block: impl FnOnce(&mut Scope)) {
265    let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
266    match comptime_cond {
267        Some(cond) => {
268            if cond {
269                block(scope);
270            }
271        }
272        None => {
273            let mut child = scope.child();
274
275            block(&mut child);
276
277            scope.register(Branch::If(Box::new(If {
278                cond: *runtime_cond,
279                scope: child,
280            })));
281        }
282    }
283}
284
285#[allow(clippy::large_enum_variant)]
286pub enum IfElseExpand {
287    ComptimeThen,
288    ComptimeElse,
289    Runtime {
290        runtime_cond: ExpandElement,
291        then_child: Scope,
292    },
293}
294
295impl IfElseExpand {
296    pub fn or_else(self, scope: &mut Scope, else_block: impl FnOnce(&mut Scope)) {
297        match self {
298            Self::Runtime {
299                runtime_cond,
300                then_child,
301            } => {
302                let mut else_child = scope.child();
303                else_block(&mut else_child);
304
305                scope.register(Branch::IfElse(Box::new(IfElse {
306                    cond: *runtime_cond,
307                    scope_if: then_child,
308                    scope_else: else_child,
309                })));
310            }
311            Self::ComptimeElse => else_block(scope),
312            Self::ComptimeThen => (),
313        }
314    }
315}
316
317pub fn if_else_expand(
318    scope: &mut Scope,
319    runtime_cond: ExpandElement,
320    then_block: impl FnOnce(&mut Scope),
321) -> IfElseExpand {
322    let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
323    match comptime_cond {
324        Some(true) => {
325            then_block(scope);
326            IfElseExpand::ComptimeThen
327        }
328        Some(false) => IfElseExpand::ComptimeElse,
329        None => {
330            let mut then_child = scope.child();
331            then_block(&mut then_child);
332
333            IfElseExpand::Runtime {
334                runtime_cond,
335                then_child,
336            }
337        }
338    }
339}
340
341#[allow(clippy::large_enum_variant)]
342pub enum IfElseExprExpand<C: CubeType> {
343    ComptimeThen(ExpandElementTyped<C>),
344    ComptimeElse,
345    Runtime {
346        runtime_cond: ExpandElement,
347        out: ExpandElementTyped<C>,
348        then_child: Scope,
349    },
350}
351
352impl<C: CubePrimitive> IfElseExprExpand<C> {
353    pub fn or_else(
354        self,
355        scope: &mut Scope,
356        else_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
357    ) -> ExpandElementTyped<C> {
358        match self {
359            Self::Runtime {
360                runtime_cond,
361                out,
362                then_child,
363            } => {
364                let mut else_child = scope.child();
365                let ret = else_block(&mut else_child);
366                assign::expand::<C>(&mut else_child, ret, out.clone());
367
368                scope.register(Branch::IfElse(Box::new(IfElse {
369                    cond: *runtime_cond,
370                    scope_if: then_child,
371                    scope_else: else_child,
372                })));
373                out
374            }
375            Self::ComptimeElse => else_block(scope),
376            Self::ComptimeThen(ret) => ret,
377        }
378    }
379}
380
381pub fn if_else_expr_expand<C: CubePrimitive>(
382    scope: &mut Scope,
383    runtime_cond: ExpandElement,
384    then_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
385) -> IfElseExprExpand<C> {
386    let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
387    match comptime_cond {
388        Some(true) => {
389            let ret = then_block(scope);
390            IfElseExprExpand::ComptimeThen(ret)
391        }
392        Some(false) => IfElseExprExpand::ComptimeElse,
393        None => {
394            let mut then_child = scope.child();
395            let ret = then_block(&mut then_child);
396            let out: ExpandElementTyped<C> = scope.create_local_mut(ret.expand.item).into();
397            assign::expand::<C>(&mut then_child, ret, out.clone());
398
399            IfElseExprExpand::Runtime {
400                runtime_cond,
401                out,
402                then_child,
403            }
404        }
405    }
406}
407
408pub struct SwitchExpand<I: Int> {
409    value: ExpandElementTyped<I>,
410    default: Scope,
411    cases: Vec<(ExpandElementTyped<I>, Scope)>,
412}
413
414impl<I: Int> SwitchExpand<I> {
415    pub fn case(
416        mut self,
417        scope: &mut Scope,
418        value: impl Int,
419        block: impl FnOnce(&mut Scope),
420    ) -> Self {
421        let value = I::from(value).unwrap();
422        let mut case_child = scope.child();
423        block(&mut case_child);
424        self.cases.push((value.into(), case_child));
425        self
426    }
427
428    pub fn finish(self, scope: &mut Scope) {
429        let value_var = *self.value.expand;
430        scope.register(Branch::Switch(Box::new(Switch {
431            value: value_var,
432            scope_default: self.default,
433            cases: self
434                .cases
435                .into_iter()
436                .map(|it| (*it.0.expand, it.1))
437                .collect(),
438        })));
439    }
440}
441
442pub fn switch_expand<I: Int>(
443    scope: &mut Scope,
444    value: ExpandElementTyped<I>,
445    default_block: impl FnOnce(&mut Scope),
446) -> SwitchExpand<I> {
447    let mut default_child = scope.child();
448    default_block(&mut default_child);
449
450    SwitchExpand {
451        value,
452        default: default_child,
453        cases: Vec::new(),
454    }
455}
456
457pub struct SwitchExpandExpr<I: Int, C: CubePrimitive> {
458    value: ExpandElementTyped<I>,
459    out: ExpandElementTyped<C>,
460    default: Scope,
461    cases: Vec<(ExpandElementTyped<I>, Scope)>,
462}
463
464impl<I: Int, C: CubePrimitive> SwitchExpandExpr<I, C> {
465    pub fn case(
466        mut self,
467        scope: &mut Scope,
468        value: impl Int,
469        block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
470    ) -> Self {
471        let value = I::from(value).unwrap();
472        let mut case_child = scope.child();
473        let ret = block(&mut case_child);
474        assign::expand::<C>(&mut case_child, ret, self.out.clone());
475        self.cases.push((value.into(), case_child));
476        self
477    }
478
479    pub fn finish(self, scope: &mut Scope) -> ExpandElementTyped<C> {
480        let value_var = *self.value.expand;
481        scope.register(Branch::Switch(Box::new(Switch {
482            value: value_var,
483            scope_default: self.default,
484            cases: self
485                .cases
486                .into_iter()
487                .map(|it| (*it.0.expand, it.1))
488                .collect(),
489        })));
490        self.out
491    }
492}
493
494pub fn switch_expand_expr<I: Int, C: CubePrimitive>(
495    scope: &mut Scope,
496    value: ExpandElementTyped<I>,
497    default_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
498) -> SwitchExpandExpr<I, C> {
499    let mut default_child = scope.child();
500    let default = default_block(&mut default_child);
501    let out: ExpandElementTyped<C> = scope.create_local_mut(default.expand.item).into();
502    assign::expand::<C>(&mut default_child, default, out.clone());
503
504    SwitchExpandExpr {
505        value,
506        out,
507        default: default_child,
508        cases: Vec::new(),
509    }
510}
511
512pub fn break_expand(scope: &mut Scope) {
513    scope.register(Branch::Break);
514}
515
516pub fn return_expand(scope: &mut Scope) {
517    scope.register(Branch::Return);
518}
519
520// Don't make this `FnOnce`, it must be executable multiple times
521pub fn loop_expand(scope: &mut Scope, mut block: impl FnMut(&mut Scope)) {
522    let mut inside_loop = scope.child();
523
524    block(&mut inside_loop);
525    scope.register(Branch::Loop(Box::new(Loop { scope: inside_loop })));
526}