use crate::CubeRuntime;
use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
use burn_backend::quantization::QuantScheme;
use burn_backend::{DType, QTensorPrimitive, Shape, TensorMetadata};
use burn_std::{Metadata, strides, tensor::is_contiguous};
use cubecl::server::Handle;
use cubecl::std::tensor::TensorHandle;
use cubecl::{client::ComputeClient, std::tensor::layout::linear::LinearViewLaunch};
use cubecl::{frontend::Numeric, std::tensor::layout::linear::LinearViewLayoutLaunch};
use cubecl::{
prelude::{TensorBinding, *},
std::tensor::layout::linear::LinearViewLayout,
};
use std::marker::PhantomData;
use super::QParams;
pub struct CubeTensor<R: CubeRuntime> {
pub client: ComputeClient<R>,
pub handle: Handle,
pub meta: Box<Metadata>,
pub device: R::Device,
pub dtype: DType,
pub qparams: Option<QParams>,
}
impl<R: CubeRuntime> From<CubeTensor<R>> for TensorHandle<R> {
fn from(val: CubeTensor<R>) -> Self {
TensorHandle::new(
val.handle.clone(),
val.meta.shape().clone(),
val.meta.strides().clone(),
val.dtype,
)
}
}
impl<R: CubeRuntime> cubecl::tune::AutotuneOutput for CubeTensor<R> {
#[cfg(feature = "autotune-checks")]
fn check_equivalence(&self, other: Self) {
use crate::ops::into_data_sync;
use burn_backend::Tolerance;
let expected = into_data_sync::<R>(self.clone());
let actual = into_data_sync::<R>(other);
expected.assert_approx_eq::<f32>(&actual, Tolerance::permissive());
}
}
impl<R> core::fmt::Debug for CubeTensor<R>
where
R: CubeRuntime,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!(
"CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}",
self.meta.shape(),
self.device,
self.meta.strides(),
self.dtype.name(),
R::name(&self.client),
))
}
}
impl<R> Clone for CubeTensor<R>
where
R: CubeRuntime,
{
fn clone(&self) -> Self {
Self {
client: self.client.clone(),
handle: self.handle.clone(),
meta: self.meta.clone(),
device: self.device.clone(),
dtype: self.dtype,
qparams: self.qparams.clone(),
}
}
}
impl<R: CubeRuntime> TensorMetadata for CubeTensor<R> {
fn dtype(&self) -> DType {
self.dtype
}
fn shape(&self) -> Shape {
self.meta.shape().clone()
}
fn rank(&self) -> usize {
self.meta.rank()
}
}
impl<R: CubeRuntime> QTensorPrimitive for CubeTensor<R> {
fn scheme(&self) -> &QuantScheme {
if let DType::QFloat(scheme) = &self.dtype {
scheme
} else {
panic!(
"Quantization scheme is not valid for dtype {:?}",
self.dtype,
)
}
}
}
impl<R> CubeTensor<R>
where
R: CubeRuntime,
{
pub fn new(
client: ComputeClient<R>,
handle: Handle,
metadata: Metadata,
device: R::Device,
dtype: DType,
) -> Self {
CubeTensor {
client,
handle,
meta: Box::new(metadata),
device,
dtype,
qparams: None,
}
}
pub fn new_contiguous(
client: ComputeClient<R>,
device: R::Device,
shape: Shape,
handle: Handle,
dtype: DType,
) -> Self {
let ndims = shape.num_dims();
let mut strides = strides![0; ndims];
let mut current = 1;
shape.iter().enumerate().rev().for_each(|(index, val)| {
strides[index] = current;
current *= val;
});
Self {
client,
handle,
meta: Box::new(Metadata::new(shape, strides)),
device,
dtype,
qparams: None,
}
}
pub fn to_client(&mut self, client: ComputeClient<R>, device: R::Device) -> Self {
let desc = self.handle.clone().copy_descriptor(
self.meta.shape().clone(),
self.meta.strides().clone(),
self.elem_size(),
);
let handle = self
.client
.to_client_tensor(desc, &client, self.dtype.into());
Self {
client,
handle,
meta: Box::new(Metadata::new(self.shape(), self.meta.strides().clone())),
device,
dtype: self.dtype,
qparams: self.qparams.clone(),
}
}
pub fn binding(self) -> TensorBinding<R> {
TensorBinding {
handle: self.handle.binding(),
strides: self.meta.strides,
shape: self.meta.shape,
runtime: PhantomData,
}
}
pub fn elem_size(&self) -> usize {
self.dtype.size()
}
pub fn into_tensor_arg(self) -> TensorArg<R> {
self.binding().into_tensor_arg()
}
pub fn into_array_arg(self) -> ArrayArg<R> {
self.into_tensor_arg().into_array_arg()
}
pub fn as_tensor_alias(&self, input_pos: usize) -> TensorArg<R> {
TensorArg::Alias {
input_pos,
strides: self.meta.strides().clone(),
shape: self.meta.shape().clone(),
}
}
pub fn into_linear_view(self) -> LinearViewLaunch<R> {
let layout = LinearViewLayoutLaunch::new();
let buffer = self.into_tensor_arg();
LinearViewLaunch::new_tensor::<LinearViewLayout>(buffer, layout)
}
pub fn as_linear_view_alias(&self, input_pos: usize) -> LinearViewLaunch<R> {
let layout = LinearViewLayoutLaunch::new();
let buffer = self.as_tensor_alias(input_pos);
LinearViewLaunch::new_tensor::<LinearViewLayout>(buffer, layout)
}
pub fn into_linear_view_like(self, reference: &Self) -> LinearViewLaunch<R> {
let layout = LinearViewLayoutLaunch::from_reference_shape(reference.shape());
let buffer = self.into_tensor_arg();
LinearViewLaunch::new_tensor::<LinearViewLayout>(buffer, layout)
}
pub fn required_address_type(&self) -> AddressType {
match self.try_scheme() {
Some(scheme) => {
let len = self.handle.size() as usize * 8 / scheme.size_bits_value();
AddressType::from_len(len)
}
None => AddressType::from_len(self.handle.size() as usize / self.dtype.size()),
}
}
pub fn try_scheme(&self) -> Option<&QuantScheme> {
match &self.dtype {
DType::QFloat(scheme) => Some(scheme),
_ => None,
}
}
pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {
if !self.handle.can_mut() || !self.is_nonoverlapping() {
return false;
}
let ndims = self.meta.num_dims();
for i in 0..ndims {
let shape_lhs = self.meta.shape()[i];
let shape_rhs = rhs.meta.shape()[i];
if shape_lhs < shape_rhs {
return false;
}
}
true
}
pub fn copy(&self) -> Self {
struct Copy;
#[cube]
impl<T: Numeric, N: Size> NumericUnaryOp<T, N> for Copy {
type Options = ();
fn execute(input: Vector<T, N>, _options: &Self::Options) -> Vector<T, N> {
input
}
}
impl NumericUnaryOpFamily for Copy {
type Options = ();
type Unary<T: Numeric, N: Size> = Self;
}
let tensor = self.clone();
launch_unary_numeric::<R, Copy, _>(tensor, |_| ())
}
pub fn can_mut(&self) -> bool {
self.handle.can_mut()
}
pub fn assert_is_on_same_device(&self, other: &Self) {
if self.device != other.device {
panic!(
"Both tensors should be on the same device {:?} != {:?}",
self.device, other.device
);
}
}
pub fn is_contiguous(&self) -> bool {
is_contiguous(self.meta.shape(), self.meta.strides())
}
pub fn is_contiguous_buffer(&self) -> bool {
self.meta.shape().num_elements() * self.dtype.size() == self.handle.size() as usize
}
pub fn is_nonoverlapping(&self) -> bool {
let shape = self.meta.shape();
let strides = self.meta.strides();
if strides.contains(&0) {
return false;
}
let rank = self.rank();
if rank > 1 {
let mut dims = shape.iter().zip(strides.iter()).collect::<Vec<_>>();
dims.sort_by_key(|(_, stride)| **stride);
let mut max_offset = 0;
for (shape, stride) in dims.into_iter() {
if *stride <= max_offset && *shape != 1 {
return false;
}
max_offset += (*shape - 1) * *stride;
}
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn is_contiguous_non_increasing() {
assert!(is_contiguous(&[3, 1], &[1, 1]));
}
#[test]
fn is_contiguous_basic() {
assert!(is_contiguous(&[32, 32], &[32, 1]));
}
#[test]
fn is_contiguous_permuted() {
assert!(!is_contiguous(&[32, 32], &[1, 32]));
}
#[test]
fn is_contiguous_slice() {
assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));
}
#[test]
fn is_contiguous_4d_positive() {
assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));
}
#[test]
fn is_contiguous_4d_negative() {
assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));
}
#[test]
fn is_contiguous_4d_unit_shape() {
assert!(!is_contiguous(&[1, 1, 1, 9], &[72, 1, 72, 8]));
}
}