use crate::{CubeRuntime, kernel, ops::numeric::empty_device_dtype, tensor::CubeTensor};
use burn_backend::{
DType, ExecutionError, QTensorPrimitive, Shape, TensorData,
quantization::{QuantLevel, QuantStore, params_shape},
};
use burn_backend::{TensorMetadata, ops::unfold::calculate_unfold_shape};
use burn_std::{
Metadata, QuantValue, ReshapeAnalysis, reshape_analysis, strides,
tensor::{ReshapeAction, contiguous_strides, reshape_action},
};
use cubecl::{ir::VectorSize, server::CopyDescriptor};
use cubecl::{quant::scheme::BlockSize, tensor_vector_size_parallel};
pub(crate) fn from_data<R: CubeRuntime>(data: TensorData, device: &R::Device) -> CubeTensor<R> {
let client = R::client(device);
let alloc = client.create_tensor(data.bytes, data.shape.clone(), data.dtype.size());
let shape: Shape = (&data.shape).into();
CubeTensor::new(
client,
alloc.memory,
Metadata::new(shape, alloc.strides),
device.clone(),
data.dtype,
)
}
pub(crate) async fn into_data<R: CubeRuntime>(
tensor: CubeTensor<R>,
) -> Result<TensorData, ExecutionError> {
let tensor = kernel::into_contiguous_aligned(tensor);
let elem_size = tensor.elem_size();
let shape = tensor.meta.shape().clone();
let strides = tensor.meta.strides().clone();
let binding = CopyDescriptor::new(tensor.handle.binding(), shape, strides, elem_size);
let bytes = tensor
.client
.read_one_tensor_async(binding)
.await
.map_err(|err| ExecutionError::WithContext {
reason: format!("{err}"),
})?;
Ok(TensorData::from_bytes(
bytes,
tensor.meta.shape.clone(),
tensor.dtype,
))
}
#[allow(unused, reason = "useful for debugging kernels")]
pub fn into_data_sync<R: CubeRuntime>(tensor: CubeTensor<R>) -> TensorData {
burn_std::future::block_on(into_data(tensor)).unwrap()
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(tensor, device))
)]
pub(crate) fn to_device<R: CubeRuntime>(
tensor: CubeTensor<R>,
device: &R::Device,
) -> CubeTensor<R> {
if &tensor.device == device {
return tensor;
}
let mut tensor = kernel::into_contiguous_aligned(tensor);
let client = R::client(device);
tensor.to_client(client, device.clone())
}
pub(crate) fn empty<R: CubeRuntime>(
shape: Shape,
device: &R::Device,
dtype: DType,
) -> CubeTensor<R> {
let client = R::client(device);
let alloc = client.empty_tensor(shape.clone(), dtype.size());
CubeTensor::new(
client,
alloc.memory,
Metadata::new(shape, alloc.strides),
device.clone(),
dtype,
)
}
pub(crate) fn swap_dims<R: CubeRuntime>(
mut tensor: CubeTensor<R>,
dim1: usize,
dim2: usize,
) -> CubeTensor<R> {
tensor.meta.swap(dim1, dim2);
if let DType::QFloat(scheme) = tensor.dtype
&& let QuantLevel::Block(block_size) = scheme.level
{
let rank = tensor.rank();
let qparams = tensor.qparams.as_mut().unwrap();
let mut block_size = block_size.to_dim_vec(rank);
block_size.swap(dim1, dim2);
let block_size = BlockSize::new_trim(block_size);
if block_size.len() > BlockSize::MAX_DIMS {
panic!("Swapped block size would exceed max dims");
}
qparams.scales.metadata.swap(dim1, dim2);
tensor.dtype = DType::QFloat(scheme.with_level(QuantLevel::Block(block_size)))
}
if let DType::QFloat(scheme) = &mut tensor.dtype
&& let QuantStore::PackedU32(packed_dim) | QuantStore::PackedNative(packed_dim) =
&mut scheme.store
{
let rank = tensor.meta.num_dims();
if *packed_dim == rank - dim1 - 1 {
*packed_dim = rank - dim2 - 1;
} else if *packed_dim == rank - dim2 - 1 {
*packed_dim = rank - dim1 - 1;
}
}
tensor
}
pub fn permute<R: CubeRuntime>(mut tensor: CubeTensor<R>, axes: &[usize]) -> CubeTensor<R> {
tensor.meta.permute(axes).unwrap();
if let DType::QFloat(scheme) = tensor.dtype
&& let QuantLevel::Block(block_size) = scheme.level
{
let rank = tensor.rank();
let qparams = tensor.qparams.as_mut().unwrap();
let mut block_size = block_size.to_dim_vec(rank);
block_size = axes.iter().map(|i| block_size[*i]).collect();
let block_size = block_size
.into_iter()
.skip_while(|it| *it == 1)
.collect::<Vec<_>>();
if block_size.len() > BlockSize::MAX_DIMS {
panic!("Swapped block size would exceed max dims");
}
qparams.scales.metadata.permute(axes).unwrap();
tensor.dtype = DType::QFloat(scheme.with_level(QuantLevel::block(&block_size)))
}
if let DType::QFloat(scheme) = &mut tensor.dtype
&& let QuantStore::PackedU32(packed_dim) = &mut scheme.store
{
let rank = tensor.meta.num_dims();
let new_pos = axes
.iter()
.position(|axis| *axis == rank - *packed_dim - 1)
.unwrap_or(0);
*packed_dim = rank - new_pos - 1;
}
tensor
}
pub fn permute_nchw_to_nhwc<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
let rank = tensor.meta.num_dims();
let c_dim = 1;
let mut dims = vec![0];
dims.extend(2..rank);
dims.push(c_dim);
permute(tensor, &dims)
}
pub fn permute_nchw_to_nhwc_shape(shape: Shape) -> Shape {
let rank = shape.num_dims();
let c_dim = 1;
let mut dims = vec![0];
dims.extend(2..rank);
dims.push(c_dim);
shape.permuted(&dims).expect("Shape permute should succeed")
}
pub fn permute_nhwc_to_nchw<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
let rank = tensor.meta.num_dims();
let c_dim = rank - 1;
let mut dims = vec![0];
dims.push(c_dim);
dims.extend(1..c_dim);
permute(tensor, &dims)
}
pub fn permute_nhwc_to_nchw_shape(shape: Shape) -> Shape {
let rank = shape.num_dims();
let c_dim = rank - 1;
let mut dims = vec![0];
dims.push(c_dim);
dims.extend(1..c_dim);
shape.permuted(&dims).expect("Shape permute should succeed")
}
pub(crate) fn expand<R: CubeRuntime>(tensor: CubeTensor<R>, target_shape: Shape) -> CubeTensor<R> {
let ndims_in = tensor.meta.shape().num_dims();
let ndims_out = target_shape.num_dims();
let mut new_strides = strides![0usize; ndims_out];
let dim_diff = ndims_out.saturating_sub(ndims_in);
let mut tensor_dim_iter = tensor.meta.shape().iter().rev();
for i in (0..ndims_out).rev() {
if i >= dim_diff {
if let Some(&tensor_dim) = tensor_dim_iter.next() {
if tensor_dim == target_shape[i] || tensor_dim == 1 {
new_strides[i] = if tensor_dim == target_shape[i] {
tensor.meta.strides()[i - dim_diff]
} else {
0
};
} else {
panic!(
"Dimension mismatch: cannot broadcast dimension {tensor_dim} of tensor to target shape"
);
}
} else {
new_strides[i] = 0;
}
} else {
new_strides[i] = 0;
}
}
if tensor.qparams.is_some() {
match tensor.scheme().level {
QuantLevel::Tensor => {}
QuantLevel::Block(_) => todo!(),
}
}
CubeTensor {
client: tensor.client.clone(),
device: tensor.device.clone(),
meta: Box::new(Metadata::new(target_shape, new_strides)),
handle: tensor.handle.clone(),
dtype: tensor.dtype,
qparams: tensor.qparams.clone(),
}
}
pub fn reshape<R: CubeRuntime>(mut tensor: CubeTensor<R>, shape: Shape) -> CubeTensor<R> {
let analysis = reshape_action(tensor.meta.shape(), tensor.meta.strides(), &shape);
match analysis {
ReshapeAction::UpdateStrides { strides } => {
*tensor.meta = Metadata::new(shape, strides);
return tensor;
}
ReshapeAction::NoChange => return tensor,
ReshapeAction::Recompute => (),
}
let out = empty_device_dtype(
tensor.client.clone(),
tensor.device.clone(),
shape,
tensor.dtype,
);
cubecl::std::tensor::copy_into(
&out.client,
tensor.binding(),
out.clone().binding(),
out.dtype.into(),
);
out
}
pub fn q_reshape<R: CubeRuntime>(mut tensor: CubeTensor<R>, shape: Shape) -> CubeTensor<R> {
let scheme = *tensor.scheme();
let curr_shape = tensor.meta.shape();
let shape_values = match scheme.store {
QuantStore::Native => shape.clone(),
QuantStore::PackedNative(packed_dim) | QuantStore::PackedU32(packed_dim) => {
let rank = shape.num_dims();
let mut shape = shape.clone();
let packed_d = rank - packed_dim - 1;
let num_quants = scheme.num_quants();
if !shape[packed_d].is_multiple_of(num_quants) {
unimplemented!(
"Cannot reshape packed tensor: inner dimension {} is not aligned with packing factor {num_quants}",
shape[packed_d]
);
}
shape[packed_d] = shape[packed_d].div_ceil(num_quants);
shape
}
};
let (values, scales) = tensor.quantized_handles().unwrap();
let analysis_values = reshape_analysis(
values.meta.shape(),
Some(values.meta.strides()),
&shape_values,
);
let action_values =
analysis_values.action(values.meta.shape(), values.meta.strides(), &shape_values);
let n_new_dims = shape.num_dims().saturating_sub(curr_shape.num_dims());
let is_unsqueeze = n_new_dims > 0 && shape[n_new_dims..] == **curr_shape;
if !is_unsqueeze
&& matches!(
scheme.value,
QuantValue::Q4S | QuantValue::Q4F | QuantValue::Q2S | QuantValue::Q2F
)
{
todo!("Reshape with sub-byte values is not supported")
}
if let ReshapeAction::UpdateStrides { .. } = &action_values {
match analysis_values {
ReshapeAnalysis::IsContiguous => {
if let QuantLevel::Block(block_size) = scheme.level
&& block_size.len() > 1
&& !is_unsqueeze
{
unimplemented!("Reshape of ND block-quantized tensor is not yet supported.");
}
}
ReshapeAnalysis::Broadcasted => {} ReshapeAnalysis::Split => {
if let QuantLevel::Block(block_size) = scheme.level
&& block_size.len() > 1
{
unimplemented!(
"Split reshape of ND block-quantized tensor is not yet supported."
);
}
}
other => unreachable!("Reshape analysis {other:?} should not update strides."),
}
}
let shape_last = *shape.last().unwrap();
let shape_scales = match scheme.level {
QuantLevel::Tensor => scales.meta.shape().clone(), QuantLevel::Block(block_size)
if block_size.len() == 1 && shape_last < (block_size[0] as usize) =>
{
if scales.meta.shape().num_elements() > 1 {
unimplemented!("Reshape would split a block across multiple rows.");
}
scales.meta.shape().clone()
}
QuantLevel::Block(_) => {
params_shape(&shape, scheme.level)
}
};
let action_scales = reshape_action(scales.meta.shape(), scales.meta.strides(), &shape_scales);
match (action_values, action_scales) {
(
ReshapeAction::UpdateStrides { strides },
ReshapeAction::UpdateStrides {
strides: scales_strides,
},
) => {
let qparams = tensor.qparams.as_mut().unwrap();
*tensor.meta = Metadata::new(shape, strides);
qparams.scales.metadata = Metadata::new(shape_scales, scales_strides);
}
(ReshapeAction::UpdateStrides { strides }, ReshapeAction::NoChange) => {
*tensor.meta = Metadata::new(shape, strides);
}
(
ReshapeAction::NoChange,
ReshapeAction::UpdateStrides {
strides: scales_strides,
},
) => {
let qparams = tensor.qparams.as_mut().unwrap();
qparams.scales.metadata = Metadata::new(shape_scales, scales_strides);
}
(ReshapeAction::Recompute, _) | (_, ReshapeAction::Recompute) => {
if let QuantLevel::Block(_) = scheme.level
&& shape_scales.num_elements() > 1
{
unimplemented!(
"Cannot reshape a block-quantized tensor when the reshape requires recomputing the buffer."
);
}
tensor = kernel::into_contiguous(tensor);
*tensor.meta = Metadata::new(shape, contiguous_strides(&shape_values));
let qparams = tensor.qparams.as_mut().unwrap();
let strides = contiguous_strides(&shape_scales);
qparams.scales.metadata = Metadata::new(shape_scales, strides);
}
(ReshapeAction::NoChange, ReshapeAction::NoChange) => {}
}
tensor
}
pub(crate) fn max_vector_size<R: CubeRuntime>(tensor: &CubeTensor<R>) -> VectorSize {
tensor_vector_size_parallel(
tensor.client.io_optimized_vector_sizes(tensor.dtype.size()),
tensor.meta.shape(),
tensor.meta.strides(),
tensor.meta.num_dims() - 1,
)
}
pub(crate) fn max_vector_size_many<R: CubeRuntime>(
tensors: &[&CubeTensor<R>],
axis: usize,
) -> VectorSize {
let vec = tensors
.iter()
.map(|tensor| {
tensor_vector_size_parallel(
tensor.client.io_optimized_vector_sizes(tensor.dtype.size()),
tensor.meta.shape(),
tensor.meta.strides(),
axis,
)
})
.min();
vec.unwrap_or(0)
}
pub fn unfold<R: CubeRuntime>(
tensor: CubeTensor<R>,
dim: usize,
size: usize,
step: usize,
) -> CubeTensor<R> {
let shape = calculate_unfold_shape(tensor.shape(), dim, size, step);
let d_stride = tensor.meta.strides()[dim];
let mut strides = tensor.meta.strides.clone();
strides[dim] = step * d_stride;
strides.push(d_stride);
CubeTensor {
meta: Box::new(Metadata::new(shape, strides)),
client: tensor.client.clone(),
handle: tensor.handle.clone(),
device: tensor.device.clone(),
dtype: tensor.dtype,
qparams: tensor.qparams.clone(),
}
}