cubecl_std/tensor/
handle.rs

1use core::marker::PhantomData;
2use cubecl_core::ir::StorageType;
3use cubecl_core::tensor_line_size_parallel;
4use cubecl_core::{Runtime, server};
5use cubecl_core::{calculate_cube_count_elemwise, server::Allocation};
6use cubecl_core::{prelude::*, server::CopyDescriptor};
7use cubecl_runtime::server::Handle;
8
9/// Tensor representation containing a [server handle](Handle) as well as basic tensor metadata.,
10pub struct TensorHandle<R>
11where
12    R: Runtime,
13{
14    /// The buffer where the data are stored.
15    pub handle: server::Handle,
16    pub shape: Vec<usize>,
17    pub strides: Vec<usize>,
18    /// The type used as storage.
19    pub dtype: StorageType,
20    runtime: PhantomData<R>,
21}
22
23impl<R> core::fmt::Debug for TensorHandle<R>
24where
25    R: Runtime,
26{
27    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
28        f.write_fmt(format_args!(
29            "Tensor {{ shape: {:?}, strides: {:?}, dtype: {}}}",
30            self.shape, self.strides, self.dtype,
31        ))
32    }
33}
34
35impl<R> Clone for TensorHandle<R>
36where
37    R: Runtime,
38{
39    fn clone(&self) -> Self {
40        Self {
41            handle: self.handle.clone(),
42            shape: self.shape.clone(),
43            strides: self.strides.clone(),
44            dtype: self.dtype,
45            runtime: PhantomData,
46        }
47    }
48}
49
50impl<R> TensorHandle<R>
51where
52    R: Runtime,
53{
54    /// Create a new tensor.
55    pub fn new(
56        handle: server::Handle,
57        shape: Vec<usize>,
58        strides: Vec<usize>,
59        storage: StorageType,
60    ) -> Self {
61        Self {
62            handle,
63            shape,
64            strides,
65            dtype: storage,
66            runtime: PhantomData,
67        }
68    }
69
70    pub fn empty(
71        client: &ComputeClient<R::Server>,
72        shape: Vec<usize>,
73        storage: StorageType,
74    ) -> Self {
75        let elem_size = storage.size();
76        let Allocation { handle, strides } = client.empty_tensor(&shape, elem_size);
77
78        Self::new(handle, shape, strides, storage)
79    }
80
81    /// Create a new tensor.
82    pub fn from_ref(handle: &TensorHandleRef<'_, R>, storage: StorageType) -> Self {
83        Self {
84            handle: handle.handle.clone(),
85            shape: handle.shape.to_vec(),
86            strides: handle.strides.to_vec(),
87            dtype: storage,
88            runtime: PhantomData,
89        }
90    }
91
92    /// Create a new tensor with a contiguous memory layout.
93    pub fn new_contiguous(shape: Vec<usize>, handle: Handle, storage: StorageType) -> Self {
94        let strides = Self::contiguous_strides(&shape);
95
96        Self {
97            handle,
98            shape,
99            strides,
100            dtype: storage,
101            runtime: PhantomData,
102        }
103    }
104
105    /// Check if the tensor is safe to mutate.
106    pub fn can_mut(&self) -> bool {
107        self.handle.can_mut()
108    }
109
110    pub fn as_ref(&self) -> TensorHandleRef<'_, R> {
111        unsafe {
112            TensorHandleRef::from_raw_parts(
113                &self.handle,
114                &self.strides,
115                &self.shape,
116                self.dtype.size(),
117            )
118        }
119    }
120
121    /// Return the reference to a tensor argument.
122    pub fn as_arg<'a>(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
123        let handle: TensorHandleRef<'a, R> = self.as_ref();
124
125        unsafe {
126            TensorArg::from_raw_parts_and_size(
127                handle.handle,
128                handle.strides,
129                handle.shape,
130                vectorisation,
131                handle.elem_size,
132            )
133        }
134    }
135
136    pub fn as_copy_descriptor<'a>(&'a self) -> CopyDescriptor<'a> {
137        CopyDescriptor {
138            binding: self.handle.clone().binding(),
139            shape: &self.shape,
140            strides: &self.strides,
141            elem_size: self.dtype.size(),
142        }
143    }
144
145    fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
146        let mut strides = Vec::with_capacity(shape.len());
147
148        let mut current = 1;
149        shape.iter().enumerate().rev().for_each(|(_, val)| {
150            strides.push(current);
151            current *= val;
152        });
153        strides.reverse();
154        strides
155    }
156}
157impl<R> TensorHandle<R>
158where
159    R: Runtime,
160{
161    pub fn zeros(client: &ComputeClient<R::Server>, shape: Vec<usize>, dtype: StorageType) -> Self {
162        let num_elements: usize = shape.iter().product();
163        let rank = shape.len();
164        let output = Self::empty(client, shape, dtype);
165
166        let line_size = tensor_line_size_parallel(
167            R::supported_line_sizes().iter().cloned(),
168            &output.shape,
169            &output.strides,
170            rank - 1,
171        );
172
173        let cube_dim = CubeDim::default();
174        let cube_count = calculate_cube_count_elemwise(num_elements / line_size as usize, cube_dim);
175        let array_len = output.handle.size() as usize / dtype.size();
176
177        unsafe {
178            init::zeros_array::launch_unchecked::<R>(
179                client,
180                cube_count,
181                cube_dim,
182                ArrayArg::from_raw_parts_and_size(
183                    &output.handle,
184                    array_len,
185                    line_size,
186                    dtype.size(),
187                ),
188                dtype,
189            )
190        };
191
192        output
193    }
194}
195
196pub(crate) mod init {
197    use cubecl::prelude::*;
198    use cubecl_core::{self as cubecl, ir::StorageType};
199
200    #[cube(launch_unchecked)]
201    pub fn zeros_array<C: Numeric>(output: &mut Array<Line<C>>, #[define(C)] _elem: StorageType) {
202        if ABSOLUTE_POS < output.len() {
203            output[ABSOLUTE_POS] = Line::cast_from(C::from_int(0));
204        }
205    }
206}