1use cubecl_ir::ExpandElement;
2use num_traits::NumCast;
3
4use crate::ir::Switch;
5use crate::ir::{Branch, If, IfElse, Item, Loop, RangeLoop, Scope};
6
7use super::{CubePrimitive, CubeType, ExpandElementTyped, Int, Numeric, assign};
8
9pub trait Iterable<T: CubeType>: Sized {
12 fn expand(self, scope: &mut Scope, body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType));
19 fn expand_unroll(
26 self,
27 scope: &mut Scope,
28 body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
29 );
30}
31
32pub struct RangeExpand<I: Int> {
33 pub start: ExpandElementTyped<I>,
34 pub end: ExpandElementTyped<I>,
35 pub inclusive: bool,
36}
37
38impl<I: Int> RangeExpand<I> {
39 pub fn new(start: ExpandElementTyped<I>, end: ExpandElementTyped<I>, inclusive: bool) -> Self {
40 RangeExpand {
41 start,
42 end,
43 inclusive,
44 }
45 }
46
47 pub fn __expand_step_by_method(
48 self,
49 n: impl Into<ExpandElementTyped<u32>>,
50 ) -> SteppedRangeExpand<I> {
51 SteppedRangeExpand {
52 start: self.start,
53 end: self.end,
54 step: n.into(),
55 inclusive: self.inclusive,
56 }
57 }
58}
59
60impl<I: Int> Iterable<I> for RangeExpand<I> {
61 fn expand_unroll(
62 self,
63 scope: &mut Scope,
64 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
65 ) {
66 let start = self
67 .start
68 .expand
69 .as_const()
70 .expect("Only constant start can be unrolled.")
71 .as_i64();
72 let end = self
73 .end
74 .expand
75 .as_const()
76 .expect("Only constant end can be unrolled.")
77 .as_i64();
78
79 if self.inclusive {
80 for i in start..=end {
81 let var = I::from_int(i);
82 body(scope, var.into())
83 }
84 } else {
85 for i in start..end {
86 let var = I::from_int(i);
87 body(scope, var.into())
88 }
89 }
90 }
91
92 fn expand(
93 self,
94 scope: &mut Scope,
95 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
96 ) {
97 let mut child = scope.child();
98 let index_ty = Item::new(I::as_elem(scope));
99 let i = child.create_local_restricted(index_ty);
100
101 body(&mut child, i.clone().into());
102
103 scope.register(Branch::RangeLoop(Box::new(RangeLoop {
104 i: *i,
105 start: *self.start.expand,
106 end: *self.end.expand,
107 step: None,
108 scope: child,
109 inclusive: self.inclusive,
110 })));
111 }
112}
113
114pub struct SteppedRangeExpand<I: Int> {
115 start: ExpandElementTyped<I>,
116 end: ExpandElementTyped<I>,
117 step: ExpandElementTyped<u32>,
118 inclusive: bool,
119}
120
121impl<I: Int + Into<ExpandElement>> Iterable<I> for SteppedRangeExpand<I> {
122 fn expand(
123 self,
124 scope: &mut Scope,
125 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
126 ) {
127 let mut child = scope.child();
128 let index_ty = Item::new(I::as_elem(scope));
129 let i = child.create_local_restricted(index_ty);
130
131 body(&mut child, i.clone().into());
132
133 scope.register(Branch::RangeLoop(Box::new(RangeLoop {
134 i: *i,
135 start: *self.start.expand,
136 end: *self.end.expand,
137 step: Some(*self.step.expand),
138 scope: child,
139 inclusive: self.inclusive,
140 })));
141 }
142
143 fn expand_unroll(
144 self,
145 scope: &mut Scope,
146 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
147 ) {
148 let start = self
149 .start
150 .expand
151 .as_const()
152 .expect("Only constant start can be unrolled.")
153 .as_i64();
154 let end = self
155 .end
156 .expand
157 .as_const()
158 .expect("Only constant end can be unrolled.")
159 .as_i64();
160 let step = self
161 .step
162 .expand
163 .as_const()
164 .expect("Only constant step can be unrolled.")
165 .as_usize();
166
167 if self.inclusive {
168 for i in (start..=end).step_by(step) {
169 let var = I::from_int(i);
170 body(scope, var.into())
171 }
172 } else {
173 for i in (start..end).step_by(step) {
174 let var = I::from_int(i);
175 body(scope, var.into())
176 }
177 }
178 }
179}
180
181pub fn range<T: Int>(start: T, end: T) -> impl Iterator<Item = T> {
187 let start: i64 = start.to_i64().unwrap();
188 let end: i64 = end.to_i64().unwrap();
189 (start..end).map(<T as NumCast>::from).map(Option::unwrap)
190}
191
192pub mod range {
193 use cubecl_ir::Scope;
194
195 use crate::prelude::{ExpandElementTyped, Int};
196
197 use super::RangeExpand;
198
199 pub fn expand<I: Int>(
200 _scope: &mut Scope,
201 start: ExpandElementTyped<I>,
202 end: ExpandElementTyped<I>,
203 ) -> RangeExpand<I> {
204 RangeExpand {
205 start,
206 end,
207 inclusive: false,
208 }
209 }
210}
211
212pub fn range_stepped<I: Int>(start: I, end: I, step: impl Int) -> impl Iterator<Item = I> {
220 let start = start.to_i64().unwrap();
221 let end = end.to_i64().unwrap();
222 let step = step.to_usize().unwrap();
223 (start..end)
224 .step_by(step)
225 .map(<I as NumCast>::from)
226 .map(Option::unwrap)
227}
228
229pub mod range_stepped {
230 use cubecl_ir::Scope;
231
232 use crate::prelude::{ExpandElementTyped, Int};
233
234 use super::SteppedRangeExpand;
235
236 pub fn expand<I: Int>(
237 _scope: &mut Scope,
238 start: ExpandElementTyped<I>,
239 end: ExpandElementTyped<I>,
240 step: ExpandElementTyped<u32>,
241 ) -> SteppedRangeExpand<I> {
242 SteppedRangeExpand {
243 start,
244 end,
245 step,
246 inclusive: false,
247 }
248 }
249}
250
251pub fn for_expand<I: Numeric>(
252 scope: &mut Scope,
253 range: impl Iterable<I>,
254 unroll: bool,
255 body: impl FnMut(&mut Scope, ExpandElementTyped<I>),
256) {
257 if unroll {
258 range.expand_unroll(scope, body);
259 } else {
260 range.expand(scope, body);
261 }
262}
263
264pub fn if_expand(scope: &mut Scope, runtime_cond: ExpandElement, block: impl FnOnce(&mut Scope)) {
265 let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
266 match comptime_cond {
267 Some(cond) => {
268 if cond {
269 block(scope);
270 }
271 }
272 None => {
273 let mut child = scope.child();
274
275 block(&mut child);
276
277 scope.register(Branch::If(Box::new(If {
278 cond: *runtime_cond,
279 scope: child,
280 })));
281 }
282 }
283}
284
285#[allow(clippy::large_enum_variant)]
286pub enum IfElseExpand {
287 ComptimeThen,
288 ComptimeElse,
289 Runtime {
290 runtime_cond: ExpandElement,
291 then_child: Scope,
292 },
293}
294
295impl IfElseExpand {
296 pub fn or_else(self, scope: &mut Scope, else_block: impl FnOnce(&mut Scope)) {
297 match self {
298 Self::Runtime {
299 runtime_cond,
300 then_child,
301 } => {
302 let mut else_child = scope.child();
303 else_block(&mut else_child);
304
305 scope.register(Branch::IfElse(Box::new(IfElse {
306 cond: *runtime_cond,
307 scope_if: then_child,
308 scope_else: else_child,
309 })));
310 }
311 Self::ComptimeElse => else_block(scope),
312 Self::ComptimeThen => (),
313 }
314 }
315}
316
317pub fn if_else_expand(
318 scope: &mut Scope,
319 runtime_cond: ExpandElement,
320 then_block: impl FnOnce(&mut Scope),
321) -> IfElseExpand {
322 let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
323 match comptime_cond {
324 Some(true) => {
325 then_block(scope);
326 IfElseExpand::ComptimeThen
327 }
328 Some(false) => IfElseExpand::ComptimeElse,
329 None => {
330 let mut then_child = scope.child();
331 then_block(&mut then_child);
332
333 IfElseExpand::Runtime {
334 runtime_cond,
335 then_child,
336 }
337 }
338 }
339}
340
341#[allow(clippy::large_enum_variant)]
342pub enum IfElseExprExpand<C: CubeType> {
343 ComptimeThen(ExpandElementTyped<C>),
344 ComptimeElse,
345 Runtime {
346 runtime_cond: ExpandElement,
347 out: ExpandElementTyped<C>,
348 then_child: Scope,
349 },
350}
351
352impl<C: CubePrimitive> IfElseExprExpand<C> {
353 pub fn or_else(
354 self,
355 scope: &mut Scope,
356 else_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
357 ) -> ExpandElementTyped<C> {
358 match self {
359 Self::Runtime {
360 runtime_cond,
361 out,
362 then_child,
363 } => {
364 let mut else_child = scope.child();
365 let ret = else_block(&mut else_child);
366 assign::expand::<C>(&mut else_child, ret, out.clone());
367
368 scope.register(Branch::IfElse(Box::new(IfElse {
369 cond: *runtime_cond,
370 scope_if: then_child,
371 scope_else: else_child,
372 })));
373 out
374 }
375 Self::ComptimeElse => else_block(scope),
376 Self::ComptimeThen(ret) => ret,
377 }
378 }
379}
380
381pub fn if_else_expr_expand<C: CubePrimitive>(
382 scope: &mut Scope,
383 runtime_cond: ExpandElement,
384 then_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
385) -> IfElseExprExpand<C> {
386 let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
387 match comptime_cond {
388 Some(true) => {
389 let ret = then_block(scope);
390 IfElseExprExpand::ComptimeThen(ret)
391 }
392 Some(false) => IfElseExprExpand::ComptimeElse,
393 None => {
394 let mut then_child = scope.child();
395 let ret = then_block(&mut then_child);
396 let out: ExpandElementTyped<C> = scope.create_local_mut(ret.expand.item).into();
397 assign::expand::<C>(&mut then_child, ret, out.clone());
398
399 IfElseExprExpand::Runtime {
400 runtime_cond,
401 out,
402 then_child,
403 }
404 }
405 }
406}
407
408pub struct SwitchExpand<I: Int> {
409 value: ExpandElementTyped<I>,
410 default: Scope,
411 cases: Vec<(ExpandElementTyped<I>, Scope)>,
412}
413
414impl<I: Int> SwitchExpand<I> {
415 pub fn case(
416 mut self,
417 scope: &mut Scope,
418 value: impl Int,
419 block: impl FnOnce(&mut Scope),
420 ) -> Self {
421 let value = I::from(value).unwrap();
422 let mut case_child = scope.child();
423 block(&mut case_child);
424 self.cases.push((value.into(), case_child));
425 self
426 }
427
428 pub fn finish(self, scope: &mut Scope) {
429 let value_var = *self.value.expand;
430 scope.register(Branch::Switch(Box::new(Switch {
431 value: value_var,
432 scope_default: self.default,
433 cases: self
434 .cases
435 .into_iter()
436 .map(|it| (*it.0.expand, it.1))
437 .collect(),
438 })));
439 }
440}
441
442pub fn switch_expand<I: Int>(
443 scope: &mut Scope,
444 value: ExpandElementTyped<I>,
445 default_block: impl FnOnce(&mut Scope),
446) -> SwitchExpand<I> {
447 let mut default_child = scope.child();
448 default_block(&mut default_child);
449
450 SwitchExpand {
451 value,
452 default: default_child,
453 cases: Vec::new(),
454 }
455}
456
457pub struct SwitchExpandExpr<I: Int, C: CubePrimitive> {
458 value: ExpandElementTyped<I>,
459 out: ExpandElementTyped<C>,
460 default: Scope,
461 cases: Vec<(ExpandElementTyped<I>, Scope)>,
462}
463
464impl<I: Int, C: CubePrimitive> SwitchExpandExpr<I, C> {
465 pub fn case(
466 mut self,
467 scope: &mut Scope,
468 value: impl Int,
469 block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
470 ) -> Self {
471 let value = I::from(value).unwrap();
472 let mut case_child = scope.child();
473 let ret = block(&mut case_child);
474 assign::expand::<C>(&mut case_child, ret, self.out.clone());
475 self.cases.push((value.into(), case_child));
476 self
477 }
478
479 pub fn finish(self, scope: &mut Scope) -> ExpandElementTyped<C> {
480 let value_var = *self.value.expand;
481 scope.register(Branch::Switch(Box::new(Switch {
482 value: value_var,
483 scope_default: self.default,
484 cases: self
485 .cases
486 .into_iter()
487 .map(|it| (*it.0.expand, it.1))
488 .collect(),
489 })));
490 self.out
491 }
492}
493
494pub fn switch_expand_expr<I: Int, C: CubePrimitive>(
495 scope: &mut Scope,
496 value: ExpandElementTyped<I>,
497 default_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
498) -> SwitchExpandExpr<I, C> {
499 let mut default_child = scope.child();
500 let default = default_block(&mut default_child);
501 let out: ExpandElementTyped<C> = scope.create_local_mut(default.expand.item).into();
502 assign::expand::<C>(&mut default_child, default, out.clone());
503
504 SwitchExpandExpr {
505 value,
506 out,
507 default: default_child,
508 cases: Vec::new(),
509 }
510}
511
512pub fn break_expand(scope: &mut Scope) {
513 scope.register(Branch::Break);
514}
515
516pub fn return_expand(scope: &mut Scope) {
517 scope.register(Branch::Return);
518}
519
520pub fn loop_expand(scope: &mut Scope, mut block: impl FnMut(&mut Scope)) {
522 let mut inside_loop = scope.child();
523
524 block(&mut inside_loop);
525 scope.register(Branch::Loop(Box::new(Loop { scope: inside_loop })));
526}