cubecl_core/frontend/
branch.rs

1use cubecl_ir::ExpandElement;
2use num_traits::NumCast;
3
4use crate::ir::Switch;
5use crate::ir::{Branch, If, IfElse, Loop, RangeLoop, Scope, Type};
6
7use super::{CubePrimitive, CubeType, ExpandElementTyped, Int, Numeric, assign};
8
9/// Something that can be iterated on by a for loop. Currently only includes `Range`, `StepBy` and
10/// `Sequence`.
11pub trait Iterable<T: CubeType>: Sized {
12    /// Expand a runtime loop without unrolling
13    ///
14    /// # Arguments
15    /// # Arguments
16    /// * `scope` - the expansion scope
17    /// * `body` - the loop body to be executed repeatedly
18    fn expand(self, scope: &mut Scope, body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType));
19    /// Expand an unrolled loop. The body should be invoced `n` times, where `n` is the number of
20    /// iterations.
21    ///
22    /// # Arguments
23    /// * `scope` - the expansion scope
24    /// * `body` - the loop body to be executed repeatedly
25    fn expand_unroll(
26        self,
27        scope: &mut Scope,
28        body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
29    );
30    /// Return the comptime length of this iterable, if possible
31    fn const_len(&self) -> Option<usize> {
32        None
33    }
34}
35
36pub struct RangeExpand<I: Int> {
37    pub start: ExpandElementTyped<I>,
38    pub end: ExpandElementTyped<I>,
39    pub inclusive: bool,
40}
41
42impl<I: Int> RangeExpand<I> {
43    pub fn new(start: ExpandElementTyped<I>, end: ExpandElementTyped<I>, inclusive: bool) -> Self {
44        RangeExpand {
45            start,
46            end,
47            inclusive,
48        }
49    }
50
51    pub fn __expand_step_by_method(
52        self,
53        n: impl Into<ExpandElementTyped<I>>,
54    ) -> SteppedRangeExpand<I> {
55        SteppedRangeExpand {
56            start: self.start,
57            end: self.end,
58            step: n.into(),
59            inclusive: self.inclusive,
60        }
61    }
62}
63
64impl<I: Int> Iterable<I> for RangeExpand<I> {
65    fn expand_unroll(
66        self,
67        scope: &mut Scope,
68        mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
69    ) {
70        let start = self
71            .start
72            .expand
73            .as_const()
74            .expect("Only constant start can be unrolled.")
75            .as_i64();
76        let end = self
77            .end
78            .expand
79            .as_const()
80            .expect("Only constant end can be unrolled.")
81            .as_i64();
82
83        if self.inclusive {
84            for i in start..=end {
85                let var = I::from_int(i);
86                body(scope, var.into())
87            }
88        } else {
89            for i in start..end {
90                let var = I::from_int(i);
91                body(scope, var.into())
92            }
93        }
94    }
95
96    fn expand(
97        self,
98        scope: &mut Scope,
99        mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
100    ) {
101        let mut child = scope.child();
102        let index_ty = Type::new(I::as_type(scope));
103        let i = child.create_local_restricted(index_ty);
104
105        body(&mut child, i.clone().into());
106
107        scope.register(Branch::RangeLoop(Box::new(RangeLoop {
108            i: *i,
109            start: *self.start.expand,
110            end: *self.end.expand,
111            step: None,
112            scope: child,
113            inclusive: self.inclusive,
114        })));
115    }
116
117    fn const_len(&self) -> Option<usize> {
118        let start = self.start.expand.as_const()?.as_i64();
119        let end = self.end.expand.as_const()?.as_i64();
120        Some(start.abs_diff(end) as usize)
121    }
122}
123
124pub struct SteppedRangeExpand<I: Int> {
125    start: ExpandElementTyped<I>,
126    end: ExpandElementTyped<I>,
127    step: ExpandElementTyped<I>,
128    inclusive: bool,
129}
130
131impl<I: Int + Into<ExpandElement>> Iterable<I> for SteppedRangeExpand<I> {
132    fn expand(
133        self,
134        scope: &mut Scope,
135        mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
136    ) {
137        let mut child = scope.child();
138        let index_ty = Type::new(I::as_type(scope));
139        let i = child.create_local_restricted(index_ty);
140
141        body(&mut child, i.clone().into());
142
143        scope.register(Branch::RangeLoop(Box::new(RangeLoop {
144            i: *i,
145            start: *self.start.expand,
146            end: *self.end.expand,
147            step: Some(*self.step.expand),
148            scope: child,
149            inclusive: self.inclusive,
150        })));
151    }
152
153    fn expand_unroll(
154        self,
155        scope: &mut Scope,
156        mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
157    ) {
158        let start = self
159            .start
160            .expand
161            .as_const()
162            .expect("Only constant start can be unrolled.")
163            .as_i128();
164        let end = self
165            .end
166            .expand
167            .as_const()
168            .expect("Only constant end can be unrolled.")
169            .as_i128();
170        let step = self
171            .step
172            .expand
173            .as_const()
174            .expect("Only constant step can be unrolled.")
175            .as_i128();
176
177        match (self.inclusive, step.is_negative()) {
178            (true, true) => {
179                for i in (end..=start).rev().step_by(step.unsigned_abs() as usize) {
180                    let var = I::from_int_128(i);
181                    body(scope, var.into())
182                }
183            }
184            (true, false) => {
185                for i in (start..=end).step_by(step.unsigned_abs() as usize) {
186                    let var = I::from_int_128(i);
187                    body(scope, var.into())
188                }
189            }
190            (false, true) => {
191                for i in (end..start).rev().step_by(step.unsigned_abs() as usize) {
192                    let var = I::from_int_128(i);
193                    body(scope, var.into())
194                }
195            }
196            (false, false) => {
197                for i in (start..end).step_by(step.unsigned_abs() as usize) {
198                    let var = I::from_int_128(i);
199                    body(scope, var.into())
200                }
201            }
202        }
203    }
204
205    fn const_len(&self) -> Option<usize> {
206        let start = self.start.constant()?.as_i128();
207        let end = self.end.constant()?.as_i128();
208        let step = self.step.constant()?.as_i128().unsigned_abs();
209        Some((start.abs_diff(end) / step) as usize)
210    }
211}
212
213/// integer range. Equivalent to:
214///
215/// ```ignore
216/// start..end
217/// ```
218pub fn range<T: Int>(start: T, end: T) -> impl Iterator<Item = T> {
219    let start: i64 = start.to_i64().unwrap();
220    let end: i64 = end.to_i64().unwrap();
221    (start..end).map(<T as NumCast>::from).map(Option::unwrap)
222}
223
224pub mod range {
225    use cubecl_ir::Scope;
226
227    use crate::prelude::{ExpandElementTyped, Int};
228
229    use super::RangeExpand;
230
231    pub fn expand<I: Int>(
232        _scope: &mut Scope,
233        start: ExpandElementTyped<I>,
234        end: ExpandElementTyped<I>,
235    ) -> RangeExpand<I> {
236        RangeExpand {
237            start,
238            end,
239            inclusive: false,
240        }
241    }
242}
243
244/// Stepped range. Equivalent to:
245///
246/// ```ignore
247/// (start..end).step_by(step)
248/// ```
249///
250/// Allows using any integer for the step, instead of just usize
251pub fn range_stepped<I: Int>(start: I, end: I, step: I) -> Box<dyn Iterator<Item = I>> {
252    let start = start.to_i128().unwrap();
253    let end = end.to_i128().unwrap();
254    let step = step.to_i128().unwrap();
255
256    if step < 0 {
257        Box::new(
258            (end..start)
259                .rev()
260                .step_by(step.unsigned_abs() as usize)
261                .map(<I as NumCast>::from)
262                .map(Option::unwrap),
263        )
264    } else {
265        Box::new(
266            (start..end)
267                .step_by(step.unsigned_abs() as usize)
268                .map(<I as NumCast>::from)
269                .map(Option::unwrap),
270        )
271    }
272}
273
274pub mod range_stepped {
275    use cubecl_ir::Scope;
276
277    use crate::prelude::{ExpandElementTyped, Int};
278
279    use super::SteppedRangeExpand;
280
281    pub fn expand<I: Int>(
282        _scope: &mut Scope,
283        start: ExpandElementTyped<I>,
284        end: ExpandElementTyped<I>,
285        step: ExpandElementTyped<I>,
286    ) -> SteppedRangeExpand<I> {
287        SteppedRangeExpand {
288            start,
289            end,
290            step,
291            inclusive: false,
292        }
293    }
294}
295
296pub fn for_expand<I: Numeric>(
297    scope: &mut Scope,
298    range: impl Iterable<I>,
299    unroll: bool,
300    body: impl FnMut(&mut Scope, ExpandElementTyped<I>),
301) {
302    if unroll || range.const_len() == Some(1) {
303        range.expand_unroll(scope, body);
304    } else {
305        range.expand(scope, body);
306    }
307}
308
309pub fn if_expand(scope: &mut Scope, runtime_cond: ExpandElement, block: impl FnOnce(&mut Scope)) {
310    let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
311    match comptime_cond {
312        Some(cond) => {
313            if cond {
314                block(scope);
315            }
316        }
317        None => {
318            let mut child = scope.child();
319
320            block(&mut child);
321
322            scope.register(Branch::If(Box::new(If {
323                cond: *runtime_cond,
324                scope: child,
325            })));
326        }
327    }
328}
329
330#[allow(clippy::large_enum_variant)]
331pub enum IfElseExpand {
332    ComptimeThen,
333    ComptimeElse,
334    Runtime {
335        runtime_cond: ExpandElement,
336        then_child: Scope,
337    },
338}
339
340impl IfElseExpand {
341    pub fn or_else(self, scope: &mut Scope, else_block: impl FnOnce(&mut Scope)) {
342        match self {
343            Self::Runtime {
344                runtime_cond,
345                then_child,
346            } => {
347                let mut else_child = scope.child();
348                else_block(&mut else_child);
349
350                scope.register(Branch::IfElse(Box::new(IfElse {
351                    cond: *runtime_cond,
352                    scope_if: then_child,
353                    scope_else: else_child,
354                })));
355            }
356            Self::ComptimeElse => else_block(scope),
357            Self::ComptimeThen => (),
358        }
359    }
360}
361
362pub fn if_else_expand(
363    scope: &mut Scope,
364    runtime_cond: ExpandElement,
365    then_block: impl FnOnce(&mut Scope),
366) -> IfElseExpand {
367    let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
368    match comptime_cond {
369        Some(true) => {
370            then_block(scope);
371            IfElseExpand::ComptimeThen
372        }
373        Some(false) => IfElseExpand::ComptimeElse,
374        None => {
375            let mut then_child = scope.child();
376            then_block(&mut then_child);
377
378            IfElseExpand::Runtime {
379                runtime_cond,
380                then_child,
381            }
382        }
383    }
384}
385
386#[allow(clippy::large_enum_variant)]
387pub enum IfElseExprExpand<C: CubeType> {
388    ComptimeThen(ExpandElementTyped<C>),
389    ComptimeElse,
390    Runtime {
391        runtime_cond: ExpandElement,
392        out: ExpandElementTyped<C>,
393        then_child: Scope,
394    },
395}
396
397impl<C: CubePrimitive> IfElseExprExpand<C> {
398    pub fn or_else(
399        self,
400        scope: &mut Scope,
401        else_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
402    ) -> ExpandElementTyped<C> {
403        match self {
404            Self::Runtime {
405                runtime_cond,
406                out,
407                then_child,
408            } => {
409                let mut else_child = scope.child();
410                let ret = else_block(&mut else_child);
411                assign::expand_no_check::<C>(&mut else_child, ret, out.clone());
412
413                scope.register(Branch::IfElse(Box::new(IfElse {
414                    cond: *runtime_cond,
415                    scope_if: then_child,
416                    scope_else: else_child,
417                })));
418                out
419            }
420            Self::ComptimeElse => else_block(scope),
421            Self::ComptimeThen(ret) => ret,
422        }
423    }
424}
425
426pub fn if_else_expr_expand<C: CubePrimitive>(
427    scope: &mut Scope,
428    runtime_cond: ExpandElement,
429    then_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
430) -> IfElseExprExpand<C> {
431    let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
432    match comptime_cond {
433        Some(true) => {
434            let ret = then_block(scope);
435            IfElseExprExpand::ComptimeThen(ret)
436        }
437        Some(false) => IfElseExprExpand::ComptimeElse,
438        None => {
439            let mut then_child = scope.child();
440            let ret = then_block(&mut then_child);
441            let out: ExpandElementTyped<C> = scope.create_local_mut(ret.expand.ty).into();
442            assign::expand_no_check::<C>(&mut then_child, ret, out.clone());
443
444            IfElseExprExpand::Runtime {
445                runtime_cond,
446                out,
447                then_child,
448            }
449        }
450    }
451}
452
453pub struct SwitchExpand<I: Int> {
454    value: ExpandElementTyped<I>,
455    default: Scope,
456    cases: Vec<(ExpandElementTyped<I>, Scope)>,
457}
458
459impl<I: Int> SwitchExpand<I> {
460    pub fn case(
461        mut self,
462        scope: &mut Scope,
463        value: impl Int,
464        block: impl FnOnce(&mut Scope),
465    ) -> Self {
466        let value = I::from(value).unwrap();
467        let mut case_child = scope.child();
468        block(&mut case_child);
469        self.cases.push((value.into(), case_child));
470        self
471    }
472
473    pub fn finish(self, scope: &mut Scope) {
474        let value_var = *self.value.expand;
475        scope.register(Branch::Switch(Box::new(Switch {
476            value: value_var,
477            scope_default: self.default,
478            cases: self
479                .cases
480                .into_iter()
481                .map(|it| (*it.0.expand, it.1))
482                .collect(),
483        })));
484    }
485}
486
487pub fn switch_expand<I: Int>(
488    scope: &mut Scope,
489    value: ExpandElementTyped<I>,
490    default_block: impl FnOnce(&mut Scope),
491) -> SwitchExpand<I> {
492    let mut default_child = scope.child();
493    default_block(&mut default_child);
494
495    SwitchExpand {
496        value,
497        default: default_child,
498        cases: Vec::new(),
499    }
500}
501
502pub struct SwitchExpandExpr<I: Int, C: CubePrimitive> {
503    value: ExpandElementTyped<I>,
504    out: ExpandElementTyped<C>,
505    default: Scope,
506    cases: Vec<(ExpandElementTyped<I>, Scope)>,
507}
508
509impl<I: Int, C: CubePrimitive> SwitchExpandExpr<I, C> {
510    pub fn case(
511        mut self,
512        scope: &mut Scope,
513        value: impl Int,
514        block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
515    ) -> Self {
516        let value = I::from(value).unwrap();
517        let mut case_child = scope.child();
518        let ret = block(&mut case_child);
519        assign::expand_no_check::<C>(&mut case_child, ret, self.out.clone());
520        self.cases.push((value.into(), case_child));
521        self
522    }
523
524    pub fn finish(self, scope: &mut Scope) -> ExpandElementTyped<C> {
525        let value_var = *self.value.expand;
526        scope.register(Branch::Switch(Box::new(Switch {
527            value: value_var,
528            scope_default: self.default,
529            cases: self
530                .cases
531                .into_iter()
532                .map(|it| (*it.0.expand, it.1))
533                .collect(),
534        })));
535        self.out
536    }
537}
538
539pub fn switch_expand_expr<I: Int, C: CubePrimitive>(
540    scope: &mut Scope,
541    value: ExpandElementTyped<I>,
542    default_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
543) -> SwitchExpandExpr<I, C> {
544    let mut default_child = scope.child();
545    let default = default_block(&mut default_child);
546    let out: ExpandElementTyped<C> = scope.create_local_mut(default.expand.ty).into();
547    assign::expand_no_check::<C>(&mut default_child, default, out.clone());
548
549    SwitchExpandExpr {
550        value,
551        out,
552        default: default_child,
553        cases: Vec::new(),
554    }
555}
556
557pub fn break_expand(scope: &mut Scope) {
558    scope.register(Branch::Break);
559}
560
561pub fn return_expand(scope: &mut Scope) {
562    scope.register(Branch::Return);
563}
564
565// Don't make this `FnOnce`, it must be executable multiple times
566pub fn loop_expand(scope: &mut Scope, mut block: impl FnMut(&mut Scope)) {
567    let mut inside_loop = scope.child();
568
569    block(&mut inside_loop);
570    scope.register(Branch::Loop(Box::new(Loop { scope: inside_loop })));
571}