1use alloc::{boxed::Box, vec::Vec};
2use cubecl_ir::ManagedVariable;
3use num_traits::NumCast;
4
5use crate::{ir::Switch, prelude::CubeEnum};
6use crate::{
7 ir::{Branch, If, IfElse, Loop, RangeLoop, Scope},
8 prelude::Assign,
9};
10
11use super::{CubeType, Int, NativeExpand, Numeric};
12
13pub trait Iterable<T: CubeType>: Sized {
16 fn expand(self, scope: &mut Scope, body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType));
23 fn expand_unroll(
30 self,
31 scope: &mut Scope,
32 body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
33 );
34 fn const_len(&self) -> Option<usize> {
36 None
37 }
38}
39
40pub struct RangeExpand<I: Int> {
41 pub start: NativeExpand<I>,
42 pub end: NativeExpand<I>,
43 pub inclusive: bool,
44}
45
46impl<I: Int> RangeExpand<I> {
47 pub fn new(start: NativeExpand<I>, end: NativeExpand<I>, inclusive: bool) -> Self {
48 RangeExpand {
49 start,
50 end,
51 inclusive,
52 }
53 }
54
55 pub fn __expand_step_by_method(self, n: impl Into<NativeExpand<I>>) -> SteppedRangeExpand<I> {
56 SteppedRangeExpand {
57 start: self.start,
58 end: self.end,
59 step: n.into(),
60 inclusive: self.inclusive,
61 }
62 }
63}
64
65impl<I: Int> Iterable<I> for RangeExpand<I> {
66 fn expand_unroll(
67 self,
68 scope: &mut Scope,
69 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
70 ) {
71 let start = self
72 .start
73 .expand
74 .as_const()
75 .expect("Only constant start can be unrolled.")
76 .as_i64();
77 let end = self
78 .end
79 .expand
80 .as_const()
81 .expect("Only constant end can be unrolled.")
82 .as_i64();
83
84 if self.inclusive {
85 for i in start..=end {
86 let var = I::from_int(i);
87 body(scope, var.into())
88 }
89 } else {
90 for i in start..end {
91 let var = I::from_int(i);
92 body(scope, var.into())
93 }
94 }
95 }
96
97 fn expand(
98 self,
99 scope: &mut Scope,
100 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
101 ) {
102 let mut child = scope.child();
103 let index_ty = I::as_type(scope);
104 let i = child.create_local_restricted(index_ty);
105
106 body(&mut child, i.clone().into());
107
108 let mut start = *self.start.expand;
109 let mut end = *self.end.expand;
110
111 start.ty = I::as_type(scope);
113 end.ty = I::as_type(scope);
114
115 scope.register(Branch::RangeLoop(Box::new(RangeLoop {
116 i: *i,
117 start,
118 end,
119 step: None,
120 scope: child,
121 inclusive: self.inclusive,
122 })));
123 }
124
125 fn const_len(&self) -> Option<usize> {
126 let start = self.start.expand.as_const()?.as_i64();
127 let end = self.end.expand.as_const()?.as_i64();
128 Some(start.abs_diff(end) as usize)
129 }
130}
131
132pub struct SteppedRangeExpand<I: Int> {
133 start: NativeExpand<I>,
134 end: NativeExpand<I>,
135 step: NativeExpand<I>,
136 inclusive: bool,
137}
138
139impl<I: Int + Into<ManagedVariable>> Iterable<I> for SteppedRangeExpand<I> {
140 fn expand(
141 self,
142 scope: &mut Scope,
143 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
144 ) {
145 let mut child = scope.child();
146 let index_ty = I::as_type(scope);
147 let i = child.create_local_restricted(index_ty);
148
149 body(&mut child, i.clone().into());
150
151 scope.register(Branch::RangeLoop(Box::new(RangeLoop {
152 i: *i,
153 start: *self.start.expand,
154 end: *self.end.expand,
155 step: Some(*self.step.expand),
156 scope: child,
157 inclusive: self.inclusive,
158 })));
159 }
160
161 fn expand_unroll(
162 self,
163 scope: &mut Scope,
164 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
165 ) {
166 let start = self
167 .start
168 .expand
169 .as_const()
170 .expect("Only constant start can be unrolled.")
171 .as_i128();
172 let end = self
173 .end
174 .expand
175 .as_const()
176 .expect("Only constant end can be unrolled.")
177 .as_i128();
178 let step = self
179 .step
180 .expand
181 .as_const()
182 .expect("Only constant step can be unrolled.")
183 .as_i128();
184
185 match (self.inclusive, step.is_negative()) {
186 (true, true) => {
187 for i in (end..=start).rev().step_by(step.unsigned_abs() as usize) {
188 let var = I::from_int_128(i);
189 body(scope, var.into())
190 }
191 }
192 (true, false) => {
193 for i in (start..=end).step_by(step.unsigned_abs() as usize) {
194 let var = I::from_int_128(i);
195 body(scope, var.into())
196 }
197 }
198 (false, true) => {
199 for i in (end..start).rev().step_by(step.unsigned_abs() as usize) {
200 let var = I::from_int_128(i);
201 body(scope, var.into())
202 }
203 }
204 (false, false) => {
205 for i in (start..end).step_by(step.unsigned_abs() as usize) {
206 let var = I::from_int_128(i);
207 body(scope, var.into())
208 }
209 }
210 }
211 }
212
213 fn const_len(&self) -> Option<usize> {
214 let start = self.start.constant()?.as_i128();
215 let end = self.end.constant()?.as_i128();
216 let step = self.step.constant()?.as_i128().unsigned_abs();
217 Some((start.abs_diff(end) / step) as usize)
218 }
219}
220
221pub fn range<T: Int>(start: T, end: T) -> impl Iterator<Item = T> {
227 let start: i64 = start.to_i64().unwrap();
228 let end: i64 = end.to_i64().unwrap();
229 (start..end).map(<T as NumCast>::from).map(Option::unwrap)
230}
231
232pub mod range {
233 use cubecl_ir::Scope;
234
235 use crate::prelude::{Int, NativeExpand};
236
237 use super::RangeExpand;
238
239 pub fn expand<I: Int>(
240 _scope: &mut Scope,
241 start: NativeExpand<I>,
242 end: NativeExpand<I>,
243 ) -> RangeExpand<I> {
244 RangeExpand {
245 start,
246 end,
247 inclusive: false,
248 }
249 }
250}
251
252pub fn range_stepped<I: Int>(start: I, end: I, step: I) -> Box<dyn Iterator<Item = I>> {
260 let start = start.to_i128().unwrap();
261 let end = end.to_i128().unwrap();
262 let step = step.to_i128().unwrap();
263
264 if step < 0 {
265 Box::new(
266 (end..start)
267 .rev()
268 .step_by(step.unsigned_abs() as usize)
269 .map(<I as NumCast>::from)
270 .map(Option::unwrap),
271 )
272 } else {
273 Box::new(
274 (start..end)
275 .step_by(step.unsigned_abs() as usize)
276 .map(<I as NumCast>::from)
277 .map(Option::unwrap),
278 )
279 }
280}
281
282pub mod range_stepped {
283 use cubecl_ir::Scope;
284
285 use crate::prelude::{Int, NativeExpand};
286
287 use super::SteppedRangeExpand;
288
289 pub fn expand<I: Int>(
290 _scope: &mut Scope,
291 start: NativeExpand<I>,
292 end: NativeExpand<I>,
293 step: NativeExpand<I>,
294 ) -> SteppedRangeExpand<I> {
295 SteppedRangeExpand {
296 start,
297 end,
298 step,
299 inclusive: false,
300 }
301 }
302}
303
304pub fn for_expand<I: Numeric>(
305 scope: &mut Scope,
306 range: impl Iterable<I>,
307 unroll: bool,
308 body: impl FnMut(&mut Scope, NativeExpand<I>),
309) {
310 if unroll || range.const_len() == Some(1) {
311 range.expand_unroll(scope, body);
312 } else {
313 range.expand(scope, body);
314 }
315}
316
317pub fn if_expand(scope: &mut Scope, condition: NativeExpand<bool>, block: impl FnOnce(&mut Scope)) {
318 let comptime_cond = condition.expand.as_const().map(|it| it.as_bool());
319 match comptime_cond {
320 Some(cond) => {
321 if cond {
322 block(scope);
323 }
324 }
325 None => {
326 let mut child = scope.child();
327
328 block(&mut child);
329
330 scope.register(Branch::If(Box::new(If {
331 cond: *condition.expand,
332 scope: child,
333 })));
334 }
335 }
336}
337
338#[allow(clippy::large_enum_variant)]
339pub enum IfElseExpand {
340 ComptimeThen,
341 ComptimeElse,
342 Runtime {
343 runtime_cond: NativeExpand<bool>,
344 then_child: Scope,
345 },
346}
347
348impl IfElseExpand {
349 pub fn or_else(self, scope: &mut Scope, else_block: impl FnOnce(&mut Scope)) {
350 match self {
351 Self::Runtime {
352 runtime_cond,
353 then_child,
354 } => {
355 let mut else_child = scope.child();
356 else_block(&mut else_child);
357
358 scope.register(Branch::IfElse(Box::new(IfElse {
359 cond: *runtime_cond.expand,
360 scope_if: then_child,
361 scope_else: else_child,
362 })));
363 }
364 Self::ComptimeElse => else_block(scope),
365 Self::ComptimeThen => (),
366 }
367 }
368}
369
370pub fn if_else_expand(
371 scope: &mut Scope,
372 condition: NativeExpand<bool>,
373 then_block: impl FnOnce(&mut Scope),
374) -> IfElseExpand {
375 let comptime_cond = condition.expand.as_const().map(|it| it.as_bool());
376 match comptime_cond {
377 Some(true) => {
378 then_block(scope);
379 IfElseExpand::ComptimeThen
380 }
381 Some(false) => IfElseExpand::ComptimeElse,
382 None => {
383 let mut then_child = scope.child();
384 then_block(&mut then_child);
385
386 IfElseExpand::Runtime {
387 runtime_cond: condition,
388 then_child,
389 }
390 }
391 }
392}
393
394#[allow(clippy::large_enum_variant)]
395pub enum IfElseExprExpand<C: Assign> {
396 ComptimeThen(C),
397 ComptimeElse,
398 Runtime {
399 runtime_cond: NativeExpand<bool>,
400 out: C,
401 then_child: Scope,
402 },
403}
404
405impl<C: Assign> IfElseExprExpand<C> {
406 pub fn or_else(self, scope: &mut Scope, else_block: impl FnOnce(&mut Scope) -> C) -> C {
407 match self {
408 Self::Runtime {
409 runtime_cond,
410 mut out,
411 then_child,
412 } => {
413 let mut else_child = scope.child();
414 let ret = else_block(&mut else_child);
415 out.expand_assign(&mut else_child, ret);
416
417 scope.register(Branch::IfElse(Box::new(IfElse {
418 cond: *runtime_cond.expand,
419 scope_if: then_child,
420 scope_else: else_child,
421 })));
422 out
423 }
424 Self::ComptimeElse => else_block(scope),
425 Self::ComptimeThen(ret) => ret,
426 }
427 }
428}
429
430pub fn if_else_expr_expand<C: Assign>(
431 scope: &mut Scope,
432 condition: NativeExpand<bool>,
433 then_block: impl FnOnce(&mut Scope) -> C,
434) -> IfElseExprExpand<C> {
435 let comptime_cond = condition.expand.as_const().map(|it| it.as_bool());
436 match comptime_cond {
437 Some(true) => {
438 let ret = then_block(scope);
439 IfElseExprExpand::ComptimeThen(ret)
440 }
441 Some(false) => IfElseExprExpand::ComptimeElse,
442 None => {
443 let mut then_child = scope.child();
444 let ret = then_block(&mut then_child);
445 let mut out = ret.init_mut(scope);
446 out.expand_assign(&mut then_child, ret);
447
448 IfElseExprExpand::Runtime {
449 runtime_cond: condition,
450 out,
451 then_child,
452 }
453 }
454 }
455}
456
457pub struct SwitchExpand<I: Int> {
458 value: NativeExpand<I>,
459 default: Scope,
460 cases: Vec<(NativeExpand<I>, Scope)>,
461}
462
463impl<I: Int> SwitchExpand<I> {
464 pub fn case(
465 mut self,
466 scope: &mut Scope,
467 value: impl Int,
468 block: impl FnOnce(&mut Scope),
469 ) -> Self {
470 let value = I::from(value).unwrap();
471 let mut case_child = scope.child();
472 block(&mut case_child);
473 self.cases.push((value.into(), case_child));
474 self
475 }
476
477 pub fn finish(self, scope: &mut Scope) {
478 let value_var = *self.value.expand;
479 scope.register(Branch::Switch(Box::new(Switch {
480 value: value_var,
481 scope_default: self.default,
482 cases: self
483 .cases
484 .into_iter()
485 .map(|it| (*it.0.expand, it.1))
486 .collect(),
487 })));
488 }
489}
490
491pub fn switch_expand<I: Int>(
492 scope: &mut Scope,
493 value: NativeExpand<I>,
494 default_block: impl FnOnce(&mut Scope),
495) -> SwitchExpand<I> {
496 let mut default_child = scope.child();
497 default_block(&mut default_child);
498
499 SwitchExpand {
500 value,
501 default: default_child,
502 cases: Vec::new(),
503 }
504}
505
506pub struct SwitchExpandExpr<I: Int, C: Assign> {
507 value: NativeExpand<I>,
508 out: C,
509 default: Scope,
510 cases: Vec<(NativeExpand<I>, Scope)>,
511}
512
513impl<I: Int, C: Assign> SwitchExpandExpr<I, C> {
514 pub fn case(
515 mut self,
516 scope: &mut Scope,
517 value: impl Int,
518 block: impl FnOnce(&mut Scope) -> C,
519 ) -> Self {
520 let value = I::from(value).unwrap();
521 let mut case_child = scope.child();
522 let ret = block(&mut case_child);
523 self.out.expand_assign(&mut case_child, ret);
524 self.cases.push((value.into(), case_child));
525 self
526 }
527
528 pub fn finish(self, scope: &mut Scope) -> C {
529 let value_var = *self.value.expand;
530 scope.register(Branch::Switch(Box::new(Switch {
531 value: value_var,
532 scope_default: self.default,
533 cases: self
534 .cases
535 .into_iter()
536 .map(|it| (*it.0.expand, it.1))
537 .collect(),
538 })));
539 self.out
540 }
541}
542
543pub fn switch_expand_expr<I: Int, C: Assign>(
544 scope: &mut Scope,
545 value: NativeExpand<I>,
546 default_block: impl FnOnce(&mut Scope) -> C,
547) -> SwitchExpandExpr<I, C> {
548 let mut default_child = scope.child();
549 let default = default_block(&mut default_child);
550 let mut out = default.init_mut(scope);
551 out.expand_assign(&mut default_child, default);
552
553 SwitchExpandExpr {
554 value,
555 out,
556 default: default_child,
557 cases: Vec::new(),
558 }
559}
560
561#[allow(clippy::large_enum_variant)]
562pub enum MatchExpand<T: CubeEnum> {
563 ComptimeVariant {
564 variant: i32,
565 runtime_value: T::RuntimeValue,
566 matched: bool,
567 },
568 RuntimeVariant {
569 variant: NativeExpand<i32>,
570 cases: Vec<(NativeExpand<i32>, Scope)>,
571 runtime_value: T::RuntimeValue,
572 default: Option<Scope>,
573 },
574}
575
576impl<T: CubeEnum> MatchExpand<T> {
577 pub fn case(
578 mut self,
579 scope: &mut Scope,
580 value: i32,
581 block: impl FnOnce(&mut Scope, T::RuntimeValue),
582 ) -> Self {
583 match &mut self {
584 Self::RuntimeVariant {
585 cases,
586 runtime_value,
587 ..
588 } => {
589 let mut case_child = scope.child();
590 block(&mut case_child, runtime_value.clone());
591 cases.push((value.into(), case_child));
592 }
593 Self::ComptimeVariant {
594 variant,
595 runtime_value,
596 matched,
597 } => {
598 if value == *variant {
599 block(scope, runtime_value.clone());
600 *matched = true;
601 }
602 }
603 }
604 self
605 }
606
607 pub fn default(
608 mut self,
609 scope: &mut Scope,
610 block: impl FnOnce(&mut Scope, T::RuntimeValue),
611 ) -> Self {
612 match &mut self {
613 Self::RuntimeVariant {
614 runtime_value,
615 default,
616 ..
617 } => {
618 let mut case_child = scope.child();
619 block(&mut case_child, runtime_value.clone());
620 *default = Some(case_child);
621 }
622 Self::ComptimeVariant {
623 runtime_value,
624 matched,
625 ..
626 } => {
627 if !*matched {
628 block(scope, runtime_value.clone());
629 *matched = true;
630 }
631 }
632 }
633 self
634 }
635
636 pub fn finish(self, scope: &mut Scope) {
637 match self {
638 MatchExpand::ComptimeVariant { .. } => {}
639 MatchExpand::RuntimeVariant {
640 variant,
641 cases,
642 default,
643 ..
644 } => {
645 let variant_var = *variant.expand;
646 let scope_default = default.unwrap_or_else(|| {
647 let mut scope_default = scope.child();
648 unreachable_unchecked::expand(&mut scope_default);
649 scope_default
650 });
651
652 scope.register(Branch::Switch(Box::new(Switch {
653 value: variant_var,
654 scope_default,
655 cases: cases.into_iter().map(|it| (*it.0.expand, it.1)).collect(),
656 })));
657 }
658 }
659 }
660}
661
662pub fn match_expand<T: CubeEnum>(
663 scope: &mut Scope,
664 value: T,
665 discriminant0: i32,
666 arm0: impl FnOnce(&mut Scope, T::RuntimeValue),
667) -> MatchExpand<T> {
668 let discriminant = value.discriminant();
669 match discriminant.constant() {
670 Some(const_variant) if const_variant.as_i32() == discriminant0 => {
671 let runtime_value = value.runtime_value();
672 arm0(scope, runtime_value.clone());
673 MatchExpand::ComptimeVariant {
674 variant: const_variant.as_i32(),
675 runtime_value,
676 matched: true,
677 }
678 }
679 Some(const_variant) => MatchExpand::ComptimeVariant {
680 variant: const_variant.as_i32(),
681 runtime_value: value.runtime_value(),
682 matched: false,
683 },
684 None => {
685 let runtime_value = value.runtime_value();
686 let mut case_child = scope.child();
687 arm0(&mut case_child, runtime_value.clone());
688
689 MatchExpand::RuntimeVariant {
690 variant: discriminant,
691 cases: alloc::vec![(discriminant0.into(), case_child)],
692 runtime_value,
693 default: None,
694 }
695 }
696 }
697}
698
699#[allow(clippy::large_enum_variant)]
700pub enum MatchExpandExpr<T: CubeEnum, C: Assign> {
701 ComptimeVariant {
702 variant: i32,
703 runtime_value: T::RuntimeValue,
704 out: Option<C>,
705 matched: bool,
706 },
707 RuntimeVariant {
708 variant: NativeExpand<i32>,
709 out: C,
710 cases: Vec<(NativeExpand<i32>, Scope)>,
711 runtime_value: T::RuntimeValue,
712 default: Option<Scope>,
713 },
714}
715
716impl<T: CubeEnum, C: Assign> MatchExpandExpr<T, C> {
717 pub fn case(
718 mut self,
719 scope: &mut Scope,
720 value: i32,
721 block: impl FnOnce(&mut Scope, T::RuntimeValue) -> C,
722 ) -> Self {
723 match &mut self {
724 Self::RuntimeVariant {
725 cases,
726 out,
727 runtime_value,
728 ..
729 } => {
730 let mut case_child = scope.child();
731 let ret_val = block(&mut case_child, runtime_value.clone());
732 out.expand_assign(&mut case_child, ret_val);
733 cases.push((value.into(), case_child));
734 }
735 Self::ComptimeVariant {
736 variant,
737 runtime_value,
738 out,
739 matched,
740 } => {
741 if value == *variant {
742 *out = Some(block(scope, runtime_value.clone()));
743 *matched = true;
744 }
745 }
746 }
747 self
748 }
749
750 pub fn default(
751 mut self,
752 scope: &mut Scope,
753 block: impl FnOnce(&mut Scope, T::RuntimeValue) -> C,
754 ) -> Self {
755 match &mut self {
756 Self::RuntimeVariant {
757 runtime_value,
758 out,
759 default,
760 ..
761 } => {
762 let mut case_child = scope.child();
763 let ret_val = block(&mut case_child, runtime_value.clone());
764 out.expand_assign(&mut case_child, ret_val);
765 *default = Some(case_child);
766 }
767 Self::ComptimeVariant {
768 runtime_value,
769 out,
770 matched,
771 ..
772 } => {
773 if !*matched {
774 *out = Some(block(scope, runtime_value.clone()));
775 *matched = true;
776 }
777 }
778 }
779 self
780 }
781
782 pub fn finish(self, scope: &mut Scope) -> C {
783 match self {
784 MatchExpandExpr::ComptimeVariant { out, .. } => {
785 out.expect("At least one variant should be matched")
786 }
787 MatchExpandExpr::RuntimeVariant {
788 variant,
789 cases,
790 out,
791 default,
792 ..
793 } => {
794 let variant_var = *variant.expand;
795 let scope_default = default.unwrap_or_else(|| {
796 let mut scope_default = scope.child();
797 unreachable_unchecked::expand(&mut scope_default);
798 scope_default
799 });
800 scope.register(Branch::Switch(Box::new(Switch {
801 value: variant_var,
802 scope_default,
803 cases: cases.into_iter().map(|it| (*it.0.expand, it.1)).collect(),
804 })));
805 out
806 }
807 }
808 }
809}
810
811pub fn match_expand_expr<T: CubeEnum, C: Assign>(
812 scope: &mut Scope,
813 value: T,
814 discriminant0: i32,
815 arm0: impl FnOnce(&mut Scope, T::RuntimeValue) -> C,
816) -> MatchExpandExpr<T, C> {
817 let discriminant = value.discriminant();
818 match discriminant.constant() {
819 Some(const_variant) if const_variant.as_i32() == discriminant0 => {
820 let runtime_value = value.runtime_value();
821 let out = arm0(scope, runtime_value.clone());
822 MatchExpandExpr::ComptimeVariant {
823 variant: const_variant.as_i32(),
824 out: Some(out),
825 runtime_value,
826 matched: true,
827 }
828 }
829 Some(const_variant) => MatchExpandExpr::ComptimeVariant {
830 variant: const_variant.as_i32(),
831 out: None,
832 runtime_value: value.runtime_value(),
833 matched: false,
834 },
835 None => {
836 let runtime_value = value.runtime_value();
837 let mut case_child = scope.child();
838 let ret_val = arm0(&mut case_child, runtime_value.clone());
839
840 let mut out = ret_val.init_mut(scope);
841 out.expand_assign(&mut case_child, ret_val);
842
843 MatchExpandExpr::RuntimeVariant {
844 variant: discriminant,
845 out,
846 cases: alloc::vec![(discriminant0.into(), case_child)],
847 runtime_value,
848 default: None,
849 }
850 }
851 }
852}
853
854pub fn break_expand(scope: &mut Scope) {
855 scope.register(Branch::Break);
856}
857
858pub fn return_expand(scope: &mut Scope) {
859 scope.register(Branch::Return);
860}
861
862pub mod unreachable_unchecked {
863 use super::*;
864
865 pub fn expand(scope: &mut Scope) {
866 scope.register(Branch::Unreachable);
867 }
868}
869
870pub fn loop_expand(scope: &mut Scope, mut block: impl FnMut(&mut Scope)) {
872 let mut inside_loop = scope.child();
873
874 block(&mut inside_loop);
875 scope.register(Branch::Loop(Box::new(Loop { scope: inside_loop })));
876}