use cubecl_ir::{
Arithmetic, BinaryOperator, Comparison, ElemType, IndexAssignOperator, IndexOperator,
Instruction, ManagedVariable, Operation, Operator, Scope, Type, UnaryOperator, Variable,
VariableKind, VectorSize,
};
use cubecl_macros::cube;
use crate::{
self as cubecl,
prelude::{CubeIndex, CubeType, Int, NativeExpand, eq, rem},
};
pub(crate) fn binary_expand<F, Op>(
scope: &mut Scope,
lhs: ManagedVariable,
rhs: ManagedVariable,
func: F,
) -> ManagedVariable
where
F: Fn(BinaryOperator) -> Op,
Op: Into<Operation>,
{
let lhs = lhs.consume();
let rhs = rhs.consume();
let item_lhs = lhs.ty;
let item_rhs = rhs.ty;
let vector_size = find_vectorization(item_lhs, item_rhs);
let item = item_lhs.with_vector_size(vector_size);
let output = scope.create_local(item);
let out = *output;
let op = func(BinaryOperator { lhs, rhs });
scope.register(Instruction::new(op, out));
output
}
pub(crate) fn index_expand_no_vec<F>(
scope: &mut Scope,
list: ManagedVariable,
index: ManagedVariable,
func: F,
) -> ManagedVariable
where
F: Fn(IndexOperator) -> Operator,
{
let list = list.consume();
let index = index.consume();
let item_lhs = list.ty;
let item = item_lhs.with_vector_size(0);
let output = scope.create_local(item);
let out = *output;
let op = func(IndexOperator {
list,
index,
vector_size: 0,
unroll_factor: 1,
});
scope.register(Instruction::new(op, out));
output
}
pub(crate) fn index_expand<F, Op>(
scope: &mut Scope,
list: ManagedVariable,
index: ManagedVariable,
vector_size: Option<VectorSize>,
func: F,
) -> ManagedVariable
where
F: Fn(IndexOperator) -> Op,
Op: Into<Operation>,
{
let list = list.consume();
let index = index.consume();
let item_lhs = list.ty;
let item_rhs = index.ty;
let vec = if let Some(vector_size) = vector_size {
vector_size
} else {
find_vectorization(item_lhs, item_rhs)
};
let item = item_lhs.with_vector_size(vec);
let output = scope.create_local(item);
let out = *output;
let op = func(IndexOperator {
list,
index,
vector_size: vector_size.unwrap_or(0),
unroll_factor: 1,
});
scope.register(Instruction::new(op, out));
output
}
pub(crate) fn binary_expand_fixed_output<F>(
scope: &mut Scope,
lhs: ManagedVariable,
rhs: ManagedVariable,
out_item: Type,
func: F,
) -> ManagedVariable
where
F: Fn(BinaryOperator) -> Arithmetic,
{
let lhs_var = lhs.consume();
let rhs_var = rhs.consume();
let out = scope.create_local(out_item);
let out_var = *out;
let op = func(BinaryOperator {
lhs: lhs_var,
rhs: rhs_var,
});
scope.register(Instruction::new(op, out_var));
out
}
pub(crate) fn cmp_expand<F>(
scope: &mut Scope,
lhs: ManagedVariable,
rhs: ManagedVariable,
func: F,
) -> ManagedVariable
where
F: Fn(BinaryOperator) -> Comparison,
{
let lhs = lhs.consume();
let rhs = rhs.consume();
let item_lhs = lhs.ty;
let item_rhs = rhs.ty;
let vector_size = find_vectorization(item_lhs, item_rhs);
let out_item = Type::scalar(ElemType::Bool).with_vector_size(vector_size);
let out = scope.create_local(out_item);
let out_var = *out;
let op = func(BinaryOperator { lhs, rhs });
scope.register(Instruction::new(op, out_var));
out
}
pub(crate) fn assign_op_expand<F, Op>(
scope: &mut Scope,
lhs: ManagedVariable,
rhs: ManagedVariable,
func: F,
) -> ManagedVariable
where
F: Fn(BinaryOperator) -> Op,
Op: Into<Operation>,
{
if lhs.is_immutable() {
panic!("Can't have a mutable operation on a const variable. Try to use `RuntimeCell`.");
}
let lhs_var: Variable = *lhs;
let rhs: Variable = *rhs;
let op = func(BinaryOperator { lhs: lhs_var, rhs });
scope.register(Instruction::new(op, lhs_var));
lhs
}
pub fn unary_expand<F, Op>(scope: &mut Scope, input: ManagedVariable, func: F) -> ManagedVariable
where
F: Fn(UnaryOperator) -> Op,
Op: Into<Operation>,
{
let input = input.consume();
let item = input.ty;
let out = scope.create_local(item);
let out_var = *out;
let op = func(UnaryOperator { input });
scope.register(Instruction::new(op, out_var));
out
}
pub fn unary_expand_fixed_output<F, Op>(
scope: &mut Scope,
input: ManagedVariable,
out_item: Type,
func: F,
) -> ManagedVariable
where
F: Fn(UnaryOperator) -> Op,
Op: Into<Operation>,
{
let input = input.consume();
let output = scope.create_local(out_item);
let out = *output;
let op = func(UnaryOperator { input });
scope.register(Instruction::new(op, out));
output
}
pub fn init_expand<F>(
scope: &mut Scope,
input: ManagedVariable,
mutable: bool,
func: F,
) -> ManagedVariable
where
F: Fn(Variable) -> Operation,
{
let input_var: Variable = *input;
let item = input.ty;
let out = if mutable {
scope.create_local_mut(item)
} else {
scope.create_local(item)
};
let out_var = *out;
let op = func(input_var);
scope.register(Instruction::new(op, out_var));
out
}
pub(crate) fn find_vectorization(lhs: Type, rhs: Type) -> VectorSize {
if matches!(lhs, Type::Scalar(_)) && matches!(rhs, Type::Scalar(_)) {
0
} else {
lhs.vector_size().max(rhs.vector_size())
}
}
pub fn array_assign_binary_op_expand<
A: CubeType + CubeIndex,
V: CubeType,
F: Fn(BinaryOperator) -> Op,
Op: Into<Operation>,
>(
scope: &mut Scope,
array: NativeExpand<A>,
index: NativeExpand<usize>,
value: NativeExpand<V>,
func: F,
) where
A::Output: CubeType + Sized,
{
let array: ManagedVariable = array.into();
let index: ManagedVariable = index.into();
let value: ManagedVariable = value.into();
let array_item = match array.kind {
VariableKind::LocalMut { .. } => array.ty.with_vector_size(0),
_ => array.ty,
};
let array_value = scope.create_local(array_item);
let read = Instruction::new(
Operator::Index(IndexOperator {
list: *array,
index: *index,
vector_size: 0,
unroll_factor: 1,
}),
*array_value,
);
let array_value = array_value.consume();
let op_out = scope.create_local(array_item);
let calculate = Instruction::new(
func(BinaryOperator {
lhs: array_value,
rhs: *value,
}),
*op_out,
);
let write = Operator::IndexAssign(IndexAssignOperator {
index: *index,
value: op_out.consume(),
vector_size: 0,
unroll_factor: 1,
});
scope.register(read);
scope.register(calculate);
scope.register(Instruction::new(write, *array));
}
pub trait DivCeil: Int + CubeType<ExpandType: DivCeilExpand<Self>> {
fn div_ceil(self, divisor: Self) -> Self;
fn __expand_div_ceil(
scope: &mut Scope,
a: NativeExpand<Self>,
b: NativeExpand<Self>,
) -> NativeExpand<Self> {
a.__expand_div_ceil_method(scope, b)
}
}
pub trait DivCeilExpand<E: Int> {
fn __expand_div_ceil_method(self, scope: &mut Scope, divisor: Self) -> Self;
}
impl<E: DivCeil> DivCeilExpand<E> for NativeExpand<E> {
fn __expand_div_ceil_method(
self,
scope: &mut Scope,
divisor: NativeExpand<E>,
) -> NativeExpand<E> {
div_ceil::expand::<E>(scope, self, divisor)
}
}
macro_rules! impl_div_ceil {
($($ty:ty),*) => {
$(
impl DivCeil for $ty {
#[allow(clippy::manual_div_ceil)] fn div_ceil(self, divisor: Self) -> Self {
(self + divisor - 1) / divisor
}
}
)*
};
}
impl_div_ceil!(u8, u16, u32, u64, usize, i8, i16, i32, i64, isize);
impl<E: Int> NativeExpand<E> {
pub fn __expand_is_multiple_of_method(
self,
scope: &mut Scope,
factor: NativeExpand<E>,
) -> NativeExpand<bool> {
let modulo = rem::expand(scope, self, factor);
eq::expand(scope, modulo, E::from_int(0).into())
}
}
#[cube]
pub fn div_ceil<E: Int>(a: E, b: E) -> E {
(a + b - E::new(1)) / b
}