cubecl_core/frontend/
branch.rs

1use cubecl_ir::ExpandElement;
2use num_traits::NumCast;
3
4use crate::ir::Switch;
5use crate::ir::{Branch, If, IfElse, Loop, RangeLoop, Scope, Type};
6
7use super::{CubePrimitive, CubeType, ExpandElementTyped, Int, Numeric, assign};
8
9/// Something that can be iterated on by a for loop. Currently only includes `Range`, `StepBy` and
10/// `Sequence`.
11pub trait Iterable<T: CubeType>: Sized {
12    /// Expand a runtime loop without unrolling
13    ///
14    /// # Arguments
15    /// # Arguments
16    /// * `scope` - the expansion scope
17    /// * `body` - the loop body to be executed repeatedly
18    fn expand(self, scope: &mut Scope, body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType));
19    /// Expand an unrolled loop. The body should be invoced `n` times, where `n` is the number of
20    /// iterations.
21    ///
22    /// # Arguments
23    /// * `scope` - the expansion scope
24    /// * `body` - the loop body to be executed repeatedly
25    fn expand_unroll(
26        self,
27        scope: &mut Scope,
28        body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
29    );
30    /// Return the comptime length of this iterable, if possible
31    fn const_len(&self) -> Option<usize> {
32        None
33    }
34}
35
36pub struct RangeExpand<I: Int> {
37    pub start: ExpandElementTyped<I>,
38    pub end: ExpandElementTyped<I>,
39    pub inclusive: bool,
40}
41
42impl<I: Int> RangeExpand<I> {
43    pub fn new(start: ExpandElementTyped<I>, end: ExpandElementTyped<I>, inclusive: bool) -> Self {
44        RangeExpand {
45            start,
46            end,
47            inclusive,
48        }
49    }
50
51    pub fn __expand_step_by_method(
52        self,
53        n: impl Into<ExpandElementTyped<u32>>,
54    ) -> SteppedRangeExpand<I> {
55        SteppedRangeExpand {
56            start: self.start,
57            end: self.end,
58            step: n.into(),
59            inclusive: self.inclusive,
60        }
61    }
62}
63
64impl<I: Int> Iterable<I> for RangeExpand<I> {
65    fn expand_unroll(
66        self,
67        scope: &mut Scope,
68        mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
69    ) {
70        let start = self
71            .start
72            .expand
73            .as_const()
74            .expect("Only constant start can be unrolled.")
75            .as_i64();
76        let end = self
77            .end
78            .expand
79            .as_const()
80            .expect("Only constant end can be unrolled.")
81            .as_i64();
82
83        if self.inclusive {
84            for i in start..=end {
85                let var = I::from_int(i);
86                body(scope, var.into())
87            }
88        } else {
89            for i in start..end {
90                let var = I::from_int(i);
91                body(scope, var.into())
92            }
93        }
94    }
95
96    fn expand(
97        self,
98        scope: &mut Scope,
99        mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
100    ) {
101        let mut child = scope.child();
102        let index_ty = Type::new(I::as_type(scope));
103        let i = child.create_local_restricted(index_ty);
104
105        body(&mut child, i.clone().into());
106
107        scope.register(Branch::RangeLoop(Box::new(RangeLoop {
108            i: *i,
109            start: *self.start.expand,
110            end: *self.end.expand,
111            step: None,
112            scope: child,
113            inclusive: self.inclusive,
114        })));
115    }
116
117    fn const_len(&self) -> Option<usize> {
118        let start = self.start.expand.as_const()?.as_i64();
119        let end = self.end.expand.as_const()?.as_i64();
120        Some(start.abs_diff(end) as usize)
121    }
122}
123
124pub struct SteppedRangeExpand<I: Int> {
125    start: ExpandElementTyped<I>,
126    end: ExpandElementTyped<I>,
127    step: ExpandElementTyped<u32>,
128    inclusive: bool,
129}
130
131impl<I: Int + Into<ExpandElement>> Iterable<I> for SteppedRangeExpand<I> {
132    fn expand(
133        self,
134        scope: &mut Scope,
135        mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
136    ) {
137        let mut child = scope.child();
138        let index_ty = Type::new(I::as_type(scope));
139        let i = child.create_local_restricted(index_ty);
140
141        body(&mut child, i.clone().into());
142
143        scope.register(Branch::RangeLoop(Box::new(RangeLoop {
144            i: *i,
145            start: *self.start.expand,
146            end: *self.end.expand,
147            step: Some(*self.step.expand),
148            scope: child,
149            inclusive: self.inclusive,
150        })));
151    }
152
153    fn expand_unroll(
154        self,
155        scope: &mut Scope,
156        mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
157    ) {
158        let start = self
159            .start
160            .expand
161            .as_const()
162            .expect("Only constant start can be unrolled.")
163            .as_i64();
164        let end = self
165            .end
166            .expand
167            .as_const()
168            .expect("Only constant end can be unrolled.")
169            .as_i64();
170        let step = self
171            .step
172            .expand
173            .as_const()
174            .expect("Only constant step can be unrolled.")
175            .as_usize();
176
177        if self.inclusive {
178            for i in (start..=end).step_by(step) {
179                let var = I::from_int(i);
180                body(scope, var.into())
181            }
182        } else {
183            for i in (start..end).step_by(step) {
184                let var = I::from_int(i);
185                body(scope, var.into())
186            }
187        }
188    }
189
190    fn const_len(&self) -> Option<usize> {
191        let start = self.start.constant()?.as_i64();
192        let end = self.end.constant()?.as_i64();
193        let step = self.step.constant()?.as_u64();
194        Some((start.abs_diff(end) / step) as usize)
195    }
196}
197
198/// integer range. Equivalent to:
199///
200/// ```ignore
201/// start..end
202/// ```
203pub fn range<T: Int>(start: T, end: T) -> impl Iterator<Item = T> {
204    let start: i64 = start.to_i64().unwrap();
205    let end: i64 = end.to_i64().unwrap();
206    (start..end).map(<T as NumCast>::from).map(Option::unwrap)
207}
208
209pub mod range {
210    use cubecl_ir::Scope;
211
212    use crate::prelude::{ExpandElementTyped, Int};
213
214    use super::RangeExpand;
215
216    pub fn expand<I: Int>(
217        _scope: &mut Scope,
218        start: ExpandElementTyped<I>,
219        end: ExpandElementTyped<I>,
220    ) -> RangeExpand<I> {
221        RangeExpand {
222            start,
223            end,
224            inclusive: false,
225        }
226    }
227}
228
229/// Stepped range. Equivalent to:
230///
231/// ```ignore
232/// (start..end).step_by(step)
233/// ```
234///
235/// Allows using any integer for the step, instead of just usize
236pub fn range_stepped<I: Int>(start: I, end: I, step: impl Int) -> impl Iterator<Item = I> {
237    let start = start.to_i64().unwrap();
238    let end = end.to_i64().unwrap();
239    let step = step.to_usize().unwrap();
240    (start..end)
241        .step_by(step)
242        .map(<I as NumCast>::from)
243        .map(Option::unwrap)
244}
245
246pub mod range_stepped {
247    use cubecl_ir::Scope;
248
249    use crate::prelude::{ExpandElementTyped, Int};
250
251    use super::SteppedRangeExpand;
252
253    pub fn expand<I: Int>(
254        _scope: &mut Scope,
255        start: ExpandElementTyped<I>,
256        end: ExpandElementTyped<I>,
257        step: ExpandElementTyped<u32>,
258    ) -> SteppedRangeExpand<I> {
259        SteppedRangeExpand {
260            start,
261            end,
262            step,
263            inclusive: false,
264        }
265    }
266}
267
268pub fn for_expand<I: Numeric>(
269    scope: &mut Scope,
270    range: impl Iterable<I>,
271    unroll: bool,
272    body: impl FnMut(&mut Scope, ExpandElementTyped<I>),
273) {
274    if unroll || range.const_len() == Some(1) {
275        range.expand_unroll(scope, body);
276    } else {
277        range.expand(scope, body);
278    }
279}
280
281pub fn if_expand(scope: &mut Scope, runtime_cond: ExpandElement, block: impl FnOnce(&mut Scope)) {
282    let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
283    match comptime_cond {
284        Some(cond) => {
285            if cond {
286                block(scope);
287            }
288        }
289        None => {
290            let mut child = scope.child();
291
292            block(&mut child);
293
294            scope.register(Branch::If(Box::new(If {
295                cond: *runtime_cond,
296                scope: child,
297            })));
298        }
299    }
300}
301
302#[allow(clippy::large_enum_variant)]
303pub enum IfElseExpand {
304    ComptimeThen,
305    ComptimeElse,
306    Runtime {
307        runtime_cond: ExpandElement,
308        then_child: Scope,
309    },
310}
311
312impl IfElseExpand {
313    pub fn or_else(self, scope: &mut Scope, else_block: impl FnOnce(&mut Scope)) {
314        match self {
315            Self::Runtime {
316                runtime_cond,
317                then_child,
318            } => {
319                let mut else_child = scope.child();
320                else_block(&mut else_child);
321
322                scope.register(Branch::IfElse(Box::new(IfElse {
323                    cond: *runtime_cond,
324                    scope_if: then_child,
325                    scope_else: else_child,
326                })));
327            }
328            Self::ComptimeElse => else_block(scope),
329            Self::ComptimeThen => (),
330        }
331    }
332}
333
334pub fn if_else_expand(
335    scope: &mut Scope,
336    runtime_cond: ExpandElement,
337    then_block: impl FnOnce(&mut Scope),
338) -> IfElseExpand {
339    let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
340    match comptime_cond {
341        Some(true) => {
342            then_block(scope);
343            IfElseExpand::ComptimeThen
344        }
345        Some(false) => IfElseExpand::ComptimeElse,
346        None => {
347            let mut then_child = scope.child();
348            then_block(&mut then_child);
349
350            IfElseExpand::Runtime {
351                runtime_cond,
352                then_child,
353            }
354        }
355    }
356}
357
358#[allow(clippy::large_enum_variant)]
359pub enum IfElseExprExpand<C: CubeType> {
360    ComptimeThen(ExpandElementTyped<C>),
361    ComptimeElse,
362    Runtime {
363        runtime_cond: ExpandElement,
364        out: ExpandElementTyped<C>,
365        then_child: Scope,
366    },
367}
368
369impl<C: CubePrimitive> IfElseExprExpand<C> {
370    pub fn or_else(
371        self,
372        scope: &mut Scope,
373        else_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
374    ) -> ExpandElementTyped<C> {
375        match self {
376            Self::Runtime {
377                runtime_cond,
378                out,
379                then_child,
380            } => {
381                let mut else_child = scope.child();
382                let ret = else_block(&mut else_child);
383                assign::expand_no_check::<C>(&mut else_child, ret, out.clone());
384
385                scope.register(Branch::IfElse(Box::new(IfElse {
386                    cond: *runtime_cond,
387                    scope_if: then_child,
388                    scope_else: else_child,
389                })));
390                out
391            }
392            Self::ComptimeElse => else_block(scope),
393            Self::ComptimeThen(ret) => ret,
394        }
395    }
396}
397
398pub fn if_else_expr_expand<C: CubePrimitive>(
399    scope: &mut Scope,
400    runtime_cond: ExpandElement,
401    then_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
402) -> IfElseExprExpand<C> {
403    let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
404    match comptime_cond {
405        Some(true) => {
406            let ret = then_block(scope);
407            IfElseExprExpand::ComptimeThen(ret)
408        }
409        Some(false) => IfElseExprExpand::ComptimeElse,
410        None => {
411            let mut then_child = scope.child();
412            let ret = then_block(&mut then_child);
413            let out: ExpandElementTyped<C> = scope.create_local_mut(ret.expand.ty).into();
414            assign::expand_no_check::<C>(&mut then_child, ret, out.clone());
415
416            IfElseExprExpand::Runtime {
417                runtime_cond,
418                out,
419                then_child,
420            }
421        }
422    }
423}
424
425pub struct SwitchExpand<I: Int> {
426    value: ExpandElementTyped<I>,
427    default: Scope,
428    cases: Vec<(ExpandElementTyped<I>, Scope)>,
429}
430
431impl<I: Int> SwitchExpand<I> {
432    pub fn case(
433        mut self,
434        scope: &mut Scope,
435        value: impl Int,
436        block: impl FnOnce(&mut Scope),
437    ) -> Self {
438        let value = I::from(value).unwrap();
439        let mut case_child = scope.child();
440        block(&mut case_child);
441        self.cases.push((value.into(), case_child));
442        self
443    }
444
445    pub fn finish(self, scope: &mut Scope) {
446        let value_var = *self.value.expand;
447        scope.register(Branch::Switch(Box::new(Switch {
448            value: value_var,
449            scope_default: self.default,
450            cases: self
451                .cases
452                .into_iter()
453                .map(|it| (*it.0.expand, it.1))
454                .collect(),
455        })));
456    }
457}
458
459pub fn switch_expand<I: Int>(
460    scope: &mut Scope,
461    value: ExpandElementTyped<I>,
462    default_block: impl FnOnce(&mut Scope),
463) -> SwitchExpand<I> {
464    let mut default_child = scope.child();
465    default_block(&mut default_child);
466
467    SwitchExpand {
468        value,
469        default: default_child,
470        cases: Vec::new(),
471    }
472}
473
474pub struct SwitchExpandExpr<I: Int, C: CubePrimitive> {
475    value: ExpandElementTyped<I>,
476    out: ExpandElementTyped<C>,
477    default: Scope,
478    cases: Vec<(ExpandElementTyped<I>, Scope)>,
479}
480
481impl<I: Int, C: CubePrimitive> SwitchExpandExpr<I, C> {
482    pub fn case(
483        mut self,
484        scope: &mut Scope,
485        value: impl Int,
486        block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
487    ) -> Self {
488        let value = I::from(value).unwrap();
489        let mut case_child = scope.child();
490        let ret = block(&mut case_child);
491        assign::expand_no_check::<C>(&mut case_child, ret, self.out.clone());
492        self.cases.push((value.into(), case_child));
493        self
494    }
495
496    pub fn finish(self, scope: &mut Scope) -> ExpandElementTyped<C> {
497        let value_var = *self.value.expand;
498        scope.register(Branch::Switch(Box::new(Switch {
499            value: value_var,
500            scope_default: self.default,
501            cases: self
502                .cases
503                .into_iter()
504                .map(|it| (*it.0.expand, it.1))
505                .collect(),
506        })));
507        self.out
508    }
509}
510
511pub fn switch_expand_expr<I: Int, C: CubePrimitive>(
512    scope: &mut Scope,
513    value: ExpandElementTyped<I>,
514    default_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
515) -> SwitchExpandExpr<I, C> {
516    let mut default_child = scope.child();
517    let default = default_block(&mut default_child);
518    let out: ExpandElementTyped<C> = scope.create_local_mut(default.expand.ty).into();
519    assign::expand_no_check::<C>(&mut default_child, default, out.clone());
520
521    SwitchExpandExpr {
522        value,
523        out,
524        default: default_child,
525        cases: Vec::new(),
526    }
527}
528
529pub fn break_expand(scope: &mut Scope) {
530    scope.register(Branch::Break);
531}
532
533pub fn return_expand(scope: &mut Scope) {
534    scope.register(Branch::Return);
535}
536
537// Don't make this `FnOnce`, it must be executable multiple times
538pub fn loop_expand(scope: &mut Scope, mut block: impl FnMut(&mut Scope)) {
539    let mut inside_loop = scope.child();
540
541    block(&mut inside_loop);
542    scope.register(Branch::Loop(Box::new(Loop { scope: inside_loop })));
543}