cubecl-core 0.9.0

CubeCL core create
Documentation
use cubecl_ir::{
    Arithmetic, BinaryOperator, Comparison, ElemType, ExpandElement, IndexAssignOperator,
    IndexOperator, Instruction, LineSize, Operation, Operator, Scope, Type, UnaryOperator,
    Variable, VariableKind,
};
use cubecl_macros::cube;

use crate::{
    self as cubecl,
    prelude::{CubeIndex, CubeType, ExpandElementTyped, Int, eq, rem},
};

pub(crate) fn binary_expand<F, Op>(
    scope: &mut Scope,
    lhs: ExpandElement,
    rhs: ExpandElement,
    func: F,
) -> ExpandElement
where
    F: Fn(BinaryOperator) -> Op,
    Op: Into<Operation>,
{
    let lhs = lhs.consume();
    let rhs = rhs.consume();

    let item_lhs = lhs.ty;
    let item_rhs = rhs.ty;

    let line_size = find_vectorization(item_lhs, item_rhs);

    let item = item_lhs.line(line_size);

    let output = scope.create_local(item);
    let out = *output;

    let op = func(BinaryOperator { lhs, rhs });

    scope.register(Instruction::new(op, out));

    output
}

pub(crate) fn index_expand_no_vec<F>(
    scope: &mut Scope,
    list: ExpandElement,
    index: ExpandElement,
    func: F,
) -> ExpandElement
where
    F: Fn(IndexOperator) -> Operator,
{
    let list = list.consume();
    let index = index.consume();

    let item_lhs = list.ty;

    let item = item_lhs.line(0);

    let output = scope.create_local(item);
    let out = *output;

    let op = func(IndexOperator {
        list,
        index,
        line_size: 0,
        unroll_factor: 1,
    });

    scope.register(Instruction::new(op, out));

    output
}
pub(crate) fn index_expand<F, Op>(
    scope: &mut Scope,
    list: ExpandElement,
    index: ExpandElement,
    line_size: Option<LineSize>,
    func: F,
) -> ExpandElement
where
    F: Fn(IndexOperator) -> Op,
    Op: Into<Operation>,
{
    let list = list.consume();
    let index = index.consume();

    let item_lhs = list.ty;
    let item_rhs = index.ty;

    let vec = if let Some(line_size) = line_size {
        line_size
    } else {
        find_vectorization(item_lhs, item_rhs)
    };

    let item = item_lhs.line(vec);

    let output = scope.create_local(item);
    let out = *output;

    let op = func(IndexOperator {
        list,
        index,
        line_size: line_size.unwrap_or(0),
        unroll_factor: 1,
    });

    scope.register(Instruction::new(op, out));

    output
}

pub(crate) fn binary_expand_fixed_output<F>(
    scope: &mut Scope,
    lhs: ExpandElement,
    rhs: ExpandElement,
    out_item: Type,
    func: F,
) -> ExpandElement
where
    F: Fn(BinaryOperator) -> Arithmetic,
{
    let lhs_var = lhs.consume();
    let rhs_var = rhs.consume();

    let out = scope.create_local(out_item);

    let out_var = *out;

    let op = func(BinaryOperator {
        lhs: lhs_var,
        rhs: rhs_var,
    });

    scope.register(Instruction::new(op, out_var));

    out
}

pub(crate) fn cmp_expand<F>(
    scope: &mut Scope,
    lhs: ExpandElement,
    rhs: ExpandElement,
    func: F,
) -> ExpandElement
where
    F: Fn(BinaryOperator) -> Comparison,
{
    let lhs = lhs.consume();
    let rhs = rhs.consume();

    let item_lhs = lhs.ty;
    let item_rhs = rhs.ty;

    let line_size = find_vectorization(item_lhs, item_rhs);

    let out_item = Type::scalar(ElemType::Bool).line(line_size);

    let out = scope.create_local(out_item);
    let out_var = *out;

    let op = func(BinaryOperator { lhs, rhs });

    scope.register(Instruction::new(op, out_var));

    out
}

