use cubecl_ir::ExpandElement;
use num_traits::NumCast;
use crate::ir::Switch;
use crate::ir::{Branch, If, IfElse, Loop, RangeLoop, Scope, Type};
use super::{CubePrimitive, CubeType, ExpandElementTyped, Int, Numeric, assign};
pub trait Iterable<T: CubeType>: Sized {
fn expand(self, scope: &mut Scope, body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType));
fn expand_unroll(
self,
scope: &mut Scope,
body: impl FnMut(&mut Scope, <T as CubeType>::ExpandType),
);
fn const_len(&self) -> Option<usize> {
None
}
}
pub struct RangeExpand<I: Int> {
pub start: ExpandElementTyped<I>,
pub end: ExpandElementTyped<I>,
pub inclusive: bool,
}
impl<I: Int> RangeExpand<I> {
pub fn new(start: ExpandElementTyped<I>, end: ExpandElementTyped<I>, inclusive: bool) -> Self {
RangeExpand {
start,
end,
inclusive,
}
}
pub fn __expand_step_by_method(
self,
n: impl Into<ExpandElementTyped<I>>,
) -> SteppedRangeExpand<I> {
SteppedRangeExpand {
start: self.start,
end: self.end,
step: n.into(),
inclusive: self.inclusive,
}
}
}
impl<I: Int> Iterable<I> for RangeExpand<I> {
fn expand_unroll(
self,
scope: &mut Scope,
mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
) {
let start = self
.start
.expand
.as_const()
.expect("Only constant start can be unrolled.")
.as_i64();
let end = self
.end
.expand
.as_const()
.expect("Only constant end can be unrolled.")
.as_i64();
if self.inclusive {
for i in start..=end {
let var = I::from_int(i);
body(scope, var.into())
}
} else {
for i in start..end {
let var = I::from_int(i);
body(scope, var.into())
}
}
}
fn expand(
self,
scope: &mut Scope,
mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
) {
let mut child = scope.child();
let index_ty = Type::new(I::as_type(scope));
let i = child.create_local_restricted(index_ty);
body(&mut child, i.clone().into());
scope.register(Branch::RangeLoop(Box::new(RangeLoop {
i: *i,
start: *self.start.expand,
end: *self.end.expand,
step: None,
scope: child,
inclusive: self.inclusive,
})));
}
fn const_len(&self) -> Option<usize> {
let start = self.start.expand.as_const()?.as_i64();
let end = self.end.expand.as_const()?.as_i64();
Some(start.abs_diff(end) as usize)
}
}
pub struct SteppedRangeExpand<I: Int> {
start: ExpandElementTyped<I>,
end: ExpandElementTyped<I>,
step: ExpandElementTyped<I>,
inclusive: bool,
}
impl<I: Int + Into<ExpandElement>> Iterable<I> for SteppedRangeExpand<I> {
fn expand(
self,
scope: &mut Scope,
mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
) {
let mut child = scope.child();
let index_ty = Type::new(I::as_type(scope));
let i = child.create_local_restricted(index_ty);
body(&mut child, i.clone().into());
scope.register(Branch::RangeLoop(Box::new(RangeLoop {
i: *i,
start: *self.start.expand,
end: *self.end.expand,
step: Some(*self.step.expand),
scope: child,
inclusive: self.inclusive,
})));
}
fn expand_unroll(
self,
scope: &mut Scope,
mut body: impl FnMut(&mut Scope, <I as CubeType>::ExpandType),
) {
let start = self
.start
.expand
.as_const()
.expect("Only constant start can be unrolled.")
.as_i128();
let end = self
.end
.expand
.as_const()
.expect("Only constant end can be unrolled.")
.as_i128();
let step = self
.step
.expand
.as_const()
.expect("Only constant step can be unrolled.")
.as_i128();
match (self.inclusive, step.is_negative()) {
(true, true) => {
for i in (end..=start).rev().step_by(step.unsigned_abs() as usize) {
let var = I::from_int_128(i);
body(scope, var.into())
}
}
(true, false) => {
for i in (start..=end).step_by(step.unsigned_abs() as usize) {
let var = I::from_int_128(i);
body(scope, var.into())
}
}
(false, true) => {
for i in (end..start).rev().step_by(step.unsigned_abs() as usize) {
let var = I::from_int_128(i);
body(scope, var.into())
}
}
(false, false) => {
for i in (start..end).step_by(step.unsigned_abs() as usize) {
let var = I::from_int_128(i);
body(scope, var.into())
}
}
}
}
fn const_len(&self) -> Option<usize> {
let start = self.start.constant()?.as_i128();
let end = self.end.constant()?.as_i128();
let step = self.step.constant()?.as_i128().unsigned_abs();
Some((start.abs_diff(end) / step) as usize)
}
}
pub fn range<T: Int>(start: T, end: T) -> impl Iterator<Item = T> {
let start: i64 = start.to_i64().unwrap();
let end: i64 = end.to_i64().unwrap();
(start..end).map(<T as NumCast>::from).map(Option::unwrap)
}
pub mod range {
use cubecl_ir::Scope;
use crate::prelude::{ExpandElementTyped, Int};
use super::RangeExpand;
pub fn expand<I: Int>(
_scope: &mut Scope,
start: ExpandElementTyped<I>,
end: ExpandElementTyped<I>,
) -> RangeExpand<I> {
RangeExpand {
start,
end,
inclusive: false,
}
}
}
pub fn range_stepped<I: Int>(start: I, end: I, step: I) -> Box<dyn Iterator<Item = I>> {
let start = start.to_i128().unwrap();
let end = end.to_i128().unwrap();
let step = step.to_i128().unwrap();
if step < 0 {
Box::new(
(end..start)
.rev()
.step_by(step.unsigned_abs() as usize)
.map(<I as NumCast>::from)
.map(Option::unwrap),
)
} else {
Box::new(
(start..end)
.step_by(step.unsigned_abs() as usize)
.map(<I as NumCast>::from)
.map(Option::unwrap),
)
}
}
pub mod range_stepped {
use cubecl_ir::Scope;
use crate::prelude::{ExpandElementTyped, Int};
use super::SteppedRangeExpand;
pub fn expand<I: Int>(
_scope: &mut Scope,
start: ExpandElementTyped<I>,
end: ExpandElementTyped<I>,
step: ExpandElementTyped<I>,
) -> SteppedRangeExpand<I> {
SteppedRangeExpand {
start,
end,
step,
inclusive: false,
}
}
}
pub fn for_expand<I: Numeric>(
scope: &mut Scope,
range: impl Iterable<I>,
unroll: bool,
body: impl FnMut(&mut Scope, ExpandElementTyped<I>),
) {
if unroll || range.const_len() == Some(1) {
range.expand_unroll(scope, body);
} else {
range.expand(scope, body);
}
}
pub fn if_expand(scope: &mut Scope, runtime_cond: ExpandElement, block: impl FnOnce(&mut Scope)) {
let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
match comptime_cond {
Some(cond) => {
if cond {
block(scope);
}
}
None => {
let mut child = scope.child();
block(&mut child);
scope.register(Branch::If(Box::new(If {
cond: *runtime_cond,
scope: child,
})));
}
}
}
#[allow(clippy::large_enum_variant)]
pub enum IfElseExpand {
ComptimeThen,
ComptimeElse,
Runtime {
runtime_cond: ExpandElement,
then_child: Scope,
},
}
impl IfElseExpand {
pub fn or_else(self, scope: &mut Scope, else_block: impl FnOnce(&mut Scope)) {
match self {
Self::Runtime {
runtime_cond,
then_child,
} => {
let mut else_child = scope.child();
else_block(&mut else_child);
scope.register(Branch::IfElse(Box::new(IfElse {
cond: *runtime_cond,
scope_if: then_child,
scope_else: else_child,
})));
}
Self::ComptimeElse => else_block(scope),
Self::ComptimeThen => (),
}
}
}
pub fn if_else_expand(
scope: &mut Scope,
runtime_cond: ExpandElement,
then_block: impl FnOnce(&mut Scope),
) -> IfElseExpand {
let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
match comptime_cond {
Some(true) => {
then_block(scope);
IfElseExpand::ComptimeThen
}
Some(false) => IfElseExpand::ComptimeElse,
None => {
let mut then_child = scope.child();
then_block(&mut then_child);
IfElseExpand::Runtime {
runtime_cond,
then_child,
}
}
}
}
#[allow(clippy::large_enum_variant)]
pub enum IfElseExprExpand<C: CubeType> {
ComptimeThen(ExpandElementTyped<C>),
ComptimeElse,
Runtime {
runtime_cond: ExpandElement,
out: ExpandElementTyped<C>,
then_child: Scope,
},
}
impl<C: CubePrimitive> IfElseExprExpand<C> {
pub fn or_else(
self,
scope: &mut Scope,
else_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
) -> ExpandElementTyped<C> {
match self {
Self::Runtime {
runtime_cond,
out,
then_child,
} => {
let mut else_child = scope.child();
let ret = else_block(&mut else_child);
assign::expand_no_check::<C>(&mut else_child, ret, out.clone());
scope.register(Branch::IfElse(Box::new(IfElse {
cond: *runtime_cond,
scope_if: then_child,
scope_else: else_child,
})));
out
}
Self::ComptimeElse => else_block(scope),
Self::ComptimeThen(ret) => ret,
}
}
}
pub fn if_else_expr_expand<C: CubePrimitive>(
scope: &mut Scope,
runtime_cond: ExpandElement,
then_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
) -> IfElseExprExpand<C> {
let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
match comptime_cond {
Some(true) => {
let ret = then_block(scope);
IfElseExprExpand::ComptimeThen(ret)
}
Some(false) => IfElseExprExpand::ComptimeElse,
None => {
let mut then_child = scope.child();
let ret = then_block(&mut then_child);
let out: ExpandElementTyped<C> = scope.create_local_mut(ret.expand.ty).into();
assign::expand_no_check::<C>(&mut then_child, ret, out.clone());
IfElseExprExpand::Runtime {
runtime_cond,
out,
then_child,
}
}
}
}
pub struct SwitchExpand<I: Int> {
value: ExpandElementTyped<I>,
default: Scope,
cases: Vec<(ExpandElementTyped<I>, Scope)>,
}
impl<I: Int> SwitchExpand<I> {
pub fn case(
mut self,
scope: &mut Scope,
value: impl Int,
block: impl FnOnce(&mut Scope),
) -> Self {
let value = I::from(value).unwrap();
let mut case_child = scope.child();
block(&mut case_child);
self.cases.push((value.into(), case_child));
self
}
pub fn finish(self, scope: &mut Scope) {
let value_var = *self.value.expand;
scope.register(Branch::Switch(Box::new(Switch {
value: value_var,
scope_default: self.default,
cases: self
.cases
.into_iter()
.map(|it| (*it.0.expand, it.1))
.collect(),
})));
}
}
pub fn switch_expand<I: Int>(
scope: &mut Scope,
value: ExpandElementTyped<I>,
default_block: impl FnOnce(&mut Scope),
) -> SwitchExpand<I> {
let mut default_child = scope.child();
default_block(&mut default_child);
SwitchExpand {
value,
default: default_child,
cases: Vec::new(),
}
}
pub struct SwitchExpandExpr<I: Int, C: CubePrimitive> {
value: ExpandElementTyped<I>,
out: ExpandElementTyped<C>,
default: Scope,
cases: Vec<(ExpandElementTyped<I>, Scope)>,
}
impl<I: Int, C: CubePrimitive> SwitchExpandExpr<I, C> {
pub fn case(
mut self,
scope: &mut Scope,
value: impl Int,
block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
) -> Self {
let value = I::from(value).unwrap();
let mut case_child = scope.child();
let ret = block(&mut case_child);
assign::expand_no_check::<C>(&mut case_child, ret, self.out.clone());
self.cases.push((value.into(), case_child));
self
}
pub fn finish(self, scope: &mut Scope) -> ExpandElementTyped<C> {
let value_var = *self.value.expand;
scope.register(Branch::Switch(Box::new(Switch {
value: value_var,
scope_default: self.default,
cases: self
.cases
.into_iter()
.map(|it| (*it.0.expand, it.1))
.collect(),
})));
self.out
}
}
pub fn switch_expand_expr<I: Int, C: CubePrimitive>(
scope: &mut Scope,
value: ExpandElementTyped<I>,
default_block: impl FnOnce(&mut Scope) -> ExpandElementTyped<C>,
) -> SwitchExpandExpr<I, C> {
let mut default_child = scope.child();
let default = default_block(&mut default_child);
let out: ExpandElementTyped<C> = scope.create_local_mut(default.expand.ty).into();
assign::expand_no_check::<C>(&mut default_child, default, out.clone());
SwitchExpandExpr {
value,
out,
default: default_child,
cases: Vec::new(),
}
}
pub fn break_expand(scope: &mut Scope) {
scope.register(Branch::Break);
}
pub fn return_expand(scope: &mut Scope) {
scope.register(Branch::Return);
}
pub fn loop_expand(scope: &mut Scope, mut block: impl FnMut(&mut Scope)) {
let mut inside_loop = scope.child();
block(&mut inside_loop);
scope.register(Branch::Loop(Box::new(Loop { scope: inside_loop })));
}