Skip to main content

cubecl_std/tensor/
handle.rs

1use core::marker::PhantomData;
2use cubecl_core::{Runtime, server, zspace::strides};
3use cubecl_core::{calculate_cube_count_elemwise, server::Allocation};
4use cubecl_core::{ir::StorageType, zspace::metadata::Metadata};
5use cubecl_core::{prelude::*, server::CopyDescriptor};
6use cubecl_core::{
7    tensor_line_size_parallel,
8    zspace::{Shape, Strides},
9};
10use cubecl_runtime::server::Handle;
11
12/// Tensor representation containing a [server handle](Handle) as well as basic tensor metadata.,
13pub struct TensorHandle<R>
14where
15    R: Runtime,
16{
17    /// The buffer where the data are stored.
18    pub handle: server::Handle,
19    pub metadata: Box<Metadata>,
20    /// The type used as storage.
21    pub dtype: StorageType,
22    runtime: PhantomData<R>,
23}
24
25impl<R> core::fmt::Debug for TensorHandle<R>
26where
27    R: Runtime,
28{
29    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
30        f.write_fmt(format_args!(
31            "Tensor {{ shape: {:?}, strides: {:?}, dtype: {}}}",
32            self.shape(),
33            self.strides(),
34            self.dtype,
35        ))
36    }
37}
38
39impl<R> Clone for TensorHandle<R>
40where
41    R: Runtime,
42{
43    fn clone(&self) -> Self {
44        Self {
45            handle: self.handle.clone(),
46            metadata: self.metadata.clone(),
47            dtype: self.dtype,
48            runtime: PhantomData,
49        }
50    }
51}
52
53impl<R> TensorHandle<R>
54where
55    R: Runtime,
56{
57    /// Create a new tensor.
58    pub fn new(
59        handle: server::Handle,
60        shape: impl Into<Shape>,
61        strides: impl Into<Strides>,
62        storage: StorageType,
63    ) -> Self {
64        Self {
65            handle,
66            metadata: Box::new(Metadata::new(shape, strides)),
67            dtype: storage,
68            runtime: PhantomData,
69        }
70    }
71
72    pub fn empty(client: &ComputeClient<R>, shape: impl Into<Shape>, storage: StorageType) -> Self {
73        let shape = shape.into();
74        let elem_size = storage.size();
75        let Allocation { handle, strides } = client.empty_tensor(&shape, elem_size);
76
77        Self::new(handle, shape, strides, storage)
78    }
79
80    /// Create a new tensor.
81    pub fn from_ref(handle: &TensorHandleRef<'_, R>, storage: StorageType) -> Self {
82        Self {
83            handle: handle.handle.clone(),
84            metadata: Box::new(Metadata::new(handle.shape, handle.strides)),
85            dtype: storage,
86            runtime: PhantomData,
87        }
88    }
89
90    /// Create a new tensor with a contiguous memory layout.
91    pub fn new_contiguous(shape: impl Into<Shape>, handle: Handle, storage: StorageType) -> Self {
92        let shape = shape.into();
93        let strides = Self::contiguous_strides(&shape);
94
95        Self {
96            handle,
97            metadata: Box::new(Metadata::new(shape, strides)),
98            dtype: storage,
99            runtime: PhantomData,
100        }
101    }
102
103    /// Check if the tensor is safe to mutate.
104    pub fn can_mut(&self) -> bool {
105        self.handle.can_mut()
106    }
107
108    pub fn as_ref(&self) -> TensorHandleRef<'_, R> {
109        unsafe {
110            TensorHandleRef::from_raw_parts(
111                &self.handle,
112                self.strides(),
113                self.shape(),
114                self.dtype.size(),
115            )
116        }
117    }
118
119    /// Return the reference to a tensor argument.
120    pub fn as_arg<'a>(&'a self, line_size: LineSize) -> TensorArg<'a, R> {
121        let handle: TensorHandleRef<'a, R> = self.as_ref();
122
123        unsafe {
124            TensorArg::from_raw_parts_and_size(
125                handle.handle,
126                handle.strides,
127                handle.shape,
128                line_size,
129                handle.elem_size,
130            )
131        }
132    }
133
134    pub fn as_copy_descriptor<'a>(&'a self) -> CopyDescriptor<'a> {
135        CopyDescriptor {
136            binding: self.handle.clone().binding(),
137            shape: self.shape(),
138            strides: self.strides(),
139            elem_size: self.dtype.size(),
140        }
141    }
142
143    pub fn required_address_type(&self) -> AddressType {
144        let len = self.handle.size() / self.dtype.size() as u64;
145        AddressType::from_len(len as usize)
146    }
147
148    pub fn shape(&self) -> &Shape {
149        self.metadata.shape()
150    }
151
152    pub fn strides(&self) -> &Strides {
153        self.metadata.strides()
154    }
155
156    fn contiguous_strides(shape: &[usize]) -> Strides {
157        let mut strides = strides![1; shape.len()];
158
159        let mut current = 1;
160        shape.iter().rev().enumerate().for_each(|(i, val)| {
161            strides[i] = current;
162            current *= val;
163        });
164        strides.reverse();
165        strides
166    }
167}
168impl<R> TensorHandle<R>
169where
170    R: Runtime,
171{
172    pub fn zeros(client: &ComputeClient<R>, shape: impl Into<Shape>, dtype: StorageType) -> Self {
173        let shape = shape.into();
174        let num_elements: usize = shape.iter().product();
175        let rank = shape.len();
176        let output = Self::empty(client, shape, dtype);
177
178        let line_size = tensor_line_size_parallel(
179            client.io_optimized_line_sizes(dtype.size()),
180            output.shape(),
181            output.strides(),
182            rank - 1,
183        );
184
185        let working_units = num_elements / line_size as usize;
186        let cube_dim = CubeDim::new(client, working_units);
187        let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim);
188        let array_len = output.handle.size() as usize / dtype.size();
189
190        unsafe {
191            init::zeros_array::launch_unchecked(
192                client,
193                cube_count,
194                cube_dim,
195                output.required_address_type(),
196                ArrayArg::from_raw_parts_and_size(
197                    &output.handle,
198                    array_len,
199                    line_size,
200                    dtype.size(),
201                ),
202                dtype,
203            )
204            .expect("Should be able to launch the kernel all the time")
205        };
206
207        output
208    }
209}
210
211pub(crate) mod init {
212    use cubecl::prelude::*;
213    use cubecl_core::{self as cubecl, ir::StorageType};
214
215    #[cube(launch_unchecked, address_type = "dynamic")]
216    pub fn zeros_array<C: Numeric>(output: &mut Array<Line<C>>, #[define(C)] _elem: StorageType) {
217        if ABSOLUTE_POS < output.len() {
218            output[ABSOLUTE_POS] = Line::cast_from(C::from_int(0));
219        }
220    }
221}