use crate::{CubeRuntime, kernel, ops::numeric::empty_device_dtype, tensor::CubeTensor};
use burn_backend::{
DType, ExecutionError, QTensorPrimitive, Shape, TensorData,
quantization::{QuantLevel, QuantStore, params_shape},
};
use burn_backend::{TensorMetadata, ops::unfold::calculate_unfold_shape};
use burn_std::{
Metadata, strides,
tensor::{ReshapeAction, contiguous_strides, reshape_action},
};
use cubecl::{ir::LineSize, server::CopyDescriptor};
use cubecl::{quant::scheme::BlockSize, tensor_line_size_parallel};
pub(crate) fn from_data<R: CubeRuntime>(data: TensorData, device: &R::Device) -> CubeTensor<R> {
let client = R::client(device);
let alloc = client.create_tensor(data.bytes, &data.shape, data.dtype.size());
let shape: Shape = (&data.shape).into();
CubeTensor::new(
client,
alloc.handle,
Metadata::new(shape, alloc.strides),
device.clone(),
data.dtype,
)
}
pub(crate) async fn into_data<R: CubeRuntime>(
tensor: CubeTensor<R>,
) -> Result<TensorData, ExecutionError> {
let tensor = kernel::into_contiguous_aligned(tensor);
let elem_size = tensor.elem_size();
let shape = tensor.meta.shape();
let strides = tensor.meta.strides();
let binding = CopyDescriptor::new(tensor.handle.binding(), shape, strides, elem_size);
let bytes = tensor
.client
.read_one_tensor_async(binding)
.await
.map_err(|err| ExecutionError::WithContext {
reason: format!("{err}"),
})?;
Ok(TensorData::from_bytes(
bytes,
tensor.meta.shape,
tensor.dtype,
))
}
#[allow(unused, reason = "useful for debugging kernels")]
pub fn into_data_sync<R: CubeRuntime>(tensor: CubeTensor<R>) -> TensorData {
burn_std::future::block_on(into_data(tensor)).unwrap()
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(tensor, device))
)]
pub(crate) fn to_device<R: CubeRuntime>(
tensor: CubeTensor<R>,
device: &R::Device,
) -> CubeTensor<R> {
if &tensor.device == device {
return tensor;
}
let tensor = kernel::into_contiguous_aligned(tensor);
let client = R::client(device);
tensor.to_client(client, device.clone())
}
pub(crate) fn empty<R: CubeRuntime>(
shape: Shape,
device: &R::Device,
dtype: DType,
) -> CubeTensor<R> {
let client = R::client(device);
let alloc = client.empty_tensor(&shape, dtype.size());
CubeTensor::new(
client,
alloc.handle,
Metadata::new(shape, alloc.strides),
device.clone(),
dtype,
)
}
pub(crate) fn swap_dims<R: CubeRuntime>(
mut tensor: CubeTensor<R>,
dim1: usize,
dim2: usize,
) -> CubeTensor<R> {
tensor.meta.swap(dim1, dim2);
if let DType::QFloat(scheme) = tensor.dtype
&& let QuantLevel::Block(block_size) = scheme.level
{
let rank = tensor.rank();
let qparams = tensor.qparams.as_mut().unwrap();
let mut block_size = block_size.to_dim_vec(rank);
block_size.swap(dim1, dim2);
let block_size = BlockSize::new_trim(block_size);
if block_size.len() > BlockSize::MAX_DIMS {
panic!("Swapped block size would exceed max dims");
}
qparams.scales.metadata.swap(dim1, dim2);
tensor.dtype = DType::QFloat(scheme.with_level(QuantLevel::Block(block_size)))
}
if let DType::QFloat(scheme) = &mut tensor.dtype
&& let QuantStore::PackedU32(packed_dim) | QuantStore::PackedNative(packed_dim) =
&mut scheme.store
{
let rank = tensor.meta.num_dims();
if *packed_dim == rank - dim1 - 1 {
*packed_dim = rank - dim2 - 1;
} else if *packed_dim == rank - dim2 - 1 {
*packed_dim = rank - dim1 - 1;
}
}
tensor
}
pub fn permute<R: CubeRuntime>(mut tensor: CubeTensor<R>, axes: &[usize]) -> CubeTensor<R> {
tensor.meta.permute(axes).unwrap();
if let DType::QFloat(scheme) = tensor.dtype
&& let QuantLevel::Block(block_size) = scheme.level
{
let rank = tensor.rank();
let qparams = tensor.qparams.as_mut().unwrap();
let mut block_size = block_size.to_dim_vec(rank);
block_size = axes.iter().map(|i| block_size[*i]).collect();
let block_size = block_size
.into_iter()
.skip_while(|it| *it == 1)
.collect::<Vec<_>>();
if block_size.len() > BlockSize::MAX_DIMS {
panic!("Swapped block size would exceed max dims");
}
qparams.scales.metadata.permute(axes).unwrap();
tensor.dtype = DType::QFloat(scheme.with_level(QuantLevel::block(&block_size)))
}
if let DType::QFloat(scheme) = &mut tensor.dtype
&& let QuantStore::PackedU32(packed_dim) = &mut scheme.store
{
let rank = tensor.meta.num_dims();
let new_pos = axes
.iter()
.position(|axis| *axis == rank - *packed_dim - 1)
.unwrap_or(0);
*packed_dim = rank - new_pos - 1;
}
tensor
}
pub fn permute_nchw_to_nhwc<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
let rank = tensor.meta.num_dims();
let c_dim = 1;
let mut dims = vec![0];
dims.extend(2..rank);
dims.push(c_dim);
permute(tensor, &dims)
}
pub fn permute_nchw_to_nhwc_shape(shape: Shape) -> Shape {
let rank = shape.num_dims();
let c_dim = 1;
let mut dims = vec![0];
dims.extend(2..rank);
dims.push(c_dim);
shape.permuted(&dims).expect("Shape permute should succeed")
}
pub fn permute_nhwc_to_nchw<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
let rank = tensor.meta.num_dims();
let c_dim = rank - 1;
let mut dims = vec![0];
dims.push(c_dim);
dims.extend(1..c_dim);
permute(tensor, &dims)
}
pub fn permute_nhwc_to_nchw_shape(shape: Shape) -> Shape {
let rank = shape.num_dims();
let c_dim = rank - 1;
let mut dims = vec![0];
dims.push(c_dim);
dims.extend(1..c_dim);
shape.permuted(&dims).expect("Shape permute should succeed")
}
pub(crate) fn expand<R: CubeRuntime>(tensor: CubeTensor<R>, target_shape: Shape) -> CubeTensor<R> {
let ndims_in = tensor.meta.shape().num_dims();
let ndims_out = target_shape.num_dims();
let mut new_strides = strides![0usize; ndims_out];
let dim_diff = ndims_out.saturating_sub(ndims_in);
let mut tensor_dim_iter = tensor.meta.shape().iter().rev();
for i in (0..ndims_out).rev() {
if i >= dim_diff {
if let Some(&tensor_dim) = tensor_dim_iter.next() {
if tensor_dim == target_shape[i] || tensor_dim == 1 {
new_strides[i] = if tensor_dim == target_shape[i] {
tensor.meta.strides()[i - dim_diff]
} else {
0
};
} else {
panic!(
"Dimension mismatch: cannot broadcast dimension {tensor_dim} of tensor to target shape"
);
}
} else {
new_strides[i] = 0;
}
} else {
new_strides[i] = 0;
}
}
if tensor.qparams.is_some() {
match tensor.scheme().level {
QuantLevel::Tensor => {}
QuantLevel::Block(_) => todo!(),
}
}
CubeTensor {
client: tensor.client,
device: tensor.device,
meta: Box::new(Metadata::new(target_shape, new_strides)),
handle: tensor.handle,
dtype: tensor.dtype,
qparams: tensor.qparams,
}
}
pub fn reshape<R: CubeRuntime>(mut tensor: CubeTensor<R>, shape: Shape) -> CubeTensor<R> {
let analysis = reshape_action(tensor.meta.shape(), tensor.meta.strides(), &shape);
match analysis {
ReshapeAction::UpdateStrides { strides } => {
*tensor.meta = Metadata::new(shape, strides);
return tensor;
}
ReshapeAction::NoChange => return tensor,
ReshapeAction::Recompute => (),
}
let out = empty_device_dtype(
tensor.client.clone(),
tensor.device.clone(),
shape,
tensor.dtype,
);
cubecl::std::tensor::copy_into(
&tensor.client,
&tensor.as_handle_ref(),
&out.as_handle_ref(),
tensor.dtype.into(),
)
.expect("Kernel should not fail");
out
}
pub fn q_reshape<R: CubeRuntime>(mut tensor: CubeTensor<R>, shape: Shape) -> CubeTensor<R> {
let scheme = *tensor.scheme();
let shape_values = {
let rank = shape.num_dims();
let mut shape = shape.clone();
shape[rank - 1] = shape[rank - 1].div_ceil(scheme.num_quants());
shape
};
let shape_scales = params_shape(&shape, scheme.level);
let (values, scales) = tensor.quantized_handles().unwrap();
let analysis_values = reshape_action(values.meta.shape(), values.meta.strides(), &shape_values);
let analysis_scales = reshape_action(scales.meta.shape(), scales.meta.strides(), &shape_scales);
match (analysis_values, analysis_scales) {
(
ReshapeAction::UpdateStrides { strides },
ReshapeAction::UpdateStrides {
strides: scales_strides,
},
) => {
let qparams = tensor.qparams.as_mut().unwrap();
*tensor.meta = Metadata::new(shape, strides);
qparams.scales.metadata = Metadata::new(shape_scales, scales_strides);
}
(ReshapeAction::UpdateStrides { strides }, ReshapeAction::NoChange) => {
*tensor.meta = Metadata::new(shape, strides);
}
(
ReshapeAction::NoChange,
ReshapeAction::UpdateStrides {
strides: scales_strides,
},
) => {
let qparams = tensor.qparams.as_mut().unwrap();
qparams.scales.metadata = Metadata::new(shape_scales, scales_strides);
}
(ReshapeAction::NoChange, ReshapeAction::NoChange) => {}
_ => {
tensor = kernel::into_contiguous(tensor);
*tensor.meta = Metadata::new(shape, contiguous_strides(&shape_values));
let qparams = tensor.qparams.as_mut().unwrap();
let strides = contiguous_strides(&shape_scales);
qparams.scales.metadata = Metadata::new(shape_scales, strides);
}
}
tensor
}
pub(crate) fn max_line_size<R: CubeRuntime>(tensor: &CubeTensor<R>) -> LineSize {
tensor_line_size_parallel(
tensor.client.io_optimized_line_sizes(tensor.dtype.size()),
tensor.meta.shape(),
tensor.meta.strides(),
tensor.meta.num_dims() - 1,
)
}
pub(crate) fn max_line_size_many<R: CubeRuntime>(
tensors: &[&CubeTensor<R>],
axis: usize,
) -> LineSize {
let vec = tensors
.iter()
.map(|tensor| {
tensor_line_size_parallel(
tensor.client.io_optimized_line_sizes(tensor.dtype.size()),
tensor.meta.shape(),
tensor.meta.strides(),
axis,
)
})
.min();
vec.unwrap_or(0)
}
pub fn unfold<R: CubeRuntime>(
tensor: CubeTensor<R>,
dim: usize,
size: usize,
step: usize,
) -> CubeTensor<R> {
let shape = calculate_unfold_shape(tensor.shape(), dim, size, step);
let d_stride = tensor.meta.strides()[dim];
let mut strides = tensor.meta.strides.clone();
strides[dim] = step * d_stride;
strides.push(d_stride);
CubeTensor {
meta: Box::new(Metadata::new(shape, strides)),
..tensor
}
}