1use cubecl_ir::ExpandElement;
2use num_traits::NumCast;
3
4use crate::ir::Switch;
5use crate::ir::{Branch, If, IfElse, Loop, RangeLoop, Scope, Type};
6
7use super::{CubePrimitive, CubeType, ExpandElementTyped, Int, Numeric, assign};
8
9pub trait Iterable<T: CubeType>: Sized {
12 fn expand(self, scope: &mut Scope, body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType));
19 fn expand_unroll(
26 self,
27 scope: &mut Scope,
28 body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
29 );
30 fn const_len(&self) -> Option<usize> {
32 None
33 }
34}
35
36pub struct RangeExpand<I: Int> {
37 pub start: ExpandElementTyped<I>,
38 pub end: ExpandElementTyped<I>,
39 pub inclusive: bool,
40}
41
42impl<I: Int> RangeExpand<I> {
43 pub fn new(start: ExpandElementTyped<I>, end: ExpandElementTyped<I>, inclusive: bool) -> Self {
44 RangeExpand {
45 start,
46 end,
47 inclusive,
48 }
49 }
50
51 pub fn __expand_step_by_method(
52 self,
53 n: impl Into<ExpandElementTyped<u32>>,
54 ) -> SteppedRangeExpand<I> {
55 SteppedRangeExpand {
56 start: self.start,
57 end: self.end,
58 step: n.into(),
59 inclusive: self.inclusive,
60 }
61 }
62}
63
64impl<I: Int> Iterable<I> for RangeExpand<I> {
65 fn expand_unroll(
66 self,
67 scope: &mut Scope,
68 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
69 ) {
70 let start = self
71 .start
72 .expand
73 .as_const()
74 .expect("Only constant start can be unrolled.")
75 .as_i64();
76 let end = self
77 .end
78 .expand
79 .as_const()
80 .expect("Only constant end can be unrolled.")
81 .as_i64();
82
83 if self.inclusive {
84 for i in start..=end {
85 let var = I::from_int(i);
86 body(scope, var.into())
87 }
88 } else {
89 for i in start..end {
90 let var = I::from_int(i);
91 body(scope, var.into())
92 }
93 }
94 }
95
96 fn expand(
97 self,
98 scope: &mut Scope,
99 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
100 ) {
101 let mut child = scope.child();
102 let index_ty = Type::new(I::as_type(scope));
103 let i = child.create_local_restricted(index_ty);
104
105 body(&mut child, i.clone().into());
106
107 scope.register(Branch::RangeLoop(Box::new(RangeLoop {
108 i: *i,
109 start: *self.start.expand,
110 end: *self.end.expand,
111 step: None,
112 scope: child,
113 inclusive: self.inclusive,
114 })));
115 }
116
117 fn const_len(&self) -> Option<usize> {
118 let start = self.start.expand.as_const()?.as_i64();
119 let end = self.end.expand.as_const()?.as_i64();
120 Some(start.abs_diff(end) as usize)
121 }
122}
123
124pub struct SteppedRangeExpand<I: Int> {
125 start: ExpandElementTyped<I>,
126 end: ExpandElementTyped<I>,
127 step: ExpandElementTyped<u32>,
128 inclusive: bool,
129}
130
131impl<I: Int + Into<ExpandElement>> Iterable<I> for SteppedRangeExpand<I> {
132 fn expand(
133 self,
134 scope: &mut Scope,
135 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
136 ) {
137 let mut child = scope.child();
138 let index_ty = Type::new(I::as_type(scope));
139 let i = child.create_local_restricted(index_ty);
140
141 body(&mut child, i.clone().into());
142
143 scope.register(Branch::RangeLoop(Box::new(RangeLoop {
144 i: *i,
145 start: *self.start.expand,
146 end: *self.end.expand,
147 step: Some(*self.step.expand),
148 scope: child,
149 inclusive: self.inclusive,
150 })));
151 }
152
153 fn expand_unroll(
154 self,
155 scope: &mut Scope,
156 mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
157 ) {
158 let start = self
159 .start
160 .expand
161 .as_const()
162 .expect("Only constant start can be unrolled.")
163 .as_i64();
164 let end = self
165 .end
166 .expand
167 .as_const()
168 .expect("Only constant end can be unrolled.")
169 .as_i64();
170 let step = self
171 .step
172 .expand
173 .as_const()
174 .expect("Only constant step can be unrolled.")
175 .as_usize();
176
177 if self.inclusive {
178 for i in (start..=end).step_by(step) {
179 let var = I::from_int(i);
180 body(scope, var.into())
181 }
182 } else {
183 for i in (start..end).step_by(step) {
184 let var = I::from_int(i);
185 body(scope, var.into())
186 }
187 }
188 }
189
190 fn const_len(&self) -> Option<usize> {
191 let start = self.start.constant()?.as_i64();
192 let end = self.end.constant()?.as_i64();
193 let step = self.step.constant()?.as_u64();
194 Some((start.abs_diff(end) / step) as usize)
195 }
196}
197
198pub fn range<T: Int>(start: T, end: T) -> impl Iterator<Item = T> {
204 let start: i64 = start.to_i64().unwrap();
205 let end: i64 = end.to_i64().unwrap();
206 (start..end).map(<T as NumCast>::from).map(Option::unwrap)
207}
208
209pub mod range {
210 use cubecl_ir::Scope;
211
212 use crate::prelude::{ExpandElementTyped, Int};
213
214 use super::RangeExpand;
215
216 pub fn expand<I: Int>(
217 _scope: &mut Scope,
218 start: ExpandElementTyped<I>,
219 end: ExpandElementTyped<I>,
220 ) -> RangeExpand<I> {
221 RangeExpand {
222 start,
223 end,
224 inclusive: false,
225 }
226 }
227}
228
229pub fn range_stepped<I: Int>(start: I, end: I, step: impl Int) -> impl Iterator<Item = I> {
237 let start = start.to_i64().unwrap();
238 let end = end.to_i64().unwrap();
239 let step = step.to_usize().unwrap();
240 (start..end)
241 .step_by(step)
242 .map(<I as NumCast>::from)
243 .map(Option::unwrap)
244}
245
246pub mod range_stepped {
247 use cubecl_ir::Scope;
248
249 use crate::prelude::{ExpandElementTyped, Int};
250
251 use super::SteppedRangeExpand;
252
253 pub fn expand<I: Int>(
254 _scope: &mut Scope,
255 start: ExpandElementTyped<I>,
256 end: ExpandElementTyped<I>,
257 step: ExpandElementTyped<u32>,
258 ) -> SteppedRangeExpand<I> {
259 SteppedRangeExpand {
260 start,
261 end,
262 step,
263 inclusive: false,
264 }
265 }
266}
267
268pub fn for_expand<I: Numeric>(
269 scope: &mut Scope,
270 range: impl Iterable<I>,
271 unroll: bool,
272 body: impl FnMut(&mut Scope, ExpandElementTyped<I>),
273) {
274 if unroll || range.const_len() == Some(1) {
275 range.expand_unroll(scope, body);
276 } else {
277 range.expand(scope, body);
278 }
279}
280
281pub fn if_expand(scope: &mut Scope, runtime_cond: ExpandElement, block: impl FnOnce(&mut Scope)) {
282 let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
283 match comptime_cond {
284 Some(cond) => {
285 if cond {
286 block(scope);
287 }
288 }
289 None => {
290 let mut child = scope.child();
291
292 block(&mut child);
293
294 scope.register(Branch::If(Box::new(If {
295 cond: *runtime_cond,
296 scope: child,
297 })));
298 }
299 }
300}
301
302#[allow(clippy::large_enum_variant)]
303pub enum IfElseExpand {
304 ComptimeThen,
305 ComptimeElse,
306 Runtime {
307 runtime_cond: ExpandElement,
308 then_child: Scope,
309 },
310}
311
312impl IfElseExpand {
313 pub fn or_else(self, scope: &mut Scope, else_block: impl FnOnce(&mut Scope)) {
314 match self {
315 Self::Runtime {
316 runtime_cond,
317 then_child,
318 } => {
319 let mut else_child = scope.child();
320 else_block(&mut else_child);
321
322 scope.register(Branch::IfElse(Box::new(IfElse {
323 cond: *runtime_cond,
324 scope_if: then_child,
325 scope_else: else_child,
326 })));
327 }
328 Self::ComptimeElse => else_block(scope),
329 Self::ComptimeThen => (),
330 }
331 }
332}
333
334pub fn if_else_expand(
335 scope: &mut Scope,
336 runtime_cond: ExpandElement,
337 then_block: impl FnOnce(&mut Scope),
338) -> IfElseExpand {
339 let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
340 match comptime_cond {
341 Some(true) => {
342 then_block(scope);
343 IfElseExpand::ComptimeThen
344 }
345 Some(false) => IfElseExpand::ComptimeElse,
346 None => {
347 let mut then_child = scope.child();
348 then_block(&mut then_child);
349
350 IfElseExpand::Runtime {
351 runtime_cond,
352 then_child,
353 }
354 }
355 }
356}
357
358#[allow(clippy::large_enum_variant)]
359pub enum IfElseExprExpand<C: CubeType> {
360 ComptimeThen(ExpandElementTyped<C>),
361 ComptimeElse,
362 Runtime {
363 runtime_cond: ExpandElement,
364 out: ExpandElementTyped<C>,
365 then_child: Scope,
366 },
367}
368
369impl<C: CubePrimitive> IfElseExprExpand<C> {
370 pub fn or_else(
371 self,
372 scope: &mut Scope,
373 else_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
374 ) -> ExpandElementTyped<C> {
375 match self {
376 Self::Runtime {
377 runtime_cond,
378 out,
379 then_child,
380 } => {
381 let mut else_child = scope.child();
382 let ret = else_block(&mut else_child);
383 assign::expand_no_check::<C>(&mut else_child, ret, out.clone());
384
385 scope.register(Branch::IfElse(Box::new(IfElse {
386 cond: *runtime_cond,
387 scope_if: then_child,
388 scope_else: else_child,
389 })));
390 out
391 }
392 Self::ComptimeElse => else_block(scope),
393 Self::ComptimeThen(ret) => ret,
394 }
395 }
396}
397
398pub fn if_else_expr_expand<C: CubePrimitive>(
399 scope: &mut Scope,
400 runtime_cond: ExpandElement,
401 then_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
402) -> IfElseExprExpand<C> {
403 let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
404 match comptime_cond {
405 Some(true) => {
406 let ret = then_block(scope);
407 IfElseExprExpand::ComptimeThen(ret)
408 }
409 Some(false) => IfElseExprExpand::ComptimeElse,
410 None => {
411 let mut then_child = scope.child();
412 let ret = then_block(&mut then_child);
413 let out: ExpandElementTyped<C> = scope.create_local_mut(ret.expand.ty).into();
414 assign::expand_no_check::<C>(&mut then_child, ret, out.clone());
415
416 IfElseExprExpand::Runtime {
417 runtime_cond,
418 out,
419 then_child,
420 }
421 }
422 }
423}
424
425pub struct SwitchExpand<I: Int> {
426 value: ExpandElementTyped<I>,
427 default: Scope,
428 cases: Vec<(ExpandElementTyped<I>, Scope)>,
429}
430
431impl<I: Int> SwitchExpand<I> {
432 pub fn case(
433 mut self,
434 scope: &mut Scope,
435 value: impl Int,
436 block: impl FnOnce(&mut Scope),
437 ) -> Self {
438 let value = I::from(value).unwrap();
439 let mut case_child = scope.child();
440 block(&mut case_child);
441 self.cases.push((value.into(), case_child));
442 self
443 }
444
445 pub fn finish(self, scope: &mut Scope) {
446 let value_var = *self.value.expand;
447 scope.register(Branch::Switch(Box::new(Switch {
448 value: value_var,
449 scope_default: self.default,
450 cases: self
451 .cases
452 .into_iter()
453 .map(|it| (*it.0.expand, it.1))
454 .collect(),
455 })));
456 }
457}
458
459pub fn switch_expand<I: Int>(
460 scope: &mut Scope,
461 value: ExpandElementTyped<I>,
462 default_block: impl FnOnce(&mut Scope),
463) -> SwitchExpand<I> {
464 let mut default_child = scope.child();
465 default_block(&mut default_child);
466
467 SwitchExpand {
468 value,
469 default: default_child,
470 cases: Vec::new(),
471 }
472}
473
474pub struct SwitchExpandExpr<I: Int, C: CubePrimitive> {
475 value: ExpandElementTyped<I>,
476 out: ExpandElementTyped<C>,
477 default: Scope,
478 cases: Vec<(ExpandElementTyped<I>, Scope)>,
479}
480
481impl<I: Int, C: CubePrimitive> SwitchExpandExpr<I, C> {
482 pub fn case(
483 mut self,
484 scope: &mut Scope,
485 value: impl Int,
486 block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
487 ) -> Self {
488 let value = I::from(value).unwrap();
489 let mut case_child = scope.child();
490 let ret = block(&mut case_child);
491 assign::expand_no_check::<C>(&mut case_child, ret, self.out.clone());
492 self.cases.push((value.into(), case_child));
493 self
494 }
495
496 pub fn finish(self, scope: &mut Scope) -> ExpandElementTyped<C> {
497 let value_var = *self.value.expand;
498 scope.register(Branch::Switch(Box::new(Switch {
499 value: value_var,
500 scope_default: self.default,
501 cases: self
502 .cases
503 .into_iter()
504 .map(|it| (*it.0.expand, it.1))
505 .collect(),
506 })));
507 self.out
508 }
509}
510
511pub fn switch_expand_expr<I: Int, C: CubePrimitive>(
512 scope: &mut Scope,
513 value: ExpandElementTyped<I>,
514 default_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
515) -> SwitchExpandExpr<I, C> {
516 let mut default_child = scope.child();
517 let default = default_block(&mut default_child);
518 let out: ExpandElementTyped<C> = scope.create_local_mut(default.expand.ty).into();
519 assign::expand_no_check::<C>(&mut default_child, default, out.clone());
520
521 SwitchExpandExpr {
522 value,
523 out,
524 default: default_child,
525 cases: Vec::new(),
526 }
527}
528
529pub fn break_expand(scope: &mut Scope) {
530 scope.register(Branch::Break);
531}
532
533pub fn return_expand(scope: &mut Scope) {
534 scope.register(Branch::Return);
535}
536
537pub fn loop_expand(scope: &mut Scope, mut block: impl FnMut(&mut Scope)) {
539 let mut inside_loop = scope.child();
540
541 block(&mut inside_loop);
542 scope.register(Branch::Loop(Box::new(Loop { scope: inside_loop })));
543}