1use cubecl_ir::ExpandElement;
2use num_traits::NumCast;
3
4use crate::ir::Switch;
5use crate::ir::{Branch, If, IfElse, Loop, RangeLoop, Scope, Type};
6
7use super::{CubePrimitive, CubeType, ExpandElementTyped, Int, Numeric, assign};
8
9pub trait Iterable<T: CubeType>: Sized {
12 fn expand(self, scope: &mut Scope, body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType));
19 fn expand_unroll(
26 self,
27 scope: &mut Scope,
28 body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
29 );
30 fn const_len(&self) -> Option<usize> {
32 None
33 }
34}
35
36pub struct RangeExpand<I: Int> {
37 pub start: ExpandElementTyped<I>,
38 pub end: ExpandElementTyped<I>,
39 pub inclusive: bool,
40}
41
42impl<I: Int> RangeExpand<I> {
43 pub fn new(start: ExpandElementTyped<I>, end: ExpandElementTyped<I>, inclusive: bool) -> Self {
44 RangeExpand {
45 start,
46 end,
47 inclusive,
48 }
49 }
50
51 pub fn __expand_step_by_method(
52 self,
53 n: impl Into<ExpandElementTyped<I>>,
54 ) -> SteppedRangeExpand<I> {
55 SteppedRangeExpand {
56 start: self.start,
57 end: self.end,
58 step: n.into(),
59 inclusive: self.inclusive,
60 }
61 }
62}
63
64impl<I: Int> Iterable<I> for RangeExpand<I> {
65 fn expand_unroll(
66 self,
67 scope: &mut Scope,
68 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
69 ) {
70 let start = self
71 .start
72 .expand
73 .as_const()
74 .expect("Only constant start can be unrolled.")
75 .as_i64();
76 let end = self
77 .end
78 .expand
79 .as_const()
80 .expect("Only constant end can be unrolled.")
81 .as_i64();
82
83 if self.inclusive {
84 for i in start..=end {
85 let var = I::from_int(i);
86 body(scope, var.into())
87 }
88 } else {
89 for i in start..end {
90 let var = I::from_int(i);
91 body(scope, var.into())
92 }
93 }
94 }
95
96 fn expand(
97 self,
98 scope: &mut Scope,
99 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
100 ) {
101 let mut child = scope.child();
102 let index_ty = Type::new(I::as_type(scope));
103 let i = child.create_local_restricted(index_ty);
104
105 body(&mut child, i.clone().into());
106
107 scope.register(Branch::RangeLoop(Box::new(RangeLoop {
108 i: *i,
109 start: *self.start.expand,
110 end: *self.end.expand,
111 step: None,
112 scope: child,
113 inclusive: self.inclusive,
114 })));
115 }
116
117 fn const_len(&self) -> Option<usize> {
118 let start = self.start.expand.as_const()?.as_i64();
119 let end = self.end.expand.as_const()?.as_i64();
120 Some(start.abs_diff(end) as usize)
121 }
122}
123
124pub struct SteppedRangeExpand<I: Int> {
125 start: ExpandElementTyped<I>,
126 end: ExpandElementTyped<I>,
127 step: ExpandElementTyped<I>,
128 inclusive: bool,
129}
130
131impl<I: Int + Into<ExpandElement>> Iterable<I> for SteppedRangeExpand<I> {
132 fn expand(
133 self,
134 scope: &mut Scope,
135 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
136 ) {
137 let mut child = scope.child();
138 let index_ty = Type::new(I::as_type(scope));
139 let i = child.create_local_restricted(index_ty);
140
141 body(&mut child, i.clone().into());
142
143 scope.register(Branch::RangeLoop(Box::new(RangeLoop {
144 i: *i,
145 start: *self.start.expand,
146 end: *self.end.expand,
147 step: Some(*self.step.expand),
148 scope: child,
149 inclusive: self.inclusive,
150 })));
151 }
152
153 fn expand_unroll(
154 self,
155 scope: &mut Scope,
156 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
157 ) {
158 let start = self
159 .start
160 .expand
161 .as_const()
162 .expect("Only constant start can be unrolled.")
163 .as_i128();
164 let end = self
165 .end
166 .expand
167 .as_const()
168 .expect("Only constant end can be unrolled.")
169 .as_i128();
170 let step = self
171 .step
172 .expand
173 .as_const()
174 .expect("Only constant step can be unrolled.")
175 .as_i128();
176
177 match (self.inclusive, step.is_negative()) {
178 (true, true) => {
179 for i in (end..=start).rev().step_by(step.unsigned_abs() as usize) {
180 let var = I::from_int_128(i);
181 body(scope, var.into())
182 }
183 }
184 (true, false) => {
185 for i in (start..=end).step_by(step.unsigned_abs() as usize) {
186 let var = I::from_int_128(i);
187 body(scope, var.into())
188 }
189 }
190 (false, true) => {
191 for i in (end..start).rev().step_by(step.unsigned_abs() as usize) {
192 let var = I::from_int_128(i);
193 body(scope, var.into())
194 }
195 }
196 (false, false) => {
197 for i in (start..end).step_by(step.unsigned_abs() as usize) {
198 let var = I::from_int_128(i);
199 body(scope, var.into())
200 }
201 }
202 }
203 }
204
205 fn const_len(&self) -> Option<usize> {
206 let start = self.start.constant()?.as_i128();
207 let end = self.end.constant()?.as_i128();
208 let step = self.step.constant()?.as_i128().unsigned_abs();
209 Some((start.abs_diff(end) / step) as usize)
210 }
211}
212
213pub fn range<T: Int>(start: T, end: T) -> impl Iterator<Item = T> {
219 let start: i64 = start.to_i64().unwrap();
220 let end: i64 = end.to_i64().unwrap();
221 (start..end).map(<T as NumCast>::from).map(Option::unwrap)
222}
223
224pub mod range {
225 use cubecl_ir::Scope;
226
227 use crate::prelude::{ExpandElementTyped, Int};
228
229 use super::RangeExpand;
230
231 pub fn expand<I: Int>(
232 _scope: &mut Scope,
233 start: ExpandElementTyped<I>,
234 end: ExpandElementTyped<I>,
235 ) -> RangeExpand<I> {
236 RangeExpand {
237 start,
238 end,
239 inclusive: false,
240 }
241 }
242}
243
244pub fn range_stepped<I: Int>(start: I, end: I, step: I) -> Box<dyn Iterator<Item = I>> {
252 let start = start.to_i128().unwrap();
253 let end = end.to_i128().unwrap();
254 let step = step.to_i128().unwrap();
255
256 if step < 0 {
257 Box::new(
258 (end..start)
259 .rev()
260 .step_by(step.unsigned_abs() as usize)
261 .map(<I as NumCast>::from)
262 .map(Option::unwrap),
263 )
264 } else {
265 Box::new(
266 (start..end)
267 .step_by(step.unsigned_abs() as usize)
268 .map(<I as NumCast>::from)
269 .map(Option::unwrap),
270 )
271 }
272}
273
274pub mod range_stepped {
275 use cubecl_ir::Scope;
276
277 use crate::prelude::{ExpandElementTyped, Int};
278
279 use super::SteppedRangeExpand;
280
281 pub fn expand<I: Int>(
282 _scope: &mut Scope,
283 start: ExpandElementTyped<I>,
284 end: ExpandElementTyped<I>,
285 step: ExpandElementTyped<I>,
286 ) -> SteppedRangeExpand<I> {
287 SteppedRangeExpand {
288 start,
289 end,
290 step,
291 inclusive: false,
292 }
293 }
294}
295
296pub fn for_expand<I: Numeric>(
297 scope: &mut Scope,
298 range: impl Iterable<I>,
299 unroll: bool,
300 body: impl FnMut(&mut Scope, ExpandElementTyped<I>),
301) {
302 if unroll || range.const_len() == Some(1) {
303 range.expand_unroll(scope, body);
304 } else {
305 range.expand(scope, body);
306 }
307}
308
309pub fn if_expand(scope: &mut Scope, runtime_cond: ExpandElement, block: impl FnOnce(&mut Scope)) {
310 let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
311 match comptime_cond {
312 Some(cond) => {
313 if cond {
314 block(scope);
315 }
316 }
317 None => {
318 let mut child = scope.child();
319
320 block(&mut child);
321
322 scope.register(Branch::If(Box::new(If {
323 cond: *runtime_cond,
324 scope: child,
325 })));
326 }
327 }
328}
329
330#[allow(clippy::large_enum_variant)]
331pub enum IfElseExpand {
332 ComptimeThen,
333 ComptimeElse,
334 Runtime {
335 runtime_cond: ExpandElement,
336 then_child: Scope,
337 },
338}
339
340impl IfElseExpand {
341 pub fn or_else(self, scope: &mut Scope, else_block: impl FnOnce(&mut Scope)) {
342 match self {
343 Self::Runtime {
344 runtime_cond,
345 then_child,
346 } => {
347 let mut else_child = scope.child();
348 else_block(&mut else_child);
349
350 scope.register(Branch::IfElse(Box::new(IfElse {
351 cond: *runtime_cond,
352 scope_if: then_child,
353 scope_else: else_child,
354 })));
355 }
356 Self::ComptimeElse => else_block(scope),
357 Self::ComptimeThen => (),
358 }
359 }
360}
361
362pub fn if_else_expand(
363 scope: &mut Scope,
364 runtime_cond: ExpandElement,
365 then_block: impl FnOnce(&mut Scope),
366) -> IfElseExpand {
367 let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
368 match comptime_cond {
369 Some(true) => {
370 then_block(scope);
371 IfElseExpand::ComptimeThen
372 }
373 Some(false) => IfElseExpand::ComptimeElse,
374 None => {
375 let mut then_child = scope.child();
376 then_block(&mut then_child);
377
378 IfElseExpand::Runtime {
379 runtime_cond,
380 then_child,
381 }
382 }
383 }
384}
385
386#[allow(clippy::large_enum_variant)]
387pub enum IfElseExprExpand<C: CubeType> {
388 ComptimeThen(ExpandElementTyped<C>),
389 ComptimeElse,
390 Runtime {
391 runtime_cond: ExpandElement,
392 out: ExpandElementTyped<C>,
393 then_child: Scope,
394 },
395}
396
397impl<C: CubePrimitive> IfElseExprExpand<C> {
398 pub fn or_else(
399 self,
400 scope: &mut Scope,
401 else_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
402 ) -> ExpandElementTyped<C> {
403 match self {
404 Self::Runtime {
405 runtime_cond,
406 out,
407 then_child,
408 } => {
409 let mut else_child = scope.child();
410 let ret = else_block(&mut else_child);
411 assign::expand_no_check::<C>(&mut else_child, ret, out.clone());
412
413 scope.register(Branch::IfElse(Box::new(IfElse {
414 cond: *runtime_cond,
415 scope_if: then_child,
416 scope_else: else_child,
417 })));
418 out
419 }
420 Self::ComptimeElse => else_block(scope),
421 Self::ComptimeThen(ret) => ret,
422 }
423 }
424}
425
426pub fn if_else_expr_expand<C: CubePrimitive>(
427 scope: &mut Scope,
428 runtime_cond: ExpandElement,
429 then_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
430) -> IfElseExprExpand<C> {
431 let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
432 match comptime_cond {
433 Some(true) => {
434 let ret = then_block(scope);
435 IfElseExprExpand::ComptimeThen(ret)
436 }
437 Some(false) => IfElseExprExpand::ComptimeElse,
438 None => {
439 let mut then_child = scope.child();
440 let ret = then_block(&mut then_child);
441 let out: ExpandElementTyped<C> = scope.create_local_mut(ret.expand.ty).into();
442 assign::expand_no_check::<C>(&mut then_child, ret, out.clone());
443
444 IfElseExprExpand::Runtime {
445 runtime_cond,
446 out,
447 then_child,
448 }
449 }
450 }
451}
452
453pub struct SwitchExpand<I: Int> {
454 value: ExpandElementTyped<I>,
455 default: Scope,
456 cases: Vec<(ExpandElementTyped<I>, Scope)>,
457}
458
459impl<I: Int> SwitchExpand<I> {
460 pub fn case(
461 mut self,
462 scope: &mut Scope,
463 value: impl Int,
464 block: impl FnOnce(&mut Scope),
465 ) -> Self {
466 let value = I::from(value).unwrap();
467 let mut case_child = scope.child();
468 block(&mut case_child);
469 self.cases.push((value.into(), case_child));
470 self
471 }
472
473 pub fn finish(self, scope: &mut Scope) {
474 let value_var = *self.value.expand;
475 scope.register(Branch::Switch(Box::new(Switch {
476 value: value_var,
477 scope_default: self.default,
478 cases: self
479 .cases
480 .into_iter()
481 .map(|it| (*it.0.expand, it.1))
482 .collect(),
483 })));
484 }
485}
486
487pub fn switch_expand<I: Int>(
488 scope: &mut Scope,
489 value: ExpandElementTyped<I>,
490 default_block: impl FnOnce(&mut Scope),
491) -> SwitchExpand<I> {
492 let mut default_child = scope.child();
493 default_block(&mut default_child);
494
495 SwitchExpand {
496 value,
497 default: default_child,
498 cases: Vec::new(),
499 }
500}
501
502pub struct SwitchExpandExpr<I: Int, C: CubePrimitive> {
503 value: ExpandElementTyped<I>,
504 out: ExpandElementTyped<C>,
505 default: Scope,
506 cases: Vec<(ExpandElementTyped<I>, Scope)>,
507}
508
509impl<I: Int, C: CubePrimitive> SwitchExpandExpr<I, C> {
510 pub fn case(
511 mut self,
512 scope: &mut Scope,
513 value: impl Int,
514 block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
515 ) -> Self {
516 let value = I::from(value).unwrap();
517 let mut case_child = scope.child();
518 let ret = block(&mut case_child);
519 assign::expand_no_check::<C>(&mut case_child, ret, self.out.clone());
520 self.cases.push((value.into(), case_child));
521 self
522 }
523
524 pub fn finish(self, scope: &mut Scope) -> ExpandElementTyped<C> {
525 let value_var = *self.value.expand;
526 scope.register(Branch::Switch(Box::new(Switch {
527 value: value_var,
528 scope_default: self.default,
529 cases: self
530 .cases
531 .into_iter()
532 .map(|it| (*it.0.expand, it.1))
533 .collect(),
534 })));
535 self.out
536 }
537}
538
539pub fn switch_expand_expr<I: Int, C: CubePrimitive>(
540 scope: &mut Scope,
541 value: ExpandElementTyped<I>,
542 default_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
543) -> SwitchExpandExpr<I, C> {
544 let mut default_child = scope.child();
545 let default = default_block(&mut default_child);
546 let out: ExpandElementTyped<C> = scope.create_local_mut(default.expand.ty).into();
547 assign::expand_no_check::<C>(&mut default_child, default, out.clone());
548
549 SwitchExpandExpr {
550 value,
551 out,
552 default: default_child,
553 cases: Vec::new(),
554 }
555}
556
557pub fn break_expand(scope: &mut Scope) {
558 scope.register(Branch::Break);
559}
560
561pub fn return_expand(scope: &mut Scope) {
562 scope.register(Branch::Return);
563}
564
565pub fn loop_expand(scope: &mut Scope, mut block: impl FnMut(&mut Scope)) {
567 let mut inside_loop = scope.child();
568
569 block(&mut inside_loop);
570 scope.register(Branch::Loop(Box::new(Loop { scope: inside_loop })));
571}