use burn_backend::{DType, QTensorPrimitive, TensorMetadata};
use cubecl::quant::scheme::{QuantStore, QuantValue};
use cubecl::server::AllocationKind;
use crate::{CubeRuntime, ops::empty_qtensor, tensor::CubeTensor};
pub fn into_contiguous<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
if tensor.is_contiguous() {
return tensor;
}
if tensor.qparams.is_some() {
return into_contiguous_quantized(tensor, AllocationKind::Contiguous);
}
let output = cubecl::std::tensor::into_contiguous_ref(
&tensor.client,
&tensor.as_handle_ref(),
tensor.dtype.into(),
)
.expect("Kernel to never fail");
CubeTensor::new(
tensor.client,
output.handle,
*output.metadata,
tensor.device,
tensor.dtype,
)
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(tensor))
)]
pub fn into_contiguous_aligned<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {
if R::can_read_tensor(tensor.meta.shape(), tensor.meta.strides()) {
return tensor;
}
if tensor.qparams.is_some() {
return into_contiguous_quantized(tensor, AllocationKind::Optimized);
}
let output = cubecl::std::tensor::into_contiguous_pitched_ref(
&tensor.client,
&tensor.as_handle_ref(),
tensor.dtype.into(),
)
.expect("Kernel to never fail");
CubeTensor::new(
tensor.client,
output.handle,
*output.metadata,
tensor.device,
tensor.dtype,
)
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(tensor))
)]
fn into_contiguous_quantized<R: CubeRuntime>(
tensor: CubeTensor<R>,
kind: AllocationKind,
) -> CubeTensor<R> {
let scheme = tensor.scheme();
let output = empty_qtensor(tensor.shape(), *tensor.scheme(), &tensor.device, kind);
let (values, scales) = tensor.quantized_handles().unwrap();
let (out_values, out_scales) = output.quantized_handles().unwrap();
match scheme.store {
QuantStore::PackedU32(packed_dim) => {
cubecl::std::tensor::into_contiguous_packed_ref(
&values.client,
&values.as_handle_ref(),
&out_values.as_handle_ref(),
packed_dim,
tensor.meta.shape(),
scheme.num_quants(),
DType::U32.into(),
)
.expect("Kernel to never fail");
}
QuantStore::PackedNative(packed_dim) if scheme.value == QuantValue::E2M1 => {
cubecl::std::tensor::into_contiguous_packed_ref(
&values.client,
&values.as_handle_ref(),
&out_values.as_handle_ref(),
packed_dim,
tensor.meta.shape(),
scheme.num_quants(),
DType::U8.into(),
)
.expect("Kernel to never fail");
}
_ => {
cubecl::std::tensor::copy_into(
&values.client,
&values.as_handle_ref(),
&out_values.as_handle_ref(),
values.dtype.into(),
)
.expect("Kernel to never fail");
}
}
cubecl::std::tensor::copy_into(
&scales.client,
&scales.as_handle_ref(),
&out_scales.as_handle_ref(),
scales.dtype.into(),
)
.expect("Kernel to never fail");
output
}