1use num_traits::NumCast;
2
3use crate::ir::{Branch, If, IfElse, Item, Loop, RangeLoop};
4use crate::{
5 frontend::{CubeContext, ExpandElement},
6 ir::Switch,
7};
8
9use super::{assign, CubePrimitive, CubeType, ExpandElementTyped, Int, Numeric};
10
11pub trait Iterable<T: CubeType>: Sized {
14 fn expand(
20 self,
21 context: &mut CubeContext,
22 body: impl FnMut(&mut CubeContext, <T as CubeType>::ExpandType),
23 );
24 fn expand_unroll(
31 self,
32 context: &mut CubeContext,
33 body: impl FnMut(&mut CubeContext, <T as CubeType>::ExpandType),
34 );
35}
36
37pub struct RangeExpand<I: Int> {
38 pub start: ExpandElementTyped<I>,
39 pub end: ExpandElementTyped<I>,
40 pub inclusive: bool,
41}
42
43impl<I: Int> RangeExpand<I> {
44 pub fn new(start: ExpandElementTyped<I>, end: ExpandElementTyped<I>, inclusive: bool) -> Self {
45 RangeExpand {
46 start,
47 end,
48 inclusive,
49 }
50 }
51
52 pub fn __expand_step_by_method(
53 self,
54 n: impl Into<ExpandElementTyped<u32>>,
55 ) -> SteppedRangeExpand<I> {
56 SteppedRangeExpand {
57 start: self.start,
58 end: self.end,
59 step: n.into(),
60 inclusive: self.inclusive,
61 }
62 }
63}
64
65impl<I: Int> Iterable<I> for RangeExpand<I> {
66 fn expand_unroll(
67 self,
68 context: &mut CubeContext,
69 mut body: impl FnMut(&mut CubeContext, <I as CubeType>::ExpandType),
70 ) {
71 let start = self
72 .start
73 .expand
74 .as_const()
75 .expect("Only constant start can be unrolled.")
76 .as_i64();
77 let end = self
78 .end
79 .expand
80 .as_const()
81 .expect("Only constant end can be unrolled.")
82 .as_i64();
83
84 if self.inclusive {
85 for i in start..=end {
86 let var = I::from_int(i);
87 body(context, var.into())
88 }
89 } else {
90 for i in start..end {
91 let var = I::from_int(i);
92 body(context, var.into())
93 }
94 }
95 }
96
97 fn expand(
98 self,
99 context: &mut CubeContext,
100 mut body: impl FnMut(&mut CubeContext, <I as CubeType>::ExpandType),
101 ) {
102 let mut child = context.child();
103 let index_ty = Item::new(I::as_elem(context));
104 let i = child.create_local_restricted(index_ty);
105
106 body(&mut child, i.clone().into());
107
108 context.register(Branch::RangeLoop(Box::new(RangeLoop {
109 i: *i,
110 start: *self.start.expand,
111 end: *self.end.expand,
112 step: None,
113 scope: child.into_scope(),
114 inclusive: self.inclusive,
115 })));
116 }
117}
118
119pub struct SteppedRangeExpand<I: Int> {
120 start: ExpandElementTyped<I>,
121 end: ExpandElementTyped<I>,
122 step: ExpandElementTyped<u32>,
123 inclusive: bool,
124}
125
126impl<I: Int + Into<ExpandElement>> Iterable<I> for SteppedRangeExpand<I> {
127 fn expand(
128 self,
129 context: &mut CubeContext,
130 mut body: impl FnMut(&mut CubeContext, <I as CubeType>::ExpandType),
131 ) {
132 let mut child = context.child();
133 let index_ty = Item::new(I::as_elem(context));
134 let i = child.create_local_restricted(index_ty);
135
136 body(&mut child, i.clone().into());
137
138 context.register(Branch::RangeLoop(Box::new(RangeLoop {
139 i: *i,
140 start: *self.start.expand,
141 end: *self.end.expand,
142 step: Some(*self.step.expand),
143 scope: child.into_scope(),
144 inclusive: self.inclusive,
145 })));
146 }
147
148 fn expand_unroll(
149 self,
150 context: &mut CubeContext,
151 mut body: impl FnMut(&mut CubeContext, <I as CubeType>::ExpandType),
152 ) {
153 let start = self
154 .start
155 .expand
156 .as_const()
157 .expect("Only constant start can be unrolled.")
158 .as_i64();
159 let end = self
160 .end
161 .expand
162 .as_const()
163 .expect("Only constant end can be unrolled.")
164 .as_i64();
165 let step = self
166 .step
167 .expand
168 .as_const()
169 .expect("Only constant step can be unrolled.")
170 .as_usize();
171
172 if self.inclusive {
173 for i in (start..=end).step_by(step) {
174 let var = I::from_int(i);
175 body(context, var.into())
176 }
177 } else {
178 for i in (start..end).step_by(step) {
179 let var = I::from_int(i);
180 body(context, var.into())
181 }
182 }
183 }
184}
185
186pub fn range<T: Int>(start: T, end: T) -> impl Iterator<Item = T> {
192 let start: i64 = start.to_i64().unwrap();
193 let end: i64 = end.to_i64().unwrap();
194 (start..end).map(<T as NumCast>::from).map(Option::unwrap)
195}
196
197pub mod range {
198 use crate::prelude::{CubeContext, ExpandElementTyped, Int};
199
200 use super::RangeExpand;
201
202 pub fn expand<I: Int>(
203 _context: &mut CubeContext,
204 start: ExpandElementTyped<I>,
205 end: ExpandElementTyped<I>,
206 ) -> RangeExpand<I> {
207 RangeExpand {
208 start,
209 end,
210 inclusive: false,
211 }
212 }
213}
214
215pub fn range_stepped<I: Int>(start: I, end: I, step: impl Int) -> impl Iterator<Item = I> {
223 let start = start.to_i64().unwrap();
224 let end = end.to_i64().unwrap();
225 let step = step.to_usize().unwrap();
226 (start..end)
227 .step_by(step)
228 .map(<I as NumCast>::from)
229 .map(Option::unwrap)
230}
231
232pub mod range_stepped {
233 use crate::prelude::{CubeContext, ExpandElementTyped, Int};
234
235 use super::SteppedRangeExpand;
236
237 pub fn expand<I: Int>(
238 _context: &mut CubeContext,
239 start: ExpandElementTyped<I>,
240 end: ExpandElementTyped<I>,
241 step: ExpandElementTyped<u32>,
242 ) -> SteppedRangeExpand<I> {
243 SteppedRangeExpand {
244 start,
245 end,
246 step,
247 inclusive: false,
248 }
249 }
250}
251
252pub fn for_expand<I: Numeric>(
253 context: &mut CubeContext,
254 range: impl Iterable<I>,
255 unroll: bool,
256 body: impl FnMut(&mut CubeContext, ExpandElementTyped<I>),
257) {
258 if unroll {
259 range.expand_unroll(context, body);
260 } else {
261 range.expand(context, body);
262 }
263}
264
265pub fn if_expand(
266 context: &mut CubeContext,
267 runtime_cond: ExpandElement,
268 block: impl FnOnce(&mut CubeContext),
269) {
270 let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
271 match comptime_cond {
272 Some(cond) => {
273 if cond {
274 block(context);
275 }
276 }
277 None => {
278 let mut child = context.child();
279
280 block(&mut child);
281
282 context.register(Branch::If(Box::new(If {
283 cond: *runtime_cond,
284 scope: child.into_scope(),
285 })));
286 }
287 }
288}
289
290pub enum IfElseExpand {
291 ComptimeThen,
292 ComptimeElse,
293 Runtime {
294 runtime_cond: ExpandElement,
295 then_child: CubeContext,
296 },
297}
298
299impl IfElseExpand {
300 pub fn or_else(self, context: &mut CubeContext, else_block: impl FnOnce(&mut CubeContext)) {
301 match self {
302 Self::Runtime {
303 runtime_cond,
304 then_child,
305 } => {
306 let mut else_child = context.child();
307 else_block(&mut else_child);
308
309 context.register(Branch::IfElse(Box::new(IfElse {
310 cond: *runtime_cond,
311 scope_if: then_child.into_scope(),
312 scope_else: else_child.into_scope(),
313 })));
314 }
315 Self::ComptimeElse => else_block(context),
316 Self::ComptimeThen => (),
317 }
318 }
319}
320
321pub fn if_else_expand(
322 context: &mut CubeContext,
323 runtime_cond: ExpandElement,
324 then_block: impl FnOnce(&mut CubeContext),
325) -> IfElseExpand {
326 let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
327 match comptime_cond {
328 Some(true) => {
329 then_block(context);
330 IfElseExpand::ComptimeThen
331 }
332 Some(false) => IfElseExpand::ComptimeElse,
333 None => {
334 let mut then_child = context.child();
335 then_block(&mut then_child);
336
337 IfElseExpand::Runtime {
338 runtime_cond,
339 then_child,
340 }
341 }
342 }
343}
344
345pub enum IfElseExprExpand<C: CubeType> {
346 ComptimeThen(ExpandElementTyped<C>),
347 ComptimeElse,
348 Runtime {
349 runtime_cond: ExpandElement,
350 out: ExpandElementTyped<C>,
351 then_child: CubeContext,
352 },
353}
354
355impl<C: CubePrimitive> IfElseExprExpand<C> {
356 pub fn or_else(
357 self,
358 context: &mut CubeContext,
359 else_block: impl FnOnce(&mut CubeContext) -> ExpandElementTyped<C>,
360 ) -> ExpandElementTyped<C> {
361 match self {
362 Self::Runtime {
363 runtime_cond,
364 out,
365 then_child,
366 } => {
367 let mut else_child = context.child();
368 let ret = else_block(&mut else_child);
369 assign::expand(&mut else_child, ret, out.clone());
370
371 context.register(Branch::IfElse(Box::new(IfElse {
372 cond: *runtime_cond,
373 scope_if: then_child.into_scope(),
374 scope_else: else_child.into_scope(),
375 })));
376 out
377 }
378 Self::ComptimeElse => else_block(context),
379 Self::ComptimeThen(ret) => ret,
380 }
381 }
382}
383
384pub fn if_else_expr_expand<C: CubePrimitive>(
385 context: &mut CubeContext,
386 runtime_cond: ExpandElement,
387 then_block: impl FnOnce(&mut CubeContext) -> ExpandElementTyped<C>,
388) -> IfElseExprExpand<C> {
389 let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
390 match comptime_cond {
391 Some(true) => {
392 let ret = then_block(context);
393 IfElseExprExpand::ComptimeThen(ret)
394 }
395 Some(false) => IfElseExprExpand::ComptimeElse,
396 None => {
397 let mut then_child = context.child();
398 let ret = then_block(&mut then_child);
399 let out: ExpandElementTyped<C> = context.create_local_mut(ret.expand.item).into();
400 assign::expand(&mut then_child, ret, out.clone());
401
402 IfElseExprExpand::Runtime {
403 runtime_cond,
404 out,
405 then_child,
406 }
407 }
408 }
409}
410
411pub struct SwitchExpand<I: Int> {
412 value: ExpandElementTyped<I>,
413 default: CubeContext,
414 cases: Vec<(ExpandElementTyped<I>, CubeContext)>,
415}
416
417impl<I: Int> SwitchExpand<I> {
418 pub fn case(
419 mut self,
420 context: &mut CubeContext,
421 value: impl Int,
422 block: impl FnOnce(&mut CubeContext),
423 ) -> Self {
424 let value = I::from(value).unwrap();
425 let mut case_child = context.child();
426 block(&mut case_child);
427 self.cases.push((value.into(), case_child));
428 self
429 }
430
431 pub fn finish(self, context: &mut CubeContext) {
432 let value_var = *self.value.expand;
433 context.register(Branch::Switch(Box::new(Switch {
434 value: value_var,
435 scope_default: self.default.into_scope(),
436 cases: self
437 .cases
438 .into_iter()
439 .map(|it| (*it.0.expand, it.1.into_scope()))
440 .collect(),
441 })));
442 }
443}
444
445pub fn switch_expand<I: Int>(
446 context: &mut CubeContext,
447 value: ExpandElementTyped<I>,
448 default_block: impl FnOnce(&mut CubeContext),
449) -> SwitchExpand<I> {
450 let mut default_child = context.child();
451 default_block(&mut default_child);
452
453 SwitchExpand {
454 value,
455 default: default_child,
456 cases: Vec::new(),
457 }
458}
459
460pub struct SwitchExpandExpr<I: Int, C: CubePrimitive> {
461 value: ExpandElementTyped<I>,
462 out: ExpandElementTyped<C>,
463 default: CubeContext,
464 cases: Vec<(ExpandElementTyped<I>, CubeContext)>,
465}
466
467impl<I: Int, C: CubePrimitive> SwitchExpandExpr<I, C> {
468 pub fn case(
469 mut self,
470 context: &mut CubeContext,
471 value: impl Int,
472 block: impl FnOnce(&mut CubeContext) -> ExpandElementTyped<C>,
473 ) -> Self {
474 let value = I::from(value).unwrap();
475 let mut case_child = context.child();
476 let ret = block(&mut case_child);
477 assign::expand(&mut case_child, ret, self.out.clone());
478 self.cases.push((value.into(), case_child));
479 self
480 }
481
482 pub fn finish(self, context: &mut CubeContext) -> ExpandElementTyped<C> {
483 let value_var = *self.value.expand;
484 context.register(Branch::Switch(Box::new(Switch {
485 value: value_var,
486 scope_default: self.default.into_scope(),
487 cases: self
488 .cases
489 .into_iter()
490 .map(|it| (*it.0.expand, it.1.into_scope()))
491 .collect(),
492 })));
493 self.out
494 }
495}
496
497pub fn switch_expand_expr<I: Int, C: CubePrimitive>(
498 context: &mut CubeContext,
499 value: ExpandElementTyped<I>,
500 default_block: impl FnOnce(&mut CubeContext) -> ExpandElementTyped<C>,
501) -> SwitchExpandExpr<I, C> {
502 let mut default_child = context.child();
503 let default = default_block(&mut default_child);
504 let out: ExpandElementTyped<C> = context.create_local_mut(default.expand.item).into();
505 assign::expand(&mut default_child, default, out.clone());
506
507 SwitchExpandExpr {
508 value,
509 out,
510 default: default_child,
511 cases: Vec::new(),
512 }
513}
514
515pub fn break_expand(context: &mut CubeContext) {
516 context.register(Branch::Break);
517}
518
519pub fn return_expand(context: &mut CubeContext) {
520 context.register(Branch::Return);
521}
522
523pub fn loop_expand(context: &mut CubeContext, mut block: impl FnMut(&mut CubeContext)) {
525 let mut inside_loop = context.child();
526
527 block(&mut inside_loop);
528 context.register(Branch::Loop(Box::new(Loop {
529 scope: inside_loop.into_scope(),
530 })));
531}