1use alloc::{boxed::Box, vec::Vec};
2use cubecl_ir::ExpandElement;
3use num_traits::NumCast;
4
5use crate::ir::Switch;
6use crate::ir::{Branch, If, IfElse, Loop, RangeLoop, Scope, Type};
7
8use super::{CubePrimitive, CubeType, ExpandElementTyped, Int, Numeric, assign};
9
10pub trait Iterable<T: CubeType>: Sized {
13 fn expand(self, scope: &mut Scope, body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType));
20 fn expand_unroll(
27 self,
28 scope: &mut Scope,
29 body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
30 );
31 fn const_len(&self) -> Option<usize> {
33 None
34 }
35}
36
37pub struct RangeExpand<I: Int> {
38 pub start: ExpandElementTyped<I>,
39 pub end: ExpandElementTyped<I>,
40 pub inclusive: bool,
41}
42
43impl<I: Int> RangeExpand<I> {
44 pub fn new(start: ExpandElementTyped<I>, end: ExpandElementTyped<I>, inclusive: bool) -> Self {
45 RangeExpand {
46 start,
47 end,
48 inclusive,
49 }
50 }
51
52 pub fn __expand_step_by_method(
53 self,
54 n: impl Into<ExpandElementTyped<I>>,
55 ) -> SteppedRangeExpand<I> {
56 SteppedRangeExpand {
57 start: self.start,
58 end: self.end,
59 step: n.into(),
60 inclusive: self.inclusive,
61 }
62 }
63}
64
65impl<I: Int> Iterable<I> for RangeExpand<I> {
66 fn expand_unroll(
67 self,
68 scope: &mut Scope,
69 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
70 ) {
71 let start = self
72 .start
73 .expand
74 .as_const()
75 .expect("Only constant start can be unrolled.")
76 .as_i64();
77 let end = self
78 .end
79 .expand
80 .as_const()
81 .expect("Only constant end can be unrolled.")
82 .as_i64();
83
84 if self.inclusive {
85 for i in start..=end {
86 let var = I::from_int(i);
87 body(scope, var.into())
88 }
89 } else {
90 for i in start..end {
91 let var = I::from_int(i);
92 body(scope, var.into())
93 }
94 }
95 }
96
97 fn expand(
98 self,
99 scope: &mut Scope,
100 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
101 ) {
102 let mut child = scope.child();
103 let index_ty = Type::new(I::as_type(scope));
104 let i = child.create_local_restricted(index_ty);
105
106 body(&mut child, i.clone().into());
107
108 scope.register(Branch::RangeLoop(Box::new(RangeLoop {
109 i: *i,
110 start: *self.start.expand,
111 end: *self.end.expand,
112 step: None,
113 scope: child,
114 inclusive: self.inclusive,
115 })));
116 }
117
118 fn const_len(&self) -> Option<usize> {
119 let start = self.start.expand.as_const()?.as_i64();
120 let end = self.end.expand.as_const()?.as_i64();
121 Some(start.abs_diff(end) as usize)
122 }
123}
124
125pub struct SteppedRangeExpand<I: Int> {
126 start: ExpandElementTyped<I>,
127 end: ExpandElementTyped<I>,
128 step: ExpandElementTyped<I>,
129 inclusive: bool,
130}
131
132impl<I: Int + Into<ExpandElement>> Iterable<I> for SteppedRangeExpand<I> {
133 fn expand(
134 self,
135 scope: &mut Scope,
136 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
137 ) {
138 let mut child = scope.child();
139 let index_ty = Type::new(I::as_type(scope));
140 let i = child.create_local_restricted(index_ty);
141
142 body(&mut child, i.clone().into());
143
144 scope.register(Branch::RangeLoop(Box::new(RangeLoop {
145 i: *i,
146 start: *self.start.expand,
147 end: *self.end.expand,
148 step: Some(*self.step.expand),
149 scope: child,
150 inclusive: self.inclusive,
151 })));
152 }
153
154 fn expand_unroll(
155 self,
156 scope: &mut Scope,
157 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
158 ) {
159 let start = self
160 .start
161 .expand
162 .as_const()
163 .expect("Only constant start can be unrolled.")
164 .as_i128();
165 let end = self
166 .end
167 .expand
168 .as_const()
169 .expect("Only constant end can be unrolled.")
170 .as_i128();
171 let step = self
172 .step
173 .expand
174 .as_const()
175 .expect("Only constant step can be unrolled.")
176 .as_i128();
177
178 match (self.inclusive, step.is_negative()) {
179 (true, true) => {
180 for i in (end..=start).rev().step_by(step.unsigned_abs() as usize) {
181 let var = I::from_int_128(i);
182 body(scope, var.into())
183 }
184 }
185 (true, false) => {
186 for i in (start..=end).step_by(step.unsigned_abs() as usize) {
187 let var = I::from_int_128(i);
188 body(scope, var.into())
189 }
190 }
191 (false, true) => {
192 for i in (end..start).rev().step_by(step.unsigned_abs() as usize) {
193 let var = I::from_int_128(i);
194 body(scope, var.into())
195 }
196 }
197 (false, false) => {
198 for i in (start..end).step_by(step.unsigned_abs() as usize) {
199 let var = I::from_int_128(i);
200 body(scope, var.into())
201 }
202 }
203 }
204 }
205
206 fn const_len(&self) -> Option<usize> {
207 let start = self.start.constant()?.as_i128();
208 let end = self.end.constant()?.as_i128();
209 let step = self.step.constant()?.as_i128().unsigned_abs();
210 Some((start.abs_diff(end) / step) as usize)
211 }
212}
213
214pub fn range<T: Int>(start: T, end: T) -> impl Iterator<Item = T> {
220 let start: i64 = start.to_i64().unwrap();
221 let end: i64 = end.to_i64().unwrap();
222 (start..end).map(<T as NumCast>::from).map(Option::unwrap)
223}
224
225pub mod range {
226 use cubecl_ir::Scope;
227
228 use crate::prelude::{ExpandElementTyped, Int};
229
230 use super::RangeExpand;
231
232 pub fn expand<I: Int>(
233 _scope: &mut Scope,
234 start: ExpandElementTyped<I>,
235 end: ExpandElementTyped<I>,
236 ) -> RangeExpand<I> {
237 RangeExpand {
238 start,
239 end,
240 inclusive: false,
241 }
242 }
243}
244
245pub fn range_stepped<I: Int>(start: I, end: I, step: I) -> Box<dyn Iterator<Item = I>> {
253 let start = start.to_i128().unwrap();
254 let end = end.to_i128().unwrap();
255 let step = step.to_i128().unwrap();
256
257 if step < 0 {
258 Box::new(
259 (end..start)
260 .rev()
261 .step_by(step.unsigned_abs() as usize)
262 .map(<I as NumCast>::from)
263 .map(Option::unwrap),
264 )
265 } else {
266 Box::new(
267 (start..end)
268 .step_by(step.unsigned_abs() as usize)
269 .map(<I as NumCast>::from)
270 .map(Option::unwrap),
271 )
272 }
273}
274
275pub mod range_stepped {
276 use cubecl_ir::Scope;
277
278 use crate::prelude::{ExpandElementTyped, Int};
279
280 use super::SteppedRangeExpand;
281
282 pub fn expand<I: Int>(
283 _scope: &mut Scope,
284 start: ExpandElementTyped<I>,
285 end: ExpandElementTyped<I>,
286 step: ExpandElementTyped<I>,
287 ) -> SteppedRangeExpand<I> {
288 SteppedRangeExpand {
289 start,
290 end,
291 step,
292 inclusive: false,
293 }
294 }
295}
296
297pub fn for_expand<I: Numeric>(
298 scope: &mut Scope,
299 range: impl Iterable<I>,
300 unroll: bool,
301 body: impl FnMut(&mut Scope, ExpandElementTyped<I>),
302) {
303 if unroll || range.const_len() == Some(1) {
304 range.expand_unroll(scope, body);
305 } else {
306 range.expand(scope, body);
307 }
308}
309
310pub fn if_expand(scope: &mut Scope, runtime_cond: ExpandElement, block: impl FnOnce(&mut Scope)) {
311 let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
312 match comptime_cond {
313 Some(cond) => {
314 if cond {
315 block(scope);
316 }
317 }
318 None => {
319 let mut child = scope.child();
320
321 block(&mut child);
322
323 scope.register(Branch::If(Box::new(If {
324 cond: *runtime_cond,
325 scope: child,
326 })));
327 }
328 }
329}
330
331#[allow(clippy::large_enum_variant)]
332pub enum IfElseExpand {
333 ComptimeThen,
334 ComptimeElse,
335 Runtime {
336 runtime_cond: ExpandElement,
337 then_child: Scope,
338 },
339}
340
341impl IfElseExpand {
342 pub fn or_else(self, scope: &mut Scope, else_block: impl FnOnce(&mut Scope)) {
343 match self {
344 Self::Runtime {
345 runtime_cond,
346 then_child,
347 } => {
348 let mut else_child = scope.child();
349 else_block(&mut else_child);
350
351 scope.register(Branch::IfElse(Box::new(IfElse {
352 cond: *runtime_cond,
353 scope_if: then_child,
354 scope_else: else_child,
355 })));
356 }
357 Self::ComptimeElse => else_block(scope),
358 Self::ComptimeThen => (),
359 }
360 }
361}
362
363pub fn if_else_expand(
364 scope: &mut Scope,
365 runtime_cond: ExpandElement,
366 then_block: impl FnOnce(&mut Scope),
367) -> IfElseExpand {
368 let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
369 match comptime_cond {
370 Some(true) => {
371 then_block(scope);
372 IfElseExpand::ComptimeThen
373 }
374 Some(false) => IfElseExpand::ComptimeElse,
375 None => {
376 let mut then_child = scope.child();
377 then_block(&mut then_child);
378
379 IfElseExpand::Runtime {
380 runtime_cond,
381 then_child,
382 }
383 }
384 }
385}
386
387#[allow(clippy::large_enum_variant)]
388pub enum IfElseExprExpand<C: CubeType> {
389 ComptimeThen(ExpandElementTyped<C>),
390 ComptimeElse,
391 Runtime {
392 runtime_cond: ExpandElement,
393 out: ExpandElementTyped<C>,
394 then_child: Scope,
395 },
396}
397
398impl<C: CubePrimitive> IfElseExprExpand<C> {
399 pub fn or_else(
400 self,
401 scope: &mut Scope,
402 else_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
403 ) -> ExpandElementTyped<C> {
404 match self {
405 Self::Runtime {
406 runtime_cond,
407 out,
408 then_child,
409 } => {
410 let mut else_child = scope.child();
411 let ret = else_block(&mut else_child);
412 assign::expand_no_check::<C>(&mut else_child, ret, out.clone());
413
414 scope.register(Branch::IfElse(Box::new(IfElse {
415 cond: *runtime_cond,
416 scope_if: then_child,
417 scope_else: else_child,
418 })));
419 out
420 }
421 Self::ComptimeElse => else_block(scope),
422 Self::ComptimeThen(ret) => ret,
423 }
424 }
425}
426
427pub fn if_else_expr_expand<C: CubePrimitive>(
428 scope: &mut Scope,
429 runtime_cond: ExpandElement,
430 then_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
431) -> IfElseExprExpand<C> {
432 let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
433 match comptime_cond {
434 Some(true) => {
435 let ret = then_block(scope);
436 IfElseExprExpand::ComptimeThen(ret)
437 }
438 Some(false) => IfElseExprExpand::ComptimeElse,
439 None => {
440 let mut then_child = scope.child();
441 let ret = then_block(&mut then_child);
442 let out: ExpandElementTyped<C> = scope.create_local_mut(ret.expand.ty).into();
443 assign::expand_no_check::<C>(&mut then_child, ret, out.clone());
444
445 IfElseExprExpand::Runtime {
446 runtime_cond,
447 out,
448 then_child,
449 }
450 }
451 }
452}
453
454pub struct SwitchExpand<I: Int> {
455 value: ExpandElementTyped<I>,
456 default: Scope,
457 cases: Vec<(ExpandElementTyped<I>, Scope)>,
458}
459
460impl<I: Int> SwitchExpand<I> {
461 pub fn case(
462 mut self,
463 scope: &mut Scope,
464 value: impl Int,
465 block: impl FnOnce(&mut Scope),
466 ) -> Self {
467 let value = I::from(value).unwrap();
468 let mut case_child = scope.child();
469 block(&mut case_child);
470 self.cases.push((value.into(), case_child));
471 self
472 }
473
474 pub fn finish(self, scope: &mut Scope) {
475 let value_var = *self.value.expand;
476 scope.register(Branch::Switch(Box::new(Switch {
477 value: value_var,
478 scope_default: self.default,
479 cases: self
480 .cases
481 .into_iter()
482 .map(|it| (*it.0.expand, it.1))
483 .collect(),
484 })));
485 }
486}
487
488pub fn switch_expand<I: Int>(
489 scope: &mut Scope,
490 value: ExpandElementTyped<I>,
491 default_block: impl FnOnce(&mut Scope),
492) -> SwitchExpand<I> {
493 let mut default_child = scope.child();
494 default_block(&mut default_child);
495
496 SwitchExpand {
497 value,
498 default: default_child,
499 cases: Vec::new(),
500 }
501}
502
503pub struct SwitchExpandExpr<I: Int, C: CubePrimitive> {
504 value: ExpandElementTyped<I>,
505 out: ExpandElementTyped<C>,
506 default: Scope,
507 cases: Vec<(ExpandElementTyped<I>, Scope)>,
508}
509
510impl<I: Int, C: CubePrimitive> SwitchExpandExpr<I, C> {
511 pub fn case(
512 mut self,
513 scope: &mut Scope,
514 value: impl Int,
515 block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
516 ) -> Self {
517 let value = I::from(value).unwrap();
518 let mut case_child = scope.child();
519 let ret = block(&mut case_child);
520 assign::expand_no_check::<C>(&mut case_child, ret, self.out.clone());
521 self.cases.push((value.into(), case_child));
522 self
523 }
524
525 pub fn finish(self, scope: &mut Scope) -> ExpandElementTyped<C> {
526 let value_var = *self.value.expand;
527 scope.register(Branch::Switch(Box::new(Switch {
528 value: value_var,
529 scope_default: self.default,
530 cases: self
531 .cases
532 .into_iter()
533 .map(|it| (*it.0.expand, it.1))
534 .collect(),
535 })));
536 self.out
537 }
538}
539
540pub fn switch_expand_expr<I: Int, C: CubePrimitive>(
541 scope: &mut Scope,
542 value: ExpandElementTyped<I>,
543 default_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
544) -> SwitchExpandExpr<I, C> {
545 let mut default_child = scope.child();
546 let default = default_block(&mut default_child);
547 let out: ExpandElementTyped<C> = scope.create_local_mut(default.expand.ty).into();
548 assign::expand_no_check::<C>(&mut default_child, default, out.clone());
549
550 SwitchExpandExpr {
551 value,
552 out,
553 default: default_child,
554 cases: Vec::new(),
555 }
556}
557
558pub fn break_expand(scope: &mut Scope) {
559 scope.register(Branch::Break);
560}
561
562pub fn return_expand(scope: &mut Scope) {
563 scope.register(Branch::Return);
564}
565
566pub fn loop_expand(scope: &mut Scope, mut block: impl FnMut(&mut Scope)) {
568 let mut inside_loop = scope.child();
569
570 block(&mut inside_loop);
571 scope.register(Branch::Loop(Box::new(Loop { scope: inside_loop })));
572}