use crate::{
CubeRuntime,
kernel::utils::{address_type, linear_view, shape_divmod},
};
use crate::{element::CubeElement, tensor::CubeTensor};
use crate::{
kernel::{
AddOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, DivOp, MulOp, PowOp, RemainderOp, SubOp,
launch_binop, launch_binop_int, launch_scalar_binop, launch_scalar_binop_int,
},
ops::max_line_size,
};
use burn_backend::{DType, Shape, TensorMetadata};
use burn_std::Metadata;
use cubecl::{calculate_cube_count_elemwise, prelude::*};
use cubecl::{client::ComputeClient, server::Allocation};
use cubecl::{
server::AllocationDescriptor,
std::{FastDivmod, tensor::layout::linear::LinearView},
};
pub fn full<R: CubeRuntime, E: CubeElement>(
shape: Shape,
device: &R::Device,
value: E,
) -> CubeTensor<R> {
let client = R::client(device);
full_client::<R, E>(client, shape, device.clone(), value)
}
pub fn full_client<R: CubeRuntime, E: CubeElement>(
client: ComputeClient<R>,
shape: Shape,
device: R::Device,
value: E,
) -> CubeTensor<R> {
let dtype = E::dtype();
full_device_dtype(client, shape, device, InputScalar::new(value, dtype), dtype)
}
pub fn full_device_dtype<R: CubeRuntime>(
client: ComputeClient<R>,
shape: Shape,
device: R::Device,
value: InputScalar,
dtype: DType,
) -> CubeTensor<R> {
let empty = empty_device_dtype(client, device, shape, dtype);
#[cube(launch_unchecked, address_type = "dynamic")]
pub fn full_kernel<C: Numeric>(
tensor: &mut LinearView<C, ReadWrite>,
value: InputScalar,
#[define(C)] _dtype: StorageType,
) {
if !tensor.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
tensor[ABSOLUTE_POS] = value.get::<C>();
}
let num_elems = empty.meta.num_elements();
let line_size = max_line_size(&empty);
let working_units = num_elems / line_size as usize;
let cube_dim = CubeDim::new(&empty.client, working_units);
let cube_count = calculate_cube_count_elemwise(&empty.client, working_units, cube_dim);
unsafe {
full_kernel::launch_unchecked(
&empty.client,
cube_count,
cube_dim,
address_type!(empty),
linear_view(&empty, line_size),
value,
empty.dtype.into(),
)
.expect("Kernel to never fail");
}
empty
}
pub fn zeros<R: CubeRuntime>(device: R::Device, shape: Shape, dtype: DType) -> CubeTensor<R> {
let client = R::client(&device);
full_device_dtype(client, shape, device, InputScalar::new(0u32, dtype), dtype)
}
pub fn ones<R: CubeRuntime>(device: R::Device, shape: Shape, dtype: DType) -> CubeTensor<R> {
let client = R::client(&device);
full_device_dtype(client, shape, device, InputScalar::new(1u32, dtype), dtype)
}
pub fn zeros_client<R: CubeRuntime>(
client: ComputeClient<R>,
device: R::Device,
shape: Shape,
dtype: DType,
) -> CubeTensor<R> {
full_device_dtype(client, shape, device, InputScalar::new(0u32, dtype), dtype)
}
pub fn ones_client<R: CubeRuntime>(
client: ComputeClient<R>,
device: R::Device,
shape: Shape,
dtype: DType,
) -> CubeTensor<R> {
full_device_dtype(client, shape, device, InputScalar::new(1u32, dtype), dtype)
}
pub fn empty_device<R: CubeRuntime, E: CubeElement>(
client: ComputeClient<R>,
device: R::Device,
shape: Shape,
) -> CubeTensor<R> {
let Allocation { handle, strides } = client.empty_tensor(&shape, size_of::<E>());
CubeTensor::new(
client,
handle,
Metadata::new(shape, strides),
device,
E::dtype(),
)
}
pub fn empty_device_dtype<R: CubeRuntime>(
client: ComputeClient<R>,
device: R::Device,
shape: Shape,
dtype: DType,
) -> CubeTensor<R> {
let Allocation { handle, strides } = client.empty_tensor(&shape, dtype.size());
CubeTensor::new(client, handle, Metadata::new(shape, strides), device, dtype)
}
pub fn empty_device_contiguous_dtype<R: CubeRuntime>(
client: ComputeClient<R>,
device: R::Device,
shape: Shape,
dtype: DType,
) -> CubeTensor<R> {
let descriptor = AllocationDescriptor::contiguous(&shape, dtype.size());
let Allocation { handle, strides } = client.empty_tensors(vec![descriptor]).remove(0);
CubeTensor::new(client, handle, Metadata::new(shape, strides), device, dtype)
}
pub fn add<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
launch_binop::<R, AddOp>(lhs, rhs)
}
pub fn add_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
launch_scalar_binop::<R, AddOp>(lhs, rhs)
}
pub fn sub<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
launch_binop::<R, SubOp>(lhs, rhs)
}
pub fn sub_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
launch_scalar_binop::<R, SubOp>(lhs, rhs)
}
pub fn mul<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
launch_binop::<R, MulOp>(lhs, rhs)
}
pub fn mul_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
launch_scalar_binop::<R, MulOp>(lhs, rhs)
}
pub fn div<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
launch_binop::<R, DivOp>(lhs, rhs)
}
pub fn div_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
launch_scalar_binop::<R, DivOp>(lhs, rhs)
}
pub fn remainder<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
launch_binop::<R, RemainderOp>(lhs, rhs)
}
pub fn remainder_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
launch_scalar_binop::<R, RemainderOp>(lhs, rhs)
}
pub fn pow<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
launch_binop::<R, PowOp>(lhs, rhs)
}
pub fn bitwise_and<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
launch_binop_int::<R, BitwiseAndOp>(lhs, rhs)
}
pub fn bitwise_and_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
launch_scalar_binop_int::<R, BitwiseAndOp>(lhs, rhs)
}
pub fn bitwise_or<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
launch_binop_int::<R, BitwiseOrOp>(lhs, rhs)
}
pub fn bitwise_or_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
launch_scalar_binop_int::<R, BitwiseOrOp>(lhs, rhs)
}
pub fn bitwise_xor<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
launch_binop_int::<R, BitwiseXorOp>(lhs, rhs)
}
pub fn bitwise_xor_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
launch_scalar_binop_int::<R, BitwiseXorOp>(lhs, rhs)
}
pub(crate) trait CumulativeOpFamily: Send + Sync + 'static {
type CumulativeOp<C: Numeric>: CumulativeOp<C>;
}
#[cube]
pub(crate) trait CumulativeOp<C: Numeric>: 'static + Send + Sync {
fn execute(lhs: C, rhs: C) -> C;
fn init_value(first_element: C) -> C;
}
struct SumOp;
struct ProdOp;
struct MaxOp;
struct MinOp;
impl CumulativeOpFamily for SumOp {
type CumulativeOp<C: Numeric> = Self;
}
impl CumulativeOpFamily for ProdOp {
type CumulativeOp<C: Numeric> = Self;
}
impl CumulativeOpFamily for MaxOp {
type CumulativeOp<C: Numeric> = Self;
}
impl CumulativeOpFamily for MinOp {
type CumulativeOp<C: Numeric> = Self;
}
#[cube]
impl<N: Numeric> CumulativeOp<N> for SumOp {
fn execute(lhs: N, rhs: N) -> N {
lhs + rhs
}
fn init_value(_first_element: N) -> N {
N::from_int(0)
}
}
#[cube]
impl<N: Numeric> CumulativeOp<N> for ProdOp {
fn execute(lhs: N, rhs: N) -> N {
lhs * rhs
}
fn init_value(_first_element: N) -> N {
N::from_int(1)
}
}
#[cube]
impl<N: Numeric> CumulativeOp<N> for MaxOp {
fn execute(lhs: N, rhs: N) -> N {
max(lhs, rhs)
}
fn init_value(first_element: N) -> N {
first_element
}
}
#[cube]
impl<N: Numeric> CumulativeOp<N> for MinOp {
fn execute(lhs: N, rhs: N) -> N {
min(lhs, rhs)
}
fn init_value(first_element: N) -> N {
first_element
}
}
#[cube(launch_unchecked, address_type = "dynamic")]
fn cumulative_kernel<C: Numeric, O: CumulativeOpFamily>(
input: &Tensor<C>,
output: &mut LinearView<C, ReadWrite>,
shape: Sequence<FastDivmod<usize>>,
#[comptime] dim: usize,
#[define(C)] _dtype: StorageType,
) {
if !output.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
let rank = comptime![shape.len()];
let dim_stride = input.stride(dim);
let mut remainder = ABSOLUTE_POS;
let mut offset = 0;
let mut dim_idx = 0;
#[unroll]
for i in 0..shape.len() {
let i = comptime![rank - i - 1];
let (rem, local_idx) = shape.index(i).div_mod(remainder);
remainder = rem;
if i == dim {
dim_idx = local_idx;
} else {
offset += local_idx * input.stride(i);
}
}
let first_read_idx = offset + dim_idx * dim_stride;
let first_elem = input[first_read_idx];
let mut result = O::CumulativeOp::<C>::init_value(first_elem);
for i in 0..=dim_idx {
let read_idx = offset + i * dim_stride;
result = O::CumulativeOp::<C>::execute(result, input[read_idx]);
}
output[ABSOLUTE_POS] = result;
}
pub fn cumsum<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
cumulative_op::<R, SumOp>(input, dim)
}
pub fn cumprod<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
cumulative_op::<R, ProdOp>(input, dim)
}
pub fn cummin<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
cumulative_op::<R, MinOp>(input, dim)
}
pub fn cummax<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
cumulative_op::<R, MaxOp>(input, dim)
}
fn cumulative_op<R: CubeRuntime, O: CumulativeOpFamily>(
input: CubeTensor<R>,
dim: usize,
) -> CubeTensor<R> {
let client = input.client.clone();
let device = input.device.clone();
let output = empty_device_dtype(client.clone(), device, input.shape(), input.dtype);
let num_elems = output.meta.num_elements();
let working_units = num_elems;
let cube_dim = CubeDim::new(&client, working_units);
let cube_count = calculate_cube_count_elemwise(&client, working_units, cube_dim);
unsafe {
cumulative_kernel::launch_unchecked::<O, R>(
&client,
cube_count,
cube_dim,
address_type!(input, output),
input.as_tensor_arg(1),
linear_view(&output, 1),
shape_divmod(&input),
dim,
output.dtype.into(),
)
.expect("Kernel to never fail");
}
output
}