Skip to main content

cubecl_core/frontend/
branch.rs

1use alloc::{boxed::Box, vec::Vec};
2use cubecl_ir::ExpandElement;
3use num_traits::NumCast;
4
5use crate::ir::Switch;
6use crate::ir::{Branch, If, IfElse, Loop, RangeLoop, Scope, Type};
7
8use super::{CubePrimitive, CubeType, ExpandElementTyped, Int, Numeric, assign};
9
10/// Something that can be iterated on by a for loop. Currently only includes `Range`, `StepBy` and
11/// `Sequence`.
12pub trait Iterable<T: CubeType>: Sized {
13    /// Expand a runtime loop without unrolling
14    ///
15    /// # Arguments
16    /// # Arguments
17    /// * `scope` - the expansion scope
18    /// * `body` - the loop body to be executed repeatedly
19    fn expand(self, scope: &mut Scope, body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType));
20    /// Expand an unrolled loop. The body should be invoced `n` times, where `n` is the number of
21    /// iterations.
22    ///
23    /// # Arguments
24    /// * `scope` - the expansion scope
25    /// * `body` - the loop body to be executed repeatedly
26    fn expand_unroll(
27        self,
28        scope: &mut Scope,
29        body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
30    );
31    /// Return the comptime length of this iterable, if possible
32    fn const_len(&self) -> Option<usize> {
33        None
34    }
35}
36
37pub struct RangeExpand<I: Int> {
38    pub start: ExpandElementTyped<I>,
39    pub end: ExpandElementTyped<I>,
40    pub inclusive: bool,
41}
42
43impl<I: Int> RangeExpand<I> {
44    pub fn new(start: ExpandElementTyped<I>, end: ExpandElementTyped<I>, inclusive: bool) -> Self {
45        RangeExpand {
46            start,
47            end,
48            inclusive,
49        }
50    }
51
52    pub fn __expand_step_by_method(
53        self,
54        n: impl Into<ExpandElementTyped<I>>,
55    ) -> SteppedRangeExpand<I> {
56        SteppedRangeExpand {
57            start: self.start,
58            end: self.end,
59            step: n.into(),
60            inclusive: self.inclusive,
61        }
62    }
63}
64
65impl<I: Int> Iterable<I> for RangeExpand<I> {
66    fn expand_unroll(
67        self,
68        scope: &mut Scope,
69        mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
70    ) {
71        let start = self
72            .start
73            .expand
74            .as_const()
75            .expect("Only constant start can be unrolled.")
76            .as_i64();
77        let end = self
78            .end
79            .expand
80            .as_const()
81            .expect("Only constant end can be unrolled.")
82            .as_i64();
83
84        if self.inclusive {
85            for i in start..=end {
86                let var = I::from_int(i);
87                body(scope, var.into())
88            }
89        } else {
90            for i in start..end {
91                let var = I::from_int(i);
92                body(scope, var.into())
93            }
94        }
95    }
96
97    fn expand(
98        self,
99        scope: &mut Scope,
100        mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
101    ) {
102        let mut child = scope.child();
103        let index_ty = Type::new(I::as_type(scope));
104        let i = child.create_local_restricted(index_ty);
105
106        body(&mut child, i.clone().into());
107
108        scope.register(Branch::RangeLoop(Box::new(RangeLoop {
109            i: *i,
110            start: *self.start.expand,
111            end: *self.end.expand,
112            step: None,
113            scope: child,
114            inclusive: self.inclusive,
115        })));
116    }
117
118    fn const_len(&self) -> Option<usize> {
119        let start = self.start.expand.as_const()?.as_i64();
120        let end = self.end.expand.as_const()?.as_i64();
121        Some(start.abs_diff(end) as usize)
122    }
123}
124
125pub struct SteppedRangeExpand<I: Int> {
126    start: ExpandElementTyped<I>,
127    end: ExpandElementTyped<I>,
128    step: ExpandElementTyped<I>,
129    inclusive: bool,
130}
131
132impl<I: Int + Into<ExpandElement>> Iterable<I> for SteppedRangeExpand<I> {
133    fn expand(
134        self,
135        scope: &mut Scope,
136        mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
137    ) {
138        let mut child = scope.child();
139        let index_ty = Type::new(I::as_type(scope));
140        let i = child.create_local_restricted(index_ty);
141
142        body(&mut child, i.clone().into());
143
144        scope.register(Branch::RangeLoop(Box::new(RangeLoop {
145            i: *i,
146            start: *self.start.expand,
147            end: *self.end.expand,
148            step: Some(*self.step.expand),
149            scope: child,
150            inclusive: self.inclusive,
151        })));
152    }
153
154    fn expand_unroll(
155        self,
156        scope: &mut Scope,
157        mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
158    ) {
159        let start = self
160            .start
161            .expand
162            .as_const()
163            .expect("Only constant start can be unrolled.")
164            .as_i128();
165        let end = self
166            .end
167            .expand
168            .as_const()
169            .expect("Only constant end can be unrolled.")
170            .as_i128();
171        let step = self
172            .step
173            .expand
174            .as_const()
175            .expect("Only constant step can be unrolled.")
176            .as_i128();
177
178        match (self.inclusive, step.is_negative()) {
179            (true, true) => {
180                for i in (end..=start).rev().step_by(step.unsigned_abs() as usize) {
181                    let var = I::from_int_128(i);
182                    body(scope, var.into())
183                }
184            }
185            (true, false) => {
186                for i in (start..=end).step_by(step.unsigned_abs() as usize) {
187                    let var = I::from_int_128(i);
188                    body(scope, var.into())
189                }
190            }
191            (false, true) => {
192                for i in (end..start).rev().step_by(step.unsigned_abs() as usize) {
193                    let var = I::from_int_128(i);
194                    body(scope, var.into())
195                }
196            }
197            (false, false) => {
198                for i in (start..end).step_by(step.unsigned_abs() as usize) {
199                    let var = I::from_int_128(i);
200                    body(scope, var.into())
201                }
202            }
203        }
204    }
205
206    fn const_len(&self) -> Option<usize> {
207        let start = self.start.constant()?.as_i128();
208        let end = self.end.constant()?.as_i128();
209        let step = self.step.constant()?.as_i128().unsigned_abs();
210        Some((start.abs_diff(end) / step) as usize)
211    }
212}
213
214/// integer range. Equivalent to:
215///
216/// ```ignore
217/// start..end
218/// ```
219pub fn range<T: Int>(start: T, end: T) -> impl Iterator<Item = T> {
220    let start: i64 = start.to_i64().unwrap();
221    let end: i64 = end.to_i64().unwrap();
222    (start..end).map(<T as NumCast>::from).map(Option::unwrap)
223}
224
225pub mod range {
226    use cubecl_ir::Scope;
227
228    use crate::prelude::{ExpandElementTyped, Int};
229
230    use super::RangeExpand;
231
232    pub fn expand<I: Int>(
233        _scope: &mut Scope,
234        start: ExpandElementTyped<I>,
235        end: ExpandElementTyped<I>,
236    ) -> RangeExpand<I> {
237        RangeExpand {
238            start,
239            end,
240            inclusive: false,
241        }
242    }
243}
244
245/// Stepped range. Equivalent to:
246///
247/// ```ignore
248/// (start..end).step_by(step)
249/// ```
250///
251/// Allows using any integer for the step, instead of just usize
252pub fn range_stepped<I: Int>(start: I, end: I, step: I) -> Box<dyn Iterator<Item = I>> {
253    let start = start.to_i128().unwrap();
254    let end = end.to_i128().unwrap();
255    let step = step.to_i128().unwrap();
256
257    if step < 0 {
258        Box::new(
259            (end..start)
260                .rev()
261                .step_by(step.unsigned_abs() as usize)
262                .map(<I as NumCast>::from)
263                .map(Option::unwrap),
264        )
265    } else {
266        Box::new(
267            (start..end)
268                .step_by(step.unsigned_abs() as usize)
269                .map(<I as NumCast>::from)
270                .map(Option::unwrap),
271        )
272    }
273}
274
275pub mod range_stepped {
276    use cubecl_ir::Scope;
277
278    use crate::prelude::{ExpandElementTyped, Int};
279
280    use super::SteppedRangeExpand;
281
282    pub fn expand<I: Int>(
283        _scope: &mut Scope,
284        start: ExpandElementTyped<I>,
285        end: ExpandElementTyped<I>,
286        step: ExpandElementTyped<I>,
287    ) -> SteppedRangeExpand<I> {
288        SteppedRangeExpand {
289            start,
290            end,
291            step,
292            inclusive: false,
293        }
294    }
295}
296
297pub fn for_expand<I: Numeric>(
298    scope: &mut Scope,
299    range: impl Iterable<I>,
300    unroll: bool,
301    body: impl FnMut(&mut Scope, ExpandElementTyped<I>),
302) {
303    if unroll || range.const_len() == Some(1) {
304        range.expand_unroll(scope, body);
305    } else {
306        range.expand(scope, body);
307    }
308}
309
310pub fn if_expand(scope: &mut Scope, runtime_cond: ExpandElement, block: impl FnOnce(&mut Scope)) {
311    let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
312    match comptime_cond {
313        Some(cond) => {
314            if cond {
315                block(scope);
316            }
317        }
318        None => {
319            let mut child = scope.child();
320
321            block(&mut child);
322
323            scope.register(Branch::If(Box::new(If {
324                cond: *runtime_cond,
325                scope: child,
326            })));
327        }
328    }
329}
330
331#[allow(clippy::large_enum_variant)]
332pub enum IfElseExpand {
333    ComptimeThen,
334    ComptimeElse,
335    Runtime {
336        runtime_cond: ExpandElement,
337        then_child: Scope,
338    },
339}
340
341impl IfElseExpand {
342    pub fn or_else(self, scope: &mut Scope, else_block: impl FnOnce(&mut Scope)) {
343        match self {
344            Self::Runtime {
345                runtime_cond,
346                then_child,
347            } => {
348                let mut else_child = scope.child();
349                else_block(&mut else_child);
350
351                scope.register(Branch::IfElse(Box::new(IfElse {
352                    cond: *runtime_cond,
353                    scope_if: then_child,
354                    scope_else: else_child,
355                })));
356            }
357            Self::ComptimeElse => else_block(scope),
358            Self::ComptimeThen => (),
359        }
360    }
361}
362
363pub fn if_else_expand(
364    scope: &mut Scope,
365    runtime_cond: ExpandElement,
366    then_block: impl FnOnce(&mut Scope),
367) -> IfElseExpand {
368    let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
369    match comptime_cond {
370        Some(true) => {
371            then_block(scope);
372            IfElseExpand::ComptimeThen
373        }
374        Some(false) => IfElseExpand::ComptimeElse,
375        None => {
376            let mut then_child = scope.child();
377            then_block(&mut then_child);
378
379            IfElseExpand::Runtime {
380                runtime_cond,
381                then_child,
382            }
383        }
384    }
385}
386
387#[allow(clippy::large_enum_variant)]
388pub enum IfElseExprExpand<C: CubeType> {
389    ComptimeThen(ExpandElementTyped<C>),
390    ComptimeElse,
391    Runtime {
392        runtime_cond: ExpandElement,
393        out: ExpandElementTyped<C>,
394        then_child: Scope,
395    },
396}
397
398impl<C: CubePrimitive> IfElseExprExpand<C> {
399    pub fn or_else(
400        self,
401        scope: &mut Scope,
402        else_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
403    ) -> ExpandElementTyped<C> {
404        match self {
405            Self::Runtime {
406                runtime_cond,
407                out,
408                then_child,
409            } => {
410                let mut else_child = scope.child();
411                let ret = else_block(&mut else_child);
412                assign::expand_no_check::<C>(&mut else_child, ret, out.clone());
413
414                scope.register(Branch::IfElse(Box::new(IfElse {
415                    cond: *runtime_cond,
416                    scope_if: then_child,
417                    scope_else: else_child,
418                })));
419                out
420            }
421            Self::ComptimeElse => else_block(scope),
422            Self::ComptimeThen(ret) => ret,
423        }
424    }
425}
426
427pub fn if_else_expr_expand<C: CubePrimitive>(
428    scope: &mut Scope,
429    runtime_cond: ExpandElement,
430    then_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
431) -> IfElseExprExpand<C> {
432    let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
433    match comptime_cond {
434        Some(true) => {
435            let ret = then_block(scope);
436            IfElseExprExpand::ComptimeThen(ret)
437        }
438        Some(false) => IfElseExprExpand::ComptimeElse,
439        None => {
440            let mut then_child = scope.child();
441            let ret = then_block(&mut then_child);
442            let out: ExpandElementTyped<C> = scope.create_local_mut(ret.expand.ty).into();
443            assign::expand_no_check::<C>(&mut then_child, ret, out.clone());
444
445            IfElseExprExpand::Runtime {
446                runtime_cond,
447                out,
448                then_child,
449            }
450        }
451    }
452}
453
454pub struct SwitchExpand<I: Int> {
455    value: ExpandElementTyped<I>,
456    default: Scope,
457    cases: Vec<(ExpandElementTyped<I>, Scope)>,
458}
459
460impl<I: Int> SwitchExpand<I> {
461    pub fn case(
462        mut self,
463        scope: &mut Scope,
464        value: impl Int,
465        block: impl FnOnce(&mut Scope),
466    ) -> Self {
467        let value = I::from(value).unwrap();
468        let mut case_child = scope.child();
469        block(&mut case_child);
470        self.cases.push((value.into(), case_child));
471        self
472    }
473
474    pub fn finish(self, scope: &mut Scope) {
475        let value_var = *self.value.expand;
476        scope.register(Branch::Switch(Box::new(Switch {
477            value: value_var,
478            scope_default: self.default,
479            cases: self
480                .cases
481                .into_iter()
482                .map(|it| (*it.0.expand, it.1))
483                .collect(),
484        })));
485    }
486}
487
488pub fn switch_expand<I: Int>(
489    scope: &mut Scope,
490    value: ExpandElementTyped<I>,
491    default_block: impl FnOnce(&mut Scope),
492) -> SwitchExpand<I> {
493    let mut default_child = scope.child();
494    default_block(&mut default_child);
495
496    SwitchExpand {
497        value,
498        default: default_child,
499        cases: Vec::new(),
500    }
501}
502
503pub struct SwitchExpandExpr<I: Int, C: CubePrimitive> {
504    value: ExpandElementTyped<I>,
505    out: ExpandElementTyped<C>,
506    default: Scope,
507    cases: Vec<(ExpandElementTyped<I>, Scope)>,
508}
509
510impl<I: Int, C: CubePrimitive> SwitchExpandExpr<I, C> {
511    pub fn case(
512        mut self,
513        scope: &mut Scope,
514        value: impl Int,
515        block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
516    ) -> Self {
517        let value = I::from(value).unwrap();
518        let mut case_child = scope.child();
519        let ret = block(&mut case_child);
520        assign::expand_no_check::<C>(&mut case_child, ret, self.out.clone());
521        self.cases.push((value.into(), case_child));
522        self
523    }
524
525    pub fn finish(self, scope: &mut Scope) -> ExpandElementTyped<C> {
526        let value_var = *self.value.expand;
527        scope.register(Branch::Switch(Box::new(Switch {
528            value: value_var,
529            scope_default: self.default,
530            cases: self
531                .cases
532                .into_iter()
533                .map(|it| (*it.0.expand, it.1))
534                .collect(),
535        })));
536        self.out
537    }
538}
539
540pub fn switch_expand_expr<I: Int, C: CubePrimitive>(
541    scope: &mut Scope,
542    value: ExpandElementTyped<I>,
543    default_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
544) -> SwitchExpandExpr<I, C> {
545    let mut default_child = scope.child();
546    let default = default_block(&mut default_child);
547    let out: ExpandElementTyped<C> = scope.create_local_mut(default.expand.ty).into();
548    assign::expand_no_check::<C>(&mut default_child, default, out.clone());
549
550    SwitchExpandExpr {
551        value,
552        out,
553        default: default_child,
554        cases: Vec::new(),
555    }
556}
557
558pub fn break_expand(scope: &mut Scope) {
559    scope.register(Branch::Break);
560}
561
562pub fn return_expand(scope: &mut Scope) {
563    scope.register(Branch::Return);
564}
565
566// Don't make this `FnOnce`, it must be executable multiple times
567pub fn loop_expand(scope: &mut Scope, mut block: impl FnMut(&mut Scope)) {
568    let mut inside_loop = scope.child();
569
570    block(&mut inside_loop);
571    scope.register(Branch::Loop(Box::new(Loop { scope: inside_loop })));
572}