pub(crate) fn assign_op_expand<F, Op>(
    scope: &mut Scope,
    lhs: ExpandElement,
    rhs: ExpandElement,
    func: F,
) -> ExpandElement
where
    F: Fn(BinaryOperator) -> Op,
    Op: Into<Operation>,
{
    if lhs.is_immutable() {
        panic!("Can't have a mutable operation on a const variable. Try to use `RuntimeCell`.");
    }
    let lhs_var: Variable = *lhs;
    let rhs: Variable = *rhs;

    let op = func(BinaryOperator { lhs: lhs_var, rhs });

    scope.register(Instruction::new(op, lhs_var));

    lhs
}

pub fn unary_expand<F, Op>(scope: &mut Scope, input: ExpandElement, func: F) -> ExpandElement
where
    F: Fn(UnaryOperator) -> Op,
    Op: Into<Operation>,
{
    let input = input.consume();
    let item = input.ty;

    let out = scope.create_local(item);
    let out_var = *out;

    let op = func(UnaryOperator { input });

    scope.register(Instruction::new(op, out_var));

    out
}

pub fn unary_expand_fixed_output<F, Op>(
    scope: &mut Scope,
    input: ExpandElement,
    out_item: Type,
    func: F,
) -> ExpandElement
where
    F: Fn(UnaryOperator) -> Op,
    Op: Into<Operation>,
{
    let input = input.consume();
    let output = scope.create_local(out_item);
    let out = *output;

    let op = func(UnaryOperator { input });

    scope.register(Instruction::new(op, out));

    output
}

pub fn init_expand<F>(
    scope: &mut Scope,
    input: ExpandElement,
    mutable: bool,
    func: F,
) -> ExpandElement
where
    F: Fn(Variable) -> Operation,
{
    let input_var: Variable = *input;
    let item = input.ty;

    let out = if mutable {
        scope.create_local_mut(item)
    } else {
        scope.create_local(item)
    };

    let out_var = *out;

    let op = func(input_var);
    scope.register(Instruction::new(op, out_var));

    out
}

fn find_vectorization(lhs: Type, rhs: Type) -> LineSize {
    if matches!(lhs, Type::Scalar(_)) && matches!(rhs, Type::Scalar(_)) {
        0
    } else {
        lhs.line_size().max(rhs.line_size())
    }
}

pub fn array_assign_binary_op_expand<
    A: CubeType + CubeIndex,
    V: CubeType,
    F: Fn(BinaryOperator) -> Op,
    Op: Into<Operation>,
>(
    scope: &mut Scope,
    array: ExpandElementTyped<A>,
    index: ExpandElementTyped<usize>,
    value: ExpandElementTyped<V>,
    func: F,
) where
    A::Output: CubeType + Sized,
{
    let array: ExpandElement = array.into();
    let index: ExpandElement = index.into();
    let value: ExpandElement = value.into();

    let array_item = match array.kind {
        // In that case, the array is a line.
        VariableKind::LocalMut { .. } => array.ty.line(0),
        _ => array.ty,
    };
    let array_value = scope.create_local(array_item);

    let read = Instruction::new(
        Operator::Index(IndexOperator {
            list: *array,
            index: *index,
            line_size: 0,
            unroll_factor: 1,
        }),
        *array_value,
    );
    let array_value = array_value.consume();
    let op_out = scope.create_local(array_item);
    let calculate = Instruction::new(
        func(BinaryOperator {
            lhs: array_value,
            rhs: *value,
        }),
        *op_out,
    );

    let write = Operator::IndexAssign(IndexAssignOperator {
        index: *index,
        value: op_out.consume(),
        line_size: 0,
        unroll_factor: 1,
    });
    scope.register(read);
    scope.register(calculate);
    scope.register(Instruction::new(write, *array));
}

// Utilities for clippy lint compatibility
impl<E: Int> ExpandElementTyped<E> {
    pub fn __expand_div_ceil_method(
        self,
        scope: &mut Scope,
        divisor: ExpandElementTyped<E>,
    ) -> ExpandElementTyped<E> {
        div_ceil::expand::<E>(scope, self, divisor)
    }

    pub fn __expand_is_multiple_of_method(
        self,
        scope: &mut Scope,
        factor: ExpandElementTyped<E>,
    ) -> ExpandElementTyped<bool> {
        let modulo = rem::expand(scope, self, factor);
        eq::expand(scope, modulo, E::from_int(0).into())
    }
}

#[cube]
pub fn div_ceil<E: Int>(a: E, b: E) -> E {
    (a + b - E::new(1)) / b
}