cubecl_core/frontend/
branch.rs

1use num_traits::NumCast;
2
3use crate::ir::{Branch, If, IfElse, Item, Loop, RangeLoop};
4use crate::{
5    frontend::{CubeContext, ExpandElement},
6    ir::Switch,
7};
8
9use super::{assign, CubePrimitive, CubeType, ExpandElementTyped, Int, Numeric};
10
11/// Something that can be iterated on by a for loop. Currently only includes `Range`, `StepBy` and
12/// `Sequence`.
13pub trait Iterable<T: CubeType>: Sized {
14    /// Expand a runtime loop without unrolling
15    ///
16    /// # Arguments
17    /// * `context` - the expansion context
18    /// * `body` - the loop body to be executed repeatedly
19    fn expand(
20        self,
21        context: &mut CubeContext,
22        body: impl FnMut(&mut CubeContext, <T as CubeType>::ExpandType),
23    );
24    /// Expand an unrolled loop. The body should be invoced `n` times, where `n` is the number of
25    /// iterations.
26    ///
27    /// # Arguments
28    /// * `context` - the expansion context
29    /// * `body` - the loop body to be executed repeatedly
30    fn expand_unroll(
31        self,
32        context: &mut CubeContext,
33        body: impl FnMut(&mut CubeContext, <T as CubeType>::ExpandType),
34    );
35}
36
37pub struct RangeExpand<I: Int> {
38    pub start: ExpandElementTyped<I>,
39    pub end: ExpandElementTyped<I>,
40    pub inclusive: bool,
41}
42
43impl<I: Int> RangeExpand<I> {
44    pub fn new(start: ExpandElementTyped<I>, end: ExpandElementTyped<I>, inclusive: bool) -> Self {
45        RangeExpand {
46            start,
47            end,
48            inclusive,
49        }
50    }
51
52    pub fn __expand_step_by_method(
53        self,
54        n: impl Into<ExpandElementTyped<u32>>,
55    ) -> SteppedRangeExpand<I> {
56        SteppedRangeExpand {
57            start: self.start,
58            end: self.end,
59            step: n.into(),
60            inclusive: self.inclusive,
61        }
62    }
63}
64
65impl<I: Int> Iterable<I> for RangeExpand<I> {
66    fn expand_unroll(
67        self,
68        context: &mut CubeContext,
69        mut body: impl FnMut(&mut CubeContext, <I as CubeType>::ExpandType),
70    ) {
71        let start = self
72            .start
73            .expand
74            .as_const()
75            .expect("Only constant start can be unrolled.")
76            .as_i64();
77        let end = self
78            .end
79            .expand
80            .as_const()
81            .expect("Only constant end can be unrolled.")
82            .as_i64();
83
84        if self.inclusive {
85            for i in start..=end {
86                let var = I::from_int(i);
87                body(context, var.into())
88            }
89        } else {
90            for i in start..end {
91                let var = I::from_int(i);
92                body(context, var.into())
93            }
94        }
95    }
96
97    fn expand(
98        self,
99        context: &mut CubeContext,
100        mut body: impl FnMut(&mut CubeContext, <I as CubeType>::ExpandType),
101    ) {
102        let mut child = context.child();
103        let index_ty = Item::new(I::as_elem(context));
104        let i = child.create_local_restricted(index_ty);
105
106        body(&mut child, i.clone().into());
107
108        context.register(Branch::RangeLoop(Box::new(RangeLoop {
109            i: *i,
110            start: *self.start.expand,
111            end: *self.end.expand,
112            step: None,
113            scope: child.into_scope(),
114            inclusive: self.inclusive,
115        })));
116    }
117}
118
119pub struct SteppedRangeExpand<I: Int> {
120    start: ExpandElementTyped<I>,
121    end: ExpandElementTyped<I>,
122    step: ExpandElementTyped<u32>,
123    inclusive: bool,
124}
125
126impl<I: Int + Into<ExpandElement>> Iterable<I> for SteppedRangeExpand<I> {
127    fn expand(
128        self,
129        context: &mut CubeContext,
130        mut body: impl FnMut(&mut CubeContext, <I as CubeType>::ExpandType),
131    ) {
132        let mut child = context.child();
133        let index_ty = Item::new(I::as_elem(context));
134        let i = child.create_local_restricted(index_ty);
135
136        body(&mut child, i.clone().into());
137
138        context.register(Branch::RangeLoop(Box::new(RangeLoop {
139            i: *i,
140            start: *self.start.expand,
141            end: *self.end.expand,
142            step: Some(*self.step.expand),
143            scope: child.into_scope(),
144            inclusive: self.inclusive,
145        })));
146    }
147
148    fn expand_unroll(
149        self,
150        context: &mut CubeContext,
151        mut body: impl FnMut(&mut CubeContext, <I as CubeType>::ExpandType),
152    ) {
153        let start = self
154            .start
155            .expand
156            .as_const()
157            .expect("Only constant start can be unrolled.")
158            .as_i64();
159        let end = self
160            .end
161            .expand
162            .as_const()
163            .expect("Only constant end can be unrolled.")
164            .as_i64();
165        let step = self
166            .step
167            .expand
168            .as_const()
169            .expect("Only constant step can be unrolled.")
170            .as_usize();
171
172        if self.inclusive {
173            for i in (start..=end).step_by(step) {
174                let var = I::from_int(i);
175                body(context, var.into())
176            }
177        } else {
178            for i in (start..end).step_by(step) {
179                let var = I::from_int(i);
180                body(context, var.into())
181            }
182        }
183    }
184}
185
186/// integer range. Equivalent to:
187///
188/// ```ignore
189/// start..end
190/// ```
191pub fn range<T: Int>(start: T, end: T) -> impl Iterator<Item = T> {
192    let start: i64 = start.to_i64().unwrap();
193    let end: i64 = end.to_i64().unwrap();
194    (start..end).map(<T as NumCast>::from).map(Option::unwrap)
195}
196
197pub mod range {
198    use crate::prelude::{CubeContext, ExpandElementTyped, Int};
199
200    use super::RangeExpand;
201
202    pub fn expand<I: Int>(
203        _context: &mut CubeContext,
204        start: ExpandElementTyped<I>,
205        end: ExpandElementTyped<I>,
206    ) -> RangeExpand<I> {
207        RangeExpand {
208            start,
209            end,
210            inclusive: false,
211        }
212    }
213}
214
215/// Stepped range. Equivalent to:
216///
217/// ```ignore
218/// (start..end).step_by(step)
219/// ```
220///
221/// Allows using any integer for the step, instead of just usize
222pub fn range_stepped<I: Int>(start: I, end: I, step: impl Int) -> impl Iterator<Item = I> {
223    let start = start.to_i64().unwrap();
224    let end = end.to_i64().unwrap();
225    let step = step.to_usize().unwrap();
226    (start..end)
227        .step_by(step)
228        .map(<I as NumCast>::from)
229        .map(Option::unwrap)
230}
231
232pub mod range_stepped {
233    use crate::prelude::{CubeContext, ExpandElementTyped, Int};
234
235    use super::SteppedRangeExpand;
236
237    pub fn expand<I: Int>(
238        _context: &mut CubeContext,
239        start: ExpandElementTyped<I>,
240        end: ExpandElementTyped<I>,
241        step: ExpandElementTyped<u32>,
242    ) -> SteppedRangeExpand<I> {
243        SteppedRangeExpand {
244            start,
245            end,
246            step,
247            inclusive: false,
248        }
249    }
250}
251
252pub fn for_expand<I: Numeric>(
253    context: &mut CubeContext,
254    range: impl Iterable<I>,
255    unroll: bool,
256    body: impl FnMut(&mut CubeContext, ExpandElementTyped<I>),
257) {
258    if unroll {
259        range.expand_unroll(context, body);
260    } else {
261        range.expand(context, body);
262    }
263}
264
265pub fn if_expand(
266    context: &mut CubeContext,
267    runtime_cond: ExpandElement,
268    block: impl FnOnce(&mut CubeContext),
269) {
270    let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
271    match comptime_cond {
272        Some(cond) => {
273            if cond {
274                block(context);
275            }
276        }
277        None => {
278            let mut child = context.child();
279
280            block(&mut child);
281
282            context.register(Branch::If(Box::new(If {
283                cond: *runtime_cond,
284                scope: child.into_scope(),
285            })));
286        }
287    }
288}
289
290pub enum IfElseExpand {
291    ComptimeThen,
292    ComptimeElse,
293    Runtime {
294        runtime_cond: ExpandElement,
295        then_child: CubeContext,
296    },
297}
298
299impl IfElseExpand {
300    pub fn or_else(self, context: &mut CubeContext, else_block: impl FnOnce(&mut CubeContext)) {
301        match self {
302            Self::Runtime {
303                runtime_cond,
304                then_child,
305            } => {
306                let mut else_child = context.child();
307                else_block(&mut else_child);
308
309                context.register(Branch::IfElse(Box::new(IfElse {
310                    cond: *runtime_cond,
311                    scope_if: then_child.into_scope(),
312                    scope_else: else_child.into_scope(),
313                })));
314            }
315            Self::ComptimeElse => else_block(context),
316            Self::ComptimeThen => (),
317        }
318    }
319}
320
321pub fn if_else_expand(
322    context: &mut CubeContext,
323    runtime_cond: ExpandElement,
324    then_block: impl FnOnce(&mut CubeContext),
325) -> IfElseExpand {
326    let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
327    match comptime_cond {
328        Some(true) => {
329            then_block(context);
330            IfElseExpand::ComptimeThen
331        }
332        Some(false) => IfElseExpand::ComptimeElse,
333        None => {
334            let mut then_child = context.child();
335            then_block(&mut then_child);
336
337            IfElseExpand::Runtime {
338                runtime_cond,
339                then_child,
340            }
341        }
342    }
343}
344
345pub enum IfElseExprExpand<C: CubeType> {
346    ComptimeThen(ExpandElementTyped<C>),
347    ComptimeElse,
348    Runtime {
349        runtime_cond: ExpandElement,
350        out: ExpandElementTyped<C>,
351        then_child: CubeContext,
352    },
353}
354
355impl<C: CubePrimitive> IfElseExprExpand<C> {
356    pub fn or_else(
357        self,
358        context: &mut CubeContext,
359        else_block: impl FnOnce(&mut CubeContext) -> ExpandElementTyped<C>,
360    ) -> ExpandElementTyped<C> {
361        match self {
362            Self::Runtime {
363                runtime_cond,
364                out,
365                then_child,
366            } => {
367                let mut else_child = context.child();
368                let ret = else_block(&mut else_child);
369                assign::expand(&mut else_child, ret, out.clone());
370
371                context.register(Branch::IfElse(Box::new(IfElse {
372                    cond: *runtime_cond,
373                    scope_if: then_child.into_scope(),
374                    scope_else: else_child.into_scope(),
375                })));
376                out
377            }
378            Self::ComptimeElse => else_block(context),
379            Self::ComptimeThen(ret) => ret,
380        }
381    }
382}
383
384pub fn if_else_expr_expand<C: CubePrimitive>(
385    context: &mut CubeContext,
386    runtime_cond: ExpandElement,
387    then_block: impl FnOnce(&mut CubeContext) -> ExpandElementTyped<C>,
388) -> IfElseExprExpand<C> {
389    let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
390    match comptime_cond {
391        Some(true) => {
392            let ret = then_block(context);
393            IfElseExprExpand::ComptimeThen(ret)
394        }
395        Some(false) => IfElseExprExpand::ComptimeElse,
396        None => {
397            let mut then_child = context.child();
398            let ret = then_block(&mut then_child);
399            let out: ExpandElementTyped<C> = context.create_local_mut(ret.expand.item).into();
400            assign::expand(&mut then_child, ret, out.clone());
401
402            IfElseExprExpand::Runtime {
403                runtime_cond,
404                out,
405                then_child,
406            }
407        }
408    }
409}
410
411pub struct SwitchExpand<I: Int> {
412    value: ExpandElementTyped<I>,
413    default: CubeContext,
414    cases: Vec<(ExpandElementTyped<I>, CubeContext)>,
415}
416
417impl<I: Int> SwitchExpand<I> {
418    pub fn case(
419        mut self,
420        context: &mut CubeContext,
421        value: impl Int,
422        block: impl FnOnce(&mut CubeContext),
423    ) -> Self {
424        let value = I::from(value).unwrap();
425        let mut case_child = context.child();
426        block(&mut case_child);
427        self.cases.push((value.into(), case_child));
428        self
429    }
430
431    pub fn finish(self, context: &mut CubeContext) {
432        let value_var = *self.value.expand;
433        context.register(Branch::Switch(Box::new(Switch {
434            value: value_var,
435            scope_default: self.default.into_scope(),
436            cases: self
437                .cases
438                .into_iter()
439                .map(|it| (*it.0.expand, it.1.into_scope()))
440                .collect(),
441        })));
442    }
443}
444
445pub fn switch_expand<I: Int>(
446    context: &mut CubeContext,
447    value: ExpandElementTyped<I>,
448    default_block: impl FnOnce(&mut CubeContext),
449) -> SwitchExpand<I> {
450    let mut default_child = context.child();
451    default_block(&mut default_child);
452
453    SwitchExpand {
454        value,
455        default: default_child,
456        cases: Vec::new(),
457    }
458}
459
460pub struct SwitchExpandExpr<I: Int, C: CubePrimitive> {
461    value: ExpandElementTyped<I>,
462    out: ExpandElementTyped<C>,
463    default: CubeContext,
464    cases: Vec<(ExpandElementTyped<I>, CubeContext)>,
465}
466
467impl<I: Int, C: CubePrimitive> SwitchExpandExpr<I, C> {
468    pub fn case(
469        mut self,
470        context: &mut CubeContext,
471        value: impl Int,
472        block: impl FnOnce(&mut CubeContext) -> ExpandElementTyped<C>,
473    ) -> Self {
474        let value = I::from(value).unwrap();
475        let mut case_child = context.child();
476        let ret = block(&mut case_child);
477        assign::expand(&mut case_child, ret, self.out.clone());
478        self.cases.push((value.into(), case_child));
479        self
480    }
481
482    pub fn finish(self, context: &mut CubeContext) -> ExpandElementTyped<C> {
483        let value_var = *self.value.expand;
484        context.register(Branch::Switch(Box::new(Switch {
485            value: value_var,
486            scope_default: self.default.into_scope(),
487            cases: self
488                .cases
489                .into_iter()
490                .map(|it| (*it.0.expand, it.1.into_scope()))
491                .collect(),
492        })));
493        self.out
494    }
495}
496
497pub fn switch_expand_expr<I: Int, C: CubePrimitive>(
498    context: &mut CubeContext,
499    value: ExpandElementTyped<I>,
500    default_block: impl FnOnce(&mut CubeContext) -> ExpandElementTyped<C>,
501) -> SwitchExpandExpr<I, C> {
502    let mut default_child = context.child();
503    let default = default_block(&mut default_child);
504    let out: ExpandElementTyped<C> = context.create_local_mut(default.expand.item).into();
505    assign::expand(&mut default_child, default, out.clone());
506
507    SwitchExpandExpr {
508        value,
509        out,
510        default: default_child,
511        cases: Vec::new(),
512    }
513}
514
515pub fn break_expand(context: &mut CubeContext) {
516    context.register(Branch::Break);
517}
518
519pub fn return_expand(context: &mut CubeContext) {
520    context.register(Branch::Return);
521}
522
523// Don't make this `FnOnce`, it must be executable multiple times
524pub fn loop_expand(context: &mut CubeContext, mut block: impl FnMut(&mut CubeContext)) {
525    let mut inside_loop = context.child();
526
527    block(&mut inside_loop);
528    context.register(Branch::Loop(Box::new(Loop {
529        scope: inside_loop.into_scope(),
530    })));
531}